Skip to content

Commit

Permalink
Add device assertion check
Browse files Browse the repository at this point in the history
  • Loading branch information
asi1024 committed Mar 4, 2020
1 parent 58a33bb commit 4ef9ac4
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,26 @@ def test_astensor_negative_stride():
def test_asarray_empty_cpu():
t = torch.tensor([], dtype=torch.float32)
a = tensor.asarray(t)
assert isinstance(a, numpy.ndarray)


def test_asarray_empty_gpu():
t = torch.tensor([], dtype=torch.float32, device='cuda')
a = tensor.asarray(t)
assert isinstance(a, cupy.ndarray)


def test_astensor_empty_cpu():
a = numpy.array([], dtype=numpy.float32)
t = tensor.astensor(a)
assert t.device.type == 'cpu'


def test_astensor_empty_gpu():
a = cupy.array([], dtype=cupy.float32)
t = tensor.astensor(a)
assert isinstance(t, torch.Tensor)
assert t.device.type == 'cuda'
t += 1
numpy.testing.assert_array_equal(a.get(), t.cpu().numpy())

Expand Down

0 comments on commit 4ef9ac4

Please sign in to comment.