diff --git a/chainer_pytorch_migration/tensor.py b/chainer_pytorch_migration/tensor.py index 2438b53..56fa105 100644 --- a/chainer_pytorch_migration/tensor.py +++ b/chainer_pytorch_migration/tensor.py @@ -72,7 +72,11 @@ def astensor(array): # If the array is not allocated (empty) # we just create a new one if array.data.ptr == 0: - return torch.empty(array.shape, dtype=to_torch_dtype(array.dtype)) + return torch.empty( + array.shape, + dtype=to_torch_dtype(array.dtype), + device=array.device.id + ) return torch.as_tensor( _ArrayWithCudaArrayInterfaceHavingStrides(array), device=array.device.id,