Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong committed Jan 22, 2025
1 parent e427eae commit df999ce
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from typing import Self

import pytest
Expand Down Expand Up @@ -72,7 +73,7 @@ def test_failed_publish_maintains_old_object_ref(client, custom_object, monkeypa
def fail_publish(*args, **kwargs):
raise Exception("Publish failed") # noqa: TRY002

m.setattr("weave.publish", fail_publish)
m.setattr("weave._publish", fail_publish)
custom_object.publish()

assert custom_object.ref == old_ref
Expand All @@ -87,3 +88,32 @@ class UnregisteredObject(weave.Object):

with pytest.raises(NotImplementedError):
unregistered_object.publish()


def test_saving_symmetry(client, custom_object):
custom_object2 = deepcopy(custom_object)

ref = weave.publish(custom_object)
ref2 = custom_object2.publish()

assert ref == ref2


def test_delete_symmetry(client, custom_object):
# publishing the same object will result in dedupe,
# so change a value to test deleting
custom_object2 = deepcopy(custom_object)
custom_object2.a = 2

ref = weave.publish(custom_object)
ref2 = weave.publish(custom_object2)

ref.delete()
obj2 = ref2.get()
obj2.delete()

with pytest.raises(ObjectDeletedError):
ref.get()

with pytest.raises(ObjectDeletedError):
ref2.get()
16 changes: 2 additions & 14 deletions weave/flow/obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing_extensions import Self

from weave.trace import api
from weave.trace.api import publish as weave_publish
from weave.trace.objectify import Objectifyable, is_registered
from weave.trace.op import ObjectRef, Op
from weave.trace.vals import WeaveObject, pydantic_getattribute
Expand Down Expand Up @@ -69,20 +70,7 @@ def publish(self, name: Union[str, None] = None) -> ObjectRef:
if not is_registered(cls_name):
raise NotImplementedError("Publish is not supported for this object!")

import weave

if name is None:
name = self.name

old_ref, self.ref = self.ref, None
try:
new_ref = weave.publish(self, name)
except Exception:
self.ref = old_ref
raise

self.ref = new_ref
return new_ref
return weave_publish(self, name)

def delete(self) -> None:
if self.ref is None:
Expand Down
32 changes: 31 additions & 1 deletion weave/trace/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import threading
import time
from collections.abc import Iterator
from typing import Any
from typing import TYPE_CHECKING, Any

# TODO: type_handlers is imported here to trigger registration of the image serializer.
# There is probably a better place for this, but including here for now to get the fix in.
Expand All @@ -28,6 +28,9 @@
from weave.trace.table import Table
from weave.trace_server.interface.builtin_object_classes import leaderboard

if TYPE_CHECKING:
from weave.flow.obj import Object

_global_postprocess_inputs: PostprocessInputsFunc | None = None
_global_postprocess_output: PostprocessOutputFunc | None = None

Expand Down Expand Up @@ -120,6 +123,23 @@ def publish(obj: Any, name: str | None = None) -> weave_client.ObjectRef:
Returns:
A weave Ref to the saved object.
"""
import weave

if not isinstance(obj, weave.Object):
return _publish(obj, name)

old_ref, obj.ref = obj.ref, None
try:
new_ref = _publish(obj, name)
except Exception:
obj.ref = old_ref
raise

obj.ref = new_ref
return new_ref


def _publish(obj: Any, name: str | None = None) -> weave_client.ObjectRef:
client = weave_client_context.require_weave_client()

save_name: str
Expand Down Expand Up @@ -156,9 +176,19 @@ def publish(obj: Any, name: str | None = None) -> weave_client.ObjectRef:
ref.digest,
)
print(f"{TRACE_OBJECT_EMOJI} Published to {url}")

return ref


def delete(obj: Object | ObjectRef) -> None:
import weave

if not isinstance(obj, (weave.Object, weave.ObjectRef)):
raise ValueError("Expected an Object or ObjectRef") # noqa: TRY004

obj.delete()


def ref(location: str) -> weave_client.ObjectRef:
"""Construct a Ref to a Weave object.
Expand Down

0 comments on commit df999ce

Please sign in to comment.