From 094a454f2d8faa930026de422979dc86a479470a Mon Sep 17 00:00:00 2001 From: yashbonde Date: Tue, 12 Mar 2024 22:33:32 +0530 Subject: [PATCH] [chore] add secrets + improved tests + expanded types --- .env.sample | 29 ++ .gitignore | 3 +- api_docs/conf.py | 2 +- chainfury/__init__.py | 10 +- chainfury/base.py | 306 +++++++++++++++-- chainfury/chat.py | 431 ------------------------ chainfury/cli.py | 212 ++++++------ chainfury/components/const.py | 4 +- chainfury/components/openai/__init__.py | 23 +- chainfury/components/tune/__init__.py | 134 ++++++-- chainfury/types.py | 26 +- chainfury/utils.py | 4 +- chainfury/version.py | 2 +- extra/ex1/chain.py | 78 ----- pyproject.toml | 3 +- server/chainfury_server/__main__.py | 26 +- server/chainfury_server/api/chains.py | 23 +- server/chainfury_server/api/user.py | 136 +++++--- server/chainfury_server/app.py | 8 +- server/chainfury_server/database.py | 12 + server/chainfury_server/engine.py | 20 +- server/chainfury_server/utils.py | 65 +++- server/chainfury_server/version.py | 2 +- server/pyproject.toml | 5 +- tests/__main__.py | 8 - tests/main.py | 15 + tests/test_base_chain2.py | 61 ++++ tests/{base.py => test_base_types.py} | 75 ++++- tests/{getkv.py => test_getkv.py} | 0 29 files changed, 924 insertions(+), 799 deletions(-) create mode 100644 .env.sample delete mode 100644 chainfury/chat.py delete mode 100644 extra/ex1/chain.py delete mode 100644 tests/__main__.py create mode 100644 tests/main.py create mode 100644 tests/test_base_chain2.py rename tests/{base.py => test_base_types.py} (59%) rename tests/{getkv.py => test_getkv.py} (100%) diff --git a/.env.sample b/.env.sample new file mode 100644 index 0000000..728f648 --- /dev/null +++ b/.env.sample @@ -0,0 +1,29 @@ +# chainfury server +# ================ +# These are the environment variables that are used by the chainfury_server +# For chainfury jump below to the chainfury section + +# Required +# -------- + +# URL to the database for chainfury server, uses sqlalchemy, so most things should work +CFS_DATABASE="db_drivers://username:password@host:port/db_name" + +# (once in a lifetime) secret string for creating the JWT secrets +JWT_SECRET="secret" + +# (once in a lifetime) password to store the user secrets +CFS_SECRETS_PASSWORD="password" + +# chainfury +# ========= +# These are the environment variables that are used by the chainfury + +# To store all the file and data in the chainfury server +CF_FOLDER="~/cf" + +# (client mode) the URL for the chainfury server +CF_URL="" + +# (client mode) the token for the chainfury server +CF_TOKEN="" diff --git a/.gitignore b/.gitignore index be52001..00613c4 100644 --- a/.gitignore +++ b/.gitignore @@ -142,7 +142,7 @@ langflow dunmp.rdb *.ipynb server/chainfury_server/stories/fury.json -notebooks/* +notebooks stories/fury.json workers/ private.sh @@ -153,3 +153,4 @@ demo/ logs.py chunker/ chainfury/chains/ +gosrc/ diff --git a/api_docs/conf.py b/api_docs/conf.py index 0570017..8731d75 100644 --- a/api_docs/conf.py +++ b/api_docs/conf.py @@ -14,7 +14,7 @@ project = "ChainFury" copyright = "2023, NimbleBox Engineering" author = "NimbleBox Engineering" -release = "1.7.0a1" +release = "1.7.0a2" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/chainfury/__init__.py b/chainfury/__init__.py index 7763832..561671f 100644 --- a/chainfury/__init__.py +++ b/chainfury/__init__.py @@ -21,6 +21,8 @@ Chain, Model, Edge, + Tools, + Action, ) from chainfury.core import ( model_registry, @@ -31,11 +33,11 @@ Memory, ) from chainfury.client import get_client -from chainfury.chat import ( +from chainfury.types import ( Message, - Chat, - TuneChats, - TuneDataset, + Thread, + ThreadsList, + Dataset, human, system, assistant, diff --git a/chainfury/base.py b/chainfury/base.py index 1074696..04f1617 100644 --- a/chainfury/base.py +++ b/chainfury/base.py @@ -9,15 +9,17 @@ import importlib import traceback from pprint import pformat +from functools import partial from typing import Any, Union, Optional, Dict, List, Tuple, Callable, Generator from collections import deque, defaultdict, Counter import jinja2schema as j2s from jinja2schema import model as j2sm +from tuneapi.utils import load_module_from_path, to_json, from_json + from chainfury.utils import logger, terminal_top_with_text import chainfury.types as T -from chainfury import chat class Secret(str): @@ -26,6 +28,9 @@ class Secret(str): def __init__(self, value=""): self.value = value + def has_value(self) -> bool: + return self.value != "" + # # Vars: this is the base class for all the fields that the user can provide from the front end @@ -35,7 +40,7 @@ def __init__(self, value=""): class Var: def __init__( self, - type: Union[str, List["Var"]], + type: Union[str, List["Var"]] = "", format: str = "", items: List["Var"] = [], additionalProperties: Union[List["Var"], "Var"] = [], @@ -45,6 +50,7 @@ def __init__( placeholder: str = "", show: bool = False, name: str = "", + description: str = "", *, loc: Optional[Tuple] = (), ): @@ -62,6 +68,9 @@ def __init__( name (str, optional): The name of this field. Defaults to "". loc (Optional[Tuple], optional): The location of this field. Defaults to (). """ + if not type: + raise ValueError("type cannot be empty") + self.type = type self.format = format self.items = items or [] @@ -72,6 +81,7 @@ def __init__( self.placeholder = placeholder self.show = show self.name = name + self.description = description # self.value = None self.loc = loc # this is the location from which this value is extracted @@ -116,6 +126,8 @@ def to_dict(self) -> Dict[str, Any]: d["name"] = self.name if self.loc: d["loc"] = self.loc + if self.description: + d["description"] = self.description return d @classmethod @@ -138,6 +150,7 @@ def from_dict(cls, d: Dict[str, Any]) -> "Var": show_val = d.get("show", False) name_val = d.get("name", "") loc_val = d.get("loc", ()) + description_val = d.get("description", "") if isinstance(type_val, list): type_val = [ @@ -163,6 +176,7 @@ def from_dict(cls, d: Dict[str, Any]) -> "Var": placeholder=placeholder_val, show=show_val, name=name_val, + description=description_val, loc=loc_val, ) return var @@ -740,16 +754,27 @@ def __call__(self, model_data: Dict[str, Any]) -> Tuple[Any, Optional[Exception] except Exception as e: return traceback.format_exc(), e + def set_api_token(self, token: str) -> None: + raise NotImplementedError( + f"set_api_token method is not implemented for {self.id}" + ) + def completion(self, prompt: str, **kwargs): """Subclass and implement your own text completion API""" return NotImplementedError( f"completion method is not implemented for {self.id}" ) - def chat(self, chat: chat.Chat, **kwargs): + def chat(self, chat: T.Thread, **kwargs): """Subclass and implement your own chat API""" raise NotImplementedError("chat method is not implemented for this model") + def stream_chat(self, chat: T.Thread, **kwargs): + """Subclass and implement your own chat API""" + raise NotImplementedError( + "stream_chat method is not implemented for this model" + ) + # # Node: Each box that is drag and dropped in the UI is a Node, it will tell what kind of things are @@ -981,12 +1006,12 @@ def __call__( @classmethod def from_chat( cls, - chat: chat.Chat, + thread: T.Thread, node_id: str, model: Model, description: Optional[str] = None, ) -> "Node": - chat_dict = chat.to_dict() + chat_dict = thread.to_dict() # print(variables) fields = [] templates = [] @@ -996,7 +1021,7 @@ def from_chat( obj = get_value_by_keys(chat_dict, field[0]) if not obj: raise ValueError( - f"Field {field[0]} not found in {chat}, but was extraced. There is a bug in get_value_by_keys function" + f"Field {field[0]} not found in {thread}, but was extraced. There is a bug in get_value_by_keys function" ) templates.append((obj, jinja2.Template(obj), field[0])) @@ -1124,7 +1149,11 @@ def __init__( self.edges = edges if len(self.nodes) == 1: - assert len(self.edges) == 0, "Cannot have edges with only 1 node" + if len(self.edges) != 0: + logger.error(f"Got only one node: {self.nodes.keys()=}") + raise ValueError( + f"Cannot have edges with only 1 node. Got {self.edges}" + ) self.topo_order = [next(iter(self.nodes))] else: self.topo_order = topological_sort(self.edges) @@ -1134,6 +1163,7 @@ def __init__( if self.is_empty: # there is nothing to do here + logger.info("This is empty chain") return if "/" not in main_out: @@ -1185,7 +1215,7 @@ def __repr__(self) -> str: def add_thread( self, node_id: str, - chat: chat.Chat, + thread: T.Thread, model: Optional[Model] = None, description: str = "", ) -> "Chain": @@ -1197,30 +1227,35 @@ def add_thread( # build the node node = Node.from_chat( - chat, + thread, node_id=node_id, model=model or self.default_model, # type: ignore description=description, ) + logger.debug(f"Adding node (total nodes {len(self.nodes)}): {node.id=}") + # add edges as required for var in node.fields: if var.name in self.nodes: - self.edges.append( - Edge( - src_node_id=var.name, - src_node_var=var.name, - trg_node_id=node_id, - trg_node_var=self.nodes[var.name].outputs[0].name, - ) + e = Edge( + src_node_id=var.name, + src_node_var=var.name, + trg_node_id=node_id, + trg_node_var=self.nodes[var.name].outputs[0].name, ) + logger.debug(f"Adding (total edges {len(self.edges)}) {e=}") + self.edges.append(e) # assign the node self.nodes[node.id] = node # topo sort if len(self.nodes) == 1: - assert len(self.edges) == 0, "Cannot have edges with only 1 node" + if len(self.edges) != 0: + raise ValueError( + f"Cannot have edges with only 1 node. Got {self.edges}" + ) self.topo_order = [next(iter(self.nodes))] else: self.topo_order = topological_sort(self.edges) @@ -1339,9 +1374,9 @@ def to_dag(self) -> T.Dag: nodes = [] for i, node in enumerate(self.nodes.values()): nodes.append( - T.FENode( + T.UINode( id=node.id, - position=T.FENode.Position( + position=T.UINode.Position( x=i * 100, y=i * 100, ), @@ -1349,13 +1384,13 @@ def to_dag(self) -> T.Dag: width=100, height=100, selected=False, - position_absolute=T.FENode.Position( + position_absolute=T.UINode.Position( x=i * 100, y=i * 100, ), dragging=False, cf_id=node.id, - cf_data=T.FENode.CFData( + cf_data=T.UINode.CFData( id=node.id, type=node.type, node=node.to_dict(), @@ -1806,6 +1841,235 @@ def stream( yield out, True +# +# Tools: A new abstraction for AGI +# + + +class Action: + def __init__( + self, + name: str, + description: str, + properties: Optional[Dict[str, Any]] = {}, + required: Optional[List[str]] = [], + fn: Optional[Callable] = None, + fn_meta: Optional[Dict[str, Any]] = None, + ): + self.name = name + self.description = description + self.properties = properties + self.required = required + + if (fn is None and fn_meta is None) or (fn is not None and fn_meta is not None): + raise ValueError("Either fn or fn_meta is required") + if fn_meta is not None: + # first validate that the line content is same in the two files, then try to load the item + if not os.path.exists(fn_meta["file"]): + raise ValueError(f"File {fn_meta['file']} does not exist") + with open(fn_meta["file"], "r") as f: + for i, l in enumerate(f): + if i == fn_meta["line"] and l != fn_meta["line_val"]: + raise ValueError( + f"Line #{fn_meta['line']} does not match in {fn_meta['file']}\n" + f" Expected: {l}\n" + f" Found: {fn_meta['line_val']}" + ) + fn = load_module_from_path(fn_meta["name"], fn_meta["file"]) + elif fn is not None: + fn_meta = { + "file": inspect.getfile(fn), + "line": inspect.getsourcelines(fn)[1] - 1, + "line_val": inspect.getsourcelines(fn)[0][0], + "name": fn.__name__, + } + + self.fn_meta = fn_meta + self.fn: Callable = fn # type: ignore + + def __repr__(self) -> str: + return f"[Action] {self.name}: {self.description}" + + # ser/deser + + def to_dict(self) -> Dict[str, Any]: + """Serializes the action to a dictionary. + + Returns: + Dict[str, Any]: The dictionary representation of the action. + """ + return { + "name": self.name, + "description": self.description, + "required": self.required, + "properties": self.properties, + "fn_meta": self.fn_meta, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Action": + """Deserializes the action from a dictionary. + + Args: + data (Dict[str, Any]): The dictionary representation of the action. + """ + return cls( + name=data["name"], + description=data["description"], + required=data["required"], + properties=data["properties"], + fn_meta=data["fn_meta"], + ) + + def to_json(self, indent=0, tight=True) -> str: + """Serializes the action to a JSON string. + + Returns: + str: The JSON string representation of the action. + """ + return to_json(self.to_dict(), indent=indent, tight=tight) + + @classmethod + def from_json(cls, data: str) -> "Action": + """Creates an action from a JSON string. + + Args: + data (str): The JSON string representation of the action. + """ + return cls.from_dict(json.loads(data)) + + def __call__(self, *args, **kwargs): + # validate the data is in + return self.fn(*args, **kwargs) + + # usage + + def to_fn(self) -> Dict[str, Any]: + data = self.to_dict() + data.pop("fn_meta") + return data + + +class Tools: + """ + Usage: + + >>> from chainfury import Tool, Var + >>> my_tool = Tool("My Tool", "This is a test tool") + >>> @my_tool( + ... description = "this is test action", + ... props = { + ... "a": Var("int", "number 1"), + ... "b": Var("int", "number 2"), + ... "secret": Var("int", secret = True, key="MY_ENV_VAR"), # to implement + ... } + ... ) + ... def add_two_numbers(a: int, b: int, secret: int): + ... return (a + b) * secret + ... + >>> my_tool.to_json(indent = 2) + { + "name": "add_two_numbers", + ... + } + """ + + def __init__(self, name: str, description: str): + self.name = name + self.description = description + + # + self.actions: Dict[str, Action] = {} + + def __repr__(self) -> str: + return f"[Tool] {self.name}: {self.description}" + + def _register_action( + self, + fn: Callable, + description: str, + properties: Dict[str, Var], + name: Optional[str] = None, + ) -> Action: + name = name or fn.__name__ + props = {} + required = [] + for k, v in properties.items(): + props[k] = { + "type": v.type, + "description": v.description, + } + if v.required: + required.append(k) + + self.actions[name] = Action( + name=name, + description=description, + properties=props, + required=required, + fn=fn, + ) + return self.actions[name] + + def add(self, description: str, properties: Dict[str, Var] = {}) -> Action: + """ + Register the actions + """ + return partial( + self._register_action, + description=description, + properties=properties, + ) # type: ignore + + # ser/deser + + def to_dict(self) -> Dict[str, Any]: + """ + Register the actions + """ + return { + "name": self.name, + "description": self.description, + "actions": {k: v.to_dict() for k, v in self.actions.items()}, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Tools": + """Deserializes the Tool from a dictionary. + + Args: + data (Dict[str, Any]): The dictionary representation of the chain. + """ + self = cls( + name=data["name"], + description=data["description"], + ) + actions = data.get("actions", {}) + for a in actions.values(): + action = Action.from_dict(a) + self.actions[action.name] = action + + return self + + def to_json(self, indent=2, tight=False) -> str: + """Serializes the Tool to a JSON string. + + Returns: + str: The JSON string representation of the chain. + """ + return to_json(self.to_dict(), indent=indent, tight=tight) + + @classmethod + def from_json(cls, data: str) -> "Tools": + """ + Register the actions + """ + return cls.from_dict(json.loads(data)) + + def to_fn(self) -> Dict[str, Any]: + return {"name": self.name, "description": self.description, "properties": {}} + + # # helper functions # diff --git a/chainfury/chat.py b/chainfury/chat.py deleted file mode 100644 index 7ebfe9d..0000000 --- a/chainfury/chat.py +++ /dev/null @@ -1,431 +0,0 @@ -# Copyright © 2023- Frello Technology Private Limited - -import os -import json -import random -from functools import partial -from collections.abc import Iterable -from typing import Dict, List, Any, Tuple, Optional, Generator - -from chainfury.utils import to_json, get_random_string, logger - - -class Message: - SYSTEM = "system" - HUMAN = "human" - GPT = "gpt" - VALUE = "value" - FUNCTION = "function" - FUNCTION_RESPONSE = "function-response" - - # start initialization here - def __init__(self, value: str | float, role: str): - if role in ["system", "sys"]: - role = self.SYSTEM - elif role in ["user", "human"]: - role = self.HUMAN - elif role in ["gpt", "assistant", "machine"]: - role = self.GPT - elif role in ["value"]: - role = self.VALUE - elif role in ["function", "fn"]: - role = self.FUNCTION - elif role in ["function-response", "fn-resp"]: - role = self.FUNCTION_RESPONSE - else: - raise ValueError(f"Unknown role: {role}") - if value is None: - raise ValueError("value cannot be None") - - self.role = role - self.value = value - self._unq_value = get_random_string(6) - - def __str__(self) -> str: - try: - idx = max(os.get_terminal_size().columns - len(self.role) - 40, 10) - except OSError: - idx = 50 - return f"<{self.role}: {json.dumps(self.value)[:idx]}>" - - def __repr__(self) -> str: - return str(self.value) - - def to_dict(self, ft: bool = False): - """ - if `ft` then export to following format: `{"from": "system/human/gpt", "value": "..."}` - else export to following format: `{"role": "system/user/assistant", "content": "..."}` - """ - role = self.role - if not ft: - if self.role == self.HUMAN: - role = "user" - elif self.role == self.GPT: - role = "assistant" - - chat_message: Dict[str, str | float] - if ft: - chat_message = {"from": role} - else: - chat_message = {"role": role} - - if not ft: - chat_message["content"] = self.value - else: - chat_message["value"] = self.value - return chat_message - - @classmethod - def from_dict(cls, data): - return cls( - value=data.get("value") or data.get("content"), - role=data.get("from") or data.get("role"), - ) # type: ignore - - -### Aliases -human = partial(Message, role=Message.HUMAN) -system = partial(Message, role=Message.SYSTEM) -assistant = partial(Message, role=Message.GPT) - - -class Chat: - """ - If the last Message is a "value" then a special tag "koro.regression"="true" is added to the meta. - - Args: - chats (List[Message]): List of chat messages - jl (Dict[str, Any]): Optional json-logic - """ - - def __init__( - self, - chats: List[Message], - jl: Optional[Dict[str, Any]] = None, - model: Optional[str] = None, - **kwargs, - ): - self.chats = chats - self.jl = jl - self.model = model - - # check for regression - if self.chats[-1].role == Message.VALUE: - kwargs["koro.regression"] = True - - kwargs = {k: v for k, v in sorted(kwargs.items())} - self.meta = kwargs - self.keys = list(kwargs.keys()) - self.values = tuple(kwargs.values()) - - # avoid special character BS. - assert not any(["=" in x or "&" in x for x in self.keys]) - if self.values: - assert all([type(x) in [int, str, float, bool] for x in self.values]) - - self.value_hash = hash(self.values) - - def __repr__(self) -> str: - x = " Any: - if __name in self.meta: - return self.meta[__name] - raise AttributeError(f"Attribute {__name} not found") - - # ser/deser - - def to_dict(self, full: bool = False): - if full: - return { - "chats": [x.to_dict() for x in self.chats], - "jl": self.jl, - "model": self.model, - "meta": self.meta, - } - return { - "chats": [x.to_dict() for x in self.chats], - } - - def to_chat_template(self): - return self.to_dict()["chats"] - - @classmethod - def from_dict(cls, data: Dict[str, Any]): - chats = data.get("chats", None) or data.get("conversations", None) - if not chats: - raise ValueError("No chats found") - return cls( - chats=[Message.from_dict(x) for x in chats], - jl=data.get("jl"), - model=data.get("model"), - **data.get("meta", {}), - ) - - def to_ft( - self, id: Any = None, drop_last: bool = False - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - chats = self.chats if not drop_last else self.chats[:-1] - ft_dict = { - "id": id or get_random_string(6), - "conversations": [x.to_dict(ft=True) for x in chats], - } - if drop_last: - ft_dict["last"] = self.chats[-1].to_dict(ft=True) - return ft_dict, self.meta - - # modifications - - def copy(self) -> "Chat": - return Chat( - chats=[x for x in self.chats], - jl=self.jl, - model=self.model, - **self.meta, - ) - - def add(self, message: Message): - self.chats.append(message) - - -# these are the classes that we use for tune datasets from r-stack - - -class TuneChats(list): - """This class implements some basic container methods for a list of Chat objects""" - - def __init__(self): - self.keys = {} - self.items: List[Chat] = [] - self.idx_dict: Dict[int, Tuple[Any, ...]] = {} - self.key_to_items_idx: Dict[int, List[int]] = {} - - def __repr__(self) -> str: - return ( - f"TuneChats(unq_keys={len(self.key_to_items_idx)}, items={len(self.items)})" - ) - - def __len__(self) -> int: - return len(self.items) - - def __iter__(self) -> Generator[Chat, None, None]: - for x in self.items: - yield x - - def stream(self) -> Generator[Chat, None, None]: - for x in self: - yield x - - def __getitem__(self, __index) -> List[Chat]: - return self.items[__index] - - def table(self) -> str: - try: - from tabulate import tabulate - except ImportError: - raise ImportError("Install tabulate to use this method") - - table = [] - for k, v in self.idx_dict.items(): - table.append( - [ - *v, - len(self.key_to_items_idx[k]), - f"{len(self.key_to_items_idx[k])/len(self)*100:0.2f}%", - ] - ) - return tabulate(table, headers=[*list(self.keys), "count", "percentage"]) - - # data manipulation - - def append(self, __object: Any) -> None: - if not self.items: - self.keys = __object.meta.keys() - if self.keys != __object.meta.keys(): - raise ValueError("Keys should match") - self.idx_dict.setdefault(__object.value_hash, __object.values) - self.key_to_items_idx.setdefault(__object.value_hash, []) - self.key_to_items_idx[__object.value_hash].append(len(self.items)) - self.items.append(__object) - - def add(self, x: Chat): - return self.append(x) - - def extend(self, __iterable: Iterable) -> None: - if hasattr(__iterable, "items"): - for x in __iterable.items: # type: ignore - self.append(x) - elif isinstance(__iterable, Iterable): - for x in __iterable: - self.append(x) - else: - raise ValueError("Unknown iterable") - - def shuffle(self, seed: Optional[int] = None) -> None: - """Perform in place shuffle""" - # shuffle using indices, self.items and self.key_to_items_idx - idx = list(range(len(self.items))) - if seed: - rng = random.Random(seed) - rng.shuffle(idx) - else: - random.shuffle(idx) - self.items = [self.items[i] for i in idx] - self.key_to_items_idx = {} - for i, x in enumerate(self.items): - self.key_to_items_idx.setdefault(x.value_hash, []) - self.key_to_items_idx[x.value_hash].append(i) - - def create_te_split(self, test_items: int | float = 0.1) -> Tuple["TuneChats", ...]: - try: - import numpy as np - except ImportError: - raise ImportError("Install numpy to use `create_te_split` method") - - train_ds = TuneChats() - eval_ds = TuneChats() - items_np_arr = np.array(self.items) - for k, v in self.key_to_items_idx.items(): - if isinstance(test_items, float): - if int(len(v) * test_items) < 1: - raise ValueError( - f"Test percentage {test_items} is too high for the dataset key '{k}'" - ) - split_ids = random.sample(v, int(len(v) * test_items)) - else: - if test_items > len(v): - raise ValueError( - f"Test items {test_items} is too high for the dataset key '{k}'" - ) - split_ids = random.sample(v, test_items) - - # get items - eval_items = items_np_arr[split_ids] - train_items = items_np_arr[np.setdiff1d(v, split_ids)] - train_ds.extend(train_items) - eval_ds.extend(eval_items) - - return train_ds, eval_ds - - # ser / deser - - def to_dict(self): - return {"items": [x.to_dict() for x in self.items]} - - @classmethod - def from_dict(cls, data): - bench_dataset = cls() - for item in data["items"]: - bench_dataset.append(Chat.from_dict(item)) - return bench_dataset - - def to_disk(self, folder: str, fmt: Optional[str] = None): - if fmt: - logger.warn( - f"exporting to {fmt} format, you cannot recreate the dataset from this." - ) - os.makedirs(folder) - with open(f"{folder}/tuneds.jsonl", "w") as f: - for sample in self.items: - if fmt == "sharegpt": - item, _ = sample.to_ft() - elif fmt is None: - item = sample.to_dict() - else: - raise ValueError(f"Unknown format: {fmt}") - f.write(to_json(item, tight=True) + "\n") # type: ignore - - @classmethod - def from_disk(cls, folder: str): - bench_dataset = cls() - with open(f"{folder}/tuneds.jsonl", "r") as f: - for line in f: - item = json.loads(line) - bench_dataset.append(Chat.from_dict(item)) - return bench_dataset - - def to_hf_dataset(self) -> Tuple["datasets.Dataset", List]: # type: ignore - try: - import datasets as dst - except ImportError: - raise ImportError("Install huggingface datasets library to use this method") - - _ds_list = [] - meta_list = [] - for x in self.items: - sample, meta = x.to_ft() - _ds_list.append(sample) - meta_list.append(meta) - return dst.Dataset.from_list(_ds_list), meta_list - - # properties - - def can_train_koro_regression(self) -> bool: - return all(["koro.regression" in x.meta for x in self]) - - -class TuneDataset: - """This class is a container for training and evaulation datasets, useful for serialising items to and from disk""" - - def __init__(self, train: TuneChats, eval: TuneChats): - self.train_ds = train - self.eval_ds = eval - - def __repr__(self) -> str: - return f"TuneDataset(\n train={self.train_ds},\n eval={self.eval_ds}\n)" - - @classmethod - def from_list(cls, items: List["TuneDataset"]): - train_ds = TuneChats() - eval_ds = TuneChats() - for item in items: - train_ds.extend(item.train_ds) - eval_ds.extend(item.eval_ds) - return cls(train=train_ds, eval=eval_ds) - - def to_hf_dict(self) -> Tuple["datasets.DatasetDict", Dict[str, List]]: # type: ignore - try: - import datasets as dst - except ImportError: - raise ImportError("Install huggingface datasets library to use this method") - - train_ds, train_meta = self.train_ds.to_hf_dataset() - eval_ds, eval_meta = self.eval_ds.to_hf_dataset() - return dst.DatasetDict(train=train_ds, eval=eval_ds), { - "train": train_meta, - "eval": eval_meta, - } - - def to_disk(self, folder: str, fmt: Optional[str] = None): - config = {} - config["type"] = "tune" - config["hf_type"] = fmt - os.makedirs(folder) - self.train_ds.to_disk(f"{folder}/train", fmt=fmt) - self.eval_ds.to_disk(f"{folder}/eval", fmt=fmt) - to_json(config, fp=f"{folder}/tune_config.json", tight=True) - - @classmethod - def from_disk(cls, folder: str): - if not os.path.exists(folder): - raise ValueError(f"Folder '{folder}' does not exist") - if not os.path.exists(f"{folder}/train"): - raise ValueError(f"Folder '{folder}/train' does not exist") - if not os.path.exists(f"{folder}/eval"): - raise ValueError(f"Folder '{folder}/eval' does not exist") - if not os.path.exists(f"{folder}/tune_config.json"): - raise ValueError(f"File '{folder}/tune_config.json' does not exist") - - # not sure what to do with these - with open(f"{folder}/tune_config.json", "r") as f: - config = json.load(f) - return cls( - train=TuneChats.from_disk(f"{folder}/train"), - eval=TuneChats.from_disk(f"{folder}/eval"), - ) diff --git a/chainfury/cli.py b/chainfury/cli.py index 8c44d20..426f94e 100644 --- a/chainfury/cli.py +++ b/chainfury/cli.py @@ -1,85 +1,23 @@ # Copyright © 2023- Frello Technology Private Limited +import dotenv + +dotenv.load_dotenv() + import os import sys import json from fire import Fire +from typing import Optional from chainfury import Chain from chainfury.version import __version__ from chainfury.components import all_items from chainfury.core import model_registry, programatic_actions_registry, memory_registry +from chainfury.chat import Chat, Message -def run( - chain: str, - inp: str, - stream: bool = False, - print_thoughts: bool = False, - f=sys.stdout, -): - """ - Run a chain with input and write the outputs. - - Args: - chain (str): This can be one of json filepath (e.g. "/chain.json"), json string (e.g. '{"id": "99jcjs9j2", ...}'), - chain id (e.g. "99jcjs9j2") - inp (str): This can be one of json filepath (e.g. "/input.json"), json string (e.g. '{"foo": "bar", ...}') - stream (bool, optional): Whether to stream the output. Defaults to False. - print_thoughts (bool, optional): Whether to print thoughts. Defaults to False. - f (file, optional): File to write the output to. Defaults to `sys.stdout`. - - Examples: - >>> $ cf run ./sample.json {"foo": "bar"} - """ - # validate inputs - if isinstance(inp, str): - if os.path.exists(inp): - with open(inp, "r") as f: - inp = json.load(f) - else: - try: - inp = json.loads(inp) - except Exception as e: - raise ValueError( - "Input must be a valid json string or a json file path" - ) - assert isinstance(inp, dict), "Input must be a dict" - - # create chain - chain_obj = None - if isinstance(chain, str): - if os.path.exists(chain): - with open(chain, "w") as f: - chain = json.load(f) - if len(chain) == 8: - chain_obj = Chain.from_id(chain) - else: - chain = json.loads(chain) - elif isinstance(chain, dict): - chain_obj = Chain.from_dict(chain) - assert chain_obj is not None, "Chain not found" - - # output - if isinstance(f, str): - f = open(f, "w") - - # run the chain - if stream: - cf_response_gen = chain_obj.stream(inp, print_thoughts=print_thoughts) - for ir, done in cf_response_gen: - if not done: - f.write(json.dumps(ir) + "\n") - else: - out, buffer = chain_obj(inp, print_thoughts=print_thoughts) - for k, v in buffer.items(): - f.write(json.dumps({k: v}) + "\n") - - # close file - f.close() - - -class __CLI: +class CLI: info = rf""" ___ _ _ ___ / __| |_ __ _(_)_ _ | __| _ _ _ _ _ @@ -90,39 +28,121 @@ class __CLI: ae e0 a5 87 e0 a4 b5 20 e0 a4 9c e0 a4 af e0 a4 a4 e0 a5 87 - cf_version: {__version__} 🦋 The FOSS chaining engine behind chat.tune.app - -A powerful way to program for the "Software 2.0" era. Read more: - -- https://tunehq.ai -- https://chat.tune.app -- https://studio.tune.app 🌟 us on https://github.com/NimbleBoxAI/ChainFury - -Build with ♥️ by Tune AI from the Koro coast 🌊 Chennai, India +♥️ Built by [Tune AI](https://tunehq.ai) from ECR, Chennai 🌊 """ - comp = { - "all": lambda: print(all_items), - "model": { - "list": list(model_registry.get_models()), - "all": model_registry.get_models(), - "get": model_registry.get, - }, - "prog": { - "list": list(programatic_actions_registry.get_nodes()), - "all": programatic_actions_registry.get_nodes(), - }, - "memory": { - "list": list(memory_registry.get_nodes()), - "all": memory_registry.get_nodes(), - }, - } - run = run + def run( + self, + chain: str, + inp: str, + stream: bool = False, + print_thoughts: bool = False, + f=sys.stdout, + ): + """ + Run a chain with input and write the outputs. + + Args: + chain (str): This can be one of json filepath (e.g. "/chain.json"), json string (e.g. '{"id": "99jcjs9j2", ...}'), + chain id (e.g. "99jcjs9j2") + inp (str): This can be one of json filepath (e.g. "/input.json"), json string (e.g. '{"foo": "bar", ...}') + stream (bool, optional): Whether to stream the output. Defaults to False. + print_thoughts (bool, optional): Whether to print thoughts. Defaults to False. + f (file, optional): File to write the output to. Defaults to `sys.stdout`. + + Examples: + >>> $ cf run ./sample.json {"foo": "bar"} + """ + # validate inputs + if isinstance(inp, str): + if os.path.exists(inp): + with open(inp, "r") as f: + inp = json.load(f) + else: + try: + inp = json.loads(inp) + except Exception as e: + raise ValueError( + "Input must be a valid json string or a json file path" + ) + assert isinstance(inp, dict), "Input must be a dict" + + # create chain + chain_obj = None + if isinstance(chain, str): + if os.path.exists(chain): + with open(chain, "w") as f: + chain = json.load(f) + if len(chain) == 8: + chain_obj = Chain.from_id(chain) + else: + chain = json.loads(chain) + elif isinstance(chain, dict): + chain_obj = Chain.from_dict(chain) + assert chain_obj is not None, "Chain not found" + + # output + if isinstance(f, str): + f = open(f, "w") + + # run the chain + if stream: + cf_response_gen = chain_obj.stream(inp, print_thoughts=print_thoughts) + for ir, done in cf_response_gen: + if not done: + f.write(json.dumps(ir) + "\n") + else: + out, buffer = chain_obj(inp, print_thoughts=print_thoughts) + for k, v in buffer.items(): + f.write(json.dumps({k: v}) + "\n") + + # close file + f.close() + + def sh( + self, + api: str = "tuneapi", + model: str = "rohan/mixtral-8x7b-inst-v0-1-32k", # "kaushikaakash04/tune-blob" + token: Optional[str] = None, + stream: bool = True, + ): + cf_model = model_registry.get(api) + if token is not None: + cf_model.set_api_token(token) + + # loop for user input through command line + chat = Chat() + usr_cntr = 0 + while True: + try: + user_input = input( + f"\033[1m\033[33m [{usr_cntr:02d}] YOU \033[39m:\033[0m " + ) + except KeyboardInterrupt: + break + if user_input == "exit" or user_input == "quit" or user_input == "": + break + chat.add(Message(user_input, Message.HUMAN)) + + print(f"\033[1m\033[34m ASSISTANT \033[39m:\033[0m ", end="", flush=True) + if stream: + response = "" + for str_token in cf_model.stream_chat(chat, model=model): + response += str_token + print(str_token, end="", flush=True) + print() # new line + chat.add(Message(response, Message.GPT)) + else: + response = cf_model.chat(chat, model=model) + print(response) + + chat.add(Message(response, Message.GPT)) + usr_cntr += 1 def main(): - Fire(__CLI) + Fire(CLI) diff --git a/chainfury/components/const.py b/chainfury/components/const.py index fbb7d6e..59a99a6 100644 --- a/chainfury/components/const.py +++ b/chainfury/components/const.py @@ -12,7 +12,7 @@ class Env: * CF_URL: ChainFury API URL * NBX_DEPLOY_URL: NimbleBox Deploy URL * NBX_DEPLOY_KEY: NimbleBox Deploy API key - * TUNECHAT_KEY: ChatNBX API key, see chat.nbox.ai + * TUNEAPI_TOKEN: ChatNBX API key, see chat.nbox.ai * OPENAI_TOKEN: OpenAI API token, see platform.openai.com * SERPER_API_KEY: Serper API key, see serper.dev/ * STABILITY_KEY: Stability API key, see dreamstudio.ai @@ -29,7 +29,7 @@ class Env: NBX_DEPLOY_KEY = lambda x: x or os.getenv("NBX_DEPLOY_KEY", "") ## different keys for different 3rd party APIs - TUNECHAT_KEY = lambda x: x or os.getenv("TUNECHAT_KEY", "") + TUNEAPI_TOKEN = lambda x: x or os.getenv("TUNEAPI_TOKEN", "") OPENAI_TOKEN = lambda x: x or os.getenv("OPENAI_TOKEN", "") SERPER_API_KEY = lambda x: x or os.getenv("SERPER_API_KEY", "") diff --git a/chainfury/components/openai/__init__.py b/chainfury/components/openai/__init__.py index bfbcd35..0fddf4c 100644 --- a/chainfury/components/openai/__init__.py +++ b/chainfury/components/openai/__init__.py @@ -13,23 +13,27 @@ UnAuthException, ) from chainfury.components.const import Env -from chainfury.chat import Chat +from chainfury.types import Thread class OpenaiGPTModel(Model): def __init__(self, id: Optional[str] = None): self._openai_model_id = id + self.openai_api_key = Secret(Env.OPENAI_TOKEN("")) super().__init__( id="openai-chat", description="Use OpenAI chat models", usage=["usage", "total_tokens"], ) + def set_api_token(self, token: str) -> None: + self.openai_api_key = Secret(token) + def chat( self, - chats: Chat, + chats: Thread, model: Optional[str] = None, - openai_api_key: Secret = Secret(""), + token: Secret = Secret(""), temperature: float = 1.0, top_p: float = 1.0, n: int = 1, @@ -49,7 +53,7 @@ def chat( Args: messages: A list of messages describing the conversation so far model: ID of the model to use. See [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat/create). - openai_api_key (Secret): The OpenAI API key. Defaults to "" or the OPENAI_TOKEN environment variable. + token (Secret): The OpenAI API key. Defaults to "" or the OPENAI_TOKEN environment variable. temperature: Optional. What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both. Defaults to 1. top_p: Optional. An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both. Defaults to 1. n: Optional. How many chat completion choices to generate for each input message. Defaults to 1. @@ -65,14 +69,13 @@ def chat( Returns: Any: The completion(s) generated by the API. """ - if not openai_api_key: - openai_api_key = Secret(Env.OPENAI_TOKEN("")).value # type: ignore - if not openai_api_key: + if not token and not self.openai_api_key.value: raise Exception( "OpenAI API key not found. Please set OPENAI_TOKEN environment variable or pass through function" ) - if isinstance(chats, Chat): - messages = chats.to_dict() + + if isinstance(chats, Thread): + messages = chats.to_dict()["chats"] else: messages = chats @@ -83,7 +86,7 @@ def _fn(): "https://api.openai.com/v1/chat/completions", headers={ "Content-Type": "application/json", - "Authorization": f"Bearer {openai_api_key}", + "Authorization": f"Bearer {token}", }, json={ "model": model, diff --git a/chainfury/components/tune/__init__.py b/chainfury/components/tune/__init__.py index b1caeda..8c556e2 100644 --- a/chainfury/components/tune/__init__.py +++ b/chainfury/components/tune/__init__.py @@ -7,7 +7,7 @@ from chainfury import Secret, model_registry, exponential_backoff, Model from chainfury.components.const import Env -from chainfury.chat import Chat +from chainfury.types import Thread class TuneModel(Model): @@ -15,23 +15,79 @@ class TuneModel(Model): def __init__(self, id: Optional[str] = None): self._tune_model_id = id + self.tune_api_token = Secret(Env.TUNEAPI_TOKEN("")) super().__init__( - id="chatnbx", - description="Chat with the ChatNBX API with OpenAI compatability, see more at https://chat.nbox.ai/", + id="tuneapi", + description="Chat with the Tune Studio APIs, see more at https://studio.tune.app/", usage=["usage", "total_tokens"], ) + def set_api_token(self, token: str) -> None: + self.tune_api_token = Secret(token) + def chat( self, - chats: Chat, - chatnbx_api_key: Secret = Secret(""), + chats: Thread, model: Optional[str] = None, max_tokens: int = 1024, temperature: float = 1, *, - retry_count: int = 3, - retry_delay: int = 1, + token: Secret = Secret(""), ) -> Dict[str, Any]: + """ + Chat with the Tune Studio APIs, see more at https://studio.tune.app/ + + Note: This is a API is partially compatible with OpenAI's API, so `messages` should be of type :code:`[{"role": ..., "content": ...}]` + + Args: + model (str): The model to use, see https://studio.nbox.ai/ for more info + messages (List[Dict[str, str]]): A list of messages to send to the API which are OpenAI compatible + token (Secret, optional): The API key to use or set TUNEAPI_TOKEN environment variable + max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 1024. + temperature (float, optional): The higher the temperature, the crazier the text. Defaults to 1. + + Returns: + Dict[str, Any]: The response from the API + """ + if not token and not self.tune_api_token.has_value(): # type: ignore + raise Exception( + "Tune API key not found. Please set TUNEAPI_TOKEN environment variable or pass through function" + ) + token = token or self.tune_api_token + if isinstance(chats, Thread): + messages = chats.to_dict()["chats"] + else: + messages = chats + + model = model or self._tune_model_id + url = "https://proxy.tune.app/chat/completions" + headers = { + "Authorization": token.value, + "Content-Type": "application/json", + } + data = { + "temperature": temperature, + "messages": messages, + "model": model, + "stream": False, + "max_tokens": max_tokens, + } + response = requests.post(url, headers=headers, json=data) + try: + response.raise_for_status() + except Exception as e: + raise e + return response.json()["choices"][0]["message"]["content"] + + def stream_chat( + self, + chats: Thread, + model: Optional[str] = None, + max_tokens: int = 1024, + temperature: float = 1, + *, + token: Secret = Secret(""), + ): """ Chat with the ChatNBX API with OpenAI compatability, see more at https://chat.nbox.ai/ @@ -40,45 +96,57 @@ def chat( Args: model (str): The model to use, see https://chat.nbox.ai/ for more info messages (List[Dict[str, str]]): A list of messages to send to the API which are OpenAI compatible - chatnbx_api_key (Secret, optional): The API key to use or set TUNECHAT_KEY environment variable + token (Secret, optional): The API key to use or set TUNEAPI_TOKEN environment variable max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 1024. temperature (float, optional): The higher the temperature, the crazier the text. Defaults to 1. Returns: Dict[str, Any]: The response from the API """ - if not chatnbx_api_key: - chatnbx_api_key = Secret(Env.TUNECHAT_KEY("")).value # type: ignore - if not chatnbx_api_key: + if not token and not self.tune_api_token.has_value(): # type: ignore raise Exception( - "OpenAI API key not found. Please set TUNECHAT_KEY environment variable or pass through function" + "Tune API key not found. Please set TUNEAPI_TOKEN environment variable or pass through function" ) - if isinstance(chats, Chat): - messages = chats.to_dict() + + token = token or self.tune_api_token + if isinstance(chats, Thread): + messages = chats.to_dict()["chats"] else: messages = chats model = model or self._tune_model_id - - def _fn(): - url = "https://proxy.tune.app/chat/completions" - headers = { - "Authorization": chatnbx_api_key, - "Content-Type": "application/json", - } - data = { - "temperature": temperature, - "messages": messages, - "model": model, - "stream": False, - "max_tokens": max_tokens, - } - response = requests.post(url, headers=headers, json=data) - return response.json()["choices"][0]["message"]["content"] - - return exponential_backoff( - _fn, max_retries=retry_count, retry_delay=retry_delay + url = "https://proxy.tune.app/chat/completions" + headers = { + "Authorization": token.value, + "Content-Type": "application/json", + } + data = { + "temperature": temperature, + "messages": messages, + "model": model, + "stream": True, + "max_tokens": max_tokens, + } + response = requests.post( + url, + headers=headers, + json=data, + stream=True, ) + try: + response.raise_for_status() + except Exception as e: + print(response.text) + raise e + for line in response.iter_lines(): + line = line.decode().strip() + if line: + try: + yield json.loads(line.replace("data: ", ""))["choices"][0]["delta"][ + "content" + ] + except: + break tune_model = model_registry.register(model=TuneModel()) diff --git a/chainfury/types.py b/chainfury/types.py index 1829af4..185607f 100644 --- a/chainfury/types.py +++ b/chainfury/types.py @@ -4,10 +4,23 @@ from typing import Dict, Any, List, Optional from pydantic import BaseModel, Field, ConfigDict +# some types that are copied from the tuneapi types + +from tuneapi.types.chats import ( + Message, + Thread, + ThreadsList, + Dataset, + human, + system, + assistant, +) + + # First is the set of types that are used in the chainfury itself -class FENode(BaseModel): +class UINode(BaseModel): """FENode is the node as required by the UI to render the node in the graph. If you do not care about the UI, you can populate either the ``cf_id`` or ``cf_data``.""" @@ -56,14 +69,14 @@ class Edge(BaseModel): class Dag(BaseModel): """This is visual representation of the chain. JSON of this is stored in the DB.""" - nodes: List[FENode] + nodes: List[UINode] edges: List[Edge] sample: Dict[str, Any] = Field(default_factory=dict) main_in: str = "" main_out: str = "" -class CFPromptResult(BaseModel): +class ChainResult(BaseModel): """This is a structured result of the prompt by the Chain. This is more useful for providing types on the server.""" result: str @@ -71,6 +84,9 @@ class CFPromptResult(BaseModel): task_id: str = "" +# Then a set of types that are used in the API (client mode) + + class ApiLoginResponse(BaseModel): message: str token: Optional[str] = None @@ -178,14 +194,14 @@ class ApiPromptFeedbackResponse(BaseModel): rating: int -class ApiSaveTokenRequest(BaseModel): +class ApiToken(BaseModel): key: str token: str meta: Optional[Dict[str, Any]] = {} class ApiListTokensResponse(BaseModel): - tokens: List[ApiSaveTokenRequest] + tokens: List[ApiToken] class ApiChainLog(BaseModel): diff --git a/chainfury/utils.py b/chainfury/utils.py index 4276569..2d730d5 100644 --- a/chainfury/utils.py +++ b/chainfury/utils.py @@ -32,13 +32,13 @@ class CFEnv: CF_LOG_LEVEL = lambda: os.getenv("CF_LOG_LEVEL", "info") CF_FOLDER = lambda: os.path.expanduser(os.getenv("CF_FOLDER", "~/cf")) + CF_URL = lambda: os.getenv("CF_URL", "") + CF_TOKEN = lambda: os.getenv("CF_TOKEN", "") CF_BLOB_STORAGE = lambda: os.path.join(CFEnv.CF_FOLDER(), "blob") CF_BLOB_ENGINE = lambda: os.getenv("CF_BLOB_ENGINE", "local") CF_BLOB_BUCKET = lambda: os.getenv("CF_BLOB_BUCKET", "") CF_BLOB_PREFIX = lambda: os.getenv("CF_BLOB_PREFIX", "") CF_BLOB_AWS_CLOUD_FRONT = lambda: os.getenv("CF_BLOB_AWS_CLOUD_FRONT", "") - CF_URL = lambda: os.getenv("CF_URL", "") - CF_TOKEN = lambda: os.getenv("CF_TOKEN", "") def store_blob(key: str, value: bytes, engine: str = "", bucket: str = "") -> str: diff --git a/chainfury/version.py b/chainfury/version.py index c7502fa..7eb5f4c 100644 --- a/chainfury/version.py +++ b/chainfury/version.py @@ -1,6 +1,6 @@ # Copyright © 2023- Frello Technology Private Limited -__version__ = "1.7.0a1" +__version__ = "1.7.0a2" _major, _minor, _patch = __version__.split(".") _major = int(_major) _minor = int(_minor) diff --git a/extra/ex1/chain.py b/extra/ex1/chain.py deleted file mode 100644 index ba890d8..0000000 --- a/extra/ex1/chain.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright © 2023- Frello Technology Private Limited - -from fire import Fire - -from chainfury.base import Chain -from chainfury.chat import human, Message, Chat - -from chainfury.components.openai import OpenaiGPTModel -from chainfury.components.tune import TuneModel - - -def main(q: str, openai: bool = False): - chain = Chain( - name="demo-one", - description=( - "Building the hardcore example of chain at https://nimbleboxai.github.io/ChainFury/examples/usage-hardcore.html " - "using threaded chains" - ), - main_in="stupid_question", - main_out="fight_scene/fight_scene", - default_model=( - OpenaiGPTModel("gpt-3.5-turbo") - if openai - else TuneModel("rohan/mixtral-8x7b-inst-v0-1-32k") - ), - ) - print("before:") - print(chain) - - chain = chain.add_thread( - "character_one", - Chat( - [ - human( - "You were who was running in the middle of desert. You see a McDonald's and the waiter ask a stupid " - "question like: '{{ stupid_question }}'? You are pissed and you say." - ) - ] - ), - ) - - chain = chain.add_thread( - "character_two", - Chat( - [ - human( - "Someone comes upto you in a bar and screams '{{ character_one }}'? You are a bartender give a funny response to it." - ) - ] - ), - ) - - chain = chain.add_thread( - "fight_scene", - Chat( - [ - human( - "Two men were fighting in a bar. One yelled '{{ character_one }}'. Other responded by yelling '{{ character_two }}'.\n" - "Continue this story for 3 more lines." - ) - ] - ), - ) - - print("---------------") - print(chain) - - print(chain.topo_order) - for ir, done in chain.stream(q): - # print(ir) - pass - - print("---------------") - print(ir) - - -if __name__ == "__main__": - Fire(main) diff --git a/pyproject.toml b/pyproject.toml index 1e5f0c4..05bedbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "chainfury" -version = "1.7.0a1" +version = "1.7.0a2" description = "ChainFury is a powerful tool that simplifies the creation and management of chains of prompts, making it easier to build complex chat applications using LLMs." authors = ["Tune AI "] license = "Apache 2.0" @@ -9,6 +9,7 @@ repository = "https://github.com/NimbleBoxAI/ChainFury" [tool.poetry.dependencies] python = "^3.9,<3.12" +tuneapi = "0.1.1" fire = "0.5.0" Jinja2 = "3.1.2" jinja2schema = "0.1.4" diff --git a/server/chainfury_server/__main__.py b/server/chainfury_server/__main__.py index 54d574f..c723015 100644 --- a/server/chainfury_server/__main__.py +++ b/server/chainfury_server/__main__.py @@ -10,17 +10,6 @@ if os.path.exists(_dotenv_fp): dotenv.load_dotenv(_dotenv_fp) -__CF_LOGO = """ - ___ _ _ ___ - / __| |_ __ _(_)_ _ | __| _ _ _ _ _ -| (__| ' \/ _` | | ' \ | _| || | '_| || | - \___|_||_\__,_|_|_||_||_| \_,_|_| \_, | - |__/ -e0 a4 b8 e0 a4 a4 e0 a5 8d e0 a4 af e0 a4 -ae e0 a5 87 e0 a4 b5 20 e0 a4 9c e0 a4 af - e0 a4 a4 e0 a5 87 -""" - def main( host: str = "0.0.0.0", @@ -38,19 +27,18 @@ def main( post (List[str], optional): List of modules to load after the server is imported. Defaults to []. """ # WARNING: ensure that nothing is being imported in the utils from chainfury_server + from chainfury.cli import CLI from chainfury_server.utils import logger - from chainfury.version import __version__ as cf_version from chainfury_server.version import __version__ as cfs_version logger.info( - f"{__CF_LOGO}\n" + f"{CLI.info}\n" f"Starting ChainFury server ...\n" - f" Host: {host}\n" - f" Port: {port}\n" - f" Pre: {pre}\n" - f" Post: {post}\n" - f" chainfury version: {cf_version}\n" - f" cf_server version: {cfs_version}" + f" Host: {host}\n" + f" Port: {port}\n" + f" Pre: {pre}\n" + f" Post: {post}\n" + f" cf_server: {cfs_version}" ) # load all things you need to preload the modules diff --git a/server/chainfury_server/api/chains.py b/server/chainfury_server/api/chains.py index 97bf143..10537fa 100644 --- a/server/chainfury_server/api/chains.py +++ b/server/chainfury_server/api/chains.py @@ -200,12 +200,14 @@ def run_chain( store_ir: bool = False, store_io: bool = False, db: Session = Depends(DB.fastapi_db_session), -) -> Union[StreamingResponse, T.CFPromptResult, T.ApiResponse]: +) -> Union[StreamingResponse, T.ChainResult, T.ApiResponse]: """ This is the master function to run any chain over the API. This can behave in a bunch of different formats like: - (default) this will wait for the entire chain to execute and return the response - if ``stream`` is passed it will give a streaming response with line by line JSON and last response containing ``"done"`` key - if ``as_task`` is passed then a task ID is received and you can poll for the results at ``/chains/{id}/results`` this supercedes the ``stream``. + + ``as_task`` is not implemented. """ # validate user user = DB.get_user_from_jwt(token=token, db=db) @@ -243,15 +245,16 @@ def run_chain( if as_task: # when run as a task this will return a task ID that will be submitted - result = engine.submit( - chatbot=chatbot, - prompt=prompt, - db=db, - start=time.time(), - store_ir=store_ir, - store_io=store_io, - ) - return result + raise HTTPException(501, detail="Not implemented yet") + # result = engine.submit( + # chatbot=chatbot, + # prompt=prompt, + # db=db, + # start=time.time(), + # store_ir=store_ir, + # store_io=store_io, + # ) + # return result elif stream: def _get_streaming_response(result): diff --git a/server/chainfury_server/api/user.py b/server/chainfury_server/api/user.py index c34a908..633124c 100644 --- a/server/chainfury_server/api/user.py +++ b/server/chainfury_server/api/user.py @@ -4,17 +4,17 @@ from fastapi import HTTPException from passlib.hash import sha256_crypt from sqlalchemy.orm import Session -from fastapi import Request, Response, Depends, Header -from typing import Annotated +from fastapi import Depends, Header +from typing import Annotated, List from chainfury_server.utils import logger, Env import chainfury_server.database as DB import chainfury.types as T +from tuneapi.utils import encrypt, decrypt + def login( - req: Request, - resp: Response, auth: T.ApiAuthRequest, db: Session = Depends(DB.fastapi_db_session), ) -> T.ApiLoginResponse: @@ -26,13 +26,10 @@ def login( ) return T.ApiLoginResponse(message="success", token=token) else: - resp.status_code = 401 - return T.ApiLoginResponse(message="failed") + raise HTTPException(status_code=401, detail="Invalid username or password") def sign_up( - req: Request, - resp: Response, auth: T.ApiSignUpRequest, db: Session = Depends(DB.fastapi_db_session), ) -> T.ApiLoginResponse: @@ -67,13 +64,10 @@ def sign_up( ) return T.ApiLoginResponse(message="success", token=token) else: - resp.status_code = 400 - return T.ApiLoginResponse(message="failed") + raise HTTPException(status_code=500, detail="Unknown error") def change_password( - req: Request, - resp: Response, token: Annotated[str, Header()], inputs: T.ApiChangePasswordRequest, db: Session = Depends(DB.fastapi_db_session), @@ -87,51 +81,111 @@ def change_password( db.commit() return T.ApiResponse(message="success") else: - resp.status_code = 400 - return T.ApiResponse(message="password incorrect") - + raise HTTPException(status_code=401, detail="Invalid old password") -# TODO: @tunekoro - Implement the following functions - -def create_token( - req: Request, - resp: Response, +def create_secret( token: Annotated[str, Header()], - inputs: T.ApiSaveTokenRequest, + inputs: T.ApiToken, db: Session = Depends(DB.fastapi_db_session), ) -> T.ApiResponse: - resp.status_code = 501 # - return T.ApiResponse(message="not implemented") + # validate user + user = DB.get_user_from_jwt(token=token, db=db) + + # validate inputs + if len(inputs.token) >= DB.Tokens.MAXLEN_TOKEN: + raise HTTPException( + status_code=400, + detail=f"Token too long, should be less than {DB.Tokens.MAXLEN_TOKEN} characters", + ) + if len(inputs.key) >= DB.Tokens.MAXLEN_KEY: + raise HTTPException( + status_code=400, + detail=f"Key too long, should be less than {DB.Tokens.MAXLEN_KEY} characters", + ) + + cfs_secrets_password = Env.CFS_SECRETS_PASSWORD() + if cfs_secrets_password is None: + logger.error("CFS_TOKEN_PASSWORD not set, cannot create secrets") + raise HTTPException(500, "internal server error") + + # create a token + token = DB.Tokens( + user_id=user.id, + key=inputs.key, + value=encrypt(inputs.token, cfs_secrets_password, user.id).decode("utf-8"), + meta=inputs.meta, + ) # type: ignore + db.add(token) + db.commit() + return T.ApiResponse(message="success") -def get_token( - req: Request, - resp: Response, +def get_secret( key: str, token: Annotated[str, Header()], db: Session = Depends(DB.fastapi_db_session), -) -> T.ApiResponse: - resp.status_code = 501 # - return T.ApiResponse(message="not implemented") +) -> T.ApiToken: + # validate user + user = DB.get_user_from_jwt(token=token, db=db) + + db_token: DB.Tokens = db.query(DB.Tokens).filter(DB.Tokens.key == key, user.id == user.id).first() # type: ignore + if db_token is None: + raise HTTPException(status_code=404, detail="Token not found") + cfs_token = Env.CFS_SECRETS_PASSWORD() + if cfs_token is None: + logger.error("CFS_TOKEN_PASSWORD not set, cannot create secrets") + raise HTTPException(500, "internal server error") -def list_tokens( - req: Request, - resp: Response, + try: + db_token.value = decrypt(db_token.value, cfs_token, user.id) + except Exception as e: + raise HTTPException(status_code=401, detail="Cannot get token") + return db_token.to_ApiToken() + + +def list_secret( token: Annotated[str, Header()], + limit: int = 100, + offset: int = 0, db: Session = Depends(DB.fastapi_db_session), -) -> T.ApiResponse: - resp.status_code = 501 # - return T.ApiResponse(message="not implemented") - +) -> T.ApiListTokensResponse: + """Returns a list of token keys, and metadata. The token values are not returned.""" + # validate user + user = DB.get_user_from_jwt(token=token, db=db) -def delete_token( - req: Request, - resp: Response, + # get tokens + tokens: List[DB.Tokens] = ( + db.query(DB.Tokens) + .filter(DB.Tokens.user_id == user.id) # type: ignore + .limit(limit) + .offset(offset) + .all() + ) + tokens_resp = [] + for t in tokens: + tok = t.to_ApiToken() + tok.token = "" + tokens_resp.append(tok) + return T.ApiListTokensResponse(tokens=tokens_resp) + + +def delete_secret( key: str, token: Annotated[str, Header()], db: Session = Depends(DB.fastapi_db_session), ) -> T.ApiResponse: - resp.status_code = 501 # - return T.ApiResponse(message="not implemented") + # validate user + user = DB.get_user_from_jwt(token=token, db=db) + + # validate the user can access the token + _ = get_secret(key=key, token=token, db=db) + + # delete token + db_token: DB.Tokens = db.query(DB.Tokens).filter(DB.Tokens.key == key, user.id == user.id).first() # type: ignore + if db_token is None: + raise HTTPException(status_code=404, detail="Token not found") + db.delete(db_token) + db.commit() + return T.ApiResponse(message="success") diff --git a/server/chainfury_server/app.py b/server/chainfury_server/app.py index d5c8445..6c93e75 100644 --- a/server/chainfury_server/app.py +++ b/server/chainfury_server/app.py @@ -46,10 +46,10 @@ app.add_api_route(methods=["POST"], path="/user/login/", endpoint=api_user.login, tags=["user"]) # type: ignore app.add_api_route(methods=["POST"], path="/user/signup/", endpoint=api_user.sign_up, tags=["user"]) # type: ignore app.add_api_route(methods=["POST"], path="/user/change_password/", endpoint=api_user.change_password, tags=["user"]) # type: ignore -app.add_api_route(methods=["PUT"], path="/user/token/", endpoint=api_user.create_token, tags=["user"]) # type: ignore -app.add_api_route(methods=["GET"], path="/user/token/", endpoint=api_user.get_token, tags=["user"]) # type: ignore -app.add_api_route(methods=["DELETE"], path="/user/token/", endpoint=api_user.delete_token, tags=["user"]) # type: ignore -app.add_api_route(methods=["GET"], path="/user/tokens/list/", endpoint=api_user.list_tokens, tags=["user"]) # type: ignore +app.add_api_route(methods=["PUT"], path="/user/secret/", endpoint=api_user.create_secret, tags=["user"]) # type: ignore +app.add_api_route(methods=["GET"], path="/user/secret/", endpoint=api_user.get_secret, tags=["user"]) # type: ignore +app.add_api_route(methods=["DELETE"], path="/user/secret/", endpoint=api_user.delete_secret, tags=["user"]) # type: ignore +app.add_api_route(methods=["GET"], path="/user/secret/list/", endpoint=api_user.list_secret, tags=["user"]) # type: ignore # chains app.add_api_route(methods=["GET"], path="/api/chains/", endpoint=api_chains.list_chains, tags=["chains"]) # type: ignore diff --git a/server/chainfury_server/database.py b/server/chainfury_server/database.py index 72a85ce..938d949 100644 --- a/server/chainfury_server/database.py +++ b/server/chainfury_server/database.py @@ -182,6 +182,7 @@ class Tokens(Base): MAXLEN_KEY = 80 MAXLEN_VAL = 1024 + MAXLEN_TOKEN = 703 # 703 long string can create 1016 long token id = Column(Integer, primary_key=True) user_id = Column(String(ID_LENGTH), ForeignKey("user.id"), nullable=False) @@ -189,10 +190,18 @@ class Tokens(Base): value = Column(String(MAXLEN_VAL), nullable=False) meta = Column(JSON, nullable=True) user = relationship("User", back_populates="tokens") + # (user_id, key) is a unique constraint def __repr__(self): return f"Tokens(id={self.id}, user_id={self.user_id}, key={self.key}, value={self.value[:5]}..., meta={self.meta})" + def to_ApiToken(self) -> T.ApiToken: + return T.ApiToken( + key=self.key, + token=self.value, + meta=self.meta, + ) + class ChatBot(Base): __tablename__ = "chatbot" @@ -262,6 +271,9 @@ class Prompt(Base): session_id: Dict[str, Any] = Column(String(80), nullable=False) meta: Dict[str, Any] = Column(JSON) + # migrate to snowflake ID + sf_id = Column(String(19), nullable=True) + def to_dict(self): return { "id": self.id, diff --git a/server/chainfury_server/engine.py b/server/chainfury_server/engine.py index beb86ec..9c87fbb 100644 --- a/server/chainfury_server/engine.py +++ b/server/chainfury_server/engine.py @@ -25,7 +25,7 @@ def run( start: float, store_ir: bool, store_io: bool, - ) -> T.CFPromptResult: + ) -> T.ChainResult: if prompt.new_message and prompt.data: raise HTTPException( status_code=400, detail="prompt cannot have both new_message and data" @@ -37,7 +37,7 @@ def run( # Create a Fury chain then run the chain while logging all the intermediate steps dag = T.Dag(**chatbot.dag) # type: ignore chain = Chain.from_dag(dag, check_server=False) - callback = FuryThoughts(db, prompt_row.id) + callback = FuryThoughtsCallback(db, prompt_row.id) if prompt.new_message: prompt.data = {chain.main_in: prompt.new_message} @@ -76,7 +76,7 @@ def run( db.commit() # create the result - result = T.CFPromptResult( + result = T.ChainResult( result=( json.dumps(mainline_out) if type(mainline_out) != str @@ -108,7 +108,7 @@ def stream( start: float, store_ir: bool, store_io: bool, - ) -> Generator[Tuple[Union[T.CFPromptResult, Dict[str, Any]], bool], None, None]: + ) -> Generator[Tuple[Union[T.ChainResult, Dict[str, Any]], bool], None, None]: if prompt.new_message and prompt.data: raise HTTPException( status_code=400, detail="prompt cannot have both new_message and data" @@ -120,7 +120,7 @@ def stream( # Create a Fury chain then run the chain while logging all the intermediate steps dag = T.Dag(**chatbot.dag) # type: ignore chain = Chain.from_dag(dag, check_server=False) - callback = FuryThoughts(db, prompt_row.id) + callback = FuryThoughtsCallback(db, prompt_row.id) if prompt.new_message: prompt.data = {chain.main_in: prompt.new_message} @@ -162,7 +162,7 @@ def stream( ) # type: ignore db.add(db_chainlog) - result = T.CFPromptResult( + result = T.ChainResult( result=str(mainline_out), prompt_id=prompt_row.id, # type: ignore ) @@ -189,7 +189,7 @@ def submit( start: float, store_ir: bool, store_io: bool, - ) -> T.CFPromptResult: + ) -> T.ChainResult: if prompt.new_message and prompt.data: raise HTTPException( status_code=400, detail="prompt cannot have both new_message and data" @@ -206,7 +206,7 @@ def submit( # call the chain task_id: str = str(uuid4()) - result = T.CFPromptResult( + result = T.ChainResult( result=f"Task '{task_id}' scheduled", prompt_id=prompt_row.id, task_id=task_id, @@ -224,12 +224,10 @@ def submit( raise HTTPException(status_code=500, detail=str(e)) from e -# engine_registry.register(FuryEngine()) - # helpers -class FuryThoughts: +class FuryThoughtsCallback: def __init__(self, db, prompt_id): self.db = db self.prompt_id = prompt_id diff --git a/server/chainfury_server/utils.py b/server/chainfury_server/utils.py index 34f7faa..608b8f3 100644 --- a/server/chainfury_server/utils.py +++ b/server/chainfury_server/utils.py @@ -1,7 +1,8 @@ # Copyright © 2023- Frello Technology Private Limited import os -import logging +from Cryptodome.Cipher import AES +from base64 import b64decode, b64encode # WARNING: do not import anything from anywhere here, this is the place where chainfury_server starts. # importing anything can cause the --pre and --post flags to fail when starting server. @@ -14,13 +15,11 @@ class Env: """ Single namespace for all environment variables. - - * CFS_DATABASE: database connection string - * JWT_SECRET: secret for JWT tokens """ # once a lifetime secret JWT_SECRET = lambda: os.getenv("JWT_SECRET", "hajime-shimamoto") + CFS_SECRETS_PASSWORD = lambda: os.getenv("CFS_SECRETS_PASSWORDs") # when you want to use chainfury as a client you need to set the following vars CFS_DATABASE = lambda: os.getenv("CFS_DATABASE", None) @@ -47,3 +46,61 @@ def folder(x: str) -> str: def joinp(x: str, *args) -> str: """convienience function for os.path.join""" return os.path.join(x, *args) + + +class Crypt: + + def __init__(self, salt="SlTKeYOpHygTYkP3"): + self.salt = salt.encode("utf8") + self.enc_dec_method = "utf-8" + + def encrypt(self, str_to_enc, str_key): + try: + aes_obj = AES.new(str_key.encode("utf-8"), AES.MODE_CFB, self.salt) + hx_enc = aes_obj.encrypt(str_to_enc.encode("utf8")) + mret = b64encode(hx_enc).decode(self.enc_dec_method) + return mret + except ValueError as value_error: + if value_error.args[0] == "IV must be 16 bytes long": + raise ValueError("Encryption Error: SALT must be 16 characters long") + elif ( + value_error.args[0] == "AES key must be either 16, 24, or 32 bytes long" + ): + raise ValueError( + "Encryption Error: Encryption key must be either 16, 24, or 32 characters long" + ) + else: + raise ValueError(value_error) + + def decrypt(self, enc_str, str_key): + try: + aes_obj = AES.new(str_key.encode("utf8"), AES.MODE_CFB, self.salt) + str_tmp = b64decode(enc_str.encode(self.enc_dec_method)) + str_dec = aes_obj.decrypt(str_tmp) + mret = str_dec.decode(self.enc_dec_method) + return mret + except ValueError as value_error: + if value_error.args[0] == "IV must be 16 bytes long": + raise ValueError("Decryption Error: SALT must be 16 characters long") + elif ( + value_error.args[0] == "AES key must be either 16, 24, or 32 bytes long" + ): + raise ValueError( + "Decryption Error: Encryption key must be either 16, 24, or 32 characters long" + ) + else: + raise ValueError(value_error) + + +CURRENT_EPOCH_START = 1705905900000 # UTC timezone +"""Start of the current epoch, used for generating snowflake ids""" + +from snowflake import SnowflakeGenerator + + +class SFGen: + def __init__(self, instance, epoch=CURRENT_EPOCH_START): + self.gen = SnowflakeGenerator(instance, epoch=epoch) + + def __call__(self): + return next(self.gen) diff --git a/server/chainfury_server/version.py b/server/chainfury_server/version.py index 73d6f7c..e23f8b6 100644 --- a/server/chainfury_server/version.py +++ b/server/chainfury_server/version.py @@ -1,6 +1,6 @@ # Copyright © 2023- Frello Technology Private Limited -__version__ = "2.1.2a" +__version__ = "2.1.3a0" _major, _minor, _patch = __version__.split(".") _major = int(_major) _minor = int(_minor) diff --git a/server/pyproject.toml b/server/pyproject.toml index 58e26e6..dc1342e 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -2,7 +2,7 @@ [tool.poetry] name = "chainfury_server" -version = "2.1.2a" +version = "2.1.3a0" description = "ChainFury Server is the DB + API server for managing the ChainFury engine in production. Used in production at chat.tune.app" authors = ["Tune AI "] license = "Apache 2.0" @@ -22,7 +22,8 @@ SQLAlchemy = "1.4.47" uvicorn = "0.27.1" PyMySQL = "1.0.3" urllib3 = ">=1.26.18" -"cryptography" = ">=41.0.6" +cryptography = ">=41.0.6" +snowflake_id = "1.0.1" [tool.poetry.scripts] chainfury_server = "chainfury_server:__main__" diff --git a/tests/__main__.py b/tests/__main__.py deleted file mode 100644 index 5f90bd1..0000000 --- a/tests/__main__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright © 2023- Frello Technology Private Limited - -from tests.getkv import TestGetValueByKeys -from tests.base import TestSerDeser, TestNode -import unittest - -if __name__ == "__main__": - unittest.main() diff --git a/tests/main.py b/tests/main.py new file mode 100644 index 0000000..2a8e4fc --- /dev/null +++ b/tests/main.py @@ -0,0 +1,15 @@ +# Copyright © 2023- Frello Technology Private Limited + +import os +from tuneapi.utils import folder, joinp + +tests = [] +curdir = folder(__file__) +for x in os.listdir(curdir): + if x.startswith("test_") and x.endswith(".py"): + tests.append(joinp(curdir, x)) + +for t in tests: + code = os.system(f"python3 {t} -v") + if code != 0: + raise Exception(f"Test {t} failed with code {code}") diff --git a/tests/test_base_chain2.py b/tests/test_base_chain2.py new file mode 100644 index 0000000..611727d --- /dev/null +++ b/tests/test_base_chain2.py @@ -0,0 +1,61 @@ +# Copyright © 2023- Frello Technology Private Limited + +from chainfury import ( + Chain, + Thread, + human, +) +from chainfury.components.tune import TuneModel +import unittest + + +chain = Chain( + name="demo-one", + description=( + "Building the hardcore example of chain at https://nimbleboxai.github.io/ChainFury/examples/usage-hardcore.html " + "using threaded chains" + ), + main_in="stupid_question", + main_out="fight_scene/fight_scene", + default_model=TuneModel("rohan/mixtral-8x7b-inst-v0-1-32k"), +) +chain.add_thread( + "character_one", + Thread( + human( + "You were who was running in the middle of desert. You see a McDonald's and the waiter ask a stupid " + "question like: '{{ stupid_question }}'? You are pissed and you say." + ), + ), +) +chain.add_thread( + "character_two", + Thread( + human( + "Someone comes upto you in a bar and screams '{{ character_one }}'? You are a bartender give a funny response to it." + ), + ), +) +chain.add_thread( + "fight_scene", + Thread( + human( + "Two men were fighting in a bar. One yelled '{{ character_one }}'. Other responded by yelling '{{ character_two }}'.\n" + "Continue this story for 3 more lines." + ) + ), +) + + +class TestChain(unittest.TestCase): + """Testing Chain specific functionality""" + + def test_chain_toposort(self): + self.assertEqual( + chain.topo_order, + ["character_one", "character_two", "fight_scene"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/base.py b/tests/test_base_types.py similarity index 59% rename from tests/base.py rename to tests/test_base_types.py index 5611b48..d081c17 100644 --- a/tests/base.py +++ b/tests/test_base_types.py @@ -1,22 +1,14 @@ # Copyright © 2023- Frello Technology Private Limited -from chainfury import programatic_actions_registry, Chain -from chainfury.components.functional import echo - import unittest - - -chain = Chain( - name="echo-cf-public", - description="abyss", - nodes=[programatic_actions_registry.get("chainfury-echo")], # type: ignore - sample={"message": "hi there"}, - main_in="message", - main_out="chainfury-echo/message", -) +from functools import cache +from chainfury.components.functional import echo +from chainfury import programatic_actions_registry, Chain, Var, Tools class TestSerDeser(unittest.TestCase): + """Tests Serialisation and Deserialisation of Nodes, Chains and Tools.""" + def test_chain_dict(self): Chain.from_dict(chain.to_dict()) @@ -43,8 +35,16 @@ def test_node_json(self): self.assertIsNotNone(node) node.from_json(node.to_json()) + def test_tool_dict(self): + Tools.from_dict(tool.to_dict()) + + def test_tool_json(self): + Tools.from_json(tool.to_json()) + class TestNode(unittest.TestCase): + """Test Node specific functionality.""" + def test_node_run(self): node = programatic_actions_registry.get("chainfury-echo") if node is None: @@ -58,5 +58,54 @@ def test_node_run(self): self.assertEqual(out, {"message": fn_out}) +# +# Chain definition +# + +chain = Chain( + name="echo-cf-public", + description="abyss", + nodes=[programatic_actions_registry.get("chainfury-echo")], # type: ignore + sample={"message": "hi there"}, + main_in="message", + main_out="chainfury-echo/message", +) + + +# +# Tool definition +# +tool = Tools( + name="calculator", + description=( + "This tool is a calculator, it can perform basica calculations. " + "Use this when you are trying to do some mathematical task" + ), +) + + +@tool.add( + description="This function adds two numbers", + properties={ + "a": Var("int", required=True, description="number one"), + "b": Var("int", description="number two"), + }, +) +def add_two_numbers(a: int, b: int = 10): + return a + b + + +@tool.add( + description="This calculates square root of a number", + properties={ + "a": Var("int", description="number to calculate square root of"), + }, +) +def square_root_number(a): + import math + + return math.sqrt(a) + + if __name__ == "__main__": unittest.main() diff --git a/tests/getkv.py b/tests/test_getkv.py similarity index 100% rename from tests/getkv.py rename to tests/test_getkv.py