diff --git a/skynet/modules/ttt/summaries/app.py b/skynet/modules/ttt/summaries/app.py index 12cae9a..9aeccd5 100644 --- a/skynet/modules/ttt/summaries/app.py +++ b/skynet/modules/ttt/summaries/app.py @@ -12,7 +12,6 @@ from .jobs import start_monitoring_jobs from .persistence import db -from .processor import initialize as initialize_summaries from .v1.router import router as v1_router @@ -57,7 +56,6 @@ async def executor_startup(): initialize_openai_api() - initialize_summaries() log.info('summaries:executor module initialized') await db.initialize() diff --git a/skynet/modules/ttt/summaries/processor.py b/skynet/modules/ttt/summaries/processor.py index a4b7d75..5a1ddb6 100644 --- a/skynet/modules/ttt/summaries/processor.py +++ b/skynet/modules/ttt/summaries/processor.py @@ -16,7 +16,6 @@ from .prompts.summary import summary_conversation, summary_emails, summary_meeting, summary_text from .v1.models import DocumentPayload, HintType, JobType -llm = None log = get_logger(__name__) @@ -36,10 +35,8 @@ } -def initialize(): - global llm - - llm = ChatOpenAI( +def get_local_llm(**kwargs): + return ChatOpenAI( model=llama_path, api_key='placeholder', # use a placeholder value to bypass validation, and allow the custom base url to be used base_url=f'{openai_api_base_url}/v1', @@ -47,11 +44,12 @@ def initialize(): frequency_penalty=1, max_retries=0, temperature=0, + **kwargs, ) async def process(payload: DocumentPayload, job_type: JobType, model: ChatOpenAI = None) -> str: - current_model = model or llm + current_model = model or get_local_llm(max_tokens=payload.max_tokens) chain = None text = payload.text @@ -99,6 +97,7 @@ async def process(payload: DocumentPayload, job_type: JobType, model: ChatOpenAI async def process_open_ai(payload: DocumentPayload, job_type: JobType, api_key: str, model_name=None) -> str: llm = ChatOpenAI( api_key=api_key, + max_tokens=payload.max_tokens, model_name=model_name, temperature=0, ) @@ -114,6 +113,7 @@ async def process_azure( api_version=azure_openai_api_version, azure_endpoint=endpoint, azure_deployment=deployment_name, + max_tokens=payload.max_tokens, temperature=0, ) diff --git a/skynet/modules/ttt/summaries/v1/models.py b/skynet/modules/ttt/summaries/v1/models.py index a7fbe20..e4ae3c3 100644 --- a/skynet/modules/ttt/summaries/v1/models.py +++ b/skynet/modules/ttt/summaries/v1/models.py @@ -19,6 +19,7 @@ class Priority(Enum): class DocumentPayload(BaseModel): text: str hint: HintType = HintType.MEETING + max_tokens: int = 0 priority: Priority = Priority.NORMAL prompt: str | None = None @@ -27,7 +28,8 @@ class DocumentPayload(BaseModel): 'examples': [ { 'text': 'Your text here', - 'hint': 'text', + 'hint': 'meeting', + 'max_tokens': 0, 'priority': 'normal', 'prompt': 'Summarize the following text', }