diff --git a/.env.sample b/.env.sample
new file mode 100644
index 00000000..728f6483
--- /dev/null
+++ b/.env.sample
@@ -0,0 +1,29 @@
+# chainfury server
+# ================
+# These are the environment variables that are used by the chainfury_server
+# For chainfury jump below to the chainfury section
+
+# Required
+# --------
+
+# URL to the database for chainfury server, uses sqlalchemy, so most things should work
+CFS_DATABASE="db_drivers://username:password@host:port/db_name"
+
+# (once in a lifetime) secret string for creating the JWT secrets
+JWT_SECRET="secret"
+
+# (once in a lifetime) password to store the user secrets
+CFS_SECRETS_PASSWORD="password"
+
+# chainfury
+# =========
+# These are the environment variables that are used by the chainfury
+
+# To store all the file and data in the chainfury server
+CF_FOLDER="~/cf"
+
+# (client mode) the URL for the chainfury server
+CF_URL=""
+
+# (client mode) the token for the chainfury server
+CF_TOKEN=""
diff --git a/.gitignore b/.gitignore
index be520012..00613c48 100644
--- a/.gitignore
+++ b/.gitignore
@@ -142,7 +142,7 @@ langflow
dunmp.rdb
*.ipynb
server/chainfury_server/stories/fury.json
-notebooks/*
+notebooks
stories/fury.json
workers/
private.sh
@@ -153,3 +153,4 @@ demo/
logs.py
chunker/
chainfury/chains/
+gosrc/
diff --git a/README.md b/README.md
index 603ad1d5..6cd6c95c 100644
--- a/README.md
+++ b/README.md
@@ -24,6 +24,14 @@ ae e0 a5 87 e0 a4 b5 20 e0 a4 9c e0 a4 af
The documentation page contains all the information on using `chainfury` and `chainfury_server`.
+#### `chainfury`
+
+
+
+#### `chainfury_server`
+
+
+
# Looking for Inspirations?
Here's a few example to get your journey started on Software 2.0:
@@ -86,7 +94,7 @@ source venv/bin/activate
You will need to have `yarn` installed to build the frontend and move it to the correct location on the server
```bash
-sh stories/build_and_copy.sh
+sh build_ui.sh
```
Once the static files are copied we can now proceed to install dependecies:
@@ -104,7 +112,7 @@ You can now visit [localhost:8000](http://localhost:8000/ui/) to see the GUI and
There are a few test cases for super hard problems like `get_kv` which checks the `chainfury.base.get_value_by_keys` function.
```bash
-python3 -m tests -v
+python3 tests/main.py
```
# Contibutions
diff --git a/api_docs/conf.py b/api_docs/conf.py
index 05700179..8731d75c 100644
--- a/api_docs/conf.py
+++ b/api_docs/conf.py
@@ -14,7 +14,7 @@
project = "ChainFury"
copyright = "2023, NimbleBox Engineering"
author = "NimbleBox Engineering"
-release = "1.7.0a1"
+release = "1.7.0a2"
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
diff --git a/scripts/build_and_copy.sh b/build_ui.sh
similarity index 95%
rename from scripts/build_and_copy.sh
rename to build_ui.sh
index 47991e42..a1cd1b95 100755
--- a/scripts/build_and_copy.sh
+++ b/build_ui.sh
@@ -14,9 +14,6 @@ cd client
yarn install
yarn build
-# Go back to the root directory
-cd ..
-
# copy the dist folder to the server
# Go into the server folder, remove the old static folder and copy the new dist folder, copy index.html to templates
echo "Copying the generated files to the server"
diff --git a/chainfury/__init__.py b/chainfury/__init__.py
index 2111d9b0..561671fa 100644
--- a/chainfury/__init__.py
+++ b/chainfury/__init__.py
@@ -14,7 +14,16 @@
logger,
CFEnv,
)
-from chainfury.base import Var, Node, Secret, Chain, Model, Edge
+from chainfury.base import (
+ Var,
+ Node,
+ Secret,
+ Chain,
+ Model,
+ Edge,
+ Tools,
+ Action,
+)
from chainfury.core import (
model_registry,
programatic_actions_registry,
@@ -24,11 +33,11 @@
Memory,
)
from chainfury.client import get_client
-from chainfury.chat import (
+from chainfury.types import (
Message,
- Chat,
- TuneChats,
- TuneDataset,
+ Thread,
+ ThreadsList,
+ Dataset,
human,
system,
assistant,
diff --git a/chainfury/base.py b/chainfury/base.py
index 27e92536..04f1617b 100644
--- a/chainfury/base.py
+++ b/chainfury/base.py
@@ -1,5 +1,6 @@
# Copyright © 2023- Frello Technology Private Limited
+import os
import copy
import json
import jinja2
@@ -8,15 +9,17 @@
import importlib
import traceback
from pprint import pformat
+from functools import partial
from typing import Any, Union, Optional, Dict, List, Tuple, Callable, Generator
from collections import deque, defaultdict, Counter
import jinja2schema as j2s
from jinja2schema import model as j2sm
+from tuneapi.utils import load_module_from_path, to_json, from_json
+
from chainfury.utils import logger, terminal_top_with_text
import chainfury.types as T
-from chainfury import chat
class Secret(str):
@@ -25,6 +28,9 @@ class Secret(str):
def __init__(self, value=""):
self.value = value
+ def has_value(self) -> bool:
+ return self.value != ""
+
#
# Vars: this is the base class for all the fields that the user can provide from the front end
@@ -34,7 +40,7 @@ def __init__(self, value=""):
class Var:
def __init__(
self,
- type: Union[str, List["Var"]],
+ type: Union[str, List["Var"]] = "",
format: str = "",
items: List["Var"] = [],
additionalProperties: Union[List["Var"], "Var"] = [],
@@ -44,6 +50,7 @@ def __init__(
placeholder: str = "",
show: bool = False,
name: str = "",
+ description: str = "",
*,
loc: Optional[Tuple] = (),
):
@@ -61,6 +68,9 @@ def __init__(
name (str, optional): The name of this field. Defaults to "".
loc (Optional[Tuple], optional): The location of this field. Defaults to ().
"""
+ if not type:
+ raise ValueError("type cannot be empty")
+
self.type = type
self.format = format
self.items = items or []
@@ -71,6 +81,7 @@ def __init__(
self.placeholder = placeholder
self.show = show
self.name = name
+ self.description = description
#
self.value = None
self.loc = loc # this is the location from which this value is extracted
@@ -115,6 +126,8 @@ def to_dict(self) -> Dict[str, Any]:
d["name"] = self.name
if self.loc:
d["loc"] = self.loc
+ if self.description:
+ d["description"] = self.description
return d
@classmethod
@@ -137,6 +150,7 @@ def from_dict(cls, d: Dict[str, Any]) -> "Var":
show_val = d.get("show", False)
name_val = d.get("name", "")
loc_val = d.get("loc", ())
+ description_val = d.get("description", "")
if isinstance(type_val, list):
type_val = [
@@ -162,6 +176,7 @@ def from_dict(cls, d: Dict[str, Any]) -> "Var":
placeholder=placeholder_val,
show=show_val,
name=name_val,
+ description=description_val,
loc=loc_val,
)
return var
@@ -739,16 +754,27 @@ def __call__(self, model_data: Dict[str, Any]) -> Tuple[Any, Optional[Exception]
except Exception as e:
return traceback.format_exc(), e
+ def set_api_token(self, token: str) -> None:
+ raise NotImplementedError(
+ f"set_api_token method is not implemented for {self.id}"
+ )
+
def completion(self, prompt: str, **kwargs):
"""Subclass and implement your own text completion API"""
return NotImplementedError(
f"completion method is not implemented for {self.id}"
)
- def chat(self, chat: chat.Chat, **kwargs):
+ def chat(self, chat: T.Thread, **kwargs):
"""Subclass and implement your own chat API"""
raise NotImplementedError("chat method is not implemented for this model")
+ def stream_chat(self, chat: T.Thread, **kwargs):
+ """Subclass and implement your own chat API"""
+ raise NotImplementedError(
+ "stream_chat method is not implemented for this model"
+ )
+
#
# Node: Each box that is drag and dropped in the UI is a Node, it will tell what kind of things are
@@ -980,12 +1006,12 @@ def __call__(
@classmethod
def from_chat(
cls,
- chat: chat.Chat,
+ thread: T.Thread,
node_id: str,
model: Model,
description: Optional[str] = None,
) -> "Node":
- chat_dict = chat.to_dict()
+ chat_dict = thread.to_dict()
# print(variables)
fields = []
templates = []
@@ -995,7 +1021,7 @@ def from_chat(
obj = get_value_by_keys(chat_dict, field[0])
if not obj:
raise ValueError(
- f"Field {field[0]} not found in {chat}, but was extraced. There is a bug in get_value_by_keys function"
+ f"Field {field[0]} not found in {thread}, but was extraced. There is a bug in get_value_by_keys function"
)
templates.append((obj, jinja2.Template(obj), field[0]))
@@ -1123,7 +1149,11 @@ def __init__(
self.edges = edges
if len(self.nodes) == 1:
- assert len(self.edges) == 0, "Cannot have edges with only 1 node"
+ if len(self.edges) != 0:
+ logger.error(f"Got only one node: {self.nodes.keys()=}")
+ raise ValueError(
+ f"Cannot have edges with only 1 node. Got {self.edges}"
+ )
self.topo_order = [next(iter(self.nodes))]
else:
self.topo_order = topological_sort(self.edges)
@@ -1133,6 +1163,7 @@ def __init__(
if self.is_empty:
# there is nothing to do here
+ logger.info("This is empty chain")
return
if "/" not in main_out:
@@ -1184,7 +1215,7 @@ def __repr__(self) -> str:
def add_thread(
self,
node_id: str,
- chat: chat.Chat,
+ thread: T.Thread,
model: Optional[Model] = None,
description: str = "",
) -> "Chain":
@@ -1196,30 +1227,35 @@ def add_thread(
# build the node
node = Node.from_chat(
- chat,
+ thread,
node_id=node_id,
model=model or self.default_model, # type: ignore
description=description,
)
+ logger.debug(f"Adding node (total nodes {len(self.nodes)}): {node.id=}")
+
# add edges as required
for var in node.fields:
if var.name in self.nodes:
- self.edges.append(
- Edge(
- src_node_id=var.name,
- src_node_var=var.name,
- trg_node_id=node_id,
- trg_node_var=self.nodes[var.name].outputs[0].name,
- )
+ e = Edge(
+ src_node_id=var.name,
+ src_node_var=var.name,
+ trg_node_id=node_id,
+ trg_node_var=self.nodes[var.name].outputs[0].name,
)
+ logger.debug(f"Adding (total edges {len(self.edges)}) {e=}")
+ self.edges.append(e)
# assign the node
self.nodes[node.id] = node
# topo sort
if len(self.nodes) == 1:
- assert len(self.edges) == 0, "Cannot have edges with only 1 node"
+ if len(self.edges) != 0:
+ raise ValueError(
+ f"Cannot have edges with only 1 node. Got {self.edges}"
+ )
self.topo_order = [next(iter(self.nodes))]
else:
self.topo_order = topological_sort(self.edges)
@@ -1338,9 +1374,9 @@ def to_dag(self) -> T.Dag:
nodes = []
for i, node in enumerate(self.nodes.values()):
nodes.append(
- T.FENode(
+ T.UINode(
id=node.id,
- position=T.FENode.Position(
+ position=T.UINode.Position(
x=i * 100,
y=i * 100,
),
@@ -1348,13 +1384,13 @@ def to_dag(self) -> T.Dag:
width=100,
height=100,
selected=False,
- position_absolute=T.FENode.Position(
+ position_absolute=T.UINode.Position(
x=i * 100,
y=i * 100,
),
dragging=False,
cf_id=node.id,
- cf_data=T.FENode.CFData(
+ cf_data=T.UINode.CFData(
id=node.id,
type=node.type,
node=node.to_dict(),
@@ -1805,6 +1841,235 @@ def stream(
yield out, True
+#
+# Tools: A new abstraction for AGI
+#
+
+
+class Action:
+ def __init__(
+ self,
+ name: str,
+ description: str,
+ properties: Optional[Dict[str, Any]] = {},
+ required: Optional[List[str]] = [],
+ fn: Optional[Callable] = None,
+ fn_meta: Optional[Dict[str, Any]] = None,
+ ):
+ self.name = name
+ self.description = description
+ self.properties = properties
+ self.required = required
+
+ if (fn is None and fn_meta is None) or (fn is not None and fn_meta is not None):
+ raise ValueError("Either fn or fn_meta is required")
+ if fn_meta is not None:
+ # first validate that the line content is same in the two files, then try to load the item
+ if not os.path.exists(fn_meta["file"]):
+ raise ValueError(f"File {fn_meta['file']} does not exist")
+ with open(fn_meta["file"], "r") as f:
+ for i, l in enumerate(f):
+ if i == fn_meta["line"] and l != fn_meta["line_val"]:
+ raise ValueError(
+ f"Line #{fn_meta['line']} does not match in {fn_meta['file']}\n"
+ f" Expected: {l}\n"
+ f" Found: {fn_meta['line_val']}"
+ )
+ fn = load_module_from_path(fn_meta["name"], fn_meta["file"])
+ elif fn is not None:
+ fn_meta = {
+ "file": inspect.getfile(fn),
+ "line": inspect.getsourcelines(fn)[1] - 1,
+ "line_val": inspect.getsourcelines(fn)[0][0],
+ "name": fn.__name__,
+ }
+
+ self.fn_meta = fn_meta
+ self.fn: Callable = fn # type: ignore
+
+ def __repr__(self) -> str:
+ return f"[Action] {self.name}: {self.description}"
+
+ # ser/deser
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Serializes the action to a dictionary.
+
+ Returns:
+ Dict[str, Any]: The dictionary representation of the action.
+ """
+ return {
+ "name": self.name,
+ "description": self.description,
+ "required": self.required,
+ "properties": self.properties,
+ "fn_meta": self.fn_meta,
+ }
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "Action":
+ """Deserializes the action from a dictionary.
+
+ Args:
+ data (Dict[str, Any]): The dictionary representation of the action.
+ """
+ return cls(
+ name=data["name"],
+ description=data["description"],
+ required=data["required"],
+ properties=data["properties"],
+ fn_meta=data["fn_meta"],
+ )
+
+ def to_json(self, indent=0, tight=True) -> str:
+ """Serializes the action to a JSON string.
+
+ Returns:
+ str: The JSON string representation of the action.
+ """
+ return to_json(self.to_dict(), indent=indent, tight=tight)
+
+ @classmethod
+ def from_json(cls, data: str) -> "Action":
+ """Creates an action from a JSON string.
+
+ Args:
+ data (str): The JSON string representation of the action.
+ """
+ return cls.from_dict(json.loads(data))
+
+ def __call__(self, *args, **kwargs):
+ # validate the data is in
+ return self.fn(*args, **kwargs)
+
+ # usage
+
+ def to_fn(self) -> Dict[str, Any]:
+ data = self.to_dict()
+ data.pop("fn_meta")
+ return data
+
+
+class Tools:
+ """
+ Usage:
+
+ >>> from chainfury import Tool, Var
+ >>> my_tool = Tool("My Tool", "This is a test tool")
+ >>> @my_tool(
+ ... description = "this is test action",
+ ... props = {
+ ... "a": Var("int", "number 1"),
+ ... "b": Var("int", "number 2"),
+ ... "secret": Var("int", secret = True, key="MY_ENV_VAR"), # to implement
+ ... }
+ ... )
+ ... def add_two_numbers(a: int, b: int, secret: int):
+ ... return (a + b) * secret
+ ...
+ >>> my_tool.to_json(indent = 2)
+ {
+ "name": "add_two_numbers",
+ ...
+ }
+ """
+
+ def __init__(self, name: str, description: str):
+ self.name = name
+ self.description = description
+
+ #
+ self.actions: Dict[str, Action] = {}
+
+ def __repr__(self) -> str:
+ return f"[Tool] {self.name}: {self.description}"
+
+ def _register_action(
+ self,
+ fn: Callable,
+ description: str,
+ properties: Dict[str, Var],
+ name: Optional[str] = None,
+ ) -> Action:
+ name = name or fn.__name__
+ props = {}
+ required = []
+ for k, v in properties.items():
+ props[k] = {
+ "type": v.type,
+ "description": v.description,
+ }
+ if v.required:
+ required.append(k)
+
+ self.actions[name] = Action(
+ name=name,
+ description=description,
+ properties=props,
+ required=required,
+ fn=fn,
+ )
+ return self.actions[name]
+
+ def add(self, description: str, properties: Dict[str, Var] = {}) -> Action:
+ """
+ Register the actions
+ """
+ return partial(
+ self._register_action,
+ description=description,
+ properties=properties,
+ ) # type: ignore
+
+ # ser/deser
+
+ def to_dict(self) -> Dict[str, Any]:
+ """
+ Register the actions
+ """
+ return {
+ "name": self.name,
+ "description": self.description,
+ "actions": {k: v.to_dict() for k, v in self.actions.items()},
+ }
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "Tools":
+ """Deserializes the Tool from a dictionary.
+
+ Args:
+ data (Dict[str, Any]): The dictionary representation of the chain.
+ """
+ self = cls(
+ name=data["name"],
+ description=data["description"],
+ )
+ actions = data.get("actions", {})
+ for a in actions.values():
+ action = Action.from_dict(a)
+ self.actions[action.name] = action
+
+ return self
+
+ def to_json(self, indent=2, tight=False) -> str:
+ """Serializes the Tool to a JSON string.
+
+ Returns:
+ str: The JSON string representation of the chain.
+ """
+ return to_json(self.to_dict(), indent=indent, tight=tight)
+
+ @classmethod
+ def from_json(cls, data: str) -> "Tools":
+ """
+ Register the actions
+ """
+ return cls.from_dict(json.loads(data))
+
+ def to_fn(self) -> Dict[str, Any]:
+ return {"name": self.name, "description": self.description, "properties": {}}
+
+
#
# helper functions
#
diff --git a/chainfury/chat.py b/chainfury/chat.py
deleted file mode 100644
index 7ebfe9d6..00000000
--- a/chainfury/chat.py
+++ /dev/null
@@ -1,431 +0,0 @@
-# Copyright © 2023- Frello Technology Private Limited
-
-import os
-import json
-import random
-from functools import partial
-from collections.abc import Iterable
-from typing import Dict, List, Any, Tuple, Optional, Generator
-
-from chainfury.utils import to_json, get_random_string, logger
-
-
-class Message:
- SYSTEM = "system"
- HUMAN = "human"
- GPT = "gpt"
- VALUE = "value"
- FUNCTION = "function"
- FUNCTION_RESPONSE = "function-response"
-
- # start initialization here
- def __init__(self, value: str | float, role: str):
- if role in ["system", "sys"]:
- role = self.SYSTEM
- elif role in ["user", "human"]:
- role = self.HUMAN
- elif role in ["gpt", "assistant", "machine"]:
- role = self.GPT
- elif role in ["value"]:
- role = self.VALUE
- elif role in ["function", "fn"]:
- role = self.FUNCTION
- elif role in ["function-response", "fn-resp"]:
- role = self.FUNCTION_RESPONSE
- else:
- raise ValueError(f"Unknown role: {role}")
- if value is None:
- raise ValueError("value cannot be None")
-
- self.role = role
- self.value = value
- self._unq_value = get_random_string(6)
-
- def __str__(self) -> str:
- try:
- idx = max(os.get_terminal_size().columns - len(self.role) - 40, 10)
- except OSError:
- idx = 50
- return f"<{self.role}: {json.dumps(self.value)[:idx]}>"
-
- def __repr__(self) -> str:
- return str(self.value)
-
- def to_dict(self, ft: bool = False):
- """
- if `ft` then export to following format: `{"from": "system/human/gpt", "value": "..."}`
- else export to following format: `{"role": "system/user/assistant", "content": "..."}`
- """
- role = self.role
- if not ft:
- if self.role == self.HUMAN:
- role = "user"
- elif self.role == self.GPT:
- role = "assistant"
-
- chat_message: Dict[str, str | float]
- if ft:
- chat_message = {"from": role}
- else:
- chat_message = {"role": role}
-
- if not ft:
- chat_message["content"] = self.value
- else:
- chat_message["value"] = self.value
- return chat_message
-
- @classmethod
- def from_dict(cls, data):
- return cls(
- value=data.get("value") or data.get("content"),
- role=data.get("from") or data.get("role"),
- ) # type: ignore
-
-
-### Aliases
-human = partial(Message, role=Message.HUMAN)
-system = partial(Message, role=Message.SYSTEM)
-assistant = partial(Message, role=Message.GPT)
-
-
-class Chat:
- """
- If the last Message is a "value" then a special tag "koro.regression"="true" is added to the meta.
-
- Args:
- chats (List[Message]): List of chat messages
- jl (Dict[str, Any]): Optional json-logic
- """
-
- def __init__(
- self,
- chats: List[Message],
- jl: Optional[Dict[str, Any]] = None,
- model: Optional[str] = None,
- **kwargs,
- ):
- self.chats = chats
- self.jl = jl
- self.model = model
-
- # check for regression
- if self.chats[-1].role == Message.VALUE:
- kwargs["koro.regression"] = True
-
- kwargs = {k: v for k, v in sorted(kwargs.items())}
- self.meta = kwargs
- self.keys = list(kwargs.keys())
- self.values = tuple(kwargs.values())
-
- # avoid special character BS.
- assert not any(["=" in x or "&" in x for x in self.keys])
- if self.values:
- assert all([type(x) in [int, str, float, bool] for x in self.values])
-
- self.value_hash = hash(self.values)
-
- def __repr__(self) -> str:
- x = ""
- return x
-
- def __getattr__(self, __name: str) -> Any:
- if __name in self.meta:
- return self.meta[__name]
- raise AttributeError(f"Attribute {__name} not found")
-
- # ser/deser
-
- def to_dict(self, full: bool = False):
- if full:
- return {
- "chats": [x.to_dict() for x in self.chats],
- "jl": self.jl,
- "model": self.model,
- "meta": self.meta,
- }
- return {
- "chats": [x.to_dict() for x in self.chats],
- }
-
- def to_chat_template(self):
- return self.to_dict()["chats"]
-
- @classmethod
- def from_dict(cls, data: Dict[str, Any]):
- chats = data.get("chats", None) or data.get("conversations", None)
- if not chats:
- raise ValueError("No chats found")
- return cls(
- chats=[Message.from_dict(x) for x in chats],
- jl=data.get("jl"),
- model=data.get("model"),
- **data.get("meta", {}),
- )
-
- def to_ft(
- self, id: Any = None, drop_last: bool = False
- ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
- chats = self.chats if not drop_last else self.chats[:-1]
- ft_dict = {
- "id": id or get_random_string(6),
- "conversations": [x.to_dict(ft=True) for x in chats],
- }
- if drop_last:
- ft_dict["last"] = self.chats[-1].to_dict(ft=True)
- return ft_dict, self.meta
-
- # modifications
-
- def copy(self) -> "Chat":
- return Chat(
- chats=[x for x in self.chats],
- jl=self.jl,
- model=self.model,
- **self.meta,
- )
-
- def add(self, message: Message):
- self.chats.append(message)
-
-
-# these are the classes that we use for tune datasets from r-stack
-
-
-class TuneChats(list):
- """This class implements some basic container methods for a list of Chat objects"""
-
- def __init__(self):
- self.keys = {}
- self.items: List[Chat] = []
- self.idx_dict: Dict[int, Tuple[Any, ...]] = {}
- self.key_to_items_idx: Dict[int, List[int]] = {}
-
- def __repr__(self) -> str:
- return (
- f"TuneChats(unq_keys={len(self.key_to_items_idx)}, items={len(self.items)})"
- )
-
- def __len__(self) -> int:
- return len(self.items)
-
- def __iter__(self) -> Generator[Chat, None, None]:
- for x in self.items:
- yield x
-
- def stream(self) -> Generator[Chat, None, None]:
- for x in self:
- yield x
-
- def __getitem__(self, __index) -> List[Chat]:
- return self.items[__index]
-
- def table(self) -> str:
- try:
- from tabulate import tabulate
- except ImportError:
- raise ImportError("Install tabulate to use this method")
-
- table = []
- for k, v in self.idx_dict.items():
- table.append(
- [
- *v,
- len(self.key_to_items_idx[k]),
- f"{len(self.key_to_items_idx[k])/len(self)*100:0.2f}%",
- ]
- )
- return tabulate(table, headers=[*list(self.keys), "count", "percentage"])
-
- # data manipulation
-
- def append(self, __object: Any) -> None:
- if not self.items:
- self.keys = __object.meta.keys()
- if self.keys != __object.meta.keys():
- raise ValueError("Keys should match")
- self.idx_dict.setdefault(__object.value_hash, __object.values)
- self.key_to_items_idx.setdefault(__object.value_hash, [])
- self.key_to_items_idx[__object.value_hash].append(len(self.items))
- self.items.append(__object)
-
- def add(self, x: Chat):
- return self.append(x)
-
- def extend(self, __iterable: Iterable) -> None:
- if hasattr(__iterable, "items"):
- for x in __iterable.items: # type: ignore
- self.append(x)
- elif isinstance(__iterable, Iterable):
- for x in __iterable:
- self.append(x)
- else:
- raise ValueError("Unknown iterable")
-
- def shuffle(self, seed: Optional[int] = None) -> None:
- """Perform in place shuffle"""
- # shuffle using indices, self.items and self.key_to_items_idx
- idx = list(range(len(self.items)))
- if seed:
- rng = random.Random(seed)
- rng.shuffle(idx)
- else:
- random.shuffle(idx)
- self.items = [self.items[i] for i in idx]
- self.key_to_items_idx = {}
- for i, x in enumerate(self.items):
- self.key_to_items_idx.setdefault(x.value_hash, [])
- self.key_to_items_idx[x.value_hash].append(i)
-
- def create_te_split(self, test_items: int | float = 0.1) -> Tuple["TuneChats", ...]:
- try:
- import numpy as np
- except ImportError:
- raise ImportError("Install numpy to use `create_te_split` method")
-
- train_ds = TuneChats()
- eval_ds = TuneChats()
- items_np_arr = np.array(self.items)
- for k, v in self.key_to_items_idx.items():
- if isinstance(test_items, float):
- if int(len(v) * test_items) < 1:
- raise ValueError(
- f"Test percentage {test_items} is too high for the dataset key '{k}'"
- )
- split_ids = random.sample(v, int(len(v) * test_items))
- else:
- if test_items > len(v):
- raise ValueError(
- f"Test items {test_items} is too high for the dataset key '{k}'"
- )
- split_ids = random.sample(v, test_items)
-
- # get items
- eval_items = items_np_arr[split_ids]
- train_items = items_np_arr[np.setdiff1d(v, split_ids)]
- train_ds.extend(train_items)
- eval_ds.extend(eval_items)
-
- return train_ds, eval_ds
-
- # ser / deser
-
- def to_dict(self):
- return {"items": [x.to_dict() for x in self.items]}
-
- @classmethod
- def from_dict(cls, data):
- bench_dataset = cls()
- for item in data["items"]:
- bench_dataset.append(Chat.from_dict(item))
- return bench_dataset
-
- def to_disk(self, folder: str, fmt: Optional[str] = None):
- if fmt:
- logger.warn(
- f"exporting to {fmt} format, you cannot recreate the dataset from this."
- )
- os.makedirs(folder)
- with open(f"{folder}/tuneds.jsonl", "w") as f:
- for sample in self.items:
- if fmt == "sharegpt":
- item, _ = sample.to_ft()
- elif fmt is None:
- item = sample.to_dict()
- else:
- raise ValueError(f"Unknown format: {fmt}")
- f.write(to_json(item, tight=True) + "\n") # type: ignore
-
- @classmethod
- def from_disk(cls, folder: str):
- bench_dataset = cls()
- with open(f"{folder}/tuneds.jsonl", "r") as f:
- for line in f:
- item = json.loads(line)
- bench_dataset.append(Chat.from_dict(item))
- return bench_dataset
-
- def to_hf_dataset(self) -> Tuple["datasets.Dataset", List]: # type: ignore
- try:
- import datasets as dst
- except ImportError:
- raise ImportError("Install huggingface datasets library to use this method")
-
- _ds_list = []
- meta_list = []
- for x in self.items:
- sample, meta = x.to_ft()
- _ds_list.append(sample)
- meta_list.append(meta)
- return dst.Dataset.from_list(_ds_list), meta_list
-
- # properties
-
- def can_train_koro_regression(self) -> bool:
- return all(["koro.regression" in x.meta for x in self])
-
-
-class TuneDataset:
- """This class is a container for training and evaulation datasets, useful for serialising items to and from disk"""
-
- def __init__(self, train: TuneChats, eval: TuneChats):
- self.train_ds = train
- self.eval_ds = eval
-
- def __repr__(self) -> str:
- return f"TuneDataset(\n train={self.train_ds},\n eval={self.eval_ds}\n)"
-
- @classmethod
- def from_list(cls, items: List["TuneDataset"]):
- train_ds = TuneChats()
- eval_ds = TuneChats()
- for item in items:
- train_ds.extend(item.train_ds)
- eval_ds.extend(item.eval_ds)
- return cls(train=train_ds, eval=eval_ds)
-
- def to_hf_dict(self) -> Tuple["datasets.DatasetDict", Dict[str, List]]: # type: ignore
- try:
- import datasets as dst
- except ImportError:
- raise ImportError("Install huggingface datasets library to use this method")
-
- train_ds, train_meta = self.train_ds.to_hf_dataset()
- eval_ds, eval_meta = self.eval_ds.to_hf_dataset()
- return dst.DatasetDict(train=train_ds, eval=eval_ds), {
- "train": train_meta,
- "eval": eval_meta,
- }
-
- def to_disk(self, folder: str, fmt: Optional[str] = None):
- config = {}
- config["type"] = "tune"
- config["hf_type"] = fmt
- os.makedirs(folder)
- self.train_ds.to_disk(f"{folder}/train", fmt=fmt)
- self.eval_ds.to_disk(f"{folder}/eval", fmt=fmt)
- to_json(config, fp=f"{folder}/tune_config.json", tight=True)
-
- @classmethod
- def from_disk(cls, folder: str):
- if not os.path.exists(folder):
- raise ValueError(f"Folder '{folder}' does not exist")
- if not os.path.exists(f"{folder}/train"):
- raise ValueError(f"Folder '{folder}/train' does not exist")
- if not os.path.exists(f"{folder}/eval"):
- raise ValueError(f"Folder '{folder}/eval' does not exist")
- if not os.path.exists(f"{folder}/tune_config.json"):
- raise ValueError(f"File '{folder}/tune_config.json' does not exist")
-
- # not sure what to do with these
- with open(f"{folder}/tune_config.json", "r") as f:
- config = json.load(f)
- return cls(
- train=TuneChats.from_disk(f"{folder}/train"),
- eval=TuneChats.from_disk(f"{folder}/eval"),
- )
diff --git a/chainfury/cli.py b/chainfury/cli.py
index ae1067c2..650c0190 100644
--- a/chainfury/cli.py
+++ b/chainfury/cli.py
@@ -1,19 +1,23 @@
# Copyright © 2023- Frello Technology Private Limited
+import dotenv
+
+dotenv.load_dotenv()
+
import os
import sys
import json
from fire import Fire
+from typing import Optional
from chainfury import Chain
from chainfury.version import __version__
-from chainfury.components import all_items
-from chainfury.core import model_registry, programatic_actions_registry, memory_registry
+from chainfury.core import model_registry
+from chainfury.types import Thread, Message
-def help():
- print(
- f"""
+class CLI:
+ info = rf"""
___ _ _ ___
/ __| |_ __ _(_)_ _ | __| _ _ _ _ _
| (__| ' \/ _` | | ' \ | _| || | '_| || |
@@ -23,124 +27,121 @@ def help():
ae e0 a5 87 e0 a4 b5 20 e0 a4 9c e0 a4 af
e0 a4 a4 e0 a5 87
-🦋 Welcome to ChainFury Engine!
-
cf_version: {__version__}
-The chaining engine behind chat.tune.app
-
-A powerful way to program for the "Software 2.0" era. Read more:
-
-- https://blog.nimblebox.ai/new-flow-engine-from-scratch
-- https://blog.nimblebox.ai/fury-actions
-- https://gist.github.com/yashbonde/002c527853e04869bfaa04646f3e0974
-- https://tunehq.ai
-- https://chat.tune.app
-- https://studio.tune.app
-
+🦋 The FOSS chaining engine behind chat.tune.app
🌟 us on https://github.com/NimbleBoxAI/ChainFury
-
-Build with ♥️ by Tune AI
-
-🌊 Chennai, India
+♥️ Built by [Tune AI](https://tunehq.ai) from ECR, Chennai 🌊
"""
- )
-
-
-def run(
- chain: str,
- inp: str,
- stream: bool = False,
- print_thoughts: bool = False,
- f=sys.stdout,
-):
- """
- Run a chain with input and write the outputs.
-
- Args:
- chain (str): This can be one of json filepath (e.g. "/chain.json"), json string (e.g. '{"id": "99jcjs9j2", ...}'),
- chain id (e.g. "99jcjs9j2")
- inp (str): This can be one of json filepath (e.g. "/input.json"), json string (e.g. '{"foo": "bar", ...}')
- stream (bool, optional): Whether to stream the output. Defaults to False.
- print_thoughts (bool, optional): Whether to print thoughts. Defaults to False.
- f (file, optional): File to write the output to. Defaults to `sys.stdout`.
-
- Examples:
- >>> $ cf run ./sample.json {"foo": "bar"}
- """
- # validate inputs
- if isinstance(inp, str):
- if os.path.exists(inp):
- with open(inp, "r") as f:
- inp = json.load(f)
+
+ def run(
+ self,
+ chain: str,
+ inp: str,
+ stream: bool = False,
+ print_thoughts: bool = False,
+ f=sys.stdout,
+ ):
+ """
+ Run a chain with input and write the outputs.
+
+ Args:
+ chain (str): This can be one of json filepath (e.g. "/chain.json"), json string (e.g. '{"id": "99jcjs9j2", ...}'),
+ chain id (e.g. "99jcjs9j2")
+ inp (str): This can be one of json filepath (e.g. "/input.json"), json string (e.g. '{"foo": "bar", ...}')
+ stream (bool, optional): Whether to stream the output. Defaults to False.
+ print_thoughts (bool, optional): Whether to print thoughts. Defaults to False.
+ f (file, optional): File to write the output to. Defaults to `sys.stdout`.
+
+ Examples:
+ >>> $ cf run ./sample.json {"foo": "bar"}
+ """
+ # validate inputs
+ if isinstance(inp, str):
+ if os.path.exists(inp):
+ with open(inp, "r") as f:
+ inp = json.load(f)
+ else:
+ try:
+ inp = json.loads(inp)
+ except Exception as e:
+ raise ValueError(
+ "Input must be a valid json string or a json file path"
+ )
+ assert isinstance(inp, dict), "Input must be a dict"
+
+ # create chain
+ chain_obj = None
+ if isinstance(chain, str):
+ if os.path.exists(chain):
+ with open(chain, "w") as f:
+ chain = json.load(f)
+ if len(chain) == 8:
+ chain_obj = Chain.from_id(chain)
+ else:
+ chain = json.loads(chain)
+ elif isinstance(chain, dict):
+ chain_obj = Chain.from_dict(chain)
+ assert chain_obj is not None, "Chain not found"
+
+ # output
+ if isinstance(f, str):
+ f = open(f, "w")
+
+ # run the chain
+ if stream:
+ cf_response_gen = chain_obj.stream(inp, print_thoughts=print_thoughts)
+ for ir, done in cf_response_gen:
+ if not done:
+ f.write(json.dumps(ir) + "\n")
else:
+ out, buffer = chain_obj(inp, print_thoughts=print_thoughts)
+ for k, v in buffer.items():
+ f.write(json.dumps({k: v}) + "\n")
+
+ # close file
+ f.close()
+
+ def sh(
+ self,
+ api: str = "tuneapi",
+ model: str = "rohan/mixtral-8x7b-inst-v0-1-32k", # "kaushikaakash04/tune-blob"
+ token: Optional[str] = None,
+ stream: bool = True,
+ ):
+ cf_model = model_registry.get(api)
+ if token is not None:
+ cf_model.set_api_token(token)
+
+ # loop for user input through command line
+ thread = Thread()
+ usr_cntr = 0
+ while True:
try:
- inp = json.loads(inp)
- except Exception as e:
- raise ValueError(
- "Input must be a valid json string or a json file path"
+ user_input = input(
+ f"\033[1m\033[33m [{usr_cntr:02d}] YOU \033[39m:\033[0m "
)
- assert isinstance(inp, dict), "Input must be a dict"
-
- # create chain
- chain_obj = None
- if isinstance(chain, str):
- if os.path.exists(chain):
- with open(chain, "w") as f:
- chain = json.load(f)
- if len(chain) == 8:
- chain_obj = Chain.from_id(chain)
- else:
- chain = json.loads(chain)
- elif isinstance(chain, dict):
- chain_obj = Chain.from_dict(chain)
- assert chain_obj is not None, "Chain not found"
-
- # output
- if isinstance(f, str):
- f = open(f, "w")
-
- # run the chain
- if stream:
- cf_response_gen = chain_obj.stream(inp, print_thoughts=print_thoughts)
- for ir, done in cf_response_gen:
- if not done:
- f.write(json.dumps(ir) + "\n")
- else:
- out, buffer = chain_obj(inp, print_thoughts=print_thoughts)
- for k, v in buffer.items():
- f.write(json.dumps({k: v}) + "\n")
-
- # close file
- f.close()
+ except KeyboardInterrupt:
+ break
+ if user_input == "exit" or user_input == "quit" or user_input == "":
+ break
+ thread.add(Message(user_input, Message.HUMAN))
+
+ print(f"\033[1m\033[34m ASSISTANT \033[39m:\033[0m ", end="", flush=True)
+ if stream:
+ response = ""
+ for str_token in cf_model.stream_chat(thread, model=model):
+ response += str_token
+ print(str_token, end="", flush=True)
+ print() # new line
+ thread.add(Message(response, Message.GPT))
+ else:
+ response = cf_model.chat(thread, model=model)
+ print(response)
+
+ thread.add(Message(response, Message.GPT))
+ usr_cntr += 1
def main():
- Fire(
- {
- "comp": {
- "all": lambda: print(all_items),
- "model": {
- "list": list(model_registry.get_models()),
- "all": model_registry.get_models(),
- "get": model_registry.get,
- },
- "prog": {
- "list": list(programatic_actions_registry.get_nodes()),
- "all": programatic_actions_registry.get_nodes(),
- },
- "memory": {
- "list": list(memory_registry.get_nodes()),
- "all": memory_registry.get_nodes(),
- },
- },
- "help": help,
- "run": run,
- "version": lambda: print(
- f"""ChainFury 🦋 Engine
-
-chainfury=={__version__}
-"""
- ),
- }
- )
+ Fire(CLI)
diff --git a/chainfury/client.py b/chainfury/client.py
index 5163a975..001e92e0 100644
--- a/chainfury/client.py
+++ b/chainfury/client.py
@@ -1,6 +1,5 @@
# Copyright © 2023- Frello Technology Private Limited
-import os
import requests
from functools import lru_cache
from typing import Dict, Any, Tuple
diff --git a/chainfury/components/const.py b/chainfury/components/const.py
index fbb7d6e3..59a99a64 100644
--- a/chainfury/components/const.py
+++ b/chainfury/components/const.py
@@ -12,7 +12,7 @@ class Env:
* CF_URL: ChainFury API URL
* NBX_DEPLOY_URL: NimbleBox Deploy URL
* NBX_DEPLOY_KEY: NimbleBox Deploy API key
- * TUNECHAT_KEY: ChatNBX API key, see chat.nbox.ai
+ * TUNEAPI_TOKEN: ChatNBX API key, see chat.nbox.ai
* OPENAI_TOKEN: OpenAI API token, see platform.openai.com
* SERPER_API_KEY: Serper API key, see serper.dev/
* STABILITY_KEY: Stability API key, see dreamstudio.ai
@@ -29,7 +29,7 @@ class Env:
NBX_DEPLOY_KEY = lambda x: x or os.getenv("NBX_DEPLOY_KEY", "")
## different keys for different 3rd party APIs
- TUNECHAT_KEY = lambda x: x or os.getenv("TUNECHAT_KEY", "")
+ TUNEAPI_TOKEN = lambda x: x or os.getenv("TUNEAPI_TOKEN", "")
OPENAI_TOKEN = lambda x: x or os.getenv("OPENAI_TOKEN", "")
SERPER_API_KEY = lambda x: x or os.getenv("SERPER_API_KEY", "")
diff --git a/chainfury/components/functional/__init__.py b/chainfury/components/functional/__init__.py
index c64c9b06..789086d2 100644
--- a/chainfury/components/functional/__init__.py
+++ b/chainfury/components/functional/__init__.py
@@ -237,7 +237,7 @@ def echo(message: str) -> Tuple[Dict[str, Dict[str, str]], Optional[Exception]]:
programatic_actions_registry.register(
fn=echo,
- outputs={"message": (0,)}, # type: ignore
+ outputs={"message": ()}, # type: ignore
node_id="chainfury-echo",
description="I stared into the abyss and it stared back at me. Echoes the message, used for debugging",
)
diff --git a/chainfury/components/openai/__init__.py b/chainfury/components/openai/__init__.py
index bfbcd35d..0fddf4c7 100644
--- a/chainfury/components/openai/__init__.py
+++ b/chainfury/components/openai/__init__.py
@@ -13,23 +13,27 @@
UnAuthException,
)
from chainfury.components.const import Env
-from chainfury.chat import Chat
+from chainfury.types import Thread
class OpenaiGPTModel(Model):
def __init__(self, id: Optional[str] = None):
self._openai_model_id = id
+ self.openai_api_key = Secret(Env.OPENAI_TOKEN(""))
super().__init__(
id="openai-chat",
description="Use OpenAI chat models",
usage=["usage", "total_tokens"],
)
+ def set_api_token(self, token: str) -> None:
+ self.openai_api_key = Secret(token)
+
def chat(
self,
- chats: Chat,
+ chats: Thread,
model: Optional[str] = None,
- openai_api_key: Secret = Secret(""),
+ token: Secret = Secret(""),
temperature: float = 1.0,
top_p: float = 1.0,
n: int = 1,
@@ -49,7 +53,7 @@ def chat(
Args:
messages: A list of messages describing the conversation so far
model: ID of the model to use. See [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat/create).
- openai_api_key (Secret): The OpenAI API key. Defaults to "" or the OPENAI_TOKEN environment variable.
+ token (Secret): The OpenAI API key. Defaults to "" or the OPENAI_TOKEN environment variable.
temperature: Optional. What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both. Defaults to 1.
top_p: Optional. An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both. Defaults to 1.
n: Optional. How many chat completion choices to generate for each input message. Defaults to 1.
@@ -65,14 +69,13 @@ def chat(
Returns:
Any: The completion(s) generated by the API.
"""
- if not openai_api_key:
- openai_api_key = Secret(Env.OPENAI_TOKEN("")).value # type: ignore
- if not openai_api_key:
+ if not token and not self.openai_api_key.value:
raise Exception(
"OpenAI API key not found. Please set OPENAI_TOKEN environment variable or pass through function"
)
- if isinstance(chats, Chat):
- messages = chats.to_dict()
+
+ if isinstance(chats, Thread):
+ messages = chats.to_dict()["chats"]
else:
messages = chats
@@ -83,7 +86,7 @@ def _fn():
"https://api.openai.com/v1/chat/completions",
headers={
"Content-Type": "application/json",
- "Authorization": f"Bearer {openai_api_key}",
+ "Authorization": f"Bearer {token}",
},
json={
"model": model,
diff --git a/chainfury/components/tune/__init__.py b/chainfury/components/tune/__init__.py
index b1caeda0..7aa550cf 100644
--- a/chainfury/components/tune/__init__.py
+++ b/chainfury/components/tune/__init__.py
@@ -7,7 +7,7 @@
from chainfury import Secret, model_registry, exponential_backoff, Model
from chainfury.components.const import Env
-from chainfury.chat import Chat
+from chainfury.types import Thread
class TuneModel(Model):
@@ -15,23 +15,79 @@ class TuneModel(Model):
def __init__(self, id: Optional[str] = None):
self._tune_model_id = id
+ self.tune_api_token = Secret(Env.TUNEAPI_TOKEN(""))
super().__init__(
- id="chatnbx",
- description="Chat with the ChatNBX API with OpenAI compatability, see more at https://chat.nbox.ai/",
+ id="tuneapi",
+ description="Chat with the Tune Studio APIs, see more at https://studio.tune.app/",
usage=["usage", "total_tokens"],
)
+ def set_api_token(self, token: str) -> None:
+ self.tune_api_token = Secret(token)
+
def chat(
self,
- chats: Chat,
- chatnbx_api_key: Secret = Secret(""),
+ chats: Thread,
model: Optional[str] = None,
max_tokens: int = 1024,
temperature: float = 1,
*,
- retry_count: int = 3,
- retry_delay: int = 1,
- ) -> Dict[str, Any]:
+ token: Secret = Secret(""),
+ ) -> str:
+ """
+ Chat with the Tune Studio APIs, see more at https://studio.tune.app/
+
+ Note: This is a API is partially compatible with OpenAI's API, so `messages` should be of type :code:`[{"role": ..., "content": ...}]`
+
+ Args:
+ model (str): The model to use, see https://studio.nbox.ai/ for more info
+ messages (List[Dict[str, str]]): A list of messages to send to the API which are OpenAI compatible
+ token (Secret, optional): The API key to use or set TUNEAPI_TOKEN environment variable
+ max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 1024.
+ temperature (float, optional): The higher the temperature, the crazier the text. Defaults to 1.
+
+ Returns:
+ Dict[str, Any]: The response from the API
+ """
+ if not token and not self.tune_api_token.has_value(): # type: ignore
+ raise Exception(
+ "Tune API key not found. Please set TUNEAPI_TOKEN environment variable or pass through function"
+ )
+ token = token or self.tune_api_token
+ if isinstance(chats, Thread):
+ messages = chats.to_dict()["chats"]
+ else:
+ messages = chats
+
+ model = model or self._tune_model_id
+ url = "https://proxy.tune.app/chat/completions"
+ headers = {
+ "Authorization": token.value,
+ "Content-Type": "application/json",
+ }
+ data = {
+ "temperature": temperature,
+ "messages": messages,
+ "model": model,
+ "stream": False,
+ "max_tokens": max_tokens,
+ }
+ response = requests.post(url, headers=headers, json=data)
+ try:
+ response.raise_for_status()
+ except Exception as e:
+ raise e
+ return response.json()["choices"][0]["message"]["content"]
+
+ def stream_chat(
+ self,
+ chats: Thread,
+ model: Optional[str] = None,
+ max_tokens: int = 1024,
+ temperature: float = 1,
+ *,
+ token: Secret = Secret(""),
+ ):
"""
Chat with the ChatNBX API with OpenAI compatability, see more at https://chat.nbox.ai/
@@ -40,45 +96,57 @@ def chat(
Args:
model (str): The model to use, see https://chat.nbox.ai/ for more info
messages (List[Dict[str, str]]): A list of messages to send to the API which are OpenAI compatible
- chatnbx_api_key (Secret, optional): The API key to use or set TUNECHAT_KEY environment variable
+ token (Secret, optional): The API key to use or set TUNEAPI_TOKEN environment variable
max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 1024.
temperature (float, optional): The higher the temperature, the crazier the text. Defaults to 1.
Returns:
Dict[str, Any]: The response from the API
"""
- if not chatnbx_api_key:
- chatnbx_api_key = Secret(Env.TUNECHAT_KEY("")).value # type: ignore
- if not chatnbx_api_key:
+ if not token and not self.tune_api_token.has_value(): # type: ignore
raise Exception(
- "OpenAI API key not found. Please set TUNECHAT_KEY environment variable or pass through function"
+ "Tune API key not found. Please set TUNEAPI_TOKEN environment variable or pass through function"
)
- if isinstance(chats, Chat):
- messages = chats.to_dict()
+
+ token = token or self.tune_api_token
+ if isinstance(chats, Thread):
+ messages = chats.to_dict()["chats"]
else:
messages = chats
model = model or self._tune_model_id
-
- def _fn():
- url = "https://proxy.tune.app/chat/completions"
- headers = {
- "Authorization": chatnbx_api_key,
- "Content-Type": "application/json",
- }
- data = {
- "temperature": temperature,
- "messages": messages,
- "model": model,
- "stream": False,
- "max_tokens": max_tokens,
- }
- response = requests.post(url, headers=headers, json=data)
- return response.json()["choices"][0]["message"]["content"]
-
- return exponential_backoff(
- _fn, max_retries=retry_count, retry_delay=retry_delay
+ url = "https://proxy.tune.app/chat/completions"
+ headers = {
+ "Authorization": token.value,
+ "Content-Type": "application/json",
+ }
+ data = {
+ "temperature": temperature,
+ "messages": messages,
+ "model": model,
+ "stream": True,
+ "max_tokens": max_tokens,
+ }
+ response = requests.post(
+ url,
+ headers=headers,
+ json=data,
+ stream=True,
)
+ try:
+ response.raise_for_status()
+ except Exception as e:
+ print(response.text)
+ raise e
+ for line in response.iter_lines():
+ line = line.decode().strip()
+ if line:
+ try:
+ yield json.loads(line.replace("data: ", ""))["choices"][0]["delta"][
+ "content"
+ ]
+ except:
+ break
tune_model = model_registry.register(model=TuneModel())
diff --git a/chainfury/core/actions.py b/chainfury/core.py
similarity index 60%
rename from chainfury/core/actions.py
rename to chainfury/core.py
index d1d06a25..11370f05 100644
--- a/chainfury/core/actions.py
+++ b/chainfury/core.py
@@ -8,6 +8,7 @@
"""
import copy
+import random
from uuid import uuid4
from typing import Any, List, Optional, Dict, Tuple
@@ -24,7 +25,7 @@
put_value_by_keys,
)
from chainfury.utils import logger
-from chainfury.core.models import model_registry
+
# Programtic Actions Registry
# ---------------------------
@@ -497,6 +498,310 @@ def get_count_for_nodes(self, node_id: str) -> int:
return self.counter.get(node_id, 0)
+# Memory Registry
+# ---------------------------
+# All the components that have to do with storage and retreival of data from the DB. This sections is supppsed to act
+# like the memory in an Von Neumann architecture.
+
+
+class Memory:
+ """Class to wrap the DB functions as a callable.
+
+ Args:
+ node_id (str): The id of the node
+ fn (object): The function that is used for this action
+ vector_key (str): The key for the vector in the DB
+ read_mode (bool, optional): If the function is a read function, if `False` then this is a write function.
+ """
+
+ fields_model = [
+ Var(
+ name="items",
+ type=[Var(type="string"), Var(type="array", items=[Var(type="string")])],
+ required=True,
+ ),
+ Var(name="embedding_model", type="string", required=True),
+ Var(
+ name="embedding_model_params",
+ type="object",
+ additionalProperties=Var(type="string"),
+ ),
+ Var(name="embedding_model_key", type="string"),
+ Var(
+ name="translation_layer",
+ type="object",
+ additionalProperties=Var(type="string"),
+ ),
+ ]
+ """These are the fields that are used to map the input items to the embedding model, do not use directly"""
+
+ def __init__(
+ self, node_id: str, fn: object, vector_key: str, read_mode: bool = False
+ ):
+ self.node_id = node_id
+ self.fn = fn
+ self.vector_key = vector_key
+ self.read_mode = read_mode
+ self.fields_fn = func_to_vars(fn)
+ self.fields = self.fields_fn + self.fields_model
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Serialize the Memory object to a dict."""
+ return {
+ "node_id": self.node_id.split("-")[0],
+ "vector_key": self.vector_key,
+ "read_mode": self.read_mode,
+ }
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]):
+ """Deserialize the Memory object from a dict."""
+ read_mode = data["read_mode"]
+ if read_mode:
+ fn = memory_registry.get_read(data["node_id"])
+ else:
+ fn = memory_registry.get_write(data["node_id"])
+
+ # here we do return Memory type but instead of creating one we use a previously existing Node and return
+ # the fn for the Node which is ultimately this precise Memory object
+ return fn.fn # type: ignore
+
+ def __call__(self, **data: Dict[str, Any]) -> Any:
+ # the first thing we have to do is get the data for the model. This is actually a very hard problem because this
+ # function needs to call some other arbitrary function where we know the inputs to this function "items" but we
+ # do not know which variable to pass this to in the undelying model's function. Thus we need to take in a huge
+ # amount of things as more inputs ("embedding_model_key", "embedding_model_params"). Then we don't even know
+ # what the inputs to the underlying DB functionbare going to be, in which case we also need to add things like
+ # the translation that needs to be done ("translation_layer"). This makes the number of inputs a lot but
+ # ultimately is required to do the job for robust-ness. Which is why we provide a default for openai-embedding
+ # model. For any other model user will need to pass all the information.
+ model_fields: Dict[str, Any] = {}
+ for f in self.fields_model:
+ if f.required and f.name not in data:
+ raise Exception(
+ f"Field '{f.name}' is required in {self.node_id} but not present"
+ )
+ if f.name in data:
+ model_fields[f.name] = data.pop(f.name)
+
+ model_data = {**model_fields.get("embedding_model_params", {})}
+ model_id = model_fields.pop("embedding_model")
+
+ # TODO: @yashbonde - clean this mess up
+ # DEFAULT_MEMORY_CONSTANTS = {
+ # "openai-embedding": {
+ # "embedding_model_key": "input_strings",
+ # "embedding_model_params": {
+ # "model": "text-embedding-ada-002",
+ # },
+ # "translation_layer": {
+ # "embeddings": ["data", "*", "embedding"],
+ # },
+ # }
+ # }
+ # embedding_model_default_config = DEFAULT_MEMORY_CONSTANTS.get(model_id, {})
+ # if embedding_model_default_config:
+ # model_data = {
+ # **embedding_model_default_config.get("embedding_model_params", {}),
+ # **model_data,
+ # }
+ # model_key = embedding_model_default_config.get(
+ # "embedding_model_key", "items"
+ # ) or model_data.get("embedding_model_key")
+ # model_fields["translation_layer"] = model_fields.get(
+ # "translation_layer"
+ # ) or embedding_model_default_config.get("translation_layer")
+ # else:
+
+ req_keys = [x.name for x in self.fields_model[2:]]
+ if not all([x in model_fields for x in req_keys]):
+ raise Exception(f"Model {model_id} requires {req_keys} to be passed")
+ model_key = model_fields.get("embedding_model_key")
+ model_data = {
+ **model_fields.get("embedding_model_params", {}),
+ **model_data,
+ }
+ model_data[model_key] = model_fields.pop("items") # type: ignore
+ model = model_registry.get(model_id)
+ embeddings, err = model(model_data=model_data)
+ if err:
+ logger.error(f"error: {err}")
+ logger.error(f"traceback: {embeddings}")
+ raise err
+
+ # now that we have all the embeddings ready we now need to translate it to be fed into the DB function
+ translated_data = {}
+ for k, v in model_fields.get("translation_layer", {}).items():
+ translated_data[k] = get_value_by_keys(embeddings, v)
+
+ # create the dictionary to call the underlying function
+ db_data = {}
+ for f in self.fields_fn:
+ if f.required and not (f.name in data or f.name in translated_data):
+ raise Exception(
+ f"Field '{f.name}' is required in {self.node_id} but not present"
+ )
+ if f.name in data:
+ db_data[f.name] = data.pop(f.name)
+ if f.name in translated_data:
+ db_data[f.name] = translated_data.pop(f.name)
+ out, err = self.fn(**db_data) # type: ignore
+ return out, err
+
+
+class MemoryRegistry:
+ def __init__(self) -> None:
+ self._memories: Dict[str, Node] = {}
+
+ def register_write(
+ self,
+ component_name: str,
+ fn: object,
+ outputs: Dict[str, Any],
+ vector_key: str,
+ description: str = "",
+ tags: List[str] = [],
+ ) -> Node:
+ node_id = f"{component_name}-write"
+ mem_fn = Memory(node_id=node_id, fn=fn, vector_key=vector_key, read_mode=False)
+ output_fields = func_to_return_vars(fn, returns=outputs)
+ node = Node(
+ id=node_id,
+ fn=mem_fn,
+ type=Node.types.MEMORY,
+ fields=mem_fn.fields,
+ outputs=output_fields,
+ description=description,
+ tags=tags,
+ )
+ self._memories[node_id] = node
+ return node
+
+ def register_read(
+ self,
+ component_name: str,
+ fn: object,
+ outputs: Dict[str, Any],
+ vector_key: str,
+ description: str = "",
+ tags: List[str] = [],
+ ) -> Node:
+ node_id = f"{component_name}-read"
+ mem_fn = Memory(node_id=node_id, fn=fn, vector_key=vector_key, read_mode=True)
+ output_fields = func_to_return_vars(fn, returns=outputs)
+ node = Node(
+ id=node_id,
+ fn=mem_fn,
+ type=Node.types.MEMORY,
+ fields=mem_fn.fields,
+ outputs=output_fields,
+ description=description,
+ tags=tags,
+ )
+ self._memories[node_id] = node
+ return node
+
+ def get_write(self, node_id: str) -> Optional[Node]:
+ out = self._memories.get(node_id + "-write", None)
+ if out is None:
+ raise ValueError(f"Memory '{node_id}' not found")
+ return out
+
+ def get_read(self, node_id: str) -> Optional[Node]:
+ out = self._memories.get(node_id + "-read", None)
+ if out is None:
+ raise ValueError(f"Memory '{node_id}' not found")
+ return out
+
+ def get_nodes(self):
+ return {k: v.to_dict() for k, v in self._memories.items()}
+
+
+# Models Registry
+# ---------------
+# All the things below are for the models that are registered in the model registry, so that they can be used as inputs
+# in the chain. There can be several models that can put as inputs in a single chatbot.
+
+
+class ModelRegistry:
+ """Model registry contains metadata for all the models that are provided in the components"""
+
+ def __init__(self):
+ self.models: Dict[str, Model] = {}
+ self.counter: Dict[str, int] = {}
+ self.tags_to_models: Dict[str, List[str]] = {}
+
+ def has(self, id: str):
+ """A helper function to check if a model is registered or not"""
+ return id in self.models
+
+ def register(self, model: Model):
+ """Register a model in the registry
+
+ Args:
+ model (Model): Model to register
+ """
+ id = model.id
+ logger.debug(f"Registering model {id} at {id}")
+ if id in self.models:
+ raise Exception(f"Model {id} already registered")
+ self.models[id] = model
+ for tag in model.tags:
+ self.tags_to_models[tag] = self.tags_to_models.get(tag, []) + [id]
+ return model
+
+ def get_tags(self) -> List[str]:
+ """Get all the tags that are registered in the registry
+
+ Returns:
+ List[str]: List of tags
+ """
+ return list(self.tags_to_models.keys())
+
+ def get_models(self, tag: str = "") -> Dict[str, Dict[str, Any]]:
+ """Get all the models that are registered in the registry
+
+ Args:
+ tag (str, optional): Filter models by tag. Defaults to "".
+
+ Returns:
+ Dict[str, Dict[str, Any]]: Dictionary of models
+ """
+ items = {k: v.to_dict() for k, v in self.models.items()}
+ if tag:
+ items = {k: v for k, v in items.items() if tag in v.get("tags", [])}
+ return items
+
+ def get(self, id: str) -> Model:
+ """Get a model from the registry
+
+ Args:
+ id (str): Id of the model
+
+ Returns:
+ Model: Model
+ """
+ self.counter[id] = self.counter.get(id, 0) + 1
+ out = self.models.get(id, None)
+ if out is None:
+ raise ValueError(f"Model {id} not found")
+ return out
+
+ def get_count_for_model(self, id: str) -> int:
+ """Get the number of times a model is used
+
+ Args:
+ id (str): Id of the model
+
+ Returns:
+ int: Number of times the model is used
+ """
+ return self.counter.get(id, 0)
+
+ def get_any_model(self) -> Model:
+ return random.choice(list(self.models.values()))
+
+
# Initialise Registries
# ---------------------
@@ -511,3 +816,14 @@ def get_count_for_nodes(self, node_id: str) -> int:
`ai_actions_registry` is a global instance of `AIActionsRegistry` class. This is used to register and unregister
`AIAction` instances. This is used by the server to serve the registered actions.
"""
+
+memory_registry = MemoryRegistry()
+"""
+`memory_registry` is a global instance of MemoryRegistry class. This is used to register and unregister Memory instances.
+This is what the user should use when they want to use the memory elements in their chain.
+"""
+
+model_registry = ModelRegistry()
+"""
+`model_registry` is a global variable that is used to register models. It is an instance of ModelRegistry class.
+"""
diff --git a/chainfury/core/__init__.py b/chainfury/core/__init__.py
deleted file mode 100644
index f24cb1e5..00000000
--- a/chainfury/core/__init__.py
+++ /dev/null
@@ -1,9 +0,0 @@
-# Copyright © 2023- Frello Technology Private Limited
-
-from chainfury.core.models import model_registry
-from chainfury.core.actions import (
- programatic_actions_registry,
- ai_actions_registry,
- AIAction,
-)
-from chainfury.core.memory import memory_registry, Memory
diff --git a/chainfury/core/memory.py b/chainfury/core/memory.py
deleted file mode 100644
index 069cf3ed..00000000
--- a/chainfury/core/memory.py
+++ /dev/null
@@ -1,240 +0,0 @@
-# Copyright © 2023- Frello Technology Private Limited
-
-"""
-Actions
-=======
-
-All actions that the AI can do.
-"""
-
-from typing import Any, List, Optional, Dict
-
-from chainfury.base import (
- Node,
- func_to_return_vars,
- func_to_vars,
- Var,
- get_value_by_keys,
-)
-from chainfury.utils import logger
-from chainfury.core.models import model_registry
-
-
-class Memory:
- """Class to wrap the DB functions as a callable.
-
- Args:
- node_id (str): The id of the node
- fn (object): The function that is used for this action
- vector_key (str): The key for the vector in the DB
- read_mode (bool, optional): If the function is a read function, if `False` then this is a write function.
- """
-
- fields_model = [
- Var(
- name="items",
- type=[Var(type="string"), Var(type="array", items=[Var(type="string")])],
- required=True,
- ),
- Var(name="embedding_model", type="string", required=True),
- Var(
- name="embedding_model_params",
- type="object",
- additionalProperties=Var(type="string"),
- ),
- Var(name="embedding_model_key", type="string"),
- Var(
- name="translation_layer",
- type="object",
- additionalProperties=Var(type="string"),
- ),
- ]
- """These are the fields that are used to map the input items to the embedding model, do not use directly"""
-
- def __init__(
- self, node_id: str, fn: object, vector_key: str, read_mode: bool = False
- ):
- self.node_id = node_id
- self.fn = fn
- self.vector_key = vector_key
- self.read_mode = read_mode
- self.fields_fn = func_to_vars(fn)
- self.fields = self.fields_fn + self.fields_model
-
- def to_dict(self) -> Dict[str, Any]:
- """Serialize the Memory object to a dict."""
- return {
- "node_id": self.node_id.split("-")[0],
- "vector_key": self.vector_key,
- "read_mode": self.read_mode,
- }
-
- @classmethod
- def from_dict(cls, data: Dict[str, Any]):
- """Deserialize the Memory object from a dict."""
- read_mode = data["read_mode"]
- if read_mode:
- fn = memory_registry.get_read(data["node_id"])
- else:
- fn = memory_registry.get_write(data["node_id"])
-
- # here we do return Memory type but instead of creating one we use a previously existing Node and return
- # the fn for the Node which is ultimately this precise Memory object
- return fn.fn # type: ignore
-
- def __call__(self, **data: Dict[str, Any]) -> Any:
- # the first thing we have to do is get the data for the model. This is actually a very hard problem because this
- # function needs to call some other arbitrary function where we know the inputs to this function "items" but we
- # do not know which variable to pass this to in the undelying model's function. Thus we need to take in a huge
- # amount of things as more inputs ("embedding_model_key", "embedding_model_params"). Then we don't even know
- # what the inputs to the underlying DB functionbare going to be, in which case we also need to add things like
- # the translation that needs to be done ("translation_layer"). This makes the number of inputs a lot but
- # ultimately is required to do the job for robust-ness. Which is why we provide a default for openai-embedding
- # model. For any other model user will need to pass all the information.
- model_fields: Dict[str, Any] = {}
- for f in self.fields_model:
- if f.required and f.name not in data:
- raise Exception(
- f"Field '{f.name}' is required in {self.node_id} but not present"
- )
- if f.name in data:
- model_fields[f.name] = data.pop(f.name)
-
- model_data = {**model_fields.get("embedding_model_params", {})}
- model_id = model_fields.pop("embedding_model")
-
- # TODO: @yashbonde - clean this mess up
- # DEFAULT_MEMORY_CONSTANTS = {
- # "openai-embedding": {
- # "embedding_model_key": "input_strings",
- # "embedding_model_params": {
- # "model": "text-embedding-ada-002",
- # },
- # "translation_layer": {
- # "embeddings": ["data", "*", "embedding"],
- # },
- # }
- # }
- # embedding_model_default_config = DEFAULT_MEMORY_CONSTANTS.get(model_id, {})
- # if embedding_model_default_config:
- # model_data = {
- # **embedding_model_default_config.get("embedding_model_params", {}),
- # **model_data,
- # }
- # model_key = embedding_model_default_config.get(
- # "embedding_model_key", "items"
- # ) or model_data.get("embedding_model_key")
- # model_fields["translation_layer"] = model_fields.get(
- # "translation_layer"
- # ) or embedding_model_default_config.get("translation_layer")
- # else:
-
- req_keys = [x.name for x in self.fields_model[2:]]
- if not all([x in model_fields for x in req_keys]):
- raise Exception(f"Model {model_id} requires {req_keys} to be passed")
- model_key = model_fields.get("embedding_model_key")
- model_data = {
- **model_fields.get("embedding_model_params", {}),
- **model_data,
- }
- model_data[model_key] = model_fields.pop("items") # type: ignore
- model = model_registry.get(model_id)
- embeddings, err = model(model_data=model_data)
- if err:
- logger.error(f"error: {err}")
- logger.error(f"traceback: {embeddings}")
- raise err
-
- # now that we have all the embeddings ready we now need to translate it to be fed into the DB function
- translated_data = {}
- for k, v in model_fields.get("translation_layer", {}).items():
- translated_data[k] = get_value_by_keys(embeddings, v)
-
- # create the dictionary to call the underlying function
- db_data = {}
- for f in self.fields_fn:
- if f.required and not (f.name in data or f.name in translated_data):
- raise Exception(
- f"Field '{f.name}' is required in {self.node_id} but not present"
- )
- if f.name in data:
- db_data[f.name] = data.pop(f.name)
- if f.name in translated_data:
- db_data[f.name] = translated_data.pop(f.name)
- out, err = self.fn(**db_data) # type: ignore
- return out, err
-
-
-class MemoryRegistry:
- def __init__(self) -> None:
- self._memories: Dict[str, Node] = {}
-
- def register_write(
- self,
- component_name: str,
- fn: object,
- outputs: Dict[str, Any],
- vector_key: str,
- description: str = "",
- tags: List[str] = [],
- ) -> Node:
- node_id = f"{component_name}-write"
- mem_fn = Memory(node_id=node_id, fn=fn, vector_key=vector_key, read_mode=False)
- output_fields = func_to_return_vars(fn, returns=outputs)
- node = Node(
- id=node_id,
- fn=mem_fn,
- type=Node.types.MEMORY,
- fields=mem_fn.fields,
- outputs=output_fields,
- description=description,
- tags=tags,
- )
- self._memories[node_id] = node
- return node
-
- def register_read(
- self,
- component_name: str,
- fn: object,
- outputs: Dict[str, Any],
- vector_key: str,
- description: str = "",
- tags: List[str] = [],
- ) -> Node:
- node_id = f"{component_name}-read"
- mem_fn = Memory(node_id=node_id, fn=fn, vector_key=vector_key, read_mode=True)
- output_fields = func_to_return_vars(fn, returns=outputs)
- node = Node(
- id=node_id,
- fn=mem_fn,
- type=Node.types.MEMORY,
- fields=mem_fn.fields,
- outputs=output_fields,
- description=description,
- tags=tags,
- )
- self._memories[node_id] = node
- return node
-
- def get_write(self, node_id: str) -> Optional[Node]:
- out = self._memories.get(node_id + "-write", None)
- if out is None:
- raise ValueError(f"Memory '{node_id}' not found")
- return out
-
- def get_read(self, node_id: str) -> Optional[Node]:
- out = self._memories.get(node_id + "-read", None)
- if out is None:
- raise ValueError(f"Memory '{node_id}' not found")
- return out
-
- def get_nodes(self):
- return {k: v.to_dict() for k, v in self._memories.items()}
-
-
-memory_registry = MemoryRegistry()
-"""
-`memory_registry` is a global instance of MemoryRegistry class. This is used to register and unregister Memory instances.
-This is what the user should use when they want to use the memory elements in their chain.
-"""
diff --git a/chainfury/core/models.py b/chainfury/core/models.py
deleted file mode 100644
index e5642bce..00000000
--- a/chainfury/core/models.py
+++ /dev/null
@@ -1,105 +0,0 @@
-# Copyright © 2023- Frello Technology Private Limited
-
-"""
-Models
-======
-
-All things required in a model.
-"""
-
-import random
-from typing import Any, List, Dict
-
-from chainfury.base import Model
-from chainfury.utils import logger
-
-
-# Models
-# ------
-# All the things below are for the models that are registered in the model registry, so that they can be used as inputs
-# in the chain. There can be several models that can put as inputs in a single chatbot.
-
-
-class ModelRegistry:
- """Model registry contains metadata for all the models that are provided in the components"""
-
- def __init__(self):
- self.models: Dict[str, Model] = {}
- self.counter: Dict[str, int] = {}
- self.tags_to_models: Dict[str, List[str]] = {}
-
- def has(self, id: str):
- """A helper function to check if a model is registered or not"""
- return id in self.models
-
- def register(self, model: Model):
- """Register a model in the registry
-
- Args:
- model (Model): Model to register
- """
- id = model.id
- logger.debug(f"Registering model {id} at {id}")
- if id in self.models:
- raise Exception(f"Model {id} already registered")
- self.models[id] = model
- for tag in model.tags:
- self.tags_to_models[tag] = self.tags_to_models.get(tag, []) + [id]
- return model
-
- def get_tags(self) -> List[str]:
- """Get all the tags that are registered in the registry
-
- Returns:
- List[str]: List of tags
- """
- return list(self.tags_to_models.keys())
-
- def get_models(self, tag: str = "") -> Dict[str, Dict[str, Any]]:
- """Get all the models that are registered in the registry
-
- Args:
- tag (str, optional): Filter models by tag. Defaults to "".
-
- Returns:
- Dict[str, Dict[str, Any]]: Dictionary of models
- """
- items = {k: v.to_dict() for k, v in self.models.items()}
- if tag:
- items = {k: v for k, v in items.items() if tag in v.get("tags", [])}
- return items
-
- def get(self, id: str) -> Model:
- """Get a model from the registry
-
- Args:
- id (str): Id of the model
-
- Returns:
- Model: Model
- """
- self.counter[id] = self.counter.get(id, 0) + 1
- out = self.models.get(id, None)
- if out is None:
- raise ValueError(f"Model {id} not found")
- return out
-
- def get_count_for_model(self, id: str) -> int:
- """Get the number of times a model is used
-
- Args:
- id (str): Id of the model
-
- Returns:
- int: Number of times the model is used
- """
- return self.counter.get(id, 0)
-
- def get_any_model(self) -> Model:
- return random.choice(list(self.models.values()))
-
-
-model_registry = ModelRegistry()
-"""
-`model_registry` is a global variable that is used to register models. It is an instance of ModelRegistry class.
-"""
diff --git a/chainfury/types.py b/chainfury/types.py
index a93e8816..185607fd 100644
--- a/chainfury/types.py
+++ b/chainfury/types.py
@@ -2,13 +2,25 @@
from datetime import datetime
from typing import Dict, Any, List, Optional
-
from pydantic import BaseModel, Field, ConfigDict
+# some types that are copied from the tuneapi types
+
+from tuneapi.types.chats import (
+ Message,
+ Thread,
+ ThreadsList,
+ Dataset,
+ human,
+ system,
+ assistant,
+)
+
+
# First is the set of types that are used in the chainfury itself
-class FENode(BaseModel):
+class UINode(BaseModel):
"""FENode is the node as required by the UI to render the node in the graph. If you do not care about the UI, you can
populate either the ``cf_id`` or ``cf_data``."""
@@ -57,14 +69,14 @@ class Edge(BaseModel):
class Dag(BaseModel):
"""This is visual representation of the chain. JSON of this is stored in the DB."""
- nodes: List[FENode]
+ nodes: List[UINode]
edges: List[Edge]
sample: Dict[str, Any] = Field(default_factory=dict)
main_in: str = ""
main_out: str = ""
-class CFPromptResult(BaseModel):
+class ChainResult(BaseModel):
"""This is a structured result of the prompt by the Chain. This is more useful for providing types on the server."""
result: str
@@ -72,6 +84,14 @@ class CFPromptResult(BaseModel):
task_id: str = ""
+# Then a set of types that are used in the API (client mode)
+
+
+class ApiLoginResponse(BaseModel):
+ message: str
+ token: Optional[str] = None
+
+
class ApiResponse(BaseModel):
"""This is the default response body of the API"""
@@ -149,18 +169,18 @@ class ApiActionUpdateRequest(BaseModel):
update_fields: List[str] = Field(description="The fields to update.")
-class ApiAuth(BaseModel):
+class ApiAuthRequest(BaseModel):
username: str
password: str
-class ApiSignUp(BaseModel):
+class ApiSignUpRequest(BaseModel):
username: str
email: str
password: str
-class ApiChangePassword(BaseModel):
+class ApiChangePasswordRequest(BaseModel):
username: str
old_password: str
new_password: str
@@ -168,3 +188,49 @@ class ApiChangePassword(BaseModel):
class ApiPromptFeedback(BaseModel):
score: int
+
+
+class ApiPromptFeedbackResponse(BaseModel):
+ rating: int
+
+
+class ApiToken(BaseModel):
+ key: str
+ token: str
+ meta: Optional[Dict[str, Any]] = {}
+
+
+class ApiListTokensResponse(BaseModel):
+ tokens: List[ApiToken]
+
+
+class ApiChainLog(BaseModel):
+ id: str
+ created_at: datetime
+ prompt_id: int
+ node_id: str
+ worker_id: str
+ message: Optional[str] = None
+ data: Optional[Dict[str, Any]] = None
+
+
+class ApiListChainLogsResponse(BaseModel):
+ logs: List[ApiChainLog]
+
+
+class ApiPrompt(BaseModel):
+ id: int
+ chatbot_id: str
+ input_prompt: str
+ created_at: datetime
+ session_id: str
+ meta: Optional[Dict[str, Any]] = None
+ response: Optional[str] = None
+ gpt_rating: Optional[str] = None
+ user_rating: Optional[int] = None
+ time_taken: Optional[float] = None
+ num_tokens: Optional[int] = None
+
+
+class ApiListPromptsResponse(BaseModel):
+ prompts: List[ApiPrompt]
diff --git a/chainfury/utils.py b/chainfury/utils.py
index b21ab592..77efe555 100644
--- a/chainfury/utils.py
+++ b/chainfury/utils.py
@@ -32,13 +32,13 @@ class CFEnv:
CF_LOG_LEVEL = lambda: os.getenv("CF_LOG_LEVEL", "info")
CF_FOLDER = lambda: os.path.expanduser(os.getenv("CF_FOLDER", "~/cf"))
+ CF_URL = lambda: os.getenv("CF_URL", "")
+ CF_TOKEN = lambda: os.getenv("CF_TOKEN", "")
CF_BLOB_STORAGE = lambda: os.path.join(CFEnv.CF_FOLDER(), "blob")
CF_BLOB_ENGINE = lambda: os.getenv("CF_BLOB_ENGINE", "local")
CF_BLOB_BUCKET = lambda: os.getenv("CF_BLOB_BUCKET", "")
CF_BLOB_PREFIX = lambda: os.getenv("CF_BLOB_PREFIX", "")
CF_BLOB_AWS_CLOUD_FRONT = lambda: os.getenv("CF_BLOB_AWS_CLOUD_FRONT", "")
- CF_URL = lambda: os.getenv("CF_URL", "")
- CF_TOKEN = lambda: os.getenv("CF_TOKEN", "")
def store_blob(key: str, value: bytes, engine: str = "", bucket: str = "") -> str:
@@ -311,7 +311,7 @@ def threaded_map(
results[i] = res
except Exception as e:
if safe:
- results[i] = e
+ results[i] = e # type: ignore
else:
raise e
return results
@@ -418,6 +418,10 @@ def get_now_float() -> float: # type: ignore
"""Get the current datetime in UTC timezone as a float"""
return SimplerTimes.get_now_datetime().timestamp()
+ def get_now_fp64() -> float: # type: ignore
+ """Get the current datetime in UTC timezone as a float"""
+ return SimplerTimes.get_now_datetime().timestamp()
+
def get_now_i64() -> int: # type: ignore
"""Get the current datetime in UTC timezone as a int"""
return int(SimplerTimes.get_now_datetime().timestamp())
diff --git a/chainfury/version.py b/chainfury/version.py
index c7502faf..7eb5f4cf 100644
--- a/chainfury/version.py
+++ b/chainfury/version.py
@@ -1,6 +1,6 @@
# Copyright © 2023- Frello Technology Private Limited
-__version__ = "1.7.0a1"
+__version__ = "1.7.0a2"
_major, _minor, _patch = __version__.split(".")
_major = int(_major)
_minor = int(_minor)
diff --git a/client/package.json b/client/package.json
index cb194b47..d79be780 100644
--- a/client/package.json
+++ b/client/package.json
@@ -34,10 +34,11 @@
"postcss": "^8.4.31",
"tailwindcss": "^3.3.1",
"typescript": "^4.9.3",
- "vite": "^4.2.3"
+ "vite": "^4.4.12"
},
"resolutions": {
"postcss": "^8.4.31",
- "semver": "^6.3.1"
+ "semver": "^6.3.1",
+ "@babel/traverse": "^7.23.2"
}
}
diff --git a/client/src/redux/services/auth.ts b/client/src/redux/services/auth.ts
index c0e04fdd..c846981a 100644
--- a/client/src/redux/services/auth.ts
+++ b/client/src/redux/services/auth.ts
@@ -49,7 +49,7 @@ export const authApi = createApi({
}
>({
query: ({ score, prompt_id }) => ({
- url: `${BASE_URL}/api/v1/prompts/${prompt_id}/feedback`,
+ url: `${BASE_URL}/api/v1/prompts/${prompt_id}/feedback/`,
method: 'PUT',
body: {
score
@@ -136,7 +136,7 @@ export const authApi = createApi({
}
>({
query: ({ score, prompt_id, chatbot_id }) => ({
- url: `${BASE_URL}/api/prompts/${prompt_id}/feedback`,
+ url: `${BASE_URL}/api/prompts/${prompt_id}/feedback/`,
method: 'PUT',
body: {
score
diff --git a/client/yarn.lock b/client/yarn.lock
index 8090a2b7..bd4e761a 100644
--- a/client/yarn.lock
+++ b/client/yarn.lock
@@ -17,6 +17,14 @@
dependencies:
"@babel/highlight" "^7.18.6"
+"@babel/code-frame@^7.22.13":
+ version "7.22.13"
+ resolved "https://registry.yarnpkg.com/@babel/code-frame/-/code-frame-7.22.13.tgz#e3c1c099402598483b7a8c46a721d1038803755e"
+ integrity sha512-XktuhWlJ5g+3TJXc5upd9Ks1HutSArik6jf2eAjYFyIOf4ej3RN+184cZbzDvbPnuTJIUhPKKJE3cIsYTiAT3w==
+ dependencies:
+ "@babel/highlight" "^7.22.13"
+ chalk "^2.4.2"
+
"@babel/compat-data@^7.21.4":
version "7.21.4"
resolved "https://registry.yarnpkg.com/@babel/compat-data/-/compat-data-7.21.4.tgz#457ffe647c480dff59c2be092fc3acf71195c87f"
@@ -53,6 +61,16 @@
"@jridgewell/trace-mapping" "^0.3.17"
jsesc "^2.5.1"
+"@babel/generator@^7.23.0":
+ version "7.23.0"
+ resolved "https://registry.yarnpkg.com/@babel/generator/-/generator-7.23.0.tgz#df5c386e2218be505b34837acbcb874d7a983420"
+ integrity sha512-lN85QRR+5IbYrMWM6Y4pE/noaQtg4pNiqeNGX60eqOfo6gtEj6uw/JagelB8vVztSd7R6M5n1+PQkDbHbBRU4g==
+ dependencies:
+ "@babel/types" "^7.23.0"
+ "@jridgewell/gen-mapping" "^0.3.2"
+ "@jridgewell/trace-mapping" "^0.3.17"
+ jsesc "^2.5.1"
+
"@babel/helper-compilation-targets@^7.21.4":
version "7.21.4"
resolved "https://registry.yarnpkg.com/@babel/helper-compilation-targets/-/helper-compilation-targets-7.21.4.tgz#770cd1ce0889097ceacb99418ee6934ef0572656"
@@ -69,20 +87,25 @@
resolved "https://registry.yarnpkg.com/@babel/helper-environment-visitor/-/helper-environment-visitor-7.18.9.tgz#0c0cee9b35d2ca190478756865bb3528422f51be"
integrity sha512-3r/aACDJ3fhQ/EVgFy0hpj8oHyHpQc+LPtJoY9SzTThAsStm4Ptegq92vqKoE3vD706ZVFWITnMnxucw+S9Ipg==
-"@babel/helper-function-name@^7.21.0":
- version "7.21.0"
- resolved "https://registry.yarnpkg.com/@babel/helper-function-name/-/helper-function-name-7.21.0.tgz#d552829b10ea9f120969304023cd0645fa00b1b4"
- integrity sha512-HfK1aMRanKHpxemaY2gqBmL04iAPOPRj7DxtNbiDOrJK+gdwkiNRVpCpUJYbUT+aZyemKN8brqTOxzCaG6ExRg==
+"@babel/helper-environment-visitor@^7.22.20":
+ version "7.22.20"
+ resolved "https://registry.yarnpkg.com/@babel/helper-environment-visitor/-/helper-environment-visitor-7.22.20.tgz#96159db61d34a29dba454c959f5ae4a649ba9167"
+ integrity sha512-zfedSIzFhat/gFhWfHtgWvlec0nqB9YEIVrpuwjruLlXfUSnA8cJB0miHKwqDnQ7d32aKo2xt88/xZptwxbfhA==
+
+"@babel/helper-function-name@^7.23.0":
+ version "7.23.0"
+ resolved "https://registry.yarnpkg.com/@babel/helper-function-name/-/helper-function-name-7.23.0.tgz#1f9a3cdbd5b2698a670c30d2735f9af95ed52759"
+ integrity sha512-OErEqsrxjZTJciZ4Oo+eoZqeW9UIiOcuYKRJA4ZAgV9myA+pOXhhmpfNCKjEH/auVfEYVFJ6y1Tc4r0eIApqiw==
dependencies:
- "@babel/template" "^7.20.7"
- "@babel/types" "^7.21.0"
+ "@babel/template" "^7.22.15"
+ "@babel/types" "^7.23.0"
-"@babel/helper-hoist-variables@^7.18.6":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/helper-hoist-variables/-/helper-hoist-variables-7.18.6.tgz#d4d2c8fb4baeaa5c68b99cc8245c56554f926678"
- integrity sha512-UlJQPkFqFULIcyW5sbzgbkxn2FKRgwWiRexcuaR8RNJRy8+LLveqPjwZV/bwrLZCN0eUHD/x8D0heK1ozuoo6Q==
+"@babel/helper-hoist-variables@^7.22.5":
+ version "7.22.5"
+ resolved "https://registry.yarnpkg.com/@babel/helper-hoist-variables/-/helper-hoist-variables-7.22.5.tgz#c01a007dac05c085914e8fb652b339db50d823bb"
+ integrity sha512-wGjk9QZVzvknA6yKIUURb8zY3grXCcOZt+/7Wcy8O2uctxhplmUPkOdlgoNhmdVee2c92JXbf1xpMtVNbfoxRw==
dependencies:
- "@babel/types" "^7.18.6"
+ "@babel/types" "^7.22.5"
"@babel/helper-module-imports@^7.16.7", "@babel/helper-module-imports@^7.18.6":
version "7.21.4"
@@ -124,16 +147,33 @@
dependencies:
"@babel/types" "^7.18.6"
+"@babel/helper-split-export-declaration@^7.22.6":
+ version "7.22.6"
+ resolved "https://registry.yarnpkg.com/@babel/helper-split-export-declaration/-/helper-split-export-declaration-7.22.6.tgz#322c61b7310c0997fe4c323955667f18fcefb91c"
+ integrity sha512-AsUnxuLhRYsisFiaJwvp1QF+I3KjD5FOxut14q/GzovUe6orHLesW2C7d754kRm53h5gqrz6sFl6sxc4BVtE/g==
+ dependencies:
+ "@babel/types" "^7.22.5"
+
"@babel/helper-string-parser@^7.19.4":
version "7.19.4"
resolved "https://registry.yarnpkg.com/@babel/helper-string-parser/-/helper-string-parser-7.19.4.tgz#38d3acb654b4701a9b77fb0615a96f775c3a9e63"
integrity sha512-nHtDoQcuqFmwYNYPz3Rah5ph2p8PFeFCsZk9A/48dPc/rGocJ5J3hAAZ7pb76VWX3fZKu+uEr/FhH5jLx7umrw==
+"@babel/helper-string-parser@^7.22.5":
+ version "7.22.5"
+ resolved "https://registry.yarnpkg.com/@babel/helper-string-parser/-/helper-string-parser-7.22.5.tgz#533f36457a25814cf1df6488523ad547d784a99f"
+ integrity sha512-mM4COjgZox8U+JcXQwPijIZLElkgEpO5rsERVDJTc2qfCDfERyob6k5WegS14SX18IIjv+XD+GrqNumY5JRCDw==
+
"@babel/helper-validator-identifier@^7.18.6", "@babel/helper-validator-identifier@^7.19.1":
version "7.19.1"
resolved "https://registry.yarnpkg.com/@babel/helper-validator-identifier/-/helper-validator-identifier-7.19.1.tgz#7eea834cf32901ffdc1a7ee555e2f9c27e249ca2"
integrity sha512-awrNfaMtnHUr653GgGEs++LlAvW6w+DcPrOliSMXWCKo597CwL5Acf/wWdNkf/tfEQE3mjkeD1YOVZOUV/od1w==
+"@babel/helper-validator-identifier@^7.22.20":
+ version "7.22.20"
+ resolved "https://registry.yarnpkg.com/@babel/helper-validator-identifier/-/helper-validator-identifier-7.22.20.tgz#c4ae002c61d2879e724581d96665583dbc1dc0e0"
+ integrity sha512-Y4OZ+ytlatR8AI+8KZfKuL5urKp7qey08ha31L8b3BwewJAoJamTzyvxPR/5D+KkdJCGPq/+8TukHBlY10FX9A==
+
"@babel/helper-validator-option@^7.21.0":
version "7.21.0"
resolved "https://registry.yarnpkg.com/@babel/helper-validator-option/-/helper-validator-option-7.21.0.tgz#8224c7e13ace4bafdc4004da2cf064ef42673180"
@@ -157,11 +197,25 @@
chalk "^2.0.0"
js-tokens "^4.0.0"
+"@babel/highlight@^7.22.13":
+ version "7.22.20"
+ resolved "https://registry.yarnpkg.com/@babel/highlight/-/highlight-7.22.20.tgz#4ca92b71d80554b01427815e06f2df965b9c1f54"
+ integrity sha512-dkdMCN3py0+ksCgYmGG8jKeGA/8Tk+gJwSYYlFGxG5lmhfKNoAy004YpLxpS1W2J8m/EK2Ew+yOs9pVRwO89mg==
+ dependencies:
+ "@babel/helper-validator-identifier" "^7.22.20"
+ chalk "^2.4.2"
+ js-tokens "^4.0.0"
+
"@babel/parser@^7.20.7", "@babel/parser@^7.21.4":
version "7.21.4"
resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.21.4.tgz#94003fdfc520bbe2875d4ae557b43ddb6d880f17"
integrity sha512-alVJj7k7zIxqBZ7BTRhz0IqJFxW1VJbm6N8JbcYhQ186df9ZBPbZBmWSqAMXwHGsCJdYks7z/voa3ibiS5bCIw==
+"@babel/parser@^7.22.15", "@babel/parser@^7.23.0":
+ version "7.23.0"
+ resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.23.0.tgz#da950e622420bf96ca0d0f2909cdddac3acd8719"
+ integrity sha512-vvPKKdMemU85V9WE/l5wZEmImpCtLqbnTvqDS2U1fJ96KrxoW7KrXhNsNCblQlg8Ck4b85yxdTyelsMUgFUXiw==
+
"@babel/plugin-transform-react-jsx-self@^7.18.6":
version "7.21.0"
resolved "https://registry.yarnpkg.com/@babel/plugin-transform-react-jsx-self/-/plugin-transform-react-jsx-self-7.21.0.tgz#ec98d4a9baafc5a1eb398da4cf94afbb40254a54"
@@ -192,19 +246,28 @@
"@babel/parser" "^7.20.7"
"@babel/types" "^7.20.7"
-"@babel/traverse@^7.21.0", "@babel/traverse@^7.21.2", "@babel/traverse@^7.21.4":
- version "7.21.4"
- resolved "https://registry.yarnpkg.com/@babel/traverse/-/traverse-7.21.4.tgz#a836aca7b116634e97a6ed99976236b3282c9d36"
- integrity sha512-eyKrRHKdyZxqDm+fV1iqL9UAHMoIg0nDaGqfIOd8rKH17m5snv7Gn4qgjBoFfLz9APvjFU/ICT00NVCv1Epp8Q==
- dependencies:
- "@babel/code-frame" "^7.21.4"
- "@babel/generator" "^7.21.4"
- "@babel/helper-environment-visitor" "^7.18.9"
- "@babel/helper-function-name" "^7.21.0"
- "@babel/helper-hoist-variables" "^7.18.6"
- "@babel/helper-split-export-declaration" "^7.18.6"
- "@babel/parser" "^7.21.4"
- "@babel/types" "^7.21.4"
+"@babel/template@^7.22.15":
+ version "7.22.15"
+ resolved "https://registry.yarnpkg.com/@babel/template/-/template-7.22.15.tgz#09576efc3830f0430f4548ef971dde1350ef2f38"
+ integrity sha512-QPErUVm4uyJa60rkI73qneDacvdvzxshT3kksGqlGWYdOTIUOwJ7RDUL8sGqslY1uXWSL6xMFKEXDS3ox2uF0w==
+ dependencies:
+ "@babel/code-frame" "^7.22.13"
+ "@babel/parser" "^7.22.15"
+ "@babel/types" "^7.22.15"
+
+"@babel/traverse@^7.21.0", "@babel/traverse@^7.21.2", "@babel/traverse@^7.21.4", "@babel/traverse@^7.23.2":
+ version "7.23.2"
+ resolved "https://registry.yarnpkg.com/@babel/traverse/-/traverse-7.23.2.tgz#329c7a06735e144a506bdb2cad0268b7f46f4ad8"
+ integrity sha512-azpe59SQ48qG6nu2CzcMLbxUudtN+dOM9kDbUqGq3HXUJRlo7i8fvPoxQUzYgLZ4cMVmuZgm8vvBpNeRhd6XSw==
+ dependencies:
+ "@babel/code-frame" "^7.22.13"
+ "@babel/generator" "^7.23.0"
+ "@babel/helper-environment-visitor" "^7.22.20"
+ "@babel/helper-function-name" "^7.23.0"
+ "@babel/helper-hoist-variables" "^7.22.5"
+ "@babel/helper-split-export-declaration" "^7.22.6"
+ "@babel/parser" "^7.23.0"
+ "@babel/types" "^7.23.0"
debug "^4.1.0"
globals "^11.1.0"
@@ -217,6 +280,15 @@
"@babel/helper-validator-identifier" "^7.19.1"
to-fast-properties "^2.0.0"
+"@babel/types@^7.22.15", "@babel/types@^7.22.5", "@babel/types@^7.23.0":
+ version "7.23.0"
+ resolved "https://registry.yarnpkg.com/@babel/types/-/types-7.23.0.tgz#8c1f020c9df0e737e4e247c0619f58c68458aaeb"
+ integrity sha512-0oIyUfKoI3mSqMvsxBdclDwxXKXAUA8v/apZbc+iSyARYou1o8ZGDxbUYyLFoW2arqS2jDGqJuZvv1d/io1axg==
+ dependencies:
+ "@babel/helper-string-parser" "^7.22.5"
+ "@babel/helper-validator-identifier" "^7.22.20"
+ to-fast-properties "^2.0.0"
+
"@emotion/babel-plugin@^11.10.6":
version "11.10.6"
resolved "https://registry.yarnpkg.com/@emotion/babel-plugin/-/babel-plugin-11.10.6.tgz#a68ee4b019d661d6f37dec4b8903255766925ead"
@@ -1060,7 +1132,7 @@ caniuse-lite@^1.0.30001449, caniuse-lite@^1.0.30001464:
resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001474.tgz#13b6fe301a831fe666cce8ca4ef89352334133d5"
integrity sha512-iaIZ8gVrWfemh5DG3T9/YqarVZoYf0r188IjaGwx68j4Pf0SGY6CQkmJUIE+NZHkkecQGohzXmBGEwWDr9aM3Q==
-chalk@^2.0.0:
+chalk@^2.0.0, chalk@^2.4.2:
version "2.4.2"
resolved "https://registry.yarnpkg.com/chalk/-/chalk-2.4.2.tgz#cd42541677a54333cf541a49108c1432b44c9424"
integrity sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ==
@@ -2059,10 +2131,10 @@ util-deprecate@^1.0.2:
resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf"
integrity sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==
-vite@^4.2.3:
- version "4.4.11"
- resolved "https://registry.yarnpkg.com/vite/-/vite-4.4.11.tgz#babdb055b08c69cfc4c468072a2e6c9ca62102b0"
- integrity sha512-ksNZJlkcU9b0lBwAGZGGaZHCMqHsc8OpgtoYhsQ4/I2v5cnpmmmqe5pM4nv/4Hn6G/2GhTdj0DhZh2e+Er1q5A==
+vite@^4.4.12:
+ version "4.5.1"
+ resolved "https://registry.yarnpkg.com/vite/-/vite-4.5.1.tgz#3370986e1ed5dbabbf35a6c2e1fb1e18555b968a"
+ integrity sha512-AXXFaAJ8yebyqzoNB9fu2pHoo/nWX+xZlaRwoeYUxEqBO+Zj4msE5G+BhGBll9lYEKv9Hfks52PAF2X7qDYXQA==
dependencies:
esbuild "^0.18.10"
postcss "^8.4.27"
diff --git a/extra/ex1/chain.py b/extra/ex1/chain.py
deleted file mode 100644
index ba890d80..00000000
--- a/extra/ex1/chain.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# Copyright © 2023- Frello Technology Private Limited
-
-from fire import Fire
-
-from chainfury.base import Chain
-from chainfury.chat import human, Message, Chat
-
-from chainfury.components.openai import OpenaiGPTModel
-from chainfury.components.tune import TuneModel
-
-
-def main(q: str, openai: bool = False):
- chain = Chain(
- name="demo-one",
- description=(
- "Building the hardcore example of chain at https://nimbleboxai.github.io/ChainFury/examples/usage-hardcore.html "
- "using threaded chains"
- ),
- main_in="stupid_question",
- main_out="fight_scene/fight_scene",
- default_model=(
- OpenaiGPTModel("gpt-3.5-turbo")
- if openai
- else TuneModel("rohan/mixtral-8x7b-inst-v0-1-32k")
- ),
- )
- print("before:")
- print(chain)
-
- chain = chain.add_thread(
- "character_one",
- Chat(
- [
- human(
- "You were who was running in the middle of desert. You see a McDonald's and the waiter ask a stupid "
- "question like: '{{ stupid_question }}'? You are pissed and you say."
- )
- ]
- ),
- )
-
- chain = chain.add_thread(
- "character_two",
- Chat(
- [
- human(
- "Someone comes upto you in a bar and screams '{{ character_one }}'? You are a bartender give a funny response to it."
- )
- ]
- ),
- )
-
- chain = chain.add_thread(
- "fight_scene",
- Chat(
- [
- human(
- "Two men were fighting in a bar. One yelled '{{ character_one }}'. Other responded by yelling '{{ character_two }}'.\n"
- "Continue this story for 3 more lines."
- )
- ]
- ),
- )
-
- print("---------------")
- print(chain)
-
- print(chain.topo_order)
- for ir, done in chain.stream(q):
- # print(ir)
- pass
-
- print("---------------")
- print(ir)
-
-
-if __name__ == "__main__":
- Fire(main)
diff --git a/pyproject.toml b/pyproject.toml
index abbfbda0..05bedbea 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,14 +1,15 @@
[tool.poetry]
name = "chainfury"
-version = "1.7.0a1"
+version = "1.7.0a2"
description = "ChainFury is a powerful tool that simplifies the creation and management of chains of prompts, making it easier to build complex chat applications using LLMs."
-authors = ["NimbleBox Engineering "]
+authors = ["Tune AI "]
license = "Apache 2.0"
readme = "README.md"
repository = "https://github.com/NimbleBoxAI/ChainFury"
[tool.poetry.dependencies]
python = "^3.9,<3.12"
+tuneapi = "0.1.1"
fire = "0.5.0"
Jinja2 = "3.1.2"
jinja2schema = "0.1.4"
diff --git a/scripts/list_builtins.py b/scripts/list_builtins.py
deleted file mode 100644
index e6afa4b3..00000000
--- a/scripts/list_builtins.py
+++ /dev/null
@@ -1,58 +0,0 @@
-from fire import Fire
-import jinja2 as j2
-from chainfury import programatic_actions_registry, ai_actions_registry, memory_registry, model_registry
-
-
-def main(src_file: str, trg_file: str, v: bool = False):
- with open(src_file, "r") as f:
- temp = j2.Template(f.read())
-
- # create the components list
- pc = []
- for node_id, node in programatic_actions_registry.nodes.items():
- pc.append(
- {
- "id": node.id,
- "description": node.description.rstrip(".") + f'. Copy: ``programatic_actions_registry.get("{node.id}")``',
- }
- )
-
- ac = []
- for node_id, node in ai_actions_registry.nodes.items():
- ac.append(
- {
- "id": node.id,
- "description": node.description.rstrip(".") + f'. Copy: ``ai_actions_registry.get("{node.id}")``',
- }
- )
-
- mc = []
- for node_id, node in memory_registry._memories.items():
- fn = "get_read" if node.id.endswith("-read") else "get_write"
- mc.append(
- {
- "id": node.id,
- "description": node.description.rstrip(".") + f'. Copy: ``memory_registry.{fn}("{node.id}")``',
- }
- )
-
- moc = []
- for model_id, model in model_registry.models.items():
- moc.append(
- {
- "id": model_id,
- "description": model.description.rstrip(".") + f'. Copy: ``model_registry.get("{model_id}")``',
- }
- )
-
- op = temp.render(pc=pc, ac=ac, mc=mc, moc=moc)
- if v:
- print(op)
- print("Writing to", trg_file)
-
- with open(trg_file, "w") as f:
- f.write(op)
-
-
-if __name__ == "__main__":
- Fire(main)
diff --git a/server/chainfury_server/__main__.py b/server/chainfury_server/__main__.py
index 54d574f1..c723015a 100644
--- a/server/chainfury_server/__main__.py
+++ b/server/chainfury_server/__main__.py
@@ -10,17 +10,6 @@
if os.path.exists(_dotenv_fp):
dotenv.load_dotenv(_dotenv_fp)
-__CF_LOGO = """
- ___ _ _ ___
- / __| |_ __ _(_)_ _ | __| _ _ _ _ _
-| (__| ' \/ _` | | ' \ | _| || | '_| || |
- \___|_||_\__,_|_|_||_||_| \_,_|_| \_, |
- |__/
-e0 a4 b8 e0 a4 a4 e0 a5 8d e0 a4 af e0 a4
-ae e0 a5 87 e0 a4 b5 20 e0 a4 9c e0 a4 af
- e0 a4 a4 e0 a5 87
-"""
-
def main(
host: str = "0.0.0.0",
@@ -38,19 +27,18 @@ def main(
post (List[str], optional): List of modules to load after the server is imported. Defaults to [].
"""
# WARNING: ensure that nothing is being imported in the utils from chainfury_server
+ from chainfury.cli import CLI
from chainfury_server.utils import logger
- from chainfury.version import __version__ as cf_version
from chainfury_server.version import __version__ as cfs_version
logger.info(
- f"{__CF_LOGO}\n"
+ f"{CLI.info}\n"
f"Starting ChainFury server ...\n"
- f" Host: {host}\n"
- f" Port: {port}\n"
- f" Pre: {pre}\n"
- f" Post: {post}\n"
- f" chainfury version: {cf_version}\n"
- f" cf_server version: {cfs_version}"
+ f" Host: {host}\n"
+ f" Port: {port}\n"
+ f" Pre: {pre}\n"
+ f" Post: {post}\n"
+ f" cf_server: {cfs_version}"
)
# load all things you need to preload the modules
diff --git a/server/chainfury_server/api/chains.py b/server/chainfury_server/api/chains.py
index 1eef5c96..822ca8c5 100644
--- a/server/chainfury_server/api/chains.py
+++ b/server/chainfury_server/api/chains.py
@@ -31,14 +31,14 @@ def create_chain(
return T.ApiResponse(message="Name not specified")
if chatbot_data.dag:
for n in chatbot_data.dag.nodes:
- if len(n.id) > Env.CFS_MAXLEN_CF_NDOE():
+ if len(n.id) > Env.CFS_MAXLEN_CF_NODE():
raise HTTPException(
status_code=400,
- detail=f"Node ID length cannot be more than {Env.CFS_MAXLEN_CF_NDOE()}",
+ detail=f"Node ID length cannot be more than {Env.CFS_MAXLEN_CF_NODE()}",
)
# DB call
- dag = chatbot_data.dag.dict() if chatbot_data.dag else {}
+ dag = chatbot_data.dag.model_dump() if chatbot_data.dag else {}
chatbot = DB.ChatBot(
name=chatbot_data.name,
created_by=user.id,
@@ -51,8 +51,7 @@ def create_chain(
db.refresh(chatbot)
# return
- response = T.ApiChain(**chatbot.to_dict())
- return response
+ return chatbot.to_ApiChain()
def get_chain(
@@ -74,13 +73,13 @@ def get_chain(
]
if tag_id:
filters.append(DB.ChatBot.tag_id == tag_id)
- chatbot = db.query(DB.ChatBot).filter(*filters).first() # type: ignore
+ chatbot: DB.ChatBot = db.query(DB.ChatBot).filter(*filters).first() # type: ignore
if not chatbot:
resp.status_code = 404
return T.ApiResponse(message="ChatBot not found")
# return
- return T.ApiChain(**chatbot.to_dict())
+ return chatbot.to_ApiChain()
def update_chain(
@@ -130,7 +129,7 @@ def update_chain(
db.refresh(chatbot)
# return
- return T.ApiChain(**chatbot.to_dict())
+ return chatbot.to_ApiChain()
def delete_chain(
@@ -186,7 +185,7 @@ def list_chains(
# return
return T.ApiListChainsResponse(
- chatbots=[T.ApiChain(**chatbot.to_dict()) for chatbot in chatbots],
+ chatbots=[chatbot.to_ApiChain() for chatbot in chatbots],
)
@@ -201,12 +200,14 @@ def run_chain(
store_ir: bool = False,
store_io: bool = False,
db: Session = Depends(DB.fastapi_db_session),
-) -> Union[StreamingResponse, T.CFPromptResult, T.ApiResponse]:
+) -> Union[StreamingResponse, T.ChainResult, T.ApiResponse]:
"""
This is the master function to run any chain over the API. This can behave in a bunch of different formats like:
- (default) this will wait for the entire chain to execute and return the response
- if ``stream`` is passed it will give a streaming response with line by line JSON and last response containing ``"done"`` key
- if ``as_task`` is passed then a task ID is received and you can poll for the results at ``/chains/{id}/results`` this supercedes the ``stream``.
+
+ ``as_task`` is not implemented.
"""
# validate user
user = DB.get_user_from_jwt(token=token, db=db)
@@ -244,6 +245,7 @@ def run_chain(
if as_task:
# when run as a task this will return a task ID that will be submitted
+ # raise HTTPException(501, detail="Not implemented yet")
result = engine.submit(
chatbot=chatbot,
prompt=prompt,
diff --git a/server/chainfury_server/api/prompts.py b/server/chainfury_server/api/prompts.py
index 8a7cb0ba..1d3627fc 100644
--- a/server/chainfury_server/api/prompts.py
+++ b/server/chainfury_server/api/prompts.py
@@ -3,7 +3,7 @@
from fastapi import Depends, Header, HTTPException
from fastapi.requests import Request
from fastapi.responses import Response
-from typing import Annotated
+from typing import Annotated, List
from sqlalchemy.orm import Session
import chainfury_server.database as DB
@@ -17,7 +17,7 @@ def list_prompts(
limit: int = 100,
offset: int = 0,
db: Session = Depends(DB.fastapi_db_session),
-):
+) -> T.ApiListPromptsResponse:
# validate user
user = DB.get_user_from_jwt(token=token, db=db)
@@ -25,7 +25,7 @@ def list_prompts(
if limit < 1 or limit > 100:
limit = 100
offset = offset if offset > 0 else 0
- prompts = (
+ prompts: List[DB.Prompt] = (
db.query(DB.Prompt) # type: ignore
.filter(DB.Prompt.chatbot_id == chain_id)
.order_by(DB.Prompt.created_at.desc()) # type: ignore
@@ -33,14 +33,14 @@ def list_prompts(
.offset(offset)
.all()
)
- return {"prompts": [p.to_dict() for p in prompts]}
+ return T.ApiListPromptsResponse(prompts=[p.to_ApiPrompt() for p in prompts])
def get_prompt(
prompt_id: int,
token: Annotated[str, Header()],
db: Session = Depends(DB.fastapi_db_session),
-):
+) -> T.ApiPrompt:
# validate user
user = DB.get_user_from_jwt(token=token, db=db)
@@ -49,14 +49,15 @@ def get_prompt(
if not prompt:
raise HTTPException(status_code=404, detail="Prompt not found")
- return {"prompt": prompt.to_dict()}
+ # return {"prompt": prompt.to_dict()} # before
+ return prompt.to_ApiPrompt()
def delete_prompt(
prompt_id: int,
token: Annotated[str, Header()],
db: Session = Depends(DB.fastapi_db_session),
-):
+) -> T.ApiResponse:
# validate user
user = DB.get_user_from_jwt(token=token, db=db)
@@ -67,7 +68,7 @@ def delete_prompt(
db.delete(prompt)
db.commit()
- return {"msg": f"Prompt: '{prompt_id}' deleted"}
+ return T.ApiResponse(message=f"Prompt '{prompt.id}' deleted")
def prompt_feedback(
@@ -75,7 +76,7 @@ def prompt_feedback(
inputs: T.ApiPromptFeedback,
prompt_id: int,
db: Session = Depends(DB.fastapi_db_session),
-):
+) -> T.ApiPromptFeedbackResponse:
# validate user
user = DB.get_user_from_jwt(token=token, db=db)
@@ -94,4 +95,26 @@ def prompt_feedback(
status_code=404,
detail=f"Unable to find the prompt",
)
- return {"rating": prompt.user_rating}
+ return T.ApiPromptFeedbackResponse(rating=prompt.user_rating) # type: ignore
+
+
+def get_chain_logs(
+ token: Annotated[str, Header()],
+ prompt_id: int,
+ limit: int = 100,
+ offset: int = 0,
+ db: Session = Depends(DB.fastapi_db_session),
+) -> T.ApiListChainLogsResponse:
+ # validate user
+ user = DB.get_user_from_jwt(token=token, db=db)
+
+ # query the DB
+ chainlogs: List[DB.ChainLog] = (
+ db.query(DB.ChainLog) # type: ignore
+ .filter(DB.ChainLog.prompt_id == prompt_id)
+ .order_by(DB.ChainLog.created_at.desc()) # type: ignore
+ .limit(limit)
+ .offset(offset)
+ .all()
+ )
+ return T.ApiListChainLogsResponse(logs=[c.to_ApiChainLog() for c in chainlogs])
diff --git a/server/chainfury_server/api/user.py b/server/chainfury_server/api/user.py
index 726a529d..633124c8 100644
--- a/server/chainfury_server/api/user.py
+++ b/server/chainfury_server/api/user.py
@@ -4,28 +4,35 @@
from fastapi import HTTPException
from passlib.hash import sha256_crypt
from sqlalchemy.orm import Session
-from fastapi import Request, Response, Depends, Header
-from typing import Annotated
+from fastapi import Depends, Header
+from typing import Annotated, List
from chainfury_server.utils import logger, Env
import chainfury_server.database as DB
import chainfury.types as T
+from tuneapi.utils import encrypt, decrypt
-def login(auth: T.ApiAuth, db: Session = Depends(DB.fastapi_db_session)):
+
+def login(
+ auth: T.ApiAuthRequest,
+ db: Session = Depends(DB.fastapi_db_session),
+) -> T.ApiLoginResponse:
user: DB.User = db.query(DB.User).filter(DB.User.username == auth.username).first() # type: ignore
if user is not None and sha256_crypt.verify(auth.password, user.password): # type: ignore
token = jwt.encode(
payload=DB.JWTPayload(username=auth.username, user_id=user.id).to_dict(),
key=Env.JWT_SECRET(),
)
- response = {"msg": "success", "token": token}
+ return T.ApiLoginResponse(message="success", token=token)
else:
- response = {"msg": "failed"}
- return response
+ raise HTTPException(status_code=401, detail="Invalid username or password")
-def sign_up(auth: T.ApiSignUp, db: Session = Depends(DB.fastapi_db_session)):
+def sign_up(
+ auth: T.ApiSignUpRequest,
+ db: Session = Depends(DB.fastapi_db_session),
+) -> T.ApiLoginResponse:
user_exists = False
email_exists = False
user: DB.User = db.query(DB.User).filter(DB.User.username == auth.username).first() # type: ignore
@@ -36,7 +43,8 @@ def sign_up(auth: T.ApiSignUp, db: Session = Depends(DB.fastapi_db_session)):
email_exists = True
if user_exists and email_exists:
raise HTTPException(
- status_code=400, detail="Username and email already registered"
+ status_code=400,
+ detail="Username and email already registered",
)
elif user_exists:
raise HTTPException(status_code=400, detail="Username is taken")
@@ -54,17 +62,14 @@ def sign_up(auth: T.ApiSignUp, db: Session = Depends(DB.fastapi_db_session)):
payload=DB.JWTPayload(username=auth.username, user_id=user.id).to_dict(),
key=Env.JWT_SECRET(),
)
- response = {"msg": "success", "token": token}
+ return T.ApiLoginResponse(message="success", token=token)
else:
- response = {"msg": "failed"}
- return response
+ raise HTTPException(status_code=500, detail="Unknown error")
def change_password(
- req: Request,
- resp: Response,
token: Annotated[str, Header()],
- inputs: T.ApiChangePassword,
+ inputs: T.ApiChangePasswordRequest,
db: Session = Depends(DB.fastapi_db_session),
) -> T.ApiResponse:
# validate user
@@ -76,5 +81,111 @@ def change_password(
db.commit()
return T.ApiResponse(message="success")
else:
- resp.status_code = 400
- return T.ApiResponse(message="password incorrect")
+ raise HTTPException(status_code=401, detail="Invalid old password")
+
+
+def create_secret(
+ token: Annotated[str, Header()],
+ inputs: T.ApiToken,
+ db: Session = Depends(DB.fastapi_db_session),
+) -> T.ApiResponse:
+ # validate user
+ user = DB.get_user_from_jwt(token=token, db=db)
+
+ # validate inputs
+ if len(inputs.token) >= DB.Tokens.MAXLEN_TOKEN:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Token too long, should be less than {DB.Tokens.MAXLEN_TOKEN} characters",
+ )
+ if len(inputs.key) >= DB.Tokens.MAXLEN_KEY:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Key too long, should be less than {DB.Tokens.MAXLEN_KEY} characters",
+ )
+
+ cfs_secrets_password = Env.CFS_SECRETS_PASSWORD()
+ if cfs_secrets_password is None:
+ logger.error("CFS_TOKEN_PASSWORD not set, cannot create secrets")
+ raise HTTPException(500, "internal server error")
+
+ # create a token
+ token = DB.Tokens(
+ user_id=user.id,
+ key=inputs.key,
+ value=encrypt(inputs.token, cfs_secrets_password, user.id).decode("utf-8"),
+ meta=inputs.meta,
+ ) # type: ignore
+ db.add(token)
+ db.commit()
+ return T.ApiResponse(message="success")
+
+
+def get_secret(
+ key: str,
+ token: Annotated[str, Header()],
+ db: Session = Depends(DB.fastapi_db_session),
+) -> T.ApiToken:
+ # validate user
+ user = DB.get_user_from_jwt(token=token, db=db)
+
+ db_token: DB.Tokens = db.query(DB.Tokens).filter(DB.Tokens.key == key, user.id == user.id).first() # type: ignore
+ if db_token is None:
+ raise HTTPException(status_code=404, detail="Token not found")
+
+ cfs_token = Env.CFS_SECRETS_PASSWORD()
+ if cfs_token is None:
+ logger.error("CFS_TOKEN_PASSWORD not set, cannot create secrets")
+ raise HTTPException(500, "internal server error")
+
+ try:
+ db_token.value = decrypt(db_token.value, cfs_token, user.id)
+ except Exception as e:
+ raise HTTPException(status_code=401, detail="Cannot get token")
+ return db_token.to_ApiToken()
+
+
+def list_secret(
+ token: Annotated[str, Header()],
+ limit: int = 100,
+ offset: int = 0,
+ db: Session = Depends(DB.fastapi_db_session),
+) -> T.ApiListTokensResponse:
+ """Returns a list of token keys, and metadata. The token values are not returned."""
+ # validate user
+ user = DB.get_user_from_jwt(token=token, db=db)
+
+ # get tokens
+ tokens: List[DB.Tokens] = (
+ db.query(DB.Tokens)
+ .filter(DB.Tokens.user_id == user.id) # type: ignore
+ .limit(limit)
+ .offset(offset)
+ .all()
+ )
+ tokens_resp = []
+ for t in tokens:
+ tok = t.to_ApiToken()
+ tok.token = ""
+ tokens_resp.append(tok)
+ return T.ApiListTokensResponse(tokens=tokens_resp)
+
+
+def delete_secret(
+ key: str,
+ token: Annotated[str, Header()],
+ db: Session = Depends(DB.fastapi_db_session),
+) -> T.ApiResponse:
+ # validate user
+ user = DB.get_user_from_jwt(token=token, db=db)
+
+ # validate the user can access the token
+ _ = get_secret(key=key, token=token, db=db)
+
+ # delete token
+ db_token: DB.Tokens = db.query(DB.Tokens).filter(DB.Tokens.key == key, user.id == user.id).first() # type: ignore
+ if db_token is None:
+ raise HTTPException(status_code=404, detail="Token not found")
+ db.delete(db_token)
+ db.commit()
+ return T.ApiResponse(message="success")
diff --git a/server/chainfury_server/app.py b/server/chainfury_server/app.py
index 883cde4a..6c93e75b 100644
--- a/server/chainfury_server/app.py
+++ b/server/chainfury_server/app.py
@@ -20,7 +20,8 @@
description="""
chainfury server is a way to deploy and run chainfury engine over APIs. `chainfury` is [Tune AI](tunehq.ai)'s FOSS project
released under [Apache-2 License](https://choosealicense.com/licenses/apache-2.0/) so you can use this for your commercial
-projects. A version `chainfury` is used in production in [Tune.Chat](chat.tune.app) and serves thousands of users daily.
+projects. A version `chainfury` is used in production in [Tune.Chat](chat.tune.app), serves and solves thousands of user
+queries daily.
""".strip(),
version=__version__,
docs_url="" if Env.CFS_DISABLE_DOCS() else "/docs",
@@ -42,24 +43,29 @@
app.add_api_route("/api/v1/chatbot/{id}/prompt", api_chains.run_chain, methods=["POST"], tags=["deprecated"], response_model=None) # type: ignore
# user
-app.add_api_route("/user/login/", api_user.login, methods=["POST"], tags=["user"]) # type: ignore
-app.add_api_route("/user/signup/", api_user.sign_up, methods=["POST"], tags=["user"]) # type: ignore
-app.add_api_route("/user/change_password/", api_user.change_password, methods=["POST"], tags=["user"]) # type: ignore
+app.add_api_route(methods=["POST"], path="/user/login/", endpoint=api_user.login, tags=["user"]) # type: ignore
+app.add_api_route(methods=["POST"], path="/user/signup/", endpoint=api_user.sign_up, tags=["user"]) # type: ignore
+app.add_api_route(methods=["POST"], path="/user/change_password/", endpoint=api_user.change_password, tags=["user"]) # type: ignore
+app.add_api_route(methods=["PUT"], path="/user/secret/", endpoint=api_user.create_secret, tags=["user"]) # type: ignore
+app.add_api_route(methods=["GET"], path="/user/secret/", endpoint=api_user.get_secret, tags=["user"]) # type: ignore
+app.add_api_route(methods=["DELETE"], path="/user/secret/", endpoint=api_user.delete_secret, tags=["user"]) # type: ignore
+app.add_api_route(methods=["GET"], path="/user/secret/list/", endpoint=api_user.list_secret, tags=["user"]) # type: ignore
# chains
-app.add_api_route("/api/chains/", api_chains.list_chains, methods=["GET"], tags=["chains"]) # type: ignore
-app.add_api_route("/api/chains/", api_chains.create_chain, methods=["PUT"], tags=["chains"]) # type: ignore
-app.add_api_route("/api/chains/{id}/", api_chains.get_chain, methods=["GET"], tags=["chains"]) # type: ignore
-app.add_api_route("/api/chains/{id}/", api_chains.delete_chain, methods=["DELETE"], tags=["chains"]) # type: ignore
-app.add_api_route("/api/chains/{id}/", api_chains.update_chain, methods=["PATCH"], tags=["chains"]) # type: ignore
-app.add_api_route("/api/chains/{id}/", api_chains.run_chain, methods=["POST"], tags=["chains"], response_model=None) # type: ignore
-app.add_api_route("/api/chains/{id}/metrics/", api_chains.get_chain_metrics, methods=["GET"], tags=["chains"]) # type: ignore
+app.add_api_route(methods=["GET"], path="/api/chains/", endpoint=api_chains.list_chains, tags=["chains"]) # type: ignore
+app.add_api_route(methods=["PUT"], path="/api/chains/", endpoint=api_chains.create_chain, tags=["chains"]) # type: ignore
+app.add_api_route(methods=["GET"], path="/api/chains/{id}/", endpoint=api_chains.get_chain, tags=["chains"]) # type: ignore
+app.add_api_route(methods=["DELETE"], path="/api/chains/{id}/", endpoint=api_chains.delete_chain, tags=["chains"]) # type: ignore
+app.add_api_route(methods=["PATCH"], path="/api/chains/{id}/", endpoint=api_chains.update_chain, tags=["chains"]) # type: ignore
+app.add_api_route(methods=["POST"], path="/api/chains/{id}/", endpoint=api_chains.run_chain, tags=["chains"], response_model=None) # type: ignore
+app.add_api_route(methods=["GET"], path="/api/chains/{id}/metrics/", endpoint=api_chains.get_chain_metrics, tags=["chains"]) # type: ignore
# prompts
-app.add_api_route("/api/prompts/", api_prompts.list_prompts, methods=["GET"], tags=["prompts"]) # type: ignore
-app.add_api_route("/api/prompts/{prompt_id}/", api_prompts.get_prompt, methods=["GET"], tags=["prompts"]) # type: ignore
-app.add_api_route("/api/prompts/{prompt_id}/", api_prompts.delete_prompt, methods=["DELETE"], tags=["prompts"]) # type: ignore
-app.add_api_route("/api/prompts/{prompt_id}/feedback", api_prompts.prompt_feedback, methods=["PUT"], tags=["prompts"]) # type: ignore
+app.add_api_route(methods=["GET"], path="/api/prompts/", endpoint=api_prompts.list_prompts, tags=["prompts"]) # type: ignore
+app.add_api_route(methods=["GET"], path="/api/prompts/{prompt_id}/", endpoint=api_prompts.get_prompt, tags=["prompts"]) # type: ignore
+app.add_api_route(methods=["DELETE"], path="/api/prompts/{prompt_id}/", endpoint=api_prompts.delete_prompt, tags=["prompts"]) # type: ignore
+app.add_api_route(methods=["PUT"], path="/api/prompts/{prompt_id}/feedback/", endpoint=api_prompts.prompt_feedback, tags=["prompts"]) # type: ignore
+app.add_api_route(methods=["GET"], path="/api/prompts/{prompt_id}/logs/", endpoint=api_prompts.get_chain_logs, tags=["prompts"]) # type: ignore
# UI files
diff --git a/server/chainfury_server/database.py b/server/chainfury_server/database.py
index 4d36c15e..799b5381 100644
--- a/server/chainfury_server/database.py
+++ b/server/chainfury_server/database.py
@@ -10,10 +10,10 @@
from dataclasses import dataclass, asdict
from typing import Dict, Any
-from sqlalchemy.pool import QueuePool
+from sqlalchemy.pool import QueuePool, NullPool
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.declarative import declarative_base
-from sqlalchemy.orm import Session, scoped_session, sessionmaker
+from sqlalchemy.orm import Session, scoped_session, sessionmaker, relationship
from sqlalchemy import (
Column,
ForeignKey,
@@ -28,6 +28,7 @@
)
from chainfury_server.utils import logger, Env
+import chainfury.types as T
########
#
@@ -54,6 +55,8 @@
)
else:
logger.info(f"Using via database URL")
+ # https://stackoverflow.com/a/73764136
+ #
engine = create_engine(
db,
poolclass=QueuePool,
@@ -83,7 +86,7 @@ def get_random_number(length) -> int:
return random_numbers
-def get_local_session() -> sessionmaker:
+def get_local_session(engine) -> sessionmaker:
logger.debug("Database opened successfully")
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
return SessionLocal
@@ -100,7 +103,7 @@ def db_session() -> Session: # type: ignore
def fastapi_db_session():
- sess_cls = get_local_session()
+ sess_cls = get_local_session(engine)
db = sess_cls()
try:
yield db
@@ -170,11 +173,38 @@ class User(Base):
username: str = Column(String(80), unique=True, nullable=False)
password: str = Column(String(80), nullable=False)
meta: Dict[str, Any] = Column(JSON)
+ tokens = relationship("Tokens", back_populates="user")
def __repr__(self):
return f"User(id={self.id}, username={self.username}, meta={self.meta})"
+class Tokens(Base):
+ __tablename__ = "tokens"
+
+ MAXLEN_KEY = 80
+ MAXLEN_VAL = 1024
+ MAXLEN_TOKEN = 703 # 703 long string can create 1016 long token
+
+ id = Column(Integer, primary_key=True)
+ user_id = Column(String(ID_LENGTH), ForeignKey("user.id"), nullable=False)
+ key = Column(String(MAXLEN_KEY), nullable=False)
+ value = Column(String(MAXLEN_VAL), nullable=False)
+ meta = Column(JSON, nullable=True)
+ user = relationship("User", back_populates="tokens")
+ # (user_id, key) is a unique constraint
+
+ def __repr__(self):
+ return f"Tokens(id={self.id}, user_id={self.user_id}, key={self.key}, value={self.value[:5]}..., meta={self.meta})"
+
+ def to_ApiToken(self) -> T.ApiToken:
+ return T.ApiToken(
+ key=self.key,
+ token=self.value,
+ meta=self.meta,
+ )
+
+
class ChatBot(Base):
__tablename__ = "chatbot"
@@ -208,6 +238,9 @@ def to_dict(self):
"deleted_at": self.deleted_at,
}
+ def to_ApiChain(self) -> T.ApiChain:
+ return T.ApiChain(**self.to_dict())
+
def __repr__(self):
return f"ChatBot(id={self.id}, name={self.name}, created_by={self.created_by}, dag={self.dag}, meta={self.meta})"
@@ -240,6 +273,9 @@ class Prompt(Base):
session_id: Dict[str, Any] = Column(String(80), nullable=False)
meta: Dict[str, Any] = Column(JSON)
+ # migrate to snowflake ID
+ # sf_id = Column(String(19), nullable=True)
+
def to_dict(self):
return {
"id": self.id,
@@ -255,6 +291,9 @@ def to_dict(self):
"meta": self.meta,
}
+ def to_ApiPrompt(self):
+ return T.ApiPrompt(**self.to_dict())
+
class ChainLog(Base):
__tablename__ = "chain_logs"
@@ -266,11 +305,25 @@ class ChainLog(Base):
)
created_at: datetime = Column(DateTime, nullable=False)
prompt_id: int = Column(Integer, ForeignKey("prompt.id"), nullable=False)
- node_id: str = Column(String(Env.CFS_MAXLEN_CF_NDOE()), nullable=False)
+ node_id: str = Column(String(Env.CFS_MAXLEN_CF_NODE()), nullable=False)
worker_id: str = Column(String(Env.CFS_MAXLEN_WORKER()), nullable=False)
message: str = Column(Text, nullable=False)
data: Dict[str, Any] = Column(JSON, nullable=True)
+ def to_dict(self):
+ return {
+ "id": self.id,
+ "created_at": self.created_at,
+ "prompt_id": self.prompt_id,
+ "node_id": self.node_id,
+ "worker_id": self.worker_id,
+ "message": self.message,
+ "data": self.data,
+ }
+
+ def to_ApiChainLog(self):
+ return T.ApiChainLog(**self.to_dict())
+
class Template(Base):
__tablename__ = "template"
diff --git a/server/chainfury_server/engine.py b/server/chainfury_server/engine.py
index 4c7be63f..ec5c278a 100644
--- a/server/chainfury_server/engine.py
+++ b/server/chainfury_server/engine.py
@@ -15,6 +15,116 @@
import chainfury_server.database as DB
from chainfury_server.utils import logger
+from celery import Celery
+
+from sqlalchemy.pool import NullPool
+from sqlalchemy import create_engine
+
+
+app = Celery()
+
+
+@app.task(name="chainfury_server.engine.run_chain")
+def run_chain(
+ chatbot_id: str,
+ prompt_id: str,
+ prompt_data: Dict,
+ store_ir: bool,
+ store_io: bool,
+ worker_id: str,
+):
+ start = SimplerTimes.get_now_fp64()
+
+ # create the DB session
+ sess = DB.get_local_session(
+ create_engine(
+ DB.db,
+ poolclass=NullPool,
+ )
+ )
+ db = sess()
+
+ # get the db object
+ chatbot = db.query(DB.ChatBot).filter(DB.ChatBot.id == chatbot_id).first() # type: ignore
+ prompt_row: DB.Prompt = db.query(DB.Prompt).filter(DB.Prompt.id == prompt_id).first() # type: ignore
+ if prompt_row is None:
+ time.sleep(2)
+ prompt_row = db.query(DB.Prompt).filter(DB.Prompt.id == prompt_id).first() # type: ignore
+ if prompt_row is None:
+ raise RuntimeError(f"Prompt {prompt_id} not found")
+
+ # Create a Fury chain then run the chain while logging all the intermediate steps
+ dag = T.Dag(**chatbot.dag) # type: ignore
+ chain = Chain.from_dag(dag, check_server=False)
+ callback = FuryThoughtsCallback(db, prompt_row.id)
+
+ # print(
+ # f"starting chain execution: [{prompt_row.meta.get('task_id')=}] [{worker_id=}]"
+ # )
+ iterator = chain.stream(
+ data=prompt_data,
+ thoughts_callback=callback,
+ print_thoughts=False,
+ )
+ mainline_out = ""
+ last_db = 0
+ for ir, done in iterator:
+ if done:
+ mainline_out = ir
+ break
+
+ if store_ir:
+ # in case of stream, every item is a fundamentally a step
+ data = {
+ "outputs": [
+ {
+ "name": k.split("/")[-1],
+ "data": v,
+ }
+ for k, v in ir.items()
+ ]
+ }
+ k = next(iter(ir)).split("/")[0]
+ db_chainlog = DB.ChainLog(
+ prompt_id=prompt_row.id,
+ created_at=SimplerTimes.get_now_datetime(),
+ node_id=k,
+ worker_id=worker_id,
+ message="step",
+ data=data,
+ ) # type: ignore
+ db.add(db_chainlog)
+
+ # update the DB every 5 seconds
+ if time.time() - last_db > 5:
+ db.commit()
+ last_db = time.time()
+
+ result = T.ChainResult(
+ result=str(mainline_out),
+ prompt_id=prompt_row.id, # type: ignore
+ )
+
+ db_chainlog = DB.ChainLog(
+ prompt_id=prompt_row.id,
+ created_at=SimplerTimes.get_now_datetime(),
+ node_id="end",
+ worker_id=worker_id,
+ message="completed",
+ ) # type: ignore
+ db.add(db_chainlog)
+
+ # commit the prompt to DB
+ if store_io:
+ prompt_row.response = result.result # type: ignore
+ prompt_row.time_taken = float(time.time() - start) # type: ignore
+
+ # update the DB after sleeping a bit
+ st = time.time() - last_db
+ if st < 2:
+ time.sleep(2 - st) # be nice to the db
+ db.commit()
+
class FuryEngine:
def run(
@@ -25,7 +135,7 @@ def run(
start: float,
store_ir: bool,
store_io: bool,
- ) -> T.CFPromptResult:
+ ) -> T.ChainResult:
if prompt.new_message and prompt.data:
raise HTTPException(
status_code=400, detail="prompt cannot have both new_message and data"
@@ -37,7 +147,7 @@ def run(
# Create a Fury chain then run the chain while logging all the intermediate steps
dag = T.Dag(**chatbot.dag) # type: ignore
chain = Chain.from_dag(dag, check_server=False)
- callback = FuryThoughts(db, prompt_row.id)
+ callback = FuryThoughtsCallback(db, prompt_row.id)
if prompt.new_message:
prompt.data = {chain.main_in: prompt.new_message}
@@ -47,7 +157,36 @@ def run(
thoughts_callback=callback,
print_thoughts=False,
)
- result = T.CFPromptResult(
+
+ # store the full_ir in the DB.ChainLog
+ if store_ir:
+ # group the logs by node_id
+ chain_logs_by_node = {}
+ for k, v in full_ir.items():
+ node_id, varname = k.split("/")
+ chain_logs_by_node.setdefault(node_id, {"outputs": []})
+ chain_logs_by_node[node_id]["outputs"].append(
+ {
+ "name": varname,
+ "data": v,
+ }
+ )
+
+ # iterate over node ids and create the logs
+ for k, v in chain_logs_by_node.items():
+ db_chainlog = DB.ChainLog(
+ prompt_id=prompt_row.id,
+ created_at=SimplerTimes.get_now_datetime(),
+ node_id=k,
+ worker_id="cf_server",
+ message="step",
+ data=v,
+ ) # type: ignore
+ db.add(db_chainlog)
+ db.commit()
+
+ # create the result
+ result = T.ChainResult(
result=(
json.dumps(mainline_out)
if type(mainline_out) != str
@@ -79,7 +218,7 @@ def stream(
start: float,
store_ir: bool,
store_io: bool,
- ) -> Generator[Tuple[Union[T.CFPromptResult, Dict[str, Any]], bool], None, None]:
+ ) -> Generator[Tuple[Union[T.ChainResult, Dict[str, Any]], bool], None, None]:
if prompt.new_message and prompt.data:
raise HTTPException(
status_code=400, detail="prompt cannot have both new_message and data"
@@ -91,7 +230,7 @@ def stream(
# Create a Fury chain then run the chain while logging all the intermediate steps
dag = T.Dag(**chatbot.dag) # type: ignore
chain = Chain.from_dag(dag, check_server=False)
- callback = FuryThoughts(db, prompt_row.id)
+ callback = FuryThoughtsCallback(db, prompt_row.id)
if prompt.new_message:
prompt.data = {chain.main_in: prompt.new_message}
@@ -111,7 +250,29 @@ def stream(
mainline_out = ir
yield ir, False
- result = T.CFPromptResult(
+ if store_ir:
+ # in case of stream, every item is a fundamentally a step
+ data = {
+ "outputs": [
+ {
+ "name": k.split("/")[-1],
+ "data": v,
+ }
+ for k, v in ir.items()
+ ]
+ }
+ k = next(iter(ir)).split("/")[0]
+ db_chainlog = DB.ChainLog(
+ prompt_id=prompt_row.id,
+ created_at=SimplerTimes.get_now_datetime(),
+ node_id=k,
+ worker_id="cf_server",
+ message="step",
+ data=data,
+ ) # type: ignore
+ db.add(db_chainlog)
+
+ result = T.ChainResult(
result=str(mainline_out),
prompt_id=prompt_row.id, # type: ignore
)
@@ -138,7 +299,7 @@ def submit(
start: float,
store_ir: bool,
store_io: bool,
- ) -> T.CFPromptResult:
+ ) -> T.ChainResult:
if prompt.new_message and prompt.data:
raise HTTPException(
status_code=400, detail="prompt cannot have both new_message and data"
@@ -155,7 +316,18 @@ def submit(
# call the chain
task_id: str = str(uuid4())
- result = T.CFPromptResult(
+ worker_id = task_id.split("-")[0]
+
+ db_chainlog = DB.ChainLog(
+ prompt_id=prompt_row.id,
+ created_at=SimplerTimes.get_now_datetime(),
+ node_id="init",
+ worker_id=worker_id,
+ message=f"scheduling task {task_id}",
+ ) # type: ignore
+ db.add(db_chainlog)
+
+ result = T.ChainResult(
result=f"Task '{task_id}' scheduled",
prompt_id=prompt_row.id,
task_id=task_id,
@@ -163,8 +335,26 @@ def submit(
if store_io:
prompt_row.response = result.result # type: ignore
prompt_row.time_taken = float(time.time() - start) # type: ignore
- db.commit()
+ prompt_row.meta = {"task_id": task_id} # type: ignore
+
+ app.send_task(
+ "chainfury_server.engine.run_chain",
+ queue="cfs",
+ kwargs={
+ "chatbot_id": chatbot.id,
+ "prompt_id": prompt_row.id,
+ "prompt_data": prompt.data,
+ "store_ir": store_ir,
+ "store_io": store_io,
+ "worker_id": worker_id,
+ },
+ task_id=task_id,
+ expires=600, # 10 mins
+ time_limit=240, # 4 mins
+ soft_time_limit=60, # 1 min
+ )
+ db.commit()
return result
except Exception as e:
@@ -173,12 +363,10 @@ def submit(
raise HTTPException(status_code=500, detail=str(e)) from e
-# engine_registry.register(FuryEngine())
-
# helpers
-class FuryThoughts:
+class FuryThoughtsCallback:
def __init__(self, db, prompt_id):
self.db = db
self.prompt_id = prompt_id
@@ -194,28 +382,11 @@ def __call__(self, thought):
self.count += 1
-# def create_intermediate_steps(
-# db: Session,
-# prompt_id: int,
-# intermediate_prompt: str = "",
-# intermediate_response: str = "",
-# response_json: Dict = {},
-# ) -> DB.IntermediateStep:
-# db_prompt = DB.IntermediateStep(
-# prompt_id=prompt_id,
-# intermediate_prompt=intermediate_prompt,
-# intermediate_response=intermediate_response,
-# response_json=response_json,
-# created_at=SimplerTimes.get_now_datetime(),
-# ) # type: ignore
-# db.add(db_prompt)
-# db.commit()
-# db.refresh(db_prompt)
-# return db_prompt
-
-
def create_prompt(
- db: Session, chatbot_id: str, input_prompt: str, session_id: str
+ db: Session,
+ chatbot_id: str,
+ input_prompt: str,
+ session_id: str,
) -> DB.Prompt:
db_prompt = DB.Prompt(
chatbot_id=chatbot_id,
diff --git a/server/chainfury_server/utils.py b/server/chainfury_server/utils.py
index 34f7faad..22c5832d 100644
--- a/server/chainfury_server/utils.py
+++ b/server/chainfury_server/utils.py
@@ -1,7 +1,8 @@
# Copyright © 2023- Frello Technology Private Limited
import os
-import logging
+from Cryptodome.Cipher import AES
+from base64 import b64decode, b64encode
# WARNING: do not import anything from anywhere here, this is the place where chainfury_server starts.
# importing anything can cause the --pre and --post flags to fail when starting server.
@@ -14,18 +15,18 @@
class Env:
"""
Single namespace for all environment variables.
-
- * CFS_DATABASE: database connection string
- * JWT_SECRET: secret for JWT tokens
"""
# once a lifetime secret
JWT_SECRET = lambda: os.getenv("JWT_SECRET", "hajime-shimamoto")
+ CFS_SECRETS_PASSWORD = lambda: os.getenv("CFS_SECRETS_PASSWORDs")
+
+ # not once a lifetime but require DB changes, might as well not change these
+ CFS_MAXLEN_CF_NODE = lambda: int(os.getenv("CFS_MAXLEN_CF_NODE", 80))
+ CFS_MAXLEN_WORKER = lambda: int(os.getenv("CFS_MAXLEN_WORKER", 16))
# when you want to use chainfury as a client you need to set the following vars
CFS_DATABASE = lambda: os.getenv("CFS_DATABASE", None)
- CFS_MAXLEN_CF_NDOE = lambda: int(os.getenv("CFS_MAXLEN_CF_NDOE", 80))
- CFS_MAXLEN_WORKER = lambda: int(os.getenv("CFS_MAXLEN_WORKER", 16))
CFS_ALLOW_CORS_ORIGINS = lambda: [
x.strip() for x in os.getenv("CFS_ALLOW_CORS_ORIGINS", "*").split(",")
]
@@ -47,3 +48,61 @@ def folder(x: str) -> str:
def joinp(x: str, *args) -> str:
"""convienience function for os.path.join"""
return os.path.join(x, *args)
+
+
+class Crypt:
+
+ def __init__(self, salt="SlTKeYOpHygTYkP3"):
+ self.salt = salt.encode("utf8")
+ self.enc_dec_method = "utf-8"
+
+ def encrypt(self, str_to_enc, str_key):
+ try:
+ aes_obj = AES.new(str_key.encode("utf-8"), AES.MODE_CFB, self.salt)
+ hx_enc = aes_obj.encrypt(str_to_enc.encode("utf8"))
+ mret = b64encode(hx_enc).decode(self.enc_dec_method)
+ return mret
+ except ValueError as value_error:
+ if value_error.args[0] == "IV must be 16 bytes long":
+ raise ValueError("Encryption Error: SALT must be 16 characters long")
+ elif (
+ value_error.args[0] == "AES key must be either 16, 24, or 32 bytes long"
+ ):
+ raise ValueError(
+ "Encryption Error: Encryption key must be either 16, 24, or 32 characters long"
+ )
+ else:
+ raise ValueError(value_error)
+
+ def decrypt(self, enc_str, str_key):
+ try:
+ aes_obj = AES.new(str_key.encode("utf8"), AES.MODE_CFB, self.salt)
+ str_tmp = b64decode(enc_str.encode(self.enc_dec_method))
+ str_dec = aes_obj.decrypt(str_tmp)
+ mret = str_dec.decode(self.enc_dec_method)
+ return mret
+ except ValueError as value_error:
+ if value_error.args[0] == "IV must be 16 bytes long":
+ raise ValueError("Decryption Error: SALT must be 16 characters long")
+ elif (
+ value_error.args[0] == "AES key must be either 16, 24, or 32 bytes long"
+ ):
+ raise ValueError(
+ "Decryption Error: Encryption key must be either 16, 24, or 32 characters long"
+ )
+ else:
+ raise ValueError(value_error)
+
+
+CURRENT_EPOCH_START = 1705905900000 # UTC timezone
+"""Start of the current epoch, used for generating snowflake ids"""
+
+from snowflake import SnowflakeGenerator
+
+
+class SFGen:
+ def __init__(self, instance, epoch=CURRENT_EPOCH_START):
+ self.gen = SnowflakeGenerator(instance, epoch=epoch)
+
+ def __call__(self):
+ return next(self.gen)
diff --git a/server/chainfury_server/version.py b/server/chainfury_server/version.py
index 635ae716..e23f8b60 100644
--- a/server/chainfury_server/version.py
+++ b/server/chainfury_server/version.py
@@ -1,6 +1,6 @@
# Copyright © 2023- Frello Technology Private Limited
-__version__ = "2.1.0"
+__version__ = "2.1.3a0"
_major, _minor, _patch = __version__.split(".")
_major = int(_major)
_minor = int(_minor)
diff --git a/server/pyproject.toml b/server/pyproject.toml
index 6de019e3..dc1342e7 100644
--- a/server/pyproject.toml
+++ b/server/pyproject.toml
@@ -2,8 +2,8 @@
[tool.poetry]
name = "chainfury_server"
-version = "2.1.1"
-description = "ChainFury Server is the server for running ChainFury Engine!"
+version = "2.1.3a0"
+description = "ChainFury Server is the DB + API server for managing the ChainFury engine in production. Used in production at chat.tune.app"
authors = ["Tune AI "]
license = "Apache 2.0"
readme = "README.md"
@@ -22,10 +22,11 @@ SQLAlchemy = "1.4.47"
uvicorn = "0.27.1"
PyMySQL = "1.0.3"
urllib3 = ">=1.26.18"
-"cryptography" = ">=41.0.6"
+cryptography = ">=41.0.6"
+snowflake_id = "1.0.1"
[tool.poetry.scripts]
-chainfury_server = "chainfury_server"
+chainfury_server = "chainfury_server:__main__"
[build-system]
requires = ["poetry-core"]
diff --git a/tests/__main__.py b/tests/__main__.py
deleted file mode 100644
index ab7a2cd7..00000000
--- a/tests/__main__.py
+++ /dev/null
@@ -1,8 +0,0 @@
-# Copyright © 2023- Frello Technology Private Limited
-
-from tests.getkv import TestGetValueByKeys
-from tests.chains import TestChainSerDeser
-import unittest
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/chains.py b/tests/chains.py
deleted file mode 100644
index 335688f2..00000000
--- a/tests/chains.py
+++ /dev/null
@@ -1,33 +0,0 @@
-# Copyright © 2023- Frello Technology Private Limited
-
-from chainfury import programatic_actions_registry, Chain
-
-import unittest
-
-
-chain = Chain(
- name="echo-cf-public",
- description="abyss",
- nodes=[programatic_actions_registry.get("chainfury-echo")], # type: ignore
- sample={"message": "hi there"},
- main_in="message",
- main_out="chainfury-echo/message",
-)
-
-
-class TestChainSerDeser(unittest.TestCase):
- def test_dict(self):
- Chain.from_dict(chain.to_dict())
-
- def test_apidict(self):
- Chain.from_dict(chain.to_dict(api=True))
-
- def test_json(self):
- Chain.from_json(chain.to_json())
-
- def test_dag(self):
- Chain.from_dag(chain.to_dag())
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/main.py b/tests/main.py
new file mode 100644
index 00000000..2a8e4fc7
--- /dev/null
+++ b/tests/main.py
@@ -0,0 +1,15 @@
+# Copyright © 2023- Frello Technology Private Limited
+
+import os
+from tuneapi.utils import folder, joinp
+
+tests = []
+curdir = folder(__file__)
+for x in os.listdir(curdir):
+ if x.startswith("test_") and x.endswith(".py"):
+ tests.append(joinp(curdir, x))
+
+for t in tests:
+ code = os.system(f"python3 {t} -v")
+ if code != 0:
+ raise Exception(f"Test {t} failed with code {code}")
diff --git a/tests/test_base_chain2.py b/tests/test_base_chain2.py
new file mode 100644
index 00000000..611727da
--- /dev/null
+++ b/tests/test_base_chain2.py
@@ -0,0 +1,61 @@
+# Copyright © 2023- Frello Technology Private Limited
+
+from chainfury import (
+ Chain,
+ Thread,
+ human,
+)
+from chainfury.components.tune import TuneModel
+import unittest
+
+
+chain = Chain(
+ name="demo-one",
+ description=(
+ "Building the hardcore example of chain at https://nimbleboxai.github.io/ChainFury/examples/usage-hardcore.html "
+ "using threaded chains"
+ ),
+ main_in="stupid_question",
+ main_out="fight_scene/fight_scene",
+ default_model=TuneModel("rohan/mixtral-8x7b-inst-v0-1-32k"),
+)
+chain.add_thread(
+ "character_one",
+ Thread(
+ human(
+ "You were who was running in the middle of desert. You see a McDonald's and the waiter ask a stupid "
+ "question like: '{{ stupid_question }}'? You are pissed and you say."
+ ),
+ ),
+)
+chain.add_thread(
+ "character_two",
+ Thread(
+ human(
+ "Someone comes upto you in a bar and screams '{{ character_one }}'? You are a bartender give a funny response to it."
+ ),
+ ),
+)
+chain.add_thread(
+ "fight_scene",
+ Thread(
+ human(
+ "Two men were fighting in a bar. One yelled '{{ character_one }}'. Other responded by yelling '{{ character_two }}'.\n"
+ "Continue this story for 3 more lines."
+ )
+ ),
+)
+
+
+class TestChain(unittest.TestCase):
+ """Testing Chain specific functionality"""
+
+ def test_chain_toposort(self):
+ self.assertEqual(
+ chain.topo_order,
+ ["character_one", "character_two", "fight_scene"],
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_base_types.py b/tests/test_base_types.py
new file mode 100644
index 00000000..d081c170
--- /dev/null
+++ b/tests/test_base_types.py
@@ -0,0 +1,111 @@
+# Copyright © 2023- Frello Technology Private Limited
+
+import unittest
+from functools import cache
+from chainfury.components.functional import echo
+from chainfury import programatic_actions_registry, Chain, Var, Tools
+
+
+class TestSerDeser(unittest.TestCase):
+ """Tests Serialisation and Deserialisation of Nodes, Chains and Tools."""
+
+ def test_chain_dict(self):
+ Chain.from_dict(chain.to_dict())
+
+ def test_chain_apidict(self):
+ Chain.from_dict(chain.to_dict(api=True))
+
+ def test_chain_json(self):
+ Chain.from_json(chain.to_json())
+
+ def test_chain_dag(self):
+ Chain.from_dag(chain.to_dag())
+
+ def test_node_dict(self):
+ node = programatic_actions_registry.get("chainfury-echo")
+ if node is None:
+ self.fail("Node not found")
+ self.assertIsNotNone(node)
+ node.from_dict(node.to_dict())
+
+ def test_node_json(self):
+ node = programatic_actions_registry.get("chainfury-echo")
+ if node is None:
+ self.fail("Node not found")
+ self.assertIsNotNone(node)
+ node.from_json(node.to_json())
+
+ def test_tool_dict(self):
+ Tools.from_dict(tool.to_dict())
+
+ def test_tool_json(self):
+ Tools.from_json(tool.to_json())
+
+
+class TestNode(unittest.TestCase):
+ """Test Node specific functionality."""
+
+ def test_node_run(self):
+ node = programatic_actions_registry.get("chainfury-echo")
+ if node is None:
+ self.fail("Node not found")
+ self.assertIsNotNone(node)
+ out, err = node(data={"message": "hi there"})
+ self.assertIsNone(err)
+
+ # call the function directly
+ fn_out, _ = echo("hi there")
+ self.assertEqual(out, {"message": fn_out})
+
+
+#
+# Chain definition
+#
+
+chain = Chain(
+ name="echo-cf-public",
+ description="abyss",
+ nodes=[programatic_actions_registry.get("chainfury-echo")], # type: ignore
+ sample={"message": "hi there"},
+ main_in="message",
+ main_out="chainfury-echo/message",
+)
+
+
+#
+# Tool definition
+#
+tool = Tools(
+ name="calculator",
+ description=(
+ "This tool is a calculator, it can perform basica calculations. "
+ "Use this when you are trying to do some mathematical task"
+ ),
+)
+
+
+@tool.add(
+ description="This function adds two numbers",
+ properties={
+ "a": Var("int", required=True, description="number one"),
+ "b": Var("int", description="number two"),
+ },
+)
+def add_two_numbers(a: int, b: int = 10):
+ return a + b
+
+
+@tool.add(
+ description="This calculates square root of a number",
+ properties={
+ "a": Var("int", description="number to calculate square root of"),
+ },
+)
+def square_root_number(a):
+ import math
+
+ return math.sqrt(a)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/getkv.py b/tests/test_getkv.py
similarity index 100%
rename from tests/getkv.py
rename to tests/test_getkv.py