From 5d4f3723147722fe06a83b7741157b3d69a72d65 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Fri, 19 Jan 2024 15:06:24 -0800 Subject: [PATCH] refactor: clean up stream exception logic --- .../instrumentation/openai/_stream.py | 40 ++++++++----------- 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_stream.py b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_stream.py index 34576a9c9..afaf090fb 100644 --- a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_stream.py +++ b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_stream.py @@ -66,25 +66,22 @@ def __next__(self) -> Any: # pass through mistaken calls if not hasattr(self.__wrapped__, "__next__"): self.__wrapped__.__next__() - iteration_is_finished = False - status_code: Optional[trace_api.StatusCode] = None try: chunk: Any = self.__wrapped__.__next__() except Exception as exception: - iteration_is_finished = True - if isinstance(exception, StopIteration): - status_code = trace_api.StatusCode.OK - else: - status_code = trace_api.StatusCode.ERROR - self._self_with_span.record_exception(exception) + if not self._self_is_finished: + if isinstance(exception, StopIteration): + status_code = trace_api.StatusCode.OK + else: + status_code = trace_api.StatusCode.ERROR + self._self_with_span.record_exception(exception) + if status_code is trace_api.StatusCode.ERROR: + self._self_with_span.record_exception(exception) + self._finish_tracing(status_code=status_code) raise else: self._process_chunk(chunk) - status_code = trace_api.StatusCode.OK return chunk - finally: - if iteration_is_finished and not self._self_is_finished: - self._finish_tracing(status_code=status_code) def __aiter__(self) -> AsyncIterator[Any]: return self @@ -93,25 +90,20 @@ async def __anext__(self) -> Any: # pass through mistaken calls if not hasattr(self.__wrapped__, "__anext__"): self.__wrapped__.__anext__() - iteration_is_finished = False - status_code: Optional[trace_api.StatusCode] = None try: chunk: Any = await self.__wrapped__.__anext__() except Exception as exception: - iteration_is_finished = True - if isinstance(exception, StopAsyncIteration): - status_code = trace_api.StatusCode.OK - else: - status_code = trace_api.StatusCode.ERROR - self._self_with_span.record_exception(exception) + if not self._self_is_finished: + if isinstance(exception, StopAsyncIteration): + status_code = trace_api.StatusCode.OK + else: + status_code = trace_api.StatusCode.ERROR + self._self_with_span.record_exception(exception) + self._finish_tracing(status_code=status_code) raise else: self._process_chunk(chunk) - status_code = trace_api.StatusCode.OK return chunk - finally: - if iteration_is_finished and not self._self_is_finished: - self._finish_tracing(status_code=status_code) def _process_chunk(self, chunk: Any) -> None: if not self._self_iteration_count: