From e2e4d1153565f0bb1a4d0ed72cee87c9d72a84c5 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Mon, 15 Apr 2024 08:32:47 +0000 Subject: [PATCH] fix admin --- server/controllers/admin.py | 47 ++++++++++++++++++++++++++++--------- server/models/document.py | 2 +- server/nlp/embeddings.py | 9 +++---- 3 files changed, 42 insertions(+), 16 deletions(-) diff --git a/server/controllers/admin.py b/server/controllers/admin.py index 54223d865..25a60d6f3 100644 --- a/server/controllers/admin.py +++ b/server/controllers/admin.py @@ -6,6 +6,7 @@ import pandas as pd from apiflask import APIBlueprint from flask import request +from sqlalchemy import select from server import db from server.models.document import Document @@ -30,7 +31,9 @@ def upload_text(): def delete_text(): """POST /admin/delete_document""" data = request.form - document = Document.query.get(data["id"]) + document = db.session.execute( + select(Document).where(Document.id == data["id"]) + ).scalar() if document is None: return {"error": "Document not found"}, 404 if document.response_count > 0: @@ -47,7 +50,9 @@ def delete_text(): def update_text(): """POST /admin/edit_document""" data = request.form - document = Document.query.get(data["id"]) + document = db.session.execute( + select(Document).where(Document.id == data["id"]) + ).scalar() if document is None: return {"error": "Document not found"}, 404 document.question = data["question"] @@ -61,14 +66,22 @@ def update_text(): @admin.route("/get_documents", methods=["GET"]) def get_all(): """GET /admin/get_documents""" - documents = Document.query.order_by(Document.id.desc()).all() + documents = ( + db.session.execute(select(Document).order_by(Document.id.desc())) + .scalars() + .all() + ) return [document.map() for document in documents] @admin.route("/update_embeddings", methods=["GET"]) def update_embeddings(): """GET /admin/update_embeddings""" - documents = Document.query.order_by(Document.id.desc()).all() + documents = ( + db.session.execute(select(Document).order_by(Document.id.desc())) + .scalars() + .all() + ) docs = [document.map() for document in documents if not document.to_delete] modified_corpus = [ @@ -107,7 +120,11 @@ def upload_json(): @admin.route("/export_json", methods=["GET"]) def export_json(): """GET /admin/export_json""" - documents = Document.query.order_by(Document.id.desc()).all() + documents = ( + db.session.execute(select(Document).order_by(Document.id.desc())) + .scalars() + .all() + ) return json.dumps( [ { @@ -132,10 +149,10 @@ def import_csv(): for _, row in df.iterrows(): document = Document( - "" if pd.isna(row["question"]) else row["question"], - row["content"], - row["source"], - "" if pd.isna(row["label"]) else row["label"], + "" if pd.isna(row["question"]) else row["question"], # type: ignore + row["content"], # type: ignore + row["source"], # type: ignore + "" if pd.isna(row["label"]) else row["label"], # type: ignore ) db.session.add(document) db.session.commit() @@ -147,7 +164,11 @@ def import_csv(): @admin.route("/export_csv", methods=["GET"]) def export_csv(): """GET /admin/export_csv""" - documents = Document.query.order_by(Document.id.desc()).all() + documents = ( + db.session.execute(select(Document).order_by(Document.id.desc())) + .scalars() + .all() + ) df = pd.DataFrame( [ { @@ -168,7 +189,11 @@ def export_csv(): def clear_documents(): """POST /admin/clear_documents""" try: - documents = Document.query.order_by(Document.id.desc()).all() + documents = ( + db.session.execute(select(Document).order_by(Document.id.desc())) + .scalars() + .all() + ) for document in documents: if document.response_count > 0: document.to_delete = True diff --git a/server/models/document.py b/server/models/document.py index b582b04e6..20e5cc9a4 100644 --- a/server/models/document.py +++ b/server/models/document.py @@ -37,7 +37,7 @@ class Document(db.Model): response_count: Mapped[int] = mapped_column(default=0, init=False) responses: Mapped[List["Response"]] = relationship( - secondary=document_response_table, back_populates="documents" + secondary=document_response_table, back_populates="documents", init=False ) def map(self): diff --git a/server/nlp/embeddings.py b/server/nlp/embeddings.py index 368988f92..1a5fbd60b 100644 --- a/server/nlp/embeddings.py +++ b/server/nlp/embeddings.py @@ -69,7 +69,7 @@ def compute_openai_embeddings(texts): embeddings = [] for i in range(len(texts)): embeddings.append( - openai.Embedding.create(input=texts[i], model=embedding_model) + openai.embeddings.create(input=texts[i], model=embedding_model) .data[0] .embedding ) @@ -81,7 +81,7 @@ def compute_embeddings(): print("computing embeddings...") # get keys, questions, content - keys = sorted(client.keys("documents:*")) + keys = sorted(client.keys("documents:*")) # type: ignore questions = client.json().mget(keys, "$.question") content = client.json().mget(keys, "$.content") @@ -89,7 +89,8 @@ def compute_embeddings(): # compute embeddings question_and_content = [ - questions[i][0] + " " + content[i][0] for i in range(len(questions)) + questions[i][0] + " " + content[i][0] + for i in range(len(questions)) # type: ignore ] # embeddings = embedder.encode(question_and_content).astype(np.float32).tolist() @@ -222,7 +223,7 @@ def queries(query, queries: list[str]) -> list[dict]: query, {"query_vector": np.array(encoded_query, dtype=np.float32).tobytes()}, ) - .docs + .docs # type: ignore ) query_result = [] for doc in result_docs: