diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index e9ba03aca801..1d5d54b9b23c 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -2,7 +2,7 @@ import dataclasses import datetime -import json +import inspect import logging import platform import re @@ -10,16 +10,7 @@ from collections.abc import Iterator, Sequence from concurrent.futures import Future from functools import lru_cache -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Generic, - Protocol, - TypeVar, - cast, - overload, -) +from typing import Any, Callable, Generic, Protocol, TypeVar, cast, overload import pydantic from requests import HTTPError @@ -31,7 +22,6 @@ from weave.trace.context import weave_client_context as weave_client_context from weave.trace.exception import exception_to_json_str from weave.trace.feedback import FeedbackQuery, RefFeedbackQuery -from weave.trace.isinstance import weave_isinstance from weave.trace.object_record import ( ObjectRecord, dataclass_object_record, @@ -48,7 +38,7 @@ parse_op_uri, parse_uri, ) -from weave.trace.sanitize import REDACTED_VALUE, should_redact +from weave.trace.sanitize import REDACT_KEYS, REDACTED_VALUE from weave.trace.serialize import from_json, isinstance_namedtuple, to_json from weave.trace.serializer import get_serializer_for_obj from weave.trace.settings import client_parallelism @@ -64,7 +54,6 @@ CallsDeleteReq, CallsFilter, CallsQueryReq, - CallsQueryStatsReq, CallStartReq, CallUpdateReq, CostCreateInput, @@ -79,7 +68,6 @@ FileCreateRes, ObjCreateReq, ObjCreateRes, - ObjDeleteReq, ObjectVersionFilter, ObjQueryReq, ObjReadReq, @@ -95,10 +83,6 @@ ) from weave.trace_server_bindings.remote_http_trace_server import RemoteHTTPTraceServer -if TYPE_CHECKING: - from weave.scorers.base_scorer import ApplyScorerResult, Scorer - - # Controls if objects can have refs to projects not the WeaveClient project. # If False, object refs with with mismatching projects will be recreated. # If True, use existing ref to object in other project. @@ -115,7 +99,6 @@ def __call__(self, offset: int, limit: int) -> list[T]: ... TransformFunc = Callable[[T], R] -SizeFunc = Callable[[], int] class PaginatedIterator(Generic[T, R]): @@ -127,12 +110,10 @@ def __init__( fetch_func: FetchFunc[T], page_size: int = 1000, transform_func: TransformFunc[T, R] | None = None, - size_func: SizeFunc | None = None, ) -> None: self.fetch_func = fetch_func self.page_size = page_size self.transform_func = transform_func - self.size_func = size_func if page_size <= 0: raise ValueError("page_size must be greater than 0") @@ -201,13 +182,6 @@ def __iter__(self: PaginatedIterator[T, R]) -> Iterator[R]: ... def __iter__(self) -> Iterator[T] | Iterator[R]: return self._get_slice(slice(0, None, 1)) - def __len__(self) -> int: - """This method is included for convenience. It includes a network call, which - is typically slower than most other len() operations!""" - if not self.size_func: - raise TypeError("This iterator does not support len()") - return self.size_func() - # TODO: should be Call, not WeaveObject CallsIter = PaginatedIterator[CallSchema, WeaveObject] @@ -236,17 +210,7 @@ def transform_func(call: CallSchema) -> WeaveObject: entity, project = project_id.split("/") return make_client_call(entity, project, call, server) - def size_func() -> int: - response = server.calls_query_stats( - CallsQueryStatsReq(project_id=project_id, filter=filter) - ) - return response.count - - return PaginatedIterator( - fetch_func, - transform_func=transform_func, - size_func=size_func, - ) + return PaginatedIterator(fetch_func, transform_func=transform_func) class OpNameError(ValueError): @@ -481,60 +445,43 @@ def set_display_name(self, name: str | None) -> None: def remove_display_name(self) -> None: self.set_display_name(None) - async def apply_scorer( - self, scorer: Op | Scorer, additional_scorer_kwargs: dict | None = None - ) -> ApplyScorerResult: + def _apply_scorer(self, scorer_op: Op) -> None: """ - `apply_scorer` is a method that applies a Scorer to a Call. This is useful - for guarding application logic with a scorer and/or monitoring the quality - of critical ops. Scorers are automatically logged to Weave as Feedback and - can be used in queries & analysis. - - Args: - scorer: The Scorer to apply. - additional_scorer_kwargs: Additional kwargs to pass to the scorer. This is - useful for passing in additional context that is not part of the call - inputs.useful for passing in additional context that is not part of the call - inputs. + This is a private method that applies a scorer to a call and records the feedback. + In the near future, this will be made public, but for now it is only used internally + for testing. - Returns: - The result of the scorer application in the form of an `ApplyScorerResult`. - - ```python - class ApplyScorerSuccess: - result: Any - score_call: Call - ``` - - Example usage: + Before making this public, we should refactor such that the `predict_and_score` method + inside `eval.py` uses this method inside the scorer block. - ```python - my_scorer = ... # construct a scorer - prediction, prediction_call = my_op.call(input_data) - result, score_call = prediction.apply_scorer(my_scorer) - ``` + Current limitations: + - only works for ops (not Scorer class) + - no async support + - no context yet (ie. ground truth) """ - from weave.scorers.base_scorer import Scorer, apply_scorer_async - - model_inputs = {k: v for k, v in self.inputs.items() if k != "self"} - example = {**model_inputs, **(additional_scorer_kwargs or {})} - output = self.output - if isinstance(output, ObjectRef): - output = output.get() - apply_scorer_result = await apply_scorer_async(scorer, example, output) - score_call = apply_scorer_result.score_call - - wc = weave_client_context.get_weave_client() - if wc: - scorer_ref_uri = None - if weave_isinstance(scorer, Scorer): - # Very important: if the score is generated from a Scorer subclass, - # then scorer_ref_uri will be None, and we will use the op_name from - # the score_call instead. - scorer_ref = get_ref(scorer) - scorer_ref_uri = scorer_ref.uri() if scorer_ref else None - wc._send_score_call(self, score_call, scorer_ref_uri) - return apply_scorer_result + client = weave_client_context.require_weave_client() + scorer_signature = inspect.signature(scorer_op) + scorer_arg_names = list(scorer_signature.parameters.keys()) + score_args = {k: v for k, v in self.inputs.items() if k in scorer_arg_names} + if "output" in scorer_arg_names: + score_args["output"] = self.output + _, score_call = scorer_op.call(**score_args) + scorer_op_ref = get_ref(scorer_op) + if scorer_op_ref is None: + raise ValueError("Scorer op has no ref") + self_ref = get_ref(self) + if self_ref is None: + raise ValueError("Call has no ref") + score_results = score_call.output + score_call_ref = get_ref(score_call) + if score_call_ref is None: + raise ValueError("Score call has no ref") + client._add_runnable_feedback( + weave_ref_uri=self_ref.uri(), + output=score_results, + call_ref_uri=score_call_ref.uri(), + runnable_ref_uri=scorer_op_ref.uri(), + ) def make_client_call( @@ -697,15 +644,8 @@ def get(self, ref: ObjectRef) -> Any: ) ) except HTTPError as e: - if e.response is not None: - if e.response.content: - try: - reason = json.loads(e.response.content).get("reason") - raise ValueError(reason) - except json.JSONDecodeError: - raise ValueError(e.response.content) - if e.response.status_code == 404: - raise ValueError(f"Unable to find object for ref uri: {ref.uri()}") + if e.response is not None and e.response.status_code == 404: + raise ValueError(f"Unable to find object for ref uri: {ref.uri()}") raise # At this point, `ref.digest` is one of three things: @@ -815,8 +755,6 @@ def create_call( Returns: The created Call object. """ - from weave.trace.api import _global_postprocess_inputs - if isinstance(op, str): if op not in self._anonymous_ops: self._anonymous_ops[op] = _build_anonymous_op(op) @@ -831,9 +769,6 @@ def create_call( else: inputs_postprocessed = inputs_redacted - if _global_postprocess_inputs: - inputs_postprocessed = _global_postprocess_inputs(inputs_postprocessed) - self._save_nested_objects(inputs_postprocessed) inputs_with_refs = map_to_refs(inputs_postprocessed) @@ -857,7 +792,6 @@ def create_call( attributes._set_weave_item("os_version", platform.version()) attributes._set_weave_item("os_release", platform.release()) attributes._set_weave_item("sys_version", sys.version) - attributes._set_weave_item("tracing_sample_rate", op.tracing_sample_rate) op_name_future = self.future_executor.defer(lambda: op_def_ref.uri()) @@ -868,7 +802,6 @@ def create_call( trace_id=trace_id, parent_id=parent_id, id=call_id, - # It feels like this should be inputs_postprocessed, not the refs. inputs=inputs_with_refs, attributes=attributes, ) @@ -922,8 +855,6 @@ def finish_call( *, op: Op | None = None, ) -> None: - from weave.trace.api import _global_postprocess_output - ended_at = datetime.datetime.now(tz=datetime.timezone.utc) call.ended_at = ended_at original_output = output @@ -932,13 +863,9 @@ def finish_call( postprocessed_output = op.postprocess_output(original_output) else: postprocessed_output = original_output - - if _global_postprocess_output: - postprocessed_output = _global_postprocess_output(postprocessed_output) - self._save_nested_objects(postprocessed_output) - output_as_refs = map_to_refs(postprocessed_output) - call.output = postprocessed_output + + call.output = map_to_refs(postprocessed_output) # Summary handling summary = {} @@ -993,7 +920,7 @@ def finish_call( op._on_finish_handler(call, original_output, exception) def send_end_call() -> None: - output_json = to_json(output_as_refs, project_id, self, use_dictify=False) + output_json = to_json(call.output, project_id, self, use_dictify=False) self.server.call_end( CallEndReq( end=EndedCallSchemaForInsert( @@ -1025,26 +952,6 @@ def delete_call(self, call: Call) -> None: ) ) - @trace_sentry.global_trace_sentry.watch() - def delete_object_version(self, object: ObjectRef) -> None: - self.server.obj_delete( - ObjDeleteReq( - project_id=self._project_id(), - object_id=object.name, - digests=[object.digest], - ) - ) - - @trace_sentry.global_trace_sentry.watch() - def delete_op_version(self, op: OpRef) -> None: - self.server.obj_delete( - ObjDeleteReq( - project_id=self._project_id(), - object_id=op.name, - digests=[op.digest], - ) - ) - def get_feedback( self, query: Query | str | None = None, @@ -1741,7 +1648,7 @@ def redact_sensitive_keys(obj: Any) -> Any: if isinstance(obj, dict): dict_res = {} for k, v in obj.items(): - if isinstance(k, str) and should_redact(k): + if k in REDACT_KEYS: dict_res[k] = REDACTED_VALUE else: dict_res[k] = redact_sensitive_keys(v)