From d2045693e6d3e2788e8410f781c83953aaa5c7c6 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 15 Jan 2025 13:39:35 -0500 Subject: [PATCH 1/4] test --- tests/trace/test_uri_get.py | 46 +++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/trace/test_uri_get.py b/tests/trace/test_uri_get.py index 55c5595c3ac9..ae8f97d04484 100644 --- a/tests/trace/test_uri_get.py +++ b/tests/trace/test_uri_get.py @@ -1,9 +1,14 @@ +from typing import TypeVar + import pytest import weave from weave.flow.prompt.prompt import EasyPrompt +from weave.trace.refs import RefWithExtra from weave.trace_server.trace_server_interface import ObjectVersionFilter, ObjQueryReq +T = TypeVar("T") + @pytest.fixture( params=[ @@ -94,3 +99,44 @@ def model(a: int) -> int: for obj in res.objs: assert obj.version_index == 0 assert obj.is_latest == 1 + + +def resolve_ref_futures(ref: RefWithExtra) -> RefWithExtra: + return ref + # """This is a bit of a hack to resolve futures in an initally unsaved object's extra fields. + + # Currently, the extras are still a Future and not yet replaced with the actual value. + # This function resolves the futures and replaces them with the actual values. + # """ + # extras = ref._extra + # new_extras = [] + # for name, val in zip(extras[::2], extras[1::2]): + # if isinstance(val, Future): + # val = val.result() + # new_extras.append(name) + # new_extras.append(val) + # ref = replace(ref, _extra=tuple(new_extras)) + # return ref + + +def test_drill_down_dataset_refs_same_after_publishing(client): + ds = weave.Dataset( + name="test", + rows=[{"a": {"b": 1}}, {"a": {"b": 2}}, {"a": {"b": 3}}], + ) + ref = weave.publish(ds) + ds2 = ref.get() + ref2 = weave.publish(ds2) + ds3 = ref2.get() + + assert resolve_ref_futures(ds.rows.ref) == ds2.rows.ref + for row, row2 in zip(ds.rows, ds2.rows): + assert resolve_ref_futures(row.ref) == row2.ref + assert resolve_ref_futures(row["a"].ref) == row2["a"].ref + assert resolve_ref_futures(row["a"]["b"].ref) == row2["a"]["b"].ref + + assert resolve_ref_futures(ds2.ref) == ds3.ref + for row2, row3 in zip(ds2.rows, ds3.rows): + assert resolve_ref_futures(row2.ref) == row3.ref + assert resolve_ref_futures(row2["a"].ref) == row3["a"].ref + assert resolve_ref_futures(row2["a"]["b"].ref) == row3["a"]["b"].ref From 6ae44e904022921fe7f50871c3b5d221e10c2cd6 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 15 Jan 2025 23:30:06 -0500 Subject: [PATCH 2/4] test --- tests/trace/test_uri_get.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/tests/trace/test_uri_get.py b/tests/trace/test_uri_get.py index ae8f97d04484..7cbd1760f88f 100644 --- a/tests/trace/test_uri_get.py +++ b/tests/trace/test_uri_get.py @@ -1,3 +1,5 @@ +from concurrent.futures import Future +from dataclasses import replace from typing import TypeVar import pytest @@ -102,21 +104,20 @@ def model(a: int) -> int: def resolve_ref_futures(ref: RefWithExtra) -> RefWithExtra: + """This is a bit of a hack to resolve futures in an initally unsaved object's extra fields. + + Currently, the extras are still a Future and not yet replaced with the actual value. + This function resolves the futures and replaces them with the actual values. + """ + extras = ref._extra + new_extras = [] + for name, val in zip(extras[::2], extras[1::2]): + if isinstance(val, Future): + val = val.result() + new_extras.append(name) + new_extras.append(val) + ref = replace(ref, _extra=tuple(new_extras)) return ref - # """This is a bit of a hack to resolve futures in an initally unsaved object's extra fields. - - # Currently, the extras are still a Future and not yet replaced with the actual value. - # This function resolves the futures and replaces them with the actual values. - # """ - # extras = ref._extra - # new_extras = [] - # for name, val in zip(extras[::2], extras[1::2]): - # if isinstance(val, Future): - # val = val.result() - # new_extras.append(name) - # new_extras.append(val) - # ref = replace(ref, _extra=tuple(new_extras)) - # return ref def test_drill_down_dataset_refs_same_after_publishing(client): From 3db26d407caeccc972af2043536b05be20035b7b Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 15 Jan 2025 23:35:16 -0500 Subject: [PATCH 3/4] test --- tests/trace/{test_uri_get.py => test_objectify.py} | 14 ++++++++++++++ weave/trace/objectify.py | 5 ++++- 2 files changed, 18 insertions(+), 1 deletion(-) rename tests/trace/{test_uri_get.py => test_objectify.py} (91%) diff --git a/tests/trace/test_uri_get.py b/tests/trace/test_objectify.py similarity index 91% rename from tests/trace/test_uri_get.py rename to tests/trace/test_objectify.py index 7cbd1760f88f..7f334a5d24bf 100644 --- a/tests/trace/test_uri_get.py +++ b/tests/trace/test_objectify.py @@ -5,7 +5,9 @@ import pytest import weave +from weave.flow.obj import Object from weave.flow.prompt.prompt import EasyPrompt +from weave.trace.objectify import register_object from weave.trace.refs import RefWithExtra from weave.trace_server.trace_server_interface import ObjectVersionFilter, ObjQueryReq @@ -141,3 +143,15 @@ def test_drill_down_dataset_refs_same_after_publishing(client): assert resolve_ref_futures(row2.ref) == row3.ref assert resolve_ref_futures(row2["a"].ref) == row3["a"].ref assert resolve_ref_futures(row2["a"]["b"].ref) == row3["a"]["b"].ref + + +def test_registration(): + # This is a second class named Dataset. The first has already been registered + # in weave.flow.obj. This should raise an error. + + with pytest.raises(ValueError, match="Class Dataset already registered as"): + + @register_object + class Dataset(Object): + anything: str + doesnt_matter: int diff --git a/weave/trace/objectify.py b/weave/trace/objectify.py index 1cf9444391ad..d9fa26db46cd 100644 --- a/weave/trace/objectify.py +++ b/weave/trace/objectify.py @@ -19,7 +19,10 @@ def from_obj(cls, obj: WeaveObject) -> T_co: ... def register_object(cls: type[T_co]) -> type[T_co]: - _registry[cls.__name__] = cls + cls_name = cls.__name__ + if (existing_cls := _registry.get(cls_name)) is not None: + raise ValueError(f"Class {cls_name} already registered as {existing_cls}") + _registry[cls_name] = cls return cls From 0c321b5015dd4b6d7e4c8ccd95c1aab1b9bb8cb5 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 15 Jan 2025 23:46:53 -0500 Subject: [PATCH 4/4] test --- tests/trace/test_objectify.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/trace/test_objectify.py b/tests/trace/test_objectify.py index 7f334a5d24bf..f57915a92269 100644 --- a/tests/trace/test_objectify.py +++ b/tests/trace/test_objectify.py @@ -138,11 +138,17 @@ def test_drill_down_dataset_refs_same_after_publishing(client): assert resolve_ref_futures(row["a"].ref) == row2["a"].ref assert resolve_ref_futures(row["a"]["b"].ref) == row2["a"]["b"].ref - assert resolve_ref_futures(ds2.ref) == ds3.ref + assert ds2.ref == ds3.ref for row2, row3 in zip(ds2.rows, ds3.rows): - assert resolve_ref_futures(row2.ref) == row3.ref - assert resolve_ref_futures(row2["a"].ref) == row3["a"].ref - assert resolve_ref_futures(row2["a"]["b"].ref) == row3["a"]["b"].ref + assert row2.ref == row3.ref + assert row2["a"].ref == row3["a"].ref + assert row2["a"]["b"].ref == row3["a"]["b"].ref + + assert ds3.rows == [{"a": {"b": 1}}, {"a": {"b": 2}}, {"a": {"b": 3}}] + for i, row in enumerate(ds3.rows, 1): + assert row == {"a": {"b": i}} + assert row["a"] == {"b": i} + assert row["a"]["b"] == i def test_registration():