diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index c2484c78a4d0..70d7356861ec 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -66,6 +66,7 @@ from weave.trace_server.constants import COMPLETIONS_CREATE_OP_NAME from weave.trace_server.emoji_util import detone_emojis from weave.trace_server.errors import ( + ClickhouseQueryError, InsertTooLarge, InvalidRequest, MissingLLMApiKeyError, @@ -756,7 +757,7 @@ def table_update(self, req: tsi.TableUpdateReq) -> tsi.TableUpdateRes: ORDER BY project_id, digest """ - row_digest_result_query = self.ch_client.query( + row_digest_result_query = self._query( query, parameters={ "project_id": req.project_id, @@ -935,7 +936,7 @@ def table_query_stats(self, req: tsi.TableQueryStatsReq) -> tsi.TableQueryStatsR WHERE project_id = {project_id:String} AND digest = {digest:String} """ - query_result = self.ch_client.query(query, parameters=parameters) + query_result = self._query(query, parameters=parameters) count = query_result.result_rows[0][0] if query_result.result_rows else 0 return tsi.TableQueryStatsRes(count=count) @@ -1260,7 +1261,7 @@ def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadRes: # The subquery is responsible for deduplication of file chunks by digest - query_result = self.ch_client.query( + query_result = self._query( """ SELECT n_chunks, val_bytes FROM ( @@ -1348,7 +1349,7 @@ def cost_query(self, req: tsi.CostQueryReq) -> tsi.CostQueryRes: query = query.order_by(req.sort_by) query = query.limit(req.limit).offset(req.offset) prepared = query.prepare(database_type="clickhouse") - query_result = self.ch_client.query(prepared.sql, prepared.parameters) + query_result = self._query(prepared.sql, prepared.parameters) results = LLM_TOKEN_PRICES_TABLE.tuples_to_rows( query_result.result_rows, prepared.fields ) @@ -1379,7 +1380,7 @@ def cost_purge(self, req: tsi.CostPurgeReq) -> tsi.CostPurgeRes: query = LLM_TOKEN_PRICES_TABLE.purge() query = query.where(query_with_pricing_level) prepared = query.prepare(database_type="clickhouse") - self.ch_client.query(prepared.sql, prepared.parameters) + self._query(prepared.sql, prepared.parameters) return tsi.CostPurgeRes() def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: @@ -1438,7 +1439,7 @@ def feedback_query(self, req: tsi.FeedbackQueryReq) -> tsi.FeedbackQueryRes: query = query.order_by(req.sort_by) query = query.limit(req.limit).offset(req.offset) prepared = query.prepare(database_type="clickhouse") - query_result = self.ch_client.query(prepared.sql, prepared.parameters) + query_result = self._query(prepared.sql, prepared.parameters) result = TABLE_FEEDBACK.tuples_to_rows( query_result.result_rows, prepared.fields ) @@ -1454,7 +1455,7 @@ def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes: query = query.project_id(req.project_id) query = query.where(req.query) prepared = query.prepare(database_type="clickhouse") - self.ch_client.query(prepared.sql, prepared.parameters) + self._query(prepared.sql, prepared.parameters) return tsi.FeedbackPurgeRes() def feedback_replace(self, req: tsi.FeedbackReplaceReq) -> tsi.FeedbackReplaceRes: @@ -1715,11 +1716,12 @@ def _query_stream( ) yield from stream except Exception as e: - logger.exception( - f"_query_stream exception: {e}", - extra={"query": query, "parameters": parameters}, + raise ClickhouseQueryError( + message=f"_query_stream exception: {e}", + query=query, + parameters=parameters, + summary=summary, ) - raise def _query( self, @@ -1729,9 +1731,19 @@ def _query( ) -> QueryResult: """Directly queries the database and returns the result.""" parameters = _process_parameters(parameters) - res = self.ch_client.query( - query, parameters=parameters, column_formats=column_formats, use_none=True - ) + try: + res = self.ch_client.query( + query, + parameters=parameters, + column_formats=column_formats, + use_none=True, + ) + except Exception as e: + raise ClickhouseQueryError( + message=f"_query exception: {e}", + query=query, + parameters=parameters, + ) logger.info( "clickhouse_query", extra={ diff --git a/weave/trace_server/errors.py b/weave/trace_server/errors.py index fac56024194a..e6716a05120e 100644 --- a/weave/trace_server/errors.py +++ b/weave/trace_server/errors.py @@ -1,4 +1,5 @@ import datetime +from typing import Any class Error(Exception): @@ -51,3 +52,22 @@ class ObjectDeletedError(Error): def __init__(self, message: str, deleted_at: datetime.datetime): self.deleted_at = deleted_at super().__init__(message) + + +class ClickhouseQueryError(Error): + """Raised when a query to Clickhouse fails.""" + + def __init__( + self, + message: str, + query: str, + parameters: dict[str, Any], + summary: dict[str, Any] | None = None, + ): + self.query = query + self.parameters = parameters + self.summary = summary + self.message = message + + def __repr__(self) -> str: + return f"ClickhouseQueryError message:{self.message}\nquery:{self.query}\nparameters:{self.parameters}"