Skip to content

Commit

Permalink
Merge pull request #81 from calpoly-csai/cameron_t
Browse files Browse the repository at this point in the history
working ask endpoint
  • Loading branch information
cameron-toy authored Feb 28, 2020
2 parents 02fe6db + ffd3d98 commit b5bde14
Show file tree
Hide file tree
Showing 34 changed files with 1,390 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,4 @@ credentials.json
folder_id.txt
settings.yaml
.export_env_vars
auth.json
172 changes: 168 additions & 4 deletions QA.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,39 @@
from typing import Callable, Dict, Any
import functools
import re
from Entity.Courses import Courses
from Entity.Locations import Locations
from Entity.Professors import Professors
from Entity.Clubs import Clubs
from Entity.Sections import Sections
from database_wrapper import NimbusMySQLAlchemy
from pandas import read_csv

Extracted_Vars = Dict[str, Any]
DB_Data = Dict[str, Any]
DB_Query = Callable[[Extracted_Vars], DB_Data]
Answer_Formatter = Callable[[Extracted_Vars, DB_Data], str]


tag_lookup = {
'PROF': Professors,
'CLUB': Clubs,
'COURSE': Courses,
'SECRET_HIDEOUT': Locations,
'SECTION': Sections
}

# TODO: Initialize this somewhere else. Currently here because of _get_property()
# Move into the Nimbus class below if possible.
db = NimbusMySQLAlchemy()


class QA:
"""
A class for wrapping functions used to answer a question.
"""

def __init__(self, q_format, db, db_query, format_function):
def __init__(self, q_format, db_query, format_answer):
"""
Args:
q_format (str): Question format string
Expand All @@ -24,20 +46,22 @@ def __init__(self, q_format, db, db_query, format_function):
data retrieved from the database--and returns a str.
"""
self.q_format = q_format
self.db = db
self.db_query = db_query
self.format_function = format_function
self.format_answer = format_answer

def _get_data_from_db(self, extracted_vars):
return self.db_query(extracted_vars)

def _format_answer(self, extracted_vars, db_data):
return self.format_function(extracted_vars, db_data)
return self.format_answer(extracted_vars, db_data)

def answer(self, extracted_vars):
db_data = self._get_data_from_db(extracted_vars)
return self._format_answer(extracted_vars, db_data)

def __repr__(self):
return self.q_format

def __hash__(self):
return hash(self.q_format)

Expand All @@ -51,3 +75,143 @@ def create_qa_mapping(qa_list):
qa_list (list(QA))
"""
return {qa.q_format: qa for qa in qa_list}


# def _string_sub(a_format, extracted_vars, db_data):
# """
# Substitutes values in a string based off the contents of the extracted_vars
# and db_data dictionaries. Keys from the dictionaries in the a_format string
# will be replaced with their associated value.
#
# Example input/output:
# a_format: "{professor1_ex}'s office is {office1_db}."
# extracted_vars: {"professor1": "Dr. Khosmood"}
# db_data: {"office1": "14-213"}
#
# "Dr. Khosmood's office is 14-213"
#
# Args:
# a_format (str): String to be formatted. Variables to be substituted should
# be in curly braces and end in "_ex" for keys from extracted_vars and "_db"
# for keys from db_data.
# extracted_vars (Extracted_Vars)
# db_data (Db_Data)
#
# Returns:
# A formatted answer string
# """
# # Adds "_ex" to the end of keys in extracted_vars
# extracted_vars = {
# k + "_ex": v for k, v in extracted_vars.items()
# }
# # Adds "_db" to the end of keys in db_data
# db_data = {
# k + "_db": v for k, v in db_data.items()
# }
# return a_format.format(**extracted_vars, **db_data)
#
#
# def _single_var_string_sub(a_format, extracted_vars, db_data):
# """
# Like _string_sub for cases where there's max one item in either dict
#
# Example input/output:
# a_format: "{ex}'s office is {db}."
# extracted_vars: {"professor1": "Dr. Khosmood"}
# db_data: {"office1": "14-213"}
#
# "Dr. Khosmood's office is 14-213"
#
# Args:
# a_format (str): String to be formatted. {ex} will be substituted with
# the value from extracted_vars and {db} will be substituted with the
# value from db_data
# extracted_vars (Extracted_Vars)
# db_data (Db_Data)
#
# Returns:
# A formatted answer string
# """
# # Gets value from a dictionary with a single item
# ex_val = next(iter(extracted_vars.values())) if extracted_vars else ''
# db_val = next(iter(db_data.values())) if db_data else ''
# return a_format.format(ex=ex_val, db=db_val)
#
#
# def string_sub(a_format):
# return functools.partial(_string_sub, a_format)
#
#
# def single_var_string_sub(a_format):
# return functools.partial(_single_var_string_sub, a_format)


def _string_sub(a_format, extracted_info, db_data):
if db_data is None:
return None
else:
return a_format.format(ex=extracted_info['normalized entity'], db=db_data)


def string_sub(a_format):
return functools.partial(_string_sub, a_format)


def _get_property(prop, extracted_info):
ent_string = extracted_info["normalized entity"]
ent = tag_lookup[extracted_info['tag']]
try:
value = db.get_property_from_entity(prop=prop, entity=ent, identifier=ent_string)
except IndexError:
return None
else:
return value


def get_property(prop):
return functools.partial(_get_property, prop)


def _yes_no(a_format, pred, extracted_info, db_data):
if pred is None:
result = 'Yes' if db_data else 'No'
elif type(pred) == str:
result = 'Yes' if re.search(pred, db_data) else 'No'
else:
result = 'Yes' if pred(db_data) else 'No'
return a_format.format(y_n=result, yes_no=result, ex=extracted_info['normalized entity'])


def yes_no(a_format, pred=None):
return functools.partial(_yes_no, a_format, pred)


def generate_fact_QA(csv):
df = read_csv(csv)
text_in_brackets = r'\[[^\[\]]*\]'
qa_objs = []
for i in range(len(df)):
q = df['question_format'][i]
a = df['answer_format'][i]
matches = re.findall(text_in_brackets, a)
extracted = None
if len(matches) == 1:
db_data = matches[0]
elif '..' in matches[1]:
db_data = matches[1]
extracted = matches[0]
else:
db_data = matches[0]
extracted = matches[1]
prop = db_data.split('..', 1)[1][0:-1]
a = a.replace(db_data, '{db}')
if extracted is not None:
a = a.replace(extracted, '{ex}')
o = QA(
q_format=q,
db_query=get_property(prop),
format_answer=string_sub(a)
)
qa_objs.append(o)

return qa_objs
18 changes: 14 additions & 4 deletions database_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@
]
UNION_PROPERTIES = Union[ProfessorsProperties]

default_tag_column_dict = {
Calendars: {"date"},
Courses: {"courseName", "courseNum", "dept"},
Locations: {"building_number", "name"},
Professors: {"firstName", "lastName"},
Clubs: {"club_name"},
Sections: {"section_name"}
}


class BadDictionaryKeyError(Exception):
"""Raised when the given JSON/dict is missing some required fields.
Expand Down Expand Up @@ -379,7 +388,8 @@ def full_fuzzy_match(self, tag_value, identifier):


def get_property_from_entity(
self, prop: str, entity: UNION_ENTITIES, identifier: str, tag_column_map: dict
self, prop: str, entity: UNION_ENTITIES, identifier: str,
tag_column_map: dict = default_tag_column_dict
):
"""
This function implements the abstractmethod to get a column of values
Expand Down Expand Up @@ -422,11 +432,11 @@ def get_property_from_entity(
total_similarity = 0
tags = []
for tag_prop in tag_props:
total_similarity += self.full_fuzzy_match(row.__dict__[tag_prop], identifier)
tags.append(row.__dict__[tag_prop])
total_similarity += self.full_fuzzy_match(str(row.__dict__[tag_prop]), identifier)
tags.append(str(row.__dict__[tag_prop]))

if total_similarity > MATCH_THRESHOLD:
results.append((total_similarity, tags, row.__dict__[prop]))
results.append((total_similarity, tags, str(row.__dict__[prop])))

if len(results) < 1:
return None
Expand Down
6 changes: 5 additions & 1 deletion flask_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from modules.formatters import WakeWordFormatter
from modules.validators import WakeWordValidator, WakeWordValidatorError

from nimbus import Nimbus

BAD_REQUEST = 400
SUCCESS = 200

Expand All @@ -24,6 +26,8 @@
app = Flask(__name__)
CORS(app)

# TODO: Initialize this somewhere else.
nimbus = Nimbus()

@app.route('/', methods=['GET', 'POST'])
def hello():
Expand Down Expand Up @@ -60,7 +64,7 @@ def handle_question():
return "request body should include the question", BAD_REQUEST

response = {
"answer": "answer of <<{}>>".format(question),
"answer": nimbus.answer_question(question)
}

if "session" in request_body:
Expand Down
31 changes: 31 additions & 0 deletions nimbus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from QA import create_qa_mapping, generate_fact_QA
from nimbus_nlp.NIMBUS_NLP import NIMBUS_NLP


class Nimbus:

def __init__(self):
self.qa_dict = create_qa_mapping(
generate_fact_QA("q_a_pairs.csv")
)

def answer_question(self, question):
ans_dict = NIMBUS_NLP.predict_question(question)
print(ans_dict)
try:
qa = self.qa_dict[ans_dict["question class"]]
except KeyError:
return "I'm sorry, I don't understand. Please try another question."
else:
answer = qa.answer(ans_dict)
if answer is None:
return("I'm sorry, I understand your question but was unable to find an answer. "
"Please try another question.")
else:
return answer

if __name__ == "__main__":
nimbus = Nimbus()
while True:
question = input("Enter a question: ")
print(nimbus.answer_question(question))
Loading

0 comments on commit b5bde14

Please sign in to comment.