Skip to content

Commit

Permalink
Merge pull request #19 from asi1024/device
Browse files Browse the repository at this point in the history
Add `to_chainer_device` and `LinkAsTorchModel.to(device)`
  • Loading branch information
niboshi authored Mar 3, 2020
2 parents 864dec5 + a420975 commit 830713f
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 0 deletions.
1 change: 1 addition & 0 deletions chainer_pytorch_migration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .links import TorchModule
from .parameter import ChainerParameter, LinkAsTorchModel, Optimizer
from .tensor import asarray, astensor, to_numpy_dtype
from .device import to_chainer_device, to_torch_device
40 changes: 40 additions & 0 deletions chainer_pytorch_migration/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import chainer
import torch


def to_chainer_device(device):
"""Create a chainer device from a given torch device.
Args:
device (torch.device): Device to be converted.
Returns:
A ``chainer.device`` object corresponding to the given input.
"""
if not isinstance(device, torch.device):
raise TypeError('The argument should be torch device.')
if device.type == 'cpu':
return chainer.get_device('@numpy')
if device.type == 'cuda':
device_index = 0 if device.index is None else device.index
return chainer.get_device('@cupy:{}'.format(device_index))
raise ValueError('{} is not supported.'.format(device.type))


def to_torch_device(device):
"""Create a torch device from a given chainer device.
Args:
device (chainer.Device): Device to be converted.
Returns:
A ``torch.device`` object corresponding to the given input.
"""
if not isinstance(device, chainer.backend.Device):
raise TypeError('The argument should be chainer device.')
if device.name == '@numpy':
return torch.device('cpu')
if device.name.startswith('@cupy:'):
cuda_device_index = int(device.name.split(':')[1])
return torch.device('cuda', cuda_device_index)
raise ValueError('{} is not supported.'.format(device.name))
42 changes: 42 additions & 0 deletions tests/test_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import chainer
import torch

import chainer_pytorch_migration as cpm


def test_to_chainer_device_cpu():
device = torch.device('cpu')
chainer_device = cpm.to_chainer_device(device)
assert chainer_device.name == '@numpy'

def test_to_chainer_device_gpu():
device = torch.device('cuda')
chainer_device = cpm.to_chainer_device(device)
assert chainer_device.name == '@cupy:0'

def test_to_chainer_device_gpu_0():
device = torch.device('cuda:0')
chainer_device = cpm.to_chainer_device(device)
assert chainer_device.name == '@cupy:0'

def test_to_chainer_device_gpu_1():
device = torch.device('cuda:1')
chainer_device = cpm.to_chainer_device(device)
assert chainer_device.name == '@cupy:1'

def test_to_torch_device_cpu():
device = chainer.get_device('@numpy')
torch_device = cpm.to_torch_device(device)
assert torch_device.type == 'cpu'

def test_to_torch_device_gpu():
device = chainer.get_device('@cupy:0')
torch_device = cpm.to_torch_device(device)
assert torch_device.type == 'cuda'
assert torch_device.index == 0

def test_to_torch_device_gpu_0():
device = chainer.get_device('@cupy:1')
torch_device = cpm.to_torch_device(device)
assert torch_device.type == 'cuda'
assert torch_device.index == 1

0 comments on commit 830713f

Please sign in to comment.