diff --git a/tests/trace/test_dataset.py b/tests/trace/test_dataset.py index ff4056fdbd73..d55af19c0f37 100644 --- a/tests/trace/test_dataset.py +++ b/tests/trace/test_dataset.py @@ -1,6 +1,7 @@ import pytest import weave +from tests.trace.test_evaluate import Dataset def test_basic_dataset_lifecycle(client): @@ -33,3 +34,62 @@ def test_pythonic_access(client): with pytest.raises(IndexError): ds[-1] + + +def test_dataset_laziness(client): + dataset = Dataset(rows=[{"input": i} for i in range(300)]) + log = client.server.attribute_access_log + assert [l for l in log if not l.startswith("_")] == ["ensure_project_exists"] + client.server.attribute_access_log = [] + + length = len(dataset) + log = client.server.attribute_access_log + assert [l for l in log if not l.startswith("_")] == [] + + length2 = len(dataset) + log = client.server.attribute_access_log + assert [l for l in log if not l.startswith("_")] == [] + + assert length == length2 + + i = 0 + for row in dataset: + log = client.server.attribute_access_log + assert [l for l in log if not l.startswith("_")] == [] + i += 1 + + +def test_published_dataset_laziness(client): + dataset = Dataset(rows=[{"input": i} for i in range(300)]) + log = client.server.attribute_access_log + assert [l for l in log if not l.startswith("_")] == ["ensure_project_exists"] + client.server.attribute_access_log = [] + + ref = weave.publish(dataset) + log = client.server.attribute_access_log + assert [l for l in log if not l.startswith("_")] == ["table_create", "obj_create"] + client.server.attribute_access_log = [] + + dataset = ref.get() + log = client.server.attribute_access_log + assert [l for l in log if not l.startswith("_")] == ["obj_read"] + client.server.attribute_access_log = [] + + length = len(dataset) + log = client.server.attribute_access_log + assert [l for l in log if not l.startswith("_")] == ["table_query_stats"] + client.server.attribute_access_log = [] + + length2 = len(dataset) + log = client.server.attribute_access_log + assert [l for l in log if not l.startswith("_")] == [] + + assert length == length2 + + i = 0 + for row in dataset: + log = client.server.attribute_access_log + assert [l for l in log if not l.startswith("_")] == ["table_query"] * ( + (i // 100) + 1 + ) + i += 1 diff --git a/tests/trace/test_evaluate.py b/tests/trace/test_evaluate.py index 7d433fc613c8..80578fa96b25 100644 --- a/tests/trace/test_evaluate.py +++ b/tests/trace/test_evaluate.py @@ -302,19 +302,27 @@ def score_simple(input, output): return input == output log = client.server.attribute_access_log - # assert log == [] + assert [l for l in log if not l.startswith("_")] == [ + "ensure_project_exists", + "table_create", + "obj_create", + "obj_read", + ] + client.server.attribute_access_log = [] + evaluation = Evaluation( dataset=dataset, scorers=[score_simple], ) log = client.server.attribute_access_log - # assert log == [] + assert [l for l in log if not l.startswith("_")] == [] result = asyncio.run(evaluation.evaluate(model_predict)) assert result["output"] == {"mean": 149.5} assert result["score_simple"] == {"true_count": 300, "true_fraction": 1.0} log = client.server.attribute_access_log + log = [l for l in log if not l.startswith("_")] # Make sure that the length was figured out deterministically assert "table_query_stats" in log @@ -330,4 +338,4 @@ def score_simple(input, output): # However, the key part is that we have basically X + 2 splits, with the middle X # being equal. We want to ensure that the table_query is not called in sequence, # but rather lazily after each batch. - assert counts_split_by_table_query == [19, 700, 700, 700, 5] + assert counts_split_by_table_query == [13, 700, 700, 700, 5] diff --git a/weave/trace/vals.py b/weave/trace/vals.py index 06e17f24e3f6..6e981282332c 100644 --- a/weave/trace/vals.py +++ b/weave/trace/vals.py @@ -263,9 +263,16 @@ def __eq__(self, other: Any) -> bool: return self._val == other +ROWS_TYPES = Union[list[dict], Iterator[dict]] + + class WeaveTable(Traceable): filter: TableRowFilter _known_length: Optional[int] = None + _rows: Optional[ROWS_TYPES] = None + # _prefetched_rows is a local cache of rows that can be used to + # avoid a remote call. Should only be used by internal code. + _prefetched_rows: Optional[list[dict]] = None def __init__( self, @@ -282,14 +289,9 @@ def __init__( self.server = server self.root = root or self self.parent = parent - self._rows: Optional[Iterator[dict]] = None - - # _prefetched_rows is a local cache of rows that can be used to - # avoid a remote call. Should only be used by internal code. - self._prefetched_rows: Optional[list[dict]] = None @property - def rows(self) -> Iterator[dict]: + def rows(self) -> ROWS_TYPES: if self._rows is None: should_local_iter = ( self.ref is not None @@ -304,13 +306,21 @@ def rows(self) -> Iterator[dict]: return self._rows @rows.setter - def rows(self, value: list[dict]) -> None: + def rows(self, value: ROWS_TYPES) -> None: if not all(isinstance(row, dict) for row in value): raise ValueError("All table rows must be dicts") self._rows = value self._mark_dirty() + def _ensure_rows_are_local_list(self) -> list[dict]: + # Any uses of this are signs of a design problem arising + # from a remote table clashing with the need to feel like a + # local list. + if not isinstance(self.rows, list): + self.rows = list(self.rows) + return self.rows + def set_prefetched_rows(self, prefetched_rows: list[dict]) -> None: """Sets the rows to a local cache of rows that can be used to avoid a remote call. Should only be used by internal code. @@ -335,6 +345,7 @@ def __len__(self) -> int: self.table_ref is not None and self.table_ref._row_digests is not None and self._prefetched_rows is not None + and isinstance(self.table_ref._row_digests, list) ): if len(self._prefetched_rows) == len(self.table_ref._row_digests): self._known_length = len(self._prefetched_rows) @@ -345,9 +356,13 @@ def __len__(self) -> int: self._known_length = self._fetch_remote_length() return self._known_length - return len(self.rows) + rows_as_list = self._ensure_rows_are_local_list() + return len(rows_as_list) def _fetch_remote_length(self) -> int: + if self.table_ref is None: + raise ValueError("Cannot fetch remote length of table without table ref") + response = self.server.table_query_stats( TableQueryStatsReq( project_id=self.table_ref.project_id, digest=self.table_ref.digest @@ -462,7 +477,10 @@ def _remote_iter(self) -> Generator[dict, None, None]: page_index += 1 def __getitem__(self, key: Union[int, slice, str]) -> Any: - rows = self.rows + # TODO: we should have a better caching strategy that allows + # partial iteration over the table. + rows = self._ensure_rows_are_local_list() + if isinstance(key, (int, slice)): return rows[key] @@ -476,14 +494,16 @@ def __iter__(self) -> Iterator[dict]: return iter(self.rows) def append(self, val: dict) -> None: + rows = self._ensure_rows_are_local_list() if not isinstance(val, dict): raise TypeError("Can only append dicts to tables") self._mark_dirty() - self.rows.append(val) + rows.append(val) def pop(self, index: int) -> None: + rows = self._ensure_rows_are_local_list() self._mark_dirty() - self.rows.pop(index) + rows.pop(index) class WeaveList(Traceable, list):