Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add chainer_torch_function and TorchCainerFunction #27

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions chainer_pytorch_migration/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from . import links
from .allocator import use_mempool_in_cupy_malloc, use_torch_in_cupy_malloc
from .datasets import TransformDataset
from .functions import chainer_torch_function
from .links import TorchModule
from .parameter import ChainerParameter, LinkAsTorchModel, Optimizer
from .tensor import asarray, astensor, to_numpy_dtype
Expand Down
123 changes: 123 additions & 0 deletions chainer_pytorch_migration/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import chainer
import torch

import chainer_pytorch_migration as cpm


class _ChainerTorchFunction(chainer.FunctionNode):
def __init__(self, torch_fn, *args, **kwargs):
self.torch_fn = torch_fn
self.torch_fwd_inputs = None
self.torch_fwd_outputs = None
self.args = args
self.kwargs = kwargs

def forward(self, inputs):
t_inputs = [cpm.astensor(x) for x in inputs]
for t_x in t_inputs:
t_x.requires_grad = True
self.torch_fwd_inputs = t_inputs
f_inputs = t_inputs + list(self.args)
# The torch function might require other arguments other than input
# tensors so append them here
t_outs = self.torch_fn(*f_inputs, **self.kwargs)
if type(t_outs) is not list and type(t_outs) is not tuple:
t_outs = (t_outs,)
self.torch_fwd_outputs = t_outs
# Need to access res from a chainer variable
c_outs = tuple(cpm.asarray(out) for out in t_outs)
# The outputs will be used in the grad function so we should retain
# them ?
self.retain_outputs(tuple(range(len(c_outs))))
return c_outs

def backward(self, indexes, grads):
out_grads = _ChainerTorchFunctionGrad(
self.torch_fwd_inputs, self.torch_fwd_outputs
).apply(grads)
return out_grads


class _ChainerTorchFunctionGrad(chainer.FunctionNode):
def __init__(self, inputs, outputs):
super(_ChainerTorchFunctionGrad, self).__init__()
self.inputs = inputs
self.outputs = outputs

def forward(self, inputs):
t_grads = tuple([cpm.astensor(g) for g in inputs])
torch.autograd.backward(self.outputs, t_grads)
out_grads = tuple(
cpm.asarray(x.grad) for x in self.inputs
)
self.outputs = [x.grad for x in self.inputs]
self.inputs = t_grads
return out_grads

def backward(self, indexes, grads):
return _ChainerTorchFunctionGrad(
self.inputs, self.outputs).apply(grads)


def chainer_torch_function(torch_fn, inputs, *args, **kwargs):
if type(inputs) is not list and type(inputs) is not tuple:
inputs = (inputs,)
y = _ChainerTorchFunction(torch_fn, *args, **kwargs).apply(inputs)
if len(y) == 1:
return y[0]
return y


class TorchChainerFunction(torch.autograd.Function):
@staticmethod
def chainer_fn():
raise RuntimeError('chainer_fn function must be overriden')

@classmethod
def forward(cls, ctx, *inputs):
chainer_fn = cls.chainer_fn()
ctx.save_for_backward(*inputs)
c_inputs = tuple((chainer.Variable(cpm.asarray(x)) for x in inputs))
ctx.c_inputs = c_inputs
c_outputs = chainer_fn(*c_inputs)
if not type(c_outputs) is tuple:
c_outputs = (c_outputs,)
t_outputs = [cpm.astensor(y.array) for y in c_outputs]
for t_y in t_outputs:
t_y.requires_grad = True
ctx.c_outputs = c_outputs
if len(t_outputs) == 1:
return t_outputs[0]
else:
return tuple(t_outputs)

@staticmethod
def backward(ctx, *grads):
grads = [ctx.c_outputs, ctx.c_inputs] + list(grads)
out_grads = _TorchChainerFunctionGrad.apply(*grads)
return out_grads


class _TorchChainerFunctionGrad(torch.autograd.Function):

@staticmethod
def forward(ctx, *inputs):
c_outputs = inputs[0]
c_inputs = inputs[1]
inputs = inputs[2:]
ctx.save_for_backward(*inputs)
c_grads = tuple((chainer.Variable(cpm.asarray(g)) for g in inputs))
fwd_outputs = c_outputs
chainer.backward(fwd_outputs, c_grads, enable_double_backprop=True)
out_grads = tuple(
cpm.astensor(x.grad) for x in c_inputs
)
for t_y in out_grads:
t_y.requires_grad = True
ctx.c_outputs = [x.grad for x in c_inputs]
ctx.c_inputs = c_grads
return out_grads

def backward(ctx, *grads):
grads = [ctx.c_outputs, ctx.c_inputs] + list(grads)
return _TorchChainerFunctionGrad.apply(*grads)
105 changes: 105 additions & 0 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import chainer
import numpy
import torch

import chainer_pytorch_migration as cpm


def test_one_output():
torch_fn = torch.sigmoid
x = chainer.Variable(numpy.ones((5, 5), dtype=numpy.float32))
z = chainer.functions.sin(x)
res = cpm.chainer_torch_function(torch_fn, z)
res = chainer.functions.sqrt(res)
res = cpm.chainer_torch_function(torch_fn, res)
res = chainer.functions.sqrt(res)
res.grad = numpy.ones((5, 5), dtype=numpy.float32)
res.backward()
c_grad = x.grad

# Do it now in pytorch and compare
x = torch.ones((5, 5), requires_grad=True)
z = torch.sin(x)
y = torch.sigmoid(torch.sigmoid(z).sqrt()).sqrt()
y.backward(torch.ones(5, 5))
t_grad = x.grad
assert torch.allclose(t_grad, cpm.astensor(c_grad))


def test_multiple_outputs():
torch_fn = torch.split
x = chainer.Variable(numpy.ones((6, 5), dtype=numpy.float32))
y = chainer.functions.sin(x)
y, z = cpm.chainer_torch_function(torch_fn, y, 3, dim=0)
y = chainer.functions.log(y)
z = chainer.functions.cos(z)
z = y + z
z.grad = numpy.ones((3, 5), dtype=numpy.float32)
z.backward()
c_grad = x.grad

x = torch.ones((6, 5), requires_grad=True)
z = torch.sin(x)
y, z = torch.split(z, 3, dim=0)
y = torch.log(y)
z = torch.cos(z)
z = y + z
z.backward(torch.ones((3, 5)))
t_grad = x.grad
assert torch.allclose(t_grad, cpm.astensor(c_grad))


def test_torch_chainer_function():
class TorchChainerSigmoid(cpm.functions.TorchChainerFunction):
@staticmethod
def chainer_fn():
return chainer.functions.sigmoid
# Combined torch
x = torch.ones(10)
x.requires_grad = True
y = torch.sin(x)
y = TorchChainerSigmoid.apply(y)
y = torch.sum(y)
y.backward()
ct_grad = x.grad

# All in torch
x = torch.ones(10)
x.requires_grad = True
y = torch.sin(x)
y = torch.sigmoid(y)
y = torch.sum(y)
y.backward()
assert torch.allclose(ct_grad, x.grad)


def test_torch_chainer_function_2():
class TorchChainerAdd(cpm.functions.TorchChainerFunction):
@staticmethod
def chainer_fn():
return chainer.functions.add
# Combined torch
a = torch.ones(10)
a.requires_grad = True
b = torch.ones(10)+2
b.requires_grad = True
y = torch.sin(a)
z = torch.sin(b)
y = TorchChainerAdd.apply(y, z)
y = torch.sum(y)
y.backward()
a_ct_grad = a.grad
b_ct_grad = b.grad

# All in torch
a = torch.ones(10)
a.requires_grad = True
b = torch.ones(10)+2
b.requires_grad = True
y = torch.sin(a)
z = torch.sin(b)
y = torch.add(y, z)
y = torch.sum(y)
y.backward()
assert torch.allclose(a_ct_grad, a.grad)
assert torch.allclose(b_ct_grad, b.grad)