Skip to content

Commit

Permalink
fix admin
Browse files Browse the repository at this point in the history
  • Loading branch information
azliu0 committed Apr 15, 2024
1 parent cfa0db3 commit e2e4d11
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 16 deletions.
47 changes: 36 additions & 11 deletions server/controllers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
from apiflask import APIBlueprint
from flask import request
from sqlalchemy import select

from server import db
from server.models.document import Document
Expand All @@ -30,7 +31,9 @@ def upload_text():
def delete_text():
"""POST /admin/delete_document"""
data = request.form
document = Document.query.get(data["id"])
document = db.session.execute(
select(Document).where(Document.id == data["id"])
).scalar()
if document is None:
return {"error": "Document not found"}, 404
if document.response_count > 0:
Expand All @@ -47,7 +50,9 @@ def delete_text():
def update_text():
"""POST /admin/edit_document"""
data = request.form
document = Document.query.get(data["id"])
document = db.session.execute(
select(Document).where(Document.id == data["id"])
).scalar()
if document is None:
return {"error": "Document not found"}, 404
document.question = data["question"]
Expand All @@ -61,14 +66,22 @@ def update_text():
@admin.route("/get_documents", methods=["GET"])
def get_all():
"""GET /admin/get_documents"""
documents = Document.query.order_by(Document.id.desc()).all()
documents = (
db.session.execute(select(Document).order_by(Document.id.desc()))
.scalars()
.all()
)
return [document.map() for document in documents]


@admin.route("/update_embeddings", methods=["GET"])
def update_embeddings():
"""GET /admin/update_embeddings"""
documents = Document.query.order_by(Document.id.desc()).all()
documents = (
db.session.execute(select(Document).order_by(Document.id.desc()))
.scalars()
.all()
)

docs = [document.map() for document in documents if not document.to_delete]
modified_corpus = [
Expand Down Expand Up @@ -107,7 +120,11 @@ def upload_json():
@admin.route("/export_json", methods=["GET"])
def export_json():
"""GET /admin/export_json"""
documents = Document.query.order_by(Document.id.desc()).all()
documents = (
db.session.execute(select(Document).order_by(Document.id.desc()))
.scalars()
.all()
)
return json.dumps(
[
{
Expand All @@ -132,10 +149,10 @@ def import_csv():

for _, row in df.iterrows():
document = Document(
"" if pd.isna(row["question"]) else row["question"],
row["content"],
row["source"],
"" if pd.isna(row["label"]) else row["label"],
"" if pd.isna(row["question"]) else row["question"], # type: ignore
row["content"], # type: ignore
row["source"], # type: ignore
"" if pd.isna(row["label"]) else row["label"], # type: ignore
)
db.session.add(document)
db.session.commit()
Expand All @@ -147,7 +164,11 @@ def import_csv():
@admin.route("/export_csv", methods=["GET"])
def export_csv():
"""GET /admin/export_csv"""
documents = Document.query.order_by(Document.id.desc()).all()
documents = (
db.session.execute(select(Document).order_by(Document.id.desc()))
.scalars()
.all()
)
df = pd.DataFrame(
[
{
Expand All @@ -168,7 +189,11 @@ def export_csv():
def clear_documents():
"""POST /admin/clear_documents"""
try:
documents = Document.query.order_by(Document.id.desc()).all()
documents = (
db.session.execute(select(Document).order_by(Document.id.desc()))
.scalars()
.all()
)
for document in documents:
if document.response_count > 0:
document.to_delete = True
Expand Down
2 changes: 1 addition & 1 deletion server/models/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Document(db.Model):
response_count: Mapped[int] = mapped_column(default=0, init=False)

responses: Mapped[List["Response"]] = relationship(
secondary=document_response_table, back_populates="documents"
secondary=document_response_table, back_populates="documents", init=False
)

def map(self):
Expand Down
9 changes: 5 additions & 4 deletions server/nlp/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def compute_openai_embeddings(texts):
embeddings = []
for i in range(len(texts)):
embeddings.append(
openai.Embedding.create(input=texts[i], model=embedding_model)
openai.embeddings.create(input=texts[i], model=embedding_model)
.data[0]
.embedding
)
Expand All @@ -81,15 +81,16 @@ def compute_embeddings():
print("computing embeddings...")

# get keys, questions, content
keys = sorted(client.keys("documents:*"))
keys = sorted(client.keys("documents:*")) # type: ignore
questions = client.json().mget(keys, "$.question")
content = client.json().mget(keys, "$.content")

assert len(questions) == len(content)

# compute embeddings
question_and_content = [
questions[i][0] + " " + content[i][0] for i in range(len(questions))
questions[i][0] + " " + content[i][0]
for i in range(len(questions)) # type: ignore
]

# embeddings = embedder.encode(question_and_content).astype(np.float32).tolist()
Expand Down Expand Up @@ -222,7 +223,7 @@ def queries(query, queries: list[str]) -> list[dict]:
query,
{"query_vector": np.array(encoded_query, dtype=np.float32).tobytes()},
)
.docs
.docs # type: ignore
)
query_result = []
for doc in result_docs:
Expand Down

0 comments on commit e2e4d11

Please sign in to comment.