Skip to content

Commit

Permalink
add a test for Theta.pop()
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Aug 29, 2024
1 parent 7f3c963 commit b93aa68
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/token_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ def __init__(
self.dtype = dtype

def forward(self, input: torch.Tensor):
return ops.embedding_lookup(input, self.weight.to(device=input.device), dtype=self.dtype)
return ops.embedding_lookup(input, self.weight, dtype=self.dtype)
16 changes: 16 additions & 0 deletions sharktank/tests/types/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,22 @@ def testTransform(self):
self.assertIsNot(pt1, pt2)
torch.testing.assert_close(pt1, pt2)

def testPop(self):
t1 = Theta(
_flat_t_dict(
_t("a.b.c", 1, 2),
_t("a.c.d", 10, 11),
_t("a.b.3", 3, 4),
)
)
popped = t1.pop("a.b").flatten()
t1 = t1.flatten()

self.assertIsNotNone("a.c.d", t1.keys())
self.assertNotIn("a.b.c", t1.keys())
self.assertNotIn("a.b.3", t1.keys())
self.assertIn("a.b.3", popped.keys())


class DatasetTest(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit b93aa68

Please sign in to comment.