diff --git a/docs/docs/guides/core-types/env-vars.md b/docs/docs/guides/core-types/env-vars.md new file mode 100644 index 000000000000..5a21cebb91ce --- /dev/null +++ b/docs/docs/guides/core-types/env-vars.md @@ -0,0 +1,28 @@ +# Environment variables + +Weave provides a set of environment variables to configure and optimize its behavior. You can set these variables in your shell or within scripts to control specific functionality. + +```bash +# Example of setting environment variables in the shell +WEAVE_PARALLELISM=10 # Controls the number of parallel workers +WEAVE_PRINT_CALL_LINK=false # Disables call link output +``` + +```python +# Example of setting environment variables in Python +import os + +os.environ["WEAVE_PARALLELISM"] = "10" +os.environ["WEAVE_PRINT_CALL_LINK"] = "false" +``` + +## Environment variables reference + +| Variable Name | Description | +|--------------------------|-----------------------------------------------------------------| +| WEAVE_CAPTURE_CODE | Disable code capture for `weave.op` if set to `false`. | +| WEAVE_DEBUG_HTTP | If set to `1`, turns on HTTP request and response logging for debugging. | +| WEAVE_DISABLED | If set to `true`, all tracing to Weave is disabled. | +| WEAVE_PARALLELISM | In evaluations, the number of examples to evaluate in parallel. `1` runs examples sequentially. Default value is `20`. | +| WEAVE_PRINT_CALL_LINK | If set to `false`, call URL printing is suppressed. Default value is `false`. | +| WEAVE_TRACE_LANGCHAIN | When set to `false`, explicitly disable global tracing for LangChain. | | diff --git a/docs/sidebars.ts b/docs/sidebars.ts index 447e26e44bbb..6ce0325fb3d6 100644 --- a/docs/sidebars.ts +++ b/docs/sidebars.ts @@ -70,6 +70,7 @@ const sidebars: SidebarsConfig = { "guides/tools/comparison", "guides/tools/playground", "guides/core-types/media", + "guides/core-types/env-vars", { type: "category", collapsible: true, diff --git a/tests/trace/test_call_apply_scorer.py b/tests/trace/test_call_apply_scorer.py new file mode 100644 index 000000000000..ecb363b8efa3 --- /dev/null +++ b/tests/trace/test_call_apply_scorer.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +import pytest + +import weave +from weave.scorers.base_scorer import ApplyScorerResult +from weave.trace.op import OpCallError +from weave.trace.refs import CallRef +from weave.trace.weave_client import Call, Op, WeaveClient + + +def do_assertions_for_scorer_op( + apply_score_res: ApplyScorerResult, + call: Call, + score_fn: Op | weave.Scorer, + client: WeaveClient, +): + assert apply_score_res.score_call.id is not None + assert apply_score_res.result == 0 + + feedbacks = list(call.feedback) + assert len(feedbacks) == 1 + target_feedback = feedbacks[0] + scorer_name = ( + score_fn.name if isinstance(score_fn, Op) else score_fn.__class__.__name__ + ) + assert target_feedback.feedback_type == "wandb.runnable." + scorer_name + assert target_feedback.runnable_ref == score_fn.ref.uri() + assert ( + target_feedback.call_ref + == CallRef( + entity=client.entity, + project=client.project, + id=apply_score_res.score_call.id, + ).uri() + ) + assert target_feedback.payload == {"output": apply_score_res.result} + + +@pytest.mark.asyncio +async def test_scorer_op_no_context(client: WeaveClient): + @weave.op + def predict(x): + return x + 1 + + @weave.op + def score_fn(x, output): + return output - x - 1 + + _, call = predict.call(1) + apply_score_res = await call.apply_scorer(score_fn) + do_assertions_for_scorer_op(apply_score_res, call, score_fn, client) + + @weave.op + def score_fn_with_incorrect_args(y, output): + return output - y + + with pytest.raises(OpCallError): + apply_score_res = await call.apply_scorer(score_fn_with_incorrect_args) + + +@pytest.mark.asyncio +async def test_scorer_op_with_context(client: WeaveClient): + @weave.op + def predict(x): + return x + 1 + + @weave.op + def score_fn(x, output, correct_answer): + return output - correct_answer + + _, call = predict.call(1) + apply_score_res = await call.apply_scorer( + score_fn, additional_scorer_kwargs={"correct_answer": 2} + ) + do_assertions_for_scorer_op(apply_score_res, call, score_fn, client) + + @weave.op + def score_fn_with_incorrect_args(x, output, incorrect_arg): + return output - incorrect_arg + + with pytest.raises(OpCallError): + apply_score_res = await call.apply_scorer( + score_fn_with_incorrect_args, additional_scorer_kwargs={"correct_answer": 2} + ) + + +@pytest.mark.asyncio +async def test_async_scorer_op(client: WeaveClient): + @weave.op + def predict(x): + return x + 1 + + @weave.op + async def score_fn(x, output): + return output - x - 1 + + _, call = predict.call(1) + apply_score_res = await call.apply_scorer(score_fn) + do_assertions_for_scorer_op(apply_score_res, call, score_fn, client) + + @weave.op + async def score_fn_with_incorrect_args(y, output): + return output - y + + with pytest.raises(OpCallError): + apply_score_res = await call.apply_scorer(score_fn_with_incorrect_args) + + +@pytest.mark.asyncio +async def test_scorer_obj_no_context(client: WeaveClient): + @weave.op + def predict(x): + return x + 1 + + class MyScorer(weave.Scorer): + offset: int + + @weave.op + def score(self, x, output): + return output - x - self.offset + + scorer = MyScorer(offset=1) + + _, call = predict.call(1) + apply_score_res = await call.apply_scorer(scorer) + do_assertions_for_scorer_op(apply_score_res, call, scorer, client) + + class MyScorerWithIncorrectArgs(weave.Scorer): + offset: int + + @weave.op + def score(self, y, output): + return output - y - self.offset + + with pytest.raises(OpCallError): + apply_score_res = await call.apply_scorer(MyScorerWithIncorrectArgs(offset=1)) + + +@pytest.mark.asyncio +async def test_scorer_obj_with_context(client: WeaveClient): + @weave.op + def predict(x): + return x + 1 + + class MyScorer(weave.Scorer): + offset: int + + @weave.op + def score(self, x, output, correct_answer): + return output - correct_answer - self.offset + + scorer = MyScorer(offset=0) + + _, call = predict.call(1) + apply_score_res = await call.apply_scorer( + scorer, additional_scorer_kwargs={"correct_answer": 2} + ) + do_assertions_for_scorer_op(apply_score_res, call, scorer, client) + + class MyScorerWithIncorrectArgs(weave.Scorer): + offset: int + + @weave.op + def score(self, y, output, incorrect_arg): + return output - incorrect_arg - self.offset + + with pytest.raises(OpCallError): + apply_score_res = await call.apply_scorer( + MyScorerWithIncorrectArgs(offset=0), + additional_scorer_kwargs={"incorrect_arg": 2}, + ) + + class MyScorerWithIncorrectArgsButCorrectColumnMapping(weave.Scorer): + offset: int + + @weave.op + def score(self, y, output, incorrect_arg): + return output - incorrect_arg - self.offset + + scorer = MyScorerWithIncorrectArgsButCorrectColumnMapping( + offset=0, column_map={"y": "x", "incorrect_arg": "correct_answer"} + ) + + _, call = predict.call(1) + apply_score_res = await call.apply_scorer( + scorer, additional_scorer_kwargs={"correct_answer": 2} + ) + do_assertions_for_scorer_op(apply_score_res, call, scorer, client) + + +@pytest.mark.asyncio +async def test_async_scorer_obj(client: WeaveClient): + @weave.op + def predict(x): + return x + 1 + + class MyScorer(weave.Scorer): + offset: int + + @weave.op + async def score(self, x, output): + return output - x - 1 + + scorer = MyScorer(offset=0) + + _, call = predict.call(1) + apply_score_res = await call.apply_scorer( + scorer, additional_scorer_kwargs={"correct_answer": 2} + ) + do_assertions_for_scorer_op(apply_score_res, call, scorer, client) diff --git a/tests/trace/test_feedback.py b/tests/trace/test_feedback.py index f5eae39ea39f..45ebb1be705a 100644 --- a/tests/trace/test_feedback.py +++ b/tests/trace/test_feedback.py @@ -346,7 +346,7 @@ def test_runnable_feedback(client: WeaveClient) -> None: } -def populate_feedback(client: WeaveClient) -> None: +async def populate_feedback(client: WeaveClient) -> None: @weave.op def my_scorer(x: int, output: int) -> int: expected = ["a", "b", "c", "d"][x] @@ -369,7 +369,7 @@ def my_model(x: int) -> str: for x in range(4): _, c = my_model.call(x) ids.append(c.id) - c._apply_scorer(my_scorer) + await c.apply_scorer(my_scorer) assert len(list(my_scorer.calls())) == 4 assert len(list(my_model.calls())) == 4 @@ -377,13 +377,14 @@ def my_model(x: int) -> str: return ids, my_scorer, my_model -def test_sort_by_feedback(client: WeaveClient) -> None: +@pytest.mark.asyncio +async def test_sort_by_feedback(client: WeaveClient) -> None: if client_is_sqlite(client): # Not implemented in sqlite - skip return pytest.skip() """Test sorting by feedback.""" - ids, my_scorer, my_model = populate_feedback(client) + ids, my_scorer, my_model = await populate_feedback(client) for fields, asc_ids in [ ( @@ -441,13 +442,14 @@ def test_sort_by_feedback(client: WeaveClient) -> None: ), f"Sorting by {fields} descending failed, expected {asc_ids[::-1]}, got {found_ids}" -def test_filter_by_feedback(client: WeaveClient) -> None: +@pytest.mark.asyncio +async def test_filter_by_feedback(client: WeaveClient) -> None: if client_is_sqlite(client): # Not implemented in sqlite - skip return pytest.skip() """Test filtering by feedback.""" - ids, my_scorer, my_model = populate_feedback(client) + ids, my_scorer, my_model = await populate_feedback(client) for field, value, eq_ids, gt_ids in [ ( "feedback.[wandb.runnable.my_scorer].payload.output.model_output", @@ -514,13 +516,14 @@ def __eq__(self, other): return isinstance(other, datetime.datetime) -def test_filter_and_sort_by_feedback(client: WeaveClient) -> None: +@pytest.mark.asyncio +async def test_filter_and_sort_by_feedback(client: WeaveClient) -> None: if client_is_sqlite(client): # Not implemented in sqlite - skip return pytest.skip() """Test filtering by feedback.""" - ids, my_scorer, my_model = populate_feedback(client) + ids, my_scorer, my_model = await populate_feedback(client) calls = client.server.calls_query_stream( tsi.CallsQueryReq( project_id=client._project_id(), diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGrid.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGrid.tsx index 5194af00fe31..1622924c9771 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGrid.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGrid.tsx @@ -7,12 +7,12 @@ import {TargetBlank} from '../../../../../common/util/links'; import {Alert} from '../../../../Alert'; import {Loading} from '../../../../Loading'; import {Tailwind} from '../../../../Tailwind'; -import {RUNNABLE_FEEDBACK_TYPE_PREFIX} from '../pages/CallPage/CallScoresViewer'; import {Empty} from '../pages/common/Empty'; import {useWFHooks} from '../pages/wfReactInterface/context'; import {useGetTraceServerClientContext} from '../pages/wfReactInterface/traceServerClientContext'; import {FeedbackGridInner} from './FeedbackGridInner'; import {HUMAN_ANNOTATION_BASE_TYPE} from './StructuredFeedback/humanAnnotationTypes'; +import {RUNNABLE_FEEDBACK_TYPE_PREFIX} from './StructuredFeedback/runnableFeedbackTypes'; const ANNOTATION_PREFIX = `${HUMAN_ANNOTATION_BASE_TYPE}.`; @@ -62,7 +62,11 @@ export const FeedbackGrid = ({ ); // only keep the most recent feedback for each (feedback_type, creator) const combinedFiltered = Object.values(combined).map( - fs => fs.sort((a, b) => b.created_at - a.created_at)[0] + fs => + fs.sort( + (a, b) => + new Date(b.created_at).getTime() - new Date(a.created_at).getTime() + )[0] ); // add the non-annotation feedback to the combined object combinedFiltered.push( diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/HumanFeedback/tsScorerFeedback.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/HumanFeedback/tsScorerFeedback.ts new file mode 100644 index 000000000000..b43f6f337007 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/HumanFeedback/tsScorerFeedback.ts @@ -0,0 +1,45 @@ +import {RUNNABLE_FEEDBACK_TYPE_PREFIX} from '../StructuredFeedback/runnableFeedbackTypes'; + +export const RUNNABLE_FEEDBACK_IN_SUMMARY_PREFIX = + 'summary.weave.feedback.' + RUNNABLE_FEEDBACK_TYPE_PREFIX; +export const RUNNABLE_FEEDBACK_OUTPUT_PART = 'payload.output'; + +export type ScorerFeedbackTypeParts = { + scorerName: string; + scorePath: string; +}; + +export const parseScorerFeedbackField = ( + inputField: string +): ScorerFeedbackTypeParts | null => { + const prefix = RUNNABLE_FEEDBACK_IN_SUMMARY_PREFIX + '.'; + if (!inputField.startsWith(prefix)) { + return null; + } + const res = inputField.replace(prefix, ''); + if (!res.includes('.')) { + return null; + } + const [scorerName, ...rest] = res.split('.'); + const prefixedScorePath = rest.join('.'); + const pathPrefix = RUNNABLE_FEEDBACK_OUTPUT_PART; + if (!prefixedScorePath.startsWith(pathPrefix)) { + return null; + } + const scorePath = prefixedScorePath.replace(pathPrefix, ''); + return { + scorerName, + scorePath, + }; +}; + +export const convertScorerFeedbackFieldToBackendFilter = ( + field: string +): string => { + const parsed = parseScorerFeedbackField(field); + if (parsed === null) { + return field; + } + const {scorerName, scorePath} = parsed; + return `feedback.[${RUNNABLE_FEEDBACK_TYPE_PREFIX}.${scorerName}].${RUNNABLE_FEEDBACK_OUTPUT_PART}${scorePath}`; +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/ScorerFeedbackGrid.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/ScorerFeedbackGrid.tsx new file mode 100644 index 000000000000..507787b2465f --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/ScorerFeedbackGrid.tsx @@ -0,0 +1,123 @@ +import {Box} from '@mui/material'; +import _ from 'lodash'; +import React, {useEffect, useMemo} from 'react'; + +import {useViewerInfo} from '../../../../../common/hooks/useViewerInfo'; +import {TargetBlank} from '../../../../../common/util/links'; +import {Alert} from '../../../../Alert'; +import {Loading} from '../../../../Loading'; +import {Tailwind} from '../../../../Tailwind'; +import {Empty} from '../pages/common/Empty'; +import {useWFHooks} from '../pages/wfReactInterface/context'; +import {useGetTraceServerClientContext} from '../pages/wfReactInterface/traceServerClientContext'; +import {ScoresFeedbackGridInner} from './ScoresFeedbackGridInner'; +import {RUNNABLE_FEEDBACK_TYPE_PREFIX} from './StructuredFeedback/runnableFeedbackTypes'; + +type FeedbackGridProps = { + entity: string; + project: string; + weaveRef: string; + objectType?: string; +}; + +export const ScorerFeedbackGrid = ({ + entity, + project, + weaveRef, + objectType, +}: FeedbackGridProps) => { + /** + * This component is very similar to `FeedbackGrid`, but it only shows scores. + * While some of the code is duplicated, it is kept separate to make it easier + * to modify in the future. + */ + const {loading: loadingUserInfo, userInfo} = useViewerInfo(); + + const {useFeedback} = useWFHooks(); + const query = useFeedback({ + entity, + project, + weaveRef, + }); + + const getTsClient = useGetTraceServerClientContext(); + useEffect(() => { + return getTsClient().registerOnFeedbackListener(weaveRef, query.refetch); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); + + // Group by feedback on this object vs. descendent objects + const grouped = useMemo(() => { + // Exclude runnables as they are presented in a different tab + const onlyRunnables = (query.result ?? []).filter(f => + f.feedback_type.startsWith(RUNNABLE_FEEDBACK_TYPE_PREFIX) + ); + + // Group by feedback on this object vs. descendent objects + return _.groupBy(onlyRunnables, f => + f.weave_ref.substring(weaveRef.length) + ); + }, [query.result, weaveRef]); + + const paths = useMemo(() => Object.keys(grouped).sort(), [grouped]); + + if (query.loading || loadingUserInfo) { + return ( + + + + ); + } + if (query.error) { + return ( +
+ + Error: {query.error.message ?? JSON.stringify(query.error)} + +
+ ); + } + + if (!paths.length) { + return ( + + Learn how to{' '} + + run evaluations + + . + + } + // Need to add this additional detail once the new API is released. + // description="You can add scores to calls by using the `Call.apply_scorer` method." + /> + ); + } + + const currentViewerId = userInfo ? userInfo.id : null; + return ( + + {paths.map(path => { + return ( +
+ {path &&
On {path}
} + +
+ ); + })} +
+ ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/ScoresFeedbackGridInner.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/ScoresFeedbackGridInner.tsx new file mode 100644 index 000000000000..8b13e46821e2 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/ScoresFeedbackGridInner.tsx @@ -0,0 +1,192 @@ +import {Box} from '@mui/material'; +import {GridColDef, GridRowHeightParams} from '@mui/x-data-grid-pro'; +import {isWeaveObjectRef, parseRefMaybe} from '@wandb/weave/react'; +import React from 'react'; + +import {Timestamp} from '../../../../Timestamp'; +import {UserLink} from '../../../../UserLink'; +import {CellValue} from '../../Browse2/CellValue'; +import {SmallRef} from '../../Browse2/SmallRef'; +import {CallRefLink} from '../pages/common/Links'; +import {Feedback} from '../pages/wfReactInterface/traceServerClientTypes'; +import {StyledDataGrid} from '../StyledDataGrid'; +import {FeedbackGridActions} from './FeedbackGridActions'; + +type FeedbackGridInnerProps = { + feedback: Feedback[]; + currentViewerId: string | null; + showAnnotationName?: boolean; +}; + +export const ScoresFeedbackGridInner = ({ + feedback, + currentViewerId, + showAnnotationName, +}: FeedbackGridInnerProps) => { + /** + * This component is very similar to `FeedbackGridInner`, but it only shows scores. + * While some of the code is duplicated, it is kept separate to make it easier + * to modify in the future. + */ + const columns: Array> = [ + { + field: 'runnable_ref', + headerName: 'Scorer', + display: 'flex', + flex: 1, + renderCell: params => { + const runnable_ref = params.row.runnable_ref; + if (!runnable_ref) { + return null; + } + const objRef = parseRefMaybe(runnable_ref); + if (!objRef) { + return null; + } + return ( +
+ +
+ ); + }, + }, + { + field: 'payload', + headerName: 'Score', + sortable: false, + flex: 1, + renderCell: params => { + const value = params.row.payload.output; + return ( + + + + ); + }, + }, + { + field: 'call_ref', + headerName: 'Score Call', + display: 'flex', + renderCell: params => { + const call_ref = params.row.call_ref; + if (!call_ref) { + return null; + } + const objRef = parseRefMaybe(call_ref); + if (!objRef) { + return null; + } + if (!isWeaveObjectRef(objRef)) { + return null; + } + return ( +
+ +
+ ); + }, + }, + { + field: 'created_at', + headerName: 'Timestamp', + minWidth: 105, + width: 105, + renderCell: params => ( + + ), + }, + { + field: 'wb_user_id', + headerName: 'Creator', + minWidth: 150, + width: 150, + // Might be confusing to enable as-is, because the user sees name / + // email but the underlying data is userId. + filterable: false, + sortable: false, + resizable: false, + disableColumnMenu: true, + renderCell: params => { + if ( + params.row.creator && + params.row.creator !== params.row.wb_user_id + ) { + return params.row.creator; + } + return ; + }, + }, + { + field: 'actions', + headerName: '', + width: 36, + minWidth: 36, + filterable: false, + sortable: false, + resizable: false, + disableColumnMenu: true, + display: 'flex', + renderCell: params => { + const projectId = params.row.project_id; + const feedbackId = params.row.id; + const creatorId = params.row.wb_user_id; + if (!currentViewerId || creatorId !== currentViewerId) { + return null; + } + return ( + + ); + }, + }, + ]; + const rows = feedback; + return ( + { + if (isWandbFeedbackType(params.model.feedback_type)) { + return 38; + } + return 'auto'; + }} + columns={columns} + disableRowSelectionOnClick + /> + ); +}; + +const isWandbFeedbackType = (feedbackType: string) => + feedbackType.startsWith('wandb.'); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/runnableFeedbackTypes.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/runnableFeedbackTypes.ts new file mode 100644 index 000000000000..eb8ba6e87830 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/runnableFeedbackTypes.ts @@ -0,0 +1 @@ +export const RUNNABLE_FEEDBACK_TYPE_PREFIX = 'wandb.runnable'; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/filters/common.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/filters/common.ts index 17cfbf45ac49..a0abb3c6801c 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/filters/common.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/filters/common.ts @@ -12,6 +12,10 @@ import {isWeaveObjectRef, parseRefMaybe} from '@wandb/weave/react'; import _ from 'lodash'; import {parseFeedbackType} from '../feedback/HumanFeedback/tsHumanFeedback'; +import { + parseScorerFeedbackField, + RUNNABLE_FEEDBACK_IN_SUMMARY_PREFIX, +} from '../feedback/HumanFeedback/tsScorerFeedback'; import {WEAVE_REF_PREFIX} from '../pages/wfReactInterface/constants'; import {TraceCallSchema} from '../pages/wfReactInterface/traceServerClientTypes'; @@ -41,7 +45,7 @@ export const FIELD_LABELS: Record = { }; export const getFieldLabel = (field: string): string => { - if (field.startsWith('feedback.')) { + if (field.startsWith('feedback.wandb.annotation.')) { // Here the field is coming from convertFeedbackFieldToBackendFilter // so the field should start with 'feedback.' if feedback const parsed = parseFeedbackType(field); @@ -50,6 +54,13 @@ export const getFieldLabel = (field: string): string => { } return parsed.displayName; } + if (field.startsWith(RUNNABLE_FEEDBACK_IN_SUMMARY_PREFIX)) { + const parsed = parseScorerFeedbackField(field); + if (parsed === null) { + return field; + } + return parsed.scorerName + parsed.scorePath; + } return FIELD_LABELS[field] ?? field; }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx index a3315266f65d..5219f390c810 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx @@ -24,6 +24,7 @@ import { WeaveflowPeekContext, } from '../../context'; import {FeedbackGrid} from '../../feedback/FeedbackGrid'; +import {ScorerFeedbackGrid} from '../../feedback/ScorerFeedbackGrid'; import {FeedbackSidebar} from '../../feedback/StructuredFeedback/FeedbackSidebar'; import {useHumanAnnotationSpecs} from '../../feedback/StructuredFeedback/tsHumanFeedback'; import {NotFoundPanel} from '../../NotFoundPanel'; @@ -42,7 +43,6 @@ import {CallSchema} from '../wfReactInterface/wfDataModelHooksInterface'; import {CallChat} from './CallChat'; import {CallDetails} from './CallDetails'; import {CallOverview} from './CallOverview'; -import {CallScoresViewer} from './CallScoresViewer'; import {CallSummary} from './CallSummary'; import {CallTraceView, useCallFlattenedTraceTree} from './CallTraceView'; import {PaginationControls} from './PaginationControls'; @@ -77,7 +77,6 @@ export const useShowRunnableUI = () => { }; const useCallTabs = (call: CallSchema) => { - const showScores = useShowRunnableUI(); const codeURI = call.opVersionRef; const {entity, project, callId} = call; const weaveRef = makeRefCall(entity, project, callId); @@ -171,6 +170,19 @@ const useCallTabs = (call: CallSchema) => { ), }, + { + label: 'Scores', + content: ( + + + + ), + }, { label: 'Summary', content: ( @@ -179,21 +191,6 @@ const useCallTabs = (call: CallSchema) => { ), }, - // For now, we are only showing this tab for W&B admins since the - // feature is in active development. We want to be able to get - // feedback without enabling for all users. - ...(showScores - ? [ - { - label: 'Scores', - content: ( - - - - ), - }, - ] - : []), { label: 'Use', content: ( diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx deleted file mode 100644 index 9ec2f3ff80ec..000000000000 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx +++ /dev/null @@ -1,379 +0,0 @@ -import {Box} from '@material-ui/core'; -import {GridColDef} from '@mui/x-data-grid-pro'; -import {Button} from '@wandb/weave/components/Button/Button'; -import {Timestamp} from '@wandb/weave/components/Timestamp'; -import {parseRef} from '@wandb/weave/react'; -import {makeRefCall} from '@wandb/weave/util/refs'; -import _ from 'lodash'; -import React, {useMemo, useState} from 'react'; - -import {flattenObjectPreservingWeaveTypes} from '../../../Browse2/browse2Util'; -import {CellValue} from '../../../Browse2/CellValue'; -import {NotApplicable} from '../../../Browse2/NotApplicable'; -import {SmallRef} from '../../../Browse2/SmallRef'; -import {StyledDataGrid} from '../../StyledDataGrid'; // Import the StyledDataGrid component -import {WEAVE_REF_SCHEME} from '../wfReactInterface/constants'; -import {useWFHooks} from '../wfReactInterface/context'; -import { - TraceObjSchemaForBaseObjectClass, - useBaseObjectInstances, -} from '../wfReactInterface/objectClassQuery'; -import {useGetTraceServerClientContext} from '../wfReactInterface/traceServerClientContext'; -import {Feedback} from '../wfReactInterface/traceServerClientTypes'; -import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks'; -import {objectVersionKeyToRefUri} from '../wfReactInterface/utilities'; -import {CallSchema} from '../wfReactInterface/wfDataModelHooksInterface'; - -export const RUNNABLE_FEEDBACK_TYPE_PREFIX = 'wandb.runnable'; - -const useLatestActionDefinitionsForCall = (call: CallSchema) => { - const actionSpecs = ( - useBaseObjectInstances('ActionSpec', { - project_id: projectIdFromParts({ - entity: call.entity, - project: call.project, - }), - filter: {latest_only: true}, - }).result ?? [] - ).sort((a, b) => (a.val.name ?? '').localeCompare(b.val.name ?? '')); - return actionSpecs; -}; - -const useRunnableFeedbacksForCall = (call: CallSchema) => { - const {useFeedback} = useWFHooks(); - const weaveRef = makeRefCall(call.entity, call.project, call.callId); - const feedbackQuery = useFeedback({ - entity: call.entity, - project: call.project, - weaveRef, - }); - - const runnableFeedbacks: Feedback[] = useMemo(() => { - return (feedbackQuery.result ?? []).filter( - f => - f.feedback_type?.startsWith(RUNNABLE_FEEDBACK_TYPE_PREFIX) && - f.runnable_ref !== null - ); - }, [feedbackQuery.result]); - - return {runnableFeedbacks, refetchFeedback: feedbackQuery.refetch}; -}; - -const useRunnableFeedbackTypeToLatestActionRef = ( - call: CallSchema, - actionSpecs: Array> -): Record => { - return useMemo(() => { - return _.fromPairs( - actionSpecs.map(actionSpec => { - return [ - RUNNABLE_FEEDBACK_TYPE_PREFIX + '.' + actionSpec.object_id, - objectVersionKeyToRefUri({ - scheme: WEAVE_REF_SCHEME, - weaveKind: 'object', - entity: call.entity, - project: call.project, - objectId: actionSpec.object_id, - versionHash: actionSpec.digest, - path: '', - }), - ]; - }) - ); - }, [actionSpecs, call.entity, call.project]); -}; - -type GroupedRowType = { - id: string; - displayName: string; - runnableActionRef?: string; - feedback?: Feedback; - runCount: number; -}; - -const useTableRowsForRunnableFeedbacks = ( - actionSpecs: Array>, - runnableFeedbacks: Feedback[], - runnableFeedbackTypeToLatestActionRef: Record -): GroupedRowType[] => { - const rows = useMemo(() => { - const scoredRows = Object.entries( - _.groupBy(runnableFeedbacks, f => f.feedback_type) - ).map(([feedbackType, fs]) => { - const val = _.reverse(_.sortBy(fs, 'created_at'))[0]; - return { - id: feedbackType, - displayName: feedbackType.slice( - RUNNABLE_FEEDBACK_TYPE_PREFIX.length + 1 - ), - runnableActionRef: - runnableFeedbackTypeToLatestActionRef[val.feedback_type], - feedback: val, - runCount: fs.length, - }; - }); - const additionalRows = actionSpecs - .map(actionSpec => { - const feedbackType = - RUNNABLE_FEEDBACK_TYPE_PREFIX + '.' + actionSpec.object_id; - return { - id: feedbackType, - runnableActionRef: - runnableFeedbackTypeToLatestActionRef[feedbackType], - displayName: actionSpec.object_id, - runCount: 0, - }; - }) - .filter(row => !scoredRows.some(r => r.id === row.id)); - return _.sortBy([...scoredRows, ...additionalRows], s => s.id); - }, [actionSpecs, runnableFeedbackTypeToLatestActionRef, runnableFeedbacks]); - - return rows; -}; - -type FlattenedRowType = { - id: string; - displayName: string; - runnableActionRef?: string; - feedback?: Feedback; - runCount: number; - feedbackKey?: string; - feedbackValue?: any; -}; - -const useFlattenedRows = (rows: GroupedRowType[]): FlattenedRowType[] => { - return useMemo(() => { - return rows.flatMap(r => { - if (r.feedback == null) { - return [r]; - } - const feedback = flattenObjectPreservingWeaveTypes(r.feedback.payload); - return Object.entries(feedback).map(([k, v]) => ({ - ...r, - id: r.id + '::' + k, - feedbackKey: k, - feedbackValue: v, - })); - }); - }, [rows]); -}; - -export const CallScoresViewer: React.FC<{ - call: CallSchema; -}> = props => { - const actionSpecs = useLatestActionDefinitionsForCall(props.call); - const {runnableFeedbacks, refetchFeedback} = useRunnableFeedbacksForCall( - props.call - ); - const runnableFeedbackTypeToLatestActionRef = - useRunnableFeedbackTypeToLatestActionRef(props.call, actionSpecs); - const rows = useTableRowsForRunnableFeedbacks( - actionSpecs, - runnableFeedbacks, - runnableFeedbackTypeToLatestActionRef - ); - const flattenedRows = useFlattenedRows(rows); - - const columns: Array> = [ - { - field: 'scorer', - headerName: 'Scorer', - width: 150, - rowSpanValueGetter: (value, row) => row.displayName, - renderCell: params => { - const refToUse = - params.row.runnableActionRef || params.row.feedback?.runnable_ref; - const title = params.row.displayName; - return ( - - {' '} - {refToUse && ( - - - - )} - {title} - - ); - }, - }, - { - field: 'runCount', - headerName: 'Runs', - width: 55, - rowSpanValueGetter: (value, row) => row.displayName, - }, - { - field: 'lastRanAt', - headerName: 'Last Ran At', - width: 100, - rowSpanValueGetter: (value, row) => row.displayName, - renderCell: params => { - if (params.row.feedback == null) { - return null; - } - const createdAt = new Date(params.row.feedback.created_at + 'Z'); - const value = createdAt ? createdAt.getTime() / 1000 : undefined; - if (value == null) { - return ; - } - return ; - }, - }, - { - field: 'lastResultKey', - headerName: 'Key', - width: 100, - renderCell: params => { - let key = params.row.feedbackKey; - // Handle cases where the output is a primitive value vs a nested object - if (key?.startsWith('output.')) { - key = key.slice(7); - } - return key; - }, - }, - { - field: 'lastResultValue', - headerName: 'Value', - flex: 1, - renderCell: params => { - const value = params.row.feedbackValue; - if (value == null) { - return ; - } - return ( - - - - ); - }, - }, - { - field: 'run', - headerName: '', - width: 75, - rowSpanValueGetter: (value, row) => row.displayName, - renderCell: params => { - const actionRef = params.row.runnableActionRef; - return actionRef ? ( - - ) : null; - }, - }, - ]; - - return ( - <> - - - ); -}; - -const RunButton: React.FC<{ - actionRef: string; - callId: string; - entity: string; - project: string; - refetchFeedback: () => void; -}> = ({actionRef, callId, entity, project, refetchFeedback}) => { - const getClient = useGetTraceServerClientContext(); - - const [isRunning, setIsRunning] = useState(false); - const [error, setError] = useState(null); - - const handleRunClick = async () => { - setIsRunning(true); - setError(null); - try { - await getClient().actionsExecuteBatch({ - project_id: projectIdFromParts({entity, project}), - call_ids: [callId], - action_ref: actionRef, - }); - refetchFeedback(); - } catch (err) { - setError('An error occurred while running the action.'); - } finally { - setIsRunning(false); - } - }; - - if (error) { - return ( - - ); - } - - return ( -
- -
- ); -}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableColumns.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableColumns.tsx index 4eee09ece141..b48f37eef4bf 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableColumns.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableColumns.tsx @@ -18,11 +18,18 @@ import {monthRoundedTime} from '../../../../../../common/util/time'; import {isWeaveObjectRef, parseRef} from '../../../../../../react'; import {makeRefCall} from '../../../../../../util/refs'; import {Timestamp} from '../../../../../Timestamp'; +import {CellValue} from '../../../Browse2/CellValue'; import {CellValueString} from '../../../Browse2/CellValueString'; import { convertFeedbackFieldToBackendFilter, parseFeedbackType, } from '../../feedback/HumanFeedback/tsHumanFeedback'; +import { + convertScorerFeedbackFieldToBackendFilter, + parseScorerFeedbackField, + RUNNABLE_FEEDBACK_IN_SUMMARY_PREFIX, + RUNNABLE_FEEDBACK_OUTPUT_PART, +} from '../../feedback/HumanFeedback/tsScorerFeedback'; import {Reactions} from '../../feedback/Reactions'; import {CellFilterWrapper, OnAddFilter} from '../../filters/CellFilterWrapper'; import {isWeaveRef} from '../../filters/common'; @@ -383,6 +390,72 @@ function buildCallsTableColumns( cols.push(...annotationColumns); } + const scoreColNames = allDynamicColumnNames.filter( + c => + c.startsWith(RUNNABLE_FEEDBACK_IN_SUMMARY_PREFIX) && + c.includes(RUNNABLE_FEEDBACK_OUTPUT_PART) + ); + if (scoreColNames.length > 0) { + // Add feedback group to grouping model + const scoreGroup = { + groupId: 'scores', + headerName: 'Scores', + children: [] as any[], + }; + groupingModel.push(scoreGroup); + + // Add feedback columns + const scoreColumns: Array> = scoreColNames.map( + c => { + const parsed = parseScorerFeedbackField(c); + const field = convertScorerFeedbackFieldToBackendFilter(c); + scoreGroup.children.push({ + field, + }); + if (parsed === null) { + return { + field, + headerName: c, + width: 150, + renderHeader: () => { + return
{c}
; + }, + valueGetter: (unused: any, row: any) => { + return row[c]; + }, + renderCell: (params: GridRenderCellParams) => { + return ; + }, + }; + } + return { + field, + headerName: 'Scores.' + parsed.scorerName + parsed.scorePath, + width: 150, + renderHeader: () => { + return
{parsed.scorerName + parsed.scorePath}
; + }, + valueGetter: (unused: any, row: any) => { + return row[c]; + }, + renderCell: (params: GridRenderCellParams) => { + return ( + + + + ); + }, + }; + } + ); + cols.push(...scoreColumns); + } + cols.push({ field: 'wb_user_id', headerName: 'User', diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/Links.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/Links.tsx index 4060735cc67a..5f84121f0889 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/Links.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/Links.tsx @@ -4,6 +4,7 @@ import { TEAL_500, TEAL_600, } from '@wandb/weave/common/css/color.styles'; +import {WeaveObjectRef} from '@wandb/weave/react'; import React from 'react'; import {Link as LinkComp, useHistory} from 'react-router-dom'; import styled, {css} from 'styled-components'; @@ -272,6 +273,47 @@ export const OpVersionLink: React.FC<{ ); }; +export const CallRefLink: React.FC<{ + callRef: WeaveObjectRef; +}> = props => { + const history = useHistory(); + const {peekingRouter} = useWeaveflowRouteContext(); + const callId = props.callRef.artifactName; + const to = peekingRouter.callUIUrl( + props.callRef.entityName, + props.callRef.projectName, + '', + callId + ); + const onClick = () => { + history.push(to); + }; + + if (props.callRef.weaveKind !== 'call') { + return null; + } + + return ( + + + + + + + + + + ); +}; + export const CallLink: React.FC<{ entityName: string; projectName: string; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts index e039747042cb..7c89efd44196 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts @@ -165,6 +165,7 @@ export type FeedbackQueryReq = { export type Feedback = { id: string; + project_id: string; weave_ref: string; wb_user_id: string; // authenticated creator username creator: string | null; // display name diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/wfDataModelHooksInterface.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/wfDataModelHooksInterface.ts index 6b08bbba26f4..34496ffeed04 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/wfDataModelHooksInterface.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/wfDataModelHooksInterface.ts @@ -275,7 +275,8 @@ export type WFDataModelHooksInterface = { useFeedback: ( key: FeedbackKey | null, sortBy?: traceServerClientTypes.SortBy[] - ) => LoadableWithError & Refetchable; + ) => LoadableWithError & + Refetchable; useTableUpdate: () => ( projectId: string, baseDigest: string, diff --git a/weave/flow/eval.py b/weave/flow/eval.py index 7400fc45ce38..f540dfb1dfe8 100644 --- a/weave/flow/eval.py +++ b/weave/flow/eval.py @@ -27,8 +27,6 @@ get_scorer_attributes, transpose, ) -from weave.scorers.base_scorer import apply_scorer_async -from weave.trace.context.weave_client_context import get_weave_client from weave.trace.env import get_weave_parallelism from weave.trace.errors import OpCallError from weave.trace.isinstance import weave_isinstance @@ -209,22 +207,8 @@ async def predict_and_score(self, model: Union[Op, Model], example: dict) -> dic scorers = self._post_init_scorers for scorer in scorers: - apply_scorer_result = await apply_scorer_async( - scorer, example, model_output - ) + apply_scorer_result = await model_call.apply_scorer(scorer, example) result = apply_scorer_result.result - score_call = apply_scorer_result.score_call - - wc = get_weave_client() - if wc: - scorer_ref_uri = None - if weave_isinstance(scorer, Scorer): - # Very important: if the score is generated from a Scorer subclass, - # then scorer_ref_uri will be None, and we will use the op_name from - # the score_call instead. - scorer_ref = get_ref(scorer) - scorer_ref_uri = scorer_ref.uri() if scorer_ref else None - wc._send_score_call(model_call, score_call, scorer_ref_uri) scorer_attributes = get_scorer_attributes(scorer) scorer_name = scorer_attributes.scorer_name scores[scorer_name] = result diff --git a/weave/scorers/base_scorer.py b/weave/scorers/base_scorer.py index 462246914f5b..9616fb0c4500 100644 --- a/weave/scorers/base_scorer.py +++ b/weave/scorers/base_scorer.py @@ -185,7 +185,7 @@ class ApplyScorerSuccess: async def apply_scorer_async( - scorer: Union[Op, Scorer], example: dict, model_output: dict + scorer: Union[Op, Scorer], example: dict, model_output: Any ) -> ApplyScorerResult: """Apply a scoring function to model output and example data asynchronously. diff --git a/weave/trace/op.py b/weave/trace/op.py index ea3f0c66dc0c..bf3ee58f7a79 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -417,15 +417,19 @@ def _do_call( # Handle all of the possible cases where we would skip tracing. if settings.should_disable_weave(): res = func(*pargs.args, **pargs.kwargs) + call.output = res return res, call if weave_client_context.get_weave_client() is None: res = func(*pargs.args, **pargs.kwargs) + call.output = res return res, call if not op._tracing_enabled: res = func(*pargs.args, **pargs.kwargs) + call.output = res return res, call if not get_tracing_enabled(): res = func(*pargs.args, **pargs.kwargs) + call.output = res return res, call current_call = call_context.get_current_call() @@ -435,6 +439,7 @@ def _do_call( # Disable tracing for this call and all descendants with tracing_disabled(): res = func(*pargs.args, **pargs.kwargs) + call.output = res return res, call # Proceed with tracing. Note that we don't check the sample rate here. @@ -478,15 +483,19 @@ async def _do_call_async( # Handle all of the possible cases where we would skip tracing. if settings.should_disable_weave(): res = await func(*args, **kwargs) + call.output = res return res, call if weave_client_context.get_weave_client() is None: res = await func(*args, **kwargs) + call.output = res return res, call if not op._tracing_enabled: res = await func(*args, **kwargs) + call.output = res return res, call if not get_tracing_enabled(): res = await func(*args, **kwargs) + call.output = res return res, call current_call = call_context.get_current_call() @@ -496,6 +505,7 @@ async def _do_call_async( # Disable tracing for this call and all descendants with tracing_disabled(): res = await func(*args, **kwargs) + call.output = res return res, call # Proceed with tracing diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 6decbc12cffc..0754b9f57eb5 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -2,7 +2,6 @@ import dataclasses import datetime -import inspect import logging import platform import re @@ -10,7 +9,16 @@ from collections.abc import Iterator, Sequence from concurrent.futures import Future from functools import lru_cache -from typing import Any, Callable, Generic, Protocol, TypeVar, cast, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Protocol, + TypeVar, + cast, + overload, +) import pydantic from requests import HTTPError @@ -22,6 +30,7 @@ from weave.trace.context import weave_client_context as weave_client_context from weave.trace.exception import exception_to_json_str from weave.trace.feedback import FeedbackQuery, RefFeedbackQuery +from weave.trace.isinstance import weave_isinstance from weave.trace.object_record import ( ObjectRecord, dataclass_object_record, @@ -83,6 +92,10 @@ ) from weave.trace_server_bindings.remote_http_trace_server import RemoteHTTPTraceServer +if TYPE_CHECKING: + from weave.scorers.base_scorer import ApplyScorerResult, Scorer + + # Controls if objects can have refs to projects not the WeaveClient project. # If False, object refs with with mismatching projects will be recreated. # If True, use existing ref to object in other project. @@ -445,43 +458,60 @@ def set_display_name(self, name: str | None) -> None: def remove_display_name(self) -> None: self.set_display_name(None) - def _apply_scorer(self, scorer_op: Op) -> None: + async def apply_scorer( + self, scorer: Op | Scorer, additional_scorer_kwargs: dict | None = None + ) -> ApplyScorerResult: """ - This is a private method that applies a scorer to a call and records the feedback. - In the near future, this will be made public, but for now it is only used internally - for testing. + `apply_scorer` is a method that applies a Scorer to a Call. This is useful + for guarding application logic with a scorer and/or monitoring the quality + of critical ops. Scorers are automatically logged to Weave as Feedback and + can be used in queries & analysis. + + Args: + scorer: The Scorer to apply. + additional_scorer_kwargs: Additional kwargs to pass to the scorer. This is + useful for passing in additional context that is not part of the call + inputs.useful for passing in additional context that is not part of the call + inputs. - Before making this public, we should refactor such that the `predict_and_score` method - inside `eval.py` uses this method inside the scorer block. + Returns: + The result of the scorer application in the form of an `ApplyScorerResult`. - Current limitations: - - only works for ops (not Scorer class) - - no async support - - no context yet (ie. ground truth) + ```python + class ApplyScorerSuccess: + result: Any + score_call: Call + ``` + + Example usage: + + ```python + my_scorer = ... # construct a scorer + prediction, prediction_call = my_op.call(input_data) + result, score_call = prediction.apply_scorer(my_scorer) + ``` """ - client = weave_client_context.require_weave_client() - scorer_signature = inspect.signature(scorer_op) - scorer_arg_names = list(scorer_signature.parameters.keys()) - score_args = {k: v for k, v in self.inputs.items() if k in scorer_arg_names} - if "output" in scorer_arg_names: - score_args["output"] = self.output - _, score_call = scorer_op.call(**score_args) - scorer_op_ref = get_ref(scorer_op) - if scorer_op_ref is None: - raise ValueError("Scorer op has no ref") - self_ref = get_ref(self) - if self_ref is None: - raise ValueError("Call has no ref") - score_results = score_call.output - score_call_ref = get_ref(score_call) - if score_call_ref is None: - raise ValueError("Score call has no ref") - client._add_runnable_feedback( - weave_ref_uri=self_ref.uri(), - output=score_results, - call_ref_uri=score_call_ref.uri(), - runnable_ref_uri=scorer_op_ref.uri(), - ) + from weave.scorers.base_scorer import Scorer, apply_scorer_async + + model_inputs = {k: v for k, v in self.inputs.items() if k != "self"} + example = {**model_inputs, **(additional_scorer_kwargs or {})} + output = self.output + if isinstance(output, ObjectRef): + output = output.get() + apply_scorer_result = await apply_scorer_async(scorer, example, output) + score_call = apply_scorer_result.score_call + + wc = weave_client_context.get_weave_client() + if wc: + scorer_ref_uri = None + if weave_isinstance(scorer, Scorer): + # Very important: if the score is generated from a Scorer subclass, + # then scorer_ref_uri will be None, and we will use the op_name from + # the score_call instead. + scorer_ref = get_ref(scorer) + scorer_ref_uri = scorer_ref.uri() if scorer_ref else None + wc._send_score_call(self, score_call, scorer_ref_uri) + return apply_scorer_result def make_client_call( @@ -807,6 +837,7 @@ def create_call( trace_id=trace_id, parent_id=parent_id, id=call_id, + # It feels like this should be inputs_postprocessed, not the refs. inputs=inputs_with_refs, attributes=attributes, ) @@ -875,8 +906,8 @@ def finish_call( postprocessed_output = _global_postprocess_output(postprocessed_output) self._save_nested_objects(postprocessed_output) - - call.output = map_to_refs(postprocessed_output) + output_as_refs = map_to_refs(postprocessed_output) + call.output = postprocessed_output # Summary handling summary = {} @@ -931,7 +962,7 @@ def finish_call( op._on_finish_handler(call, original_output, exception) def send_end_call() -> None: - output_json = to_json(call.output, project_id, self, use_dictify=False) + output_json = to_json(output_as_refs, project_id, self, use_dictify=False) self.server.call_end( CallEndReq( end=EndedCallSchemaForInsert(