From 77ce0be5d2609b21594c7535d814becfeafd4718 Mon Sep 17 00:00:00 2001 From: Adrian Swanberg Date: Tue, 14 Jan 2025 14:57:09 -0800 Subject: [PATCH 1/2] chore(weave): Change Weave to W&B Weave on docs site (#3407) * Change Weave to W&B Weave on docs site * revert weave_client changes * actually revert weave_client changes --- docs/docusaurus.config.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/docusaurus.config.ts b/docs/docusaurus.config.ts index d42692a52fda..7aa9072de3ae 100644 --- a/docs/docusaurus.config.ts +++ b/docs/docusaurus.config.ts @@ -125,7 +125,7 @@ const config: Config = { // Replace with your project's social card image: "img/logo-large-padded.png", navbar: { - title: "Weave", + title: "W&B Weave", logo: { alt: "My Site Logo", src: "img/logo.svg", @@ -231,7 +231,7 @@ const config: Config = { ], }, ], - copyright: `Weave by W&B`, + copyright: `Made with ❤️ by Weights & Biases`, }, prism: { // theme: prismThemes.nightOwl, From 173016890fea05317399fc4df12b386356572986 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Tue, 14 Jan 2025 18:31:34 -0500 Subject: [PATCH 2/2] feat(weave): Make ref-getting more pleasant (#3362) --- tests/trace/test_client_trace.py | 4 +- tests/trace/test_deepcopy.py | 1 + tests/trace/test_op_decorator_behaviour.py | 4 +- tests/trace/test_prompt_easy.py | 1 + tests/trace/test_uri_get.py | 96 ++++++++++++++++++++++ tests/trace/test_weave_client_threaded.py | 2 +- weave/flow/dataset.py | 14 +++- weave/flow/eval.py | 16 ++++ weave/flow/obj.py | 12 +++ weave/flow/prompt/prompt.py | 25 ++++-- weave/trace/objectify.py | 37 +++++++++ weave/trace/refs.py | 33 +------- weave/trace/weave_client.py | 9 +- 13 files changed, 208 insertions(+), 46 deletions(-) create mode 100644 tests/trace/test_uri_get.py create mode 100644 weave/trace/objectify.py diff --git a/tests/trace/test_client_trace.py b/tests/trace/test_client_trace.py index 212c0aef7565..36dd8ffcaa00 100644 --- a/tests/trace/test_client_trace.py +++ b/tests/trace/test_client_trace.py @@ -1380,7 +1380,7 @@ def test_dataset_row_ref(client): d2 = weave.ref(ref.uri()).get() inner = d2.rows[0]["a"] - exp_ref = "weave:///shawn/test-project/object/Dataset:0xTDJ6hEmsx8Wg9H75y42bL2WgvW5l4IXjuhHcrMh7A/attr/rows/id/XfhC9dNA5D4taMvhKT4MKN2uce7F56Krsyv4Q6mvVMA/key/a" + exp_ref = "weave:///shawn/test-project/object/Dataset:tiRVKBWTP7LOwjBEqe79WFS7HEibm1WG8nfe94VWZBo/attr/rows/id/XfhC9dNA5D4taMvhKT4MKN2uce7F56Krsyv4Q6mvVMA/key/a" assert inner == 5 assert inner.ref.uri() == exp_ref gotten = weave.ref(exp_ref).get() @@ -2747,7 +2747,7 @@ def test_objects_and_keys_with_special_characters(client): exp_key == "n-a_m.e%3A%20%2F%2B_%28%29%7B%7D%7C%22%27%3C%3E%21%40%24%5E%26%2A%23%3A%2C.%5B%5D-%3D%3B~%60100" ) - exp_digest = "O66Mk7g91rlAUtcGYOFR1Y2Wk94YyPXJy2UEAzDQcYM" + exp_digest = "iVLhViJ3vm8vMMo3Qj35mK7GiyP8jv3OJqasIGXjN0s" exp_obj_ref = f"{ref_base}/object/{exp_name}:{exp_digest}" assert obj.ref.uri() == exp_obj_ref diff --git a/tests/trace/test_deepcopy.py b/tests/trace/test_deepcopy.py index 87380b075f9b..3302eec7e030 100644 --- a/tests/trace/test_deepcopy.py +++ b/tests/trace/test_deepcopy.py @@ -28,6 +28,7 @@ class Example(weave.Object): attrs={ "name": None, "description": None, + "ref": None, "_class_name": "Example", "_bases": ["Object", "BaseModel"], "a": 1, diff --git a/tests/trace/test_op_decorator_behaviour.py b/tests/trace/test_op_decorator_behaviour.py index 03d9b6991a2a..b29fcc8d7e13 100644 --- a/tests/trace/test_op_decorator_behaviour.py +++ b/tests/trace/test_op_decorator_behaviour.py @@ -128,7 +128,7 @@ def test_sync_method_call(client, weave_obj, py_obj): entity="shawn", project="test-project", name="A", - _digest="tGCIGNe9xznnkoJvn2i75TOocSfV7ui1vldSrIP3ZZo", + _digest="nzAe1JtLJFEVeEo3yX0TOYYGhh7vAOFYRentYI9ik6U", _extra=(), ), "a": 1, @@ -163,7 +163,7 @@ async def test_async_method_call(client, weave_obj, py_obj): entity="shawn", project="test-project", name="A", - _digest="tGCIGNe9xznnkoJvn2i75TOocSfV7ui1vldSrIP3ZZo", + _digest="nzAe1JtLJFEVeEo3yX0TOYYGhh7vAOFYRentYI9ik6U", _extra=(), ), "a": 1, diff --git a/tests/trace/test_prompt_easy.py b/tests/trace/test_prompt_easy.py index 6d01db92a9f6..7086f8eff062 100644 --- a/tests/trace/test_prompt_easy.py +++ b/tests/trace/test_prompt_easy.py @@ -240,6 +240,7 @@ def test_prompt_as_pydantic_dict(): assert prompt.as_pydantic_dict() == { "name": None, "description": None, + "ref": None, "config": { "model": "gpt-4o", "temperature": 0.8, diff --git a/tests/trace/test_uri_get.py b/tests/trace/test_uri_get.py new file mode 100644 index 000000000000..55c5595c3ac9 --- /dev/null +++ b/tests/trace/test_uri_get.py @@ -0,0 +1,96 @@ +import pytest + +import weave +from weave.flow.prompt.prompt import EasyPrompt +from weave.trace_server.trace_server_interface import ObjectVersionFilter, ObjQueryReq + + +@pytest.fixture( + params=[ + "dataset", + "evaluation", + "string_prompt", + "messages_prompt", + "easy_prompt", + ] +) +def obj(request): + examples = [ + {"question": "What is 2+2?", "expected": "4"}, + {"question": "What is 3+3?", "expected": "6"}, + ] + + if request.param == "dataset": + return weave.Dataset(rows=examples) + elif request.param == "evaluation": + return weave.Evaluation(dataset=examples) + elif request.param == "string_prompt": + return weave.StringPrompt("Hello, world!") + elif request.param == "messages_prompt": + return weave.MessagesPrompt([{"role": "user", "content": "Hello, world!"}]) + elif request.param == "easy_prompt": + return weave.EasyPrompt("Hello world!") + + +def test_ref_get(client, obj): + ref = weave.publish(obj) + + obj_cls = type(obj) + obj2 = obj_cls.from_uri(ref.uri()) + obj3 = ref.get() + assert isinstance(obj2, obj_cls) + assert isinstance(obj3, obj_cls) + + for field_name in obj.model_fields: + obj_field_val = getattr(obj, field_name) + obj2_field_val = getattr(obj2, field_name) + obj3_field_val = getattr(obj3, field_name) + + # This is a special case for EasyPrompt's unique init signature where `config` + # represents the kwargs passed into the class itself. Since the original object + # has not been published, there is no ref and the key is omitted from the first + # `config`. After publishing, there is a ref so the `config` dict has an + # additional `ref` key. For comparison purposes, we pop the key to ensure the + # rest of the config dict is the same. + if obj_cls is EasyPrompt and field_name == "config": + obj2_field_val.pop("ref") + obj3_field_val.pop("ref") + + assert obj_field_val == obj2_field_val + assert obj_field_val == obj3_field_val + + +@pytest.mark.asyncio +async def test_gotten_methods(client): + @weave.op + def model(a: int) -> int: + return a + 1 + + ev = weave.Evaluation(dataset=[{"a": 1}]) + await ev.evaluate(model) + ref = weave.publish(ev) + + ev2 = weave.Evaluation.from_uri(ref.uri()) + await ev2.evaluate(model) + + # Ensure that the Evaluation object we get back is equivalent to the one published. + # If they are the same, calling evaluate again should not publish new versions of any + # relevant objects of ops. + relevant_object_ids = [ + "model", + "Evaluation.evaluate", + "Evaluation.predict_and_score", + "Evaluation.summarize", + "Dataset", + "example_evaluation", + ] + # TODO: Replace with client version of this query when available + res = client.server.objs_query( + ObjQueryReq( + project_id=client._project_id(), + filter=ObjectVersionFilter(object_ids=relevant_object_ids), + ) + ) + for obj in res.objs: + assert obj.version_index == 0 + assert obj.is_latest == 1 diff --git a/tests/trace/test_weave_client_threaded.py b/tests/trace/test_weave_client_threaded.py index e1b5617cb9bf..2878d778d40e 100644 --- a/tests/trace/test_weave_client_threaded.py +++ b/tests/trace/test_weave_client_threaded.py @@ -37,7 +37,7 @@ def test_flask_server(flask_server): url = flask_server response = requests.get(url) assert response.status_code == 200 - assert response.text == "0xTDJ6hEmsx8Wg9H75y42bL2WgvW5l4IXjuhHcrMh7A" + assert response.text == "tiRVKBWTP7LOwjBEqe79WFS7HEibm1WG8nfe94VWZBo" def test_weave_client_global_accessible_in_thread(client): diff --git a/weave/flow/dataset.py b/weave/flow/dataset.py index 0bcd9c60b81d..051c6fb93bdc 100644 --- a/weave/flow/dataset.py +++ b/weave/flow/dataset.py @@ -2,10 +2,12 @@ from typing import Any from pydantic import field_validator +from typing_extensions import Self import weave from weave.flow.obj import Object -from weave.trace.vals import WeaveTable +from weave.trace.objectify import register_object +from weave.trace.vals import WeaveObject, WeaveTable def short_str(obj: Any, limit: int = 25) -> str: @@ -15,6 +17,7 @@ def short_str(obj: Any, limit: int = 25) -> str: return str_val +@register_object class Dataset(Object): """ Dataset object with easy saving and automatic versioning @@ -42,6 +45,15 @@ class Dataset(Object): rows: weave.Table + @classmethod + def from_obj(cls, obj: WeaveObject) -> Self: + return cls( + name=obj.name, + description=obj.description, + ref=obj.ref, + rows=obj.rows, + ) + @field_validator("rows", mode="before") def convert_to_table(cls, rows: Any) -> weave.Table: if not isinstance(rows, weave.Table): diff --git a/weave/flow/eval.py b/weave/flow/eval.py index f540dfb1dfe8..38988d2aa3f9 100644 --- a/weave/flow/eval.py +++ b/weave/flow/eval.py @@ -7,6 +7,7 @@ from pydantic import PrivateAttr, model_validator from rich import print from rich.console import Console +from typing_extensions import Self import weave from weave.flow import util @@ -30,6 +31,7 @@ from weave.trace.env import get_weave_parallelism from weave.trace.errors import OpCallError from weave.trace.isinstance import weave_isinstance +from weave.trace.objectify import register_object from weave.trace.op import CallDisplayNameFunc, Op, as_op, is_op from weave.trace.vals import WeaveObject from weave.trace.weave_client import Call, get_ref @@ -57,6 +59,7 @@ class EvaluationResults(Object): ScorerLike = Union[Callable, Op, Scorer] +@register_object class Evaluation(Object): """ Sets up an evaluation which includes a set of scorers and a dataset. @@ -114,6 +117,19 @@ def function_to_evaluate(question: str): # internal attr to track whether to use the new `output` or old `model_output` key for outputs _output_key: Literal["output", "model_output"] = PrivateAttr("output") + @classmethod + def from_obj(cls, obj: WeaveObject) -> Self: + return cls( + name=obj.name, + description=obj.description, + ref=obj.ref, + dataset=obj.dataset, + scorers=obj.scorers, + preprocess_model_input=obj.preprocess_model_input, + trials=obj.trials, + evaluation_name=obj.evaluation_name, + ) + @model_validator(mode="after") def _update_display_name(self) -> "Evaluation": if self.evaluation_name: diff --git a/weave/flow/obj.py b/weave/flow/obj.py index b59067608d3c..2720da8d6066 100644 --- a/weave/flow/obj.py +++ b/weave/flow/obj.py @@ -8,7 +8,10 @@ ValidatorFunctionWrapHandler, model_validator, ) +from typing_extensions import Self +from weave.trace import api +from weave.trace.objectify import Objectifyable from weave.trace.op import ObjectRef, Op from weave.trace.vals import WeaveObject, pydantic_getattribute from weave.trace.weave_client import get_ref @@ -38,6 +41,7 @@ def setter(self: Any, value: T) -> None: class Object(BaseModel): name: Optional[str] = None description: Optional[str] = None + ref: Optional[ObjectRef] = None # Allow Op attributes model_config = ConfigDict( @@ -51,6 +55,14 @@ class Object(BaseModel): __str__ = BaseModel.__repr__ + @classmethod + def from_uri(cls, uri: str, *, objectify: bool = True) -> Self: + if not isinstance(cls, Objectifyable): + raise NotImplementedError( + f"`{cls.__name__}` must implement `from_obj` to support deserialization from a URI." + ) + return api.ref(uri).get(objectify=objectify) + # This is a "wrap" validator meaning we can run our own logic before # and after the standard pydantic validation. @model_validator(mode="wrap") diff --git a/weave/flow/prompt/prompt.py b/weave/flow/prompt/prompt.py index df18eaab88b1..7a32520f08bd 100644 --- a/weave/flow/prompt/prompt.py +++ b/weave/flow/prompt/prompt.py @@ -5,17 +5,20 @@ import textwrap from collections import UserList from pathlib import Path -from typing import IO, Any, Optional, SupportsIndex, TypedDict, Union, overload +from typing import IO, Any, Optional, SupportsIndex, TypedDict, Union, cast, overload from pydantic import Field from rich.table import Table +from typing_extensions import Self from weave.flow.obj import Object from weave.flow.prompt.common import ROLE_COLORS, color_role from weave.trace.api import publish as weave_publish +from weave.trace.objectify import register_object from weave.trace.op import op from weave.trace.refs import ObjectRef from weave.trace.rich import pydantic_util +from weave.trace.vals import WeaveObject class Message(TypedDict): @@ -76,6 +79,7 @@ def format(self, **kwargs: Any) -> Any: raise NotImplementedError("Subclasses must implement format()") +@register_object class StringPrompt(Prompt): content: str = "" @@ -87,13 +91,15 @@ def format(self, **kwargs: Any) -> str: return self.content.format(**kwargs) @classmethod - def from_obj(cls, obj: Any) -> "StringPrompt": + def from_obj(cls, obj: WeaveObject) -> Self: prompt = cls(content=obj.content) prompt.name = obj.name prompt.description = obj.description + prompt.ref = cast(ObjectRef, obj.ref) return prompt +@register_object class MessagesPrompt(Prompt): messages: list[dict] = Field(default_factory=list) @@ -114,13 +120,15 @@ def format(self, **kwargs: Any) -> list: return [self.format_message(m, **kwargs) for m in self.messages] @classmethod - def from_obj(cls, obj: Any) -> "MessagesPrompt": + def from_obj(cls, obj: WeaveObject) -> Self: prompt = cls(messages=obj.messages) prompt.name = obj.name prompt.description = obj.description + prompt.ref = cast(ObjectRef, obj.ref) return prompt +@register_object class EasyPrompt(UserList, Prompt): data: list = Field(default_factory=list) config: dict = Field(default_factory=dict) @@ -420,7 +428,7 @@ def as_dict(self) -> dict[str, Any]: } @classmethod - def from_obj(cls, obj: Any) -> "EasyPrompt": + def from_obj(cls, obj: WeaveObject) -> Self: messages = obj.messages if hasattr(obj, "messages") else obj.data messages = [dict(m) for m in messages] config = dict(obj.config) @@ -428,13 +436,14 @@ def from_obj(cls, obj: Any) -> "EasyPrompt": return cls( name=obj.name, description=obj.description, + ref=obj.ref, messages=messages, config=config, requirements=requirements, ) - @staticmethod - def load(fp: IO) -> "EasyPrompt": + @classmethod + def load(cls, fp: IO) -> Self: if isinstance(fp, str): # Common mistake raise TypeError( "Prompt.load() takes a file-like object, not a string. Did you mean Prompt.e()?" @@ -443,8 +452,8 @@ def load(fp: IO) -> "EasyPrompt": prompt = EasyPrompt(**data) return prompt - @staticmethod - def load_file(filepath: Union[str, Path]) -> "Prompt": + @classmethod + def load_file(cls, filepath: Union[str, Path]) -> Self: expanded_path = os.path.expanduser(str(filepath)) with open(expanded_path) as f: return EasyPrompt.load(f) diff --git a/weave/trace/objectify.py b/weave/trace/objectify.py new file mode 100644 index 000000000000..1cf9444391ad --- /dev/null +++ b/weave/trace/objectify.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol, TypeVar, runtime_checkable + +if TYPE_CHECKING: + from weave.flow.obj import Object + from weave.trace.vals import WeaveObject + +T_co = TypeVar("T_co", bound="Object", covariant=True) + + +@runtime_checkable +class Objectifyable(Protocol[T_co]): + @classmethod + def from_obj(cls, obj: WeaveObject) -> T_co: ... + + +_registry: dict[str, type[Object]] = {} + + +def register_object(cls: type[T_co]) -> type[T_co]: + _registry[cls.__name__] = cls + return cls + + +def maybe_objectify(obj: WeaveObject) -> T_co | WeaveObject: + if (cls_name := getattr(obj, "_class_name", None)) is None: + return obj + + if (cls := _registry.get(cls_name)) is None: + return obj + + res = cls.from_obj(obj) + if ref := getattr(obj, "ref", None): + res.ref = ref + + return res diff --git a/weave/trace/refs.py b/weave/trace/refs.py index 9bec07d69757..3307cadb708c 100644 --- a/weave/trace/refs.py +++ b/weave/trace/refs.py @@ -164,32 +164,7 @@ def uri(self) -> str: u += "/" + "/".join(refs_internal.extra_value_quoter(e) for e in self.extra) return u - def objectify(self, obj: Any) -> Any: - """Convert back to higher level object.""" - class_name = getattr(obj, "_class_name", None) - if "EasyPrompt" == class_name: - from weave.flow.prompt.prompt import EasyPrompt - - prompt = EasyPrompt.from_obj(obj) - # We want to use the ref on the object (and not self) as it will have had - # version number or latest alias resolved to a specific digest. - prompt.__dict__["ref"] = obj.ref - return prompt - if "StringPrompt" == class_name: - from weave.flow.prompt.prompt import StringPrompt - - prompt = StringPrompt.from_obj(obj) - prompt.__dict__["ref"] = obj.ref - return prompt - if "MessagesPrompt" == class_name: - from weave.flow.prompt.prompt import MessagesPrompt - - prompt = MessagesPrompt.from_obj(obj) - prompt.__dict__["ref"] = obj.ref - return prompt - return obj - - def get(self) -> Any: + def get(self, *, objectify: bool = True) -> Any: # Move import here so that it only happens when the function is called. # This import is invalid in the trace server and represents a dependency # that should be removed. @@ -198,7 +173,7 @@ def get(self) -> Any: gc = get_weave_client() if gc is not None: - return self.objectify(gc.get(self)) + return gc.get(self, objectify=objectify) # Special case: If the user is attempting to fetch an object but has not # yet initialized the client, we can initialize a client to @@ -208,10 +183,10 @@ def get(self) -> Any: f"{self.entity}/{self.project}", ensure_project_exists=False ) try: - res = init_client.client.get(self) + res = init_client.client.get(self, objectify=objectify) finally: init_client.reset() - return self.objectify(res) + return res def is_descended_from(self, potential_ancestor: ObjectRef) -> bool: if self.entity != potential_ancestor.entity: diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 6e5d36695936..8003a0fc3754 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -37,6 +37,7 @@ dataclass_object_record, pydantic_object_record, ) +from weave.trace.objectify import maybe_objectify from weave.trace.op import Op, as_op, is_op, maybe_unbind_method from weave.trace.op import op as op_deco from weave.trace.refs import ( @@ -686,7 +687,7 @@ def save(self, val: Any, name: str, branch: str = "latest") -> Any: return self.get(ref) @trace_sentry.global_trace_sentry.watch() - def get(self, ref: ObjectRef) -> Any: + def get(self, ref: ObjectRef, *, objectify: bool = True) -> Any: project_id = f"{ref.entity}/{ref.project}" try: read_res = self.server.obj_read( @@ -736,8 +737,10 @@ def get(self, ref: ObjectRef) -> Any: data = read_res.obj.val val = from_json(data, project_id, self.server) - - return make_trace_obj(val, ref, self.server, None) + weave_obj = make_trace_obj(val, ref, self.server, None) + if objectify: + return maybe_objectify(weave_obj) + return weave_obj ################ Query API ################