diff --git a/tests/trace/test_call_object.py b/tests/trace/test_call_object.py new file mode 100644 index 000000000000..50ca71cde3e6 --- /dev/null +++ b/tests/trace/test_call_object.py @@ -0,0 +1,26 @@ +import weave + + +def test_call_to_dict(client): + @weave.op + def greet(name: str, age: int) -> str: + return f"Hello {name}, you are {age}!" + + _, call = greet.call("Alice", 30) + assert call.to_dict() == { + "op_name": call.op_name, + "display_name": call.display_name, + "inputs": call.inputs, + "output": call.output, + "exception": call.exception, + "summary": call.summary, + "attributes": call.attributes, + "started_at": call.started_at, + "ended_at": call.ended_at, + "deleted_at": call.deleted_at, + "id": call.id, + "parent_id": call.parent_id, + "trace_id": call.trace_id, + "project_id": call.project_id, + } + diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 7d3c4d070f7e..708c2543a1d4 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -15,6 +15,7 @@ Callable, Generic, Protocol, + TypedDict, TypeVar, cast, overload, @@ -336,6 +337,23 @@ def map_to_refs(obj: Any) -> Any: return obj +class CallDict(TypedDict): + op_name: str + trace_id: str + project_id: str + parent_id: str | None + inputs: dict + id: str | None + output: Any + exception: str | None + summary: dict | None + display_name: str | None + attributes: dict | None + started_at: datetime.datetime | None + ended_at: datetime.datetime | None + deleted_at: datetime.datetime | None + + @dataclasses.dataclass class Call: """A Call represents a single operation that was executed as part of a trace.""" @@ -534,12 +552,16 @@ class ApplyScorerSuccess: wc._send_score_call(self, score_call, scorer_ref_uri) return apply_scorer_result - def to_dict(self) -> dict: - d = {k: v for k, v in dataclasses.asdict(self).items() if not k.startswith("_")} + def to_dict(self) -> CallDict: + d = {} + for field in dataclasses.fields(self): + if field.name.startswith("_"): + continue + d[field.name] = getattr(self, field.name) d["op_name"] = self.op_name d["display_name"] = self.display_name - return d + return cast(CallDict, d) def make_client_call(