Skip to content

Commit

Permalink
add more tests + core is a single file again
Browse files Browse the repository at this point in the history
  • Loading branch information
yashbonde committed Mar 4, 2024
1 parent 16b08b8 commit 0f43038
Show file tree
Hide file tree
Showing 8 changed files with 381 additions and 390 deletions.
2 changes: 1 addition & 1 deletion chainfury/components/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def echo(message: str) -> Tuple[Dict[str, Dict[str, str]], Optional[Exception]]:

programatic_actions_registry.register(
fn=echo,
outputs={"message": (0,)}, # type: ignore
outputs={"message": ()}, # type: ignore
node_id="chainfury-echo",
description="I stared into the abyss and it stared back at me. Echoes the message, used for debugging",
)
318 changes: 317 additions & 1 deletion chainfury/core/actions.py → chainfury/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import copy
import random
from uuid import uuid4
from typing import Any, List, Optional, Dict, Tuple

Expand All @@ -24,7 +25,7 @@
put_value_by_keys,
)
from chainfury.utils import logger
from chainfury.core.models import model_registry


# Programtic Actions Registry
# ---------------------------
Expand Down Expand Up @@ -497,6 +498,310 @@ def get_count_for_nodes(self, node_id: str) -> int:
return self.counter.get(node_id, 0)


# Memory Registry
# ---------------------------
# All the components that have to do with storage and retreival of data from the DB. This sections is supppsed to act
# like the memory in an Von Neumann architecture.


class Memory:
"""Class to wrap the DB functions as a callable.
Args:
node_id (str): The id of the node
fn (object): The function that is used for this action
vector_key (str): The key for the vector in the DB
read_mode (bool, optional): If the function is a read function, if `False` then this is a write function.
"""

fields_model = [
Var(
name="items",
type=[Var(type="string"), Var(type="array", items=[Var(type="string")])],
required=True,
),
Var(name="embedding_model", type="string", required=True),
Var(
name="embedding_model_params",
type="object",
additionalProperties=Var(type="string"),
),
Var(name="embedding_model_key", type="string"),
Var(
name="translation_layer",
type="object",
additionalProperties=Var(type="string"),
),
]
"""These are the fields that are used to map the input items to the embedding model, do not use directly"""

def __init__(
self, node_id: str, fn: object, vector_key: str, read_mode: bool = False
):
self.node_id = node_id
self.fn = fn
self.vector_key = vector_key
self.read_mode = read_mode
self.fields_fn = func_to_vars(fn)
self.fields = self.fields_fn + self.fields_model

def to_dict(self) -> Dict[str, Any]:
"""Serialize the Memory object to a dict."""
return {
"node_id": self.node_id.split("-")[0],
"vector_key": self.vector_key,
"read_mode": self.read_mode,
}

@classmethod
def from_dict(cls, data: Dict[str, Any]):
"""Deserialize the Memory object from a dict."""
read_mode = data["read_mode"]
if read_mode:
fn = memory_registry.get_read(data["node_id"])
else:
fn = memory_registry.get_write(data["node_id"])

# here we do return Memory type but instead of creating one we use a previously existing Node and return
# the fn for the Node which is ultimately this precise Memory object
return fn.fn # type: ignore

def __call__(self, **data: Dict[str, Any]) -> Any:
# the first thing we have to do is get the data for the model. This is actually a very hard problem because this
# function needs to call some other arbitrary function where we know the inputs to this function "items" but we
# do not know which variable to pass this to in the undelying model's function. Thus we need to take in a huge
# amount of things as more inputs ("embedding_model_key", "embedding_model_params"). Then we don't even know
# what the inputs to the underlying DB functionbare going to be, in which case we also need to add things like
# the translation that needs to be done ("translation_layer"). This makes the number of inputs a lot but
# ultimately is required to do the job for robust-ness. Which is why we provide a default for openai-embedding
# model. For any other model user will need to pass all the information.
model_fields: Dict[str, Any] = {}
for f in self.fields_model:
if f.required and f.name not in data:
raise Exception(
f"Field '{f.name}' is required in {self.node_id} but not present"
)
if f.name in data:
model_fields[f.name] = data.pop(f.name)

model_data = {**model_fields.get("embedding_model_params", {})}
model_id = model_fields.pop("embedding_model")

# TODO: @yashbonde - clean this mess up
# DEFAULT_MEMORY_CONSTANTS = {
# "openai-embedding": {
# "embedding_model_key": "input_strings",
# "embedding_model_params": {
# "model": "text-embedding-ada-002",
# },
# "translation_layer": {
# "embeddings": ["data", "*", "embedding"],
# },
# }
# }
# embedding_model_default_config = DEFAULT_MEMORY_CONSTANTS.get(model_id, {})
# if embedding_model_default_config:
# model_data = {
# **embedding_model_default_config.get("embedding_model_params", {}),
# **model_data,
# }
# model_key = embedding_model_default_config.get(
# "embedding_model_key", "items"
# ) or model_data.get("embedding_model_key")
# model_fields["translation_layer"] = model_fields.get(
# "translation_layer"
# ) or embedding_model_default_config.get("translation_layer")
# else:

req_keys = [x.name for x in self.fields_model[2:]]
if not all([x in model_fields for x in req_keys]):
raise Exception(f"Model {model_id} requires {req_keys} to be passed")
model_key = model_fields.get("embedding_model_key")
model_data = {
**model_fields.get("embedding_model_params", {}),
**model_data,
}
model_data[model_key] = model_fields.pop("items") # type: ignore
model = model_registry.get(model_id)
embeddings, err = model(model_data=model_data)
if err:
logger.error(f"error: {err}")
logger.error(f"traceback: {embeddings}")
raise err

# now that we have all the embeddings ready we now need to translate it to be fed into the DB function
translated_data = {}
for k, v in model_fields.get("translation_layer", {}).items():
translated_data[k] = get_value_by_keys(embeddings, v)

# create the dictionary to call the underlying function
db_data = {}
for f in self.fields_fn:
if f.required and not (f.name in data or f.name in translated_data):
raise Exception(
f"Field '{f.name}' is required in {self.node_id} but not present"
)
if f.name in data:
db_data[f.name] = data.pop(f.name)
if f.name in translated_data:
db_data[f.name] = translated_data.pop(f.name)
out, err = self.fn(**db_data) # type: ignore
return out, err


class MemoryRegistry:
def __init__(self) -> None:
self._memories: Dict[str, Node] = {}

def register_write(
self,
component_name: str,
fn: object,
outputs: Dict[str, Any],
vector_key: str,
description: str = "",
tags: List[str] = [],
) -> Node:
node_id = f"{component_name}-write"
mem_fn = Memory(node_id=node_id, fn=fn, vector_key=vector_key, read_mode=False)
output_fields = func_to_return_vars(fn, returns=outputs)
node = Node(
id=node_id,
fn=mem_fn,
type=Node.types.MEMORY,
fields=mem_fn.fields,
outputs=output_fields,
description=description,
tags=tags,
)
self._memories[node_id] = node
return node

def register_read(
self,
component_name: str,
fn: object,
outputs: Dict[str, Any],
vector_key: str,
description: str = "",
tags: List[str] = [],
) -> Node:
node_id = f"{component_name}-read"
mem_fn = Memory(node_id=node_id, fn=fn, vector_key=vector_key, read_mode=True)
output_fields = func_to_return_vars(fn, returns=outputs)
node = Node(
id=node_id,
fn=mem_fn,
type=Node.types.MEMORY,
fields=mem_fn.fields,
outputs=output_fields,
description=description,
tags=tags,
)
self._memories[node_id] = node
return node

def get_write(self, node_id: str) -> Optional[Node]:
out = self._memories.get(node_id + "-write", None)
if out is None:
raise ValueError(f"Memory '{node_id}' not found")
return out

def get_read(self, node_id: str) -> Optional[Node]:
out = self._memories.get(node_id + "-read", None)
if out is None:
raise ValueError(f"Memory '{node_id}' not found")
return out

def get_nodes(self):
return {k: v.to_dict() for k, v in self._memories.items()}


# Models Registry
# ---------------
# All the things below are for the models that are registered in the model registry, so that they can be used as inputs
# in the chain. There can be several models that can put as inputs in a single chatbot.


class ModelRegistry:
"""Model registry contains metadata for all the models that are provided in the components"""

def __init__(self):
self.models: Dict[str, Model] = {}
self.counter: Dict[str, int] = {}
self.tags_to_models: Dict[str, List[str]] = {}

def has(self, id: str):
"""A helper function to check if a model is registered or not"""
return id in self.models

def register(self, model: Model):
"""Register a model in the registry
Args:
model (Model): Model to register
"""
id = model.id
logger.debug(f"Registering model {id} at {id}")
if id in self.models:
raise Exception(f"Model {id} already registered")
self.models[id] = model
for tag in model.tags:
self.tags_to_models[tag] = self.tags_to_models.get(tag, []) + [id]
return model

def get_tags(self) -> List[str]:
"""Get all the tags that are registered in the registry
Returns:
List[str]: List of tags
"""
return list(self.tags_to_models.keys())

def get_models(self, tag: str = "") -> Dict[str, Dict[str, Any]]:
"""Get all the models that are registered in the registry
Args:
tag (str, optional): Filter models by tag. Defaults to "".
Returns:
Dict[str, Dict[str, Any]]: Dictionary of models
"""
items = {k: v.to_dict() for k, v in self.models.items()}
if tag:
items = {k: v for k, v in items.items() if tag in v.get("tags", [])}
return items

def get(self, id: str) -> Model:
"""Get a model from the registry
Args:
id (str): Id of the model
Returns:
Model: Model
"""
self.counter[id] = self.counter.get(id, 0) + 1
out = self.models.get(id, None)
if out is None:
raise ValueError(f"Model {id} not found")
return out

def get_count_for_model(self, id: str) -> int:
"""Get the number of times a model is used
Args:
id (str): Id of the model
Returns:
int: Number of times the model is used
"""
return self.counter.get(id, 0)

def get_any_model(self) -> Model:
return random.choice(list(self.models.values()))


# Initialise Registries
# ---------------------

Expand All @@ -511,3 +816,14 @@ def get_count_for_nodes(self, node_id: str) -> int:
`ai_actions_registry` is a global instance of `AIActionsRegistry` class. This is used to register and unregister
`AIAction` instances. This is used by the server to serve the registered actions.
"""

memory_registry = MemoryRegistry()
"""
`memory_registry` is a global instance of MemoryRegistry class. This is used to register and unregister Memory instances.
This is what the user should use when they want to use the memory elements in their chain.
"""

model_registry = ModelRegistry()
"""
`model_registry` is a global variable that is used to register models. It is an instance of ModelRegistry class.
"""
9 changes: 0 additions & 9 deletions chainfury/core/__init__.py

This file was deleted.

Loading

0 comments on commit 0f43038

Please sign in to comment.