From a2b5b2f7e48d2c50ae0d6e38f33aa82e3a1c8630 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 22 Jan 2025 15:18:52 -0800 Subject: [PATCH] maybe complete --- tests/trace/test_client_trace.py | 3 +++ weave/trace/vals.py | 4 ++-- weave/utils/iterators.py | 17 ++++++++++++++--- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/tests/trace/test_client_trace.py b/tests/trace/test_client_trace.py index e081f8420f52..7fb293b0676d 100644 --- a/tests/trace/test_client_trace.py +++ b/tests/trace/test_client_trace.py @@ -116,6 +116,9 @@ def test_dataset(client): d = Dataset(rows=[{"a": 5, "b": 6}, {"a": 7, "b": 10}]) ref = weave.publish(d) d2 = weave.ref(ref.uri()).get() + + # This might seem redundant, but it is useful to ensure that the + # dataset can be re-iterated over multiple times and equality is preserved. assert list(d2.rows) == list(d2.rows) assert list(d.rows) == list(d2.rows) assert list(d.rows) == list(d.rows) diff --git a/weave/trace/vals.py b/weave/trace/vals.py index 39bd84e92aa0..5c5e723f7d52 100644 --- a/weave/trace/vals.py +++ b/weave/trace/vals.py @@ -322,8 +322,8 @@ def _inefficiently_materialize_rows_as_list(self) -> list[dict]: # problem arising from a remote table clashing with the need to feel like # a local list. if not isinstance(self.rows, list): - self.rows = list(self.rows) - return self.rows + self._rows = list(iter(self.rows)) + return typing.cast(list[dict], self.rows) def set_prefetched_rows(self, prefetched_rows: list[dict]) -> None: """Sets the rows to a local cache of rows that can be used to diff --git a/weave/utils/iterators.py b/weave/utils/iterators.py index a19982144785..9ccd637762b6 100644 --- a/weave/utils/iterators.py +++ b/weave/utils/iterators.py @@ -1,13 +1,13 @@ from __future__ import annotations -from collections.abc import Generator, Iterator, Sequence +from collections.abc import Generator, Iterable, Iterator, Sequence from threading import Lock -from typing import TypeVar, overload +from typing import Any, TypeVar, overload T = TypeVar("T") -class ThreadSafeInMemoryIteratorAsSequence(Sequence[T]): +class ThreadSafeInMemoryIteratorAsSequence(Sequence[T], Iterable[T]): """ Provides a thread-safe, sequence-like interface to an iterator by caching results in memory. @@ -129,3 +129,14 @@ def _iter() -> Generator[T, None, None]: i += 1 return _iter() + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Sequence): + return False + if len(self) != len(other): + return False + self._seek_to_end() + for a, b in zip(self._list, other): + if a != b: + return False + return True