Skip to content

Commit

Permalink
changes of work 16/10
Browse files Browse the repository at this point in the history
  • Loading branch information
komi786 committed Oct 16, 2024
2 parents 325335b + e077140 commit a5d0796
Show file tree
Hide file tree
Showing 10 changed files with 300 additions and 197 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ uv.lock
__pycache__

*.jar
backend/variables.db
12 changes: 12 additions & 0 deletions backend/src/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
from dataclasses import dataclass, field

Expand Down Expand Up @@ -64,9 +65,20 @@ def token_endpoint(self) -> str:
def admins_list(self) -> list[str]:
return self.admins.split(",")

@property
def logs_filepath(self) -> str:
return os.path.join(settings.data_folder, "logs.log")


settings = Settings()

# Disable uvicorn logs, does not seems to really do much
uvicorn_error = logging.getLogger("uvicorn.error")
uvicorn_error.disabled = True
uvicorn_access = logging.getLogger("uvicorn.access")
uvicorn_access.disabled = True

logging.basicConfig(filename=settings.logs_filepath, level=logging.INFO, format="%(asctime)s - %(message)s")

# import warnings

Expand Down
192 changes: 96 additions & 96 deletions backend/src/mapping_generation/llm_chain.py

Large diffs are not rendered by default.

42 changes: 24 additions & 18 deletions backend/src/mapping_generation/manager_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,25 @@ def _load_llm(model="llama3", hugging_face=False):
# temperature=0,
# )
elif model == "llama3.1":
active_model = ChatGroq(temperature=0,groq_api_key=groq_api, model="llama-3.1-70b-versatile",max_retries=3)
# active_model = ChatTogether(
# temperature=0,
# together_api_key=togather_api,
# model="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
# max_retries=3,
# verbose=True,
# ) # only 12 request for free tier to switch to groq --- Add Groq API key
# active_model = ChatGroq(temperature=0,groq_api_key=groq_api, model="llama-3.1-70b-versatile",max_retries=3)
active_model = ChatTogether(
temperature=0,
together_api_key=togather_api,
model="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
max_retries=3,
verbose=True,
) # only 12 request for free tier to switch to groq --- Add Groq API key

# active_model = ChatOllama(
# base_url="http://ollama:11434", # Ollama server endpoint
# model="llama3.2",
# temperature=0,
# )
# active_model = ChatOllama(
# base_url="http://ollama:11434", # Ollama server endpoint
# model="llama3.1:70b",
# temperature=0,
# )
elif model == "gpt4":
active_model = ChatOpenAI(
model="gpt-4-turbo",
Expand Down Expand Up @@ -159,7 +165,7 @@ def load_local_llm_instance(model_name="phi3"):

class CustomSemanticSimilarityExampleSelector(SemanticSimilarityExampleSelector):
"""Custom Selector to check for existing vector store before creating a new one."""

@classmethod
def from_examples(
cls,
Expand All @@ -174,10 +180,10 @@ def from_examples(
selector_path: Optional[str] = None,
content_key: Optional[str] = None,
**vectorstore_cls_kwargs: Any,) -> 'CustomSemanticSimilarityExampleSelector':

if selector_path is None:
selector_path = f'../data/faiss_index_{content_key}'

if os.path.exists(selector_path):
print(f"Selector path exist: {selector_path}")
# Load the existing FAISS index
Expand All @@ -196,17 +202,17 @@ def from_examples(
example_keys=example_keys,
vectorstore_kwargs=vectorstore_kwargs,
)

class ExampleSelectorManager:
_lock = threading.Lock()
_selectors = {}


@staticmethod
def get_example_selector(context_key: str, examples: List[Dict[str, str]], k=4, score_threshold=0.6, selector_path=None):
"""
Retrieves or creates a singleton example selector based on a context key.
Args:
context_key (str): A unique key to identify the selector configuration.
examples (List[Dict[str, str]]): List of example dictionaries.
Expand All @@ -218,7 +224,7 @@ def get_example_selector(context_key: str, examples: List[Dict[str, str]], k=4,
Returns:
SemanticSimilarityExampleSelector: An initialized example selector.
"""

with ExampleSelectorManager._lock:
if context_key not in ExampleSelectorManager._selectors:
try:
Expand All @@ -231,7 +237,7 @@ def get_example_selector(context_key: str, examples: List[Dict[str, str]], k=4,
# Initialize the selector using the vector store
selector = CustomSemanticSimilarityExampleSelector.from_examples(
examples=examples,
embeddings=embedding,
embeddings=embedding,
vectorstore_cls=FAISS,
k=k,
vectorstore_kwargs={"fetch_k": 40, "lambda_mult": 0.5},
Expand All @@ -240,8 +246,8 @@ def get_example_selector(context_key: str, examples: List[Dict[str, str]], k=4,
)
ExampleSelectorManager._selectors[context_key] = selector
logger.info(f"Example selector initialized for context: {context_key}." )

except Exception as e:
logger.error(f"Error initializing example selector for {context_key}: {e}", exc_info=True)
raise
return ExampleSelectorManager._selectors[context_key]
return ExampleSelectorManager._selectors[context_key]
5 changes: 3 additions & 2 deletions backend/src/mapping_generation/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
SYN_COLLECTION_NAME = "concept_mapping_1"
NEAREST_SAMPLE_NUM = 64
QUANT_TYPE = "scalar"
LLM_ID = "gpt-4o-mini" #'llama3.1'
DB_FILE = 'variables.db'
# LLM_ID = "gpt-4o-mini"
LLM_ID = "llama3.1"
DB_FILE = "variables.db"
CANDIDATE_GENERATOR_BATCH_SIZE = 64
CACHE_DIR = f"{DATA_DIR}/resources/models"
LLAMA_CACHE_DIR = f"{DATA_DIR}/resources/models/llama"
Expand Down
14 changes: 9 additions & 5 deletions backend/src/mapping_generation/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,20 @@
post_process_candidates,
exact_match_found,
filter_irrelevant_domain_candidates,
init_logger,
)
from .utils import global_logger as logger
from .datamanager import DataManager
import os
from .vector_index import (
update_compressed_merger_retriever,
generate_vector_index,
initiate_api_retriever,
set_merger_retriever,
set_compression_retriever,

)




# Cache for retrievers based on domain
RETRIEVER_CACHE = {}

Expand Down Expand Up @@ -124,7 +123,7 @@ def full_query_processing(
except Exception as e:
logger.error(f"Error full processing query: {e}", exc_info=True)
return {}


def temp_process_query_details_db(
llm_query_obj: QueryDecomposedModel,
Expand Down Expand Up @@ -929,7 +928,7 @@ def filter_results(query, results):
# all_values[q_value] = post_process_candidates(matched_docs, max=1)
# elif categorical_value_results and len(categorical_value_results) > 0:
# if values_type == 'additional':
# q_value_ = f"{q_value}, context: {main_term}"
# q_value_ = f"{q_value}, context: {main_term}"
# else:
# q_value_ = q_value
# updated_results, _ = pass_to_chat_llm_chain(
Expand Down Expand Up @@ -1004,6 +1003,11 @@ def map_csv_to_standard_codes(meta_path: str):
data, is_mapped = load_data(meta_path, load_custom=True)
if is_mapped:
return data

cohort_folder = os.path.dirname(meta_path)
mapping_logger = init_logger(os.path.join(cohort_folder, "mapping_generation.log"))
mapping_logger.info(f"Logging mapping generation for {meta_path}")
# TODO: improve logging so that all logs are saved to a file named `mapping_generation.log` in the cohort_folder
embeddings = SAPEmbeddings()
sparse_embeddings = FastEmbedSparse(model_name="Qdrant/bm42-all-minilm-l6-v2-attentions")
hybrid_search = generate_vector_index(
Expand Down
9 changes: 4 additions & 5 deletions backend/src/mapping_generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ def init_logger(log_file_path=LOG_FILE) -> logging.Logger:
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) # Set the logging level to DEBUG
# Create a file handler
# file_handler = logging.FileHandler(log_file_path)
# file_handler.setLevel(logging.DEBUG) # Set the logging level for the file handler
file_handler = logging.FileHandler(log_file_path)
file_handler.setLevel(logging.DEBUG) # Set the logging level for the file handler

# Create a stream handler (to print to console)
stream_handler = logging.StreamHandler()
Expand All @@ -233,7 +233,7 @@ def init_logger(log_file_path=LOG_FILE) -> logging.Logger:
stream_handler.setFormatter(formatter)

# Add the handlers to the logger
# logger.addHandler(file_handler)
logger.addHandler(file_handler)
logger.addHandler(stream_handler)

return logger
Expand Down Expand Up @@ -355,7 +355,7 @@ def map_and_combine_fields(row):
'Unit Concept Code': row.get('Unit Concept Code', ''),
'Unit OMOP ID': row.get('Unit OMOP ID', ''),
}

# Combine fields
# label_ids = '|'.join(filter(None, [row.get('standard_concept_id'), row.get('additional_context_omop_ids')]))
# label_codes = '|'.join(filter(None, [row.get('standard_code'), row.get('additional_context_codes')]))
Expand Down Expand Up @@ -419,7 +419,6 @@ def load_mapping(filename, domain):

if not mapping:
return None, ranking_examples_string, relevance_examples_string

return (
{
"prompt": mapping.get("description", "No description provided."),
Expand Down
21 changes: 21 additions & 0 deletions backend/src/mapping_generation/vector_index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# !/usr/bin/env python3
# from langchain.schema import Document
import argparse
from math import e
import time

import qdrant_client.http.models as rest
Expand Down Expand Up @@ -85,6 +86,26 @@ def _create_payload_index(client, collection_name) -> None:
# )


def update_compressed_merger_retriever(
merger_retriever: CustomCompressionRetriever, domain="all", topk=10
) -> CustomCompressionRetriever:
try:
retrievers = merger_retriever.base_retriever.retrievers
api_retriever = update_api_search_filter(
retrievers[1].base_retriever, domain=domain, topk=topk
)
dense_retriever = update_qdrant_search_filter(
retrievers[0], domain=domain, topk=topk
)
merger_retriever = CustomMergeRetriever(
retrievers=[dense_retriever, api_retriever]
)
return set_compression_retriever(merger_retriever)
except Exception as e:
logger.info(f"Error updating merger retriever: {e}")
return merger_retriever


def generate_vector_index(
dense_embedding,
sparse_embedding,
Expand Down
68 changes: 62 additions & 6 deletions backend/src/upload.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import glob
import logging
import os
import shutil
from datetime import datetime
Expand Down Expand Up @@ -36,7 +37,7 @@ def publish_graph_to_endpoint(g: Graph, graph_uri: str | None = None) -> bool:
# response = requests.post(url, headers=headers, data=graph_data)
# Check response status and print result
if not response.ok:
print(f"Failed to upload data: {response.status_code}, {response.text}")
logging.warning(f"Failed to upload data: {response.status_code}, {response.text}")
return response.ok


Expand Down Expand Up @@ -220,6 +221,10 @@ def load_cohort_dict_file(dict_path: str, cohort_id: str) -> Dataset:
# Try to get IDs from old format multiple columns
df["concept_id"] = df.apply(lambda row: get_id_from_multi_columns(row), axis=1)

duplicate_variables = df[df.duplicated(subset=["VARIABLE NAME"], keep=False)]
if not duplicate_variables.empty:
errors.append(f"Duplicate VARIABLE NAME found: {', '.join(duplicate_variables['VARIABLE NAME'].unique())}")

cohort_uri = get_cohort_uri(cohort_id)
g = init_graph()
g.add((cohort_uri, RDF.type, ICARE.Cohort, cohort_uri))
Expand Down Expand Up @@ -304,6 +309,53 @@ def load_cohort_dict_file(dict_path: str, cohort_id: str) -> Dataset:
return g


@router.post(
"/get-logs",
name="Get logs",
response_description="Logs",
)
async def get_logs(
user: Any = Depends(get_current_user),
) -> list[str]:
"""Delete a cohort from the triplestore and delete its metadata file from the server."""
user_email = user["email"]
if user_email not in settings.admins_list:
raise HTTPException(status_code=403, detail="You need to be admin to perform this action.")
with open(settings.logs_filepath) as log_file:
logs = log_file.read()
return logs.split("\n")
# return {
# "message": f"Cohort {cohort_id} has been successfully deleted.",
# }


@router.post(
"/delete-cohort",
name="Delete a cohort from the database",
response_description="Delete result",
)
async def delete_cohort(
user: Any = Depends(get_current_user),
cohort_id: str = Form(...),
) -> dict[str, Any]:
"""Delete a cohort from the triplestore and delete its metadata file from the server."""
user_email = user["email"]
if user_email not in settings.admins_list:
raise HTTPException(status_code=403, detail="You need to be admin to perform this action.")
delete_existing_triples(
get_cohort_mapping_uri(cohort_id), f"<{get_cohort_uri(cohort_id)!s}>", "icare:previewEnabled"
)
delete_existing_triples(get_cohort_uri(cohort_id))
# Delete folder
cohort_folder_path = os.path.join(settings.data_folder, "cohorts", cohort_id)
if os.path.exists(cohort_folder_path) and os.path.isdir(cohort_folder_path):
shutil.rmtree(cohort_folder_path)
return {
"message": f"Cohort {cohort_id} has been successfully deleted.",
}



@router.post(
"/upload-cohort",
name="Upload cohort metadata file",
Expand Down Expand Up @@ -377,9 +429,9 @@ async def upload_cohort(
# NOTE: waiting for more tests before sending to production
background_tasks.add_task(generate_mappings, cohort_id, metadata_path, g)
# TODO: move all the "delete_existing_triples" and "publish_graph_to_endpoint" logic to the background task after mappings have been generated
# Return "The cohort has been successfully uploaded. The variables are being mapped to standard codes and will be available in the Cohort Explorer in a few minutes."

# # Delete previous graph for this file from triplestore
# Delete previous graph for this file from triplestore
# TODO: will move to background task
# delete_existing_triples(
# get_cohort_mapping_uri(cohort_id), f"<{get_cohort_uri(cohort_id)!s}>", "icare:previewEnabled"
# )
Expand All @@ -389,10 +441,14 @@ async def upload_cohort(
os.remove(metadata_path)
raise e

# return {
# "message": f"Metadata for cohort {cohort_id} have been successfully uploaded. The variables are being mapped to standard codes and will be available in the Cohort Explorer in a few minutes.",
# "identifier": cohort_id,
# # **cohort.dict(),
# }
return {
"message": f"Metadata for cohort {cohort_id} have been successfully uploaded. The variables are being mapped to standard codes and will be available in the Cohort Explorer in a few minutes.",
"message": f"Metadata for cohort {cohort_id} have been successfully uploaded.",
"identifier": cohort_id,
# **cohort.dict(),
}

def generate_mappings(cohort_id: str, metadata_path: str, g: Graph) -> None:
Expand Down Expand Up @@ -524,7 +580,7 @@ def init_triplestore() -> None:
# NOTE: default airlock preview to false if we ever need to reset cohorts,
# admins can easily ddl and reupload the cohorts with the correct airlock value
g = load_cohort_dict_file(file, folder)
g.serialize(f"{settings.data_folder}/cohort_explorer_triplestore.trig", format="trig")
# g.serialize(f"{settings.data_folder}/cohort_explorer_triplestore.trig", format="trig")
if publish_graph_to_endpoint(g):
print(f"💾 Triplestore initialization: added {len(g)} triples for cohorts {file}.")

Expand Down
Loading

0 comments on commit a5d0796

Please sign in to comment.