From 097042092af3c02e22988d5078cc9f59e43be7b0 Mon Sep 17 00:00:00 2001 From: Akifumi Imanishi Date: Wed, 4 Mar 2020 11:44:44 +0000 Subject: [PATCH] Add tests --- tests/test_parameter.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 2009e03..dc2a605 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -274,3 +274,24 @@ def test_named_params(): numpy.testing.assert_array_equal(a_arr, n_params['a'].detach()) assert 'b' in n_params numpy.testing.assert_array_equal(b_arr, n_params['b'].detach()) + + +def test_link_to_device(): + a_arr = numpy.ones((3, 2), 'float32') + a_chainer_param = chainer.Parameter(a_arr) + # 0-size parameter + b_arr = numpy.ones((2, 0, 1), 'float32') + b_chainer_param = chainer.Parameter(b_arr) + + link = chainer.Link() + with link.init_scope(): + link.a = a_chainer_param + link.b = b_chainer_param + + torched = cpm.LinkAsTorchModel(link) + ret = torched.to('cuda') + + assert torched is ret + + for name, param in torched.named_parameters(): + assert param.device.type == 'cuda'