diff --git a/sharktank/sharktank/layers/token_embedding.py b/sharktank/sharktank/layers/token_embedding.py index acf233e50..aeeb34127 100644 --- a/sharktank/sharktank/layers/token_embedding.py +++ b/sharktank/sharktank/layers/token_embedding.py @@ -23,4 +23,4 @@ def __init__( self.dtype = dtype def forward(self, input: torch.Tensor): - return ops.embedding_lookup(input, self.weight.to(device=input.device), dtype=self.dtype) + return ops.embedding_lookup(input, self.weight, dtype=self.dtype) diff --git a/sharktank/tests/types/dataset_test.py b/sharktank/tests/types/dataset_test.py index 82c9723f0..99d176c5b 100644 --- a/sharktank/tests/types/dataset_test.py +++ b/sharktank/tests/types/dataset_test.py @@ -77,6 +77,22 @@ def testTransform(self): self.assertIsNot(pt1, pt2) torch.testing.assert_close(pt1, pt2) + def testPop(self): + t1 = Theta( + _flat_t_dict( + _t("a.b.c", 1, 2), + _t("a.c.d", 10, 11), + _t("a.b.3", 3, 4), + ) + ) + popped = t1.pop("a.b").flatten() + t1 = t1.flatten() + + self.assertIsNotNone("a.c.d", t1.keys()) + self.assertNotIn("a.b.c", t1.keys()) + self.assertNotIn("a.b.3", t1.keys()) + self.assertIn("a.b.3", popped.keys()) + class DatasetTest(unittest.TestCase): def setUp(self):