diff --git a/chainer_pytorch_migration/functions.py b/chainer_pytorch_migration/functions.py index 9aff504..930d591 100644 --- a/chainer_pytorch_migration/functions.py +++ b/chainer_pytorch_migration/functions.py @@ -66,3 +66,58 @@ def chainer_torch_function(torch_fn, inputs, *args, **kwargs): if len(y) == 1: return y[0] return y + + +class TorchChainerFunction(torch.autograd.Function): + @staticmethod + def chainer_fn(): + raise RuntimeError('chainer_fn function must be overriden') + + @classmethod + def forward(cls, ctx, *inputs): + chainer_fn = cls.chainer_fn() + ctx.save_for_backward(*inputs) + c_inputs = tuple((chainer.Variable(cpm.asarray(x)) for x in inputs)) + ctx.c_inputs = c_inputs + c_outputs = chainer_fn(*c_inputs) + if not type(c_outputs) is tuple: + c_outputs = (c_outputs,) + t_outputs = [cpm.astensor(y.array) for y in c_outputs] + for t_y in t_outputs: + t_y.requires_grad = True + ctx.c_outputs = c_outputs + if len(t_outputs) == 1: + return t_outputs[0] + else: + return tuple(t_outputs) + + @staticmethod + def backward(ctx, *grads): + grads = [ctx.c_outputs, ctx.c_inputs] + list(grads) + out_grads = _TorchChainerFunctionGrad.apply(*grads) + return out_grads + + +class _TorchChainerFunctionGrad(torch.autograd.Function): + + @staticmethod + def forward(ctx, *inputs): + c_outputs = inputs[0] + c_inputs = inputs[1] + inputs = inputs[2:] + ctx.save_for_backward(*inputs) + c_grads = tuple((chainer.Variable(cpm.asarray(g)) for g in inputs)) + fwd_outputs = c_outputs + chainer.backward(fwd_outputs, c_grads, enable_double_backprop=True) + out_grads = tuple( + cpm.astensor(x.grad) for x in c_inputs + ) + for t_y in out_grads: + t_y.requires_grad = True + ctx.c_outputs = [x.grad for x in c_inputs] + ctx.c_inputs = c_grads + return out_grads + + def backward(ctx, *grads): + grads = [ctx.c_outputs, ctx.c_inputs] + list(grads) + return _TorchChainerFunctionGrad.apply(*grads) diff --git a/tests/test_functions.py b/tests/test_functions.py index ac3fa86..97147c0 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -47,3 +47,59 @@ def test_multiple_outputs(): z.backward(torch.ones((3, 5))) t_grad = x.grad assert torch.allclose(t_grad, cpm.astensor(c_grad)) + + +def test_torch_chainer_function(): + class TorchChainerSigmoid(cpm.functions.TorchChainerFunction): + @staticmethod + def chainer_fn(): + return chainer.functions.sigmoid + # Combined torch + x = torch.ones(10) + x.requires_grad = True + y = torch.sin(x) + y = TorchChainerSigmoid.apply(y) + y = torch.sum(y) + y.backward() + ct_grad = x.grad + + # All in torch + x = torch.ones(10) + x.requires_grad = True + y = torch.sin(x) + y = torch.sigmoid(y) + y = torch.sum(y) + y.backward() + assert torch.allclose(ct_grad, x.grad) + + +def test_torch_chainer_function_2(): + class TorchChainerAdd(cpm.functions.TorchChainerFunction): + @staticmethod + def chainer_fn(): + return chainer.functions.add + # Combined torch + a = torch.ones(10) + a.requires_grad = True + b = torch.ones(10)+2 + b.requires_grad = True + y = torch.sin(a) + z = torch.sin(b) + y = TorchChainerAdd.apply(y, z) + y = torch.sum(y) + y.backward() + a_ct_grad = a.grad + b_ct_grad = b.grad + + # All in torch + a = torch.ones(10) + a.requires_grad = True + b = torch.ones(10)+2 + b.requires_grad = True + y = torch.sin(a) + z = torch.sin(b) + y = torch.add(y, z) + y = torch.sum(y) + y.backward() + assert torch.allclose(a_ct_grad, a.grad) + assert torch.allclose(b_ct_grad, b.grad)