Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong committed Jan 22, 2025
1 parent 5deee37 commit 8b3692d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
13 changes: 8 additions & 5 deletions tests/trace/test_evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,18 +1055,21 @@ def model(input: str) -> str:
assert call.display_name == "wow-custom!"


@pytest.mark.xfail(
reason="TODO: This test does not seem to work with the sqlite test server"
)
@pytest.mark.asyncio
async def test_get_evaluation_results(client):
@weave.op
def model(a: int, b: int) -> int:
return a + b

ev = weave.Evaluation(dataset=[{"a": 1, "b": 2}])
ev = weave.Evaluation(
dataset=[{"a": 1, "b": 2}],
# The evaluation name here is a hack for tests and is not required on the prod
# trace server. Sqlite seems to have a memory of the previous calls, and this
# was the only way I found to get the evaluation name to be unique in the test.
evaluation_name=lambda call: call.id,
)
await ev.evaluate(model)
await ev.evaluate(model)
res = ev.get_evaluation_calls()

res = ev.get_evaluation_results()
assert len(res) == 2
16 changes: 11 additions & 5 deletions weave/flow/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ async def evaluate(self, model: Union[Op, Model]) -> dict:

return summary

def get_evaluation_results(self) -> dict[str, CallsIter]:
if not self.evaluate.ref:
def get_evaluation_calls(self) -> dict[str, CallsIter]:
if not (eval_ref := self.evaluate.ref):
raise ValueError(
"Evaluation must be run or published before calling get_evaluation_results"
)
Expand All @@ -330,10 +330,16 @@ def get_evaluation_results(self) -> dict[str, CallsIter]:
eval_calls = self.evaluate.calls()
wc = weave_client_context.require_weave_client()
for call in eval_calls:
if call.display_name and isinstance(call.display_name, str):
res[call.display_name] = wc.get_calls(
display_name = call.display_name
if display_name and isinstance(display_name, str):
if display_name in res:
logger.warning(
f"Duplicate display name {display_name} found in evaluation results; omitting some results..."
)
continue
res[display_name] = wc.get_calls(
filter=CallsFilter(
parent_ids=[call.id], input_refs=[self.ref.uri()]
parent_ids=[call.id], input_refs=[eval_ref.uri()]
)
)
return res
Expand Down

0 comments on commit 8b3692d

Please sign in to comment.