Skip to content

Commit

Permalink
[chore] add worker architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
yashbonde committed Mar 13, 2024
1 parent 094a454 commit 2dd79eb
Show file tree
Hide file tree
Showing 9 changed files with 190 additions and 94 deletions.
12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ ae e0 a5 87 e0 a4 b5 20 e0 a4 9c e0 a4 af

The documentation page contains all the information on using `chainfury` and `chainfury_server`.

#### `chainfury`

<img src="https://d2e931syjhr5o9.cloudfront.net/tune-research/assets/cf_arch.png" align="center"/>

#### `chainfury_server`

<img src="https://d2e931syjhr5o9.cloudfront.net/tune-research/assets/cfs_arch.png" align="center"/>

# Looking for Inspirations?

Here's a few example to get your journey started on Software 2.0:
Expand Down Expand Up @@ -86,7 +94,7 @@ source venv/bin/activate
You will need to have `yarn` installed to build the frontend and move it to the correct location on the server

```bash
sh stories/build_and_copy.sh
sh build_ui.sh
```

Once the static files are copied we can now proceed to install dependecies:
Expand All @@ -104,7 +112,7 @@ You can now visit [localhost:8000](http://localhost:8000/ui/) to see the GUI and
There are a few test cases for super hard problems like `get_kv` which checks the `chainfury.base.get_value_by_keys` function.

```bash
python3 -m tests -v
python3 tests/main.py
```

# Contibutions
Expand Down
3 changes: 0 additions & 3 deletions scripts/build_and_copy.sh → build_ui.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ cd client
yarn install
yarn build

# Go back to the root directory
cd ..

# copy the dist folder to the server
# Go into the server folder, remove the old static folder and copy the new dist folder, copy index.html to templates
echo "Copying the generated files to the server"
Expand Down
17 changes: 8 additions & 9 deletions chainfury/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@

from chainfury import Chain
from chainfury.version import __version__
from chainfury.components import all_items
from chainfury.core import model_registry, programatic_actions_registry, memory_registry
from chainfury.chat import Chat, Message
from chainfury.core import model_registry
from chainfury.types import Thread, Message


class CLI:
Expand Down Expand Up @@ -115,7 +114,7 @@ def sh(
cf_model.set_api_token(token)

# loop for user input through command line
chat = Chat()
thread = Thread()
usr_cntr = 0
while True:
try:
Expand All @@ -126,21 +125,21 @@ def sh(
break
if user_input == "exit" or user_input == "quit" or user_input == "":
break
chat.add(Message(user_input, Message.HUMAN))
thread.add(Message(user_input, Message.HUMAN))

print(f"\033[1m\033[34m ASSISTANT \033[39m:\033[0m ", end="", flush=True)
if stream:
response = ""
for str_token in cf_model.stream_chat(chat, model=model):
for str_token in cf_model.stream_chat(thread, model=model):
response += str_token
print(str_token, end="", flush=True)
print() # new line
chat.add(Message(response, Message.GPT))
thread.add(Message(response, Message.GPT))
else:
response = cf_model.chat(chat, model=model)
response = cf_model.chat(thread, model=model)
print(response)

chat.add(Message(response, Message.GPT))
thread.add(Message(response, Message.GPT))
usr_cntr += 1


Expand Down
4 changes: 4 additions & 0 deletions chainfury/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,10 @@ def get_now_float() -> float: # type: ignore
"""Get the current datetime in UTC timezone as a float"""
return SimplerTimes.get_now_datetime().timestamp()

def get_now_fp64() -> float: # type: ignore
"""Get the current datetime in UTC timezone as a float"""
return SimplerTimes.get_now_datetime().timestamp()

def get_now_i64() -> int: # type: ignore
"""Get the current datetime in UTC timezone as a int"""
return int(SimplerTimes.get_now_datetime().timestamp())
Expand Down
58 changes: 0 additions & 58 deletions scripts/list_builtins.py

This file was deleted.

24 changes: 12 additions & 12 deletions server/chainfury_server/api/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def create_chain(
return T.ApiResponse(message="Name not specified")
if chatbot_data.dag:
for n in chatbot_data.dag.nodes:
if len(n.id) > Env.CFS_MAXLEN_CF_NDOE():
if len(n.id) > Env.CFS_MAXLEN_CF_NODE():
raise HTTPException(
status_code=400,
detail=f"Node ID length cannot be more than {Env.CFS_MAXLEN_CF_NDOE()}",
detail=f"Node ID length cannot be more than {Env.CFS_MAXLEN_CF_NODE()}",
)

# DB call
Expand Down Expand Up @@ -245,16 +245,16 @@ def run_chain(

if as_task:
# when run as a task this will return a task ID that will be submitted
raise HTTPException(501, detail="Not implemented yet")
# result = engine.submit(
# chatbot=chatbot,
# prompt=prompt,
# db=db,
# start=time.time(),
# store_ir=store_ir,
# store_io=store_io,
# )
# return result
# raise HTTPException(501, detail="Not implemented yet")
result = engine.submit(
chatbot=chatbot,
prompt=prompt,
db=db,
start=time.time(),
store_ir=store_ir,
store_io=store_io,
)
return result
elif stream:

def _get_streaming_response(result):
Expand Down
12 changes: 7 additions & 5 deletions server/chainfury_server/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dataclasses import dataclass, asdict
from typing import Dict, Any

from sqlalchemy.pool import QueuePool
from sqlalchemy.pool import QueuePool, NullPool
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, scoped_session, sessionmaker, relationship
Expand Down Expand Up @@ -55,6 +55,8 @@
)
else:
logger.info(f"Using via database URL")
# https://stackoverflow.com/a/73764136
#
engine = create_engine(
db,
poolclass=QueuePool,
Expand Down Expand Up @@ -84,7 +86,7 @@ def get_random_number(length) -> int:
return random_numbers


def get_local_session() -> sessionmaker:
def get_local_session(engine) -> sessionmaker:
logger.debug("Database opened successfully")
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
return SessionLocal
Expand All @@ -101,7 +103,7 @@ def db_session() -> Session: # type: ignore


def fastapi_db_session():
sess_cls = get_local_session()
sess_cls = get_local_session(engine)
db = sess_cls()
try:
yield db
Expand Down Expand Up @@ -272,7 +274,7 @@ class Prompt(Base):
meta: Dict[str, Any] = Column(JSON)

# migrate to snowflake ID
sf_id = Column(String(19), nullable=True)
# sf_id = Column(String(19), nullable=True)

def to_dict(self):
return {
Expand Down Expand Up @@ -303,7 +305,7 @@ class ChainLog(Base):
)
created_at: datetime = Column(DateTime, nullable=False)
prompt_id: int = Column(Integer, ForeignKey("prompt.id"), nullable=False)
node_id: str = Column(String(Env.CFS_MAXLEN_CF_NDOE()), nullable=False)
node_id: str = Column(String(Env.CFS_MAXLEN_CF_NODE()), nullable=False)
worker_id: str = Column(String(Env.CFS_MAXLEN_WORKER()), nullable=False)
message: str = Column(Text, nullable=False)
data: Dict[str, Any] = Column(JSON, nullable=True)
Expand Down
Loading

0 comments on commit 2dd79eb

Please sign in to comment.