diff --git a/chainer_pytorch_migration/__init__.py b/chainer_pytorch_migration/__init__.py index 540b179..546ccd4 100644 --- a/chainer_pytorch_migration/__init__.py +++ b/chainer_pytorch_migration/__init__.py @@ -4,4 +4,4 @@ from .links import TorchModule from .parameter import ChainerParameter, LinkAsTorchModel, Optimizer from .tensor import asarray, astensor, to_numpy_dtype -from .device import to_chainer_device +from .device import to_chainer_device, to_torch_device diff --git a/chainer_pytorch_migration/device.py b/chainer_pytorch_migration/device.py index 30ba0da..a2fe99d 100644 --- a/chainer_pytorch_migration/device.py +++ b/chainer_pytorch_migration/device.py @@ -9,12 +9,32 @@ def to_chainer_device(device): device (torch.device): Device to be converted. Returns: - A ``torch.device`` object corresponding to the given input. + A ``chainer.device`` object corresponding to the given input. """ - assert isinstance(device, torch.device) + if not isinstance(device, torch.device): + raise TypeError('The argument should be torch device.') if device.type == 'cpu': return chainer.get_device('@numpy') if device.type == 'cuda': device_index = 0 if device.index is None else device.index return chainer.get_device('@cupy:{}'.format(device_index)) raise ValueError('{} is not supported.'.format(device.type)) + + +def to_torch_device(device): + """Create a torch device from a given chainer device. + + Args: + device (chainer.Device): Device to be converted. + + Returns: + A ``torch.device`` object corresponding to the given input. + """ + if not isinstance(device, chainer.backend.Device): + raise TypeError('The argument should be chainer device.') + if device.name == '@numpy': + return torch.device('cpu') + if device.name.startswith('@cupy:'): + cuda_device_index = int(device.name.split(':')[1]) + return torch.device('cuda', cuda_device_index) + raise ValueError('{} is not supported.'.format(device.name)) diff --git a/tests/test_device.py b/tests/test_device.py index ab0f21a..3cf984f 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -1,3 +1,4 @@ +import chainer import torch import chainer_pytorch_migration as cpm @@ -22,3 +23,20 @@ def test_to_chainer_device_gpu_1(): device = torch.device('cuda:1') chainer_device = cpm.to_chainer_device(device) assert chainer_device.name == '@cupy:1' + +def test_to_torch_device_cpu(): + device = chainer.get_device('@numpy') + torch_device = cpm.to_torch_device(device) + assert torch_device.type == 'cpu' + +def test_to_torch_device_gpu(): + device = chainer.get_device('@cupy:0') + torch_device = cpm.to_torch_device(device) + assert torch_device.type == 'cuda' + assert torch_device.index == 0 + +def test_to_torch_device_gpu_0(): + device = chainer.get_device('@cupy:1') + torch_device = cpm.to_torch_device(device) + assert torch_device.type == 'cuda' + assert torch_device.index == 1