From fa3ce6f4b35766608b2cb77b813b6b14efb3729e Mon Sep 17 00:00:00 2001 From: Roger Yang <80478925+RogerHYang@users.noreply.github.com> Date: Fri, 19 Jan 2024 15:29:02 -0800 Subject: [PATCH] refactor: clean up stream exception logic (#123) --- .../instrumentation/openai/_stream.py | 38 +++++++------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_stream.py b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_stream.py index 34576a9c9..1a1c750cc 100644 --- a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_stream.py +++ b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_stream.py @@ -66,25 +66,20 @@ def __next__(self) -> Any: # pass through mistaken calls if not hasattr(self.__wrapped__, "__next__"): self.__wrapped__.__next__() - iteration_is_finished = False - status_code: Optional[trace_api.StatusCode] = None try: chunk: Any = self.__wrapped__.__next__() except Exception as exception: - iteration_is_finished = True - if isinstance(exception, StopIteration): - status_code = trace_api.StatusCode.OK - else: - status_code = trace_api.StatusCode.ERROR - self._self_with_span.record_exception(exception) + if not self._self_is_finished: + if isinstance(exception, StopIteration): + status_code = trace_api.StatusCode.OK + else: + status_code = trace_api.StatusCode.ERROR + self._self_with_span.record_exception(exception) + self._finish_tracing(status_code=status_code) raise else: self._process_chunk(chunk) - status_code = trace_api.StatusCode.OK return chunk - finally: - if iteration_is_finished and not self._self_is_finished: - self._finish_tracing(status_code=status_code) def __aiter__(self) -> AsyncIterator[Any]: return self @@ -93,25 +88,20 @@ async def __anext__(self) -> Any: # pass through mistaken calls if not hasattr(self.__wrapped__, "__anext__"): self.__wrapped__.__anext__() - iteration_is_finished = False - status_code: Optional[trace_api.StatusCode] = None try: chunk: Any = await self.__wrapped__.__anext__() except Exception as exception: - iteration_is_finished = True - if isinstance(exception, StopAsyncIteration): - status_code = trace_api.StatusCode.OK - else: - status_code = trace_api.StatusCode.ERROR - self._self_with_span.record_exception(exception) + if not self._self_is_finished: + if isinstance(exception, StopAsyncIteration): + status_code = trace_api.StatusCode.OK + else: + status_code = trace_api.StatusCode.ERROR + self._self_with_span.record_exception(exception) + self._finish_tracing(status_code=status_code) raise else: self._process_chunk(chunk) - status_code = trace_api.StatusCode.OK return chunk - finally: - if iteration_is_finished and not self._self_is_finished: - self._finish_tracing(status_code=status_code) def _process_chunk(self, chunk: Any) -> None: if not self._self_iteration_count: