diff --git a/src/train.py b/src/train.py index 6f41d39..ea865fa 100644 --- a/src/train.py +++ b/src/train.py @@ -467,10 +467,11 @@ def exclude_similar(input_dir, subject_sim_threshold: float = 1, object_sim_thre if object_sim_threshold < 1: df_targets = drop_similar(df_targets, "target", object_sim_threshold) - df_known_dt = df_known_dt.merge(df_drugs[["drug"]], on="drug").merge(df_targets[["target"]], on="target") + # TODO: remove drugs/targets for which we don't have smiles/AA seq? + df_known_dt = df_known_dt[df_known_dt['drug'].isin(df_drugs['drug']) & df_known_dt['target'].isin(df_targets['target'])] log.info(f"DF LENGTH AFTER DROPPING: {len(df_drugs)} drugs and {len(df_targets)} targets, and {len(df_known_dt)} known pairs") - + print(df_known_dt) return df_known_dt, df_drugs, df_targets