From 24d7ee654c01f039cc22bbf1b30a33d4b9a48dcf Mon Sep 17 00:00:00 2001 From: Edward Brown Date: Mon, 30 Sep 2024 13:33:08 +0100 Subject: [PATCH] Add json-fix tools and AI tool call objects --- fmtr/tools/__init__.py | 5 +++ fmtr/tools/ai_tools.py | 53 +++++++++++++++++++++++++++--- fmtr/tools/data_modelling_tools.py | 39 ++++++++++++++++++++++ fmtr/tools/json_fix_tools.py | 20 +++++++++++ fmtr/tools/version | 2 +- requirements.py | 6 ++-- 6 files changed, 118 insertions(+), 7 deletions(-) create mode 100644 fmtr/tools/data_modelling_tools.py create mode 100644 fmtr/tools/json_fix_tools.py diff --git a/fmtr/tools/__init__.py b/fmtr/tools/__init__.py index 69c1432..e57006e 100644 --- a/fmtr/tools/__init__.py +++ b/fmtr/tools/__init__.py @@ -101,6 +101,11 @@ except ImportError as exception: ai = MissingExtraMockModule('ai', exception) +try: + from fmtr.tools import json_fix_tools as json_fix +except ImportError as exception: + json_fix = MissingExtraMockModule('json_fix', exception) + __all__ = [ 'config', diff --git a/fmtr/tools/ai_tools.py b/fmtr/tools/ai_tools.py index 7f2b137..a27df4e 100644 --- a/fmtr/tools/ai_tools.py +++ b/fmtr/tools/ai_tools.py @@ -1,11 +1,12 @@ import torch from datetime import datetime from peft import PeftConfig, PeftModel +from pydantic import Field from statistics import mean from transformers import AutoModelForCausalLM, AutoTokenizer from typing import List -from fmtr.tools import logger +from fmtr.tools import logger, data_modelling_tools from fmtr.tools.hfh_tools import get_hf_cache_path CPU = 'cpu' @@ -319,6 +320,45 @@ def get_output(self, prompt, **kwargs): return output +class ToolCall(data_modelling_tools.Base): + """ + + Tool call data + + """ + name: str = Field( + ..., + description="The function name" + ) + arguments: dict = Field( + ..., + description="The function arguments" + ) + + def apply(self, functions): + """ + + Apply the specified functions to their arguments + + """ + functions = {function.__name__: function for function in functions} + function = functions[self.name] + obj = function(**self.arguments) + return obj + + +class ToolsCall(data_modelling_tools.Root): + """ + + Tool calls data + + """ + root: List[ToolCall] + + def apply(self, functions): + objs = [child.apply(functions) for child in self.root] + return objs + def tst(): """ @@ -353,15 +393,20 @@ def get_current_weather(location: str, format: str): """ return "It's 25 degrees and sunny!" - class BiTools(BulkInferenceManager): + class BulkInferenceManagerTools(BulkInferenceManager): TOOLS = [get_current_weather] prompt = "What's the weather like in Paris?" prompts = [prompt] - manager = BiTools() + manager = BulkInferenceManagerTools() gen = manager.get_outputs(prompts, max_new_tokens=200, do_sample=True, temperature=1.2, top_p=0.5, top_k=50) texts = list(gen) - return texts + + for text in texts: + objs = ToolsCall.from_json(text).apply(BulkInferenceManagerTools.TOOLS) + obj = objs[0] + print(obj) + diff --git a/fmtr/tools/data_modelling_tools.py b/fmtr/tools/data_modelling_tools.py new file mode 100644 index 0000000..5cc70ef --- /dev/null +++ b/fmtr/tools/data_modelling_tools.py @@ -0,0 +1,39 @@ +from pydantic import BaseModel, RootModel + + +class MixinFromJson: + + @classmethod + def from_json(cls, json_str): + """ + + Error-tolerant deserialization + + """ + from fmtr.tools import json_fix + data = json_fix.from_json(json_str, default={}) + + if type(data) is dict: + self = cls(**data) + else: + self = cls(data) + + return self + + +class Base(BaseModel, MixinFromJson): + """ + + Base model + + """ + ... + + +class Root(RootModel, MixinFromJson): + """ + + Root (list) model + + """ + ... diff --git a/fmtr/tools/json_fix_tools.py b/fmtr/tools/json_fix_tools.py new file mode 100644 index 0000000..890cf4f --- /dev/null +++ b/fmtr/tools/json_fix_tools.py @@ -0,0 +1,20 @@ +import json +import json_repair + +from fmtr.tools.logging_tools import logger +from fmtr.tools.tools import Raise + + +def from_json(json_string, default=None): + """ + + Error-tolerant JSON deserialization + + """ + try: + return json_repair.loads(json_string) + except json.JSONDecodeError as exception: + if default is Raise: + raise exception + logger.warning(f'Deserialization failed {repr(exception)}: {json_string}') + return default diff --git a/fmtr/tools/version b/fmtr/tools/version index f514a2f..f76f913 100644 --- a/fmtr/tools/version +++ b/fmtr/tools/version @@ -1 +1 @@ -0.9.1 \ No newline at end of file +0.9.2 \ No newline at end of file diff --git a/requirements.py b/requirements.py index 526a25a..00d0fb7 100644 --- a/requirements.py +++ b/requirements.py @@ -20,8 +20,10 @@ 'netrc': ['tinynetrc'], 'hfh': ['huggingface_hub'], 'merging': ['deepmerge'], - 'api': ['fastapi', 'uvicorn', 'logging'], - 'ai': ['peft', 'transformers[sentencepiece]', 'torchvision', 'torchaudio'] + 'api': ['fastapi', 'uvicorn', 'logging', 'dm'], + 'ai': ['peft', 'transformers[sentencepiece]', 'torchvision', 'torchaudio', 'dm'], + 'dm': ['pydantic'], + 'json-fix': ['json_repair'] } CONSOLE_SCRIPTS = [