Skip to content

Commit

Permalink
Added the reverse, call from torch a chainer function
Browse files Browse the repository at this point in the history
  • Loading branch information
Emilio Castillo committed Jul 22, 2020
1 parent fd80272 commit ea3635c
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 0 deletions.
52 changes: 52 additions & 0 deletions chainer_pytorch_migration/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,55 @@ def chainer_torch_function(torch_fn, inputs, *args, **kwargs):
if len(y) == 1:
return y[0]
return y


class TorchChainerFunction(torch.autograd.Function):
@staticmethod
def chainer_fn():
raise RuntimeError('chainer_fn function must be overriden')

@classmethod
def forward(cls, ctx, *inputs):
chainer_fn = cls.chainer_fn()
ctx.save_for_backward(*inputs)
c_inputs = tuple((chainer.Variable(cpm.asarray(x)) for x in inputs))
ctx.c_inputs = c_inputs
c_outputs = chainer_fn(*c_inputs)
if not type(c_outputs) is tuple:
c_outputs = (c_outputs,)
t_outputs = [cpm.astensor(y.array) for y in c_outputs]
for t_y in t_outputs:
t_y.requires_grad = True
ctx.c_outputs = c_outputs
return t_outputs[0]

@staticmethod
def backward(ctx, *grads):
grads = [ctx.c_outputs, ctx.c_inputs] + list(grads)
out_grads = _TorchChainerFunctionGrad.apply(*grads)
return out_grads


class _TorchChainerFunctionGrad(torch.autograd.Function):

@staticmethod
def forward(ctx, *inputs):
c_outputs = inputs[0]
c_inputs = inputs[1]
inputs = inputs[2:]
ctx.save_for_backward(*inputs)
c_grads = tuple((chainer.Variable(cpm.asarray(g)) for g in inputs))
fwd_outputs = c_outputs
chainer.backward(fwd_outputs, c_grads, enable_double_backprop=True)
out_grads = tuple(
cpm.astensor(x.grad) for x in c_inputs
)
for t_y in out_grads:
t_y.requires_grad = True
ctx.c_outputs = [x.grad for x in c_inputs]
ctx.c_inputs = c_grads
return out_grads

def backward(ctx, *grads):
grads = [ctx.c_outputs, ctx.c_inputs] + list(grads)
return _TorchChainerFunctionGrad.apply(*grads)
24 changes: 24 additions & 0 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,27 @@ def test_multiple_outputs():
z.backward(torch.ones((3, 5)))
t_grad = x.grad
assert torch.allclose(t_grad, cpm.astensor(c_grad))


def test_torch_chainer_function():
class TorchChainerSigmoid(cpm.functions.TorchChainerFunction):
@staticmethod
def chainer_fn():
return chainer.functions.sigmoid
# Combined torch
x = torch.ones(10)
x.requires_grad = True
y = torch.sin(x)
y = TorchChainerSigmoid.apply(y)
y = torch.sum(y)
y.backward()
ct_grad = x.grad

# All in torch
x = torch.ones(10)
x.requires_grad = True
y = torch.sin(x)
y = torch.sigmoid(y)
y = torch.sum(y)
y.backward()
assert torch.allclose(ct_grad, x.grad)

0 comments on commit ea3635c

Please sign in to comment.