diff --git a/src/predict.py b/src/predict.py index e9f890a..f1d73b4 100644 --- a/src/predict.py +++ b/src/predict.py @@ -46,6 +46,8 @@ def compute_drug_embedding( os.chdir("MolecularTransformerEmbeddings") vector_list = [] # embed_dict = get_smiles_embeddings(smiles_list) + drugs_without_embed = {} + drugs_labels = {} for drug_id in drugs: from_vectordb = vectordb.get("drug", drug_id) if len(from_vectordb) > 0: @@ -55,22 +57,16 @@ def compute_drug_embedding( # df = pd.concat([df, pd.DataFrame(embeddings)], ignore_index = True) df.loc[len(df)] = embeddings continue + else: + drug_smiles, drug_label = get_smiles_for_drug(drug_id) + drugs_without_embed[drug_smiles] = drug_id + drugs_labels[drug_id] = drug_label drug_smiles, drug_label = get_smiles_for_drug(drug_id) - log.info(f"⏳💊 Drug {drug_id} not found in VectorDB, computing its embeddings from SMILES {drug_smiles}") - embed_dict = get_smiles_embeddings([drug_smiles]) - - # with open("../tmp/drug_smiles.txt", "w") as f: - # f.write(drug_smiles) - # os.system("python embed.py --data_path=../tmp/drug_smiles.txt") - # o = np.load("embeddings/drug_smiles.npz") - # files = o.files # 1 file - # gen_embeddings = [] - # for file in files: - # gen_embeddings.append(o[file]) # 'numpy.ndarray' n length x 512 - # vectors = np.stack([emb.mean(axis=0) for emb in gen_embeddings]) - # # In this case we vectorize one by one, so only 1 row in the array - # embeddings = vectors[0].tolist() + embed_dict = get_smiles_embeddings([drug_id]) + # log.info(f"⏳💊 Drug {drugs_without_embed.keys()} not found in VectorDB, computing its embeddings from SMILES {drug_smiles}") + # embed_dict = get_smiles_embeddings(list(drugs_without_embed.keys())) + # TODO: add label also? embeddings: np.array = embed_dict[drug_smiles].tolist() vector_list.append({