diff --git a/src/web/content/dynamic/config.py b/src/web/content/dynamic/config.py index 3c3a427..7a25e43 100644 --- a/src/web/content/dynamic/config.py +++ b/src/web/content/dynamic/config.py @@ -4,6 +4,7 @@ DB_NAME = 'searchmysitedb' DB_USER = 'postgres' DB_HOST = 'db' +TORCHSERVE = 'http://models:8080/' # POSTGRES_PASSWORD is normally set by docker from the .env file # The .env file is normally in the main application root (searchmysite/src/) diff --git a/src/web/content/dynamic/searchmysite/api/searchapi.py b/src/web/content/dynamic/searchmysite/api/searchapi.py index 08e4b4b..4977a2a 100644 --- a/src/web/content/dynamic/searchmysite/api/searchapi.py +++ b/src/web/content/dynamic/searchmysite/api/searchapi.py @@ -10,7 +10,8 @@ from searchmysite.db import get_db import config import searchmysite.solr -from searchmysite.searchutils import check_if_api_enabled_for_domain, get_search_params, get_filter_queries, get_start, do_search, get_no_of_results, get_links, get_display_results +from searchmysite.searchutils import check_if_api_enabled_for_domain, get_search_params, get_filter_queries, get_start, do_search, get_no_of_results, get_links, get_display_results, do_vector_search, get_query_vector_string +import requests bp = Blueprint('searchapi', __name__) @@ -189,6 +190,96 @@ def feed_newest(format, search_type='newest'): return error_response(404, 'xml', message="/{}/search/browse/ not found".format(format)) +# Vector search API +# ----------------- +# +# Full URL: +# /api/v1/knnsearch/?q=&domain= +# e.g. /api/v1/feed/search/?q=What%20is%20vector%20search&domain=* +# +# Parameters: +# is the query text +# is the domain to search, or * for all domains +# +# Responses: +# Results: +# [ +# {'id': 'https://url/!chunk010', +# 'content_chunk_text': '...', +# 'url': 'https://url/', +# 'score': 0.8489073}, +# {...}, ... +# ] +# +# +@bp.route('/knnsearch/', methods=['GET', 'POST']) +def vector_search(): + params = get_search_params(request, 'search') + query = params['q'] + domain = params['domain'] + query_vector_string = get_query_vector_string(query) + response = do_vector_search(query_vector_string, domain) + results = response['response']['docs'] + #current_app.logger.debug('results: {}'.format(results)) + return results + + +# LLM Vector search API +# --------------------- +# +# Full URL: +# /api/v1/predictions/llama2/?q=&prompt=&context= +# e.g. /api/v1/predictions/llama2/?q=How%20long%20does%20it%20take%20to%20climb%20Ben%20Nevis&prompt=qa&context=it%20took%204%20hours%20to%20climb%20ben%20nevis +# +# Parameters: +# is the query text +# indicates the prompt template to use, e.g. "qa" for question answering +# is the context text for the prompt template +# +# Responses: +# Results: +# Text +# + +@bp.route('/predictions/llama2', methods=['GET', 'POST']) +def predictions(): + # Get data from request + query = request.args.get('q', '') + prompt_type = request.args.get('prompt', 'qa') + context = request.args.get('context', '') + # Build LLM prompt + llm_prompt = get_llm_prompt(query, prompt_type, context) + llm_data = get_llm_data(llm_prompt) + # Do request + response = do_llm_prediction(llm_prompt, llm_data) + return make_response(jsonify(response)) + +def get_llm_prompt(query, prompt_type, context): + if prompt_type == 'qa': + prompt = "[INST] <>Answer the question based on the context below.<> \n [context]: {} \n [question]: {} [\INST]".format(context, query) + else: + prompt = "[INST] <> You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.<>{}[/INST]".format(query) + return prompt + +def get_llm_data(prompt): + data = json.dumps( + { + "prompt": prompt, + "max_tokens": 100, + "top_p": 0.95, + "temperature": 0.8, + } + ) + return data + +def do_llm_prediction(prompt, data): + url = config.TORCHSERVE + "predictions/llama2" + headers = {"Content-type": "application/json", "Accept": "text/plain"} + response = requests.post(url=url, data=data, headers=headers) + cleaned_response = response.text.removeprefix(prompt) + return cleaned_response + + # Utilities def convert_results_to_xml_string(results, params, no_of_results_for_display, links, search_type): diff --git a/src/web/content/dynamic/searchmysite/searchutils.py b/src/web/content/dynamic/searchmysite/searchutils.py index d0840aa..06776e2 100644 --- a/src/web/content/dynamic/searchmysite/searchutils.py +++ b/src/web/content/dynamic/searchmysite/searchutils.py @@ -85,6 +85,9 @@ def get_search_params(request, search_type): except: resultsperpage = default_results_per_page search_params['resultsperpage'] = resultsperpage + # domain (currently just used by the vector search API) + domain = request.args.get('domain', '*') + search_params['domain'] = domain #current_app.logger.debug('get_search_params: {}'.format(search_params)) return search_params @@ -156,10 +159,11 @@ def do_search(query_params, query_facets, params, start, default_filter_queries, # Need double curly braces to escape the curly braces. # Field in schema is content_chunk_vector. # Vector has to be a string representation of a list like "[1.0, 2.0, 3.0, 4.0]" -def do_vector_search(query_vector_string): +def do_vector_search(query_vector_string, domain): solr_select_params_vector_search = { "q": '{{!knn f=content_chunk_vector topK=4}}{}'.format(query_vector_string), "fl": ["id", "url", "content_chunk_text", "score"], + "fq": "domain:{}".format(domain) } solr_search = {} solr_search['params'] = solr_select_params_vector_search