Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong committed Jan 15, 2025
1 parent 66e64ff commit 58a9c22
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 3 deletions.
26 changes: 26 additions & 0 deletions tests/trace/test_call_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import weave


def test_call_to_dict(client):
@weave.op
def greet(name: str, age: int) -> str:
return f"Hello {name}, you are {age}!"

_, call = greet.call("Alice", 30)
assert call.to_dict() == {
"op_name": call.op_name,
"display_name": call.display_name,
"inputs": call.inputs,
"output": call.output,
"exception": call.exception,
"summary": call.summary,
"attributes": call.attributes,
"started_at": call.started_at,
"ended_at": call.ended_at,
"deleted_at": call.deleted_at,
"id": call.id,
"parent_id": call.parent_id,
"trace_id": call.trace_id,
"project_id": call.project_id,
}

28 changes: 25 additions & 3 deletions weave/trace/weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Callable,
Generic,
Protocol,
TypedDict,
TypeVar,
cast,
overload,
Expand Down Expand Up @@ -336,6 +337,23 @@ def map_to_refs(obj: Any) -> Any:
return obj


class CallDict(TypedDict):
op_name: str
trace_id: str
project_id: str
parent_id: str | None
inputs: dict
id: str | None
output: Any
exception: str | None
summary: dict | None
display_name: str | None
attributes: dict | None
started_at: datetime.datetime | None
ended_at: datetime.datetime | None
deleted_at: datetime.datetime | None


@dataclasses.dataclass
class Call:
"""A Call represents a single operation that was executed as part of a trace."""
Expand Down Expand Up @@ -534,12 +552,16 @@ class ApplyScorerSuccess:
wc._send_score_call(self, score_call, scorer_ref_uri)
return apply_scorer_result

def to_dict(self) -> dict:
d = {k: v for k, v in dataclasses.asdict(self).items() if not k.startswith("_")}
def to_dict(self) -> CallDict:
d = {}
for field in dataclasses.fields(self):
if field.name.startswith("_"):
continue
d[field.name] = getattr(self, field.name)
d["op_name"] = self.op_name
d["display_name"] = self.display_name

return d
return cast(CallDict, d)


def make_client_call(
Expand Down

0 comments on commit 58a9c22

Please sign in to comment.