diff --git a/mknodes/templatenodes/mkllm/__init__.py b/mknodes/templatenodes/mkllm/__init__.py index 4c8fc244..7145ed29 100644 --- a/mknodes/templatenodes/mkllm/__init__.py +++ b/mknodes/templatenodes/mkllm/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations +import functools from typing import Any, TYPE_CHECKING from jinjarope import llmfilters @@ -17,6 +18,16 @@ logger = log.get_logger(__name__) +@functools.cache +def complete_llm(user_prompt: str, system_prompt: str, model: str, context: str) -> str: + return llmfilters.llm_complete( + user_prompt, + system_prompt, + model=model, + context=context, + ) + + class MkLlm(mktext.MkText): """Node for LLM-based text generation.""" @@ -88,7 +99,7 @@ def text(self) -> str: "\n".join(filter(None, [self._context, *context_items])) or None ) - return llmfilters.llm_complete( + return complete_llm( self.user_prompt, self.system_prompt, model=self._model,