Skip to content

Commit

Permalink
Added API endpoints for both the vector search and LLM predictions, for
Browse files Browse the repository at this point in the history
#96 Chat with your website functionality
  • Loading branch information
m-i-l committed Dec 3, 2023
1 parent 7a61544 commit 8816694
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/web/content/dynamic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
DB_NAME = 'searchmysitedb'
DB_USER = 'postgres'
DB_HOST = 'db'
TORCHSERVE = 'http://models:8080/'

# POSTGRES_PASSWORD is normally set by docker from the .env file
# The .env file is normally in the main application root (searchmysite/src/)
Expand Down
93 changes: 92 additions & 1 deletion src/web/content/dynamic/searchmysite/api/searchapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from searchmysite.db import get_db
import config
import searchmysite.solr
from searchmysite.searchutils import check_if_api_enabled_for_domain, get_search_params, get_filter_queries, get_start, do_search, get_no_of_results, get_links, get_display_results
from searchmysite.searchutils import check_if_api_enabled_for_domain, get_search_params, get_filter_queries, get_start, do_search, get_no_of_results, get_links, get_display_results, do_vector_search, get_query_vector_string
import requests


bp = Blueprint('searchapi', __name__)
Expand Down Expand Up @@ -189,6 +190,96 @@ def feed_newest(format, search_type='newest'):
return error_response(404, 'xml', message="/{}/search/browse/ not found".format(format))


# Vector search API
# -----------------
#
# Full URL:
# /api/v1/knnsearch/?q=<query>&domain=<domain>
# e.g. /api/v1/feed/search/?q=What%20is%20vector%20search&domain=*
#
# Parameters:
# <query> is the query text
# <domain> is the domain to search, or * for all domains
#
# Responses:
# Results:
# [
# {'id': 'https://url/!chunk010',
# 'content_chunk_text': '...',
# 'url': 'https://url/',
# 'score': 0.8489073},
# {...}, ...
# ]
#
#
@bp.route('/knnsearch/', methods=['GET', 'POST'])
def vector_search():
params = get_search_params(request, 'search')
query = params['q']
domain = params['domain']
query_vector_string = get_query_vector_string(query)
response = do_vector_search(query_vector_string, domain)
results = response['response']['docs']
#current_app.logger.debug('results: {}'.format(results))
return results


# LLM Vector search API
# ---------------------
#
# Full URL:
# /api/v1/predictions/llama2/?q=<query>&prompt=<domain>&context=<context>
# e.g. /api/v1/predictions/llama2/?q=How%20long%20does%20it%20take%20to%20climb%20Ben%20Nevis&prompt=qa&context=it%20took%204%20hours%20to%20climb%20ben%20nevis
#
# Parameters:
# <query> is the query text
# <prompt> indicates the prompt template to use, e.g. "qa" for question answering
# <context> is the context text for the prompt template
#
# Responses:
# Results:
# Text
#

@bp.route('/predictions/llama2', methods=['GET', 'POST'])
def predictions():
# Get data from request
query = request.args.get('q', '')
prompt_type = request.args.get('prompt', 'qa')
context = request.args.get('context', '')
# Build LLM prompt
llm_prompt = get_llm_prompt(query, prompt_type, context)
llm_data = get_llm_data(llm_prompt)
# Do request
response = do_llm_prediction(llm_prompt, llm_data)
return make_response(jsonify(response))

def get_llm_prompt(query, prompt_type, context):
if prompt_type == 'qa':
prompt = "<s>[INST] <<SYS>>Answer the question based on the context below.<</SYS>> \n [context]: {} \n [question]: {} [\INST]".format(context, query)
else:
prompt = "[INST] <<SYS>> You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.<</SYS>>{}[/INST]".format(query)
return prompt

def get_llm_data(prompt):
data = json.dumps(
{
"prompt": prompt,
"max_tokens": 100,
"top_p": 0.95,
"temperature": 0.8,
}
)
return data

def do_llm_prediction(prompt, data):
url = config.TORCHSERVE + "predictions/llama2"
headers = {"Content-type": "application/json", "Accept": "text/plain"}
response = requests.post(url=url, data=data, headers=headers)
cleaned_response = response.text.removeprefix(prompt)
return cleaned_response


# Utilities

def convert_results_to_xml_string(results, params, no_of_results_for_display, links, search_type):
Expand Down
6 changes: 5 additions & 1 deletion src/web/content/dynamic/searchmysite/searchutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def get_search_params(request, search_type):
except:
resultsperpage = default_results_per_page
search_params['resultsperpage'] = resultsperpage
# domain (currently just used by the vector search API)
domain = request.args.get('domain', '*')
search_params['domain'] = domain
#current_app.logger.debug('get_search_params: {}'.format(search_params))
return search_params

Expand Down Expand Up @@ -156,10 +159,11 @@ def do_search(query_params, query_facets, params, start, default_filter_queries,
# Need double curly braces to escape the curly braces.
# Field in schema is content_chunk_vector.
# Vector has to be a string representation of a list like "[1.0, 2.0, 3.0, 4.0]"
def do_vector_search(query_vector_string):
def do_vector_search(query_vector_string, domain):
solr_select_params_vector_search = {
"q": '{{!knn f=content_chunk_vector topK=4}}{}'.format(query_vector_string),
"fl": ["id", "url", "content_chunk_text", "score"],
"fq": "domain:{}".format(domain)
}
solr_search = {}
solr_search['params'] = solr_select_params_vector_search
Expand Down

0 comments on commit 8816694

Please sign in to comment.