From 90442ed830069dcd20150d99f7074709d57b18c1 Mon Sep 17 00:00:00 2001 From: Balasubramanyam Evani Date: Mon, 16 Dec 2024 20:41:21 +0000 Subject: [PATCH] adds streaming support in the generator:_post_call --- adalflow/adalflow/core/generator.py | 67 +++++++++++++++++++++++------ 1 file changed, 53 insertions(+), 14 deletions(-) diff --git a/adalflow/adalflow/core/generator.py b/adalflow/adalflow/core/generator.py index baedd8fb..e8d8864d 100644 --- a/adalflow/adalflow/core/generator.py +++ b/adalflow/adalflow/core/generator.py @@ -7,7 +7,16 @@ import os from pathlib import Path -from typing import Any, Dict, Optional, Union, Callable, Tuple, List +from typing import ( + Any, + Dict, + Optional, + Union, + Callable, + Tuple, + List, + Generator as GeneratorType, +) import logging @@ -304,24 +313,54 @@ def _extra_repr(self) -> str: s = f"model_kwargs={self.model_kwargs}, model_type={self.model_type}" return s + def _process_chunk(self, chunk: Any) -> GeneratorOutput: + """Process a single chunk of data using the output processors. + + Args: + chunk: Raw chunk data to process + + Returns: + Any: Processed chunk + str: Error string in case of an exception + """ + if not chunk or not self.output_processors: + return chunk, None + + try: + processed_data = self.output_processors(chunk) + return processed_data, None + except Exception as e: + log.error(f"Error processing chunk using the output processors: {e}") + return None, str(e) + def _post_call(self, completion: Any) -> GeneratorOutput: - r"""Get string completion and process it with the output_processors.""" - # parse chat completion will only fill the raw_response - output: GeneratorOutput = self.model_client.parse_chat_completion(completion) - # Now adding the data filed to the output - data = output.raw_response - if self.output_processors: - if data: + """Process completion output, handling both streaming and non-streaming cases. + + Args: + completion: Raw completion data from the llm provider + + Returns: + GeneratorOutput containing processed data or generator type + """ + # Parse chat completion will only fill the raw_response + output = self.model_client.parse_chat_completion(completion) + # Handle streaming case + if isinstance(output, GeneratorType): + + def process_stream(): try: - data = self.output_processors(data) - output.data = data + for out in output: + log.debug(f"Processing raw chunk: {out.raw_response}") + out.data, out.error = self._process_chunk(out.raw_response) + yield out except Exception as e: - log.error(f"Error processing the output processors: {e}") - output.error = str(e) + log.error(f"Error in stream processing: {e}") + yield GeneratorOutput(error=str(e)) + return GeneratorOutput(data=process_stream(), raw_response=output) else: - output.data = data - + # Handle non-streaming case + output.data, output.error = self._process_chunk(output.raw_response) return output def _pre_call(self, prompt_kwargs: Dict, model_kwargs: Dict) -> Dict[str, Any]: