Skip to content

Commit

Permalink
Lint fix
Browse files Browse the repository at this point in the history
  • Loading branch information
tssweeney committed Jan 22, 2025
1 parent 4c99e73 commit 3a1cdf1
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 14 deletions.
60 changes: 60 additions & 0 deletions tests/trace/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

import weave
from tests.trace.test_evaluate import Dataset


def test_basic_dataset_lifecycle(client):
Expand Down Expand Up @@ -33,3 +34,62 @@ def test_pythonic_access(client):

with pytest.raises(IndexError):
ds[-1]


def test_dataset_laziness(client):
dataset = Dataset(rows=[{"input": i} for i in range(300)])
log = client.server.attribute_access_log
assert [l for l in log if not l.startswith("_")] == ["ensure_project_exists"]
client.server.attribute_access_log = []

length = len(dataset)
log = client.server.attribute_access_log
assert [l for l in log if not l.startswith("_")] == []

length2 = len(dataset)
log = client.server.attribute_access_log
assert [l for l in log if not l.startswith("_")] == []

assert length == length2

i = 0
for row in dataset:
log = client.server.attribute_access_log
assert [l for l in log if not l.startswith("_")] == []
i += 1


def test_published_dataset_laziness(client):
dataset = Dataset(rows=[{"input": i} for i in range(300)])
log = client.server.attribute_access_log
assert [l for l in log if not l.startswith("_")] == ["ensure_project_exists"]
client.server.attribute_access_log = []

ref = weave.publish(dataset)
log = client.server.attribute_access_log
assert [l for l in log if not l.startswith("_")] == ["table_create", "obj_create"]
client.server.attribute_access_log = []

dataset = ref.get()
log = client.server.attribute_access_log
assert [l for l in log if not l.startswith("_")] == ["obj_read"]
client.server.attribute_access_log = []

length = len(dataset)
log = client.server.attribute_access_log
assert [l for l in log if not l.startswith("_")] == ["table_query_stats"]
client.server.attribute_access_log = []

length2 = len(dataset)
log = client.server.attribute_access_log
assert [l for l in log if not l.startswith("_")] == []

assert length == length2

i = 0
for row in dataset:
log = client.server.attribute_access_log
assert [l for l in log if not l.startswith("_")] == ["table_query"] * (
(i // 100) + 1
)
i += 1
14 changes: 11 additions & 3 deletions tests/trace/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,19 +302,27 @@ def score_simple(input, output):
return input == output

log = client.server.attribute_access_log
# assert log == []
assert [l for l in log if not l.startswith("_")] == [
"ensure_project_exists",
"table_create",
"obj_create",
"obj_read",
]
client.server.attribute_access_log = []

evaluation = Evaluation(
dataset=dataset,
scorers=[score_simple],
)
log = client.server.attribute_access_log
# assert log == []
assert [l for l in log if not l.startswith("_")] == []

result = asyncio.run(evaluation.evaluate(model_predict))
assert result["output"] == {"mean": 149.5}
assert result["score_simple"] == {"true_count": 300, "true_fraction": 1.0}

log = client.server.attribute_access_log
log = [l for l in log if not l.startswith("_")]

# Make sure that the length was figured out deterministically
assert "table_query_stats" in log
Expand All @@ -330,4 +338,4 @@ def score_simple(input, output):
# However, the key part is that we have basically X + 2 splits, with the middle X
# being equal. We want to ensure that the table_query is not called in sequence,
# but rather lazily after each batch.
assert counts_split_by_table_query == [19, 700, 700, 700, 5]
assert counts_split_by_table_query == [13, 700, 700, 700, 5]
42 changes: 31 additions & 11 deletions weave/trace/vals.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,16 @@ def __eq__(self, other: Any) -> bool:
return self._val == other


ROWS_TYPES = Union[list[dict], Iterator[dict]]


class WeaveTable(Traceable):
filter: TableRowFilter
_known_length: Optional[int] = None
_rows: Optional[ROWS_TYPES] = None
# _prefetched_rows is a local cache of rows that can be used to
# avoid a remote call. Should only be used by internal code.
_prefetched_rows: Optional[list[dict]] = None

def __init__(
self,
Expand All @@ -282,14 +289,9 @@ def __init__(
self.server = server
self.root = root or self
self.parent = parent
self._rows: Optional[Iterator[dict]] = None

# _prefetched_rows is a local cache of rows that can be used to
# avoid a remote call. Should only be used by internal code.
self._prefetched_rows: Optional[list[dict]] = None

@property
def rows(self) -> Iterator[dict]:
def rows(self) -> ROWS_TYPES:
if self._rows is None:
should_local_iter = (
self.ref is not None
Expand All @@ -304,13 +306,21 @@ def rows(self) -> Iterator[dict]:
return self._rows

@rows.setter
def rows(self, value: list[dict]) -> None:
def rows(self, value: ROWS_TYPES) -> None:
if not all(isinstance(row, dict) for row in value):
raise ValueError("All table rows must be dicts")

self._rows = value
self._mark_dirty()

def _ensure_rows_are_local_list(self) -> list[dict]:
# Any uses of this are signs of a design problem arising
# from a remote table clashing with the need to feel like a
# local list.
if not isinstance(self.rows, list):
self.rows = list(self.rows)
return self.rows

def set_prefetched_rows(self, prefetched_rows: list[dict]) -> None:
"""Sets the rows to a local cache of rows that can be used to
avoid a remote call. Should only be used by internal code.
Expand All @@ -335,6 +345,7 @@ def __len__(self) -> int:
self.table_ref is not None
and self.table_ref._row_digests is not None
and self._prefetched_rows is not None
and isinstance(self.table_ref._row_digests, list)
):
if len(self._prefetched_rows) == len(self.table_ref._row_digests):
self._known_length = len(self._prefetched_rows)
Expand All @@ -345,9 +356,13 @@ def __len__(self) -> int:
self._known_length = self._fetch_remote_length()
return self._known_length

return len(self.rows)
rows_as_list = self._ensure_rows_are_local_list()
return len(rows_as_list)

def _fetch_remote_length(self) -> int:
if self.table_ref is None:
raise ValueError("Cannot fetch remote length of table without table ref")

response = self.server.table_query_stats(
TableQueryStatsReq(
project_id=self.table_ref.project_id, digest=self.table_ref.digest
Expand Down Expand Up @@ -462,7 +477,10 @@ def _remote_iter(self) -> Generator[dict, None, None]:
page_index += 1

def __getitem__(self, key: Union[int, slice, str]) -> Any:
rows = self.rows
# TODO: we should have a better caching strategy that allows
# partial iteration over the table.
rows = self._ensure_rows_are_local_list()

if isinstance(key, (int, slice)):
return rows[key]

Expand All @@ -476,14 +494,16 @@ def __iter__(self) -> Iterator[dict]:
return iter(self.rows)

def append(self, val: dict) -> None:
rows = self._ensure_rows_are_local_list()
if not isinstance(val, dict):
raise TypeError("Can only append dicts to tables")
self._mark_dirty()
self.rows.append(val)
rows.append(val)

def pop(self, index: int) -> None:
rows = self._ensure_rows_are_local_list()
self._mark_dirty()
self.rows.pop(index)
rows.pop(index)


class WeaveList(Traceable, list):
Expand Down

0 comments on commit 3a1cdf1

Please sign in to comment.