From c70018d99b50b4a6d47ac3005db0705a08f7c8a0 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 21 Jan 2025 17:52:25 -0800 Subject: [PATCH 01/16] working --- tests/trace/test_evaluate.py | 23 ++++++ tests/trace/test_flow_util.py | 131 ++++++++++++++++++++++++++++++++++ weave/flow/dataset.py | 11 +-- weave/flow/eval.py | 14 +++- weave/flow/util.py | 75 +++++++++++++++++-- 5 files changed, 243 insertions(+), 11 deletions(-) create mode 100644 tests/trace/test_flow_util.py diff --git a/tests/trace/test_evaluate.py b/tests/trace/test_evaluate.py index 002ed34fee3a..21531b0e5503 100644 --- a/tests/trace/test_evaluate.py +++ b/tests/trace/test_evaluate.py @@ -286,3 +286,26 @@ def score(self, target, model_output): result = asyncio.run(evaluation.evaluate(model)) assert result["my-scorer"] == {"true_count": 1, "true_fraction": 0.5} + + +def test_evaluate_table_lazy_iter(client): + dataset = Dataset(rows=[{"input": i} for i in range(300)]) + ref = weave.publish(dataset) + dataset = ref.get() + + @weave.op() + async def model_predict(input) -> int: + return input * 1 + + log = client.server.attribute_access_log + # assert log == [] + evaluation = Evaluation( + dataset=dataset, + scorers=[score], + ) + log = client.server.attribute_access_log + # assert log == [] + + result = asyncio.run(evaluation.evaluate(model_predict)) + log = client.server.attribute_access_log + # assert log == [] diff --git a/tests/trace/test_flow_util.py b/tests/trace/test_flow_util.py new file mode 100644 index 000000000000..10b473b2f5d7 --- /dev/null +++ b/tests/trace/test_flow_util.py @@ -0,0 +1,131 @@ +import asyncio + +import pytest + +from weave.flow.util import async_foreach + + +@pytest.mark.asyncio +async def test_async_foreach_basic(): + """Test basic functionality of async_foreach.""" + input_data = range(5) + results = [] + + async def process(x: int) -> int: + await asyncio.sleep(0.1) # Simulate async work + return x * 2 + + async for item, result in async_foreach( + input_data, process, max_concurrent_tasks=2 + ): + results.append((item, result)) + + assert len(results) == 5 + assert all(result == item * 2 for item, result in results) + assert [item for item, _ in results] == list(range(5)) + + +@pytest.mark.asyncio +async def test_async_foreach_concurrency(): + """Test that max_concurrent_tasks is respected.""" + currently_running = 0 + max_running = 0 + input_data = range(10) + + async def process(x: int) -> int: + nonlocal currently_running, max_running + currently_running += 1 + max_running = max(max_running, currently_running) + await asyncio.sleep(0.1) # Simulate async work + currently_running -= 1 + return x + + max_concurrent = 3 + async for _, _ in async_foreach( + input_data, process, max_concurrent_tasks=max_concurrent + ): + pass + + assert max_running == max_concurrent + + +@pytest.mark.asyncio +async def test_async_foreach_lazy_loading(): + """Test that items are loaded lazily from the iterator.""" + items_loaded = 0 + + def lazy_range(n: int): + nonlocal items_loaded + for i in range(n): + items_loaded += 1 + yield i + + async def process(x: int) -> int: + await asyncio.sleep(0.1) + return x + + # Process first 3 items then break + async for _, _ in async_foreach(lazy_range(100), process, max_concurrent_tasks=2): + if items_loaded >= 3: + break + + # Should have loaded at most 4 items (3 + 1 for concurrency) + assert items_loaded <= 4 + + +@pytest.mark.asyncio +async def test_async_foreach_error_handling(): + """Test error handling in async_foreach.""" + input_data = range(5) + + async def process(x: int) -> int: + if x == 3: + raise ValueError("Test error") + return x + + with pytest.raises(ValueError, match="Test error"): + async for _, _ in async_foreach(input_data, process, max_concurrent_tasks=2): + pass + + +@pytest.mark.asyncio +async def test_async_foreach_empty_input(): + """Test behavior with empty input sequence.""" + results = [] + + async def process(x: int) -> int: + return x + + async for item, result in async_foreach([], process, max_concurrent_tasks=2): + results.append((item, result)) + + assert len(results) == 0 + + +@pytest.mark.asyncio +async def test_async_foreach_cancellation(): + """Test that tasks are properly cleaned up on cancellation.""" + input_data = range(100) + results = [] + + async def slow_process(x: int) -> int: + await asyncio.sleep(0.5) # Longer delay to ensure tasks are running + return x + + # Create a task we can cancel + async def run_foreach(): + async for item, result in async_foreach( + input_data, slow_process, max_concurrent_tasks=3 + ): + results.append((item, result)) + if len(results) >= 2: # Cancel after 2 results + raise asyncio.CancelledError() + + with pytest.raises(asyncio.CancelledError): + await run_foreach() + + # Give a moment for any lingering tasks to complete if cleanup failed + await asyncio.sleep(0.1) + + # Check that we got the expected number of results before cancellation + assert len(results) == 2 diff --git a/weave/flow/dataset.py b/weave/flow/dataset.py index d8090d6d27a8..4f64a341e39b 100644 --- a/weave/flow/dataset.py +++ b/weave/flow/dataset.py @@ -1,11 +1,12 @@ from collections.abc import Iterator -from typing import Any +from typing import Any, Union from pydantic import field_validator from typing_extensions import Self import weave from weave.flow.obj import Object +from weave.trace.isinstance import weave_isinstance from weave.trace.objectify import register_object from weave.trace.vals import WeaveObject, WeaveTable @@ -43,7 +44,7 @@ class Dataset(Object): ``` """ - rows: weave.Table + rows: Union[weave.Table, WeaveTable] @classmethod def from_obj(cls, obj: WeaveObject) -> Self: @@ -55,11 +56,11 @@ def from_obj(cls, obj: WeaveObject) -> Self: ) @field_validator("rows", mode="before") - def convert_to_table(cls, rows: Any) -> weave.Table: + def convert_to_table(cls, rows: Any) -> Union[weave.Table, WeaveTable]: + if weave_isinstance(rows, WeaveTable): + return rows if not isinstance(rows, weave.Table): table_ref = getattr(rows, "table_ref", None) - if isinstance(rows, WeaveTable): - rows = list(rows) rows = weave.Table(rows) if table_ref: rows.table_ref = table_ref diff --git a/weave/flow/eval.py b/weave/flow/eval.py index 38988d2aa3f9..ce92099c7148 100644 --- a/weave/flow/eval.py +++ b/weave/flow/eval.py @@ -1,8 +1,9 @@ import asyncio import logging import traceback +from collections.abc import Iterable from datetime import datetime -from typing import Any, Callable, Literal, Optional, Union, cast +from typing import Any, Callable, Literal, Optional, TypeVar, Union, cast from pydantic import PrivateAttr, model_validator from rich import print @@ -278,7 +279,7 @@ async def eval_example(example: dict) -> dict: # with console.status("Evaluating...") as status: dataset = self._post_init_dataset _rows = dataset.rows - trial_rows = list(_rows) * self.trials + trial_rows = list(repeated_iterable(_rows, self.trials)) async for example, eval_row in util.async_foreach( trial_rows, eval_example, get_weave_parallelism() ): @@ -345,3 +346,12 @@ def is_valid_model(model: Any) -> bool: and is_op(model.predict) ) ) + + +T = TypeVar("T") + + +def repeated_iterable(iterable: Iterable[T], n: int) -> Iterable[T]: + for val in iterable: + for _ in range(n): + yield val diff --git a/weave/flow/util.py b/weave/flow/util.py index ba35d5ebe4ab..6cf7da3efc9b 100644 --- a/weave/flow/util.py +++ b/weave/flow/util.py @@ -16,18 +16,85 @@ async def async_foreach( func: Callable[[T], Awaitable[U]], max_concurrent_tasks: int, ) -> AsyncIterator[tuple[T, U]]: + """Process items from a sequence concurrently with a maximum number of parallel tasks. + + This function loads items from the input sequence lazily to support large or infinite + sequences. + + Args: + sequence: An iterable of items to process. Items are loaded lazily. + func: An async function that processes each item from the sequence. + max_concurrent_tasks: Maximum number of items to process concurrently. + + Yields: + Tuples of (original_item, processed_result). + + Example: + ```python + async def process(x: int) -> str: + await asyncio.sleep(1) # Simulate async work + return str(x * 2) + + async for item, result in async_foreach(range(10), process, max_concurrent_tasks=3): + print(f"Processed {item} -> {result}") + ``` + + Notes: + - If func raises an exception, it will be propagated to the caller + - Memory usage is bounded by max_concurrent_tasks + - All pending tasks are properly cleaned up on error or cancellation + """ semaphore = asyncio.Semaphore(max_concurrent_tasks) + active_tasks: set[asyncio.Task] = set() async def process_item(item: T) -> tuple[T, U]: + """Process a single item using the provided function with semaphore control.""" async with semaphore: result = await func(item) return item, result - tasks = [asyncio.create_task(process_item(item)) for item in sequence] + def maybe_queue_next_task() -> None: + """Attempt to queue the next task from the iterator if available.""" + try: + item = next(iterator) + task = asyncio.create_task(process_item(item)) + active_tasks.add(task) + except StopIteration: + pass - for task in asyncio.as_completed(tasks): - item, result = await task - yield item, result + iterator = iter(sequence) + + try: + # Prime the initial set of tasks + for _ in range(max_concurrent_tasks): + maybe_queue_next_task() + + while active_tasks: + done, _ = await asyncio.wait( + active_tasks, return_when=asyncio.FIRST_COMPLETED + ) + + for task in done: + active_tasks.remove(task) # Remove task after we know it's done + try: + item, result = await task + yield item, result + + # Add a new task if there are more items + maybe_queue_next_task() + except Exception: + # Clean up remaining tasks before re-raising + for t in active_tasks: + t.cancel() + await asyncio.gather(*active_tasks, return_exceptions=True) + raise + + except asyncio.CancelledError: + # Clean up tasks if the caller cancels this coroutine + for task in active_tasks: + task.cancel() + await asyncio.gather(*active_tasks, return_exceptions=True) + raise def _subproc( From a0956f7a44efee5e39b10d9404f9010a7b55296e Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 21 Jan 2025 19:35:45 -0800 Subject: [PATCH 02/16] more --- tests/trace/test_evaluate.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/trace/test_evaluate.py b/tests/trace/test_evaluate.py index 21531b0e5503..372bebcf4cc0 100644 --- a/tests/trace/test_evaluate.py +++ b/tests/trace/test_evaluate.py @@ -297,11 +297,15 @@ def test_evaluate_table_lazy_iter(client): async def model_predict(input) -> int: return input * 1 + @weave.op() + def score_simple(input, output): + return input == output + log = client.server.attribute_access_log # assert log == [] evaluation = Evaluation( dataset=dataset, - scorers=[score], + scorers=[score_simple], ) log = client.server.attribute_access_log # assert log == [] From c6974ee5c6d413e1328de1b939e52cf46ff48f0c Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 21 Jan 2025 20:03:07 -0800 Subject: [PATCH 03/16] further refinement --- tests/trace/test_evaluate.py | 16 +++++++++++++++- weave/flow/eval.py | 6 ++++-- weave/trace/vals.py | 8 ++++---- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/tests/trace/test_evaluate.py b/tests/trace/test_evaluate.py index 372bebcf4cc0..e40107fac9e0 100644 --- a/tests/trace/test_evaluate.py +++ b/tests/trace/test_evaluate.py @@ -311,5 +311,19 @@ def score_simple(input, output): # assert log == [] result = asyncio.run(evaluation.evaluate(model_predict)) + assert result["output"] == {"mean": 149.5} + assert result["score_simple"] == {"true_count": 300, "true_fraction": 1.0} + log = client.server.attribute_access_log - # assert log == [] + counts_split_by_table_query = [0] + for log_entry in log: + if log_entry == "table_query": + counts_split_by_table_query.append(0) + else: + counts_split_by_table_query[-1] += 1 + + # Note: these exact numbers might change if we change the way eval traces work. + # However, the key part is that we have basically X + 2 splits, with the middle X + # being equal. We want to ensure that the table_query is not called in sequence, + # but rather lazily after each batch. + assert counts_split_by_table_query == [18, 700, 700, 700, 5] diff --git a/weave/flow/eval.py b/weave/flow/eval.py index ce92099c7148..beb66005b149 100644 --- a/weave/flow/eval.py +++ b/weave/flow/eval.py @@ -279,12 +279,14 @@ async def eval_example(example: dict) -> dict: # with console.status("Evaluating...") as status: dataset = self._post_init_dataset _rows = dataset.rows - trial_rows = list(repeated_iterable(_rows, self.trials)) + # TODO: This is incorrect + num_rows = 10 * self.trials + trial_rows = repeated_iterable(_rows, self.trials) async for example, eval_row in util.async_foreach( trial_rows, eval_example, get_weave_parallelism() ): n_complete += 1 - print(f"Evaluated {n_complete} of {len(trial_rows)} examples") + print(f"Evaluated {n_complete} of {num_rows} examples") # status.update( # f"Evaluating... {duration:.2f}s [{n_complete} / {len(self.dataset.rows)} complete]" # type:ignore # ) diff --git a/weave/trace/vals.py b/weave/trace/vals.py index 0a38cbe5a2f9..367a0fba97c5 100644 --- a/weave/trace/vals.py +++ b/weave/trace/vals.py @@ -280,14 +280,14 @@ def __init__( self.server = server self.root = root or self self.parent = parent - self._rows: Optional[list[dict]] = None + self._rows: Optional[Iterator[dict]] = None # _prefetched_rows is a local cache of rows that can be used to # avoid a remote call. Should only be used by internal code. self._prefetched_rows: Optional[list[dict]] = None @property - def rows(self) -> list[dict]: + def rows(self) -> Iterator[dict]: if self._rows is None: should_local_iter = ( self.ref is not None @@ -296,9 +296,9 @@ def rows(self) -> list[dict]: and self._prefetched_rows is not None ) if should_local_iter: - self._rows = list(self._local_iter_with_remote_fallback()) + self._rows = iter(self._local_iter_with_remote_fallback()) else: - self._rows = list(self._remote_iter()) + self._rows = iter(self._remote_iter()) return self._rows @rows.setter From 4c99e73f7576231e4095c4aac524e0443ce42f3e Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 21 Jan 2025 20:23:23 -0800 Subject: [PATCH 04/16] further refinement --- tests/trace/test_evaluate.py | 6 +++++- weave/flow/dataset.py | 3 +-- weave/flow/eval.py | 3 +-- weave/trace/vals.py | 30 ++++++++++++++++++++++++++++++ 4 files changed, 37 insertions(+), 5 deletions(-) diff --git a/tests/trace/test_evaluate.py b/tests/trace/test_evaluate.py index e40107fac9e0..7d433fc613c8 100644 --- a/tests/trace/test_evaluate.py +++ b/tests/trace/test_evaluate.py @@ -315,6 +315,10 @@ def score_simple(input, output): assert result["score_simple"] == {"true_count": 300, "true_fraction": 1.0} log = client.server.attribute_access_log + + # Make sure that the length was figured out deterministically + assert "table_query_stats" in log + counts_split_by_table_query = [0] for log_entry in log: if log_entry == "table_query": @@ -326,4 +330,4 @@ def score_simple(input, output): # However, the key part is that we have basically X + 2 splits, with the middle X # being equal. We want to ensure that the table_query is not called in sequence, # but rather lazily after each batch. - assert counts_split_by_table_query == [18, 700, 700, 700, 5] + assert counts_split_by_table_query == [19, 700, 700, 700, 5] diff --git a/weave/flow/dataset.py b/weave/flow/dataset.py index 4f64a341e39b..f6c87774367f 100644 --- a/weave/flow/dataset.py +++ b/weave/flow/dataset.py @@ -84,8 +84,7 @@ def __iter__(self) -> Iterator[dict]: return iter(self.rows) def __len__(self) -> int: - # TODO: This can be slow for large datasets... - return len(list(self.rows)) + return len(self.rows) def __getitem__(self, key: int) -> dict: if key < 0: diff --git a/weave/flow/eval.py b/weave/flow/eval.py index beb66005b149..d998f5a44d32 100644 --- a/weave/flow/eval.py +++ b/weave/flow/eval.py @@ -279,8 +279,7 @@ async def eval_example(example: dict) -> dict: # with console.status("Evaluating...") as status: dataset = self._post_init_dataset _rows = dataset.rows - # TODO: This is incorrect - num_rows = 10 * self.trials + num_rows = len(_rows) * self.trials trial_rows = repeated_iterable(_rows, self.trials) async for example, eval_row in util.async_foreach( trial_rows, eval_example, get_weave_parallelism() diff --git a/weave/trace/vals.py b/weave/trace/vals.py index 367a0fba97c5..06e17f24e3f6 100644 --- a/weave/trace/vals.py +++ b/weave/trace/vals.py @@ -33,6 +33,7 @@ from weave.trace_server.trace_server_interface import ( ObjReadReq, TableQueryReq, + TableQueryStatsReq, TableRowFilter, TraceServerInterface, ) @@ -264,6 +265,7 @@ def __eq__(self, other: Any) -> bool: class WeaveTable(Traceable): filter: TableRowFilter + _known_length: Optional[int] = None def __init__( self, @@ -324,14 +326,42 @@ def set_prefetched_rows(self, prefetched_rows: list[dict]) -> None: self._prefetched_rows = prefetched_rows def __len__(self) -> int: + # This should e a single query + if self._known_length is not None: + return self._known_length + + # Condition 1: we already have all the rows in memory + if ( + self.table_ref is not None + and self.table_ref._row_digests is not None + and self._prefetched_rows is not None + ): + if len(self._prefetched_rows) == len(self.table_ref._row_digests): + self._known_length = len(self._prefetched_rows) + return self._known_length + + # Condition 2: We don't know the length, in which case we can get it from the server + if self.table_ref is not None: + self._known_length = self._fetch_remote_length() + return self._known_length + return len(self.rows) + def _fetch_remote_length(self) -> int: + response = self.server.table_query_stats( + TableQueryStatsReq( + project_id=self.table_ref.project_id, digest=self.table_ref.digest + ) + ) + return response.count + def __eq__(self, other: Any) -> bool: return self.rows == other def _mark_dirty(self) -> None: self.table_ref = None self._prefetched_rows = None + self._known_length = None super()._mark_dirty() def _local_iter_with_remote_fallback(self) -> Generator[dict, None, None]: From 3a1cdf1132dc6a48debed6d7736c9ae63c3b78fb Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 21 Jan 2025 22:00:09 -0800 Subject: [PATCH 05/16] Lint fix --- tests/trace/test_dataset.py | 60 ++++++++++++++++++++++++++++++++++++ tests/trace/test_evaluate.py | 14 +++++++-- weave/trace/vals.py | 42 ++++++++++++++++++------- 3 files changed, 102 insertions(+), 14 deletions(-) diff --git a/tests/trace/test_dataset.py b/tests/trace/test_dataset.py index ff4056fdbd73..d55af19c0f37 100644 --- a/tests/trace/test_dataset.py +++ b/tests/trace/test_dataset.py @@ -1,6 +1,7 @@ import pytest import weave +from tests.trace.test_evaluate import Dataset def test_basic_dataset_lifecycle(client): @@ -33,3 +34,62 @@ def test_pythonic_access(client): with pytest.raises(IndexError): ds[-1] + + +def test_dataset_laziness(client): + dataset = Dataset(rows=[{"input": i} for i in range(300)]) + log = client.server.attribute_access_log + assert [l for l in log if not l.startswith("_")] == ["ensure_project_exists"] + client.server.attribute_access_log = [] + + length = len(dataset) + log = client.server.attribute_access_log + assert [l for l in log if not l.startswith("_")] == [] + + length2 = len(dataset) + log = client.server.attribute_access_log + assert [l for l in log if not l.startswith("_")] == [] + + assert length == length2 + + i = 0 + for row in dataset: + log = client.server.attribute_access_log + assert [l for l in log if not l.startswith("_")] == [] + i += 1 + + +def test_published_dataset_laziness(client): + dataset = Dataset(rows=[{"input": i} for i in range(300)]) + log = client.server.attribute_access_log + assert [l for l in log if not l.startswith("_")] == ["ensure_project_exists"] + client.server.attribute_access_log = [] + + ref = weave.publish(dataset) + log = client.server.attribute_access_log + assert [l for l in log if not l.startswith("_")] == ["table_create", "obj_create"] + client.server.attribute_access_log = [] + + dataset = ref.get() + log = client.server.attribute_access_log + assert [l for l in log if not l.startswith("_")] == ["obj_read"] + client.server.attribute_access_log = [] + + length = len(dataset) + log = client.server.attribute_access_log + assert [l for l in log if not l.startswith("_")] == ["table_query_stats"] + client.server.attribute_access_log = [] + + length2 = len(dataset) + log = client.server.attribute_access_log + assert [l for l in log if not l.startswith("_")] == [] + + assert length == length2 + + i = 0 + for row in dataset: + log = client.server.attribute_access_log + assert [l for l in log if not l.startswith("_")] == ["table_query"] * ( + (i // 100) + 1 + ) + i += 1 diff --git a/tests/trace/test_evaluate.py b/tests/trace/test_evaluate.py index 7d433fc613c8..80578fa96b25 100644 --- a/tests/trace/test_evaluate.py +++ b/tests/trace/test_evaluate.py @@ -302,19 +302,27 @@ def score_simple(input, output): return input == output log = client.server.attribute_access_log - # assert log == [] + assert [l for l in log if not l.startswith("_")] == [ + "ensure_project_exists", + "table_create", + "obj_create", + "obj_read", + ] + client.server.attribute_access_log = [] + evaluation = Evaluation( dataset=dataset, scorers=[score_simple], ) log = client.server.attribute_access_log - # assert log == [] + assert [l for l in log if not l.startswith("_")] == [] result = asyncio.run(evaluation.evaluate(model_predict)) assert result["output"] == {"mean": 149.5} assert result["score_simple"] == {"true_count": 300, "true_fraction": 1.0} log = client.server.attribute_access_log + log = [l for l in log if not l.startswith("_")] # Make sure that the length was figured out deterministically assert "table_query_stats" in log @@ -330,4 +338,4 @@ def score_simple(input, output): # However, the key part is that we have basically X + 2 splits, with the middle X # being equal. We want to ensure that the table_query is not called in sequence, # but rather lazily after each batch. - assert counts_split_by_table_query == [19, 700, 700, 700, 5] + assert counts_split_by_table_query == [13, 700, 700, 700, 5] diff --git a/weave/trace/vals.py b/weave/trace/vals.py index 06e17f24e3f6..6e981282332c 100644 --- a/weave/trace/vals.py +++ b/weave/trace/vals.py @@ -263,9 +263,16 @@ def __eq__(self, other: Any) -> bool: return self._val == other +ROWS_TYPES = Union[list[dict], Iterator[dict]] + + class WeaveTable(Traceable): filter: TableRowFilter _known_length: Optional[int] = None + _rows: Optional[ROWS_TYPES] = None + # _prefetched_rows is a local cache of rows that can be used to + # avoid a remote call. Should only be used by internal code. + _prefetched_rows: Optional[list[dict]] = None def __init__( self, @@ -282,14 +289,9 @@ def __init__( self.server = server self.root = root or self self.parent = parent - self._rows: Optional[Iterator[dict]] = None - - # _prefetched_rows is a local cache of rows that can be used to - # avoid a remote call. Should only be used by internal code. - self._prefetched_rows: Optional[list[dict]] = None @property - def rows(self) -> Iterator[dict]: + def rows(self) -> ROWS_TYPES: if self._rows is None: should_local_iter = ( self.ref is not None @@ -304,13 +306,21 @@ def rows(self) -> Iterator[dict]: return self._rows @rows.setter - def rows(self, value: list[dict]) -> None: + def rows(self, value: ROWS_TYPES) -> None: if not all(isinstance(row, dict) for row in value): raise ValueError("All table rows must be dicts") self._rows = value self._mark_dirty() + def _ensure_rows_are_local_list(self) -> list[dict]: + # Any uses of this are signs of a design problem arising + # from a remote table clashing with the need to feel like a + # local list. + if not isinstance(self.rows, list): + self.rows = list(self.rows) + return self.rows + def set_prefetched_rows(self, prefetched_rows: list[dict]) -> None: """Sets the rows to a local cache of rows that can be used to avoid a remote call. Should only be used by internal code. @@ -335,6 +345,7 @@ def __len__(self) -> int: self.table_ref is not None and self.table_ref._row_digests is not None and self._prefetched_rows is not None + and isinstance(self.table_ref._row_digests, list) ): if len(self._prefetched_rows) == len(self.table_ref._row_digests): self._known_length = len(self._prefetched_rows) @@ -345,9 +356,13 @@ def __len__(self) -> int: self._known_length = self._fetch_remote_length() return self._known_length - return len(self.rows) + rows_as_list = self._ensure_rows_are_local_list() + return len(rows_as_list) def _fetch_remote_length(self) -> int: + if self.table_ref is None: + raise ValueError("Cannot fetch remote length of table without table ref") + response = self.server.table_query_stats( TableQueryStatsReq( project_id=self.table_ref.project_id, digest=self.table_ref.digest @@ -462,7 +477,10 @@ def _remote_iter(self) -> Generator[dict, None, None]: page_index += 1 def __getitem__(self, key: Union[int, slice, str]) -> Any: - rows = self.rows + # TODO: we should have a better caching strategy that allows + # partial iteration over the table. + rows = self._ensure_rows_are_local_list() + if isinstance(key, (int, slice)): return rows[key] @@ -476,14 +494,16 @@ def __iter__(self) -> Iterator[dict]: return iter(self.rows) def append(self, val: dict) -> None: + rows = self._ensure_rows_are_local_list() if not isinstance(val, dict): raise TypeError("Can only append dicts to tables") self._mark_dirty() - self.rows.append(val) + rows.append(val) def pop(self, index: int) -> None: + rows = self._ensure_rows_are_local_list() self._mark_dirty() - self.rows.pop(index) + rows.pop(index) class WeaveList(Traceable, list): From 3ec7a08ac4739c3898149746fd444ea6a17212a7 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 22 Jan 2025 11:42:43 -0800 Subject: [PATCH 06/16] Functionality Checkpoint --- tests/trace/test_dataset.py | 17 +++++++++++++++++ tests/trace/test_evaluate.py | 5 +++++ weave/trace/vals.py | 28 ++++++++++++++++++---------- 3 files changed, 40 insertions(+), 10 deletions(-) diff --git a/tests/trace/test_dataset.py b/tests/trace/test_dataset.py index d55af19c0f37..2a8dc03cce6b 100644 --- a/tests/trace/test_dataset.py +++ b/tests/trace/test_dataset.py @@ -37,6 +37,10 @@ def test_pythonic_access(client): def test_dataset_laziness(client): + """ + The intention of this test is to show that local construction of + a dataset does not trigger any remote operations. + """ dataset = Dataset(rows=[{"input": i} for i in range(300)]) log = client.server.attribute_access_log assert [l for l in log if not l.startswith("_")] == ["ensure_project_exists"] @@ -60,6 +64,12 @@ def test_dataset_laziness(client): def test_published_dataset_laziness(client): + """ + The intention of this test is to show that publishing a dataset, + then iterating through the "gotten" version of the dataset has + minimal remote operations - and importantly delays the fetching + of the rows until they are actually needed. + """ dataset = Dataset(rows=[{"input": i} for i in range(300)]) log = client.server.attribute_access_log assert [l for l in log if not l.startswith("_")] == ["ensure_project_exists"] @@ -89,6 +99,13 @@ def test_published_dataset_laziness(client): i = 0 for row in dataset: log = client.server.attribute_access_log + # This is the critical part of the test - ensuring that + # the rows are only fetched when they are actually needed. + # + # In a future improvement, we might eagerly fetch the next + # page of results, which would result in this assertion changing + # in that there would always be one more "table_query" than + # the number of pages. assert [l for l in log if not l.startswith("_")] == ["table_query"] * ( (i // 100) + 1 ) diff --git a/tests/trace/test_evaluate.py b/tests/trace/test_evaluate.py index 80578fa96b25..2a2ca4d407d2 100644 --- a/tests/trace/test_evaluate.py +++ b/tests/trace/test_evaluate.py @@ -289,6 +289,11 @@ def score(self, target, model_output): def test_evaluate_table_lazy_iter(client): + """ + The intention of this test is to show that an evaluation harness + lazily fetches rows from a table rather than eagerly fetching all + rows up front. + """ dataset = Dataset(rows=[{"input": i} for i in range(300)]) ref = weave.publish(dataset) dataset = ref.get() diff --git a/weave/trace/vals.py b/weave/trace/vals.py index 6e981282332c..f90df4eec95f 100644 --- a/weave/trace/vals.py +++ b/weave/trace/vals.py @@ -313,10 +313,14 @@ def rows(self, value: ROWS_TYPES) -> None: self._rows = value self._mark_dirty() - def _ensure_rows_are_local_list(self) -> list[dict]: - # Any uses of this are signs of a design problem arising - # from a remote table clashing with the need to feel like a - # local list. + def _inefficiently_materialize_rows_as_list(self) -> list[dict]: + # This method is named `inefficiently` to warn callers that + # it should be avoided. We have this nasty paradigm where sometimes + # a WeaveTable needs to act like a list, but it is actually a remote + # table. This method will force iteration through the remote data + # and materialize it into a list. Any uses of this are signs of a design + # problem arising from a remote table clashing with the need to feel like + # a local list. if not isinstance(self.rows, list): self.rows = list(self.rows) return self.rows @@ -356,7 +360,10 @@ def __len__(self) -> int: self._known_length = self._fetch_remote_length() return self._known_length - rows_as_list = self._ensure_rows_are_local_list() + # Finally, if we have no table ref, we can still get the length + # by materializing the rows as a list. I actually think this + # can never happen, but it is here for completeness. + rows_as_list = self._inefficiently_materialize_rows_as_list() return len(rows_as_list) def _fetch_remote_length(self) -> int: @@ -477,9 +484,10 @@ def _remote_iter(self) -> Generator[dict, None, None]: page_index += 1 def __getitem__(self, key: Union[int, slice, str]) -> Any: - # TODO: we should have a better caching strategy that allows - # partial iteration over the table. - rows = self._ensure_rows_are_local_list() + # TODO: ideally we would have some sort of intelligent + # LRU style caching that allows us to minimize materialization + # of the rows as a list. + rows = self._inefficiently_materialize_rows_as_list() if isinstance(key, (int, slice)): return rows[key] @@ -494,14 +502,14 @@ def __iter__(self) -> Iterator[dict]: return iter(self.rows) def append(self, val: dict) -> None: - rows = self._ensure_rows_are_local_list() + rows = self._inefficiently_materialize_rows_as_list() if not isinstance(val, dict): raise TypeError("Can only append dicts to tables") self._mark_dirty() rows.append(val) def pop(self, index: int) -> None: - rows = self._ensure_rows_are_local_list() + rows = self._inefficiently_materialize_rows_as_list() self._mark_dirty() rows.pop(index) From a381ce648be53e53ce4c1492d9146fbf96386148 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 22 Jan 2025 13:57:48 -0800 Subject: [PATCH 07/16] Added special iterator --- tests/trace/test_client_trace.py | 2 + weave/trace/vals.py | 18 ++--- weave/utils/iterators.py | 126 +++++++++++++++++++++++++++++++ 3 files changed, 137 insertions(+), 9 deletions(-) create mode 100644 weave/utils/iterators.py diff --git a/tests/trace/test_client_trace.py b/tests/trace/test_client_trace.py index 36dd8ffcaa00..e081f8420f52 100644 --- a/tests/trace/test_client_trace.py +++ b/tests/trace/test_client_trace.py @@ -117,6 +117,8 @@ def test_dataset(client): ref = weave.publish(d) d2 = weave.ref(ref.uri()).get() assert list(d2.rows) == list(d2.rows) + assert list(d.rows) == list(d2.rows) + assert list(d.rows) == list(d.rows) def test_trace_server_call_start_and_end(client): diff --git a/weave/trace/vals.py b/weave/trace/vals.py index f90df4eec95f..ce3aa7bca1e2 100644 --- a/weave/trace/vals.py +++ b/weave/trace/vals.py @@ -3,7 +3,7 @@ import logging import operator import typing -from collections.abc import Generator, Iterator +from collections.abc import Generator, Iterator, Sequence from copy import deepcopy from typing import Any, Literal, Optional, SupportsIndex, Union @@ -37,6 +37,7 @@ TableRowFilter, TraceServerInterface, ) +from weave.utils.iterators import ThreadSafeInMemoryIteratorAsSequence logger = logging.getLogger(__name__) @@ -263,13 +264,10 @@ def __eq__(self, other: Any) -> bool: return self._val == other -ROWS_TYPES = Union[list[dict], Iterator[dict]] - - class WeaveTable(Traceable): filter: TableRowFilter _known_length: Optional[int] = None - _rows: Optional[ROWS_TYPES] = None + _rows: Optional[Sequence[dict]] = None # _prefetched_rows is a local cache of rows that can be used to # avoid a remote call. Should only be used by internal code. _prefetched_rows: Optional[list[dict]] = None @@ -291,7 +289,7 @@ def __init__( self.parent = parent @property - def rows(self) -> ROWS_TYPES: + def rows(self) -> Sequence[dict]: if self._rows is None: should_local_iter = ( self.ref is not None @@ -300,13 +298,15 @@ def rows(self) -> ROWS_TYPES: and self._prefetched_rows is not None ) if should_local_iter: - self._rows = iter(self._local_iter_with_remote_fallback()) + self._rows = ThreadSafeInMemoryIteratorAsSequence( + self._local_iter_with_remote_fallback() + ) else: - self._rows = iter(self._remote_iter()) + self._rows = ThreadSafeInMemoryIteratorAsSequence(self._remote_iter()) return self._rows @rows.setter - def rows(self, value: ROWS_TYPES) -> None: + def rows(self, value: Sequence[dict]) -> None: if not all(isinstance(row, dict) for row in value): raise ValueError("All table rows must be dicts") diff --git a/weave/utils/iterators.py b/weave/utils/iterators.py new file mode 100644 index 000000000000..3938f2a84a9b --- /dev/null +++ b/weave/utils/iterators.py @@ -0,0 +1,126 @@ +from collections.abc import Generator, Iterator, Sequence +from threading import Lock +from typing import Optional, TypeVar, overload + +T = TypeVar("T") + + +class ThreadSafeInMemoryIteratorAsSequence(Sequence[T]): + """ + Provides a thread-safe, sequence-like interface to an iterator by caching results in memory. + + This class is thread-safe and supports multiple iterations over the same data and + concurrent access. + + Args: + single_use_iterator: The source iterator whose values will be cached. (must terminate!) + known_length: Optional pre-known length of the iterator. If provided, can improve + performance by avoiding the need to exhaust the iterator to determine length. + + Thread Safety: + All operations are thread-safe through the use of internal locking. + """ + + _single_use_iterator: Iterator[T] + + def __init__( + self, single_use_iterator: Iterator[T], known_length: Optional[int] = None + ) -> None: + self._lock = Lock() + self._single_use_iterator = single_use_iterator + self._list: list[T] = [] + self._stop_reached = False + self._known_length = known_length + + def _seek_to_index(self, index: int) -> None: + """ + Advances the iterator until the specified index is reached or iterator is exhausted. + Thread-safe operation. + """ + with self._lock: + while index >= len(self._list): + try: + self._list.append(next(self._single_use_iterator)) + except StopIteration: + self._stop_reached = True + return + + def _seek_to_end(self) -> None: + """ + Exhausts the iterator, caching all remaining values. + Thread-safe operation. + """ + with self._lock: + while not self._stop_reached: + try: + self._list.append(next(self._single_use_iterator)) + except StopIteration: + self._stop_reached = True + return + + @overload + def __getitem__(self, index: int) -> T: ... + + @overload + def __getitem__(self, index: slice) -> Sequence[T]: ... + + def __getitem__(self, index: int | slice) -> T | Sequence[T]: + """ + Returns the item at the specified index. + + Args: + index: The index of the desired item. + + Returns: + The item at the specified index. + + Raises: + IndexError: If the index is out of range. + """ + if isinstance(index, slice): + self._seek_to_index(index.stop) + return self._list[index] + else: + self._seek_to_index(index) + return self._list[index] + + def __len__(self) -> int: + """ + Returns the total length of the sequence. + + If known_length was provided at initialization, returns that value. + Otherwise, exhausts the iterator to determine the length. + + Returns: + The total number of items in the sequence. + """ + if self._known_length is not None: + return self._known_length + + self._seek_to_end() + return len(self._list) + + def __iter__(self) -> Iterator[T]: + """ + Returns an iterator over the sequence. + + The returned iterator is safe to use concurrently with other operations + on this sequence. + + Returns: + An iterator yielding all items in the sequence. + """ + + def _iter() -> Generator[T, None, None]: + i = 0 + while True: + try: + val = self[i] + except IndexError: + return + try: + yield val + finally: + i += 1 + + return _iter() From f5cd8393a5ca1221c7c4340995c397033ee45d86 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 22 Jan 2025 14:17:48 -0800 Subject: [PATCH 08/16] Fixed tests --- weave/trace/vals.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/weave/trace/vals.py b/weave/trace/vals.py index ce3aa7bca1e2..39bd84e92aa0 100644 --- a/weave/trace/vals.py +++ b/weave/trace/vals.py @@ -345,15 +345,18 @@ def __len__(self) -> int: return self._known_length # Condition 1: we already have all the rows in memory + if self._prefetched_rows is not None: + self._known_length = len(self._prefetched_rows) + return self._known_length + + # Condition 2: we have the row digests and they are a list if ( self.table_ref is not None and self.table_ref._row_digests is not None - and self._prefetched_rows is not None and isinstance(self.table_ref._row_digests, list) ): - if len(self._prefetched_rows) == len(self.table_ref._row_digests): - self._known_length = len(self._prefetched_rows) - return self._known_length + self._known_length = len(self.table_ref._row_digests) + return self._known_length # Condition 2: We don't know the length, in which case we can get it from the server if self.table_ref is not None: From ffcd27a789d81b73097ee9d1d3b056672874b3b9 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 22 Jan 2025 14:25:48 -0800 Subject: [PATCH 09/16] Tests --- tests/trace/test_iterators.py | 142 ++++++++++++++++++++++++++++++++++ weave/utils/iterators.py | 5 +- 2 files changed, 146 insertions(+), 1 deletion(-) create mode 100644 tests/trace/test_iterators.py diff --git a/tests/trace/test_iterators.py b/tests/trace/test_iterators.py new file mode 100644 index 000000000000..b70246b418e4 --- /dev/null +++ b/tests/trace/test_iterators.py @@ -0,0 +1,142 @@ +import threading + +import pytest + +from weave.utils.iterators import ThreadSafeInMemoryIteratorAsSequence + + +def test_basic_sequence_operations(): + # Test basic sequence operations + iterator = ThreadSafeInMemoryIteratorAsSequence(iter(range(10))) + assert len(iterator) == 10 + assert iterator[0] == 0 + assert iterator[1:3] == [1, 2] + assert list(iterator) == list(range(10)) + + +def test_empty_iterator(): + # Test behavior with empty iterator + iterator = ThreadSafeInMemoryIteratorAsSequence(iter([])) + assert len(iterator) == 0 + with pytest.raises(IndexError): + _ = iterator[0] + assert list(iterator) == [] + + +def test_known_length(): + # Test initialization with known length + iterator = ThreadSafeInMemoryIteratorAsSequence(iter(range(5)), known_length=5) + assert len(iterator) == 5 # Should not need to exhaust iterator + assert iterator[4] == 4 # Access last element + + +def test_multiple_iterations(): + # Test multiple iterations return same results + data = list(range(5)) + iterator = ThreadSafeInMemoryIteratorAsSequence(iter(data)) + + assert list(iterator) == data # First iteration + assert list(iterator) == data # Second iteration + assert list(iterator) == data # Third iteration + + +def test_concurrent_access(): + # Test thread-safe concurrent access + data = list(range(1000)) + iterator = ThreadSafeInMemoryIteratorAsSequence(iter(data)) + results = [] + + def reader_thread(): + results.append(list(iterator)) + + threads = [threading.Thread(target=reader_thread) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + # All threads should see the same data + assert all(r == data for r in results) + + +def test_slicing(): + # Test various slicing operations + iterator = ThreadSafeInMemoryIteratorAsSequence(iter(range(10))) + + assert iterator[2:5] == [2, 3, 4] + assert iterator[:3] == [0, 1, 2] + assert iterator[7:] == [7, 8, 9] + assert iterator[::2] == [0, 2, 4, 6, 8] + assert iterator[::-1] == [9, 8, 7, 6, 5, 4, 3, 2, 1, 0] + + +def test_random_access(): + # Test random access patterns + iterator = ThreadSafeInMemoryIteratorAsSequence(iter(range(10))) + + assert iterator[5] == 5 # Middle access + assert iterator[1] == 1 # Earlier access + assert iterator[8] == 8 # Later access + assert iterator[0] == 0 # First element + assert iterator[9] == 9 # Last element + + +def test_concurrent_mixed_operations(): + # Test concurrent mixed operations (reads, slices, iterations) + data = list(range(100)) + iterator = ThreadSafeInMemoryIteratorAsSequence(iter(data)) + results = [] + + def mixed_ops_thread(): + local_results = [] + local_results.append(iterator[10]) # Single element access + local_results.append(list(iterator[20:25])) # Slice access + local_results.append(iterator[50]) # Another single element + local_results.extend(iterator[90:]) # End slice + results.append(local_results) + + threads = [threading.Thread(target=mixed_ops_thread) for _ in range(3)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Verify all threads got the same results + expected = [10, list(range(20, 25)), 50, list(range(90, 100))] + assert all(r == expected for r in results) + + +def test_index_out_of_range(): + # Test index out of range behavior + iterator = ThreadSafeInMemoryIteratorAsSequence(iter(range(5))) + + with pytest.raises(IndexError): + _ = iterator[10] + + assert iterator[-1] == 4 # Negative indices are supported + + +def test_iterator_exhaustion(): + # Test behavior when iterator is exhausted + class CountingIterator: + def __init__(self): + self.count = 0 + + def __iter__(self): + return self + + def __next__(self): + if self.count < 5: + self.count += 1 + return self.count - 1 + raise StopIteration + + iterator = ThreadSafeInMemoryIteratorAsSequence(CountingIterator()) + + # Access beyond iterator length should raise IndexError + assert len(iterator) == 5 + with pytest.raises(IndexError): + _ = iterator[10] + + # Verify original data is still accessible + assert list(iterator) == [0, 1, 2, 3, 4] diff --git a/weave/utils/iterators.py b/weave/utils/iterators.py index 3938f2a84a9b..5e52e8534fed 100644 --- a/weave/utils/iterators.py +++ b/weave/utils/iterators.py @@ -78,7 +78,10 @@ def __getitem__(self, index: int | slice) -> T | Sequence[T]: IndexError: If the index is out of range. """ if isinstance(index, slice): - self._seek_to_index(index.stop) + if index.stop is None: + self._seek_to_end() + else: + self._seek_to_index(index.stop) return self._list[index] else: self._seek_to_index(index) From bd9f873ed8fd2c58e415ae6c1f9f1807efb5d412 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 22 Jan 2025 14:29:28 -0800 Subject: [PATCH 10/16] Lint --- weave/utils/iterators.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/weave/utils/iterators.py b/weave/utils/iterators.py index 5e52e8534fed..a19982144785 100644 --- a/weave/utils/iterators.py +++ b/weave/utils/iterators.py @@ -1,6 +1,8 @@ +from __future__ import annotations + from collections.abc import Generator, Iterator, Sequence from threading import Lock -from typing import Optional, TypeVar, overload +from typing import TypeVar, overload T = TypeVar("T") @@ -24,7 +26,7 @@ class ThreadSafeInMemoryIteratorAsSequence(Sequence[T]): _single_use_iterator: Iterator[T] def __init__( - self, single_use_iterator: Iterator[T], known_length: Optional[int] = None + self, single_use_iterator: Iterator[T], known_length: int | None = None ) -> None: self._lock = Lock() self._single_use_iterator = single_use_iterator From a2b5b2f7e48d2c50ae0d6e38f33aa82e3a1c8630 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 22 Jan 2025 15:18:52 -0800 Subject: [PATCH 11/16] maybe complete --- tests/trace/test_client_trace.py | 3 +++ weave/trace/vals.py | 4 ++-- weave/utils/iterators.py | 17 ++++++++++++++--- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/tests/trace/test_client_trace.py b/tests/trace/test_client_trace.py index e081f8420f52..7fb293b0676d 100644 --- a/tests/trace/test_client_trace.py +++ b/tests/trace/test_client_trace.py @@ -116,6 +116,9 @@ def test_dataset(client): d = Dataset(rows=[{"a": 5, "b": 6}, {"a": 7, "b": 10}]) ref = weave.publish(d) d2 = weave.ref(ref.uri()).get() + + # This might seem redundant, but it is useful to ensure that the + # dataset can be re-iterated over multiple times and equality is preserved. assert list(d2.rows) == list(d2.rows) assert list(d.rows) == list(d2.rows) assert list(d.rows) == list(d.rows) diff --git a/weave/trace/vals.py b/weave/trace/vals.py index 39bd84e92aa0..5c5e723f7d52 100644 --- a/weave/trace/vals.py +++ b/weave/trace/vals.py @@ -322,8 +322,8 @@ def _inefficiently_materialize_rows_as_list(self) -> list[dict]: # problem arising from a remote table clashing with the need to feel like # a local list. if not isinstance(self.rows, list): - self.rows = list(self.rows) - return self.rows + self._rows = list(iter(self.rows)) + return typing.cast(list[dict], self.rows) def set_prefetched_rows(self, prefetched_rows: list[dict]) -> None: """Sets the rows to a local cache of rows that can be used to diff --git a/weave/utils/iterators.py b/weave/utils/iterators.py index a19982144785..9ccd637762b6 100644 --- a/weave/utils/iterators.py +++ b/weave/utils/iterators.py @@ -1,13 +1,13 @@ from __future__ import annotations -from collections.abc import Generator, Iterator, Sequence +from collections.abc import Generator, Iterable, Iterator, Sequence from threading import Lock -from typing import TypeVar, overload +from typing import Any, TypeVar, overload T = TypeVar("T") -class ThreadSafeInMemoryIteratorAsSequence(Sequence[T]): +class ThreadSafeInMemoryIteratorAsSequence(Sequence[T], Iterable[T]): """ Provides a thread-safe, sequence-like interface to an iterator by caching results in memory. @@ -129,3 +129,14 @@ def _iter() -> Generator[T, None, None]: i += 1 return _iter() + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Sequence): + return False + if len(self) != len(other): + return False + self._seek_to_end() + for a, b in zip(self._list, other): + if a != b: + return False + return True From 8a7bac1dd193d5b50b3a32a719a7a68cb3ef8f79 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 22 Jan 2025 15:43:00 -0800 Subject: [PATCH 12/16] fix test --- weave/trace/vals.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/weave/trace/vals.py b/weave/trace/vals.py index 5c5e723f7d52..68e06ac4bc56 100644 --- a/weave/trace/vals.py +++ b/weave/trace/vals.py @@ -381,7 +381,8 @@ def _fetch_remote_length(self) -> int: return response.count def __eq__(self, other: Any) -> bool: - return self.rows == other + rows = self._inefficiently_materialize_rows_as_list() + return rows == other def _mark_dirty(self) -> None: self.table_ref = None From 68a6818cf4bc79832dba7189863cd4c28bfd282b Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 22 Jan 2025 16:06:04 -0800 Subject: [PATCH 13/16] fix test --- tests/trace/test_evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trace/test_evaluate.py b/tests/trace/test_evaluate.py index 2a2ca4d407d2..22fd8485b12a 100644 --- a/tests/trace/test_evaluate.py +++ b/tests/trace/test_evaluate.py @@ -343,4 +343,4 @@ def score_simple(input, output): # However, the key part is that we have basically X + 2 splits, with the middle X # being equal. We want to ensure that the table_query is not called in sequence, # but rather lazily after each batch. - assert counts_split_by_table_query == [13, 700, 700, 700, 5] + assert counts_split_by_table_query == [13, 700, 700, 700, 5], log From 4460cce0376bdc1604b5077ab48a3c7e11533015 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 22 Jan 2025 16:20:03 -0800 Subject: [PATCH 14/16] fix test again --- tests/trace/test_evaluate.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/trace/test_evaluate.py b/tests/trace/test_evaluate.py index 22fd8485b12a..029aa8fa5d51 100644 --- a/tests/trace/test_evaluate.py +++ b/tests/trace/test_evaluate.py @@ -343,4 +343,13 @@ def score_simple(input, output): # However, the key part is that we have basically X + 2 splits, with the middle X # being equal. We want to ensure that the table_query is not called in sequence, # but rather lazily after each batch. - assert counts_split_by_table_query == [13, 700, 700, 700, 5], log + assert counts_split_by_table_query[0] <= 13 + # Note: if this test suite is ran in a different order, then the low level eval ops will already be saved + # so the first count can be different. + assert counts_split_by_table_query == [ + counts_split_by_table_query[0], + 700, + 700, + 700, + 5, + ], log From e16fc3ccecf16e03a6e7b8f6b06ae6f0bb63e80e Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 22 Jan 2025 17:13:09 -0800 Subject: [PATCH 15/16] UNDO ME --- tests/trace/test_evaluate.py | 1 + weave/flow/eval.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/tests/trace/test_evaluate.py b/tests/trace/test_evaluate.py index 029aa8fa5d51..a3a1c8133a2e 100644 --- a/tests/trace/test_evaluate.py +++ b/tests/trace/test_evaluate.py @@ -111,6 +111,7 @@ def test_score_as_class(client): class MyScorer(weave.Scorer): @weave.op() def score(self, target, output): + print("score", target, output, target == output) return target == output evaluation = Evaluation( diff --git a/weave/flow/eval.py b/weave/flow/eval.py index d998f5a44d32..42afe9c84b7b 100644 --- a/weave/flow/eval.py +++ b/weave/flow/eval.py @@ -298,7 +298,9 @@ async def eval_example(example: dict) -> dict: scorer_name = scorer_attributes.scorer_name if scorer_name not in eval_row["scores"]: eval_row["scores"][scorer_name] = {} + print("eval_row", eval_row) eval_rows.append(eval_row) + print("eval_rows", eval_rows) return EvaluationResults(rows=weave.Table(eval_rows)) @weave.op(call_display_name=default_evaluation_display_name) @@ -355,4 +357,5 @@ def is_valid_model(model: Any) -> bool: def repeated_iterable(iterable: Iterable[T], n: int) -> Iterable[T]: for val in iterable: for _ in range(n): + print("yielding", val) yield val From f31984c7a90cd74a8d1c69994d2026b4d232afcf Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 22 Jan 2025 17:24:35 -0800 Subject: [PATCH 16/16] UNDO ME --- weave/flow/eval.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/weave/flow/eval.py b/weave/flow/eval.py index 42afe9c84b7b..71d1d5bce7df 100644 --- a/weave/flow/eval.py +++ b/weave/flow/eval.py @@ -238,8 +238,11 @@ async def predict_and_score(self, model: Union[Op, Model], example: dict) -> dic @weave.op() async def summarize(self, eval_table: EvaluationResults) -> dict: + print("summarize", eval_table.rows) eval_table_rows = list(eval_table.rows) + print("eval_table_rows", eval_table_rows) cols = transpose(eval_table_rows) + print("cols", cols) summary = {} for name, vals in cols.items(): @@ -249,7 +252,9 @@ async def summarize(self, eval_table: EvaluationResults) -> dict: scorer_attributes = get_scorer_attributes(scorer) scorer_name = scorer_attributes.scorer_name summarize_fn = scorer_attributes.summarize_fn + print("vals", scorer_name) scorer_stats = transpose(vals) + print("scorer_stats", scorer_name) score_table = scorer_stats[scorer_name] scored = summarize_fn(score_table) summary[scorer_name] = scored