Skip to content

Commit

Permalink
revert weave_client changes
Browse files Browse the repository at this point in the history
  • Loading branch information
adrnswanberg committed Jan 14, 2025
1 parent c91f0f5 commit 803fa13
Showing 1 changed file with 43 additions and 136 deletions.
179 changes: 43 additions & 136 deletions weave/trace/weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,15 @@

import dataclasses
import datetime
import json
import inspect
import logging
import platform
import re
import sys
from collections.abc import Iterator, Sequence
from concurrent.futures import Future
from functools import lru_cache
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Protocol,
TypeVar,
cast,
overload,
)
from typing import Any, Callable, Generic, Protocol, TypeVar, cast, overload

import pydantic
from requests import HTTPError
Expand All @@ -31,7 +22,6 @@
from weave.trace.context import weave_client_context as weave_client_context
from weave.trace.exception import exception_to_json_str
from weave.trace.feedback import FeedbackQuery, RefFeedbackQuery
from weave.trace.isinstance import weave_isinstance
from weave.trace.object_record import (
ObjectRecord,
dataclass_object_record,
Expand All @@ -48,7 +38,7 @@
parse_op_uri,
parse_uri,
)
from weave.trace.sanitize import REDACTED_VALUE, should_redact
from weave.trace.sanitize import REDACT_KEYS, REDACTED_VALUE
from weave.trace.serialize import from_json, isinstance_namedtuple, to_json
from weave.trace.serializer import get_serializer_for_obj
from weave.trace.settings import client_parallelism
Expand All @@ -64,7 +54,6 @@
CallsDeleteReq,
CallsFilter,
CallsQueryReq,
CallsQueryStatsReq,
CallStartReq,
CallUpdateReq,
CostCreateInput,
Expand All @@ -79,7 +68,6 @@
FileCreateRes,
ObjCreateReq,
ObjCreateRes,
ObjDeleteReq,
ObjectVersionFilter,
ObjQueryReq,
ObjReadReq,
Expand All @@ -95,10 +83,6 @@
)
from weave.trace_server_bindings.remote_http_trace_server import RemoteHTTPTraceServer

if TYPE_CHECKING:
from weave.scorers.base_scorer import ApplyScorerResult, Scorer


# Controls if objects can have refs to projects not the WeaveClient project.
# If False, object refs with with mismatching projects will be recreated.
# If True, use existing ref to object in other project.
Expand All @@ -115,7 +99,6 @@ def __call__(self, offset: int, limit: int) -> list[T]: ...


TransformFunc = Callable[[T], R]
SizeFunc = Callable[[], int]


class PaginatedIterator(Generic[T, R]):
Expand All @@ -127,12 +110,10 @@ def __init__(
fetch_func: FetchFunc[T],
page_size: int = 1000,
transform_func: TransformFunc[T, R] | None = None,
size_func: SizeFunc | None = None,
) -> None:
self.fetch_func = fetch_func
self.page_size = page_size
self.transform_func = transform_func
self.size_func = size_func

if page_size <= 0:
raise ValueError("page_size must be greater than 0")
Expand Down Expand Up @@ -201,13 +182,6 @@ def __iter__(self: PaginatedIterator[T, R]) -> Iterator[R]: ...
def __iter__(self) -> Iterator[T] | Iterator[R]:
return self._get_slice(slice(0, None, 1))

def __len__(self) -> int:
"""This method is included for convenience. It includes a network call, which
is typically slower than most other len() operations!"""
if not self.size_func:
raise TypeError("This iterator does not support len()")
return self.size_func()


# TODO: should be Call, not WeaveObject
CallsIter = PaginatedIterator[CallSchema, WeaveObject]
Expand Down Expand Up @@ -236,17 +210,7 @@ def transform_func(call: CallSchema) -> WeaveObject:
entity, project = project_id.split("/")
return make_client_call(entity, project, call, server)

def size_func() -> int:
response = server.calls_query_stats(
CallsQueryStatsReq(project_id=project_id, filter=filter)
)
return response.count

return PaginatedIterator(
fetch_func,
transform_func=transform_func,
size_func=size_func,
)
return PaginatedIterator(fetch_func, transform_func=transform_func)


class OpNameError(ValueError):
Expand Down Expand Up @@ -481,60 +445,43 @@ def set_display_name(self, name: str | None) -> None:
def remove_display_name(self) -> None:
self.set_display_name(None)

async def apply_scorer(
self, scorer: Op | Scorer, additional_scorer_kwargs: dict | None = None
) -> ApplyScorerResult:
def _apply_scorer(self, scorer_op: Op) -> None:
"""
`apply_scorer` is a method that applies a Scorer to a Call. This is useful
for guarding application logic with a scorer and/or monitoring the quality
of critical ops. Scorers are automatically logged to Weave as Feedback and
can be used in queries & analysis.
Args:
scorer: The Scorer to apply.
additional_scorer_kwargs: Additional kwargs to pass to the scorer. This is
useful for passing in additional context that is not part of the call
inputs.useful for passing in additional context that is not part of the call
inputs.
This is a private method that applies a scorer to a call and records the feedback.
In the near future, this will be made public, but for now it is only used internally
for testing.
Returns:
The result of the scorer application in the form of an `ApplyScorerResult`.
```python
class ApplyScorerSuccess:
result: Any
score_call: Call
```
Example usage:
Before making this public, we should refactor such that the `predict_and_score` method
inside `eval.py` uses this method inside the scorer block.
```python
my_scorer = ... # construct a scorer
prediction, prediction_call = my_op.call(input_data)
result, score_call = prediction.apply_scorer(my_scorer)
```
Current limitations:
- only works for ops (not Scorer class)
- no async support
- no context yet (ie. ground truth)
"""
from weave.scorers.base_scorer import Scorer, apply_scorer_async

model_inputs = {k: v for k, v in self.inputs.items() if k != "self"}
example = {**model_inputs, **(additional_scorer_kwargs or {})}
output = self.output
if isinstance(output, ObjectRef):
output = output.get()
apply_scorer_result = await apply_scorer_async(scorer, example, output)
score_call = apply_scorer_result.score_call

wc = weave_client_context.get_weave_client()
if wc:
scorer_ref_uri = None
if weave_isinstance(scorer, Scorer):
# Very important: if the score is generated from a Scorer subclass,
# then scorer_ref_uri will be None, and we will use the op_name from
# the score_call instead.
scorer_ref = get_ref(scorer)
scorer_ref_uri = scorer_ref.uri() if scorer_ref else None
wc._send_score_call(self, score_call, scorer_ref_uri)
return apply_scorer_result
client = weave_client_context.require_weave_client()
scorer_signature = inspect.signature(scorer_op)
scorer_arg_names = list(scorer_signature.parameters.keys())
score_args = {k: v for k, v in self.inputs.items() if k in scorer_arg_names}
if "output" in scorer_arg_names:
score_args["output"] = self.output
_, score_call = scorer_op.call(**score_args)
scorer_op_ref = get_ref(scorer_op)
if scorer_op_ref is None:
raise ValueError("Scorer op has no ref")
self_ref = get_ref(self)
if self_ref is None:
raise ValueError("Call has no ref")
score_results = score_call.output
score_call_ref = get_ref(score_call)
if score_call_ref is None:
raise ValueError("Score call has no ref")
client._add_runnable_feedback(
weave_ref_uri=self_ref.uri(),
output=score_results,
call_ref_uri=score_call_ref.uri(),
runnable_ref_uri=scorer_op_ref.uri(),
)


def make_client_call(
Expand Down Expand Up @@ -697,15 +644,8 @@ def get(self, ref: ObjectRef) -> Any:
)
)
except HTTPError as e:
if e.response is not None:
if e.response.content:
try:
reason = json.loads(e.response.content).get("reason")
raise ValueError(reason)
except json.JSONDecodeError:
raise ValueError(e.response.content)
if e.response.status_code == 404:
raise ValueError(f"Unable to find object for ref uri: {ref.uri()}")
if e.response is not None and e.response.status_code == 404:
raise ValueError(f"Unable to find object for ref uri: {ref.uri()}")
raise

# At this point, `ref.digest` is one of three things:
Expand Down Expand Up @@ -815,8 +755,6 @@ def create_call(
Returns:
The created Call object.
"""
from weave.trace.api import _global_postprocess_inputs

if isinstance(op, str):
if op not in self._anonymous_ops:
self._anonymous_ops[op] = _build_anonymous_op(op)
Expand All @@ -831,9 +769,6 @@ def create_call(
else:
inputs_postprocessed = inputs_redacted

if _global_postprocess_inputs:
inputs_postprocessed = _global_postprocess_inputs(inputs_postprocessed)

self._save_nested_objects(inputs_postprocessed)
inputs_with_refs = map_to_refs(inputs_postprocessed)

Expand All @@ -857,7 +792,6 @@ def create_call(
attributes._set_weave_item("os_version", platform.version())
attributes._set_weave_item("os_release", platform.release())
attributes._set_weave_item("sys_version", sys.version)
attributes._set_weave_item("tracing_sample_rate", op.tracing_sample_rate)

op_name_future = self.future_executor.defer(lambda: op_def_ref.uri())

Expand All @@ -868,7 +802,6 @@ def create_call(
trace_id=trace_id,
parent_id=parent_id,
id=call_id,
# It feels like this should be inputs_postprocessed, not the refs.
inputs=inputs_with_refs,
attributes=attributes,
)
Expand Down Expand Up @@ -922,8 +855,6 @@ def finish_call(
*,
op: Op | None = None,
) -> None:
from weave.trace.api import _global_postprocess_output

ended_at = datetime.datetime.now(tz=datetime.timezone.utc)
call.ended_at = ended_at
original_output = output
Expand All @@ -932,13 +863,9 @@ def finish_call(
postprocessed_output = op.postprocess_output(original_output)
else:
postprocessed_output = original_output

if _global_postprocess_output:
postprocessed_output = _global_postprocess_output(postprocessed_output)

self._save_nested_objects(postprocessed_output)
output_as_refs = map_to_refs(postprocessed_output)
call.output = postprocessed_output

call.output = map_to_refs(postprocessed_output)

# Summary handling
summary = {}
Expand Down Expand Up @@ -993,7 +920,7 @@ def finish_call(
op._on_finish_handler(call, original_output, exception)

def send_end_call() -> None:
output_json = to_json(output_as_refs, project_id, self, use_dictify=False)
output_json = to_json(call.output, project_id, self, use_dictify=False)
self.server.call_end(
CallEndReq(
end=EndedCallSchemaForInsert(
Expand Down Expand Up @@ -1025,26 +952,6 @@ def delete_call(self, call: Call) -> None:
)
)

@trace_sentry.global_trace_sentry.watch()
def delete_object_version(self, object: ObjectRef) -> None:
self.server.obj_delete(
ObjDeleteReq(
project_id=self._project_id(),
object_id=object.name,
digests=[object.digest],
)
)

@trace_sentry.global_trace_sentry.watch()
def delete_op_version(self, op: OpRef) -> None:
self.server.obj_delete(
ObjDeleteReq(
project_id=self._project_id(),
object_id=op.name,
digests=[op.digest],
)
)

def get_feedback(
self,
query: Query | str | None = None,
Expand Down Expand Up @@ -1741,7 +1648,7 @@ def redact_sensitive_keys(obj: Any) -> Any:
if isinstance(obj, dict):
dict_res = {}
for k, v in obj.items():
if isinstance(k, str) and should_redact(k):
if k in REDACT_KEYS:
dict_res[k] = REDACTED_VALUE
else:
dict_res[k] = redact_sensitive_keys(v)
Expand Down

0 comments on commit 803fa13

Please sign in to comment.