From 2a64eb62baa65eef40e6fa4f7b8476227448091b Mon Sep 17 00:00:00 2001 From: Vincent Emonet Date: Tue, 17 Oct 2023 17:54:26 +0200 Subject: [PATCH] pass list to load vectors and add call to reset db --- .github/workflows/test.yml | 3 +++ README.md | 6 ++++++ docker-compose.yml | 7 ++++--- src/api.py | 11 +++++++++++ src/predict.py | 24 ++++++++++++++++++++++-- src/train.py | 20 +++++++++++++++----- src/vectordb.py | 32 +++++++++++++++++--------------- 7 files changed, 78 insertions(+), 25 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3cae0c6..d4737f3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -66,5 +66,8 @@ jobs: with: languages: python + - name: Autobuild + uses: github/codeql-action/autobuild@v2 + - name: Perform CodeQL Analysis uses: github/codeql-action/analyze@v2 diff --git a/README.md b/README.md index b2e721e..3f8aacb 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,12 @@ cd .. ## 🏋️ Run training +To force using a specific GPU set the environment variable `CUDA_VISIBLE_DEVICES` (starting from 0, so if you have 3 GPUs you can choose between 0,1 and 2): + +```bash +export CUDA_VISIBLE_DEVICES=1 +``` + Train the model: ```bash diff --git a/docker-compose.yml b/docker-compose.yml index 4a78890..33005c5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -5,9 +5,10 @@ services: build: . restart: unless-stopped volumes: - - ./data:/app/data - - ./models:/app/models - - ./MolecularTransformerEmbeddings:/app/MolecularTransformerEmbeddings + - ./:/app + # - ./data:/app/data + # - ./models:/app/models + # - ./MolecularTransformerEmbeddings:/app/MolecularTransformerEmbeddings environment: - VIRTUAL_HOST=predict-drug-target.137.120.31.160.nip.io - LETSENCRYPT_HOST=predict-drug-target.137.120.31.160.nip.io diff --git a/src/api.py b/src/api.py index 2df89b0..8d2f73b 100644 --- a/src/api.py +++ b/src/api.py @@ -3,6 +3,8 @@ from trapi_predict_kit import TRAPI, settings from src.predict import get_drug_target_predictions +from src.utils import COLLECTIONS +from src.vectordb import init_vectordb log_level = logging.INFO logging.basicConfig(level=log_level) @@ -76,3 +78,12 @@ trapi_example=trapi_example, # trapi_description="" ) + +@app.post( + "/reset-vectordb", + name="Reset vector database", + description="Reset the collections in the vectordb" +) +def reset_db(api_key: str): + init_vectordb(COLLECTIONS, recreate=True, api_key=api_key) + return {"status": "ok"} diff --git a/src/predict.py b/src/predict.py index b3d093e..273d08a 100644 --- a/src/predict.py +++ b/src/predict.py @@ -43,6 +43,7 @@ def compute_drug_embedding( os.makedirs("tmp", exist_ok=True) os.chdir("MolecularTransformerEmbeddings") + vector_list = [] for drug_id in drugs: from_vectordb = vectordb.get("drug", drug_id) if len(from_vectordb) > 0: @@ -67,10 +68,19 @@ def compute_drug_embedding( # In this case we vectorize one by one, so only 1 row in the array embeddings = vectors[0].tolist() # TODO: add label also? - vectordb.add("drug", drug_id, vector=embeddings, sequence=drug_smiles, label=drug_label) + vector_list.append({ + "vector": embeddings, + "payload": { + "id": drug_id, + "sequence":drug_smiles, + "label": drug_label + } + }) + # vectordb.add("drug", drug_id, vector=embeddings, sequence=drug_smiles, label=drug_label) embeddings.insert(0, drug_id) df.loc[len(df)] = embeddings os.chdir("..") + vectordb.add("drug", vector_list) return df @@ -88,6 +98,7 @@ def compute_target_embedding( df = pd.DataFrame.from_records(targets_list) return df + vector_list = [] for target_id in targets: # Check if we can find it in the vectordb from_vectordb = vectordb.get("target", target_id) @@ -121,9 +132,18 @@ def compute_target_embedding( target_embeddings = torch.stack(sequence_representations, dim=0).numpy() # numpy.ndarray 3775 x 1280 embeddings = target_embeddings[0].tolist() - vectordb.add("target", target_id, vector=embeddings, sequence=target_seq, label=target_label) + vector_list.append({ + "vector": embeddings, + "payload": { + "id": target_id, + "sequence":target_seq, + "label": target_label + } + }) + # vectordb.add("target", target_id, vector=embeddings, sequence=target_seq, label=target_label) embeddings.insert(0, target_id) df.loc[len(df)] = embeddings + vectordb.add("target", vector_list) return df diff --git a/src/train.py b/src/train.py index dd1258d..d7cd56b 100644 --- a/src/train.py +++ b/src/train.py @@ -254,20 +254,30 @@ def kfoldCV(sc, pairs_all, classes_all, embedding_df, clfs, n_run, n_fold, n_pro failed_conversion = [] # Add drug embeddings to the vector db +vector_list = [] for _index, row in embeddings["drug"].iterrows(): log.info(f"Drug {_index}/{len(embeddings['drug'])}") vector = [row[column] for column in embeddings["drug"].columns if column != "drug"] # if pubchem_id not in pubchem_ids: # failed_conversion.append(row['drug']) # continue - pubchem_id = pubchem_ids[f"DRUGBANK:{row['drug']}"] - if not pubchem_id or not pubchem_id.lower().startswith("pubchem.compound:"): - failed_conversion.append(f"{row['drug']} > {pubchem_id}") + drug_id = pubchem_ids[f"DRUGBANK:{row['drug']}"] + if not drug_id or not drug_id.lower().startswith("pubchem.compound:"): + failed_conversion.append(f"{row['drug']} > {drug_id}") continue # pubchem = normalize_id_to_translator() - drug_smiles, drug_label = get_smiles_for_drug(pubchem_id) - vectordb.add("drug", pubchem_id, vector=vector, sequence=drug_smiles, label=drug_label) + drug_smiles, drug_label = get_smiles_for_drug(drug_id) + vector_list.append({ + "vector": vector, + "payload": { + "id": drug_id, + "sequence":drug_smiles, + "label": drug_label + } + }) + +vectordb.add("drug", vector_list) print(f"{len(failed_conversion)} drugs ignored:") print("\n".join(failed_conversion)) diff --git a/src/vectordb.py b/src/vectordb.py index bc2f29a..c23efa0 100644 --- a/src/vectordb.py +++ b/src/vectordb.py @@ -88,22 +88,25 @@ def __init__( ) def add( - self, collection_name: str, entity_id: str, vector: list[float], sequence: str | None = None, label: str | None = None + self, collection_name: str, item_list: list[str] ) -> UpdateResult: - payload = {"id": entity_id} - if sequence: - payload["sequence"] = sequence - if label: - payload["label"] = label + # payload = {"id": entity_id} + # if sequence: + # payload["sequence"] = sequence + # if label: + # payload["label"] = label + points_count = self.client.get_collection(collection_name).points_count + points_list = [ + PointStruct(id=points_count + i + 1, vector=item["vector"], payload=item["payload"]) for i, item in enumerate(item_list) + ] operation_info = self.client.upsert( collection_name=collection_name, wait=True, - points=[ - PointStruct( - id=self.client.get_collection(collection_name).points_count + 1, vector=vector, payload=payload - ), - # PointStruct(id=2, vector=[0.19, 0.81, 0.75, 0.11], payload={"city": "London"}), - ], + points=points_list, + # [PointStruct( + # id=self.client.get_collection(collection_name).points_count + 1, vector=vector, payload=payload + # )], + # PointStruct(id=2, vector=[0.19, 0.81, 0.75, 0.11], payload={"city": "London"}), ) return operation_info @@ -138,7 +141,6 @@ def search( return search_result[0] -def init_vectordb(collections: list[dict[str, str]], recreate: bool = False): +def init_vectordb(collections: list[dict[str, str]], recreate: bool = False, api_key: str = "TOCHANGE"): qdrant_url = "qdrant.137.120.31.148.nip.io" - qdrant_apikey = "TOCHANGE" - return QdrantDB(collections=collections, recreate=recreate, host=qdrant_url, port=443, api_key=qdrant_apikey) + return QdrantDB(collections=collections, recreate=recreate, host=qdrant_url, port=443, api_key=api_key)