diff --git a/chainer_pytorch_migration/__init__.py b/chainer_pytorch_migration/__init__.py index 095a6b6..546ccd4 100644 --- a/chainer_pytorch_migration/__init__.py +++ b/chainer_pytorch_migration/__init__.py @@ -4,3 +4,4 @@ from .links import TorchModule from .parameter import ChainerParameter, LinkAsTorchModel, Optimizer from .tensor import asarray, astensor, to_numpy_dtype +from .device import to_chainer_device, to_torch_device diff --git a/chainer_pytorch_migration/device.py b/chainer_pytorch_migration/device.py new file mode 100644 index 0000000..a2fe99d --- /dev/null +++ b/chainer_pytorch_migration/device.py @@ -0,0 +1,40 @@ +import chainer +import torch + + +def to_chainer_device(device): + """Create a chainer device from a given torch device. + + Args: + device (torch.device): Device to be converted. + + Returns: + A ``chainer.device`` object corresponding to the given input. + """ + if not isinstance(device, torch.device): + raise TypeError('The argument should be torch device.') + if device.type == 'cpu': + return chainer.get_device('@numpy') + if device.type == 'cuda': + device_index = 0 if device.index is None else device.index + return chainer.get_device('@cupy:{}'.format(device_index)) + raise ValueError('{} is not supported.'.format(device.type)) + + +def to_torch_device(device): + """Create a torch device from a given chainer device. + + Args: + device (chainer.Device): Device to be converted. + + Returns: + A ``torch.device`` object corresponding to the given input. + """ + if not isinstance(device, chainer.backend.Device): + raise TypeError('The argument should be chainer device.') + if device.name == '@numpy': + return torch.device('cpu') + if device.name.startswith('@cupy:'): + cuda_device_index = int(device.name.split(':')[1]) + return torch.device('cuda', cuda_device_index) + raise ValueError('{} is not supported.'.format(device.name)) diff --git a/tests/test_device.py b/tests/test_device.py new file mode 100644 index 0000000..3cf984f --- /dev/null +++ b/tests/test_device.py @@ -0,0 +1,42 @@ +import chainer +import torch + +import chainer_pytorch_migration as cpm + + +def test_to_chainer_device_cpu(): + device = torch.device('cpu') + chainer_device = cpm.to_chainer_device(device) + assert chainer_device.name == '@numpy' + +def test_to_chainer_device_gpu(): + device = torch.device('cuda') + chainer_device = cpm.to_chainer_device(device) + assert chainer_device.name == '@cupy:0' + +def test_to_chainer_device_gpu_0(): + device = torch.device('cuda:0') + chainer_device = cpm.to_chainer_device(device) + assert chainer_device.name == '@cupy:0' + +def test_to_chainer_device_gpu_1(): + device = torch.device('cuda:1') + chainer_device = cpm.to_chainer_device(device) + assert chainer_device.name == '@cupy:1' + +def test_to_torch_device_cpu(): + device = chainer.get_device('@numpy') + torch_device = cpm.to_torch_device(device) + assert torch_device.type == 'cpu' + +def test_to_torch_device_gpu(): + device = chainer.get_device('@cupy:0') + torch_device = cpm.to_torch_device(device) + assert torch_device.type == 'cuda' + assert torch_device.index == 0 + +def test_to_torch_device_gpu_0(): + device = chainer.get_device('@cupy:1') + torch_device = cpm.to_torch_device(device) + assert torch_device.type == 'cuda' + assert torch_device.index == 1