From 853b097801766f23ccd9658ecd1e34bda1f2cdea Mon Sep 17 00:00:00 2001 From: vgulerianb <90599235+vgulerianb@users.noreply.github.com> Date: Wed, 8 Nov 2023 18:13:29 +0530 Subject: [PATCH 1/8] fixes(dependency): Fix security alerts --- client/package.json | 3 +- client/yarn.lock | 122 +++++++++++++++++++++++++++++++++++--------- 2 files changed, 99 insertions(+), 26 deletions(-) diff --git a/client/package.json b/client/package.json index cb194b4..b1c05da 100644 --- a/client/package.json +++ b/client/package.json @@ -38,6 +38,7 @@ }, "resolutions": { "postcss": "^8.4.31", - "semver": "^6.3.1" + "semver": "^6.3.1", + "@babel/traverse": "^7.23.2" } } diff --git a/client/yarn.lock b/client/yarn.lock index 8090a2b..213309b 100644 --- a/client/yarn.lock +++ b/client/yarn.lock @@ -17,6 +17,14 @@ dependencies: "@babel/highlight" "^7.18.6" +"@babel/code-frame@^7.22.13": + version "7.22.13" + resolved "https://registry.yarnpkg.com/@babel/code-frame/-/code-frame-7.22.13.tgz#e3c1c099402598483b7a8c46a721d1038803755e" + integrity sha512-XktuhWlJ5g+3TJXc5upd9Ks1HutSArik6jf2eAjYFyIOf4ej3RN+184cZbzDvbPnuTJIUhPKKJE3cIsYTiAT3w== + dependencies: + "@babel/highlight" "^7.22.13" + chalk "^2.4.2" + "@babel/compat-data@^7.21.4": version "7.21.4" resolved "https://registry.yarnpkg.com/@babel/compat-data/-/compat-data-7.21.4.tgz#457ffe647c480dff59c2be092fc3acf71195c87f" @@ -53,6 +61,16 @@ "@jridgewell/trace-mapping" "^0.3.17" jsesc "^2.5.1" +"@babel/generator@^7.23.0": + version "7.23.0" + resolved "https://registry.yarnpkg.com/@babel/generator/-/generator-7.23.0.tgz#df5c386e2218be505b34837acbcb874d7a983420" + integrity sha512-lN85QRR+5IbYrMWM6Y4pE/noaQtg4pNiqeNGX60eqOfo6gtEj6uw/JagelB8vVztSd7R6M5n1+PQkDbHbBRU4g== + dependencies: + "@babel/types" "^7.23.0" + "@jridgewell/gen-mapping" "^0.3.2" + "@jridgewell/trace-mapping" "^0.3.17" + jsesc "^2.5.1" + "@babel/helper-compilation-targets@^7.21.4": version "7.21.4" resolved "https://registry.yarnpkg.com/@babel/helper-compilation-targets/-/helper-compilation-targets-7.21.4.tgz#770cd1ce0889097ceacb99418ee6934ef0572656" @@ -69,20 +87,25 @@ resolved "https://registry.yarnpkg.com/@babel/helper-environment-visitor/-/helper-environment-visitor-7.18.9.tgz#0c0cee9b35d2ca190478756865bb3528422f51be" integrity sha512-3r/aACDJ3fhQ/EVgFy0hpj8oHyHpQc+LPtJoY9SzTThAsStm4Ptegq92vqKoE3vD706ZVFWITnMnxucw+S9Ipg== -"@babel/helper-function-name@^7.21.0": - version "7.21.0" - resolved "https://registry.yarnpkg.com/@babel/helper-function-name/-/helper-function-name-7.21.0.tgz#d552829b10ea9f120969304023cd0645fa00b1b4" - integrity sha512-HfK1aMRanKHpxemaY2gqBmL04iAPOPRj7DxtNbiDOrJK+gdwkiNRVpCpUJYbUT+aZyemKN8brqTOxzCaG6ExRg== +"@babel/helper-environment-visitor@^7.22.20": + version "7.22.20" + resolved "https://registry.yarnpkg.com/@babel/helper-environment-visitor/-/helper-environment-visitor-7.22.20.tgz#96159db61d34a29dba454c959f5ae4a649ba9167" + integrity sha512-zfedSIzFhat/gFhWfHtgWvlec0nqB9YEIVrpuwjruLlXfUSnA8cJB0miHKwqDnQ7d32aKo2xt88/xZptwxbfhA== + +"@babel/helper-function-name@^7.23.0": + version "7.23.0" + resolved "https://registry.yarnpkg.com/@babel/helper-function-name/-/helper-function-name-7.23.0.tgz#1f9a3cdbd5b2698a670c30d2735f9af95ed52759" + integrity sha512-OErEqsrxjZTJciZ4Oo+eoZqeW9UIiOcuYKRJA4ZAgV9myA+pOXhhmpfNCKjEH/auVfEYVFJ6y1Tc4r0eIApqiw== dependencies: - "@babel/template" "^7.20.7" - "@babel/types" "^7.21.0" + "@babel/template" "^7.22.15" + "@babel/types" "^7.23.0" -"@babel/helper-hoist-variables@^7.18.6": - version "7.18.6" - resolved "https://registry.yarnpkg.com/@babel/helper-hoist-variables/-/helper-hoist-variables-7.18.6.tgz#d4d2c8fb4baeaa5c68b99cc8245c56554f926678" - integrity sha512-UlJQPkFqFULIcyW5sbzgbkxn2FKRgwWiRexcuaR8RNJRy8+LLveqPjwZV/bwrLZCN0eUHD/x8D0heK1ozuoo6Q== +"@babel/helper-hoist-variables@^7.22.5": + version "7.22.5" + resolved "https://registry.yarnpkg.com/@babel/helper-hoist-variables/-/helper-hoist-variables-7.22.5.tgz#c01a007dac05c085914e8fb652b339db50d823bb" + integrity sha512-wGjk9QZVzvknA6yKIUURb8zY3grXCcOZt+/7Wcy8O2uctxhplmUPkOdlgoNhmdVee2c92JXbf1xpMtVNbfoxRw== dependencies: - "@babel/types" "^7.18.6" + "@babel/types" "^7.22.5" "@babel/helper-module-imports@^7.16.7", "@babel/helper-module-imports@^7.18.6": version "7.21.4" @@ -124,16 +147,33 @@ dependencies: "@babel/types" "^7.18.6" +"@babel/helper-split-export-declaration@^7.22.6": + version "7.22.6" + resolved "https://registry.yarnpkg.com/@babel/helper-split-export-declaration/-/helper-split-export-declaration-7.22.6.tgz#322c61b7310c0997fe4c323955667f18fcefb91c" + integrity sha512-AsUnxuLhRYsisFiaJwvp1QF+I3KjD5FOxut14q/GzovUe6orHLesW2C7d754kRm53h5gqrz6sFl6sxc4BVtE/g== + dependencies: + "@babel/types" "^7.22.5" + "@babel/helper-string-parser@^7.19.4": version "7.19.4" resolved "https://registry.yarnpkg.com/@babel/helper-string-parser/-/helper-string-parser-7.19.4.tgz#38d3acb654b4701a9b77fb0615a96f775c3a9e63" integrity sha512-nHtDoQcuqFmwYNYPz3Rah5ph2p8PFeFCsZk9A/48dPc/rGocJ5J3hAAZ7pb76VWX3fZKu+uEr/FhH5jLx7umrw== +"@babel/helper-string-parser@^7.22.5": + version "7.22.5" + resolved "https://registry.yarnpkg.com/@babel/helper-string-parser/-/helper-string-parser-7.22.5.tgz#533f36457a25814cf1df6488523ad547d784a99f" + integrity sha512-mM4COjgZox8U+JcXQwPijIZLElkgEpO5rsERVDJTc2qfCDfERyob6k5WegS14SX18IIjv+XD+GrqNumY5JRCDw== + "@babel/helper-validator-identifier@^7.18.6", "@babel/helper-validator-identifier@^7.19.1": version "7.19.1" resolved "https://registry.yarnpkg.com/@babel/helper-validator-identifier/-/helper-validator-identifier-7.19.1.tgz#7eea834cf32901ffdc1a7ee555e2f9c27e249ca2" integrity sha512-awrNfaMtnHUr653GgGEs++LlAvW6w+DcPrOliSMXWCKo597CwL5Acf/wWdNkf/tfEQE3mjkeD1YOVZOUV/od1w== +"@babel/helper-validator-identifier@^7.22.20": + version "7.22.20" + resolved "https://registry.yarnpkg.com/@babel/helper-validator-identifier/-/helper-validator-identifier-7.22.20.tgz#c4ae002c61d2879e724581d96665583dbc1dc0e0" + integrity sha512-Y4OZ+ytlatR8AI+8KZfKuL5urKp7qey08ha31L8b3BwewJAoJamTzyvxPR/5D+KkdJCGPq/+8TukHBlY10FX9A== + "@babel/helper-validator-option@^7.21.0": version "7.21.0" resolved "https://registry.yarnpkg.com/@babel/helper-validator-option/-/helper-validator-option-7.21.0.tgz#8224c7e13ace4bafdc4004da2cf064ef42673180" @@ -157,11 +197,25 @@ chalk "^2.0.0" js-tokens "^4.0.0" +"@babel/highlight@^7.22.13": + version "7.22.20" + resolved "https://registry.yarnpkg.com/@babel/highlight/-/highlight-7.22.20.tgz#4ca92b71d80554b01427815e06f2df965b9c1f54" + integrity sha512-dkdMCN3py0+ksCgYmGG8jKeGA/8Tk+gJwSYYlFGxG5lmhfKNoAy004YpLxpS1W2J8m/EK2Ew+yOs9pVRwO89mg== + dependencies: + "@babel/helper-validator-identifier" "^7.22.20" + chalk "^2.4.2" + js-tokens "^4.0.0" + "@babel/parser@^7.20.7", "@babel/parser@^7.21.4": version "7.21.4" resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.21.4.tgz#94003fdfc520bbe2875d4ae557b43ddb6d880f17" integrity sha512-alVJj7k7zIxqBZ7BTRhz0IqJFxW1VJbm6N8JbcYhQ186df9ZBPbZBmWSqAMXwHGsCJdYks7z/voa3ibiS5bCIw== +"@babel/parser@^7.22.15", "@babel/parser@^7.23.0": + version "7.23.0" + resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.23.0.tgz#da950e622420bf96ca0d0f2909cdddac3acd8719" + integrity sha512-vvPKKdMemU85V9WE/l5wZEmImpCtLqbnTvqDS2U1fJ96KrxoW7KrXhNsNCblQlg8Ck4b85yxdTyelsMUgFUXiw== + "@babel/plugin-transform-react-jsx-self@^7.18.6": version "7.21.0" resolved "https://registry.yarnpkg.com/@babel/plugin-transform-react-jsx-self/-/plugin-transform-react-jsx-self-7.21.0.tgz#ec98d4a9baafc5a1eb398da4cf94afbb40254a54" @@ -192,19 +246,28 @@ "@babel/parser" "^7.20.7" "@babel/types" "^7.20.7" -"@babel/traverse@^7.21.0", "@babel/traverse@^7.21.2", "@babel/traverse@^7.21.4": - version "7.21.4" - resolved "https://registry.yarnpkg.com/@babel/traverse/-/traverse-7.21.4.tgz#a836aca7b116634e97a6ed99976236b3282c9d36" - integrity sha512-eyKrRHKdyZxqDm+fV1iqL9UAHMoIg0nDaGqfIOd8rKH17m5snv7Gn4qgjBoFfLz9APvjFU/ICT00NVCv1Epp8Q== - dependencies: - "@babel/code-frame" "^7.21.4" - "@babel/generator" "^7.21.4" - "@babel/helper-environment-visitor" "^7.18.9" - "@babel/helper-function-name" "^7.21.0" - "@babel/helper-hoist-variables" "^7.18.6" - "@babel/helper-split-export-declaration" "^7.18.6" - "@babel/parser" "^7.21.4" - "@babel/types" "^7.21.4" +"@babel/template@^7.22.15": + version "7.22.15" + resolved "https://registry.yarnpkg.com/@babel/template/-/template-7.22.15.tgz#09576efc3830f0430f4548ef971dde1350ef2f38" + integrity sha512-QPErUVm4uyJa60rkI73qneDacvdvzxshT3kksGqlGWYdOTIUOwJ7RDUL8sGqslY1uXWSL6xMFKEXDS3ox2uF0w== + dependencies: + "@babel/code-frame" "^7.22.13" + "@babel/parser" "^7.22.15" + "@babel/types" "^7.22.15" + +"@babel/traverse@^7.21.0", "@babel/traverse@^7.21.2", "@babel/traverse@^7.21.4", "@babel/traverse@^7.23.2": + version "7.23.2" + resolved "https://registry.yarnpkg.com/@babel/traverse/-/traverse-7.23.2.tgz#329c7a06735e144a506bdb2cad0268b7f46f4ad8" + integrity sha512-azpe59SQ48qG6nu2CzcMLbxUudtN+dOM9kDbUqGq3HXUJRlo7i8fvPoxQUzYgLZ4cMVmuZgm8vvBpNeRhd6XSw== + dependencies: + "@babel/code-frame" "^7.22.13" + "@babel/generator" "^7.23.0" + "@babel/helper-environment-visitor" "^7.22.20" + "@babel/helper-function-name" "^7.23.0" + "@babel/helper-hoist-variables" "^7.22.5" + "@babel/helper-split-export-declaration" "^7.22.6" + "@babel/parser" "^7.23.0" + "@babel/types" "^7.23.0" debug "^4.1.0" globals "^11.1.0" @@ -217,6 +280,15 @@ "@babel/helper-validator-identifier" "^7.19.1" to-fast-properties "^2.0.0" +"@babel/types@^7.22.15", "@babel/types@^7.22.5", "@babel/types@^7.23.0": + version "7.23.0" + resolved "https://registry.yarnpkg.com/@babel/types/-/types-7.23.0.tgz#8c1f020c9df0e737e4e247c0619f58c68458aaeb" + integrity sha512-0oIyUfKoI3mSqMvsxBdclDwxXKXAUA8v/apZbc+iSyARYou1o8ZGDxbUYyLFoW2arqS2jDGqJuZvv1d/io1axg== + dependencies: + "@babel/helper-string-parser" "^7.22.5" + "@babel/helper-validator-identifier" "^7.22.20" + to-fast-properties "^2.0.0" + "@emotion/babel-plugin@^11.10.6": version "11.10.6" resolved "https://registry.yarnpkg.com/@emotion/babel-plugin/-/babel-plugin-11.10.6.tgz#a68ee4b019d661d6f37dec4b8903255766925ead" @@ -1060,7 +1132,7 @@ caniuse-lite@^1.0.30001449, caniuse-lite@^1.0.30001464: resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001474.tgz#13b6fe301a831fe666cce8ca4ef89352334133d5" integrity sha512-iaIZ8gVrWfemh5DG3T9/YqarVZoYf0r188IjaGwx68j4Pf0SGY6CQkmJUIE+NZHkkecQGohzXmBGEwWDr9aM3Q== -chalk@^2.0.0: +chalk@^2.0.0, chalk@^2.4.2: version "2.4.2" resolved "https://registry.yarnpkg.com/chalk/-/chalk-2.4.2.tgz#cd42541677a54333cf541a49108c1432b44c9424" integrity sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ== From 3e7aabecdf3ece5bad6579932586d204fab3379c Mon Sep 17 00:00:00 2001 From: vgulerianb <90599235+vgulerianb@users.noreply.github.com> Date: Mon, 18 Dec 2023 17:10:50 +0530 Subject: [PATCH 2/8] chore(Chainfury): Vite version update --- client/package.json | 2 +- client/yarn.lock | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/client/package.json b/client/package.json index b1c05da..d79be78 100644 --- a/client/package.json +++ b/client/package.json @@ -34,7 +34,7 @@ "postcss": "^8.4.31", "tailwindcss": "^3.3.1", "typescript": "^4.9.3", - "vite": "^4.2.3" + "vite": "^4.4.12" }, "resolutions": { "postcss": "^8.4.31", diff --git a/client/yarn.lock b/client/yarn.lock index 213309b..bd4e761 100644 --- a/client/yarn.lock +++ b/client/yarn.lock @@ -2131,10 +2131,10 @@ util-deprecate@^1.0.2: resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf" integrity sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw== -vite@^4.2.3: - version "4.4.11" - resolved "https://registry.yarnpkg.com/vite/-/vite-4.4.11.tgz#babdb055b08c69cfc4c468072a2e6c9ca62102b0" - integrity sha512-ksNZJlkcU9b0lBwAGZGGaZHCMqHsc8OpgtoYhsQ4/I2v5cnpmmmqe5pM4nv/4Hn6G/2GhTdj0DhZh2e+Er1q5A== +vite@^4.4.12: + version "4.5.1" + resolved "https://registry.yarnpkg.com/vite/-/vite-4.5.1.tgz#3370986e1ed5dbabbf35a6c2e1fb1e18555b968a" + integrity sha512-AXXFaAJ8yebyqzoNB9fu2pHoo/nWX+xZlaRwoeYUxEqBO+Zj4msE5G+BhGBll9lYEKv9Hfks52PAF2X7qDYXQA== dependencies: esbuild "^0.18.10" postcss "^8.4.27" From 5aaaa15e335b9aa63ea377b3289efa6322c7ec8f Mon Sep 17 00:00:00 2001 From: yashbonde Date: Sun, 3 Mar 2024 11:50:53 +0530 Subject: [PATCH 3/8] add new types for APIs --- chainfury/__init__.py | 9 +++- chainfury/base.py | 1 + chainfury/cli.py | 110 ++++++++++++++++++------------------------ chainfury/client.py | 1 - chainfury/types.py | 58 ++++++++++++++++++++-- chainfury/utils.py | 2 +- pyproject.toml | 2 +- 7 files changed, 111 insertions(+), 72 deletions(-) diff --git a/chainfury/__init__.py b/chainfury/__init__.py index 2111d9b..7763832 100644 --- a/chainfury/__init__.py +++ b/chainfury/__init__.py @@ -14,7 +14,14 @@ logger, CFEnv, ) -from chainfury.base import Var, Node, Secret, Chain, Model, Edge +from chainfury.base import ( + Var, + Node, + Secret, + Chain, + Model, + Edge, +) from chainfury.core import ( model_registry, programatic_actions_registry, diff --git a/chainfury/base.py b/chainfury/base.py index 27e9253..1074696 100644 --- a/chainfury/base.py +++ b/chainfury/base.py @@ -1,5 +1,6 @@ # Copyright © 2023- Frello Technology Private Limited +import os import copy import json import jinja2 diff --git a/chainfury/cli.py b/chainfury/cli.py index ae1067c..8c44d20 100644 --- a/chainfury/cli.py +++ b/chainfury/cli.py @@ -11,42 +11,6 @@ from chainfury.core import model_registry, programatic_actions_registry, memory_registry -def help(): - print( - f""" - ___ _ _ ___ - / __| |_ __ _(_)_ _ | __| _ _ _ _ _ -| (__| ' \/ _` | | ' \ | _| || | '_| || | - \___|_||_\__,_|_|_||_||_| \_,_|_| \_, | - |__/ -e0 a4 b8 e0 a4 a4 e0 a5 8d e0 a4 af e0 a4 -ae e0 a5 87 e0 a4 b5 20 e0 a4 9c e0 a4 af - e0 a4 a4 e0 a5 87 - -🦋 Welcome to ChainFury Engine! - -cf_version: {__version__} - -The chaining engine behind chat.tune.app - -A powerful way to program for the "Software 2.0" era. Read more: - -- https://blog.nimblebox.ai/new-flow-engine-from-scratch -- https://blog.nimblebox.ai/fury-actions -- https://gist.github.com/yashbonde/002c527853e04869bfaa04646f3e0974 -- https://tunehq.ai -- https://chat.tune.app -- https://studio.tune.app - -🌟 us on https://github.com/NimbleBoxAI/ChainFury - -Build with ♥️ by Tune AI - -🌊 Chennai, India -""" - ) - - def run( chain: str, inp: str, @@ -115,32 +79,50 @@ def run( f.close() -def main(): - Fire( - { - "comp": { - "all": lambda: print(all_items), - "model": { - "list": list(model_registry.get_models()), - "all": model_registry.get_models(), - "get": model_registry.get, - }, - "prog": { - "list": list(programatic_actions_registry.get_nodes()), - "all": programatic_actions_registry.get_nodes(), - }, - "memory": { - "list": list(memory_registry.get_nodes()), - "all": memory_registry.get_nodes(), - }, - }, - "help": help, - "run": run, - "version": lambda: print( - f"""ChainFury 🦋 Engine - -chainfury=={__version__} +class __CLI: + info = rf""" + ___ _ _ ___ + / __| |_ __ _(_)_ _ | __| _ _ _ _ _ +| (__| ' \/ _` | | ' \ | _| || | '_| || | + \___|_||_\__,_|_|_||_||_| \_,_|_| \_, | + |__/ +e0 a4 b8 e0 a4 a4 e0 a5 8d e0 a4 af e0 a4 +ae e0 a5 87 e0 a4 b5 20 e0 a4 9c e0 a4 af + e0 a4 a4 e0 a5 87 + + +cf_version: {__version__} + +🦋 The FOSS chaining engine behind chat.tune.app + +A powerful way to program for the "Software 2.0" era. Read more: + +- https://tunehq.ai +- https://chat.tune.app +- https://studio.tune.app +🌟 us on https://github.com/NimbleBoxAI/ChainFury + +Build with ♥️ by Tune AI from the Koro coast 🌊 Chennai, India """ - ), - } - ) + + comp = { + "all": lambda: print(all_items), + "model": { + "list": list(model_registry.get_models()), + "all": model_registry.get_models(), + "get": model_registry.get, + }, + "prog": { + "list": list(programatic_actions_registry.get_nodes()), + "all": programatic_actions_registry.get_nodes(), + }, + "memory": { + "list": list(memory_registry.get_nodes()), + "all": memory_registry.get_nodes(), + }, + } + run = run + + +def main(): + Fire(__CLI) diff --git a/chainfury/client.py b/chainfury/client.py index 5163a97..001e92e 100644 --- a/chainfury/client.py +++ b/chainfury/client.py @@ -1,6 +1,5 @@ # Copyright © 2023- Frello Technology Private Limited -import os import requests from functools import lru_cache from typing import Dict, Any, Tuple diff --git a/chainfury/types.py b/chainfury/types.py index a93e881..1829af4 100644 --- a/chainfury/types.py +++ b/chainfury/types.py @@ -2,7 +2,6 @@ from datetime import datetime from typing import Dict, Any, List, Optional - from pydantic import BaseModel, Field, ConfigDict # First is the set of types that are used in the chainfury itself @@ -72,6 +71,11 @@ class CFPromptResult(BaseModel): task_id: str = "" +class ApiLoginResponse(BaseModel): + message: str + token: Optional[str] = None + + class ApiResponse(BaseModel): """This is the default response body of the API""" @@ -149,18 +153,18 @@ class ApiActionUpdateRequest(BaseModel): update_fields: List[str] = Field(description="The fields to update.") -class ApiAuth(BaseModel): +class ApiAuthRequest(BaseModel): username: str password: str -class ApiSignUp(BaseModel): +class ApiSignUpRequest(BaseModel): username: str email: str password: str -class ApiChangePassword(BaseModel): +class ApiChangePasswordRequest(BaseModel): username: str old_password: str new_password: str @@ -168,3 +172,49 @@ class ApiChangePassword(BaseModel): class ApiPromptFeedback(BaseModel): score: int + + +class ApiPromptFeedbackResponse(BaseModel): + rating: int + + +class ApiSaveTokenRequest(BaseModel): + key: str + token: str + meta: Optional[Dict[str, Any]] = {} + + +class ApiListTokensResponse(BaseModel): + tokens: List[ApiSaveTokenRequest] + + +class ApiChainLog(BaseModel): + id: str + created_at: datetime + prompt_id: int + node_id: str + worker_id: str + message: Optional[str] = None + data: Optional[Dict[str, Any]] = None + + +class ApiListChainLogsResponse(BaseModel): + logs: List[ApiChainLog] + + +class ApiPrompt(BaseModel): + id: int + chatbot_id: str + input_prompt: str + created_at: datetime + session_id: str + meta: Optional[Dict[str, Any]] = None + response: Optional[str] = None + gpt_rating: Optional[str] = None + user_rating: Optional[int] = None + time_taken: Optional[float] = None + num_tokens: Optional[int] = None + + +class ApiListPromptsResponse(BaseModel): + prompts: List[ApiPrompt] diff --git a/chainfury/utils.py b/chainfury/utils.py index b21ab59..4276569 100644 --- a/chainfury/utils.py +++ b/chainfury/utils.py @@ -311,7 +311,7 @@ def threaded_map( results[i] = res except Exception as e: if safe: - results[i] = e + results[i] = e # type: ignore else: raise e return results diff --git a/pyproject.toml b/pyproject.toml index abbfbda..1e5f0c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "chainfury" version = "1.7.0a1" description = "ChainFury is a powerful tool that simplifies the creation and management of chains of prompts, making it easier to build complex chat applications using LLMs." -authors = ["NimbleBox Engineering "] +authors = ["Tune AI "] license = "Apache 2.0" readme = "README.md" repository = "https://github.com/NimbleBoxAI/ChainFury" From 16b08b8cc78fded7b4ca0f3266fffe7aab315e30 Mon Sep 17 00:00:00 2001 From: yashbonde Date: Sun, 3 Mar 2024 11:51:12 +0530 Subject: [PATCH 4/8] new tokens database + chainlog APIs --- client/src/redux/services/auth.ts | 4 +- server/chainfury_server/api/chains.py | 13 ++--- server/chainfury_server/api/prompts.py | 43 ++++++++++---- server/chainfury_server/api/user.py | 77 ++++++++++++++++++++++---- server/chainfury_server/app.py | 36 +++++++----- server/chainfury_server/database.py | 41 +++++++++++++- server/chainfury_server/engine.py | 71 +++++++++++++++++------- server/chainfury_server/version.py | 2 +- server/pyproject.toml | 6 +- 9 files changed, 224 insertions(+), 69 deletions(-) diff --git a/client/src/redux/services/auth.ts b/client/src/redux/services/auth.ts index c0e04fd..c846981 100644 --- a/client/src/redux/services/auth.ts +++ b/client/src/redux/services/auth.ts @@ -49,7 +49,7 @@ export const authApi = createApi({ } >({ query: ({ score, prompt_id }) => ({ - url: `${BASE_URL}/api/v1/prompts/${prompt_id}/feedback`, + url: `${BASE_URL}/api/v1/prompts/${prompt_id}/feedback/`, method: 'PUT', body: { score @@ -136,7 +136,7 @@ export const authApi = createApi({ } >({ query: ({ score, prompt_id, chatbot_id }) => ({ - url: `${BASE_URL}/api/prompts/${prompt_id}/feedback`, + url: `${BASE_URL}/api/prompts/${prompt_id}/feedback/`, method: 'PUT', body: { score diff --git a/server/chainfury_server/api/chains.py b/server/chainfury_server/api/chains.py index 1eef5c9..97bf143 100644 --- a/server/chainfury_server/api/chains.py +++ b/server/chainfury_server/api/chains.py @@ -38,7 +38,7 @@ def create_chain( ) # DB call - dag = chatbot_data.dag.dict() if chatbot_data.dag else {} + dag = chatbot_data.dag.model_dump() if chatbot_data.dag else {} chatbot = DB.ChatBot( name=chatbot_data.name, created_by=user.id, @@ -51,8 +51,7 @@ def create_chain( db.refresh(chatbot) # return - response = T.ApiChain(**chatbot.to_dict()) - return response + return chatbot.to_ApiChain() def get_chain( @@ -74,13 +73,13 @@ def get_chain( ] if tag_id: filters.append(DB.ChatBot.tag_id == tag_id) - chatbot = db.query(DB.ChatBot).filter(*filters).first() # type: ignore + chatbot: DB.ChatBot = db.query(DB.ChatBot).filter(*filters).first() # type: ignore if not chatbot: resp.status_code = 404 return T.ApiResponse(message="ChatBot not found") # return - return T.ApiChain(**chatbot.to_dict()) + return chatbot.to_ApiChain() def update_chain( @@ -130,7 +129,7 @@ def update_chain( db.refresh(chatbot) # return - return T.ApiChain(**chatbot.to_dict()) + return chatbot.to_ApiChain() def delete_chain( @@ -186,7 +185,7 @@ def list_chains( # return return T.ApiListChainsResponse( - chatbots=[T.ApiChain(**chatbot.to_dict()) for chatbot in chatbots], + chatbots=[chatbot.to_ApiChain() for chatbot in chatbots], ) diff --git a/server/chainfury_server/api/prompts.py b/server/chainfury_server/api/prompts.py index 8a7cb0b..1d3627f 100644 --- a/server/chainfury_server/api/prompts.py +++ b/server/chainfury_server/api/prompts.py @@ -3,7 +3,7 @@ from fastapi import Depends, Header, HTTPException from fastapi.requests import Request from fastapi.responses import Response -from typing import Annotated +from typing import Annotated, List from sqlalchemy.orm import Session import chainfury_server.database as DB @@ -17,7 +17,7 @@ def list_prompts( limit: int = 100, offset: int = 0, db: Session = Depends(DB.fastapi_db_session), -): +) -> T.ApiListPromptsResponse: # validate user user = DB.get_user_from_jwt(token=token, db=db) @@ -25,7 +25,7 @@ def list_prompts( if limit < 1 or limit > 100: limit = 100 offset = offset if offset > 0 else 0 - prompts = ( + prompts: List[DB.Prompt] = ( db.query(DB.Prompt) # type: ignore .filter(DB.Prompt.chatbot_id == chain_id) .order_by(DB.Prompt.created_at.desc()) # type: ignore @@ -33,14 +33,14 @@ def list_prompts( .offset(offset) .all() ) - return {"prompts": [p.to_dict() for p in prompts]} + return T.ApiListPromptsResponse(prompts=[p.to_ApiPrompt() for p in prompts]) def get_prompt( prompt_id: int, token: Annotated[str, Header()], db: Session = Depends(DB.fastapi_db_session), -): +) -> T.ApiPrompt: # validate user user = DB.get_user_from_jwt(token=token, db=db) @@ -49,14 +49,15 @@ def get_prompt( if not prompt: raise HTTPException(status_code=404, detail="Prompt not found") - return {"prompt": prompt.to_dict()} + # return {"prompt": prompt.to_dict()} # before + return prompt.to_ApiPrompt() def delete_prompt( prompt_id: int, token: Annotated[str, Header()], db: Session = Depends(DB.fastapi_db_session), -): +) -> T.ApiResponse: # validate user user = DB.get_user_from_jwt(token=token, db=db) @@ -67,7 +68,7 @@ def delete_prompt( db.delete(prompt) db.commit() - return {"msg": f"Prompt: '{prompt_id}' deleted"} + return T.ApiResponse(message=f"Prompt '{prompt.id}' deleted") def prompt_feedback( @@ -75,7 +76,7 @@ def prompt_feedback( inputs: T.ApiPromptFeedback, prompt_id: int, db: Session = Depends(DB.fastapi_db_session), -): +) -> T.ApiPromptFeedbackResponse: # validate user user = DB.get_user_from_jwt(token=token, db=db) @@ -94,4 +95,26 @@ def prompt_feedback( status_code=404, detail=f"Unable to find the prompt", ) - return {"rating": prompt.user_rating} + return T.ApiPromptFeedbackResponse(rating=prompt.user_rating) # type: ignore + + +def get_chain_logs( + token: Annotated[str, Header()], + prompt_id: int, + limit: int = 100, + offset: int = 0, + db: Session = Depends(DB.fastapi_db_session), +) -> T.ApiListChainLogsResponse: + # validate user + user = DB.get_user_from_jwt(token=token, db=db) + + # query the DB + chainlogs: List[DB.ChainLog] = ( + db.query(DB.ChainLog) # type: ignore + .filter(DB.ChainLog.prompt_id == prompt_id) + .order_by(DB.ChainLog.created_at.desc()) # type: ignore + .limit(limit) + .offset(offset) + .all() + ) + return T.ApiListChainLogsResponse(logs=[c.to_ApiChainLog() for c in chainlogs]) diff --git a/server/chainfury_server/api/user.py b/server/chainfury_server/api/user.py index 726a529..c34a908 100644 --- a/server/chainfury_server/api/user.py +++ b/server/chainfury_server/api/user.py @@ -12,20 +12,30 @@ import chainfury.types as T -def login(auth: T.ApiAuth, db: Session = Depends(DB.fastapi_db_session)): +def login( + req: Request, + resp: Response, + auth: T.ApiAuthRequest, + db: Session = Depends(DB.fastapi_db_session), +) -> T.ApiLoginResponse: user: DB.User = db.query(DB.User).filter(DB.User.username == auth.username).first() # type: ignore if user is not None and sha256_crypt.verify(auth.password, user.password): # type: ignore token = jwt.encode( payload=DB.JWTPayload(username=auth.username, user_id=user.id).to_dict(), key=Env.JWT_SECRET(), ) - response = {"msg": "success", "token": token} + return T.ApiLoginResponse(message="success", token=token) else: - response = {"msg": "failed"} - return response + resp.status_code = 401 + return T.ApiLoginResponse(message="failed") -def sign_up(auth: T.ApiSignUp, db: Session = Depends(DB.fastapi_db_session)): +def sign_up( + req: Request, + resp: Response, + auth: T.ApiSignUpRequest, + db: Session = Depends(DB.fastapi_db_session), +) -> T.ApiLoginResponse: user_exists = False email_exists = False user: DB.User = db.query(DB.User).filter(DB.User.username == auth.username).first() # type: ignore @@ -36,7 +46,8 @@ def sign_up(auth: T.ApiSignUp, db: Session = Depends(DB.fastapi_db_session)): email_exists = True if user_exists and email_exists: raise HTTPException( - status_code=400, detail="Username and email already registered" + status_code=400, + detail="Username and email already registered", ) elif user_exists: raise HTTPException(status_code=400, detail="Username is taken") @@ -54,17 +65,17 @@ def sign_up(auth: T.ApiSignUp, db: Session = Depends(DB.fastapi_db_session)): payload=DB.JWTPayload(username=auth.username, user_id=user.id).to_dict(), key=Env.JWT_SECRET(), ) - response = {"msg": "success", "token": token} + return T.ApiLoginResponse(message="success", token=token) else: - response = {"msg": "failed"} - return response + resp.status_code = 400 + return T.ApiLoginResponse(message="failed") def change_password( req: Request, resp: Response, token: Annotated[str, Header()], - inputs: T.ApiChangePassword, + inputs: T.ApiChangePasswordRequest, db: Session = Depends(DB.fastapi_db_session), ) -> T.ApiResponse: # validate user @@ -78,3 +89,49 @@ def change_password( else: resp.status_code = 400 return T.ApiResponse(message="password incorrect") + + +# TODO: @tunekoro - Implement the following functions + + +def create_token( + req: Request, + resp: Response, + token: Annotated[str, Header()], + inputs: T.ApiSaveTokenRequest, + db: Session = Depends(DB.fastapi_db_session), +) -> T.ApiResponse: + resp.status_code = 501 # + return T.ApiResponse(message="not implemented") + + +def get_token( + req: Request, + resp: Response, + key: str, + token: Annotated[str, Header()], + db: Session = Depends(DB.fastapi_db_session), +) -> T.ApiResponse: + resp.status_code = 501 # + return T.ApiResponse(message="not implemented") + + +def list_tokens( + req: Request, + resp: Response, + token: Annotated[str, Header()], + db: Session = Depends(DB.fastapi_db_session), +) -> T.ApiResponse: + resp.status_code = 501 # + return T.ApiResponse(message="not implemented") + + +def delete_token( + req: Request, + resp: Response, + key: str, + token: Annotated[str, Header()], + db: Session = Depends(DB.fastapi_db_session), +) -> T.ApiResponse: + resp.status_code = 501 # + return T.ApiResponse(message="not implemented") diff --git a/server/chainfury_server/app.py b/server/chainfury_server/app.py index 883cde4..d5c8445 100644 --- a/server/chainfury_server/app.py +++ b/server/chainfury_server/app.py @@ -20,7 +20,8 @@ description=""" chainfury server is a way to deploy and run chainfury engine over APIs. `chainfury` is [Tune AI](tunehq.ai)'s FOSS project released under [Apache-2 License](https://choosealicense.com/licenses/apache-2.0/) so you can use this for your commercial -projects. A version `chainfury` is used in production in [Tune.Chat](chat.tune.app) and serves thousands of users daily. +projects. A version `chainfury` is used in production in [Tune.Chat](chat.tune.app), serves and solves thousands of user +queries daily. """.strip(), version=__version__, docs_url="" if Env.CFS_DISABLE_DOCS() else "/docs", @@ -42,24 +43,29 @@ app.add_api_route("/api/v1/chatbot/{id}/prompt", api_chains.run_chain, methods=["POST"], tags=["deprecated"], response_model=None) # type: ignore # user -app.add_api_route("/user/login/", api_user.login, methods=["POST"], tags=["user"]) # type: ignore -app.add_api_route("/user/signup/", api_user.sign_up, methods=["POST"], tags=["user"]) # type: ignore -app.add_api_route("/user/change_password/", api_user.change_password, methods=["POST"], tags=["user"]) # type: ignore +app.add_api_route(methods=["POST"], path="/user/login/", endpoint=api_user.login, tags=["user"]) # type: ignore +app.add_api_route(methods=["POST"], path="/user/signup/", endpoint=api_user.sign_up, tags=["user"]) # type: ignore +app.add_api_route(methods=["POST"], path="/user/change_password/", endpoint=api_user.change_password, tags=["user"]) # type: ignore +app.add_api_route(methods=["PUT"], path="/user/token/", endpoint=api_user.create_token, tags=["user"]) # type: ignore +app.add_api_route(methods=["GET"], path="/user/token/", endpoint=api_user.get_token, tags=["user"]) # type: ignore +app.add_api_route(methods=["DELETE"], path="/user/token/", endpoint=api_user.delete_token, tags=["user"]) # type: ignore +app.add_api_route(methods=["GET"], path="/user/tokens/list/", endpoint=api_user.list_tokens, tags=["user"]) # type: ignore # chains -app.add_api_route("/api/chains/", api_chains.list_chains, methods=["GET"], tags=["chains"]) # type: ignore -app.add_api_route("/api/chains/", api_chains.create_chain, methods=["PUT"], tags=["chains"]) # type: ignore -app.add_api_route("/api/chains/{id}/", api_chains.get_chain, methods=["GET"], tags=["chains"]) # type: ignore -app.add_api_route("/api/chains/{id}/", api_chains.delete_chain, methods=["DELETE"], tags=["chains"]) # type: ignore -app.add_api_route("/api/chains/{id}/", api_chains.update_chain, methods=["PATCH"], tags=["chains"]) # type: ignore -app.add_api_route("/api/chains/{id}/", api_chains.run_chain, methods=["POST"], tags=["chains"], response_model=None) # type: ignore -app.add_api_route("/api/chains/{id}/metrics/", api_chains.get_chain_metrics, methods=["GET"], tags=["chains"]) # type: ignore +app.add_api_route(methods=["GET"], path="/api/chains/", endpoint=api_chains.list_chains, tags=["chains"]) # type: ignore +app.add_api_route(methods=["PUT"], path="/api/chains/", endpoint=api_chains.create_chain, tags=["chains"]) # type: ignore +app.add_api_route(methods=["GET"], path="/api/chains/{id}/", endpoint=api_chains.get_chain, tags=["chains"]) # type: ignore +app.add_api_route(methods=["DELETE"], path="/api/chains/{id}/", endpoint=api_chains.delete_chain, tags=["chains"]) # type: ignore +app.add_api_route(methods=["PATCH"], path="/api/chains/{id}/", endpoint=api_chains.update_chain, tags=["chains"]) # type: ignore +app.add_api_route(methods=["POST"], path="/api/chains/{id}/", endpoint=api_chains.run_chain, tags=["chains"], response_model=None) # type: ignore +app.add_api_route(methods=["GET"], path="/api/chains/{id}/metrics/", endpoint=api_chains.get_chain_metrics, tags=["chains"]) # type: ignore # prompts -app.add_api_route("/api/prompts/", api_prompts.list_prompts, methods=["GET"], tags=["prompts"]) # type: ignore -app.add_api_route("/api/prompts/{prompt_id}/", api_prompts.get_prompt, methods=["GET"], tags=["prompts"]) # type: ignore -app.add_api_route("/api/prompts/{prompt_id}/", api_prompts.delete_prompt, methods=["DELETE"], tags=["prompts"]) # type: ignore -app.add_api_route("/api/prompts/{prompt_id}/feedback", api_prompts.prompt_feedback, methods=["PUT"], tags=["prompts"]) # type: ignore +app.add_api_route(methods=["GET"], path="/api/prompts/", endpoint=api_prompts.list_prompts, tags=["prompts"]) # type: ignore +app.add_api_route(methods=["GET"], path="/api/prompts/{prompt_id}/", endpoint=api_prompts.get_prompt, tags=["prompts"]) # type: ignore +app.add_api_route(methods=["DELETE"], path="/api/prompts/{prompt_id}/", endpoint=api_prompts.delete_prompt, tags=["prompts"]) # type: ignore +app.add_api_route(methods=["PUT"], path="/api/prompts/{prompt_id}/feedback/", endpoint=api_prompts.prompt_feedback, tags=["prompts"]) # type: ignore +app.add_api_route(methods=["GET"], path="/api/prompts/{prompt_id}/logs/", endpoint=api_prompts.get_chain_logs, tags=["prompts"]) # type: ignore # UI files diff --git a/server/chainfury_server/database.py b/server/chainfury_server/database.py index 4d36c15..72a85ce 100644 --- a/server/chainfury_server/database.py +++ b/server/chainfury_server/database.py @@ -13,7 +13,7 @@ from sqlalchemy.pool import QueuePool from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import Session, scoped_session, sessionmaker +from sqlalchemy.orm import Session, scoped_session, sessionmaker, relationship from sqlalchemy import ( Column, ForeignKey, @@ -28,6 +28,7 @@ ) from chainfury_server.utils import logger, Env +import chainfury.types as T ######## # @@ -170,11 +171,29 @@ class User(Base): username: str = Column(String(80), unique=True, nullable=False) password: str = Column(String(80), nullable=False) meta: Dict[str, Any] = Column(JSON) + tokens = relationship("Tokens", back_populates="user") def __repr__(self): return f"User(id={self.id}, username={self.username}, meta={self.meta})" +class Tokens(Base): + __tablename__ = "tokens" + + MAXLEN_KEY = 80 + MAXLEN_VAL = 1024 + + id = Column(Integer, primary_key=True) + user_id = Column(String(ID_LENGTH), ForeignKey("user.id"), nullable=False) + key = Column(String(MAXLEN_KEY), nullable=False) + value = Column(String(MAXLEN_VAL), nullable=False) + meta = Column(JSON, nullable=True) + user = relationship("User", back_populates="tokens") + + def __repr__(self): + return f"Tokens(id={self.id}, user_id={self.user_id}, key={self.key}, value={self.value[:5]}..., meta={self.meta})" + + class ChatBot(Base): __tablename__ = "chatbot" @@ -208,6 +227,9 @@ def to_dict(self): "deleted_at": self.deleted_at, } + def to_ApiChain(self) -> T.ApiChain: + return T.ApiChain(**self.to_dict()) + def __repr__(self): return f"ChatBot(id={self.id}, name={self.name}, created_by={self.created_by}, dag={self.dag}, meta={self.meta})" @@ -255,6 +277,9 @@ def to_dict(self): "meta": self.meta, } + def to_ApiPrompt(self): + return T.ApiPrompt(**self.to_dict()) + class ChainLog(Base): __tablename__ = "chain_logs" @@ -271,6 +296,20 @@ class ChainLog(Base): message: str = Column(Text, nullable=False) data: Dict[str, Any] = Column(JSON, nullable=True) + def to_dict(self): + return { + "id": self.id, + "created_at": self.created_at, + "prompt_id": self.prompt_id, + "node_id": self.node_id, + "worker_id": self.worker_id, + "message": self.message, + "data": self.data, + } + + def to_ApiChainLog(self): + return T.ApiChainLog(**self.to_dict()) + class Template(Base): __tablename__ = "template" diff --git a/server/chainfury_server/engine.py b/server/chainfury_server/engine.py index 4c7be63..beb86ec 100644 --- a/server/chainfury_server/engine.py +++ b/server/chainfury_server/engine.py @@ -47,6 +47,35 @@ def run( thoughts_callback=callback, print_thoughts=False, ) + + # store the full_ir in the DB.ChainLog + if store_ir: + # group the logs by node_id + chain_logs_by_node = {} + for k, v in full_ir.items(): + node_id, varname = k.split("/") + chain_logs_by_node.setdefault(node_id, {"outputs": []}) + chain_logs_by_node[node_id]["outputs"].append( + { + "name": varname, + "data": v, + } + ) + + # iterate over node ids and create the logs + for k, v in chain_logs_by_node.items(): + db_chainlog = DB.ChainLog( + prompt_id=prompt_row.id, + created_at=SimplerTimes.get_now_datetime(), + node_id=k, + worker_id="cf_server", + message="step", + data=v, + ) # type: ignore + db.add(db_chainlog) + db.commit() + + # create the result result = T.CFPromptResult( result=( json.dumps(mainline_out) @@ -111,6 +140,28 @@ def stream( mainline_out = ir yield ir, False + if store_ir: + # in case of stream, every item is a fundamentally a step + data = { + "outputs": [ + { + "name": k.split("/")[-1], + "data": v, + } + for k, v in ir.items() + ] + } + k = next(iter(ir))[0].split("/")[0] + db_chainlog = DB.ChainLog( + prompt_id=prompt_row.id, + created_at=SimplerTimes.get_now_datetime(), + node_id=k, + worker_id="cf_server", + message="step", + data=data, + ) # type: ignore + db.add(db_chainlog) + result = T.CFPromptResult( result=str(mainline_out), prompt_id=prompt_row.id, # type: ignore @@ -194,26 +245,6 @@ def __call__(self, thought): self.count += 1 -# def create_intermediate_steps( -# db: Session, -# prompt_id: int, -# intermediate_prompt: str = "", -# intermediate_response: str = "", -# response_json: Dict = {}, -# ) -> DB.IntermediateStep: -# db_prompt = DB.IntermediateStep( -# prompt_id=prompt_id, -# intermediate_prompt=intermediate_prompt, -# intermediate_response=intermediate_response, -# response_json=response_json, -# created_at=SimplerTimes.get_now_datetime(), -# ) # type: ignore -# db.add(db_prompt) -# db.commit() -# db.refresh(db_prompt) -# return db_prompt - - def create_prompt( db: Session, chatbot_id: str, input_prompt: str, session_id: str ) -> DB.Prompt: diff --git a/server/chainfury_server/version.py b/server/chainfury_server/version.py index 635ae71..73d6f7c 100644 --- a/server/chainfury_server/version.py +++ b/server/chainfury_server/version.py @@ -1,6 +1,6 @@ # Copyright © 2023- Frello Technology Private Limited -__version__ = "2.1.0" +__version__ = "2.1.2a" _major, _minor, _patch = __version__.split(".") _major = int(_major) _minor = int(_minor) diff --git a/server/pyproject.toml b/server/pyproject.toml index 6de019e..58e26e6 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -2,8 +2,8 @@ [tool.poetry] name = "chainfury_server" -version = "2.1.1" -description = "ChainFury Server is the server for running ChainFury Engine!" +version = "2.1.2a" +description = "ChainFury Server is the DB + API server for managing the ChainFury engine in production. Used in production at chat.tune.app" authors = ["Tune AI "] license = "Apache 2.0" readme = "README.md" @@ -25,7 +25,7 @@ urllib3 = ">=1.26.18" "cryptography" = ">=41.0.6" [tool.poetry.scripts] -chainfury_server = "chainfury_server" +chainfury_server = "chainfury_server:__main__" [build-system] requires = ["poetry-core"] From 0f430388cdc69de91e3e04380a26ee8a7aebb796 Mon Sep 17 00:00:00 2001 From: yashbonde Date: Mon, 4 Mar 2024 10:25:10 +0530 Subject: [PATCH 5/8] add more tests + core is a single file again --- chainfury/components/functional/__init__.py | 2 +- chainfury/{core/actions.py => core.py} | 318 +++++++++++++++++++- chainfury/core/__init__.py | 9 - chainfury/core/memory.py | 240 --------------- chainfury/core/models.py | 105 ------- tests/__main__.py | 2 +- tests/base.py | 62 ++++ tests/chains.py | 33 -- 8 files changed, 381 insertions(+), 390 deletions(-) rename chainfury/{core/actions.py => core.py} (60%) delete mode 100644 chainfury/core/__init__.py delete mode 100644 chainfury/core/memory.py delete mode 100644 chainfury/core/models.py create mode 100644 tests/base.py delete mode 100644 tests/chains.py diff --git a/chainfury/components/functional/__init__.py b/chainfury/components/functional/__init__.py index c64c9b0..789086d 100644 --- a/chainfury/components/functional/__init__.py +++ b/chainfury/components/functional/__init__.py @@ -237,7 +237,7 @@ def echo(message: str) -> Tuple[Dict[str, Dict[str, str]], Optional[Exception]]: programatic_actions_registry.register( fn=echo, - outputs={"message": (0,)}, # type: ignore + outputs={"message": ()}, # type: ignore node_id="chainfury-echo", description="I stared into the abyss and it stared back at me. Echoes the message, used for debugging", ) diff --git a/chainfury/core/actions.py b/chainfury/core.py similarity index 60% rename from chainfury/core/actions.py rename to chainfury/core.py index d1d06a2..11370f0 100644 --- a/chainfury/core/actions.py +++ b/chainfury/core.py @@ -8,6 +8,7 @@ """ import copy +import random from uuid import uuid4 from typing import Any, List, Optional, Dict, Tuple @@ -24,7 +25,7 @@ put_value_by_keys, ) from chainfury.utils import logger -from chainfury.core.models import model_registry + # Programtic Actions Registry # --------------------------- @@ -497,6 +498,310 @@ def get_count_for_nodes(self, node_id: str) -> int: return self.counter.get(node_id, 0) +# Memory Registry +# --------------------------- +# All the components that have to do with storage and retreival of data from the DB. This sections is supppsed to act +# like the memory in an Von Neumann architecture. + + +class Memory: + """Class to wrap the DB functions as a callable. + + Args: + node_id (str): The id of the node + fn (object): The function that is used for this action + vector_key (str): The key for the vector in the DB + read_mode (bool, optional): If the function is a read function, if `False` then this is a write function. + """ + + fields_model = [ + Var( + name="items", + type=[Var(type="string"), Var(type="array", items=[Var(type="string")])], + required=True, + ), + Var(name="embedding_model", type="string", required=True), + Var( + name="embedding_model_params", + type="object", + additionalProperties=Var(type="string"), + ), + Var(name="embedding_model_key", type="string"), + Var( + name="translation_layer", + type="object", + additionalProperties=Var(type="string"), + ), + ] + """These are the fields that are used to map the input items to the embedding model, do not use directly""" + + def __init__( + self, node_id: str, fn: object, vector_key: str, read_mode: bool = False + ): + self.node_id = node_id + self.fn = fn + self.vector_key = vector_key + self.read_mode = read_mode + self.fields_fn = func_to_vars(fn) + self.fields = self.fields_fn + self.fields_model + + def to_dict(self) -> Dict[str, Any]: + """Serialize the Memory object to a dict.""" + return { + "node_id": self.node_id.split("-")[0], + "vector_key": self.vector_key, + "read_mode": self.read_mode, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]): + """Deserialize the Memory object from a dict.""" + read_mode = data["read_mode"] + if read_mode: + fn = memory_registry.get_read(data["node_id"]) + else: + fn = memory_registry.get_write(data["node_id"]) + + # here we do return Memory type but instead of creating one we use a previously existing Node and return + # the fn for the Node which is ultimately this precise Memory object + return fn.fn # type: ignore + + def __call__(self, **data: Dict[str, Any]) -> Any: + # the first thing we have to do is get the data for the model. This is actually a very hard problem because this + # function needs to call some other arbitrary function where we know the inputs to this function "items" but we + # do not know which variable to pass this to in the undelying model's function. Thus we need to take in a huge + # amount of things as more inputs ("embedding_model_key", "embedding_model_params"). Then we don't even know + # what the inputs to the underlying DB functionbare going to be, in which case we also need to add things like + # the translation that needs to be done ("translation_layer"). This makes the number of inputs a lot but + # ultimately is required to do the job for robust-ness. Which is why we provide a default for openai-embedding + # model. For any other model user will need to pass all the information. + model_fields: Dict[str, Any] = {} + for f in self.fields_model: + if f.required and f.name not in data: + raise Exception( + f"Field '{f.name}' is required in {self.node_id} but not present" + ) + if f.name in data: + model_fields[f.name] = data.pop(f.name) + + model_data = {**model_fields.get("embedding_model_params", {})} + model_id = model_fields.pop("embedding_model") + + # TODO: @yashbonde - clean this mess up + # DEFAULT_MEMORY_CONSTANTS = { + # "openai-embedding": { + # "embedding_model_key": "input_strings", + # "embedding_model_params": { + # "model": "text-embedding-ada-002", + # }, + # "translation_layer": { + # "embeddings": ["data", "*", "embedding"], + # }, + # } + # } + # embedding_model_default_config = DEFAULT_MEMORY_CONSTANTS.get(model_id, {}) + # if embedding_model_default_config: + # model_data = { + # **embedding_model_default_config.get("embedding_model_params", {}), + # **model_data, + # } + # model_key = embedding_model_default_config.get( + # "embedding_model_key", "items" + # ) or model_data.get("embedding_model_key") + # model_fields["translation_layer"] = model_fields.get( + # "translation_layer" + # ) or embedding_model_default_config.get("translation_layer") + # else: + + req_keys = [x.name for x in self.fields_model[2:]] + if not all([x in model_fields for x in req_keys]): + raise Exception(f"Model {model_id} requires {req_keys} to be passed") + model_key = model_fields.get("embedding_model_key") + model_data = { + **model_fields.get("embedding_model_params", {}), + **model_data, + } + model_data[model_key] = model_fields.pop("items") # type: ignore + model = model_registry.get(model_id) + embeddings, err = model(model_data=model_data) + if err: + logger.error(f"error: {err}") + logger.error(f"traceback: {embeddings}") + raise err + + # now that we have all the embeddings ready we now need to translate it to be fed into the DB function + translated_data = {} + for k, v in model_fields.get("translation_layer", {}).items(): + translated_data[k] = get_value_by_keys(embeddings, v) + + # create the dictionary to call the underlying function + db_data = {} + for f in self.fields_fn: + if f.required and not (f.name in data or f.name in translated_data): + raise Exception( + f"Field '{f.name}' is required in {self.node_id} but not present" + ) + if f.name in data: + db_data[f.name] = data.pop(f.name) + if f.name in translated_data: + db_data[f.name] = translated_data.pop(f.name) + out, err = self.fn(**db_data) # type: ignore + return out, err + + +class MemoryRegistry: + def __init__(self) -> None: + self._memories: Dict[str, Node] = {} + + def register_write( + self, + component_name: str, + fn: object, + outputs: Dict[str, Any], + vector_key: str, + description: str = "", + tags: List[str] = [], + ) -> Node: + node_id = f"{component_name}-write" + mem_fn = Memory(node_id=node_id, fn=fn, vector_key=vector_key, read_mode=False) + output_fields = func_to_return_vars(fn, returns=outputs) + node = Node( + id=node_id, + fn=mem_fn, + type=Node.types.MEMORY, + fields=mem_fn.fields, + outputs=output_fields, + description=description, + tags=tags, + ) + self._memories[node_id] = node + return node + + def register_read( + self, + component_name: str, + fn: object, + outputs: Dict[str, Any], + vector_key: str, + description: str = "", + tags: List[str] = [], + ) -> Node: + node_id = f"{component_name}-read" + mem_fn = Memory(node_id=node_id, fn=fn, vector_key=vector_key, read_mode=True) + output_fields = func_to_return_vars(fn, returns=outputs) + node = Node( + id=node_id, + fn=mem_fn, + type=Node.types.MEMORY, + fields=mem_fn.fields, + outputs=output_fields, + description=description, + tags=tags, + ) + self._memories[node_id] = node + return node + + def get_write(self, node_id: str) -> Optional[Node]: + out = self._memories.get(node_id + "-write", None) + if out is None: + raise ValueError(f"Memory '{node_id}' not found") + return out + + def get_read(self, node_id: str) -> Optional[Node]: + out = self._memories.get(node_id + "-read", None) + if out is None: + raise ValueError(f"Memory '{node_id}' not found") + return out + + def get_nodes(self): + return {k: v.to_dict() for k, v in self._memories.items()} + + +# Models Registry +# --------------- +# All the things below are for the models that are registered in the model registry, so that they can be used as inputs +# in the chain. There can be several models that can put as inputs in a single chatbot. + + +class ModelRegistry: + """Model registry contains metadata for all the models that are provided in the components""" + + def __init__(self): + self.models: Dict[str, Model] = {} + self.counter: Dict[str, int] = {} + self.tags_to_models: Dict[str, List[str]] = {} + + def has(self, id: str): + """A helper function to check if a model is registered or not""" + return id in self.models + + def register(self, model: Model): + """Register a model in the registry + + Args: + model (Model): Model to register + """ + id = model.id + logger.debug(f"Registering model {id} at {id}") + if id in self.models: + raise Exception(f"Model {id} already registered") + self.models[id] = model + for tag in model.tags: + self.tags_to_models[tag] = self.tags_to_models.get(tag, []) + [id] + return model + + def get_tags(self) -> List[str]: + """Get all the tags that are registered in the registry + + Returns: + List[str]: List of tags + """ + return list(self.tags_to_models.keys()) + + def get_models(self, tag: str = "") -> Dict[str, Dict[str, Any]]: + """Get all the models that are registered in the registry + + Args: + tag (str, optional): Filter models by tag. Defaults to "". + + Returns: + Dict[str, Dict[str, Any]]: Dictionary of models + """ + items = {k: v.to_dict() for k, v in self.models.items()} + if tag: + items = {k: v for k, v in items.items() if tag in v.get("tags", [])} + return items + + def get(self, id: str) -> Model: + """Get a model from the registry + + Args: + id (str): Id of the model + + Returns: + Model: Model + """ + self.counter[id] = self.counter.get(id, 0) + 1 + out = self.models.get(id, None) + if out is None: + raise ValueError(f"Model {id} not found") + return out + + def get_count_for_model(self, id: str) -> int: + """Get the number of times a model is used + + Args: + id (str): Id of the model + + Returns: + int: Number of times the model is used + """ + return self.counter.get(id, 0) + + def get_any_model(self) -> Model: + return random.choice(list(self.models.values())) + + # Initialise Registries # --------------------- @@ -511,3 +816,14 @@ def get_count_for_nodes(self, node_id: str) -> int: `ai_actions_registry` is a global instance of `AIActionsRegistry` class. This is used to register and unregister `AIAction` instances. This is used by the server to serve the registered actions. """ + +memory_registry = MemoryRegistry() +""" +`memory_registry` is a global instance of MemoryRegistry class. This is used to register and unregister Memory instances. +This is what the user should use when they want to use the memory elements in their chain. +""" + +model_registry = ModelRegistry() +""" +`model_registry` is a global variable that is used to register models. It is an instance of ModelRegistry class. +""" diff --git a/chainfury/core/__init__.py b/chainfury/core/__init__.py deleted file mode 100644 index f24cb1e..0000000 --- a/chainfury/core/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright © 2023- Frello Technology Private Limited - -from chainfury.core.models import model_registry -from chainfury.core.actions import ( - programatic_actions_registry, - ai_actions_registry, - AIAction, -) -from chainfury.core.memory import memory_registry, Memory diff --git a/chainfury/core/memory.py b/chainfury/core/memory.py deleted file mode 100644 index 069cf3e..0000000 --- a/chainfury/core/memory.py +++ /dev/null @@ -1,240 +0,0 @@ -# Copyright © 2023- Frello Technology Private Limited - -""" -Actions -======= - -All actions that the AI can do. -""" - -from typing import Any, List, Optional, Dict - -from chainfury.base import ( - Node, - func_to_return_vars, - func_to_vars, - Var, - get_value_by_keys, -) -from chainfury.utils import logger -from chainfury.core.models import model_registry - - -class Memory: - """Class to wrap the DB functions as a callable. - - Args: - node_id (str): The id of the node - fn (object): The function that is used for this action - vector_key (str): The key for the vector in the DB - read_mode (bool, optional): If the function is a read function, if `False` then this is a write function. - """ - - fields_model = [ - Var( - name="items", - type=[Var(type="string"), Var(type="array", items=[Var(type="string")])], - required=True, - ), - Var(name="embedding_model", type="string", required=True), - Var( - name="embedding_model_params", - type="object", - additionalProperties=Var(type="string"), - ), - Var(name="embedding_model_key", type="string"), - Var( - name="translation_layer", - type="object", - additionalProperties=Var(type="string"), - ), - ] - """These are the fields that are used to map the input items to the embedding model, do not use directly""" - - def __init__( - self, node_id: str, fn: object, vector_key: str, read_mode: bool = False - ): - self.node_id = node_id - self.fn = fn - self.vector_key = vector_key - self.read_mode = read_mode - self.fields_fn = func_to_vars(fn) - self.fields = self.fields_fn + self.fields_model - - def to_dict(self) -> Dict[str, Any]: - """Serialize the Memory object to a dict.""" - return { - "node_id": self.node_id.split("-")[0], - "vector_key": self.vector_key, - "read_mode": self.read_mode, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]): - """Deserialize the Memory object from a dict.""" - read_mode = data["read_mode"] - if read_mode: - fn = memory_registry.get_read(data["node_id"]) - else: - fn = memory_registry.get_write(data["node_id"]) - - # here we do return Memory type but instead of creating one we use a previously existing Node and return - # the fn for the Node which is ultimately this precise Memory object - return fn.fn # type: ignore - - def __call__(self, **data: Dict[str, Any]) -> Any: - # the first thing we have to do is get the data for the model. This is actually a very hard problem because this - # function needs to call some other arbitrary function where we know the inputs to this function "items" but we - # do not know which variable to pass this to in the undelying model's function. Thus we need to take in a huge - # amount of things as more inputs ("embedding_model_key", "embedding_model_params"). Then we don't even know - # what the inputs to the underlying DB functionbare going to be, in which case we also need to add things like - # the translation that needs to be done ("translation_layer"). This makes the number of inputs a lot but - # ultimately is required to do the job for robust-ness. Which is why we provide a default for openai-embedding - # model. For any other model user will need to pass all the information. - model_fields: Dict[str, Any] = {} - for f in self.fields_model: - if f.required and f.name not in data: - raise Exception( - f"Field '{f.name}' is required in {self.node_id} but not present" - ) - if f.name in data: - model_fields[f.name] = data.pop(f.name) - - model_data = {**model_fields.get("embedding_model_params", {})} - model_id = model_fields.pop("embedding_model") - - # TODO: @yashbonde - clean this mess up - # DEFAULT_MEMORY_CONSTANTS = { - # "openai-embedding": { - # "embedding_model_key": "input_strings", - # "embedding_model_params": { - # "model": "text-embedding-ada-002", - # }, - # "translation_layer": { - # "embeddings": ["data", "*", "embedding"], - # }, - # } - # } - # embedding_model_default_config = DEFAULT_MEMORY_CONSTANTS.get(model_id, {}) - # if embedding_model_default_config: - # model_data = { - # **embedding_model_default_config.get("embedding_model_params", {}), - # **model_data, - # } - # model_key = embedding_model_default_config.get( - # "embedding_model_key", "items" - # ) or model_data.get("embedding_model_key") - # model_fields["translation_layer"] = model_fields.get( - # "translation_layer" - # ) or embedding_model_default_config.get("translation_layer") - # else: - - req_keys = [x.name for x in self.fields_model[2:]] - if not all([x in model_fields for x in req_keys]): - raise Exception(f"Model {model_id} requires {req_keys} to be passed") - model_key = model_fields.get("embedding_model_key") - model_data = { - **model_fields.get("embedding_model_params", {}), - **model_data, - } - model_data[model_key] = model_fields.pop("items") # type: ignore - model = model_registry.get(model_id) - embeddings, err = model(model_data=model_data) - if err: - logger.error(f"error: {err}") - logger.error(f"traceback: {embeddings}") - raise err - - # now that we have all the embeddings ready we now need to translate it to be fed into the DB function - translated_data = {} - for k, v in model_fields.get("translation_layer", {}).items(): - translated_data[k] = get_value_by_keys(embeddings, v) - - # create the dictionary to call the underlying function - db_data = {} - for f in self.fields_fn: - if f.required and not (f.name in data or f.name in translated_data): - raise Exception( - f"Field '{f.name}' is required in {self.node_id} but not present" - ) - if f.name in data: - db_data[f.name] = data.pop(f.name) - if f.name in translated_data: - db_data[f.name] = translated_data.pop(f.name) - out, err = self.fn(**db_data) # type: ignore - return out, err - - -class MemoryRegistry: - def __init__(self) -> None: - self._memories: Dict[str, Node] = {} - - def register_write( - self, - component_name: str, - fn: object, - outputs: Dict[str, Any], - vector_key: str, - description: str = "", - tags: List[str] = [], - ) -> Node: - node_id = f"{component_name}-write" - mem_fn = Memory(node_id=node_id, fn=fn, vector_key=vector_key, read_mode=False) - output_fields = func_to_return_vars(fn, returns=outputs) - node = Node( - id=node_id, - fn=mem_fn, - type=Node.types.MEMORY, - fields=mem_fn.fields, - outputs=output_fields, - description=description, - tags=tags, - ) - self._memories[node_id] = node - return node - - def register_read( - self, - component_name: str, - fn: object, - outputs: Dict[str, Any], - vector_key: str, - description: str = "", - tags: List[str] = [], - ) -> Node: - node_id = f"{component_name}-read" - mem_fn = Memory(node_id=node_id, fn=fn, vector_key=vector_key, read_mode=True) - output_fields = func_to_return_vars(fn, returns=outputs) - node = Node( - id=node_id, - fn=mem_fn, - type=Node.types.MEMORY, - fields=mem_fn.fields, - outputs=output_fields, - description=description, - tags=tags, - ) - self._memories[node_id] = node - return node - - def get_write(self, node_id: str) -> Optional[Node]: - out = self._memories.get(node_id + "-write", None) - if out is None: - raise ValueError(f"Memory '{node_id}' not found") - return out - - def get_read(self, node_id: str) -> Optional[Node]: - out = self._memories.get(node_id + "-read", None) - if out is None: - raise ValueError(f"Memory '{node_id}' not found") - return out - - def get_nodes(self): - return {k: v.to_dict() for k, v in self._memories.items()} - - -memory_registry = MemoryRegistry() -""" -`memory_registry` is a global instance of MemoryRegistry class. This is used to register and unregister Memory instances. -This is what the user should use when they want to use the memory elements in their chain. -""" diff --git a/chainfury/core/models.py b/chainfury/core/models.py deleted file mode 100644 index e5642bc..0000000 --- a/chainfury/core/models.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright © 2023- Frello Technology Private Limited - -""" -Models -====== - -All things required in a model. -""" - -import random -from typing import Any, List, Dict - -from chainfury.base import Model -from chainfury.utils import logger - - -# Models -# ------ -# All the things below are for the models that are registered in the model registry, so that they can be used as inputs -# in the chain. There can be several models that can put as inputs in a single chatbot. - - -class ModelRegistry: - """Model registry contains metadata for all the models that are provided in the components""" - - def __init__(self): - self.models: Dict[str, Model] = {} - self.counter: Dict[str, int] = {} - self.tags_to_models: Dict[str, List[str]] = {} - - def has(self, id: str): - """A helper function to check if a model is registered or not""" - return id in self.models - - def register(self, model: Model): - """Register a model in the registry - - Args: - model (Model): Model to register - """ - id = model.id - logger.debug(f"Registering model {id} at {id}") - if id in self.models: - raise Exception(f"Model {id} already registered") - self.models[id] = model - for tag in model.tags: - self.tags_to_models[tag] = self.tags_to_models.get(tag, []) + [id] - return model - - def get_tags(self) -> List[str]: - """Get all the tags that are registered in the registry - - Returns: - List[str]: List of tags - """ - return list(self.tags_to_models.keys()) - - def get_models(self, tag: str = "") -> Dict[str, Dict[str, Any]]: - """Get all the models that are registered in the registry - - Args: - tag (str, optional): Filter models by tag. Defaults to "". - - Returns: - Dict[str, Dict[str, Any]]: Dictionary of models - """ - items = {k: v.to_dict() for k, v in self.models.items()} - if tag: - items = {k: v for k, v in items.items() if tag in v.get("tags", [])} - return items - - def get(self, id: str) -> Model: - """Get a model from the registry - - Args: - id (str): Id of the model - - Returns: - Model: Model - """ - self.counter[id] = self.counter.get(id, 0) + 1 - out = self.models.get(id, None) - if out is None: - raise ValueError(f"Model {id} not found") - return out - - def get_count_for_model(self, id: str) -> int: - """Get the number of times a model is used - - Args: - id (str): Id of the model - - Returns: - int: Number of times the model is used - """ - return self.counter.get(id, 0) - - def get_any_model(self) -> Model: - return random.choice(list(self.models.values())) - - -model_registry = ModelRegistry() -""" -`model_registry` is a global variable that is used to register models. It is an instance of ModelRegistry class. -""" diff --git a/tests/__main__.py b/tests/__main__.py index ab7a2cd..5f90bd1 100644 --- a/tests/__main__.py +++ b/tests/__main__.py @@ -1,7 +1,7 @@ # Copyright © 2023- Frello Technology Private Limited from tests.getkv import TestGetValueByKeys -from tests.chains import TestChainSerDeser +from tests.base import TestSerDeser, TestNode import unittest if __name__ == "__main__": diff --git a/tests/base.py b/tests/base.py new file mode 100644 index 0000000..5611b48 --- /dev/null +++ b/tests/base.py @@ -0,0 +1,62 @@ +# Copyright © 2023- Frello Technology Private Limited + +from chainfury import programatic_actions_registry, Chain +from chainfury.components.functional import echo + +import unittest + + +chain = Chain( + name="echo-cf-public", + description="abyss", + nodes=[programatic_actions_registry.get("chainfury-echo")], # type: ignore + sample={"message": "hi there"}, + main_in="message", + main_out="chainfury-echo/message", +) + + +class TestSerDeser(unittest.TestCase): + def test_chain_dict(self): + Chain.from_dict(chain.to_dict()) + + def test_chain_apidict(self): + Chain.from_dict(chain.to_dict(api=True)) + + def test_chain_json(self): + Chain.from_json(chain.to_json()) + + def test_chain_dag(self): + Chain.from_dag(chain.to_dag()) + + def test_node_dict(self): + node = programatic_actions_registry.get("chainfury-echo") + if node is None: + self.fail("Node not found") + self.assertIsNotNone(node) + node.from_dict(node.to_dict()) + + def test_node_json(self): + node = programatic_actions_registry.get("chainfury-echo") + if node is None: + self.fail("Node not found") + self.assertIsNotNone(node) + node.from_json(node.to_json()) + + +class TestNode(unittest.TestCase): + def test_node_run(self): + node = programatic_actions_registry.get("chainfury-echo") + if node is None: + self.fail("Node not found") + self.assertIsNotNone(node) + out, err = node(data={"message": "hi there"}) + self.assertIsNone(err) + + # call the function directly + fn_out, _ = echo("hi there") + self.assertEqual(out, {"message": fn_out}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/chains.py b/tests/chains.py deleted file mode 100644 index 335688f..0000000 --- a/tests/chains.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright © 2023- Frello Technology Private Limited - -from chainfury import programatic_actions_registry, Chain - -import unittest - - -chain = Chain( - name="echo-cf-public", - description="abyss", - nodes=[programatic_actions_registry.get("chainfury-echo")], # type: ignore - sample={"message": "hi there"}, - main_in="message", - main_out="chainfury-echo/message", -) - - -class TestChainSerDeser(unittest.TestCase): - def test_dict(self): - Chain.from_dict(chain.to_dict()) - - def test_apidict(self): - Chain.from_dict(chain.to_dict(api=True)) - - def test_json(self): - Chain.from_json(chain.to_json()) - - def test_dag(self): - Chain.from_dag(chain.to_dag()) - - -if __name__ == "__main__": - unittest.main() From 094a454f2d8faa930026de422979dc86a479470a Mon Sep 17 00:00:00 2001 From: yashbonde Date: Tue, 12 Mar 2024 22:33:32 +0530 Subject: [PATCH 6/8] [chore] add secrets + improved tests + expanded types --- .env.sample | 29 ++ .gitignore | 3 +- api_docs/conf.py | 2 +- chainfury/__init__.py | 10 +- chainfury/base.py | 306 +++++++++++++++-- chainfury/chat.py | 431 ------------------------ chainfury/cli.py | 212 ++++++------ chainfury/components/const.py | 4 +- chainfury/components/openai/__init__.py | 23 +- chainfury/components/tune/__init__.py | 134 ++++++-- chainfury/types.py | 26 +- chainfury/utils.py | 4 +- chainfury/version.py | 2 +- extra/ex1/chain.py | 78 ----- pyproject.toml | 3 +- server/chainfury_server/__main__.py | 26 +- server/chainfury_server/api/chains.py | 23 +- server/chainfury_server/api/user.py | 136 +++++--- server/chainfury_server/app.py | 8 +- server/chainfury_server/database.py | 12 + server/chainfury_server/engine.py | 20 +- server/chainfury_server/utils.py | 65 +++- server/chainfury_server/version.py | 2 +- server/pyproject.toml | 5 +- tests/__main__.py | 8 - tests/main.py | 15 + tests/test_base_chain2.py | 61 ++++ tests/{base.py => test_base_types.py} | 75 ++++- tests/{getkv.py => test_getkv.py} | 0 29 files changed, 924 insertions(+), 799 deletions(-) create mode 100644 .env.sample delete mode 100644 chainfury/chat.py delete mode 100644 extra/ex1/chain.py delete mode 100644 tests/__main__.py create mode 100644 tests/main.py create mode 100644 tests/test_base_chain2.py rename tests/{base.py => test_base_types.py} (59%) rename tests/{getkv.py => test_getkv.py} (100%) diff --git a/.env.sample b/.env.sample new file mode 100644 index 0000000..728f648 --- /dev/null +++ b/.env.sample @@ -0,0 +1,29 @@ +# chainfury server +# ================ +# These are the environment variables that are used by the chainfury_server +# For chainfury jump below to the chainfury section + +# Required +# -------- + +# URL to the database for chainfury server, uses sqlalchemy, so most things should work +CFS_DATABASE="db_drivers://username:password@host:port/db_name" + +# (once in a lifetime) secret string for creating the JWT secrets +JWT_SECRET="secret" + +# (once in a lifetime) password to store the user secrets +CFS_SECRETS_PASSWORD="password" + +# chainfury +# ========= +# These are the environment variables that are used by the chainfury + +# To store all the file and data in the chainfury server +CF_FOLDER="~/cf" + +# (client mode) the URL for the chainfury server +CF_URL="" + +# (client mode) the token for the chainfury server +CF_TOKEN="" diff --git a/.gitignore b/.gitignore index be52001..00613c4 100644 --- a/.gitignore +++ b/.gitignore @@ -142,7 +142,7 @@ langflow dunmp.rdb *.ipynb server/chainfury_server/stories/fury.json -notebooks/* +notebooks stories/fury.json workers/ private.sh @@ -153,3 +153,4 @@ demo/ logs.py chunker/ chainfury/chains/ +gosrc/ diff --git a/api_docs/conf.py b/api_docs/conf.py index 0570017..8731d75 100644 --- a/api_docs/conf.py +++ b/api_docs/conf.py @@ -14,7 +14,7 @@ project = "ChainFury" copyright = "2023, NimbleBox Engineering" author = "NimbleBox Engineering" -release = "1.7.0a1" +release = "1.7.0a2" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/chainfury/__init__.py b/chainfury/__init__.py index 7763832..561671f 100644 --- a/chainfury/__init__.py +++ b/chainfury/__init__.py @@ -21,6 +21,8 @@ Chain, Model, Edge, + Tools, + Action, ) from chainfury.core import ( model_registry, @@ -31,11 +33,11 @@ Memory, ) from chainfury.client import get_client -from chainfury.chat import ( +from chainfury.types import ( Message, - Chat, - TuneChats, - TuneDataset, + Thread, + ThreadsList, + Dataset, human, system, assistant, diff --git a/chainfury/base.py b/chainfury/base.py index 1074696..04f1617 100644 --- a/chainfury/base.py +++ b/chainfury/base.py @@ -9,15 +9,17 @@ import importlib import traceback from pprint import pformat +from functools import partial from typing import Any, Union, Optional, Dict, List, Tuple, Callable, Generator from collections import deque, defaultdict, Counter import jinja2schema as j2s from jinja2schema import model as j2sm +from tuneapi.utils import load_module_from_path, to_json, from_json + from chainfury.utils import logger, terminal_top_with_text import chainfury.types as T -from chainfury import chat class Secret(str): @@ -26,6 +28,9 @@ class Secret(str): def __init__(self, value=""): self.value = value + def has_value(self) -> bool: + return self.value != "" + # # Vars: this is the base class for all the fields that the user can provide from the front end @@ -35,7 +40,7 @@ def __init__(self, value=""): class Var: def __init__( self, - type: Union[str, List["Var"]], + type: Union[str, List["Var"]] = "", format: str = "", items: List["Var"] = [], additionalProperties: Union[List["Var"], "Var"] = [], @@ -45,6 +50,7 @@ def __init__( placeholder: str = "", show: bool = False, name: str = "", + description: str = "", *, loc: Optional[Tuple] = (), ): @@ -62,6 +68,9 @@ def __init__( name (str, optional): The name of this field. Defaults to "". loc (Optional[Tuple], optional): The location of this field. Defaults to (). """ + if not type: + raise ValueError("type cannot be empty") + self.type = type self.format = format self.items = items or [] @@ -72,6 +81,7 @@ def __init__( self.placeholder = placeholder self.show = show self.name = name + self.description = description # self.value = None self.loc = loc # this is the location from which this value is extracted @@ -116,6 +126,8 @@ def to_dict(self) -> Dict[str, Any]: d["name"] = self.name if self.loc: d["loc"] = self.loc + if self.description: + d["description"] = self.description return d @classmethod @@ -138,6 +150,7 @@ def from_dict(cls, d: Dict[str, Any]) -> "Var": show_val = d.get("show", False) name_val = d.get("name", "") loc_val = d.get("loc", ()) + description_val = d.get("description", "") if isinstance(type_val, list): type_val = [ @@ -163,6 +176,7 @@ def from_dict(cls, d: Dict[str, Any]) -> "Var": placeholder=placeholder_val, show=show_val, name=name_val, + description=description_val, loc=loc_val, ) return var @@ -740,16 +754,27 @@ def __call__(self, model_data: Dict[str, Any]) -> Tuple[Any, Optional[Exception] except Exception as e: return traceback.format_exc(), e + def set_api_token(self, token: str) -> None: + raise NotImplementedError( + f"set_api_token method is not implemented for {self.id}" + ) + def completion(self, prompt: str, **kwargs): """Subclass and implement your own text completion API""" return NotImplementedError( f"completion method is not implemented for {self.id}" ) - def chat(self, chat: chat.Chat, **kwargs): + def chat(self, chat: T.Thread, **kwargs): """Subclass and implement your own chat API""" raise NotImplementedError("chat method is not implemented for this model") + def stream_chat(self, chat: T.Thread, **kwargs): + """Subclass and implement your own chat API""" + raise NotImplementedError( + "stream_chat method is not implemented for this model" + ) + # # Node: Each box that is drag and dropped in the UI is a Node, it will tell what kind of things are @@ -981,12 +1006,12 @@ def __call__( @classmethod def from_chat( cls, - chat: chat.Chat, + thread: T.Thread, node_id: str, model: Model, description: Optional[str] = None, ) -> "Node": - chat_dict = chat.to_dict() + chat_dict = thread.to_dict() # print(variables) fields = [] templates = [] @@ -996,7 +1021,7 @@ def from_chat( obj = get_value_by_keys(chat_dict, field[0]) if not obj: raise ValueError( - f"Field {field[0]} not found in {chat}, but was extraced. There is a bug in get_value_by_keys function" + f"Field {field[0]} not found in {thread}, but was extraced. There is a bug in get_value_by_keys function" ) templates.append((obj, jinja2.Template(obj), field[0])) @@ -1124,7 +1149,11 @@ def __init__( self.edges = edges if len(self.nodes) == 1: - assert len(self.edges) == 0, "Cannot have edges with only 1 node" + if len(self.edges) != 0: + logger.error(f"Got only one node: {self.nodes.keys()=}") + raise ValueError( + f"Cannot have edges with only 1 node. Got {self.edges}" + ) self.topo_order = [next(iter(self.nodes))] else: self.topo_order = topological_sort(self.edges) @@ -1134,6 +1163,7 @@ def __init__( if self.is_empty: # there is nothing to do here + logger.info("This is empty chain") return if "/" not in main_out: @@ -1185,7 +1215,7 @@ def __repr__(self) -> str: def add_thread( self, node_id: str, - chat: chat.Chat, + thread: T.Thread, model: Optional[Model] = None, description: str = "", ) -> "Chain": @@ -1197,30 +1227,35 @@ def add_thread( # build the node node = Node.from_chat( - chat, + thread, node_id=node_id, model=model or self.default_model, # type: ignore description=description, ) + logger.debug(f"Adding node (total nodes {len(self.nodes)}): {node.id=}") + # add edges as required for var in node.fields: if var.name in self.nodes: - self.edges.append( - Edge( - src_node_id=var.name, - src_node_var=var.name, - trg_node_id=node_id, - trg_node_var=self.nodes[var.name].outputs[0].name, - ) + e = Edge( + src_node_id=var.name, + src_node_var=var.name, + trg_node_id=node_id, + trg_node_var=self.nodes[var.name].outputs[0].name, ) + logger.debug(f"Adding (total edges {len(self.edges)}) {e=}") + self.edges.append(e) # assign the node self.nodes[node.id] = node # topo sort if len(self.nodes) == 1: - assert len(self.edges) == 0, "Cannot have edges with only 1 node" + if len(self.edges) != 0: + raise ValueError( + f"Cannot have edges with only 1 node. Got {self.edges}" + ) self.topo_order = [next(iter(self.nodes))] else: self.topo_order = topological_sort(self.edges) @@ -1339,9 +1374,9 @@ def to_dag(self) -> T.Dag: nodes = [] for i, node in enumerate(self.nodes.values()): nodes.append( - T.FENode( + T.UINode( id=node.id, - position=T.FENode.Position( + position=T.UINode.Position( x=i * 100, y=i * 100, ), @@ -1349,13 +1384,13 @@ def to_dag(self) -> T.Dag: width=100, height=100, selected=False, - position_absolute=T.FENode.Position( + position_absolute=T.UINode.Position( x=i * 100, y=i * 100, ), dragging=False, cf_id=node.id, - cf_data=T.FENode.CFData( + cf_data=T.UINode.CFData( id=node.id, type=node.type, node=node.to_dict(), @@ -1806,6 +1841,235 @@ def stream( yield out, True +# +# Tools: A new abstraction for AGI +# + + +class Action: + def __init__( + self, + name: str, + description: str, + properties: Optional[Dict[str, Any]] = {}, + required: Optional[List[str]] = [], + fn: Optional[Callable] = None, + fn_meta: Optional[Dict[str, Any]] = None, + ): + self.name = name + self.description = description + self.properties = properties + self.required = required + + if (fn is None and fn_meta is None) or (fn is not None and fn_meta is not None): + raise ValueError("Either fn or fn_meta is required") + if fn_meta is not None: + # first validate that the line content is same in the two files, then try to load the item + if not os.path.exists(fn_meta["file"]): + raise ValueError(f"File {fn_meta['file']} does not exist") + with open(fn_meta["file"], "r") as f: + for i, l in enumerate(f): + if i == fn_meta["line"] and l != fn_meta["line_val"]: + raise ValueError( + f"Line #{fn_meta['line']} does not match in {fn_meta['file']}\n" + f" Expected: {l}\n" + f" Found: {fn_meta['line_val']}" + ) + fn = load_module_from_path(fn_meta["name"], fn_meta["file"]) + elif fn is not None: + fn_meta = { + "file": inspect.getfile(fn), + "line": inspect.getsourcelines(fn)[1] - 1, + "line_val": inspect.getsourcelines(fn)[0][0], + "name": fn.__name__, + } + + self.fn_meta = fn_meta + self.fn: Callable = fn # type: ignore + + def __repr__(self) -> str: + return f"[Action] {self.name}: {self.description}" + + # ser/deser + + def to_dict(self) -> Dict[str, Any]: + """Serializes the action to a dictionary. + + Returns: + Dict[str, Any]: The dictionary representation of the action. + """ + return { + "name": self.name, + "description": self.description, + "required": self.required, + "properties": self.properties, + "fn_meta": self.fn_meta, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Action": + """Deserializes the action from a dictionary. + + Args: + data (Dict[str, Any]): The dictionary representation of the action. + """ + return cls( + name=data["name"], + description=data["description"], + required=data["required"], + properties=data["properties"], + fn_meta=data["fn_meta"], + ) + + def to_json(self, indent=0, tight=True) -> str: + """Serializes the action to a JSON string. + + Returns: + str: The JSON string representation of the action. + """ + return to_json(self.to_dict(), indent=indent, tight=tight) + + @classmethod + def from_json(cls, data: str) -> "Action": + """Creates an action from a JSON string. + + Args: + data (str): The JSON string representation of the action. + """ + return cls.from_dict(json.loads(data)) + + def __call__(self, *args, **kwargs): + # validate the data is in + return self.fn(*args, **kwargs) + + # usage + + def to_fn(self) -> Dict[str, Any]: + data = self.to_dict() + data.pop("fn_meta") + return data + + +class Tools: + """ + Usage: + + >>> from chainfury import Tool, Var + >>> my_tool = Tool("My Tool", "This is a test tool") + >>> @my_tool( + ... description = "this is test action", + ... props = { + ... "a": Var("int", "number 1"), + ... "b": Var("int", "number 2"), + ... "secret": Var("int", secret = True, key="MY_ENV_VAR"), # to implement + ... } + ... ) + ... def add_two_numbers(a: int, b: int, secret: int): + ... return (a + b) * secret + ... + >>> my_tool.to_json(indent = 2) + { + "name": "add_two_numbers", + ... + } + """ + + def __init__(self, name: str, description: str): + self.name = name + self.description = description + + # + self.actions: Dict[str, Action] = {} + + def __repr__(self) -> str: + return f"[Tool] {self.name}: {self.description}" + + def _register_action( + self, + fn: Callable, + description: str, + properties: Dict[str, Var], + name: Optional[str] = None, + ) -> Action: + name = name or fn.__name__ + props = {} + required = [] + for k, v in properties.items(): + props[k] = { + "type": v.type, + "description": v.description, + } + if v.required: + required.append(k) + + self.actions[name] = Action( + name=name, + description=description, + properties=props, + required=required, + fn=fn, + ) + return self.actions[name] + + def add(self, description: str, properties: Dict[str, Var] = {}) -> Action: + """ + Register the actions + """ + return partial( + self._register_action, + description=description, + properties=properties, + ) # type: ignore + + # ser/deser + + def to_dict(self) -> Dict[str, Any]: + """ + Register the actions + """ + return { + "name": self.name, + "description": self.description, + "actions": {k: v.to_dict() for k, v in self.actions.items()}, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Tools": + """Deserializes the Tool from a dictionary. + + Args: + data (Dict[str, Any]): The dictionary representation of the chain. + """ + self = cls( + name=data["name"], + description=data["description"], + ) + actions = data.get("actions", {}) + for a in actions.values(): + action = Action.from_dict(a) + self.actions[action.name] = action + + return self + + def to_json(self, indent=2, tight=False) -> str: + """Serializes the Tool to a JSON string. + + Returns: + str: The JSON string representation of the chain. + """ + return to_json(self.to_dict(), indent=indent, tight=tight) + + @classmethod + def from_json(cls, data: str) -> "Tools": + """ + Register the actions + """ + return cls.from_dict(json.loads(data)) + + def to_fn(self) -> Dict[str, Any]: + return {"name": self.name, "description": self.description, "properties": {}} + + # # helper functions # diff --git a/chainfury/chat.py b/chainfury/chat.py deleted file mode 100644 index 7ebfe9d..0000000 --- a/chainfury/chat.py +++ /dev/null @@ -1,431 +0,0 @@ -# Copyright © 2023- Frello Technology Private Limited - -import os -import json -import random -from functools import partial -from collections.abc import Iterable -from typing import Dict, List, Any, Tuple, Optional, Generator - -from chainfury.utils import to_json, get_random_string, logger - - -class Message: - SYSTEM = "system" - HUMAN = "human" - GPT = "gpt" - VALUE = "value" - FUNCTION = "function" - FUNCTION_RESPONSE = "function-response" - - # start initialization here - def __init__(self, value: str | float, role: str): - if role in ["system", "sys"]: - role = self.SYSTEM - elif role in ["user", "human"]: - role = self.HUMAN - elif role in ["gpt", "assistant", "machine"]: - role = self.GPT - elif role in ["value"]: - role = self.VALUE - elif role in ["function", "fn"]: - role = self.FUNCTION - elif role in ["function-response", "fn-resp"]: - role = self.FUNCTION_RESPONSE - else: - raise ValueError(f"Unknown role: {role}") - if value is None: - raise ValueError("value cannot be None") - - self.role = role - self.value = value - self._unq_value = get_random_string(6) - - def __str__(self) -> str: - try: - idx = max(os.get_terminal_size().columns - len(self.role) - 40, 10) - except OSError: - idx = 50 - return f"<{self.role}: {json.dumps(self.value)[:idx]}>" - - def __repr__(self) -> str: - return str(self.value) - - def to_dict(self, ft: bool = False): - """ - if `ft` then export to following format: `{"from": "system/human/gpt", "value": "..."}` - else export to following format: `{"role": "system/user/assistant", "content": "..."}` - """ - role = self.role - if not ft: - if self.role == self.HUMAN: - role = "user" - elif self.role == self.GPT: - role = "assistant" - - chat_message: Dict[str, str | float] - if ft: - chat_message = {"from": role} - else: - chat_message = {"role": role} - - if not ft: - chat_message["content"] = self.value - else: - chat_message["value"] = self.value - return chat_message - - @classmethod - def from_dict(cls, data): - return cls( - value=data.get("value") or data.get("content"), - role=data.get("from") or data.get("role"), - ) # type: ignore - - -### Aliases -human = partial(Message, role=Message.HUMAN) -system = partial(Message, role=Message.SYSTEM) -assistant = partial(Message, role=Message.GPT) - - -class Chat: - """ - If the last Message is a "value" then a special tag "koro.regression"="true" is added to the meta. - - Args: - chats (List[Message]): List of chat messages - jl (Dict[str, Any]): Optional json-logic - """ - - def __init__( - self, - chats: List[Message], - jl: Optional[Dict[str, Any]] = None, - model: Optional[str] = None, - **kwargs, - ): - self.chats = chats - self.jl = jl - self.model = model - - # check for regression - if self.chats[-1].role == Message.VALUE: - kwargs["koro.regression"] = True - - kwargs = {k: v for k, v in sorted(kwargs.items())} - self.meta = kwargs - self.keys = list(kwargs.keys()) - self.values = tuple(kwargs.values()) - - # avoid special character BS. - assert not any(["=" in x or "&" in x for x in self.keys]) - if self.values: - assert all([type(x) in [int, str, float, bool] for x in self.values]) - - self.value_hash = hash(self.values) - - def __repr__(self) -> str: - x = " Any: - if __name in self.meta: - return self.meta[__name] - raise AttributeError(f"Attribute {__name} not found") - - # ser/deser - - def to_dict(self, full: bool = False): - if full: - return { - "chats": [x.to_dict() for x in self.chats], - "jl": self.jl, - "model": self.model, - "meta": self.meta, - } - return { - "chats": [x.to_dict() for x in self.chats], - } - - def to_chat_template(self): - return self.to_dict()["chats"] - - @classmethod - def from_dict(cls, data: Dict[str, Any]): - chats = data.get("chats", None) or data.get("conversations", None) - if not chats: - raise ValueError("No chats found") - return cls( - chats=[Message.from_dict(x) for x in chats], - jl=data.get("jl"), - model=data.get("model"), - **data.get("meta", {}), - ) - - def to_ft( - self, id: Any = None, drop_last: bool = False - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - chats = self.chats if not drop_last else self.chats[:-1] - ft_dict = { - "id": id or get_random_string(6), - "conversations": [x.to_dict(ft=True) for x in chats], - } - if drop_last: - ft_dict["last"] = self.chats[-1].to_dict(ft=True) - return ft_dict, self.meta - - # modifications - - def copy(self) -> "Chat": - return Chat( - chats=[x for x in self.chats], - jl=self.jl, - model=self.model, - **self.meta, - ) - - def add(self, message: Message): - self.chats.append(message) - - -# these are the classes that we use for tune datasets from r-stack - - -class TuneChats(list): - """This class implements some basic container methods for a list of Chat objects""" - - def __init__(self): - self.keys = {} - self.items: List[Chat] = [] - self.idx_dict: Dict[int, Tuple[Any, ...]] = {} - self.key_to_items_idx: Dict[int, List[int]] = {} - - def __repr__(self) -> str: - return ( - f"TuneChats(unq_keys={len(self.key_to_items_idx)}, items={len(self.items)})" - ) - - def __len__(self) -> int: - return len(self.items) - - def __iter__(self) -> Generator[Chat, None, None]: - for x in self.items: - yield x - - def stream(self) -> Generator[Chat, None, None]: - for x in self: - yield x - - def __getitem__(self, __index) -> List[Chat]: - return self.items[__index] - - def table(self) -> str: - try: - from tabulate import tabulate - except ImportError: - raise ImportError("Install tabulate to use this method") - - table = [] - for k, v in self.idx_dict.items(): - table.append( - [ - *v, - len(self.key_to_items_idx[k]), - f"{len(self.key_to_items_idx[k])/len(self)*100:0.2f}%", - ] - ) - return tabulate(table, headers=[*list(self.keys), "count", "percentage"]) - - # data manipulation - - def append(self, __object: Any) -> None: - if not self.items: - self.keys = __object.meta.keys() - if self.keys != __object.meta.keys(): - raise ValueError("Keys should match") - self.idx_dict.setdefault(__object.value_hash, __object.values) - self.key_to_items_idx.setdefault(__object.value_hash, []) - self.key_to_items_idx[__object.value_hash].append(len(self.items)) - self.items.append(__object) - - def add(self, x: Chat): - return self.append(x) - - def extend(self, __iterable: Iterable) -> None: - if hasattr(__iterable, "items"): - for x in __iterable.items: # type: ignore - self.append(x) - elif isinstance(__iterable, Iterable): - for x in __iterable: - self.append(x) - else: - raise ValueError("Unknown iterable") - - def shuffle(self, seed: Optional[int] = None) -> None: - """Perform in place shuffle""" - # shuffle using indices, self.items and self.key_to_items_idx - idx = list(range(len(self.items))) - if seed: - rng = random.Random(seed) - rng.shuffle(idx) - else: - random.shuffle(idx) - self.items = [self.items[i] for i in idx] - self.key_to_items_idx = {} - for i, x in enumerate(self.items): - self.key_to_items_idx.setdefault(x.value_hash, []) - self.key_to_items_idx[x.value_hash].append(i) - - def create_te_split(self, test_items: int | float = 0.1) -> Tuple["TuneChats", ...]: - try: - import numpy as np - except ImportError: - raise ImportError("Install numpy to use `create_te_split` method") - - train_ds = TuneChats() - eval_ds = TuneChats() - items_np_arr = np.array(self.items) - for k, v in self.key_to_items_idx.items(): - if isinstance(test_items, float): - if int(len(v) * test_items) < 1: - raise ValueError( - f"Test percentage {test_items} is too high for the dataset key '{k}'" - ) - split_ids = random.sample(v, int(len(v) * test_items)) - else: - if test_items > len(v): - raise ValueError( - f"Test items {test_items} is too high for the dataset key '{k}'" - ) - split_ids = random.sample(v, test_items) - - # get items - eval_items = items_np_arr[split_ids] - train_items = items_np_arr[np.setdiff1d(v, split_ids)] - train_ds.extend(train_items) - eval_ds.extend(eval_items) - - return train_ds, eval_ds - - # ser / deser - - def to_dict(self): - return {"items": [x.to_dict() for x in self.items]} - - @classmethod - def from_dict(cls, data): - bench_dataset = cls() - for item in data["items"]: - bench_dataset.append(Chat.from_dict(item)) - return bench_dataset - - def to_disk(self, folder: str, fmt: Optional[str] = None): - if fmt: - logger.warn( - f"exporting to {fmt} format, you cannot recreate the dataset from this." - ) - os.makedirs(folder) - with open(f"{folder}/tuneds.jsonl", "w") as f: - for sample in self.items: - if fmt == "sharegpt": - item, _ = sample.to_ft() - elif fmt is None: - item = sample.to_dict() - else: - raise ValueError(f"Unknown format: {fmt}") - f.write(to_json(item, tight=True) + "\n") # type: ignore - - @classmethod - def from_disk(cls, folder: str): - bench_dataset = cls() - with open(f"{folder}/tuneds.jsonl", "r") as f: - for line in f: - item = json.loads(line) - bench_dataset.append(Chat.from_dict(item)) - return bench_dataset - - def to_hf_dataset(self) -> Tuple["datasets.Dataset", List]: # type: ignore - try: - import datasets as dst - except ImportError: - raise ImportError("Install huggingface datasets library to use this method") - - _ds_list = [] - meta_list = [] - for x in self.items: - sample, meta = x.to_ft() - _ds_list.append(sample) - meta_list.append(meta) - return dst.Dataset.from_list(_ds_list), meta_list - - # properties - - def can_train_koro_regression(self) -> bool: - return all(["koro.regression" in x.meta for x in self]) - - -class TuneDataset: - """This class is a container for training and evaulation datasets, useful for serialising items to and from disk""" - - def __init__(self, train: TuneChats, eval: TuneChats): - self.train_ds = train - self.eval_ds = eval - - def __repr__(self) -> str: - return f"TuneDataset(\n train={self.train_ds},\n eval={self.eval_ds}\n)" - - @classmethod - def from_list(cls, items: List["TuneDataset"]): - train_ds = TuneChats() - eval_ds = TuneChats() - for item in items: - train_ds.extend(item.train_ds) - eval_ds.extend(item.eval_ds) - return cls(train=train_ds, eval=eval_ds) - - def to_hf_dict(self) -> Tuple["datasets.DatasetDict", Dict[str, List]]: # type: ignore - try: - import datasets as dst - except ImportError: - raise ImportError("Install huggingface datasets library to use this method") - - train_ds, train_meta = self.train_ds.to_hf_dataset() - eval_ds, eval_meta = self.eval_ds.to_hf_dataset() - return dst.DatasetDict(train=train_ds, eval=eval_ds), { - "train": train_meta, - "eval": eval_meta, - } - - def to_disk(self, folder: str, fmt: Optional[str] = None): - config = {} - config["type"] = "tune" - config["hf_type"] = fmt - os.makedirs(folder) - self.train_ds.to_disk(f"{folder}/train", fmt=fmt) - self.eval_ds.to_disk(f"{folder}/eval", fmt=fmt) - to_json(config, fp=f"{folder}/tune_config.json", tight=True) - - @classmethod - def from_disk(cls, folder: str): - if not os.path.exists(folder): - raise ValueError(f"Folder '{folder}' does not exist") - if not os.path.exists(f"{folder}/train"): - raise ValueError(f"Folder '{folder}/train' does not exist") - if not os.path.exists(f"{folder}/eval"): - raise ValueError(f"Folder '{folder}/eval' does not exist") - if not os.path.exists(f"{folder}/tune_config.json"): - raise ValueError(f"File '{folder}/tune_config.json' does not exist") - - # not sure what to do with these - with open(f"{folder}/tune_config.json", "r") as f: - config = json.load(f) - return cls( - train=TuneChats.from_disk(f"{folder}/train"), - eval=TuneChats.from_disk(f"{folder}/eval"), - ) diff --git a/chainfury/cli.py b/chainfury/cli.py index 8c44d20..426f94e 100644 --- a/chainfury/cli.py +++ b/chainfury/cli.py @@ -1,85 +1,23 @@ # Copyright © 2023- Frello Technology Private Limited +import dotenv + +dotenv.load_dotenv() + import os import sys import json from fire import Fire +from typing import Optional from chainfury import Chain from chainfury.version import __version__ from chainfury.components import all_items from chainfury.core import model_registry, programatic_actions_registry, memory_registry +from chainfury.chat import Chat, Message -def run( - chain: str, - inp: str, - stream: bool = False, - print_thoughts: bool = False, - f=sys.stdout, -): - """ - Run a chain with input and write the outputs. - - Args: - chain (str): This can be one of json filepath (e.g. "/chain.json"), json string (e.g. '{"id": "99jcjs9j2", ...}'), - chain id (e.g. "99jcjs9j2") - inp (str): This can be one of json filepath (e.g. "/input.json"), json string (e.g. '{"foo": "bar", ...}') - stream (bool, optional): Whether to stream the output. Defaults to False. - print_thoughts (bool, optional): Whether to print thoughts. Defaults to False. - f (file, optional): File to write the output to. Defaults to `sys.stdout`. - - Examples: - >>> $ cf run ./sample.json {"foo": "bar"} - """ - # validate inputs - if isinstance(inp, str): - if os.path.exists(inp): - with open(inp, "r") as f: - inp = json.load(f) - else: - try: - inp = json.loads(inp) - except Exception as e: - raise ValueError( - "Input must be a valid json string or a json file path" - ) - assert isinstance(inp, dict), "Input must be a dict" - - # create chain - chain_obj = None - if isinstance(chain, str): - if os.path.exists(chain): - with open(chain, "w") as f: - chain = json.load(f) - if len(chain) == 8: - chain_obj = Chain.from_id(chain) - else: - chain = json.loads(chain) - elif isinstance(chain, dict): - chain_obj = Chain.from_dict(chain) - assert chain_obj is not None, "Chain not found" - - # output - if isinstance(f, str): - f = open(f, "w") - - # run the chain - if stream: - cf_response_gen = chain_obj.stream(inp, print_thoughts=print_thoughts) - for ir, done in cf_response_gen: - if not done: - f.write(json.dumps(ir) + "\n") - else: - out, buffer = chain_obj(inp, print_thoughts=print_thoughts) - for k, v in buffer.items(): - f.write(json.dumps({k: v}) + "\n") - - # close file - f.close() - - -class __CLI: +class CLI: info = rf""" ___ _ _ ___ / __| |_ __ _(_)_ _ | __| _ _ _ _ _ @@ -90,39 +28,121 @@ class __CLI: ae e0 a5 87 e0 a4 b5 20 e0 a4 9c e0 a4 af e0 a4 a4 e0 a5 87 - cf_version: {__version__} 🦋 The FOSS chaining engine behind chat.tune.app - -A powerful way to program for the "Software 2.0" era. Read more: - -- https://tunehq.ai -- https://chat.tune.app -- https://studio.tune.app 🌟 us on https://github.com/NimbleBoxAI/ChainFury - -Build with ♥️ by Tune AI from the Koro coast 🌊 Chennai, India +♥️ Built by [Tune AI](https://tunehq.ai) from ECR, Chennai 🌊 """ - comp = { - "all": lambda: print(all_items), - "model": { - "list": list(model_registry.get_models()), - "all": model_registry.get_models(), - "get": model_registry.get, - }, - "prog": { - "list": list(programatic_actions_registry.get_nodes()), - "all": programatic_actions_registry.get_nodes(), - }, - "memory": { - "list": list(memory_registry.get_nodes()), - "all": memory_registry.get_nodes(), - }, - } - run = run + def run( + self, + chain: str, + inp: str, + stream: bool = False, + print_thoughts: bool = False, + f=sys.stdout, + ): + """ + Run a chain with input and write the outputs. + + Args: + chain (str): This can be one of json filepath (e.g. "/chain.json"), json string (e.g. '{"id": "99jcjs9j2", ...}'), + chain id (e.g. "99jcjs9j2") + inp (str): This can be one of json filepath (e.g. "/input.json"), json string (e.g. '{"foo": "bar", ...}') + stream (bool, optional): Whether to stream the output. Defaults to False. + print_thoughts (bool, optional): Whether to print thoughts. Defaults to False. + f (file, optional): File to write the output to. Defaults to `sys.stdout`. + + Examples: + >>> $ cf run ./sample.json {"foo": "bar"} + """ + # validate inputs + if isinstance(inp, str): + if os.path.exists(inp): + with open(inp, "r") as f: + inp = json.load(f) + else: + try: + inp = json.loads(inp) + except Exception as e: + raise ValueError( + "Input must be a valid json string or a json file path" + ) + assert isinstance(inp, dict), "Input must be a dict" + + # create chain + chain_obj = None + if isinstance(chain, str): + if os.path.exists(chain): + with open(chain, "w") as f: + chain = json.load(f) + if len(chain) == 8: + chain_obj = Chain.from_id(chain) + else: + chain = json.loads(chain) + elif isinstance(chain, dict): + chain_obj = Chain.from_dict(chain) + assert chain_obj is not None, "Chain not found" + + # output + if isinstance(f, str): + f = open(f, "w") + + # run the chain + if stream: + cf_response_gen = chain_obj.stream(inp, print_thoughts=print_thoughts) + for ir, done in cf_response_gen: + if not done: + f.write(json.dumps(ir) + "\n") + else: + out, buffer = chain_obj(inp, print_thoughts=print_thoughts) + for k, v in buffer.items(): + f.write(json.dumps({k: v}) + "\n") + + # close file + f.close() + + def sh( + self, + api: str = "tuneapi", + model: str = "rohan/mixtral-8x7b-inst-v0-1-32k", # "kaushikaakash04/tune-blob" + token: Optional[str] = None, + stream: bool = True, + ): + cf_model = model_registry.get(api) + if token is not None: + cf_model.set_api_token(token) + + # loop for user input through command line + chat = Chat() + usr_cntr = 0 + while True: + try: + user_input = input( + f"\033[1m\033[33m [{usr_cntr:02d}] YOU \033[39m:\033[0m " + ) + except KeyboardInterrupt: + break + if user_input == "exit" or user_input == "quit" or user_input == "": + break + chat.add(Message(user_input, Message.HUMAN)) + + print(f"\033[1m\033[34m ASSISTANT \033[39m:\033[0m ", end="", flush=True) + if stream: + response = "" + for str_token in cf_model.stream_chat(chat, model=model): + response += str_token + print(str_token, end="", flush=True) + print() # new line + chat.add(Message(response, Message.GPT)) + else: + response = cf_model.chat(chat, model=model) + print(response) + + chat.add(Message(response, Message.GPT)) + usr_cntr += 1 def main(): - Fire(__CLI) + Fire(CLI) diff --git a/chainfury/components/const.py b/chainfury/components/const.py index fbb7d6e..59a99a6 100644 --- a/chainfury/components/const.py +++ b/chainfury/components/const.py @@ -12,7 +12,7 @@ class Env: * CF_URL: ChainFury API URL * NBX_DEPLOY_URL: NimbleBox Deploy URL * NBX_DEPLOY_KEY: NimbleBox Deploy API key - * TUNECHAT_KEY: ChatNBX API key, see chat.nbox.ai + * TUNEAPI_TOKEN: ChatNBX API key, see chat.nbox.ai * OPENAI_TOKEN: OpenAI API token, see platform.openai.com * SERPER_API_KEY: Serper API key, see serper.dev/ * STABILITY_KEY: Stability API key, see dreamstudio.ai @@ -29,7 +29,7 @@ class Env: NBX_DEPLOY_KEY = lambda x: x or os.getenv("NBX_DEPLOY_KEY", "") ## different keys for different 3rd party APIs - TUNECHAT_KEY = lambda x: x or os.getenv("TUNECHAT_KEY", "") + TUNEAPI_TOKEN = lambda x: x or os.getenv("TUNEAPI_TOKEN", "") OPENAI_TOKEN = lambda x: x or os.getenv("OPENAI_TOKEN", "") SERPER_API_KEY = lambda x: x or os.getenv("SERPER_API_KEY", "") diff --git a/chainfury/components/openai/__init__.py b/chainfury/components/openai/__init__.py index bfbcd35..0fddf4c 100644 --- a/chainfury/components/openai/__init__.py +++ b/chainfury/components/openai/__init__.py @@ -13,23 +13,27 @@ UnAuthException, ) from chainfury.components.const import Env -from chainfury.chat import Chat +from chainfury.types import Thread class OpenaiGPTModel(Model): def __init__(self, id: Optional[str] = None): self._openai_model_id = id + self.openai_api_key = Secret(Env.OPENAI_TOKEN("")) super().__init__( id="openai-chat", description="Use OpenAI chat models", usage=["usage", "total_tokens"], ) + def set_api_token(self, token: str) -> None: + self.openai_api_key = Secret(token) + def chat( self, - chats: Chat, + chats: Thread, model: Optional[str] = None, - openai_api_key: Secret = Secret(""), + token: Secret = Secret(""), temperature: float = 1.0, top_p: float = 1.0, n: int = 1, @@ -49,7 +53,7 @@ def chat( Args: messages: A list of messages describing the conversation so far model: ID of the model to use. See [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat/create). - openai_api_key (Secret): The OpenAI API key. Defaults to "" or the OPENAI_TOKEN environment variable. + token (Secret): The OpenAI API key. Defaults to "" or the OPENAI_TOKEN environment variable. temperature: Optional. What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both. Defaults to 1. top_p: Optional. An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both. Defaults to 1. n: Optional. How many chat completion choices to generate for each input message. Defaults to 1. @@ -65,14 +69,13 @@ def chat( Returns: Any: The completion(s) generated by the API. """ - if not openai_api_key: - openai_api_key = Secret(Env.OPENAI_TOKEN("")).value # type: ignore - if not openai_api_key: + if not token and not self.openai_api_key.value: raise Exception( "OpenAI API key not found. Please set OPENAI_TOKEN environment variable or pass through function" ) - if isinstance(chats, Chat): - messages = chats.to_dict() + + if isinstance(chats, Thread): + messages = chats.to_dict()["chats"] else: messages = chats @@ -83,7 +86,7 @@ def _fn(): "https://api.openai.com/v1/chat/completions", headers={ "Content-Type": "application/json", - "Authorization": f"Bearer {openai_api_key}", + "Authorization": f"Bearer {token}", }, json={ "model": model, diff --git a/chainfury/components/tune/__init__.py b/chainfury/components/tune/__init__.py index b1caeda..8c556e2 100644 --- a/chainfury/components/tune/__init__.py +++ b/chainfury/components/tune/__init__.py @@ -7,7 +7,7 @@ from chainfury import Secret, model_registry, exponential_backoff, Model from chainfury.components.const import Env -from chainfury.chat import Chat +from chainfury.types import Thread class TuneModel(Model): @@ -15,23 +15,79 @@ class TuneModel(Model): def __init__(self, id: Optional[str] = None): self._tune_model_id = id + self.tune_api_token = Secret(Env.TUNEAPI_TOKEN("")) super().__init__( - id="chatnbx", - description="Chat with the ChatNBX API with OpenAI compatability, see more at https://chat.nbox.ai/", + id="tuneapi", + description="Chat with the Tune Studio APIs, see more at https://studio.tune.app/", usage=["usage", "total_tokens"], ) + def set_api_token(self, token: str) -> None: + self.tune_api_token = Secret(token) + def chat( self, - chats: Chat, - chatnbx_api_key: Secret = Secret(""), + chats: Thread, model: Optional[str] = None, max_tokens: int = 1024, temperature: float = 1, *, - retry_count: int = 3, - retry_delay: int = 1, + token: Secret = Secret(""), ) -> Dict[str, Any]: + """ + Chat with the Tune Studio APIs, see more at https://studio.tune.app/ + + Note: This is a API is partially compatible with OpenAI's API, so `messages` should be of type :code:`[{"role": ..., "content": ...}]` + + Args: + model (str): The model to use, see https://studio.nbox.ai/ for more info + messages (List[Dict[str, str]]): A list of messages to send to the API which are OpenAI compatible + token (Secret, optional): The API key to use or set TUNEAPI_TOKEN environment variable + max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 1024. + temperature (float, optional): The higher the temperature, the crazier the text. Defaults to 1. + + Returns: + Dict[str, Any]: The response from the API + """ + if not token and not self.tune_api_token.has_value(): # type: ignore + raise Exception( + "Tune API key not found. Please set TUNEAPI_TOKEN environment variable or pass through function" + ) + token = token or self.tune_api_token + if isinstance(chats, Thread): + messages = chats.to_dict()["chats"] + else: + messages = chats + + model = model or self._tune_model_id + url = "https://proxy.tune.app/chat/completions" + headers = { + "Authorization": token.value, + "Content-Type": "application/json", + } + data = { + "temperature": temperature, + "messages": messages, + "model": model, + "stream": False, + "max_tokens": max_tokens, + } + response = requests.post(url, headers=headers, json=data) + try: + response.raise_for_status() + except Exception as e: + raise e + return response.json()["choices"][0]["message"]["content"] + + def stream_chat( + self, + chats: Thread, + model: Optional[str] = None, + max_tokens: int = 1024, + temperature: float = 1, + *, + token: Secret = Secret(""), + ): """ Chat with the ChatNBX API with OpenAI compatability, see more at https://chat.nbox.ai/ @@ -40,45 +96,57 @@ def chat( Args: model (str): The model to use, see https://chat.nbox.ai/ for more info messages (List[Dict[str, str]]): A list of messages to send to the API which are OpenAI compatible - chatnbx_api_key (Secret, optional): The API key to use or set TUNECHAT_KEY environment variable + token (Secret, optional): The API key to use or set TUNEAPI_TOKEN environment variable max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 1024. temperature (float, optional): The higher the temperature, the crazier the text. Defaults to 1. Returns: Dict[str, Any]: The response from the API """ - if not chatnbx_api_key: - chatnbx_api_key = Secret(Env.TUNECHAT_KEY("")).value # type: ignore - if not chatnbx_api_key: + if not token and not self.tune_api_token.has_value(): # type: ignore raise Exception( - "OpenAI API key not found. Please set TUNECHAT_KEY environment variable or pass through function" + "Tune API key not found. Please set TUNEAPI_TOKEN environment variable or pass through function" ) - if isinstance(chats, Chat): - messages = chats.to_dict() + + token = token or self.tune_api_token + if isinstance(chats, Thread): + messages = chats.to_dict()["chats"] else: messages = chats model = model or self._tune_model_id - - def _fn(): - url = "https://proxy.tune.app/chat/completions" - headers = { - "Authorization": chatnbx_api_key, - "Content-Type": "application/json", - } - data = { - "temperature": temperature, - "messages": messages, - "model": model, - "stream": False, - "max_tokens": max_tokens, - } - response = requests.post(url, headers=headers, json=data) - return response.json()["choices"][0]["message"]["content"] - - return exponential_backoff( - _fn, max_retries=retry_count, retry_delay=retry_delay + url = "https://proxy.tune.app/chat/completions" + headers = { + "Authorization": token.value, + "Content-Type": "application/json", + } + data = { + "temperature": temperature, + "messages": messages, + "model": model, + "stream": True, + "max_tokens": max_tokens, + } + response = requests.post( + url, + headers=headers, + json=data, + stream=True, ) + try: + response.raise_for_status() + except Exception as e: + print(response.text) + raise e + for line in response.iter_lines(): + line = line.decode().strip() + if line: + try: + yield json.loads(line.replace("data: ", ""))["choices"][0]["delta"][ + "content" + ] + except: + break tune_model = model_registry.register(model=TuneModel()) diff --git a/chainfury/types.py b/chainfury/types.py index 1829af4..185607f 100644 --- a/chainfury/types.py +++ b/chainfury/types.py @@ -4,10 +4,23 @@ from typing import Dict, Any, List, Optional from pydantic import BaseModel, Field, ConfigDict +# some types that are copied from the tuneapi types + +from tuneapi.types.chats import ( + Message, + Thread, + ThreadsList, + Dataset, + human, + system, + assistant, +) + + # First is the set of types that are used in the chainfury itself -class FENode(BaseModel): +class UINode(BaseModel): """FENode is the node as required by the UI to render the node in the graph. If you do not care about the UI, you can populate either the ``cf_id`` or ``cf_data``.""" @@ -56,14 +69,14 @@ class Edge(BaseModel): class Dag(BaseModel): """This is visual representation of the chain. JSON of this is stored in the DB.""" - nodes: List[FENode] + nodes: List[UINode] edges: List[Edge] sample: Dict[str, Any] = Field(default_factory=dict) main_in: str = "" main_out: str = "" -class CFPromptResult(BaseModel): +class ChainResult(BaseModel): """This is a structured result of the prompt by the Chain. This is more useful for providing types on the server.""" result: str @@ -71,6 +84,9 @@ class CFPromptResult(BaseModel): task_id: str = "" +# Then a set of types that are used in the API (client mode) + + class ApiLoginResponse(BaseModel): message: str token: Optional[str] = None @@ -178,14 +194,14 @@ class ApiPromptFeedbackResponse(BaseModel): rating: int -class ApiSaveTokenRequest(BaseModel): +class ApiToken(BaseModel): key: str token: str meta: Optional[Dict[str, Any]] = {} class ApiListTokensResponse(BaseModel): - tokens: List[ApiSaveTokenRequest] + tokens: List[ApiToken] class ApiChainLog(BaseModel): diff --git a/chainfury/utils.py b/chainfury/utils.py index 4276569..2d730d5 100644 --- a/chainfury/utils.py +++ b/chainfury/utils.py @@ -32,13 +32,13 @@ class CFEnv: CF_LOG_LEVEL = lambda: os.getenv("CF_LOG_LEVEL", "info") CF_FOLDER = lambda: os.path.expanduser(os.getenv("CF_FOLDER", "~/cf")) + CF_URL = lambda: os.getenv("CF_URL", "") + CF_TOKEN = lambda: os.getenv("CF_TOKEN", "") CF_BLOB_STORAGE = lambda: os.path.join(CFEnv.CF_FOLDER(), "blob") CF_BLOB_ENGINE = lambda: os.getenv("CF_BLOB_ENGINE", "local") CF_BLOB_BUCKET = lambda: os.getenv("CF_BLOB_BUCKET", "") CF_BLOB_PREFIX = lambda: os.getenv("CF_BLOB_PREFIX", "") CF_BLOB_AWS_CLOUD_FRONT = lambda: os.getenv("CF_BLOB_AWS_CLOUD_FRONT", "") - CF_URL = lambda: os.getenv("CF_URL", "") - CF_TOKEN = lambda: os.getenv("CF_TOKEN", "") def store_blob(key: str, value: bytes, engine: str = "", bucket: str = "") -> str: diff --git a/chainfury/version.py b/chainfury/version.py index c7502fa..7eb5f4c 100644 --- a/chainfury/version.py +++ b/chainfury/version.py @@ -1,6 +1,6 @@ # Copyright © 2023- Frello Technology Private Limited -__version__ = "1.7.0a1" +__version__ = "1.7.0a2" _major, _minor, _patch = __version__.split(".") _major = int(_major) _minor = int(_minor) diff --git a/extra/ex1/chain.py b/extra/ex1/chain.py deleted file mode 100644 index ba890d8..0000000 --- a/extra/ex1/chain.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright © 2023- Frello Technology Private Limited - -from fire import Fire - -from chainfury.base import Chain -from chainfury.chat import human, Message, Chat - -from chainfury.components.openai import OpenaiGPTModel -from chainfury.components.tune import TuneModel - - -def main(q: str, openai: bool = False): - chain = Chain( - name="demo-one", - description=( - "Building the hardcore example of chain at https://nimbleboxai.github.io/ChainFury/examples/usage-hardcore.html " - "using threaded chains" - ), - main_in="stupid_question", - main_out="fight_scene/fight_scene", - default_model=( - OpenaiGPTModel("gpt-3.5-turbo") - if openai - else TuneModel("rohan/mixtral-8x7b-inst-v0-1-32k") - ), - ) - print("before:") - print(chain) - - chain = chain.add_thread( - "character_one", - Chat( - [ - human( - "You were who was running in the middle of desert. You see a McDonald's and the waiter ask a stupid " - "question like: '{{ stupid_question }}'? You are pissed and you say." - ) - ] - ), - ) - - chain = chain.add_thread( - "character_two", - Chat( - [ - human( - "Someone comes upto you in a bar and screams '{{ character_one }}'? You are a bartender give a funny response to it." - ) - ] - ), - ) - - chain = chain.add_thread( - "fight_scene", - Chat( - [ - human( - "Two men were fighting in a bar. One yelled '{{ character_one }}'. Other responded by yelling '{{ character_two }}'.\n" - "Continue this story for 3 more lines." - ) - ] - ), - ) - - print("---------------") - print(chain) - - print(chain.topo_order) - for ir, done in chain.stream(q): - # print(ir) - pass - - print("---------------") - print(ir) - - -if __name__ == "__main__": - Fire(main) diff --git a/pyproject.toml b/pyproject.toml index 1e5f0c4..05bedbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "chainfury" -version = "1.7.0a1" +version = "1.7.0a2" description = "ChainFury is a powerful tool that simplifies the creation and management of chains of prompts, making it easier to build complex chat applications using LLMs." authors = ["Tune AI "] license = "Apache 2.0" @@ -9,6 +9,7 @@ repository = "https://github.com/NimbleBoxAI/ChainFury" [tool.poetry.dependencies] python = "^3.9,<3.12" +tuneapi = "0.1.1" fire = "0.5.0" Jinja2 = "3.1.2" jinja2schema = "0.1.4" diff --git a/server/chainfury_server/__main__.py b/server/chainfury_server/__main__.py index 54d574f..c723015 100644 --- a/server/chainfury_server/__main__.py +++ b/server/chainfury_server/__main__.py @@ -10,17 +10,6 @@ if os.path.exists(_dotenv_fp): dotenv.load_dotenv(_dotenv_fp) -__CF_LOGO = """ - ___ _ _ ___ - / __| |_ __ _(_)_ _ | __| _ _ _ _ _ -| (__| ' \/ _` | | ' \ | _| || | '_| || | - \___|_||_\__,_|_|_||_||_| \_,_|_| \_, | - |__/ -e0 a4 b8 e0 a4 a4 e0 a5 8d e0 a4 af e0 a4 -ae e0 a5 87 e0 a4 b5 20 e0 a4 9c e0 a4 af - e0 a4 a4 e0 a5 87 -""" - def main( host: str = "0.0.0.0", @@ -38,19 +27,18 @@ def main( post (List[str], optional): List of modules to load after the server is imported. Defaults to []. """ # WARNING: ensure that nothing is being imported in the utils from chainfury_server + from chainfury.cli import CLI from chainfury_server.utils import logger - from chainfury.version import __version__ as cf_version from chainfury_server.version import __version__ as cfs_version logger.info( - f"{__CF_LOGO}\n" + f"{CLI.info}\n" f"Starting ChainFury server ...\n" - f" Host: {host}\n" - f" Port: {port}\n" - f" Pre: {pre}\n" - f" Post: {post}\n" - f" chainfury version: {cf_version}\n" - f" cf_server version: {cfs_version}" + f" Host: {host}\n" + f" Port: {port}\n" + f" Pre: {pre}\n" + f" Post: {post}\n" + f" cf_server: {cfs_version}" ) # load all things you need to preload the modules diff --git a/server/chainfury_server/api/chains.py b/server/chainfury_server/api/chains.py index 97bf143..10537fa 100644 --- a/server/chainfury_server/api/chains.py +++ b/server/chainfury_server/api/chains.py @@ -200,12 +200,14 @@ def run_chain( store_ir: bool = False, store_io: bool = False, db: Session = Depends(DB.fastapi_db_session), -) -> Union[StreamingResponse, T.CFPromptResult, T.ApiResponse]: +) -> Union[StreamingResponse, T.ChainResult, T.ApiResponse]: """ This is the master function to run any chain over the API. This can behave in a bunch of different formats like: - (default) this will wait for the entire chain to execute and return the response - if ``stream`` is passed it will give a streaming response with line by line JSON and last response containing ``"done"`` key - if ``as_task`` is passed then a task ID is received and you can poll for the results at ``/chains/{id}/results`` this supercedes the ``stream``. + + ``as_task`` is not implemented. """ # validate user user = DB.get_user_from_jwt(token=token, db=db) @@ -243,15 +245,16 @@ def run_chain( if as_task: # when run as a task this will return a task ID that will be submitted - result = engine.submit( - chatbot=chatbot, - prompt=prompt, - db=db, - start=time.time(), - store_ir=store_ir, - store_io=store_io, - ) - return result + raise HTTPException(501, detail="Not implemented yet") + # result = engine.submit( + # chatbot=chatbot, + # prompt=prompt, + # db=db, + # start=time.time(), + # store_ir=store_ir, + # store_io=store_io, + # ) + # return result elif stream: def _get_streaming_response(result): diff --git a/server/chainfury_server/api/user.py b/server/chainfury_server/api/user.py index c34a908..633124c 100644 --- a/server/chainfury_server/api/user.py +++ b/server/chainfury_server/api/user.py @@ -4,17 +4,17 @@ from fastapi import HTTPException from passlib.hash import sha256_crypt from sqlalchemy.orm import Session -from fastapi import Request, Response, Depends, Header -from typing import Annotated +from fastapi import Depends, Header +from typing import Annotated, List from chainfury_server.utils import logger, Env import chainfury_server.database as DB import chainfury.types as T +from tuneapi.utils import encrypt, decrypt + def login( - req: Request, - resp: Response, auth: T.ApiAuthRequest, db: Session = Depends(DB.fastapi_db_session), ) -> T.ApiLoginResponse: @@ -26,13 +26,10 @@ def login( ) return T.ApiLoginResponse(message="success", token=token) else: - resp.status_code = 401 - return T.ApiLoginResponse(message="failed") + raise HTTPException(status_code=401, detail="Invalid username or password") def sign_up( - req: Request, - resp: Response, auth: T.ApiSignUpRequest, db: Session = Depends(DB.fastapi_db_session), ) -> T.ApiLoginResponse: @@ -67,13 +64,10 @@ def sign_up( ) return T.ApiLoginResponse(message="success", token=token) else: - resp.status_code = 400 - return T.ApiLoginResponse(message="failed") + raise HTTPException(status_code=500, detail="Unknown error") def change_password( - req: Request, - resp: Response, token: Annotated[str, Header()], inputs: T.ApiChangePasswordRequest, db: Session = Depends(DB.fastapi_db_session), @@ -87,51 +81,111 @@ def change_password( db.commit() return T.ApiResponse(message="success") else: - resp.status_code = 400 - return T.ApiResponse(message="password incorrect") - + raise HTTPException(status_code=401, detail="Invalid old password") -# TODO: @tunekoro - Implement the following functions - -def create_token( - req: Request, - resp: Response, +def create_secret( token: Annotated[str, Header()], - inputs: T.ApiSaveTokenRequest, + inputs: T.ApiToken, db: Session = Depends(DB.fastapi_db_session), ) -> T.ApiResponse: - resp.status_code = 501 # - return T.ApiResponse(message="not implemented") + # validate user + user = DB.get_user_from_jwt(token=token, db=db) + + # validate inputs + if len(inputs.token) >= DB.Tokens.MAXLEN_TOKEN: + raise HTTPException( + status_code=400, + detail=f"Token too long, should be less than {DB.Tokens.MAXLEN_TOKEN} characters", + ) + if len(inputs.key) >= DB.Tokens.MAXLEN_KEY: + raise HTTPException( + status_code=400, + detail=f"Key too long, should be less than {DB.Tokens.MAXLEN_KEY} characters", + ) + + cfs_secrets_password = Env.CFS_SECRETS_PASSWORD() + if cfs_secrets_password is None: + logger.error("CFS_TOKEN_PASSWORD not set, cannot create secrets") + raise HTTPException(500, "internal server error") + + # create a token + token = DB.Tokens( + user_id=user.id, + key=inputs.key, + value=encrypt(inputs.token, cfs_secrets_password, user.id).decode("utf-8"), + meta=inputs.meta, + ) # type: ignore + db.add(token) + db.commit() + return T.ApiResponse(message="success") -def get_token( - req: Request, - resp: Response, +def get_secret( key: str, token: Annotated[str, Header()], db: Session = Depends(DB.fastapi_db_session), -) -> T.ApiResponse: - resp.status_code = 501 # - return T.ApiResponse(message="not implemented") +) -> T.ApiToken: + # validate user + user = DB.get_user_from_jwt(token=token, db=db) + + db_token: DB.Tokens = db.query(DB.Tokens).filter(DB.Tokens.key == key, user.id == user.id).first() # type: ignore + if db_token is None: + raise HTTPException(status_code=404, detail="Token not found") + cfs_token = Env.CFS_SECRETS_PASSWORD() + if cfs_token is None: + logger.error("CFS_TOKEN_PASSWORD not set, cannot create secrets") + raise HTTPException(500, "internal server error") -def list_tokens( - req: Request, - resp: Response, + try: + db_token.value = decrypt(db_token.value, cfs_token, user.id) + except Exception as e: + raise HTTPException(status_code=401, detail="Cannot get token") + return db_token.to_ApiToken() + + +def list_secret( token: Annotated[str, Header()], + limit: int = 100, + offset: int = 0, db: Session = Depends(DB.fastapi_db_session), -) -> T.ApiResponse: - resp.status_code = 501 # - return T.ApiResponse(message="not implemented") - +) -> T.ApiListTokensResponse: + """Returns a list of token keys, and metadata. The token values are not returned.""" + # validate user + user = DB.get_user_from_jwt(token=token, db=db) -def delete_token( - req: Request, - resp: Response, + # get tokens + tokens: List[DB.Tokens] = ( + db.query(DB.Tokens) + .filter(DB.Tokens.user_id == user.id) # type: ignore + .limit(limit) + .offset(offset) + .all() + ) + tokens_resp = [] + for t in tokens: + tok = t.to_ApiToken() + tok.token = "" + tokens_resp.append(tok) + return T.ApiListTokensResponse(tokens=tokens_resp) + + +def delete_secret( key: str, token: Annotated[str, Header()], db: Session = Depends(DB.fastapi_db_session), ) -> T.ApiResponse: - resp.status_code = 501 # - return T.ApiResponse(message="not implemented") + # validate user + user = DB.get_user_from_jwt(token=token, db=db) + + # validate the user can access the token + _ = get_secret(key=key, token=token, db=db) + + # delete token + db_token: DB.Tokens = db.query(DB.Tokens).filter(DB.Tokens.key == key, user.id == user.id).first() # type: ignore + if db_token is None: + raise HTTPException(status_code=404, detail="Token not found") + db.delete(db_token) + db.commit() + return T.ApiResponse(message="success") diff --git a/server/chainfury_server/app.py b/server/chainfury_server/app.py index d5c8445..6c93e75 100644 --- a/server/chainfury_server/app.py +++ b/server/chainfury_server/app.py @@ -46,10 +46,10 @@ app.add_api_route(methods=["POST"], path="/user/login/", endpoint=api_user.login, tags=["user"]) # type: ignore app.add_api_route(methods=["POST"], path="/user/signup/", endpoint=api_user.sign_up, tags=["user"]) # type: ignore app.add_api_route(methods=["POST"], path="/user/change_password/", endpoint=api_user.change_password, tags=["user"]) # type: ignore -app.add_api_route(methods=["PUT"], path="/user/token/", endpoint=api_user.create_token, tags=["user"]) # type: ignore -app.add_api_route(methods=["GET"], path="/user/token/", endpoint=api_user.get_token, tags=["user"]) # type: ignore -app.add_api_route(methods=["DELETE"], path="/user/token/", endpoint=api_user.delete_token, tags=["user"]) # type: ignore -app.add_api_route(methods=["GET"], path="/user/tokens/list/", endpoint=api_user.list_tokens, tags=["user"]) # type: ignore +app.add_api_route(methods=["PUT"], path="/user/secret/", endpoint=api_user.create_secret, tags=["user"]) # type: ignore +app.add_api_route(methods=["GET"], path="/user/secret/", endpoint=api_user.get_secret, tags=["user"]) # type: ignore +app.add_api_route(methods=["DELETE"], path="/user/secret/", endpoint=api_user.delete_secret, tags=["user"]) # type: ignore +app.add_api_route(methods=["GET"], path="/user/secret/list/", endpoint=api_user.list_secret, tags=["user"]) # type: ignore # chains app.add_api_route(methods=["GET"], path="/api/chains/", endpoint=api_chains.list_chains, tags=["chains"]) # type: ignore diff --git a/server/chainfury_server/database.py b/server/chainfury_server/database.py index 72a85ce..938d949 100644 --- a/server/chainfury_server/database.py +++ b/server/chainfury_server/database.py @@ -182,6 +182,7 @@ class Tokens(Base): MAXLEN_KEY = 80 MAXLEN_VAL = 1024 + MAXLEN_TOKEN = 703 # 703 long string can create 1016 long token id = Column(Integer, primary_key=True) user_id = Column(String(ID_LENGTH), ForeignKey("user.id"), nullable=False) @@ -189,10 +190,18 @@ class Tokens(Base): value = Column(String(MAXLEN_VAL), nullable=False) meta = Column(JSON, nullable=True) user = relationship("User", back_populates="tokens") + # (user_id, key) is a unique constraint def __repr__(self): return f"Tokens(id={self.id}, user_id={self.user_id}, key={self.key}, value={self.value[:5]}..., meta={self.meta})" + def to_ApiToken(self) -> T.ApiToken: + return T.ApiToken( + key=self.key, + token=self.value, + meta=self.meta, + ) + class ChatBot(Base): __tablename__ = "chatbot" @@ -262,6 +271,9 @@ class Prompt(Base): session_id: Dict[str, Any] = Column(String(80), nullable=False) meta: Dict[str, Any] = Column(JSON) + # migrate to snowflake ID + sf_id = Column(String(19), nullable=True) + def to_dict(self): return { "id": self.id, diff --git a/server/chainfury_server/engine.py b/server/chainfury_server/engine.py index beb86ec..9c87fbb 100644 --- a/server/chainfury_server/engine.py +++ b/server/chainfury_server/engine.py @@ -25,7 +25,7 @@ def run( start: float, store_ir: bool, store_io: bool, - ) -> T.CFPromptResult: + ) -> T.ChainResult: if prompt.new_message and prompt.data: raise HTTPException( status_code=400, detail="prompt cannot have both new_message and data" @@ -37,7 +37,7 @@ def run( # Create a Fury chain then run the chain while logging all the intermediate steps dag = T.Dag(**chatbot.dag) # type: ignore chain = Chain.from_dag(dag, check_server=False) - callback = FuryThoughts(db, prompt_row.id) + callback = FuryThoughtsCallback(db, prompt_row.id) if prompt.new_message: prompt.data = {chain.main_in: prompt.new_message} @@ -76,7 +76,7 @@ def run( db.commit() # create the result - result = T.CFPromptResult( + result = T.ChainResult( result=( json.dumps(mainline_out) if type(mainline_out) != str @@ -108,7 +108,7 @@ def stream( start: float, store_ir: bool, store_io: bool, - ) -> Generator[Tuple[Union[T.CFPromptResult, Dict[str, Any]], bool], None, None]: + ) -> Generator[Tuple[Union[T.ChainResult, Dict[str, Any]], bool], None, None]: if prompt.new_message and prompt.data: raise HTTPException( status_code=400, detail="prompt cannot have both new_message and data" @@ -120,7 +120,7 @@ def stream( # Create a Fury chain then run the chain while logging all the intermediate steps dag = T.Dag(**chatbot.dag) # type: ignore chain = Chain.from_dag(dag, check_server=False) - callback = FuryThoughts(db, prompt_row.id) + callback = FuryThoughtsCallback(db, prompt_row.id) if prompt.new_message: prompt.data = {chain.main_in: prompt.new_message} @@ -162,7 +162,7 @@ def stream( ) # type: ignore db.add(db_chainlog) - result = T.CFPromptResult( + result = T.ChainResult( result=str(mainline_out), prompt_id=prompt_row.id, # type: ignore ) @@ -189,7 +189,7 @@ def submit( start: float, store_ir: bool, store_io: bool, - ) -> T.CFPromptResult: + ) -> T.ChainResult: if prompt.new_message and prompt.data: raise HTTPException( status_code=400, detail="prompt cannot have both new_message and data" @@ -206,7 +206,7 @@ def submit( # call the chain task_id: str = str(uuid4()) - result = T.CFPromptResult( + result = T.ChainResult( result=f"Task '{task_id}' scheduled", prompt_id=prompt_row.id, task_id=task_id, @@ -224,12 +224,10 @@ def submit( raise HTTPException(status_code=500, detail=str(e)) from e -# engine_registry.register(FuryEngine()) - # helpers -class FuryThoughts: +class FuryThoughtsCallback: def __init__(self, db, prompt_id): self.db = db self.prompt_id = prompt_id diff --git a/server/chainfury_server/utils.py b/server/chainfury_server/utils.py index 34f7faa..608b8f3 100644 --- a/server/chainfury_server/utils.py +++ b/server/chainfury_server/utils.py @@ -1,7 +1,8 @@ # Copyright © 2023- Frello Technology Private Limited import os -import logging +from Cryptodome.Cipher import AES +from base64 import b64decode, b64encode # WARNING: do not import anything from anywhere here, this is the place where chainfury_server starts. # importing anything can cause the --pre and --post flags to fail when starting server. @@ -14,13 +15,11 @@ class Env: """ Single namespace for all environment variables. - - * CFS_DATABASE: database connection string - * JWT_SECRET: secret for JWT tokens """ # once a lifetime secret JWT_SECRET = lambda: os.getenv("JWT_SECRET", "hajime-shimamoto") + CFS_SECRETS_PASSWORD = lambda: os.getenv("CFS_SECRETS_PASSWORDs") # when you want to use chainfury as a client you need to set the following vars CFS_DATABASE = lambda: os.getenv("CFS_DATABASE", None) @@ -47,3 +46,61 @@ def folder(x: str) -> str: def joinp(x: str, *args) -> str: """convienience function for os.path.join""" return os.path.join(x, *args) + + +class Crypt: + + def __init__(self, salt="SlTKeYOpHygTYkP3"): + self.salt = salt.encode("utf8") + self.enc_dec_method = "utf-8" + + def encrypt(self, str_to_enc, str_key): + try: + aes_obj = AES.new(str_key.encode("utf-8"), AES.MODE_CFB, self.salt) + hx_enc = aes_obj.encrypt(str_to_enc.encode("utf8")) + mret = b64encode(hx_enc).decode(self.enc_dec_method) + return mret + except ValueError as value_error: + if value_error.args[0] == "IV must be 16 bytes long": + raise ValueError("Encryption Error: SALT must be 16 characters long") + elif ( + value_error.args[0] == "AES key must be either 16, 24, or 32 bytes long" + ): + raise ValueError( + "Encryption Error: Encryption key must be either 16, 24, or 32 characters long" + ) + else: + raise ValueError(value_error) + + def decrypt(self, enc_str, str_key): + try: + aes_obj = AES.new(str_key.encode("utf8"), AES.MODE_CFB, self.salt) + str_tmp = b64decode(enc_str.encode(self.enc_dec_method)) + str_dec = aes_obj.decrypt(str_tmp) + mret = str_dec.decode(self.enc_dec_method) + return mret + except ValueError as value_error: + if value_error.args[0] == "IV must be 16 bytes long": + raise ValueError("Decryption Error: SALT must be 16 characters long") + elif ( + value_error.args[0] == "AES key must be either 16, 24, or 32 bytes long" + ): + raise ValueError( + "Decryption Error: Encryption key must be either 16, 24, or 32 characters long" + ) + else: + raise ValueError(value_error) + + +CURRENT_EPOCH_START = 1705905900000 # UTC timezone +"""Start of the current epoch, used for generating snowflake ids""" + +from snowflake import SnowflakeGenerator + + +class SFGen: + def __init__(self, instance, epoch=CURRENT_EPOCH_START): + self.gen = SnowflakeGenerator(instance, epoch=epoch) + + def __call__(self): + return next(self.gen) diff --git a/server/chainfury_server/version.py b/server/chainfury_server/version.py index 73d6f7c..e23f8b6 100644 --- a/server/chainfury_server/version.py +++ b/server/chainfury_server/version.py @@ -1,6 +1,6 @@ # Copyright © 2023- Frello Technology Private Limited -__version__ = "2.1.2a" +__version__ = "2.1.3a0" _major, _minor, _patch = __version__.split(".") _major = int(_major) _minor = int(_minor) diff --git a/server/pyproject.toml b/server/pyproject.toml index 58e26e6..dc1342e 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -2,7 +2,7 @@ [tool.poetry] name = "chainfury_server" -version = "2.1.2a" +version = "2.1.3a0" description = "ChainFury Server is the DB + API server for managing the ChainFury engine in production. Used in production at chat.tune.app" authors = ["Tune AI "] license = "Apache 2.0" @@ -22,7 +22,8 @@ SQLAlchemy = "1.4.47" uvicorn = "0.27.1" PyMySQL = "1.0.3" urllib3 = ">=1.26.18" -"cryptography" = ">=41.0.6" +cryptography = ">=41.0.6" +snowflake_id = "1.0.1" [tool.poetry.scripts] chainfury_server = "chainfury_server:__main__" diff --git a/tests/__main__.py b/tests/__main__.py deleted file mode 100644 index 5f90bd1..0000000 --- a/tests/__main__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright © 2023- Frello Technology Private Limited - -from tests.getkv import TestGetValueByKeys -from tests.base import TestSerDeser, TestNode -import unittest - -if __name__ == "__main__": - unittest.main() diff --git a/tests/main.py b/tests/main.py new file mode 100644 index 0000000..2a8e4fc --- /dev/null +++ b/tests/main.py @@ -0,0 +1,15 @@ +# Copyright © 2023- Frello Technology Private Limited + +import os +from tuneapi.utils import folder, joinp + +tests = [] +curdir = folder(__file__) +for x in os.listdir(curdir): + if x.startswith("test_") and x.endswith(".py"): + tests.append(joinp(curdir, x)) + +for t in tests: + code = os.system(f"python3 {t} -v") + if code != 0: + raise Exception(f"Test {t} failed with code {code}") diff --git a/tests/test_base_chain2.py b/tests/test_base_chain2.py new file mode 100644 index 0000000..611727d --- /dev/null +++ b/tests/test_base_chain2.py @@ -0,0 +1,61 @@ +# Copyright © 2023- Frello Technology Private Limited + +from chainfury import ( + Chain, + Thread, + human, +) +from chainfury.components.tune import TuneModel +import unittest + + +chain = Chain( + name="demo-one", + description=( + "Building the hardcore example of chain at https://nimbleboxai.github.io/ChainFury/examples/usage-hardcore.html " + "using threaded chains" + ), + main_in="stupid_question", + main_out="fight_scene/fight_scene", + default_model=TuneModel("rohan/mixtral-8x7b-inst-v0-1-32k"), +) +chain.add_thread( + "character_one", + Thread( + human( + "You were who was running in the middle of desert. You see a McDonald's and the waiter ask a stupid " + "question like: '{{ stupid_question }}'? You are pissed and you say." + ), + ), +) +chain.add_thread( + "character_two", + Thread( + human( + "Someone comes upto you in a bar and screams '{{ character_one }}'? You are a bartender give a funny response to it." + ), + ), +) +chain.add_thread( + "fight_scene", + Thread( + human( + "Two men were fighting in a bar. One yelled '{{ character_one }}'. Other responded by yelling '{{ character_two }}'.\n" + "Continue this story for 3 more lines." + ) + ), +) + + +class TestChain(unittest.TestCase): + """Testing Chain specific functionality""" + + def test_chain_toposort(self): + self.assertEqual( + chain.topo_order, + ["character_one", "character_two", "fight_scene"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/base.py b/tests/test_base_types.py similarity index 59% rename from tests/base.py rename to tests/test_base_types.py index 5611b48..d081c17 100644 --- a/tests/base.py +++ b/tests/test_base_types.py @@ -1,22 +1,14 @@ # Copyright © 2023- Frello Technology Private Limited -from chainfury import programatic_actions_registry, Chain -from chainfury.components.functional import echo - import unittest - - -chain = Chain( - name="echo-cf-public", - description="abyss", - nodes=[programatic_actions_registry.get("chainfury-echo")], # type: ignore - sample={"message": "hi there"}, - main_in="message", - main_out="chainfury-echo/message", -) +from functools import cache +from chainfury.components.functional import echo +from chainfury import programatic_actions_registry, Chain, Var, Tools class TestSerDeser(unittest.TestCase): + """Tests Serialisation and Deserialisation of Nodes, Chains and Tools.""" + def test_chain_dict(self): Chain.from_dict(chain.to_dict()) @@ -43,8 +35,16 @@ def test_node_json(self): self.assertIsNotNone(node) node.from_json(node.to_json()) + def test_tool_dict(self): + Tools.from_dict(tool.to_dict()) + + def test_tool_json(self): + Tools.from_json(tool.to_json()) + class TestNode(unittest.TestCase): + """Test Node specific functionality.""" + def test_node_run(self): node = programatic_actions_registry.get("chainfury-echo") if node is None: @@ -58,5 +58,54 @@ def test_node_run(self): self.assertEqual(out, {"message": fn_out}) +# +# Chain definition +# + +chain = Chain( + name="echo-cf-public", + description="abyss", + nodes=[programatic_actions_registry.get("chainfury-echo")], # type: ignore + sample={"message": "hi there"}, + main_in="message", + main_out="chainfury-echo/message", +) + + +# +# Tool definition +# +tool = Tools( + name="calculator", + description=( + "This tool is a calculator, it can perform basica calculations. " + "Use this when you are trying to do some mathematical task" + ), +) + + +@tool.add( + description="This function adds two numbers", + properties={ + "a": Var("int", required=True, description="number one"), + "b": Var("int", description="number two"), + }, +) +def add_two_numbers(a: int, b: int = 10): + return a + b + + +@tool.add( + description="This calculates square root of a number", + properties={ + "a": Var("int", description="number to calculate square root of"), + }, +) +def square_root_number(a): + import math + + return math.sqrt(a) + + if __name__ == "__main__": unittest.main() diff --git a/tests/getkv.py b/tests/test_getkv.py similarity index 100% rename from tests/getkv.py rename to tests/test_getkv.py From 2dd79eb16e0569f5122f0ccc668ada2684818c76 Mon Sep 17 00:00:00 2001 From: yashbonde Date: Wed, 13 Mar 2024 14:25:56 +0530 Subject: [PATCH 7/8] [chore] add worker architecture --- README.md | 12 +- scripts/build_and_copy.sh => build_ui.sh | 3 - chainfury/cli.py | 17 ++- chainfury/utils.py | 4 + scripts/list_builtins.py | 58 --------- server/chainfury_server/api/chains.py | 24 ++-- server/chainfury_server/database.py | 12 +- server/chainfury_server/engine.py | 148 ++++++++++++++++++++++- server/chainfury_server/utils.py | 6 +- 9 files changed, 190 insertions(+), 94 deletions(-) rename scripts/build_and_copy.sh => build_ui.sh (95%) delete mode 100644 scripts/list_builtins.py diff --git a/README.md b/README.md index 603ad1d..6cd6c95 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,14 @@ ae e0 a5 87 e0 a4 b5 20 e0 a4 9c e0 a4 af The documentation page contains all the information on using `chainfury` and `chainfury_server`. +#### `chainfury` + + + +#### `chainfury_server` + + + # Looking for Inspirations? Here's a few example to get your journey started on Software 2.0: @@ -86,7 +94,7 @@ source venv/bin/activate You will need to have `yarn` installed to build the frontend and move it to the correct location on the server ```bash -sh stories/build_and_copy.sh +sh build_ui.sh ``` Once the static files are copied we can now proceed to install dependecies: @@ -104,7 +112,7 @@ You can now visit [localhost:8000](http://localhost:8000/ui/) to see the GUI and There are a few test cases for super hard problems like `get_kv` which checks the `chainfury.base.get_value_by_keys` function. ```bash -python3 -m tests -v +python3 tests/main.py ``` # Contibutions diff --git a/scripts/build_and_copy.sh b/build_ui.sh similarity index 95% rename from scripts/build_and_copy.sh rename to build_ui.sh index 47991e4..a1cd1b9 100755 --- a/scripts/build_and_copy.sh +++ b/build_ui.sh @@ -14,9 +14,6 @@ cd client yarn install yarn build -# Go back to the root directory -cd .. - # copy the dist folder to the server # Go into the server folder, remove the old static folder and copy the new dist folder, copy index.html to templates echo "Copying the generated files to the server" diff --git a/chainfury/cli.py b/chainfury/cli.py index 426f94e..650c019 100644 --- a/chainfury/cli.py +++ b/chainfury/cli.py @@ -12,9 +12,8 @@ from chainfury import Chain from chainfury.version import __version__ -from chainfury.components import all_items -from chainfury.core import model_registry, programatic_actions_registry, memory_registry -from chainfury.chat import Chat, Message +from chainfury.core import model_registry +from chainfury.types import Thread, Message class CLI: @@ -115,7 +114,7 @@ def sh( cf_model.set_api_token(token) # loop for user input through command line - chat = Chat() + thread = Thread() usr_cntr = 0 while True: try: @@ -126,21 +125,21 @@ def sh( break if user_input == "exit" or user_input == "quit" or user_input == "": break - chat.add(Message(user_input, Message.HUMAN)) + thread.add(Message(user_input, Message.HUMAN)) print(f"\033[1m\033[34m ASSISTANT \033[39m:\033[0m ", end="", flush=True) if stream: response = "" - for str_token in cf_model.stream_chat(chat, model=model): + for str_token in cf_model.stream_chat(thread, model=model): response += str_token print(str_token, end="", flush=True) print() # new line - chat.add(Message(response, Message.GPT)) + thread.add(Message(response, Message.GPT)) else: - response = cf_model.chat(chat, model=model) + response = cf_model.chat(thread, model=model) print(response) - chat.add(Message(response, Message.GPT)) + thread.add(Message(response, Message.GPT)) usr_cntr += 1 diff --git a/chainfury/utils.py b/chainfury/utils.py index 2d730d5..77efe55 100644 --- a/chainfury/utils.py +++ b/chainfury/utils.py @@ -418,6 +418,10 @@ def get_now_float() -> float: # type: ignore """Get the current datetime in UTC timezone as a float""" return SimplerTimes.get_now_datetime().timestamp() + def get_now_fp64() -> float: # type: ignore + """Get the current datetime in UTC timezone as a float""" + return SimplerTimes.get_now_datetime().timestamp() + def get_now_i64() -> int: # type: ignore """Get the current datetime in UTC timezone as a int""" return int(SimplerTimes.get_now_datetime().timestamp()) diff --git a/scripts/list_builtins.py b/scripts/list_builtins.py deleted file mode 100644 index e6afa4b..0000000 --- a/scripts/list_builtins.py +++ /dev/null @@ -1,58 +0,0 @@ -from fire import Fire -import jinja2 as j2 -from chainfury import programatic_actions_registry, ai_actions_registry, memory_registry, model_registry - - -def main(src_file: str, trg_file: str, v: bool = False): - with open(src_file, "r") as f: - temp = j2.Template(f.read()) - - # create the components list - pc = [] - for node_id, node in programatic_actions_registry.nodes.items(): - pc.append( - { - "id": node.id, - "description": node.description.rstrip(".") + f'. Copy: ``programatic_actions_registry.get("{node.id}")``', - } - ) - - ac = [] - for node_id, node in ai_actions_registry.nodes.items(): - ac.append( - { - "id": node.id, - "description": node.description.rstrip(".") + f'. Copy: ``ai_actions_registry.get("{node.id}")``', - } - ) - - mc = [] - for node_id, node in memory_registry._memories.items(): - fn = "get_read" if node.id.endswith("-read") else "get_write" - mc.append( - { - "id": node.id, - "description": node.description.rstrip(".") + f'. Copy: ``memory_registry.{fn}("{node.id}")``', - } - ) - - moc = [] - for model_id, model in model_registry.models.items(): - moc.append( - { - "id": model_id, - "description": model.description.rstrip(".") + f'. Copy: ``model_registry.get("{model_id}")``', - } - ) - - op = temp.render(pc=pc, ac=ac, mc=mc, moc=moc) - if v: - print(op) - print("Writing to", trg_file) - - with open(trg_file, "w") as f: - f.write(op) - - -if __name__ == "__main__": - Fire(main) diff --git a/server/chainfury_server/api/chains.py b/server/chainfury_server/api/chains.py index 10537fa..822ca8c 100644 --- a/server/chainfury_server/api/chains.py +++ b/server/chainfury_server/api/chains.py @@ -31,10 +31,10 @@ def create_chain( return T.ApiResponse(message="Name not specified") if chatbot_data.dag: for n in chatbot_data.dag.nodes: - if len(n.id) > Env.CFS_MAXLEN_CF_NDOE(): + if len(n.id) > Env.CFS_MAXLEN_CF_NODE(): raise HTTPException( status_code=400, - detail=f"Node ID length cannot be more than {Env.CFS_MAXLEN_CF_NDOE()}", + detail=f"Node ID length cannot be more than {Env.CFS_MAXLEN_CF_NODE()}", ) # DB call @@ -245,16 +245,16 @@ def run_chain( if as_task: # when run as a task this will return a task ID that will be submitted - raise HTTPException(501, detail="Not implemented yet") - # result = engine.submit( - # chatbot=chatbot, - # prompt=prompt, - # db=db, - # start=time.time(), - # store_ir=store_ir, - # store_io=store_io, - # ) - # return result + # raise HTTPException(501, detail="Not implemented yet") + result = engine.submit( + chatbot=chatbot, + prompt=prompt, + db=db, + start=time.time(), + store_ir=store_ir, + store_io=store_io, + ) + return result elif stream: def _get_streaming_response(result): diff --git a/server/chainfury_server/database.py b/server/chainfury_server/database.py index 938d949..799b538 100644 --- a/server/chainfury_server/database.py +++ b/server/chainfury_server/database.py @@ -10,7 +10,7 @@ from dataclasses import dataclass, asdict from typing import Dict, Any -from sqlalchemy.pool import QueuePool +from sqlalchemy.pool import QueuePool, NullPool from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session, scoped_session, sessionmaker, relationship @@ -55,6 +55,8 @@ ) else: logger.info(f"Using via database URL") + # https://stackoverflow.com/a/73764136 + # engine = create_engine( db, poolclass=QueuePool, @@ -84,7 +86,7 @@ def get_random_number(length) -> int: return random_numbers -def get_local_session() -> sessionmaker: +def get_local_session(engine) -> sessionmaker: logger.debug("Database opened successfully") SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) return SessionLocal @@ -101,7 +103,7 @@ def db_session() -> Session: # type: ignore def fastapi_db_session(): - sess_cls = get_local_session() + sess_cls = get_local_session(engine) db = sess_cls() try: yield db @@ -272,7 +274,7 @@ class Prompt(Base): meta: Dict[str, Any] = Column(JSON) # migrate to snowflake ID - sf_id = Column(String(19), nullable=True) + # sf_id = Column(String(19), nullable=True) def to_dict(self): return { @@ -303,7 +305,7 @@ class ChainLog(Base): ) created_at: datetime = Column(DateTime, nullable=False) prompt_id: int = Column(Integer, ForeignKey("prompt.id"), nullable=False) - node_id: str = Column(String(Env.CFS_MAXLEN_CF_NDOE()), nullable=False) + node_id: str = Column(String(Env.CFS_MAXLEN_CF_NODE()), nullable=False) worker_id: str = Column(String(Env.CFS_MAXLEN_WORKER()), nullable=False) message: str = Column(Text, nullable=False) data: Dict[str, Any] = Column(JSON, nullable=True) diff --git a/server/chainfury_server/engine.py b/server/chainfury_server/engine.py index 9c87fbb..ec5c278 100644 --- a/server/chainfury_server/engine.py +++ b/server/chainfury_server/engine.py @@ -15,6 +15,116 @@ import chainfury_server.database as DB from chainfury_server.utils import logger +from celery import Celery + +from sqlalchemy.pool import NullPool +from sqlalchemy import create_engine + + +app = Celery() + + +@app.task(name="chainfury_server.engine.run_chain") +def run_chain( + chatbot_id: str, + prompt_id: str, + prompt_data: Dict, + store_ir: bool, + store_io: bool, + worker_id: str, +): + start = SimplerTimes.get_now_fp64() + + # create the DB session + sess = DB.get_local_session( + create_engine( + DB.db, + poolclass=NullPool, + ) + ) + db = sess() + + # get the db object + chatbot = db.query(DB.ChatBot).filter(DB.ChatBot.id == chatbot_id).first() # type: ignore + prompt_row: DB.Prompt = db.query(DB.Prompt).filter(DB.Prompt.id == prompt_id).first() # type: ignore + if prompt_row is None: + time.sleep(2) + prompt_row = db.query(DB.Prompt).filter(DB.Prompt.id == prompt_id).first() # type: ignore + if prompt_row is None: + raise RuntimeError(f"Prompt {prompt_id} not found") + + # Create a Fury chain then run the chain while logging all the intermediate steps + dag = T.Dag(**chatbot.dag) # type: ignore + chain = Chain.from_dag(dag, check_server=False) + callback = FuryThoughtsCallback(db, prompt_row.id) + + # print( + # f"starting chain execution: [{prompt_row.meta.get('task_id')=}] [{worker_id=}]" + # ) + iterator = chain.stream( + data=prompt_data, + thoughts_callback=callback, + print_thoughts=False, + ) + mainline_out = "" + last_db = 0 + for ir, done in iterator: + if done: + mainline_out = ir + break + + if store_ir: + # in case of stream, every item is a fundamentally a step + data = { + "outputs": [ + { + "name": k.split("/")[-1], + "data": v, + } + for k, v in ir.items() + ] + } + k = next(iter(ir)).split("/")[0] + db_chainlog = DB.ChainLog( + prompt_id=prompt_row.id, + created_at=SimplerTimes.get_now_datetime(), + node_id=k, + worker_id=worker_id, + message="step", + data=data, + ) # type: ignore + db.add(db_chainlog) + + # update the DB every 5 seconds + if time.time() - last_db > 5: + db.commit() + last_db = time.time() + + result = T.ChainResult( + result=str(mainline_out), + prompt_id=prompt_row.id, # type: ignore + ) + + db_chainlog = DB.ChainLog( + prompt_id=prompt_row.id, + created_at=SimplerTimes.get_now_datetime(), + node_id="end", + worker_id=worker_id, + message="completed", + ) # type: ignore + db.add(db_chainlog) + + # commit the prompt to DB + if store_io: + prompt_row.response = result.result # type: ignore + prompt_row.time_taken = float(time.time() - start) # type: ignore + + # update the DB after sleeping a bit + st = time.time() - last_db + if st < 2: + time.sleep(2 - st) # be nice to the db + db.commit() + class FuryEngine: def run( @@ -151,7 +261,7 @@ def stream( for k, v in ir.items() ] } - k = next(iter(ir))[0].split("/")[0] + k = next(iter(ir)).split("/")[0] db_chainlog = DB.ChainLog( prompt_id=prompt_row.id, created_at=SimplerTimes.get_now_datetime(), @@ -206,6 +316,17 @@ def submit( # call the chain task_id: str = str(uuid4()) + worker_id = task_id.split("-")[0] + + db_chainlog = DB.ChainLog( + prompt_id=prompt_row.id, + created_at=SimplerTimes.get_now_datetime(), + node_id="init", + worker_id=worker_id, + message=f"scheduling task {task_id}", + ) # type: ignore + db.add(db_chainlog) + result = T.ChainResult( result=f"Task '{task_id}' scheduled", prompt_id=prompt_row.id, @@ -214,8 +335,26 @@ def submit( if store_io: prompt_row.response = result.result # type: ignore prompt_row.time_taken = float(time.time() - start) # type: ignore - db.commit() + prompt_row.meta = {"task_id": task_id} # type: ignore + + app.send_task( + "chainfury_server.engine.run_chain", + queue="cfs", + kwargs={ + "chatbot_id": chatbot.id, + "prompt_id": prompt_row.id, + "prompt_data": prompt.data, + "store_ir": store_ir, + "store_io": store_io, + "worker_id": worker_id, + }, + task_id=task_id, + expires=600, # 10 mins + time_limit=240, # 4 mins + soft_time_limit=60, # 1 min + ) + db.commit() return result except Exception as e: @@ -244,7 +383,10 @@ def __call__(self, thought): def create_prompt( - db: Session, chatbot_id: str, input_prompt: str, session_id: str + db: Session, + chatbot_id: str, + input_prompt: str, + session_id: str, ) -> DB.Prompt: db_prompt = DB.Prompt( chatbot_id=chatbot_id, diff --git a/server/chainfury_server/utils.py b/server/chainfury_server/utils.py index 608b8f3..22c5832 100644 --- a/server/chainfury_server/utils.py +++ b/server/chainfury_server/utils.py @@ -21,10 +21,12 @@ class Env: JWT_SECRET = lambda: os.getenv("JWT_SECRET", "hajime-shimamoto") CFS_SECRETS_PASSWORD = lambda: os.getenv("CFS_SECRETS_PASSWORDs") + # not once a lifetime but require DB changes, might as well not change these + CFS_MAXLEN_CF_NODE = lambda: int(os.getenv("CFS_MAXLEN_CF_NODE", 80)) + CFS_MAXLEN_WORKER = lambda: int(os.getenv("CFS_MAXLEN_WORKER", 16)) + # when you want to use chainfury as a client you need to set the following vars CFS_DATABASE = lambda: os.getenv("CFS_DATABASE", None) - CFS_MAXLEN_CF_NDOE = lambda: int(os.getenv("CFS_MAXLEN_CF_NDOE", 80)) - CFS_MAXLEN_WORKER = lambda: int(os.getenv("CFS_MAXLEN_WORKER", 16)) CFS_ALLOW_CORS_ORIGINS = lambda: [ x.strip() for x in os.getenv("CFS_ALLOW_CORS_ORIGINS", "*").split(",") ] From d725954f879112f5e58afc4f6fb6adf803aa9e09 Mon Sep 17 00:00:00 2001 From: yashbonde Date: Fri, 15 Mar 2024 12:05:06 +0530 Subject: [PATCH 8/8] [chore] update return type --- chainfury/components/tune/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chainfury/components/tune/__init__.py b/chainfury/components/tune/__init__.py index 8c556e2..7aa550c 100644 --- a/chainfury/components/tune/__init__.py +++ b/chainfury/components/tune/__init__.py @@ -33,7 +33,7 @@ def chat( temperature: float = 1, *, token: Secret = Secret(""), - ) -> Dict[str, Any]: + ) -> str: """ Chat with the Tune Studio APIs, see more at https://studio.tune.app/