diff --git a/tests/trace/test_objects.py b/tests/trace/test_saving_symmetry.py similarity index 72% rename from tests/trace/test_objects.py rename to tests/trace/test_saving_symmetry.py index 29853834f15c..485a4e038fb5 100644 --- a/tests/trace/test_objects.py +++ b/tests/trace/test_saving_symmetry.py @@ -1,3 +1,4 @@ +from copy import deepcopy from typing import Self import pytest @@ -72,7 +73,7 @@ def test_failed_publish_maintains_old_object_ref(client, custom_object, monkeypa def fail_publish(*args, **kwargs): raise Exception("Publish failed") # noqa: TRY002 - m.setattr("weave.publish", fail_publish) + m.setattr("weave._publish", fail_publish) custom_object.publish() assert custom_object.ref == old_ref @@ -87,3 +88,32 @@ class UnregisteredObject(weave.Object): with pytest.raises(NotImplementedError): unregistered_object.publish() + + +def test_saving_symmetry(client, custom_object): + custom_object2 = deepcopy(custom_object) + + ref = weave.publish(custom_object) + ref2 = custom_object2.publish() + + assert ref == ref2 + + +def test_delete_symmetry(client, custom_object): + # publishing the same object will result in dedupe, + # so change a value to test deleting + custom_object2 = deepcopy(custom_object) + custom_object2.a = 2 + + ref = weave.publish(custom_object) + ref2 = weave.publish(custom_object2) + + ref.delete() + obj2 = ref2.get() + obj2.delete() + + with pytest.raises(ObjectDeletedError): + ref.get() + + with pytest.raises(ObjectDeletedError): + ref2.get() diff --git a/weave/flow/obj.py b/weave/flow/obj.py index a8252c2261ed..20217cbd96e7 100644 --- a/weave/flow/obj.py +++ b/weave/flow/obj.py @@ -11,6 +11,7 @@ from typing_extensions import Self from weave.trace import api +from weave.trace.api import publish as weave_publish from weave.trace.objectify import Objectifyable, is_registered from weave.trace.op import ObjectRef, Op from weave.trace.vals import WeaveObject, pydantic_getattribute @@ -69,20 +70,7 @@ def publish(self, name: Union[str, None] = None) -> ObjectRef: if not is_registered(cls_name): raise NotImplementedError("Publish is not supported for this object!") - import weave - - if name is None: - name = self.name - - old_ref, self.ref = self.ref, None - try: - new_ref = weave.publish(self, name) - except Exception: - self.ref = old_ref - raise - - self.ref = new_ref - return new_ref + return weave_publish(self, name) def delete(self) -> None: if self.ref is None: diff --git a/weave/trace/api.py b/weave/trace/api.py index 7acb11781637..39ad728d44ed 100644 --- a/weave/trace/api.py +++ b/weave/trace/api.py @@ -7,7 +7,7 @@ import threading import time from collections.abc import Iterator -from typing import Any +from typing import TYPE_CHECKING, Any # TODO: type_handlers is imported here to trigger registration of the image serializer. # There is probably a better place for this, but including here for now to get the fix in. @@ -28,6 +28,9 @@ from weave.trace.table import Table from weave.trace_server.interface.builtin_object_classes import leaderboard +if TYPE_CHECKING: + from weave.flow.obj import Object + _global_postprocess_inputs: PostprocessInputsFunc | None = None _global_postprocess_output: PostprocessOutputFunc | None = None @@ -120,6 +123,23 @@ def publish(obj: Any, name: str | None = None) -> weave_client.ObjectRef: Returns: A weave Ref to the saved object. """ + import weave + + if not isinstance(obj, weave.Object): + return _publish(obj, name) + + old_ref, obj.ref = obj.ref, None + try: + new_ref = _publish(obj, name) + except Exception: + obj.ref = old_ref + raise + + obj.ref = new_ref + return new_ref + + +def _publish(obj: Any, name: str | None = None) -> weave_client.ObjectRef: client = weave_client_context.require_weave_client() save_name: str @@ -156,9 +176,19 @@ def publish(obj: Any, name: str | None = None) -> weave_client.ObjectRef: ref.digest, ) print(f"{TRACE_OBJECT_EMOJI} Published to {url}") + return ref +def delete(obj: Object | ObjectRef) -> None: + import weave + + if not isinstance(obj, (weave.Object, weave.ObjectRef)): + raise ValueError("Expected an Object or ObjectRef") # noqa: TRY004 + + obj.delete() + + def ref(location: str) -> weave_client.ObjectRef: """Construct a Ref to a Weave object.