Skip to content

Commit

Permalink
Merge branch 'master' into griffin/get_calls-feature-parity
Browse files Browse the repository at this point in the history
  • Loading branch information
gtarpenning committed Jan 15, 2025
2 parents a0ee090 + 1730168 commit bf31740
Show file tree
Hide file tree
Showing 14 changed files with 210 additions and 48 deletions.
4 changes: 2 additions & 2 deletions docs/docusaurus.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ const config: Config = {
// Replace with your project's social card
image: "img/logo-large-padded.png",
navbar: {
title: "Weave",
title: "W&B Weave",
logo: {
alt: "My Site Logo",
src: "img/logo.svg",
Expand Down Expand Up @@ -231,7 +231,7 @@ const config: Config = {
],
},
],
copyright: `Weave by W&B`,
copyright: `Made with ❤️ by Weights & Biases`,
},
prism: {
// theme: prismThemes.nightOwl,
Expand Down
4 changes: 2 additions & 2 deletions tests/trace/test_client_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -1380,7 +1380,7 @@ def test_dataset_row_ref(client):
d2 = weave.ref(ref.uri()).get()

inner = d2.rows[0]["a"]
exp_ref = "weave:///shawn/test-project/object/Dataset:0xTDJ6hEmsx8Wg9H75y42bL2WgvW5l4IXjuhHcrMh7A/attr/rows/id/XfhC9dNA5D4taMvhKT4MKN2uce7F56Krsyv4Q6mvVMA/key/a"
exp_ref = "weave:///shawn/test-project/object/Dataset:tiRVKBWTP7LOwjBEqe79WFS7HEibm1WG8nfe94VWZBo/attr/rows/id/XfhC9dNA5D4taMvhKT4MKN2uce7F56Krsyv4Q6mvVMA/key/a"
assert inner == 5
assert inner.ref.uri() == exp_ref
gotten = weave.ref(exp_ref).get()
Expand Down Expand Up @@ -2747,7 +2747,7 @@ def test_objects_and_keys_with_special_characters(client):
exp_key
== "n-a_m.e%3A%20%2F%2B_%28%29%7B%7D%7C%22%27%3C%3E%21%40%24%5E%26%2A%23%3A%2C.%5B%5D-%3D%3B~%60100"
)
exp_digest = "O66Mk7g91rlAUtcGYOFR1Y2Wk94YyPXJy2UEAzDQcYM"
exp_digest = "iVLhViJ3vm8vMMo3Qj35mK7GiyP8jv3OJqasIGXjN0s"

exp_obj_ref = f"{ref_base}/object/{exp_name}:{exp_digest}"
assert obj.ref.uri() == exp_obj_ref
Expand Down
1 change: 1 addition & 0 deletions tests/trace/test_deepcopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class Example(weave.Object):
attrs={
"name": None,
"description": None,
"ref": None,
"_class_name": "Example",
"_bases": ["Object", "BaseModel"],
"a": 1,
Expand Down
4 changes: 2 additions & 2 deletions tests/trace/test_op_decorator_behaviour.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_sync_method_call(client, weave_obj, py_obj):
entity="shawn",
project="test-project",
name="A",
_digest="tGCIGNe9xznnkoJvn2i75TOocSfV7ui1vldSrIP3ZZo",
_digest="nzAe1JtLJFEVeEo3yX0TOYYGhh7vAOFYRentYI9ik6U",
_extra=(),
),
"a": 1,
Expand Down Expand Up @@ -163,7 +163,7 @@ async def test_async_method_call(client, weave_obj, py_obj):
entity="shawn",
project="test-project",
name="A",
_digest="tGCIGNe9xznnkoJvn2i75TOocSfV7ui1vldSrIP3ZZo",
_digest="nzAe1JtLJFEVeEo3yX0TOYYGhh7vAOFYRentYI9ik6U",
_extra=(),
),
"a": 1,
Expand Down
1 change: 1 addition & 0 deletions tests/trace/test_prompt_easy.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def test_prompt_as_pydantic_dict():
assert prompt.as_pydantic_dict() == {
"name": None,
"description": None,
"ref": None,
"config": {
"model": "gpt-4o",
"temperature": 0.8,
Expand Down
96 changes: 96 additions & 0 deletions tests/trace/test_uri_get.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import pytest

import weave
from weave.flow.prompt.prompt import EasyPrompt
from weave.trace_server.trace_server_interface import ObjectVersionFilter, ObjQueryReq


@pytest.fixture(
params=[
"dataset",
"evaluation",
"string_prompt",
"messages_prompt",
"easy_prompt",
]
)
def obj(request):
examples = [
{"question": "What is 2+2?", "expected": "4"},
{"question": "What is 3+3?", "expected": "6"},
]

if request.param == "dataset":
return weave.Dataset(rows=examples)
elif request.param == "evaluation":
return weave.Evaluation(dataset=examples)
elif request.param == "string_prompt":
return weave.StringPrompt("Hello, world!")
elif request.param == "messages_prompt":
return weave.MessagesPrompt([{"role": "user", "content": "Hello, world!"}])
elif request.param == "easy_prompt":
return weave.EasyPrompt("Hello world!")


def test_ref_get(client, obj):
ref = weave.publish(obj)

obj_cls = type(obj)
obj2 = obj_cls.from_uri(ref.uri())
obj3 = ref.get()
assert isinstance(obj2, obj_cls)
assert isinstance(obj3, obj_cls)

for field_name in obj.model_fields:
obj_field_val = getattr(obj, field_name)
obj2_field_val = getattr(obj2, field_name)
obj3_field_val = getattr(obj3, field_name)

# This is a special case for EasyPrompt's unique init signature where `config`
# represents the kwargs passed into the class itself. Since the original object
# has not been published, there is no ref and the key is omitted from the first
# `config`. After publishing, there is a ref so the `config` dict has an
# additional `ref` key. For comparison purposes, we pop the key to ensure the
# rest of the config dict is the same.
if obj_cls is EasyPrompt and field_name == "config":
obj2_field_val.pop("ref")
obj3_field_val.pop("ref")

assert obj_field_val == obj2_field_val
assert obj_field_val == obj3_field_val


@pytest.mark.asyncio
async def test_gotten_methods(client):
@weave.op
def model(a: int) -> int:
return a + 1

ev = weave.Evaluation(dataset=[{"a": 1}])
await ev.evaluate(model)
ref = weave.publish(ev)

ev2 = weave.Evaluation.from_uri(ref.uri())
await ev2.evaluate(model)

# Ensure that the Evaluation object we get back is equivalent to the one published.
# If they are the same, calling evaluate again should not publish new versions of any
# relevant objects of ops.
relevant_object_ids = [
"model",
"Evaluation.evaluate",
"Evaluation.predict_and_score",
"Evaluation.summarize",
"Dataset",
"example_evaluation",
]
# TODO: Replace with client version of this query when available
res = client.server.objs_query(
ObjQueryReq(
project_id=client._project_id(),
filter=ObjectVersionFilter(object_ids=relevant_object_ids),
)
)
for obj in res.objs:
assert obj.version_index == 0
assert obj.is_latest == 1
2 changes: 1 addition & 1 deletion tests/trace/test_weave_client_threaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_flask_server(flask_server):
url = flask_server
response = requests.get(url)
assert response.status_code == 200
assert response.text == "0xTDJ6hEmsx8Wg9H75y42bL2WgvW5l4IXjuhHcrMh7A"
assert response.text == "tiRVKBWTP7LOwjBEqe79WFS7HEibm1WG8nfe94VWZBo"


def test_weave_client_global_accessible_in_thread(client):
Expand Down
14 changes: 13 additions & 1 deletion weave/flow/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from typing import Any

from pydantic import field_validator
from typing_extensions import Self

import weave
from weave.flow.obj import Object
from weave.trace.vals import WeaveTable
from weave.trace.objectify import register_object
from weave.trace.vals import WeaveObject, WeaveTable


def short_str(obj: Any, limit: int = 25) -> str:
Expand All @@ -15,6 +17,7 @@ def short_str(obj: Any, limit: int = 25) -> str:
return str_val


@register_object
class Dataset(Object):
"""
Dataset object with easy saving and automatic versioning
Expand Down Expand Up @@ -42,6 +45,15 @@ class Dataset(Object):

rows: weave.Table

@classmethod
def from_obj(cls, obj: WeaveObject) -> Self:
return cls(
name=obj.name,
description=obj.description,
ref=obj.ref,
rows=obj.rows,
)

@field_validator("rows", mode="before")
def convert_to_table(cls, rows: Any) -> weave.Table:
if not isinstance(rows, weave.Table):
Expand Down
16 changes: 16 additions & 0 deletions weave/flow/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pydantic import PrivateAttr, model_validator
from rich import print
from rich.console import Console
from typing_extensions import Self

import weave
from weave.flow import util
Expand All @@ -30,6 +31,7 @@
from weave.trace.env import get_weave_parallelism
from weave.trace.errors import OpCallError
from weave.trace.isinstance import weave_isinstance
from weave.trace.objectify import register_object
from weave.trace.op import CallDisplayNameFunc, Op, as_op, is_op
from weave.trace.vals import WeaveObject
from weave.trace.weave_client import Call, get_ref
Expand Down Expand Up @@ -57,6 +59,7 @@ class EvaluationResults(Object):
ScorerLike = Union[Callable, Op, Scorer]


@register_object
class Evaluation(Object):
"""
Sets up an evaluation which includes a set of scorers and a dataset.
Expand Down Expand Up @@ -114,6 +117,19 @@ def function_to_evaluate(question: str):
# internal attr to track whether to use the new `output` or old `model_output` key for outputs
_output_key: Literal["output", "model_output"] = PrivateAttr("output")

@classmethod
def from_obj(cls, obj: WeaveObject) -> Self:
return cls(
name=obj.name,
description=obj.description,
ref=obj.ref,
dataset=obj.dataset,
scorers=obj.scorers,
preprocess_model_input=obj.preprocess_model_input,
trials=obj.trials,
evaluation_name=obj.evaluation_name,
)

@model_validator(mode="after")
def _update_display_name(self) -> "Evaluation":
if self.evaluation_name:
Expand Down
12 changes: 12 additions & 0 deletions weave/flow/obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
ValidatorFunctionWrapHandler,
model_validator,
)
from typing_extensions import Self

from weave.trace import api
from weave.trace.objectify import Objectifyable
from weave.trace.op import ObjectRef, Op
from weave.trace.vals import WeaveObject, pydantic_getattribute
from weave.trace.weave_client import get_ref
Expand Down Expand Up @@ -38,6 +41,7 @@ def setter(self: Any, value: T) -> None:
class Object(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
ref: Optional[ObjectRef] = None

# Allow Op attributes
model_config = ConfigDict(
Expand All @@ -51,6 +55,14 @@ class Object(BaseModel):

__str__ = BaseModel.__repr__

@classmethod
def from_uri(cls, uri: str, *, objectify: bool = True) -> Self:
if not isinstance(cls, Objectifyable):
raise NotImplementedError(
f"`{cls.__name__}` must implement `from_obj` to support deserialization from a URI."
)
return api.ref(uri).get(objectify=objectify)

# This is a "wrap" validator meaning we can run our own logic before
# and after the standard pydantic validation.
@model_validator(mode="wrap")
Expand Down
25 changes: 17 additions & 8 deletions weave/flow/prompt/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@
import textwrap
from collections import UserList
from pathlib import Path
from typing import IO, Any, Optional, SupportsIndex, TypedDict, Union, overload
from typing import IO, Any, Optional, SupportsIndex, TypedDict, Union, cast, overload

from pydantic import Field
from rich.table import Table
from typing_extensions import Self

from weave.flow.obj import Object
from weave.flow.prompt.common import ROLE_COLORS, color_role
from weave.trace.api import publish as weave_publish
from weave.trace.objectify import register_object
from weave.trace.op import op
from weave.trace.refs import ObjectRef
from weave.trace.rich import pydantic_util
from weave.trace.vals import WeaveObject


class Message(TypedDict):
Expand Down Expand Up @@ -76,6 +79,7 @@ def format(self, **kwargs: Any) -> Any:
raise NotImplementedError("Subclasses must implement format()")


@register_object
class StringPrompt(Prompt):
content: str = ""

Expand All @@ -87,13 +91,15 @@ def format(self, **kwargs: Any) -> str:
return self.content.format(**kwargs)

@classmethod
def from_obj(cls, obj: Any) -> "StringPrompt":
def from_obj(cls, obj: WeaveObject) -> Self:
prompt = cls(content=obj.content)
prompt.name = obj.name
prompt.description = obj.description
prompt.ref = cast(ObjectRef, obj.ref)
return prompt


@register_object
class MessagesPrompt(Prompt):
messages: list[dict] = Field(default_factory=list)

Expand All @@ -114,13 +120,15 @@ def format(self, **kwargs: Any) -> list:
return [self.format_message(m, **kwargs) for m in self.messages]

@classmethod
def from_obj(cls, obj: Any) -> "MessagesPrompt":
def from_obj(cls, obj: WeaveObject) -> Self:
prompt = cls(messages=obj.messages)
prompt.name = obj.name
prompt.description = obj.description
prompt.ref = cast(ObjectRef, obj.ref)
return prompt


@register_object
class EasyPrompt(UserList, Prompt):
data: list = Field(default_factory=list)
config: dict = Field(default_factory=dict)
Expand Down Expand Up @@ -420,21 +428,22 @@ def as_dict(self) -> dict[str, Any]:
}

@classmethod
def from_obj(cls, obj: Any) -> "EasyPrompt":
def from_obj(cls, obj: WeaveObject) -> Self:
messages = obj.messages if hasattr(obj, "messages") else obj.data
messages = [dict(m) for m in messages]
config = dict(obj.config)
requirements = dict(obj.requirements)
return cls(
name=obj.name,
description=obj.description,
ref=obj.ref,
messages=messages,
config=config,
requirements=requirements,
)

@staticmethod
def load(fp: IO) -> "EasyPrompt":
@classmethod
def load(cls, fp: IO) -> Self:
if isinstance(fp, str): # Common mistake
raise TypeError(
"Prompt.load() takes a file-like object, not a string. Did you mean Prompt.e()?"
Expand All @@ -443,8 +452,8 @@ def load(fp: IO) -> "EasyPrompt":
prompt = EasyPrompt(**data)
return prompt

@staticmethod
def load_file(filepath: Union[str, Path]) -> "Prompt":
@classmethod
def load_file(cls, filepath: Union[str, Path]) -> Self:
expanded_path = os.path.expanduser(str(filepath))
with open(expanded_path) as f:
return EasyPrompt.load(f)
Expand Down
Loading

0 comments on commit bf31740

Please sign in to comment.