-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #20 from Capsize-Games/develop
Develop
- Loading branch information
Showing
8 changed files
with
276 additions
and
254 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
intput | ||
output | ||
.idea | ||
.vscode | ||
__pycache__ | ||
build | ||
dist | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,20 +2,20 @@ | |
|
||
setup( | ||
name='chatairunner', | ||
version='1.0.7', | ||
version='1.0.8', | ||
author='Capsize LLC', | ||
description='Chat AI: A chatbot framework', | ||
long_description=open("README.md", "r", encoding="utf-8").read(), | ||
long_description_content_type="text/markdown", | ||
keywords="ai, chatbot, chat, ai", | ||
license="AGPL-3.0", | ||
author_email="[email protected]", | ||
url="https://github.com/w4ffl35/chat-ai", | ||
url="https://github.com/Capsize-Games/chat-ai", | ||
package_dir={"": "src"}, | ||
packages=find_packages("src"), | ||
include_package_data=True, | ||
python_requires=">=3.10.0", | ||
install_requires=[ | ||
"aihandler==1.8.14", | ||
"aihandler==1.8.16", | ||
] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import os | ||
import random | ||
from PyQt6 import uic | ||
from PyQt6.QtCore import pyqtSignal, pyqtSlot | ||
from PyQt6.QtWidgets import QMainWindow | ||
from PyQt6.QtGui import QGuiApplication | ||
from aihandler.pyqt_offline_client import OfflineClient | ||
from aihandler.llmrunner import LLMRunner | ||
from chatairunner.settings_manager import SettingsManager | ||
from aihandler.qtvar import TQDMVar, MessageHandlerVar, ErrorHandlerVar | ||
from chatairunner.conversation import ChatAIConversation | ||
|
||
|
||
class BaseWindow(QMainWindow): | ||
template = "" | ||
runner = None | ||
randomize_seed_on_generate = False | ||
message_signal = pyqtSignal(str) | ||
response_signal = pyqtSignal(dict) | ||
client = None | ||
|
||
def center(self): | ||
availableGeometry = QGuiApplication.primaryScreen().availableGeometry() | ||
frameGeometry = self.ui.frameGeometry() | ||
frameGeometry.moveCenter(availableGeometry.center()) | ||
self.ui.move(frameGeometry.topLeft()) | ||
|
||
@pyqtSlot(str) | ||
def message_received(self, message): | ||
if message == "initialized": | ||
self.enable_buttons() | ||
self.stop_progress_bar() | ||
|
||
def message_handler(self, *args, **kwargs): | ||
response = args[0]["response"]["response"] | ||
# remove <pad> tokens | ||
response = response.replace("<pad>", "") | ||
response = response.replace("<unk>", "") | ||
# check if </s> token exists | ||
incomplete = False | ||
if "</s>" not in response: | ||
# remove all tokens after </s> and </s> itself | ||
incomplete = True | ||
else: | ||
response = response[: response.find("</s>")] | ||
|
||
response = response.strip() | ||
self.ui.generated_text.appendPlainText(response) | ||
|
||
if incomplete: | ||
# if there is no </s> token, the response is incomplete | ||
# so we send another request | ||
self.generate() | ||
else: | ||
self.stop_progress_bar() | ||
self.enable_buttons() | ||
|
||
def error_handler(self, *args, **kwargs): | ||
self.ui.generated_text.appendPlainText(args[0]) | ||
self.stop_progress_bar() | ||
self.enable_buttons() | ||
|
||
def prep_prompt(self): | ||
return "" | ||
|
||
def prep_properties(self): | ||
pass | ||
|
||
def initialize_offline_client(self): | ||
self.tqdm_var = TQDMVar() | ||
self.message_var = MessageHandlerVar() | ||
self.error_var = ErrorHandlerVar() | ||
self.client = OfflineClient( | ||
app=self, | ||
tqdm_var=self.tqdm_var, | ||
message_var=self.message_var, | ||
error_var=self.error_var, | ||
runners=[LLMRunner] | ||
) | ||
|
||
def handle_generate(self): | ||
self.start_progress_bar() | ||
self.disable_buttons() | ||
self.generate() | ||
|
||
def get_seed(self): | ||
random_seed = self.ui.random_seed.isChecked() | ||
seed = random.randint(0, 1000000) if random_seed else int(self.ui.seed.toPlainText()) | ||
self.ui.seed.setPlainText(str(seed)) | ||
return seed | ||
|
||
def generate(self): | ||
action = "generate" | ||
userinput = " ".join([ | ||
self.ui.generated_text.toPlainText(), | ||
self.ui.prefix.toPlainText(), | ||
self.ui.prompt.toPlainText(), | ||
]) | ||
self.client.message = { | ||
"action": "llm", | ||
"type": action, | ||
"data": { | ||
"user_input": userinput, | ||
"username": self.conversation.username, | ||
"botname": self.conversation.botname, | ||
"seed": self.seed, | ||
"conversation": self, | ||
"properties": self.conversation.properties, | ||
} | ||
} | ||
print(prompt) | ||
|
||
def start_progress_bar(self): | ||
self.ui.progressBar.setValue(0) | ||
self.ui.progressBar.setRange(0, 0) | ||
|
||
def stop_progress_bar(self): | ||
self.ui.progressBar.setRange(0, 100) | ||
self.ui.progressBar.setValue(0) | ||
|
||
def enable_buttons(self): | ||
pass | ||
|
||
def disable_buttons(self): | ||
pass | ||
|
||
@pyqtSlot(dict) | ||
def process_response(self, response): | ||
pass | ||
|
||
def __init__(self, *args, **kwargs): | ||
self.client = kwargs.pop("client") | ||
self.parent = kwargs.pop("parent") | ||
super().__init__(*args, **kwargs) | ||
if self.client is None: | ||
self.initialize_offline_client() | ||
self.conversation = ChatAIConversation(client=self.client) | ||
self.client.tqdm_var.my_signal.connect(self.tqdm_callback) | ||
self.client.message_var.my_signal.connect(self.message_handler) | ||
self.client.error_var.my_signal.connect(self.error_handler) | ||
self.settings_manager = SettingsManager(app=self) | ||
self.response_signal.connect(self.process_response) | ||
self.message_signal.connect(self.message_received) | ||
self.seed = random.randint(0, 100000) | ||
self.load_template() | ||
self.center() | ||
self.ui.show() | ||
self.ui.closeEvent = self.handle_quit | ||
self.initialize_form() | ||
# self.exec() | ||
|
||
def handle_quit(self, *args, **kwargs): | ||
pass | ||
|
||
def initialize_form(self): | ||
pass | ||
|
||
def load_template(self): | ||
self.ui = uic.loadUi(f"pyqt/{self.template}.ui") | ||
# self.ui.setWindowIcon(QIcon('./assets/icon.png')) | ||
|
||
def tqdm_callback(self, *args, **kwargs): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.