From c5e37dd24b83304cc9ba9c69ec423a2cf4f777d8 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 22 Jan 2025 16:04:35 -0500 Subject: [PATCH] chore(weave): Add tests for other expected behaviours in objectify (#3420) --- .../{test_uri_get.py => test_objectify.py} | 67 +++++++++++++++++++ weave/trace/objectify.py | 5 +- 2 files changed, 71 insertions(+), 1 deletion(-) rename tests/trace/{test_uri_get.py => test_objectify.py} (58%) diff --git a/tests/trace/test_uri_get.py b/tests/trace/test_objectify.py similarity index 58% rename from tests/trace/test_uri_get.py rename to tests/trace/test_objectify.py index 55c5595c3ac9..f57915a92269 100644 --- a/tests/trace/test_uri_get.py +++ b/tests/trace/test_objectify.py @@ -1,9 +1,18 @@ +from concurrent.futures import Future +from dataclasses import replace +from typing import TypeVar + import pytest import weave +from weave.flow.obj import Object from weave.flow.prompt.prompt import EasyPrompt +from weave.trace.objectify import register_object +from weave.trace.refs import RefWithExtra from weave.trace_server.trace_server_interface import ObjectVersionFilter, ObjQueryReq +T = TypeVar("T") + @pytest.fixture( params=[ @@ -94,3 +103,61 @@ def model(a: int) -> int: for obj in res.objs: assert obj.version_index == 0 assert obj.is_latest == 1 + + +def resolve_ref_futures(ref: RefWithExtra) -> RefWithExtra: + """This is a bit of a hack to resolve futures in an initally unsaved object's extra fields. + + Currently, the extras are still a Future and not yet replaced with the actual value. + This function resolves the futures and replaces them with the actual values. + """ + extras = ref._extra + new_extras = [] + for name, val in zip(extras[::2], extras[1::2]): + if isinstance(val, Future): + val = val.result() + new_extras.append(name) + new_extras.append(val) + ref = replace(ref, _extra=tuple(new_extras)) + return ref + + +def test_drill_down_dataset_refs_same_after_publishing(client): + ds = weave.Dataset( + name="test", + rows=[{"a": {"b": 1}}, {"a": {"b": 2}}, {"a": {"b": 3}}], + ) + ref = weave.publish(ds) + ds2 = ref.get() + ref2 = weave.publish(ds2) + ds3 = ref2.get() + + assert resolve_ref_futures(ds.rows.ref) == ds2.rows.ref + for row, row2 in zip(ds.rows, ds2.rows): + assert resolve_ref_futures(row.ref) == row2.ref + assert resolve_ref_futures(row["a"].ref) == row2["a"].ref + assert resolve_ref_futures(row["a"]["b"].ref) == row2["a"]["b"].ref + + assert ds2.ref == ds3.ref + for row2, row3 in zip(ds2.rows, ds3.rows): + assert row2.ref == row3.ref + assert row2["a"].ref == row3["a"].ref + assert row2["a"]["b"].ref == row3["a"]["b"].ref + + assert ds3.rows == [{"a": {"b": 1}}, {"a": {"b": 2}}, {"a": {"b": 3}}] + for i, row in enumerate(ds3.rows, 1): + assert row == {"a": {"b": i}} + assert row["a"] == {"b": i} + assert row["a"]["b"] == i + + +def test_registration(): + # This is a second class named Dataset. The first has already been registered + # in weave.flow.obj. This should raise an error. + + with pytest.raises(ValueError, match="Class Dataset already registered as"): + + @register_object + class Dataset(Object): + anything: str + doesnt_matter: int diff --git a/weave/trace/objectify.py b/weave/trace/objectify.py index 1cf9444391ad..d9fa26db46cd 100644 --- a/weave/trace/objectify.py +++ b/weave/trace/objectify.py @@ -19,7 +19,10 @@ def from_obj(cls, obj: WeaveObject) -> T_co: ... def register_object(cls: type[T_co]) -> type[T_co]: - _registry[cls.__name__] = cls + cls_name = cls.__name__ + if (existing_cls := _registry.get(cls_name)) is not None: + raise ValueError(f"Class {cls_name} already registered as {existing_cls}") + _registry[cls_name] = cls return cls