Skip to content

Commit

Permalink
NatComm revision updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed Jan 2, 2025
1 parent 5899a74 commit d3b9ce7
Showing 21 changed files with 677 additions and 306 deletions.
10 changes: 9 additions & 1 deletion datasail/__main__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from datasail.sail import sail
import os

#N_THREADS = "1"
#os.environ["OPENBLAS_NUM_THREADS"] = N_THREADS
#os.environ["OPENBLAS_MAX_THREADS"] = N_THREADS
#os.environ["GOTO_NUM_THREADS"] = N_THREADS
#os.environ["OMP_NUM_THREADS"] = N_THREADS
os.environ["GRB_LICENSE_FILE"] = "/home/rjo21/gurobi_mickey.lic"

from datasail.sail import sail

if __name__ == '__main__':
sail()
3 changes: 2 additions & 1 deletion datasail/cluster/clustering.py
Original file line number Diff line number Diff line change
@@ -73,7 +73,7 @@ def cluster(dataset: DataSet, **kwargs) -> DataSet:
if len(dataset.cluster_names) > dataset.num_clusters:
dataset = force_clustering(dataset, kwargs[KW_LINKAGE])

store_to_cache(dataset, **kwargs)
# store_to_cache(dataset, **kwargs)

return dataset

@@ -212,6 +212,7 @@ def additional_clustering(
)
# cluster the clusters into new, fewer, and bigger clusters
labels = ca.fit_predict(cluster_matrix)
LOGGER.info("Clustering finished")
return labels2clusters(labels, dataset, cluster_matrix, linkage)


121 changes: 95 additions & 26 deletions datasail/cluster/foldseek.py
Original file line number Diff line number Diff line change
@@ -2,8 +2,12 @@
import shutil
from pathlib import Path
from typing import Optional
import pickle

import numpy as np
from pyarrow import compute, csv
from collections import defaultdict
from tqdm import tqdm

from datasail.parsers import MultiYAMLParser
from datasail.reader.utils import DataSet
@@ -23,25 +27,26 @@ def run_foldseek(dataset: DataSet, threads: int = 1, log_dir: Optional[Path] = N
raise ValueError("Foldseek is not installed.")
user_args = MultiYAMLParser(FOLDSEEK).get_user_arguments(dataset.args, [])

results_folder = Path("fs_results")
results_folder = Path("/scratch/SCRATCH_SAS/roman/DataSAIL/fs_results")

tmp = Path("tmp")
tmp = Path("/scratch/SCRATCH_SAS/roman/DataSAIL/fs_tmp")
tmp.mkdir(parents=True, exist_ok=True)
for name, filepath in dataset.data.items():
shutil.copy(filepath, tmp)
##for name in dataset.names:
## shutil.copy(dataset.data[name], tmp)

cmd = f"mkdir {results_folder} && " \
f"cd {results_folder} && " \
f"foldseek " \
f"easy-search " \
f"../tmp " \
f"../tmp " \
f"{str(tmp.resolve())} " \
f"{str(tmp.resolve())} " \
f"aln.m8 " \
f"tmp " \
f"--format-output 'query,target,fident' " \
f"--format-output 'query,target,fident,qlen,lddt' " \
f"-e inf " \
f"--threads {threads} " \
f"{user_args}" # && " \
# f"--exhaustive-search 1 " \
# f"rm -rf ../tmp"

if log_dir is None:
@@ -54,27 +59,91 @@ def run_foldseek(dataset: DataSet, threads: int = 1, log_dir: Optional[Path] = N

LOGGER.info("Start FoldSeek clustering")
LOGGER.info(cmd)
os.system(cmd)

if not (results_folder / "aln.m8").exists():
raise ValueError("Something went wrong with foldseek. The output file does not exist.")

namap = dict((n, i) for i, n in enumerate(dataset.names))
cluster_sim = np.zeros((len(dataset.names), len(dataset.names)))
with open(f"{results_folder}/aln.m8", "r") as data:
for line in data.readlines():
q1, q2, sim = line.strip().split("\t")[:3]
if "_" in q1 and "." in q1 and q1.rindex("_") > q1.index("."):
q1 = "_".join(q1.split("_")[:-1])
if "_" in q2 and "." in q2 and q2.rindex("_") > q2.index("."):
q2 = "_".join(q2.split("_")[:-1])
q1 = q1.replace(".pdb", "")
q2 = q2.replace(".pdb", "")
cluster_sim[namap[q1], namap[q2]] = sim
cluster_sim[namap[q2], namap[q1]] = sim
##os.system(cmd)

##if not (results_folder / "aln.m8").exists():
## raise ValueError("Something went wrong with foldseek. The output file does not exist.")

##ds = read_with_pyarrow(f"{results_folder}/aln.m8")
#with open("/scratch/SCRATCH_SAS/roman/DataSAIL/pyarrow.pkl", "rb") as data:
# ds = pickle.load(data)

#try:
#except Exception as e:
# print("pickling failed due to:", e)
# namap = dict((n, i) for i, n in enumerate(dataset.names))
##cluster_sim = np.zeros((len(dataset.names), len(dataset.names)))
#with open(f"{results_folder}/aln.m8", "r") as data:
# for line in data.readlines():
# q1, q2, sim = line.strip().split("\t")[:3]
# if "_" in q1 and "." in q1 and q1.rindex("_") > q1.index("."):
# q1 = "_".join(q1.split("_")[:-1])
# if "_" in q2 and "." in q2 and q2.rindex("_") > q2.index("."):
# q2 = "_".join(q2.split("_")[:-1])
# q1 = q1.replace(".pdb", "")
# q2 = q2.replace(".pdb", "")
# cluster_sim[namap[q1], namap[q2]] = sim
# cluster_sim[namap[q2], namap[q1]] = sim
# print("Additional names:", set(dataset.names).difference(set(ds.keys())))
# print("Additional hits:", set(ds.keys()).difference(set(dataset.names)))
# exit(0)
##for i, name1 in enumerate(dataset.names):
## cluster_sim[i, i] = 1
## for j, name2 in enumerate(dataset.names[i + 1:]):
## if name2 in ds[name1]:
## cluster_sim[i, j] = ds[name1][name2][2] / ds[name1][name2][3]
## if name1 in ds[name2]:
## cluster_sim[j, i] = ds[name2][name1][2] / ds[name2][name1][3]
##cluster_sim = (cluster_sim + cluster_sim.T) / 2

# with open("/scratch/SCRATCH_SAS/roman/DataSAIL/PLINDER/prot_sim_full_v12.pkl", "wb") as out:
# pickle.dump(ds, out)
with open("/scratch/SCRATCH_SAS/roman/DataSAIL/PLINDER/eval/full_v0/prots.pkl", "rb") as f:
ds = pickle.load(f)

shutil.rmtree(results_folder, ignore_errors=True)
shutil.rmtree(tmp, ignore_errors=True)

dataset.cluster_names = dataset.names
dataset.cluster_map = dict((n, n) for n in dataset.names)
dataset.cluster_similarity = cluster_sim
dataset.cluster_similarity = ds.cluster_similarity ## cluster_sim


def extract(tmp):
if len(tmp) == 1:
return tmp[0], "?"
else:
return "_".join(tmp[:-1]), tmp[-1]


def inner_list():
return ["", "", 0, 0]


def outer_dict():
return defaultdict(inner_list)


def read_with_pyarrow(file_path):
table = csv.read_csv(
file_path,
read_options=csv.ReadOptions(use_threads=True, column_names=["qid_chainid", "tid_chainid", "fident", "qlen", "lddt"]),
parse_options=csv.ParseOptions(delimiter="\t"),
)

indices = compute.sort_indices(table, [("lddt", "descending"), ("fident", "descending")])
ds = defaultdict(outer_dict)
for idx in tqdm(indices):
q_id, q_chain = extract(table["qid_chainid"][idx.as_py()].as_py().split("_"))
t_id, t_chain = extract(table["tid_chainid"][idx.as_py()].as_py().split("_"))
record = ds[q_id][t_id]
if q_chain in record[0] or t_chain in record[1]:
continue
fident = table["fident"][idx.as_py()].as_py()
q_len = table["qlen"][idx.as_py()].as_py()
record[0] += q_chain
record[1] += t_chain
record[2] += fident * q_len
record[3] += q_len
return ds

1 change: 1 addition & 0 deletions datasail/cluster/mash.py
Original file line number Diff line number Diff line change
@@ -58,6 +58,7 @@ def run_mash(dataset: DataSet, threads: int = 1, log_dir: Optional[Path] = None)
dataset.cluster_names = dataset.names

shutil.rmtree(results_folder, ignore_errors=True)
shutil.rmtree(tmp, ignore_errors=True)


def read_mash_tsv(filename: Path, num_entities: int) -> np.ndarray:
2 changes: 1 addition & 1 deletion datasail/reader/read_molecules.py
Original file line number Diff line number Diff line change
@@ -95,7 +95,7 @@ def remove_molecule_duplicates(dataset: DataSet) -> DataSet:
dataset: The dataset to remove duplicates from
Returns:
Update arguments as teh location of the data might change and an ID-Map file might be added.
Update arguments as the location of the data might change and an ID-Map file might be added.
"""
if isinstance(dataset.data[dataset.names[0]], (list, tuple, np.ndarray)):
# TODO: proper check for duplicate embeddings
9 changes: 8 additions & 1 deletion datasail/reader/utils.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
from dataclasses import dataclass, fields
from pathlib import Path
from typing import Generator, Tuple, List, Optional, Dict, Union, Any, Callable, Iterable, Set
from collections.abc import Iterable

import h5py
import numpy as np
@@ -120,8 +121,13 @@ def strat2oh(self, name: Optional[str] = None, classes: Optional[Union[str, Set[
if name is None:
raise ValueError("Either name or class must be provided.")
classes = self.stratification[name]
if not isinstance(classes, Iterable):
classes = [classes]
if self.classes is not None:
# print(name, self.class_oh[[self.classes[class_] for class_ in classes]].sum(axis=0))
# print("classes", classes)
# print("self.cl", self.classes)
# print("self.oh", self.class_oh)
return self.class_oh[[self.classes[class_] for class_ in classes]].sum(axis=0)
return None

@@ -285,7 +291,8 @@ def read_data(
elif isinstance(weights, Generator):
dataset.weights = dict(weights)
elif inter is not None:
dataset.weights = dict(count_inter(inter, index))
dataset.weights = {k: 0 for k in dataset.data.keys()}
dataset.weights.update(dict(count_inter(inter, index)))
else:
dataset.weights = {k: 1 for k in dataset.data.keys()}

39 changes: 30 additions & 9 deletions datasail/routine.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import time
import pickle
from typing import Dict, Tuple, Optional

from datasail.argparse_patch import remove_patch
from datasail.cluster.clustering import cluster
from datasail.reader.read import read_data
from datasail.reader.utils import DataSet
from datasail.report import report
from datasail.settings import LOGGER, KW_TECHNIQUES, KW_EPSILON, KW_RUNS, KW_SPLITS, KW_NAMES, \
from datasail.settings import LOGGER, KW_INTER, KW_TECHNIQUES, KW_EPSILON, KW_RUNS, KW_SPLITS, KW_NAMES, \
KW_MAX_SEC, KW_MAX_SOL, KW_SOLVER, KW_LOGDIR, NOT_ASSIGNED, KW_OUTDIR, MODE_E, MODE_F, DIM_2, SRC_CL, KW_DELTA, \
KW_E_CLUSTERS, KW_F_CLUSTERS, KW_CC, CDHIT, INSTALLED, FOLDSEEK, TMALIGN, CDHIT_EST, DIAMOND, MMSEQS, MASH
from datasail.solver.solve import run_solver
@@ -32,27 +33,38 @@ def datasail_main(**kwargs) -> Optional[Tuple[Dict, Dict, Dict]]:
**kwargs: Parsed commandline arguments to DataSAIL.
"""
kwargs = remove_patch(**kwargs)
if kwargs[KW_CC]:
if kwargs.get(KW_CC, False):
list_cluster_algos()
return None

start = time.time()
LOGGER.info("Read data")

# read e-entities and f-entities
e_dataset, f_dataset, inter = read_data(**kwargs)
e_dataset, f_dataset_tmp, inter = read_data(**kwargs)

# if required, cluster the input otherwise define the cluster-maps to be None
clusters = list(filter(lambda x: x[0].startswith(SRC_CL), kwargs[KW_TECHNIQUES]))
cluster_e = len(clusters) != 0 and any(c[-1] in {DIM_2, MODE_E} for c in clusters)
cluster_f = len(clusters) != 0 and any(c[-1] in {DIM_2, MODE_F} for c in clusters)

if cluster_e:
LOGGER.info("Cluster first set of entities.")
e_dataset = cluster(e_dataset, **kwargs)
if cluster_f:
LOGGER.info("Cluster second set of entities.")
f_dataset = cluster(f_dataset, **kwargs)
#if cluster_e:
# LOGGER.info("Cluster first set of entities.")
# e_dataset = cluster(e_dataset, **kwargs)
#if cluster_f:
# LOGGER.info("Cluster second set of entities.")
# f_dataset = cluster(f_dataset, **kwargs)

split = str(kwargs[KW_INTER]).split("/")[-2]
#with open(f"/scratch/SCRATCH_SAS/roman/DataSAIL/PLINDER/{split}.pkl", "wb") as f:
# pickle.dump((e_dataset, f_dataset), f)
with open(f"/scratch/SCRATCH_SAS/roman/DataSAIL/PLINDER/{split}.pkl", "rb") as f:
e_dataset, f_dataset = pickle.load(f)
f_dataset.id_map = f_dataset_tmp.id_map

#print("E_ID_Map is None:", e_dataset.id_map is None)
#print("F_ID_Map is None:", f_dataset.id_map is None)
#print("Nones in inter :", sum([x is None for x in inter]))

if inter is not None:
if e_dataset.type is not None and f_dataset.type is not None:
@@ -88,13 +100,22 @@ def datasail_main(**kwargs) -> Optional[Tuple[Dict, Dict, Dict]]:

LOGGER.info("Store results")

#print("E name:", e_name_split_map.keys())
#print("F name:", f_name_split_map.keys())
#print("E cluster:", e_cluster_split_map.keys())
#print("F cluster:", f_cluster_split_map.keys())
#print("Inter:", inter_split_map.keys())

# infer interaction assignment from entity assignment if necessary and possible
output_inter_split_map = dict()
if new_inter is not None:
for technique in kwargs[KW_TECHNIQUES]:
output_inter_split_map[technique] = []
for run in range(kwargs[KW_RUNS]):
output_inter_split_map[technique].append(dict())
#print(e_name_split_map.keys())
#print(f_name_split_map.keys())
#print(techique)
for e, f in inter:
if technique.endswith(DIM_2) or technique == "R":
output_inter_split_map[technique][-1][(e, f)] = inter_split_map[technique][run].get(
1 change: 1 addition & 0 deletions datasail/solver/cluster_1d.py
Original file line number Diff line number Diff line change
@@ -64,6 +64,7 @@ def solve_c1(

loss = cvxpy.sum([t for tmp_list in tmp for t in tmp_list])
problem = solve(loss, constraints, max_sec, solver, log_file)
print(problem)

return None if problem is None else {
e: names[s] for s in range(len(splits)) for i, e in enumerate(clusters) if x[s, i].value > 0.1
Loading

0 comments on commit d3b9ce7

Please sign in to comment.