Skip to content

Commit

Permalink
pass list to load vectors and add call to reset db
Browse files Browse the repository at this point in the history
  • Loading branch information
vemonet committed Oct 17, 2023
1 parent e0f6bf9 commit 2a64eb6
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 25 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,8 @@ jobs:
with:
languages: python

- name: Autobuild
uses: github/codeql-action/autobuild@v2

- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v2
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ cd ..

## 🏋️ Run training

To force using a specific GPU set the environment variable `CUDA_VISIBLE_DEVICES` (starting from 0, so if you have 3 GPUs you can choose between 0,1 and 2):

```bash
export CUDA_VISIBLE_DEVICES=1
```

Train the model:

```bash
Expand Down
7 changes: 4 additions & 3 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ services:
build: .
restart: unless-stopped
volumes:
- ./data:/app/data
- ./models:/app/models
- ./MolecularTransformerEmbeddings:/app/MolecularTransformerEmbeddings
- ./:/app
# - ./data:/app/data
# - ./models:/app/models
# - ./MolecularTransformerEmbeddings:/app/MolecularTransformerEmbeddings
environment:
- VIRTUAL_HOST=predict-drug-target.137.120.31.160.nip.io
- LETSENCRYPT_HOST=predict-drug-target.137.120.31.160.nip.io
Expand Down
11 changes: 11 additions & 0 deletions src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from trapi_predict_kit import TRAPI, settings

from src.predict import get_drug_target_predictions
from src.utils import COLLECTIONS
from src.vectordb import init_vectordb

log_level = logging.INFO
logging.basicConfig(level=log_level)
Expand Down Expand Up @@ -76,3 +78,12 @@
trapi_example=trapi_example,
# trapi_description=""
)

@app.post(
"/reset-vectordb",
name="Reset vector database",
description="Reset the collections in the vectordb"
)
def reset_db(api_key: str):
init_vectordb(COLLECTIONS, recreate=True, api_key=api_key)
return {"status": "ok"}
24 changes: 22 additions & 2 deletions src/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def compute_drug_embedding(

os.makedirs("tmp", exist_ok=True)
os.chdir("MolecularTransformerEmbeddings")
vector_list = []
for drug_id in drugs:
from_vectordb = vectordb.get("drug", drug_id)
if len(from_vectordb) > 0:
Expand All @@ -67,10 +68,19 @@ def compute_drug_embedding(
# In this case we vectorize one by one, so only 1 row in the array
embeddings = vectors[0].tolist()
# TODO: add label also?
vectordb.add("drug", drug_id, vector=embeddings, sequence=drug_smiles, label=drug_label)
vector_list.append({
"vector": embeddings,
"payload": {
"id": drug_id,
"sequence":drug_smiles,
"label": drug_label
}
})
# vectordb.add("drug", drug_id, vector=embeddings, sequence=drug_smiles, label=drug_label)
embeddings.insert(0, drug_id)
df.loc[len(df)] = embeddings
os.chdir("..")
vectordb.add("drug", vector_list)
return df


Expand All @@ -88,6 +98,7 @@ def compute_target_embedding(
df = pd.DataFrame.from_records(targets_list)
return df

vector_list = []
for target_id in targets:
# Check if we can find it in the vectordb
from_vectordb = vectordb.get("target", target_id)
Expand Down Expand Up @@ -121,9 +132,18 @@ def compute_target_embedding(

target_embeddings = torch.stack(sequence_representations, dim=0).numpy() # numpy.ndarray 3775 x 1280
embeddings = target_embeddings[0].tolist()
vectordb.add("target", target_id, vector=embeddings, sequence=target_seq, label=target_label)
vector_list.append({
"vector": embeddings,
"payload": {
"id": target_id,
"sequence":target_seq,
"label": target_label
}
})
# vectordb.add("target", target_id, vector=embeddings, sequence=target_seq, label=target_label)
embeddings.insert(0, target_id)
df.loc[len(df)] = embeddings
vectordb.add("target", vector_list)
return df


Expand Down
20 changes: 15 additions & 5 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,20 +254,30 @@ def kfoldCV(sc, pairs_all, classes_all, embedding_df, clfs, n_run, n_fold, n_pro

failed_conversion = []
# Add drug embeddings to the vector db
vector_list = []
for _index, row in embeddings["drug"].iterrows():
log.info(f"Drug {_index}/{len(embeddings['drug'])}")
vector = [row[column] for column in embeddings["drug"].columns if column != "drug"]
# if pubchem_id not in pubchem_ids:
# failed_conversion.append(row['drug'])
# continue
pubchem_id = pubchem_ids[f"DRUGBANK:{row['drug']}"]
if not pubchem_id or not pubchem_id.lower().startswith("pubchem.compound:"):
failed_conversion.append(f"{row['drug']} > {pubchem_id}")
drug_id = pubchem_ids[f"DRUGBANK:{row['drug']}"]
if not drug_id or not drug_id.lower().startswith("pubchem.compound:"):
failed_conversion.append(f"{row['drug']} > {drug_id}")
continue

# pubchem = normalize_id_to_translator()
drug_smiles, drug_label = get_smiles_for_drug(pubchem_id)
vectordb.add("drug", pubchem_id, vector=vector, sequence=drug_smiles, label=drug_label)
drug_smiles, drug_label = get_smiles_for_drug(drug_id)
vector_list.append({
"vector": vector,
"payload": {
"id": drug_id,
"sequence":drug_smiles,
"label": drug_label
}
})

vectordb.add("drug", vector_list)

print(f"{len(failed_conversion)} drugs ignored:")
print("\n".join(failed_conversion))
Expand Down
32 changes: 17 additions & 15 deletions src/vectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,25 @@ def __init__(
)

def add(
self, collection_name: str, entity_id: str, vector: list[float], sequence: str | None = None, label: str | None = None
self, collection_name: str, item_list: list[str]
) -> UpdateResult:
payload = {"id": entity_id}
if sequence:
payload["sequence"] = sequence
if label:
payload["label"] = label
# payload = {"id": entity_id}
# if sequence:
# payload["sequence"] = sequence
# if label:
# payload["label"] = label
points_count = self.client.get_collection(collection_name).points_count
points_list = [
PointStruct(id=points_count + i + 1, vector=item["vector"], payload=item["payload"]) for i, item in enumerate(item_list)
]
operation_info = self.client.upsert(
collection_name=collection_name,
wait=True,
points=[
PointStruct(
id=self.client.get_collection(collection_name).points_count + 1, vector=vector, payload=payload
),
# PointStruct(id=2, vector=[0.19, 0.81, 0.75, 0.11], payload={"city": "London"}),
],
points=points_list,
# [PointStruct(
# id=self.client.get_collection(collection_name).points_count + 1, vector=vector, payload=payload
# )],
# PointStruct(id=2, vector=[0.19, 0.81, 0.75, 0.11], payload={"city": "London"}),
)
return operation_info

Expand Down Expand Up @@ -138,7 +141,6 @@ def search(
return search_result[0]


def init_vectordb(collections: list[dict[str, str]], recreate: bool = False):
def init_vectordb(collections: list[dict[str, str]], recreate: bool = False, api_key: str = "TOCHANGE"):
qdrant_url = "qdrant.137.120.31.148.nip.io"
qdrant_apikey = "TOCHANGE"
return QdrantDB(collections=collections, recreate=recreate, host=qdrant_url, port=443, api_key=qdrant_apikey)
return QdrantDB(collections=collections, recreate=recreate, host=qdrant_url, port=443, api_key=api_key)

0 comments on commit 2a64eb6

Please sign in to comment.