Skip to content

Commit

Permalink
Added the reverse, call from torch a chainer function
Browse files Browse the repository at this point in the history
  • Loading branch information
Emilio Castillo committed Jul 22, 2020
1 parent fd80272 commit dd4d223
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 0 deletions.
55 changes: 55 additions & 0 deletions chainer_pytorch_migration/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,58 @@ def chainer_torch_function(torch_fn, inputs, *args, **kwargs):
if len(y) == 1:
return y[0]
return y


class TorchChainerFunction(torch.autograd.Function):
@staticmethod
def chainer_fn():
raise RuntimeError('chainer_fn function must be overriden')

@classmethod
def forward(cls, ctx, *inputs):
chainer_fn = cls.chainer_fn()
ctx.save_for_backward(*inputs)
c_inputs = tuple((chainer.Variable(cpm.asarray(x)) for x in inputs))
ctx.c_inputs = c_inputs
c_outputs = chainer_fn(*c_inputs)
if not type(c_outputs) is tuple:
c_outputs = (c_outputs,)
t_outputs = [cpm.astensor(y.array) for y in c_outputs]
for t_y in t_outputs:
t_y.requires_grad = True
ctx.c_outputs = c_outputs
if len(t_outputs) == 1:
return t_outputs[0]
else:
return tuple(t_outputs)

@staticmethod
def backward(ctx, *grads):
grads = [ctx.c_outputs, ctx.c_inputs] + list(grads)
out_grads = _TorchChainerFunctionGrad.apply(*grads)
return out_grads


class _TorchChainerFunctionGrad(torch.autograd.Function):

@staticmethod
def forward(ctx, *inputs):
c_outputs = inputs[0]
c_inputs = inputs[1]
inputs = inputs[2:]
ctx.save_for_backward(*inputs)
c_grads = tuple((chainer.Variable(cpm.asarray(g)) for g in inputs))
fwd_outputs = c_outputs
chainer.backward(fwd_outputs, c_grads, enable_double_backprop=True)
out_grads = tuple(
cpm.astensor(x.grad) for x in c_inputs
)
for t_y in out_grads:
t_y.requires_grad = True
ctx.c_outputs = [x.grad for x in c_inputs]
ctx.c_inputs = c_grads
return out_grads

def backward(ctx, *grads):
grads = [ctx.c_outputs, ctx.c_inputs] + list(grads)
return _TorchChainerFunctionGrad.apply(*grads)
56 changes: 56 additions & 0 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,59 @@ def test_multiple_outputs():
z.backward(torch.ones((3, 5)))
t_grad = x.grad
assert torch.allclose(t_grad, cpm.astensor(c_grad))


def test_torch_chainer_function():
class TorchChainerSigmoid(cpm.functions.TorchChainerFunction):
@staticmethod
def chainer_fn():
return chainer.functions.sigmoid
# Combined torch
x = torch.ones(10)
x.requires_grad = True
y = torch.sin(x)
y = TorchChainerSigmoid.apply(y)
y = torch.sum(y)
y.backward()
ct_grad = x.grad

# All in torch
x = torch.ones(10)
x.requires_grad = True
y = torch.sin(x)
y = torch.sigmoid(y)
y = torch.sum(y)
y.backward()
assert torch.allclose(ct_grad, x.grad)


def test_torch_chainer_function_2():
class TorchChainerAdd(cpm.functions.TorchChainerFunction):
@staticmethod
def chainer_fn():
return chainer.functions.add
# Combined torch
a = torch.ones(10)
a.requires_grad = True
b = torch.ones(10)+2
b.requires_grad = True
y = torch.sin(a)
z = torch.sin(b)
y = TorchChainerAdd.apply(y, z)
y = torch.sum(y)
y.backward()
a_ct_grad = a.grad
b_ct_grad = b.grad

# All in torch
a = torch.ones(10)
a.requires_grad = True
b = torch.ones(10)+2
b.requires_grad = True
y = torch.sin(a)
z = torch.sin(b)
y = torch.add(y, z)
y = torch.sum(y)
y.backward()
assert torch.allclose(a_ct_grad, a.grad)
assert torch.allclose(b_ct_grad, b.grad)

0 comments on commit dd4d223

Please sign in to comment.