Skip to content

Commit

Permalink
Add to_torch_device
Browse files Browse the repository at this point in the history
  • Loading branch information
asi1024 committed Mar 3, 2020
1 parent bc77c00 commit a420975
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 3 deletions.
2 changes: 1 addition & 1 deletion chainer_pytorch_migration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .links import TorchModule
from .parameter import ChainerParameter, LinkAsTorchModel, Optimizer
from .tensor import asarray, astensor, to_numpy_dtype
from .device import to_chainer_device
from .device import to_chainer_device, to_torch_device
24 changes: 22 additions & 2 deletions chainer_pytorch_migration/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,32 @@ def to_chainer_device(device):
device (torch.device): Device to be converted.
Returns:
A ``torch.device`` object corresponding to the given input.
A ``chainer.device`` object corresponding to the given input.
"""
assert isinstance(device, torch.device)
if not isinstance(device, torch.device):
raise TypeError('The argument should be torch device.')
if device.type == 'cpu':
return chainer.get_device('@numpy')
if device.type == 'cuda':
device_index = 0 if device.index is None else device.index
return chainer.get_device('@cupy:{}'.format(device_index))
raise ValueError('{} is not supported.'.format(device.type))


def to_torch_device(device):
"""Create a torch device from a given chainer device.
Args:
device (chainer.Device): Device to be converted.
Returns:
A ``torch.device`` object corresponding to the given input.
"""
if not isinstance(device, chainer.backend.Device):
raise TypeError('The argument should be chainer device.')
if device.name == '@numpy':
return torch.device('cpu')
if device.name.startswith('@cupy:'):
cuda_device_index = int(device.name.split(':')[1])
return torch.device('cuda', cuda_device_index)
raise ValueError('{} is not supported.'.format(device.name))
18 changes: 18 additions & 0 deletions tests/test_device.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import chainer
import torch

import chainer_pytorch_migration as cpm
Expand All @@ -22,3 +23,20 @@ def test_to_chainer_device_gpu_1():
device = torch.device('cuda:1')
chainer_device = cpm.to_chainer_device(device)
assert chainer_device.name == '@cupy:1'

def test_to_torch_device_cpu():
device = chainer.get_device('@numpy')
torch_device = cpm.to_torch_device(device)
assert torch_device.type == 'cpu'

def test_to_torch_device_gpu():
device = chainer.get_device('@cupy:0')
torch_device = cpm.to_torch_device(device)
assert torch_device.type == 'cuda'
assert torch_device.index == 0

def test_to_torch_device_gpu_0():
device = chainer.get_device('@cupy:1')
torch_device = cpm.to_torch_device(device)
assert torch_device.type == 'cuda'
assert torch_device.index == 1

0 comments on commit a420975

Please sign in to comment.