Skip to content

Commit

Permalink
Add json-fix tools and AI tool call objects
Browse files Browse the repository at this point in the history
  • Loading branch information
ejohb committed Sep 30, 2024
1 parent cf08026 commit 24d7ee6
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 7 deletions.
5 changes: 5 additions & 0 deletions fmtr/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@
except ImportError as exception:
ai = MissingExtraMockModule('ai', exception)

try:
from fmtr.tools import json_fix_tools as json_fix
except ImportError as exception:
json_fix = MissingExtraMockModule('json_fix', exception)


__all__ = [
'config',
Expand Down
53 changes: 49 additions & 4 deletions fmtr/tools/ai_tools.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import torch
from datetime import datetime
from peft import PeftConfig, PeftModel
from pydantic import Field
from statistics import mean
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List

from fmtr.tools import logger
from fmtr.tools import logger, data_modelling_tools
from fmtr.tools.hfh_tools import get_hf_cache_path

CPU = 'cpu'
Expand Down Expand Up @@ -319,6 +320,45 @@ def get_output(self, prompt, **kwargs):
return output


class ToolCall(data_modelling_tools.Base):
"""
Tool call data
"""
name: str = Field(
...,
description="The function name"
)
arguments: dict = Field(
...,
description="The function arguments"
)

def apply(self, functions):
"""
Apply the specified functions to their arguments
"""
functions = {function.__name__: function for function in functions}
function = functions[self.name]
obj = function(**self.arguments)
return obj


class ToolsCall(data_modelling_tools.Root):
"""
Tool calls data
"""
root: List[ToolCall]

def apply(self, functions):
objs = [child.apply(functions) for child in self.root]
return objs

def tst():
"""
Expand Down Expand Up @@ -353,15 +393,20 @@ def get_current_weather(location: str, format: str):
"""
return "It's 25 degrees and sunny!"

class BiTools(BulkInferenceManager):
class BulkInferenceManagerTools(BulkInferenceManager):
TOOLS = [get_current_weather]

prompt = "What's the weather like in Paris?"
prompts = [prompt]
manager = BiTools()
manager = BulkInferenceManagerTools()
gen = manager.get_outputs(prompts, max_new_tokens=200, do_sample=True, temperature=1.2, top_p=0.5, top_k=50)
texts = list(gen)
return texts

for text in texts:
objs = ToolsCall.from_json(text).apply(BulkInferenceManagerTools.TOOLS)
obj = objs[0]
print(obj)




Expand Down
39 changes: 39 additions & 0 deletions fmtr/tools/data_modelling_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from pydantic import BaseModel, RootModel


class MixinFromJson:

@classmethod
def from_json(cls, json_str):
"""
Error-tolerant deserialization
"""
from fmtr.tools import json_fix
data = json_fix.from_json(json_str, default={})

if type(data) is dict:
self = cls(**data)
else:
self = cls(data)

return self


class Base(BaseModel, MixinFromJson):
"""
Base model
"""
...


class Root(RootModel, MixinFromJson):
"""
Root (list) model
"""
...
20 changes: 20 additions & 0 deletions fmtr/tools/json_fix_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import json
import json_repair

from fmtr.tools.logging_tools import logger
from fmtr.tools.tools import Raise


def from_json(json_string, default=None):
"""
Error-tolerant JSON deserialization
"""
try:
return json_repair.loads(json_string)
except json.JSONDecodeError as exception:
if default is Raise:
raise exception
logger.warning(f'Deserialization failed {repr(exception)}: {json_string}')
return default
2 changes: 1 addition & 1 deletion fmtr/tools/version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.9.1
0.9.2
6 changes: 4 additions & 2 deletions requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
'netrc': ['tinynetrc'],
'hfh': ['huggingface_hub'],
'merging': ['deepmerge'],
'api': ['fastapi', 'uvicorn', 'logging'],
'ai': ['peft', 'transformers[sentencepiece]', 'torchvision', 'torchaudio']
'api': ['fastapi', 'uvicorn', 'logging', 'dm'],
'ai': ['peft', 'transformers[sentencepiece]', 'torchvision', 'torchaudio', 'dm'],
'dm': ['pydantic'],
'json-fix': ['json_repair']
}

CONSOLE_SCRIPTS = [
Expand Down

0 comments on commit 24d7ee6

Please sign in to comment.