From 8b3692d6bbd065b6e4dfe014a4bffd605cbc059e Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 22 Jan 2025 01:47:24 -0500 Subject: [PATCH] test --- tests/trace/test_evaluations.py | 13 ++++++++----- weave/flow/eval.py | 16 +++++++++++----- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/tests/trace/test_evaluations.py b/tests/trace/test_evaluations.py index 48db2d0c588b..5886afb3287c 100644 --- a/tests/trace/test_evaluations.py +++ b/tests/trace/test_evaluations.py @@ -1055,18 +1055,21 @@ def model(input: str) -> str: assert call.display_name == "wow-custom!" -@pytest.mark.xfail( - reason="TODO: This test does not seem to work with the sqlite test server" -) @pytest.mark.asyncio async def test_get_evaluation_results(client): @weave.op def model(a: int, b: int) -> int: return a + b - ev = weave.Evaluation(dataset=[{"a": 1, "b": 2}]) + ev = weave.Evaluation( + dataset=[{"a": 1, "b": 2}], + # The evaluation name here is a hack for tests and is not required on the prod + # trace server. Sqlite seems to have a memory of the previous calls, and this + # was the only way I found to get the evaluation name to be unique in the test. + evaluation_name=lambda call: call.id, + ) await ev.evaluate(model) await ev.evaluate(model) + res = ev.get_evaluation_calls() - res = ev.get_evaluation_results() assert len(res) == 2 diff --git a/weave/flow/eval.py b/weave/flow/eval.py index be78d058a424..61804e7467cc 100644 --- a/weave/flow/eval.py +++ b/weave/flow/eval.py @@ -320,8 +320,8 @@ async def evaluate(self, model: Union[Op, Model]) -> dict: return summary - def get_evaluation_results(self) -> dict[str, CallsIter]: - if not self.evaluate.ref: + def get_evaluation_calls(self) -> dict[str, CallsIter]: + if not (eval_ref := self.evaluate.ref): raise ValueError( "Evaluation must be run or published before calling get_evaluation_results" ) @@ -330,10 +330,16 @@ def get_evaluation_results(self) -> dict[str, CallsIter]: eval_calls = self.evaluate.calls() wc = weave_client_context.require_weave_client() for call in eval_calls: - if call.display_name and isinstance(call.display_name, str): - res[call.display_name] = wc.get_calls( + display_name = call.display_name + if display_name and isinstance(display_name, str): + if display_name in res: + logger.warning( + f"Duplicate display name {display_name} found in evaluation results; omitting some results..." + ) + continue + res[display_name] = wc.get_calls( filter=CallsFilter( - parent_ids=[call.id], input_refs=[self.ref.uri()] + parent_ids=[call.id], input_refs=[eval_ref.uri()] ) ) return res