From 21215ad8728d6a69fd2552f929743e0f1d9aa7e1 Mon Sep 17 00:00:00 2001 From: Douglas Blank Date: Mon, 7 Oct 2024 14:09:19 -0400 Subject: [PATCH] Moved flush to after_call (#354) * Moved flush to after_call * pre-commit config in root is different from pre-commit in sdk/python --------- Co-authored-by: Douglas Blank --- .../src/opik/decorator/base_track_decorator.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/sdks/python/src/opik/decorator/base_track_decorator.py b/sdks/python/src/opik/decorator/base_track_decorator.py index 5a759df08..efb01789c 100644 --- a/sdks/python/src/opik/decorator/base_track_decorator.py +++ b/sdks/python/src/opik/decorator/base_track_decorator.py @@ -41,6 +41,7 @@ def track( capture_input: bool = True, capture_output: bool = True, generations_aggregator: Optional[Callable[[List[Any]], Any]] = None, + flush: bool = False, ) -> Union[Callable, Callable[[Callable], Callable]]: """ Decorator to track the execution of a function. @@ -55,6 +56,7 @@ def track( capture_input: Whether to capture the input arguments. capture_output: Whether to capture the output result. generations_aggregator: Function to aggregate generation results. + flush: Whether to flush the client after logging. Returns: Callable: The decorated function(if used without parentheses) @@ -81,6 +83,7 @@ def track( capture_input=capture_input, capture_output=capture_output, generations_aggregator=generations_aggregator, + flush=flush, ) def decorator(func: Callable) -> Callable: @@ -93,6 +96,7 @@ def decorator(func: Callable) -> Callable: capture_input=capture_input, capture_output=capture_output, generations_aggregator=generations_aggregator, + flush=flush, ) return decorator @@ -107,6 +111,7 @@ def _decorate( capture_input: bool, capture_output: bool, generations_aggregator: Optional[Callable[[List[Any]], Any]], + flush: bool, ) -> Callable: if not inspect_helpers.is_async(func): return self._tracked_sync( @@ -118,6 +123,7 @@ def _decorate( capture_input=capture_input, capture_output=capture_output, generations_aggregator=generations_aggregator, + flush=flush, ) return self._tracked_async( @@ -129,6 +135,7 @@ def _decorate( capture_input=capture_input, capture_output=capture_output, generations_aggregator=generations_aggregator, + flush=flush, ) def _tracked_sync( @@ -141,6 +148,7 @@ def _tracked_sync( capture_input: bool, capture_output: bool, generations_aggregator: Optional[Callable[[List[Any]], str]], + flush: bool, ) -> Callable: @functools.wraps(func) def wrapper(*args, **kwargs) -> Any: # type: ignore @@ -179,6 +187,7 @@ def wrapper(*args, **kwargs) -> Any: # type: ignore self._after_call( output=result, capture_output=capture_output, + flush=flush, ) if result is not None: return result @@ -195,6 +204,7 @@ def _tracked_async( capture_input: bool, capture_output: bool, generations_aggregator: Optional[Callable[[List[Any]], str]], + flush: bool, ) -> Callable: @functools.wraps(func) async def wrapper(*args, **kwargs) -> Any: # type: ignore @@ -232,6 +242,7 @@ async def wrapper(*args, **kwargs) -> Any: # type: ignore self._after_call( output=result, capture_output=capture_output, + flush=flush, ) if result is not None: return result @@ -382,6 +393,7 @@ def _after_call( capture_output: bool, generators_span_to_end: Optional[span.SpanData] = None, generators_trace_to_end: Optional[trace.TraceData] = None, + flush: bool = False, ) -> None: try: if output is not None: @@ -413,6 +425,9 @@ def _after_call( client.trace(**trace_data_to_end.__dict__) + if flush: + client.flush() + except Exception as exception: LOGGER.error( logging_messages.UNEXPECTED_EXCEPTION_ON_SPAN_FINALIZATION_FOR_TRACKED_FUNCTION,