Skip to content

Commit

Permalink
Support thread level context (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
hibobmaster committed Apr 23, 2024
1 parent a37df65 commit 69ce5b4
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 39 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,4 @@ cython_debug/
sync_db
manage_db
element-keys.txt
context.db
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog

## 1.7.0
- Support thread level context

## 1.6.0
- Add GPT Vision

Expand Down
17 changes: 11 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ This is a simple Matrix bot that support using OpenAI API, Langchain to generate
4. Langchain([Flowise](https://github.com/FlowiseAI/Flowise))
5. Image Generation with [DALL·E](https://platform.openai.com/docs/api-reference/images/create) or [LocalAI](https://localai.io/features/image-generation/) or [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API)
6. GPT Vision(openai or [GPT Vision API](https://platform.openai.com/docs/guides/vision) compatible such as [LocalAI](https://localai.io/features/gpt-vision/))
7. Room level and thread level chat context

## Installation and Setup

Expand All @@ -21,10 +22,10 @@ For explainations and complete parameter list see: https://github.com/hibobmaste
Create two empty file, for persist database only<br>

```bash
touch sync_db manage_db
touch sync_db context.db manage_db
sudo docker compose up -d
```
manage_db(can be ignored) is for langchain agent, sync_db is for matrix sync database<br>
manage_db(can be ignored) is for langchain agent, sync_db is for matrix sync database, context.db is for bot chat context<br>
<hr>
Normal Method:<br>
system dependece: <code>libolm-dev</code>
Expand Down Expand Up @@ -115,12 +116,16 @@ LangChain(flowise) admin: https://github.com/hibobmaster/matrix_chatgpt_bot/wiki
![demo2](https://i.imgur.com/BKZktWd.jpg)
https://github.com/hibobmaster/matrix_chatgpt_bot/wiki/ <br>

## Thread level Context
Mention bot with prompt, bot will reply in thread.

To keep context just send prompt in thread directly without mention it.

![thread level context 1](https://i.imgur.com/4vLvNCt.jpeg)
![thread level context 2](https://i.imgur.com/1eb1Lmd.jpeg)


## Thanks
1. [matrix-nio](https://github.com/poljar/matrix-nio)
2. [acheong08](https://github.com/acheong08)
3. [8go](https://github.com/8go/)

<a href="https://jb.gg/OpenSourceSupport" target="_blank">
<img src="https://resources.jetbrains.com/storage/products/company/brand/logos/jb_beam.png" alt="JetBrains Logo (Main) logo." width="200" height="200">
</a>
3 changes: 2 additions & 1 deletion compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ services:
# use env file or config.json
# - ./config.json:/app/config.json
# use touch to create empty db file, for persist database only
# manage_db(can be ignored) is for langchain agent, sync_db is for matrix sync database
# manage_db(can be ignored) is for langchain agent, sync_db is for matrix sync database, context.db is for bot chat context
- ./sync_db:/app/sync_db
- ./context.db:/app/context.db
# - ./manage_db:/app/manage_db
# import_keys path
# - ./element-keys.txt:/app/element-keys.txt
Expand Down
78 changes: 78 additions & 0 deletions src/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def __init__(
self.new_prog = re.compile(r"\s*!new\s+(.+)$")

async def close(self, task: asyncio.Task) -> None:
self.chatbot.cursor.close()
self.chatbot.conn.close()
await self.httpx_client.aclose()
if self.lc_admin is not None:
self.lc_manager.c.close()
Expand All @@ -251,6 +253,9 @@ async def message_callback(self, room: MatrixRoom, event: RoomMessageText) -> No
# sender_id
sender_id = event.sender

# event source
event_source = event.source

# user_message
raw_user_message = event.body

Expand All @@ -265,6 +270,48 @@ async def message_callback(self, room: MatrixRoom, event: RoomMessageText) -> No
# remove newline character from event.body
content_body = re.sub("\r\n|\r|\n", " ", raw_user_message)

# @bot and reply in thread
if "m.mentions" in event_source["content"]:
if "user_ids" in event_source["content"]["m.mentions"]:
# @bot
if (
self.user_id
in event_source["content"]["m.mentions"]["user_ids"]
):
try:
asyncio.create_task(
self.thread_chat(
room_id,
reply_to_event_id,
sender_id=sender_id,
thread_root_id=reply_to_event_id,
prompt=content_body,
)
)
except Exception as e:
logger.error(e, exe_info=True)

# thread converstaion
if "m.relates_to" in event_source["content"]:
if "rel_type" in event_source["content"]["m.relates_to"]:
thread_root_id = event_source["content"]["m.relates_to"]["event_id"]
# thread is created by @bot
if thread_root_id in self.chatbot.conversation:
try:
asyncio.create_task(
self.thread_chat(
room_id,
reply_to_event_id,
sender_id=sender_id,
thread_root_id=thread_root_id,
prompt=content_body,
)
)
except Exception as e:
logger.error(e, exe_info=True)

# common command

# !gpt command
if (
self.openai_api_key is not None
Expand Down Expand Up @@ -1300,6 +1347,37 @@ async def to_device_callback(self, event: KeyVerificationEvent) -> None:
estr = traceback.format_exc()
logger.info(estr)

# thread chat
async def thread_chat(
self, room_id, reply_to_event_id, thread_root_id, prompt, sender_id
):
try:
await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000)
content = await self.chatbot.ask_async_v2(
prompt=prompt,
convo_id=thread_root_id,
)
await send_room_message(
self.client,
room_id,
reply_message=content,
reply_to_event_id=reply_to_event_id,
sender_id=sender_id,
reply_in_thread=True,
thread_root_id=thread_root_id,
)
except Exception as e:
logger.error(e, exe_info=True)
await send_room_message(
self.client,
room_id,
reply_message=GENERAL_ERROR_MESSAGE,
sender_id=sender_id,
reply_to_event_id=reply_to_event_id,
reply_in_thread=True,
thread_root_id=thread_root_id,
)

# !chat command
async def chat(self, room_id, reply_to_event_id, prompt, sender_id, user_message):
try:
Expand Down
73 changes: 51 additions & 22 deletions src/gptbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,15 @@
Code derived from https://github.com/acheong08/ChatGPT/blob/main/src/revChatGPT/V3.py
A simple wrapper for the official ChatGPT API
"""
import sqlite3
import json
from typing import AsyncGenerator
from tenacity import retry, wait_random_exponential, stop_after_attempt
import httpx
import tiktoken


ENGINES = [
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4",
"gpt-4-32k",
"gpt-4-0613",
"gpt-4-32k-0613",
]
ENGINES = ["gpt-3.5-turbo", "gpt-4", "gpt-4-32k", "gpt-4-turbo"]


class Chatbot:
Expand All @@ -41,6 +33,7 @@ def __init__(
reply_count: int = 1,
truncate_limit: int = None,
system_prompt: str = None,
db_path: str = "context.db",
) -> None:
"""
Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
Expand All @@ -53,23 +46,24 @@ def __init__(
or "You are ChatGPT, \
a large language model trained by OpenAI. Respond conversationally"
)
# https://platform.openai.com/docs/models
self.max_tokens: int = max_tokens or (
31000
127000
if "gpt-4-turbo" in engine
else 31000
if "gpt-4-32k" in engine
else 7000
if "gpt-4" in engine
else 15000
if "gpt-3.5-turbo-16k" in engine
else 4000
else 16000
)
self.truncate_limit: int = truncate_limit or (
30500
126500
if "gpt-4-turbo" in engine
else 30500
if "gpt-4-32k" in engine
else 6500
if "gpt-4" in engine
else 14500
if "gpt-3.5-turbo-16k" in engine
else 3500
else 15500
)
self.temperature: float = temperature
self.top_p: float = top_p
Expand All @@ -80,17 +74,49 @@ def __init__(

self.aclient = aclient

self.conversation: dict[str, list[dict]] = {
self.db_path = db_path

self.conn = sqlite3.connect(self.db_path)
self.cursor = self.conn.cursor()

self._create_tables()

self.conversation = self._load_conversation()

if self.get_token_count("default") > self.max_tokens:
raise Exception("System prompt is too long")

def _create_tables(self) -> None:
self.conn.execute(
"""
CREATE TABLE IF NOT EXISTS conversations(
id INTEGER PRIMARY KEY AUTOINCREMENT,
convo_id TEXT UNIQUE,
messages TEXT
)
"""
)

def _load_conversation(self) -> dict[str, list[dict]]:
conversations: dict[str, list[dict]] = {
"default": [
{
"role": "system",
"content": system_prompt,
"content": self.system_prompt,
},
],
}
self.cursor.execute("SELECT convo_id, messages FROM conversations")
for convo_id, messages in self.cursor.fetchall():
conversations[convo_id] = json.loads(messages)
return conversations

if self.get_token_count("default") > self.max_tokens:
raise Exception("System prompt is too long")
def _save_conversation(self, convo_id) -> None:
self.conn.execute(
"INSERT OR REPLACE INTO conversations (convo_id, messages) VALUES (?, ?)",
(convo_id, json.dumps(self.conversation[convo_id])),
)
self.conn.commit()

def add_to_conversation(
self,
Expand All @@ -102,6 +128,7 @@ def add_to_conversation(
Add a message to the conversation
"""
self.conversation[convo_id].append({"role": role, "content": message})
self._save_conversation(convo_id)

def __truncate_conversation(self, convo_id: str = "default") -> None:
"""
Expand All @@ -116,6 +143,7 @@ def __truncate_conversation(self, convo_id: str = "default") -> None:
self.conversation[convo_id].pop(1)
else:
break
self._save_conversation(convo_id)

# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def get_token_count(self, convo_id: str = "default") -> int:
Expand Down Expand Up @@ -305,6 +333,7 @@ def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
self.conversation[convo_id] = [
{"role": "system", "content": system_prompt or self.system_prompt},
]
self._save_conversation(convo_id)

@retry(wait=wait_random_exponential(min=2, max=5), stop=stop_after_attempt(3))
async def oneTimeAsk(
Expand Down
37 changes: 27 additions & 10 deletions src/send_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ async def send_room_message(
sender_id: str = "",
user_message: str = "",
reply_to_event_id: str = "",
reply_in_thread: bool = False,
thread_root_id: str = "",
) -> None:
if reply_to_event_id == "":
content = {
Expand All @@ -23,6 +25,23 @@ async def send_room_message(
extensions=["nl2br", "tables", "fenced_code"],
),
}
elif reply_in_thread and thread_root_id:
content = {
"msgtype": "m.text",
"body": reply_message,
"format": "org.matrix.custom.html",
"formatted_body": markdown.markdown(
reply_message,
extensions=["nl2br", "tables", "fenced_code"],
),
"m.relates_to": {
"m.in_reply_to": {"event_id": reply_to_event_id},
"rel_type": "m.thread",
"event_id": thread_root_id,
"is_falling_back": True,
},
}

else:
body = "> <" + sender_id + "> " + user_message + "\n\n" + reply_message
format = r"org.matrix.custom.html"
Expand Down Expand Up @@ -51,13 +70,11 @@ async def send_room_message(
"formatted_body": formatted_body,
"m.relates_to": {"m.in_reply_to": {"event_id": reply_to_event_id}},
}
try:
await client.room_send(
room_id,
message_type="m.room.message",
content=content,
ignore_unverified_devices=True,
)
await client.room_typing(room_id, typing_state=False)
except Exception as e:
logger.error(e)

await client.room_send(
room_id,
message_type="m.room.message",
content=content,
ignore_unverified_devices=True,
)
await client.room_typing(room_id, typing_state=False)

0 comments on commit 69ce5b4

Please sign in to comment.