Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(weave): Eliminate premature materialization of datasets to minimize network requests #3456

Draft
wants to merge 18 commits into
base: master
Choose a base branch
from
5 changes: 5 additions & 0 deletions tests/trace/test_client_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,12 @@ def test_dataset(client):
d = Dataset(rows=[{"a": 5, "b": 6}, {"a": 7, "b": 10}])
ref = weave.publish(d)
d2 = weave.ref(ref.uri()).get()

# This might seem redundant, but it is useful to ensure that the
# dataset can be re-iterated over multiple times and equality is preserved.
assert list(d2.rows) == list(d2.rows)
assert list(d.rows) == list(d2.rows)
assert list(d.rows) == list(d.rows)


def test_trace_server_call_start_and_end(client):
Expand Down
77 changes: 77 additions & 0 deletions tests/trace/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

import weave
from tests.trace.test_evaluate import Dataset


def test_basic_dataset_lifecycle(client):
Expand Down Expand Up @@ -35,6 +36,82 @@ def test_pythonic_access(client):
ds[-1]


def test_dataset_laziness(client):
"""
The intention of this test is to show that local construction of
a dataset does not trigger any remote operations.
"""
dataset = Dataset(rows=[{"input": i} for i in range(300)])
log = client.server.attribute_access_log
assert [l for l in log if not l.startswith("_")] == ["ensure_project_exists"]
client.server.attribute_access_log = []

length = len(dataset)
log = client.server.attribute_access_log
assert [l for l in log if not l.startswith("_")] == []

length2 = len(dataset)
log = client.server.attribute_access_log
assert [l for l in log if not l.startswith("_")] == []

assert length == length2

i = 0
for row in dataset:
log = client.server.attribute_access_log
assert [l for l in log if not l.startswith("_")] == []
i += 1


def test_published_dataset_laziness(client):
"""
The intention of this test is to show that publishing a dataset,
then iterating through the "gotten" version of the dataset has
minimal remote operations - and importantly delays the fetching
of the rows until they are actually needed.
"""
dataset = Dataset(rows=[{"input": i} for i in range(300)])
log = client.server.attribute_access_log
assert [l for l in log if not l.startswith("_")] == ["ensure_project_exists"]
client.server.attribute_access_log = []

ref = weave.publish(dataset)
log = client.server.attribute_access_log
assert [l for l in log if not l.startswith("_")] == ["table_create", "obj_create"]
client.server.attribute_access_log = []

dataset = ref.get()
log = client.server.attribute_access_log
assert [l for l in log if not l.startswith("_")] == ["obj_read"]
client.server.attribute_access_log = []

length = len(dataset)
log = client.server.attribute_access_log
assert [l for l in log if not l.startswith("_")] == ["table_query_stats"]
client.server.attribute_access_log = []

length2 = len(dataset)
log = client.server.attribute_access_log
assert [l for l in log if not l.startswith("_")] == []

assert length == length2

i = 0
for row in dataset:
log = client.server.attribute_access_log
# This is the critical part of the test - ensuring that
# the rows are only fetched when they are actually needed.
#
# In a future improvement, we might eagerly fetch the next
# page of results, which would result in this assertion changing
# in that there would always be one more "table_query" than
# the number of pages.
assert [l for l in log if not l.startswith("_")] == ["table_query"] * (
(i // 100) + 1
)
i += 1


def test_dataset_from_calls(client):
@weave.op
def greet(name: str, age: int) -> str:
Expand Down
68 changes: 68 additions & 0 deletions tests/trace/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def test_score_as_class(client):
class MyScorer(weave.Scorer):
@weave.op()
def score(self, target, output):
print("score", target, output, target == output)
return target == output

evaluation = Evaluation(
Expand Down Expand Up @@ -286,3 +287,70 @@ def score(self, target, model_output):

result = asyncio.run(evaluation.evaluate(model))
assert result["my-scorer"] == {"true_count": 1, "true_fraction": 0.5}


def test_evaluate_table_lazy_iter(client):
"""
The intention of this test is to show that an evaluation harness
lazily fetches rows from a table rather than eagerly fetching all
rows up front.
"""
dataset = Dataset(rows=[{"input": i} for i in range(300)])
ref = weave.publish(dataset)
dataset = ref.get()

@weave.op()
async def model_predict(input) -> int:
return input * 1

@weave.op()
def score_simple(input, output):
return input == output

log = client.server.attribute_access_log
assert [l for l in log if not l.startswith("_")] == [
"ensure_project_exists",
"table_create",
"obj_create",
"obj_read",
]
client.server.attribute_access_log = []

evaluation = Evaluation(
dataset=dataset,
scorers=[score_simple],
)
log = client.server.attribute_access_log
assert [l for l in log if not l.startswith("_")] == []

result = asyncio.run(evaluation.evaluate(model_predict))
assert result["output"] == {"mean": 149.5}
assert result["score_simple"] == {"true_count": 300, "true_fraction": 1.0}

log = client.server.attribute_access_log
log = [l for l in log if not l.startswith("_")]

# Make sure that the length was figured out deterministically
assert "table_query_stats" in log

counts_split_by_table_query = [0]
for log_entry in log:
if log_entry == "table_query":
counts_split_by_table_query.append(0)
else:
counts_split_by_table_query[-1] += 1

# Note: these exact numbers might change if we change the way eval traces work.
# However, the key part is that we have basically X + 2 splits, with the middle X
# being equal. We want to ensure that the table_query is not called in sequence,
# but rather lazily after each batch.
assert counts_split_by_table_query[0] <= 13
# Note: if this test suite is ran in a different order, then the low level eval ops will already be saved
# so the first count can be different.
assert counts_split_by_table_query == [
counts_split_by_table_query[0],
700,
700,
700,
5,
], log
131 changes: 131 additions & 0 deletions tests/trace/test_flow_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import asyncio

import pytest

from weave.flow.util import async_foreach


@pytest.mark.asyncio
async def test_async_foreach_basic():
"""Test basic functionality of async_foreach."""
input_data = range(5)
results = []

async def process(x: int) -> int:
await asyncio.sleep(0.1) # Simulate async work
return x * 2

async for item, result in async_foreach(
input_data, process, max_concurrent_tasks=2
):
results.append((item, result))

assert len(results) == 5
assert all(result == item * 2 for item, result in results)
assert [item for item, _ in results] == list(range(5))


@pytest.mark.asyncio
async def test_async_foreach_concurrency():
"""Test that max_concurrent_tasks is respected."""
currently_running = 0
max_running = 0
input_data = range(10)

async def process(x: int) -> int:
nonlocal currently_running, max_running
currently_running += 1
max_running = max(max_running, currently_running)
await asyncio.sleep(0.1) # Simulate async work
currently_running -= 1
return x

max_concurrent = 3
async for _, _ in async_foreach(
input_data, process, max_concurrent_tasks=max_concurrent
):
pass

assert max_running == max_concurrent


@pytest.mark.asyncio
async def test_async_foreach_lazy_loading():
"""Test that items are loaded lazily from the iterator."""
items_loaded = 0

def lazy_range(n: int):
nonlocal items_loaded
for i in range(n):
items_loaded += 1
yield i

async def process(x: int) -> int:
await asyncio.sleep(0.1)
return x

# Process first 3 items then break
async for _, _ in async_foreach(lazy_range(100), process, max_concurrent_tasks=2):
if items_loaded >= 3:
break

# Should have loaded at most 4 items (3 + 1 for concurrency)
assert items_loaded <= 4


@pytest.mark.asyncio
async def test_async_foreach_error_handling():
"""Test error handling in async_foreach."""
input_data = range(5)

async def process(x: int) -> int:
if x == 3:
raise ValueError("Test error")
return x

with pytest.raises(ValueError, match="Test error"):
async for _, _ in async_foreach(input_data, process, max_concurrent_tasks=2):
pass


@pytest.mark.asyncio
async def test_async_foreach_empty_input():
"""Test behavior with empty input sequence."""
results = []

async def process(x: int) -> int:
return x

async for item, result in async_foreach([], process, max_concurrent_tasks=2):
results.append((item, result))

assert len(results) == 0


@pytest.mark.asyncio
async def test_async_foreach_cancellation():
"""Test that tasks are properly cleaned up on cancellation."""
input_data = range(100)
results = []

async def slow_process(x: int) -> int:
await asyncio.sleep(0.5) # Longer delay to ensure tasks are running
return x

# Create a task we can cancel
async def run_foreach():
async for item, result in async_foreach(
input_data, slow_process, max_concurrent_tasks=3
):
results.append((item, result))
if len(results) >= 2: # Cancel after 2 results
raise asyncio.CancelledError()

with pytest.raises(asyncio.CancelledError):
await run_foreach()

# Give a moment for any lingering tasks to complete if cleanup failed
await asyncio.sleep(0.1)

# Check that we got the expected number of results before cancellation
assert len(results) == 2
Loading
Loading