-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #262 from biolab/chat_constructor
ChatGPT Constructor: Add widget
- Loading branch information
Showing
1 changed file
with
200 additions
and
0 deletions.
There are no files selected for viewing
200 changes: 200 additions & 0 deletions
200
orangecontrib/prototypes/widgets/owchatgptconstructor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
from typing import Optional | ||
|
||
from AnyQt.QtCore import Signal, Qt | ||
from AnyQt.QtGui import QFocusEvent | ||
from AnyQt.QtWidgets import QLineEdit, QTextEdit | ||
|
||
import openai | ||
import tiktoken | ||
|
||
from Orange.data import Table, StringVariable | ||
from Orange.data.util import get_unique_names | ||
from Orange.widgets import gui | ||
from Orange.widgets.credentials import CredentialManager | ||
from Orange.widgets.utils.itemmodels import DomainModel | ||
from Orange.widgets.settings import Setting, DomainContextHandler, \ | ||
ContextSetting | ||
from Orange.widgets.widget import OWWidget, Input, Output, Msg | ||
|
||
MODELS = ["gpt-3.5-turbo", "gpt-4"] | ||
|
||
|
||
def run_gpt( | ||
api_key: str, | ||
model: str, | ||
text: str, | ||
prompt_start: str, | ||
prompt_end: str | ||
) -> str: | ||
openai.api_key = api_key | ||
enc = tiktoken.encoding_for_model(model) | ||
|
||
text = enc.decode(enc.encode(text)[:3500]) | ||
content = f"{prompt_start}\n{text}.\n{prompt_end}" | ||
response = openai.ChatCompletion.create( | ||
model=model, | ||
messages=[ | ||
{"role": "system", "content": "You are a helpful assistant."}, | ||
{"role": "user", "content": content}, | ||
] | ||
) | ||
return response.choices[0].message.content | ||
|
||
|
||
class TextEdit(QTextEdit): | ||
sigEditFinished = Signal() | ||
|
||
def focusOutEvent(self, ev: QFocusEvent): | ||
self.sigEditFinished.emit() | ||
super().focusOutEvent(ev) | ||
|
||
|
||
class OWChatGPTConstructor(OWWidget): | ||
name = "ChatGPT Constructor" | ||
description = "Construct a text field using a ChatGPT." | ||
icon = "icons/chatgpt.svg" | ||
priority = 11 | ||
keywords = ["text", "gpt"] | ||
|
||
settingsHandler = DomainContextHandler() | ||
access_key = "" | ||
model_index = Setting(0) | ||
text_var = ContextSetting(None) | ||
prompt_start = Setting("") | ||
prompt_end = Setting("") | ||
cache = Setting({}) | ||
auto_apply = Setting(False) | ||
|
||
want_main_area = False | ||
|
||
class Inputs: | ||
data = Input("Data", Table) | ||
|
||
class Outputs: | ||
data = Output("Data", Table) | ||
|
||
class Warning(OWWidget.Warning): | ||
missing_key = Msg("The Access key is missing.") | ||
missing_str_var = Msg("Data has no text variables.") | ||
|
||
class Error(OWWidget.Error): | ||
unknown_error = Msg("An error occurred while creating an answer.\n{}") | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.__data: Optional[Table] = None | ||
self.__text_var_model = DomainModel(valid_types=(StringVariable,)) | ||
self.__start_text_edit: QTextEdit = None | ||
self.__end_text_edit: QTextEdit = None | ||
|
||
self.__cm = CredentialManager("Ask") | ||
self.access_key = self.__cm.access_key or "" | ||
|
||
self.setup_gui() | ||
|
||
def setup_gui(self): | ||
box = gui.vBox(self.controlArea, "Model") | ||
edit: QLineEdit = gui.lineEdit(box, self, "access_key", "API Key:", | ||
orientation=Qt.Horizontal, | ||
callback=self.__on_access_key_changed) | ||
edit.setEchoMode(QLineEdit.Password) | ||
gui.comboBox(box, self, "model_index", label="Model:", | ||
orientation=Qt.Horizontal, | ||
items=MODELS, callback=self.commit.deferred) | ||
|
||
gui.comboBox(self.controlArea, self, "text_var", "Data", | ||
"Text variable:", model=self.__text_var_model, | ||
orientation=Qt.Horizontal, callback=self.commit.deferred) | ||
|
||
box = gui.vBox(self.controlArea, "Prompt") | ||
gui.label(box, self, "Start:") | ||
self.__start_text_edit = TextEdit(tabChangesFocus=True) | ||
self.__start_text_edit.setText(self.prompt_start) | ||
self.__start_text_edit.sigEditFinished.connect( | ||
self.__on_start_text_edit_changed) | ||
box.layout().addWidget(self.__start_text_edit) | ||
gui.label(box, self, "End:") | ||
self.__end_text_edit = TextEdit(tabChangesFocus=True) | ||
self.__end_text_edit.setText(self.prompt_end) | ||
self.__end_text_edit.sigEditFinished.connect( | ||
self.__on_end_text_edit_changed) | ||
box.layout().addWidget(self.__end_text_edit) | ||
|
||
gui.rubber(self.controlArea) | ||
|
||
gui.auto_apply(self.buttonsArea, self, "auto_apply") | ||
|
||
def __on_access_key_changed(self): | ||
self.__cm.access_key = self.access_key | ||
self.commit.deferred() | ||
|
||
def __on_start_text_edit_changed(self): | ||
prompt_start = self.__start_text_edit.toPlainText() | ||
if self.prompt_start != prompt_start: | ||
self.prompt_start = prompt_start | ||
self.commit.deferred() | ||
|
||
def __on_end_text_edit_changed(self): | ||
prompt_end = self.__end_text_edit.toPlainText() | ||
if self.prompt_end != prompt_end: | ||
self.prompt_end = prompt_end | ||
self.commit.deferred() | ||
|
||
@Inputs.data | ||
def set_data(self, data: Table): | ||
self.closeContext() | ||
self.clear_messages() | ||
self.__data = data | ||
self.__text_var_model.set_domain(data.domain if data else None) | ||
self.text_var = self.__text_var_model[0] if self.__text_var_model \ | ||
else None | ||
if data and not self.__text_var_model: | ||
self.Warning.missing_str_var() | ||
self.openContext(data) | ||
self.commit.deferred() | ||
|
||
@gui.deferred | ||
def commit(self): | ||
self.Warning.missing_key.clear() | ||
if self.access_key == "": | ||
self.Warning.missing_key() | ||
|
||
answers = self._get_answers() | ||
|
||
data = self.__data | ||
if data is not None: | ||
name = get_unique_names(data.domain, "Text") | ||
var = StringVariable(name) | ||
data = data.add_column(var, answers, to_metas=True) | ||
|
||
self.Outputs.data.send(data) | ||
|
||
def _get_answers(self) -> str: | ||
self.Error.unknown_error.clear() | ||
if not self.__data or not self.text_var or not self.access_key: | ||
return "" | ||
|
||
texts = self.__data.get_column(self.text_var) | ||
answers = [] | ||
for text in texts: | ||
args = (text.strip(), | ||
self.prompt_start.strip(), | ||
self.prompt_end.strip()) | ||
if args in self.cache: | ||
answer = self.cache[args] | ||
else: | ||
try: | ||
answer = run_gpt(self.access_key, MODELS[self.model_index], | ||
*args) | ||
self.cache[args] = answer | ||
except Exception as ex: | ||
answer = ex | ||
self.Error.unknown_error(ex) | ||
answers.append(answer) | ||
return answers | ||
|
||
|
||
if __name__ == "__main__": | ||
from Orange.widgets.utils.widgetpreview import WidgetPreview | ||
|
||
WidgetPreview(OWChatGPTConstructor).run(set_data=Table("zoo")) |