From f16dd06fe4b392998c6b43fe00dac96c5cdd46f7 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 27 Feb 2024 18:54:51 -0500 Subject: [PATCH 01/39] simple testing on cellxgene api --- environment.yml | 2 + tdc/cellxgene-census-loaders/__init__.py | 0 .../cellxgene-census.py | 78 +++++++++++++++++++ 3 files changed, 80 insertions(+) create mode 100644 tdc/cellxgene-census-loaders/__init__.py create mode 100644 tdc/cellxgene-census-loaders/cellxgene-census.py diff --git a/environment.yml b/environment.yml index 4aa999fe..5bc18717 100644 --- a/environment.yml +++ b/environment.yml @@ -18,3 +18,5 @@ dependencies: - cellxgene-census==1.10.2 - PyTDC==0.4.1 - rdkit==2023.9.5 + - tiledbsoma==1.7.2 + - yapf==0.40.2 diff --git a/tdc/cellxgene-census-loaders/__init__.py b/tdc/cellxgene-census-loaders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tdc/cellxgene-census-loaders/cellxgene-census.py b/tdc/cellxgene-census-loaders/cellxgene-census.py new file mode 100644 index 00000000..7be20608 --- /dev/null +++ b/tdc/cellxgene-census-loaders/cellxgene-census.py @@ -0,0 +1,78 @@ +import cellxgene_census +from pandas import concat +import tiledbsoma + +from tdc import base_dataset +""" + +Are we only supporting memory-efficient queries? +https://chanzuckerberg.github.io/cellxgene-census/cellxgene_census_docsite_quick_start.html#memory-efficient-queries + + +""" + + +class CXGDataLoader(base_dataset.DataLoader): + + def __init__(self, + num_slices=None, + census_version="2023-12-15", + dataset="census_data", + organism="homo_sapiens", + measurement_name="RNA", + value_filter="", + column_names=None): + if column_names is None: + raise ValueError("column_names is required for this loader") + self.column_names = column_names + num_slices = num_slices if num_slices is not None else 1 + self.num_slices = num_slices + self.df = None + self.fetch_data(census_version, dataset, organism, measurement_name, + value_filter) + + def fetch_data(self, census_version, dataset, organism, measurement_name, + value_filter): + """TODO: docs + outputs a dataframe with specified query params on census data SOMA collection object + """ + if self.column_names is None: + raise ValueError( + "Column names must be provided to CXGDataLoader class") + + with cellxgene_census.open_soma( + census_version=census_version) as census: + # Reads SOMADataFrame as a slice + cell_metadata = census[dataset][organism].obs.read( + value_filter=value_filter, column_names=self.column_names) + self.df = cell_metadata.concat().to_pandas() + # TODO: not latency on memory-efficient queries is poor... + # organismCollection = census[dataset][organism] + # query = organismCollection.axis_query( + # measurement_name = measurement_name, + # obs_query = tiledbsoma.AxisQuery( + # value_filter = value_filter + # ) + # ) + # it = query.X("raw").tables() + # dfs =[] + # for _ in range(self.num_slices): + # slice = next (it) + # df_slice = slice.to_pandas() + # dfs.append(df_slice) + # self.df = concat(dfs) + + def get_dataframe(self): + if self.df is None: + raise Exception( + "Haven't instantiated a DataFrame yet. You can call self.fetch_data first." + ) + return self.df + + +if __name__ == "__main__": + # TODO: tmp, run testing suite when this file is called as main + loader = CXGDataLoader(value_filter="tissue == 'brain' and sex == 'male'", + column_names=["assay", "cell_type", "tissue"]) + df = loader.get_dataframe() + print(df.head()) From 1d479b1be657322ed7a6d8aff7a66d40ccf5be8a Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Fri, 1 Mar 2024 14:48:04 -0500 Subject: [PATCH 02/39] cellxgene census is a resource --- tdc/cellxgene-census-loaders/__init__.py | 0 tdc/{cellxgene-census-loaders => resource}/cellxgene-census.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tdc/cellxgene-census-loaders/__init__.py rename tdc/{cellxgene-census-loaders => resource}/cellxgene-census.py (100%) diff --git a/tdc/cellxgene-census-loaders/__init__.py b/tdc/cellxgene-census-loaders/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tdc/cellxgene-census-loaders/cellxgene-census.py b/tdc/resource/cellxgene-census.py similarity index 100% rename from tdc/cellxgene-census-loaders/cellxgene-census.py rename to tdc/resource/cellxgene-census.py From f2b7caa2e245cd8760c57c13d7b73253cc4a6bac Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Sat, 2 Mar 2024 12:52:10 -0500 Subject: [PATCH 03/39] makae census api a resource --- tdc/resource/cellxgene-census.py | 95 ++++++++++++++++++++++++++------ 1 file changed, 79 insertions(+), 16 deletions(-) diff --git a/tdc/resource/cellxgene-census.py b/tdc/resource/cellxgene-census.py index 7be20608..1170a53b 100644 --- a/tdc/resource/cellxgene-census.py +++ b/tdc/resource/cellxgene-census.py @@ -2,17 +2,8 @@ from pandas import concat import tiledbsoma -from tdc import base_dataset -""" - -Are we only supporting memory-efficient queries? -https://chanzuckerberg.github.io/cellxgene-census/cellxgene_census_docsite_quick_start.html#memory-efficient-queries - -""" - - -class CXGDataLoader(base_dataset.DataLoader): +class CensusResource: def __init__(self, num_slices=None, @@ -43,9 +34,67 @@ def fetch_data(self, census_version, dataset, organism, measurement_name, with cellxgene_census.open_soma( census_version=census_version) as census: # Reads SOMADataFrame as a slice - cell_metadata = census[dataset][organism].obs.read( - value_filter=value_filter, column_names=self.column_names) - self.df = cell_metadata.concat().to_pandas() + print("reading data...") + obs = census[dataset][organism].obs + print("got var df") + print("filtering") + varread = obs.read(value_filter=value_filter, column_names = self.column_names) + print("converting to pandas.. first pyarrow") + cc = varread.concat() + print("now to pandas") + self.df = cc.to_pandas() + print(self.df.head()) + print("now running var test") + + var = census[dataset][organism].ms[measurement_name].var + self.df = var.read(column_names=["feature_name", "feature_reference"], value_filter="feature_id in ['ENSG00000161798', 'ENSG00000188229']").concat().to_pandas() + print("printing var") + print(self.df.head()) + # TODO: gene / var columns are var: 'soma_joinid', 'feature_id', 'feature_name', 'feature_length' + + print("now testing queries on data matrices (X)") + n_obs = len(obs) + n_var =len(var) + X = census[dataset][organism].ms[measurement_name].X["raw"] + slc = X.read((slice(0, 5),)).coos((n_obs,n_var)) # need bounding boxes + self.df = slc.concat() + print(self.df) + print("prinnting X") + print(self.df.to_scipy().todense()) + + print("now testing feature dataset presence matrix") + fMatrix = census[dataset][organism].ms[measurement_name]["feature_dataset_presence_matrix"] + slc = fMatrix.read((slice(0, 5),)).coos((n_obs,n_var)) # need bounding boxes + self.df = slc.concat() + print(self.df) + print("printing ftp matrix") + print(self.df.to_scipy().todense()) + + print("can we do full read on X?") + bded = X.read().coos((n_obs, n_var)) # still need bounding boxes + print("can get the sparse array coos()") + print("we cannot get pyarrow") + # print("pyarrow") + # bded.concat() + print("yes we can") + + # X = census[dataset][organism].ms[measurement_name].X["raw"] + # sparse_array = X.read() + # print("spare array...") + # print(sparse_array) + # # TODO: tmp + # print("converting to pandas") + # print("first to pyarrow") + # self.df = sparse_array.coos().concat() + # print("done") + # print("now pandas") + # self.df = self.df.to_pandas() + # print("done") + + + # .read( + # value_filter=value_filter, column_names=self.column_names) + # self.df = cell_metadata.concat().to_pandas() # TODO: not latency on memory-efficient queries is poor... # organismCollection = census[dataset][organism] # query = organismCollection.axis_query( @@ -69,10 +118,24 @@ def get_dataframe(self): ) return self.df + def get_data(self, type="df"): + if type == "df": + return self.get_dataframe() + elif type == "pyarrow": + raise Exception("PyArrow format not supported by TDC yet.") + else: + raise Exception("Type must be set to df or pyarrow") + if __name__ == "__main__": # TODO: tmp, run testing suite when this file is called as main - loader = CXGDataLoader(value_filter="tissue == 'brain' and sex == 'male'", + print("initializing object") + loader = CensusResource(value_filter="tissue == 'brain' and sex == 'male'", column_names=["assay", "cell_type", "tissue"]) - df = loader.get_dataframe() - print(df.head()) + print("getting") + df = loader.get_data() + print("getting head()") + # print(df.head()) + print("no dense") + print(df.to_scipy()) + print("done!") From 8ffa1104a1a1bd19798d71c4044d0f06beecb4e2 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Mon, 4 Mar 2024 16:42:48 -0500 Subject: [PATCH 04/39] mvp cellxgene resource implemented --- tdc/resource/cellxgene-census.py | 285 +++++++++++++++++-------------- 1 file changed, 161 insertions(+), 124 deletions(-) diff --git a/tdc/resource/cellxgene-census.py b/tdc/resource/cellxgene-census.py index 1170a53b..4ee2ea4b 100644 --- a/tdc/resource/cellxgene-census.py +++ b/tdc/resource/cellxgene-census.py @@ -1,141 +1,178 @@ +# TODO: tmp fix +import os +os.environ['KMP_DUPLICATE_LIB_OK'] = "TRUE" +# TODO: find better fix or encode in environment / docker ^^^ import cellxgene_census +from functools import wraps from pandas import concat import tiledbsoma class CensusResource: + + _CENSUS_DATA = "census_data" + _FEATURE_PRESENCE = "feature_dataset_presence_matrix" + _LATEST_CENSUS = "2023-12-15" + _HUMAN = "homo_sapiens" - def __init__(self, - num_slices=None, - census_version="2023-12-15", - dataset="census_data", - organism="homo_sapiens", - measurement_name="RNA", - value_filter="", - column_names=None): - if column_names is None: - raise ValueError("column_names is required for this loader") - self.column_names = column_names - num_slices = num_slices if num_slices is not None else 1 - self.num_slices = num_slices - self.df = None - self.fetch_data(census_version, dataset, organism, measurement_name, - value_filter) + class decorators: + @classmethod + def check_dataset_is_census_data(cls,func): + # @wraps(func) + def check(*args, **kwargs): + self = args[0] + if self.dataset != self._CENSUS_DATA: + raise ValueError("This function requires the '{}' dataset".format(self._CENSUS_DATA)) + return func(*args, **kwargs) + return check - def fetch_data(self, census_version, dataset, organism, measurement_name, - value_filter): - """TODO: docs - outputs a dataframe with specified query params on census data SOMA collection object + def __init__(self, + census_version=None, + dataset=None, + organism=None + ): + """Initialize the Census Resource. + + Args: + census_version (str): The date of the census data release in YYYY- + TODO: complete """ - if self.column_names is None: - raise ValueError( - "Column names must be provided to CXGDataLoader class") + self.census_version = census_version if census_version is not None else self._LATEST_CENSUS + self.dataset = dataset if dataset is not None else self._CENSUS_DATA + self.organism = organism if organism is not None else self._HUMAN + def fmt_cellxgene_data(self, tiledb_ptr, fmt=None): + if fmt is None: + raise Exception("format not provided to fmt_cellxgene_data(), please provide fmt variable") + elif fmt == "pandas": + return tiledb_ptr.concat().to_pandas() + elif fmt == "pyarrow": + return tiledb_ptr.concat() + elif fmt == "scipy": + return tiledb_ptr.concat().to_scipy() + else: + raise Exception("fmt not in [pandas, pyarrow, scipy] for fmt_cellxgene_data()") + + @decorators.check_dataset_is_census_data + def get_cell_metadata(self, value_filter=None, column_names=None, fmt=None): + """Get the cell metadata (obs) data from the Census API""" + if value_filter is None: + raise Exception("No value filter was provided, dataset is too large to fit in memory. \ + Memory-Efficient queries are not supported yet.") + fmt = fmt if fmt is not None else "pandas" with cellxgene_census.open_soma( - census_version=census_version) as census: - # Reads SOMADataFrame as a slice - print("reading data...") - obs = census[dataset][organism].obs - print("got var df") - print("filtering") - varread = obs.read(value_filter=value_filter, column_names = self.column_names) - print("converting to pandas.. first pyarrow") - cc = varread.concat() - print("now to pandas") - self.df = cc.to_pandas() - print(self.df.head()) - print("now running var test") - - var = census[dataset][organism].ms[measurement_name].var - self.df = var.read(column_names=["feature_name", "feature_reference"], value_filter="feature_id in ['ENSG00000161798', 'ENSG00000188229']").concat().to_pandas() - print("printing var") - print(self.df.head()) - # TODO: gene / var columns are var: 'soma_joinid', 'feature_id', 'feature_name', 'feature_length' - - print("now testing queries on data matrices (X)") - n_obs = len(obs) - n_var =len(var) - X = census[dataset][organism].ms[measurement_name].X["raw"] - slc = X.read((slice(0, 5),)).coos((n_obs,n_var)) # need bounding boxes - self.df = slc.concat() - print(self.df) - print("prinnting X") - print(self.df.to_scipy().todense()) - - print("now testing feature dataset presence matrix") - fMatrix = census[dataset][organism].ms[measurement_name]["feature_dataset_presence_matrix"] - slc = fMatrix.read((slice(0, 5),)).coos((n_obs,n_var)) # need bounding boxes - self.df = slc.concat() - print(self.df) - print("printing ftp matrix") - print(self.df.to_scipy().todense()) - - print("can we do full read on X?") - bded = X.read().coos((n_obs, n_var)) # still need bounding boxes - print("can get the sparse array coos()") - print("we cannot get pyarrow") - # print("pyarrow") - # bded.concat() - print("yes we can") - - # X = census[dataset][organism].ms[measurement_name].X["raw"] - # sparse_array = X.read() - # print("spare array...") - # print(sparse_array) - # # TODO: tmp - # print("converting to pandas") - # print("first to pyarrow") - # self.df = sparse_array.coos().concat() - # print("done") - # print("now pandas") - # self.df = self.df.to_pandas() - # print("done") - - - # .read( - # value_filter=value_filter, column_names=self.column_names) - # self.df = cell_metadata.concat().to_pandas() - # TODO: not latency on memory-efficient queries is poor... - # organismCollection = census[dataset][organism] - # query = organismCollection.axis_query( - # measurement_name = measurement_name, - # obs_query = tiledbsoma.AxisQuery( - # value_filter = value_filter - # ) - # ) - # it = query.X("raw").tables() - # dfs =[] - # for _ in range(self.num_slices): - # slice = next (it) - # df_slice = slice.to_pandas() - # dfs.append(df_slice) - # self.df = concat(dfs) + census_version=self.census_version) as census: + obs = census[self.dataset][self.organism].obs + obsread = None + if column_names: + obsread = obs.read(value_filter=value_filter, column_names=column_names) + else: + obsread = obs.read(value_filter=value_filter) + return self.fmt_cellxgene_data(obsread, fmt) + + @decorators.check_dataset_is_census_data + def get_gene_metadata(self, value_filter=None, column_names=None, measurement_name=None, fmt=None): + """Get the gene metadata (var) data from the Census API""" + if value_filter is None: + raise Exception("No value filter was provided, dataset is too large to fit in memory. \ + Memory-Efficient queries are not supported yet.") + elif measurement_name is None: + raise ValueError("measurment_name must be provided , i.e. 'RNA'") + fmt = fmt if fmt is not None else "pandas" + with cellxgene_census.open_soma( + census_version=self.census_version + ) as census: + var = census[self.dataset][self.organism].ms[measurement_name].var + varread = None + if column_names: + varread = var.read(value_filter=value_filter, column_names=column_names) + else: + varread = var.read(value_filter=value_filter) + return self.fmt_cellxgene_data(varread, fmt) + + @decorators.check_dataset_is_census_data + def get_measurement_matrix(self, upper=None, lower=None, value_adjustment=None, measurement_name=None, fmt=None, todense=None): + """Count matrix for an input measurement by slice - def get_dataframe(self): - if self.df is None: - raise Exception( - "Haven't instantiated a DataFrame yet. You can call self.fetch_data first." - ) - return self.df + Args: + upper (_type_, optional): _description_. Defaults to None. + lower (_type_, optional): _description_. Defaults to None. + value_adjustment (_type_, optional): _description_. Defaults to None. + measurement_name (_type_, optional): _description_. Defaults to None. - def get_data(self, type="df"): - if type == "df": - return self.get_dataframe() - elif type == "pyarrow": - raise Exception("PyArrow format not supported by TDC yet.") - else: - raise Exception("Type must be set to df or pyarrow") + Raises: + Exception: _description_ + Exception: _description_ + """ + if upper is None or lower is None: + raise Exception("No upper and/or lower bound for slicing was provided. Dataset is too large to fit in memory. \ + Memory-Efficient queries are not supported yet.") + elif measurement_name is None: + raise Exception("measurement_name was not provided.") + elif fmt is not None and fmt not in ["scipy", "pyarrow"]: + raise ValueError("measurement_matrix only supports 'scipy' or 'pyarrow' format") + value_adjustment = value_adjustment if value_adjustment is not None else "raw" + todense = todense if todense is not None else False + fmt = fmt if fmt is not None else "scipy" + if todense and fmt != "scipy": + raise ValueError("dense representation only available in scipy format") + with cellxgene_census.open_soma( + census_version=self.census_version + ) as census: + n_obs = len(census[self.dataset][self.organism].obs) + n_var = len(census[self.dataset][self.organism].ms[measurement_name].var) + X = census[self.dataset][self.organism].ms[measurement_name].X[value_adjustment] + slc = X.read((slice(lower, upper),)).coos((n_obs, n_var)) + out = self.fmt_cellxgene_data(slc, fmt) + return out if not todense else out.todense() + + @decorators.check_dataset_is_census_data + def get_feature_dataset_presence_matrix(self, upper=None, lower=None, measurement_name=None, fmt=None, todense=None): + if upper is None or lower is None: + raise ValueError("No upper and/or lower bound for slicing was provided. Dataset is too large to fit in memory. \ + Memory-Efficient queries are not supported yet.") + elif measurement_name is None: + raise ValueError("measurement_name was not provided") + elif fmt is not None and fmt not in ["scipy", "pyarrow"]: + raise ValueError("feature dataset presence matrix only supports 'scipy' or 'pyarrow' formats") + todense = todense if todense is not None else False + fmt = fmt if fmt is not None else "scipy" + if todense and fmt!="scipy": + raise ValueError("dense representation only available in scipy format") + with cellxgene_census.open_soma( + census_version=self.census_version + ) as census: + n_obs = len(census[self.dataset][self.organism].obs) + n_var = len(census[self.dataset][self.organism].ms[measurement_name].var) + fMatrix = census[self.dataset][self.organism].ms[measurement_name]["feature_dataset_presence_matrix"] + slc = fMatrix.read((slice(0, 5),)).coos((n_obs,n_var)) + out = self.fmt_cellxgene_data(slc, fmt) + return out if not todense else out.todense() if __name__ == "__main__": # TODO: tmp, run testing suite when this file is called as main - print("initializing object") - loader = CensusResource(value_filter="tissue == 'brain' and sex == 'male'", - column_names=["assay", "cell_type", "tissue"]) - print("getting") - df = loader.get_data() - print("getting head()") - # print(df.head()) - print("no dense") - print(df.to_scipy()) - print("done!") + print("running tests for census resource") + print("instantiating resource") + resource = CensusResource() + cell_value_filter = "tissue == 'brain' and sex == 'male'" + cell_column_names = ["assay", "cell_type", "tissue"] + gene_value_filter = "feature_id in ['ENSG00000161798', 'ENSG00000188229']" + gene_column_names = ["feature_name", "feature_reference"] + print("getting cell metadata as pandas dataframe") + obsdf = resource.get_cell_metadata(value_filter=cell_value_filter, column_names=cell_column_names, fmt="pandas") + print("success!") + print(obsdf.head()) + print("geting gene metadata as pyarrow") + varpyarrow = resource.get_gene_metadata(value_filter=gene_value_filter, column_names=gene_column_names, fmt="pyarrow", measurement_name="RNA") + print("success!") + print(varpyarrow) + print("getting sample count matrix, checking todense() and scipy") + Xslice = resource.get_measurement_matrix(upper=5, lower=0, measurement_name="RNA", fmt="scipy", todense=True) + print("success") + print(Xslice) + print("getting feature presence matrix, checking pyarrow") + FMslice = resource.get_feature_dataset_presence_matrix(upper=5, lower=0, measurement_name="RNA", fmt="pyarrow", todense=False) + print("success") + print(FMslice) + print("all tests passed") \ No newline at end of file From bb83822714f86f55da00f1ac0a32de60cde52a6d Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Mon, 4 Mar 2024 20:14:43 -0500 Subject: [PATCH 05/39] implement memory-efficient retrieval of count matrix and update environment. Also support AnnData retrieval --- environment.yml | 4 +- tdc/resource/cellxgene-census.py | 152 +++++++++++++++++++++++++++++-- 2 files changed, 148 insertions(+), 8 deletions(-) diff --git a/environment.yml b/environment.yml index 5bc18717..a9ef8915 100644 --- a/environment.yml +++ b/environment.yml @@ -16,7 +16,9 @@ dependencies: - tqdm=4.65.0 - pip: - cellxgene-census==1.10.2 - - PyTDC==0.4.1 + - gget==0.28.4 - rdkit==2023.9.5 - tiledbsoma==1.7.2 - yapf==0.40.2 +variables: + KMP_DUPLICATE_LIB_OK: "TRUE" diff --git a/tdc/resource/cellxgene-census.py b/tdc/resource/cellxgene-census.py index 4ee2ea4b..d400e608 100644 --- a/tdc/resource/cellxgene-census.py +++ b/tdc/resource/cellxgene-census.py @@ -3,14 +3,14 @@ os.environ['KMP_DUPLICATE_LIB_OK'] = "TRUE" # TODO: find better fix or encode in environment / docker ^^^ import cellxgene_census -from functools import wraps -from pandas import concat +import gget import tiledbsoma class CensusResource: _CENSUS_DATA = "census_data" + _CENSUS_META = "census_info" _FEATURE_PRESENCE = "feature_dataset_presence_matrix" _LATEST_CENSUS = "2023-12-15" _HUMAN = "homo_sapiens" @@ -21,14 +21,20 @@ def check_dataset_is_census_data(cls,func): # @wraps(func) def check(*args, **kwargs): self = args[0] - if self.dataset != self._CENSUS_DATA: - raise ValueError("This function requires the '{}' dataset".format(self._CENSUS_DATA)) + self.dataset = self._CENSUS_DATA + return func(*args, **kwargs) + return check + + @classmethod + def check_dataset_is_census_info(cls,func): + def check(*args, **kwargs): + self = args[0] + self.dataset = self._CENSUS_META return func(*args, **kwargs) return check def __init__(self, census_version=None, - dataset=None, organism=None ): """Initialize the Census Resource. @@ -38,8 +44,8 @@ def __init__(self, TODO: complete """ self.census_version = census_version if census_version is not None else self._LATEST_CENSUS - self.dataset = dataset if dataset is not None else self._CENSUS_DATA self.organism = organism if organism is not None else self._HUMAN + self.dataset = None # variable to set target census collection to either info or data def fmt_cellxgene_data(self, tiledb_ptr, fmt=None): if fmt is None: @@ -149,6 +155,138 @@ def get_feature_dataset_presence_matrix(self, upper=None, lower=None, measuremen out = self.fmt_cellxgene_data(slc, fmt) return out if not todense else out.todense() + @decorators.check_dataset_is_census_info + def get_metadata(self): + """Get the metadata for the Cell Census.""" + with cellxgene_census.open_soma( + census_version=self.census_version + ) as census: + return census[self.dataset]["summary"] + + @decorators.check_dataset_is_census_info + def get_dataset_metadata(self): + """Get the metadata for the Cell Census's datasets.""" + with cellxgene_census.open_soma( + census_version=self.census_version + ) as census: + return census[self.dataset]["datasets"] + + @decorators.check_dataset_is_census_info + def get_cell_count_metadata(self): + """Get the cell counts across cell metadata for the Cell Census.""" + with cellxgene_census.open_soma( + census_version=self.census_version + ) as census: + return census[self.dataset]["summary_cell_counts"] + + @decorators.check_dataset_is_census_data + def query_measurement_matrix( + self, + value_filter=None, + value_adjustment=None, + measurement_name=None, + fmt=None, + todense=None + ): + """Query the Census Measurement Matrix. Function returns a Python generator. + + Args: + value_filter (_type_, optional): _description_. Defaults to None. + value_adjustment (_type_, optional): _description_. Defaults to None. + measurement_name (_type_, optional): _description_. Defaults to None. + fmt (_type_, optional): _description_. Defaults to None. + todense (_type_, optional): _description_. Defaults to None. + + Raises: + ValueError: _description_ + Exception: _description_ + ValueError: _description_ + ValueError: _description_ + + Yields: + a slice of the output query in the specified format + """ + if value_filter is None: + raise ValueError("query_measurement_matrix expects a value_filter. if you don't plan to apply a filter, use get_measurement_matrix()") + elif measurement_name is None: + raise Exception("measurement_name was not provided.") + elif fmt is not None and fmt not in ["scipy", "pyarrow"]: + raise ValueError("measurement_matrix only supports 'scipy' or 'pyarrow' format") + value_adjustment = value_adjustment if value_adjustment is not None else "raw" + todense = todense if todense is not None else False + fmt = fmt if fmt is not None else "scipy" + if todense and fmt != "scipy": + raise ValueError("dense representation only available in scipy format") + with cellxgene_census.open_soma( + census_version=self.census_version + ) as census: + organism = census[self.dataset][self.organism] + query = organism.axis_query( + measurement_name = measurement_name, + obs_query = tiledbsoma.AxisQuery( + value_filter = value_filter + ) + ) + it = query.X(value_adjustment).tables() + for slc in it: + out = self.fmt_cellxgene_data(slc, fmt) + out = out if not todense else out.todense() + yield out + + + @classmethod + def gget_czi_cellxgene(cls, **kwargs): + """Wrapper for cellxgene gget() + https://chanzuckerberg.github.io/cellxgene-census/notebooks/api_demo/census_gget_demo.html + Support for AnnData or DataFrame. Params included below + + General args: + - species Choice of 'homo_sapiens' or 'mus_musculus'. Default: 'homo_sapiens'. + - gene Str or list of gene name(s) or Ensembl ID(s), e.g. ['ACE2', 'SLC5A1'] or ['ENSG00000130234', 'ENSG00000100170']. Default: None. + NOTE: Set ensembl=True when providing Ensembl ID(s) instead of gene name(s). + See https://cellxgene.cziscience.com/gene-expression for examples of available genes. + - ensembl True/False (default: False). Set to True when genes are provided as Ensembl IDs. + - column_names List of metadata columns to return (stored in AnnData.obs when meta_only=False). + Default: ["dataset_id", "assay", "suspension_type", "sex", "tissue_general", "tissue", "cell_type"] + For more options see: https://api.cellxgene.cziscience.com/curation/ui/#/ -> Schemas -> dataset + - meta_only True/False (default: False). If True, returns only metadata dataframe (corresponds to AnnData.obs). + - census_version Str defining version of Census, e.g. "2023-05-15" or "latest" or "stable". Default: "stable". + - verbose True/False whether to print progress information. Default True. + - out If provided, saves the generated AnnData h5ad (or csv when meta_only=True) file with the specified path. Default: None. + + Cell metadata attributes: + - tissue Str or list of tissue(s), e.g. ['lung', 'blood']. Default: None. + See https://cellxgene.cziscience.com/gene-expression for examples of available tissues. + - cell_type Str or list of celltype(s), e.g. ['mucus secreting cell', 'neuroendocrine cell']. Default: None. + See https://cellxgene.cziscience.com/gene-expression and select a tissue to see examples of available celltypes. + - development_stage Str or list of development stage(s). Default: None. + - disease Str or list of disease(s). Default: None. + - sex Str or list of sex(es), e.g. 'female'. Default: None. + - is_primary_data True/False (default: True). If True, returns only the canonical instance of the cellular observation. + This is commonly set to False for meta-analyses reusing data or for secondary views of data. + - dataset_id Str or list of CELLxGENE dataset ID(s). Default: None. + - tissue_general_ontology_term_id Str or list of high-level tissue UBERON ID(s). Default: None. + Also see: https://github.com/chanzuckerberg/single-cell-data-portal/blob/9b94ccb0a2e0a8f6182b213aa4852c491f6f6aff/backend/wmg/data/tissue_mapper.py + - tissue_general Str or list of high-level tissue label(s). Default: None. + Also see: https://github.com/chanzuckerberg/single-cell-data-portal/blob/9b94ccb0a2e0a8f6182b213aa4852c491f6f6aff/backend/wmg/data/tissue_mapper.py + - tissue_ontology_term_id Str or list of tissue ontology term ID(s) as defined in the CELLxGENE dataset schema. Default: None. + - assay_ontology_term_id Str or list of assay ontology term ID(s) as defined in the CELLxGENE dataset schema. Default: None. + - assay Str or list of assay(s) as defined in the CELLxGENE dataset schema. Default: None. + - cell_type_ontology_term_id Str or list of celltype ontology term ID(s) as defined in the CELLxGENE dataset schema. Default: None. + - development_stage_ontology_term_id Str or list of development stage ontology term ID(s) as defined in the CELLxGENE dataset schema. Default: None. + - disease_ontology_term_id Str or list of disease ontology term ID(s) as defined in the CELLxGENE dataset schema. Default: None. + - donor_id Str or list of donor ID(s) as defined in the CELLxGENE dataset schema. Default: None. + - self_reported_ethnicity_ontology_term_id Str or list of self reported ethnicity ontology ID(s) as defined in the CELLxGENE dataset schema. Default: None. + - self_reported_ethnicity Str or list of self reported ethnicity as defined in the CELLxGENE dataset schema. Default: None. + - sex_ontology_term_id Str or list of sex ontology ID(s) as defined in the CELLxGENE dataset schema. Default: None. + - suspension_type Str or list of suspension type(s) as defined in the CELLxGENE dataset schema. Default: None. + + Returns AnnData object (when meta_only=False) or dataframe (when meta_only=True). + + """ + gget.setup("cellxgene") + return gget.cellxgene(**kwargs) + if __name__ == "__main__": # TODO: tmp, run testing suite when this file is called as main @@ -158,7 +296,7 @@ def get_feature_dataset_presence_matrix(self, upper=None, lower=None, measuremen cell_value_filter = "tissue == 'brain' and sex == 'male'" cell_column_names = ["assay", "cell_type", "tissue"] gene_value_filter = "feature_id in ['ENSG00000161798', 'ENSG00000188229']" - gene_column_names = ["feature_name", "feature_reference"] + gene_column_names = ["feature_name", "feature_length"] print("getting cell metadata as pandas dataframe") obsdf = resource.get_cell_metadata(value_filter=cell_value_filter, column_names=cell_column_names, fmt="pandas") print("success!") From 5cc822d963a941cd1c13266658ee11508783547e Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 5 Mar 2024 10:48:34 -0500 Subject: [PATCH 06/39] lint all base files --- run_tests.py | 2 +- setup.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/run_tests.py b/run_tests.py index 981f21b7..cd250420 100644 --- a/run_tests.py +++ b/run_tests.py @@ -6,4 +6,4 @@ suite = loader.discover(start_dir) runner = unittest.TextTestRunner() - runner.run(suite) \ No newline at end of file + runner.run(suite) diff --git a/setup.py b/setup.py index 063fca8b..22635ad6 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,8 @@ def readme(): # read the contents of requirements.txt -with open(path.join(this_directory, "requirements.txt"), encoding="utf-8") as f: +with open(path.join(this_directory, "requirements.txt"), + encoding="utf-8") as f: requirements = f.read().splitlines() setup( From b80432344e893e4e6917243284a412559136f7a9 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 5 Mar 2024 11:15:16 -0500 Subject: [PATCH 07/39] yapr lint google format on tdc/benchmark_group, chem_utils, generation, multi_pred --- tdc/benchmark_group/base_group.py | 46 +-- tdc/benchmark_group/docking_group.py | 55 ++-- tdc/benchmark_group/drugcombo_group.py | 8 +- tdc/chem_utils/evaluator.py | 58 ++-- tdc/chem_utils/featurize/_smiles2pubchem.py | 35 ++- tdc/chem_utils/featurize/_xyz2mol.py | 76 ++--- tdc/chem_utils/featurize/molconvert.py | 70 ++--- tdc/chem_utils/oracle/docking.py | 6 +- tdc/chem_utils/oracle/filter.py | 29 +- tdc/chem_utils/oracle/oracle.py | 294 +++++++++++--------- tdc/generation/bi_generation_dataset.py | 8 +- tdc/generation/generation_dataset.py | 31 ++- tdc/generation/ligandmolgen.py | 1 - tdc/generation/molgen.py | 7 +- tdc/generation/reaction.py | 1 - tdc/generation/retrosyn.py | 7 +- tdc/generation/sbdd.py | 7 +- tdc/multi_pred/antibodyaff.py | 1 - tdc/multi_pred/bi_pred_dataset.py | 104 ++++--- tdc/multi_pred/catalyst.py | 9 +- tdc/multi_pred/ddi.py | 15 +- tdc/multi_pred/drugres.py | 16 +- tdc/multi_pred/drugsyn.py | 8 +- tdc/multi_pred/dti.py | 32 +-- tdc/multi_pred/gda.py | 9 +- tdc/multi_pred/mti.py | 9 +- tdc/multi_pred/multi_pred_dataset.py | 18 +- tdc/multi_pred/peptidemhc.py | 1 - tdc/multi_pred/ppi.py | 15 +- tdc/multi_pred/tcr_epi.py | 8 +- tdc/multi_pred/test_multi_pred.py | 1 - tdc/multi_pred/trialoutcome.py | 8 +- 32 files changed, 527 insertions(+), 466 deletions(-) diff --git a/tdc/benchmark_group/base_group.py b/tdc/benchmark_group/base_group.py index ab2fb1ab..e85accad 100644 --- a/tdc/benchmark_group/base_group.py +++ b/tdc/benchmark_group/base_group.py @@ -29,7 +29,6 @@ class BenchmarkGroup: - """Boilerplate of benchmark group class. It downloads, processes, and loads a set of benchmark classes along with their splits. It also provides evaluators and train/valid splitters.""" def __init__(self, name, path="./data", file_format="csv"): @@ -118,15 +117,19 @@ def get_train_valid_split(self, seed, benchmark, split_type="default"): frac = [0.875, 0.125, 0.0] if split_method == "scaffold": - out = create_scaffold_split(train_val, seed, frac=frac, entity="Drug") + out = create_scaffold_split(train_val, + seed, + frac=frac, + entity="Drug") elif split_method == "random": out = create_fold(train_val, seed, frac=frac) elif split_method == "combination": out = create_combination_split(train_val, seed, frac=frac) elif split_method == "group": - out = create_group_split( - train_val, seed, holdout_frac=0.2, group_column="Year" - ) + out = create_group_split(train_val, + seed, + holdout_frac=0.2, + group_column="Year") else: raise NotImplementedError return out["train"], out["valid"] @@ -178,8 +181,11 @@ def evaluate(self, pred, testing=True, benchmark=None, save_dict=True): elif self.file_format == "pkl": test = pd.read_pickle(os.path.join(data_path, "test.pkl")) y = test.Y.values - evaluator = eval("Evaluator(name = '" + metric_dict[data_name] + "')") - out[data_name] = {metric_dict[data_name]: round(evaluator(y, pred_), 3)} + evaluator = eval("Evaluator(name = '" + metric_dict[data_name] + + "')") + out[data_name] = { + metric_dict[data_name]: round(evaluator(y, pred_), 3) + } # If reporting accuracy across target classes if "target_class" in test.columns: @@ -190,13 +196,11 @@ def evaluate(self, pred, testing=True, benchmark=None, save_dict=True): y_subset = test_subset.Y.values pred_subset = test_subset.pred.values - evaluator = eval( - "Evaluator(name = '" + metric_dict[data_name_subset] + "')" - ) + evaluator = eval("Evaluator(name = '" + + metric_dict[data_name_subset] + "')") out[data_name_subset] = { - metric_dict[data_name_subset]: round( - evaluator(y_subset, pred_subset), 3 - ) + metric_dict[data_name_subset]: + round(evaluator(y_subset, pred_subset), 3) } return out else: @@ -207,10 +211,14 @@ def evaluate(self, pred, testing=True, benchmark=None, save_dict=True): ) data_name = fuzzy_search(benchmark, self.dataset_names) metric_dict = bm_metric_names[self.name] - evaluator = eval("Evaluator(name = '" + metric_dict[data_name] + "')") + evaluator = eval("Evaluator(name = '" + metric_dict[data_name] + + "')") return {metric_dict[data_name]: round(evaluator(true, pred), 3)} - def evaluate_many(self, preds, save_file_name=None, results_individual=None): + def evaluate_many(self, + preds, + save_file_name=None, + results_individual=None): """ This function returns the data in a format needed to submit to the Leaderboard @@ -225,11 +233,9 @@ def evaluate_many(self, preds, save_file_name=None, results_individual=None): min_requirement = 5 if len(preds) < min_requirement: - return ValueError( - "Must have predictions from at least " - + str(min_requirement) - + " runs for leaderboard submission" - ) + return ValueError("Must have predictions from at least " + + str(min_requirement) + + " runs for leaderboard submission") if results_individual is None: individual_results = [] for pred in preds: diff --git a/tdc/benchmark_group/docking_group.py b/tdc/benchmark_group/docking_group.py index ad8abb46..1a683a20 100644 --- a/tdc/benchmark_group/docking_group.py +++ b/tdc/benchmark_group/docking_group.py @@ -34,9 +34,11 @@ class docking_group(BenchmarkGroup): """ - def __init__( - self, path="./data", num_workers=None, num_cpus=None, num_max_call=5000 - ): + def __init__(self, + path="./data", + num_workers=None, + num_cpus=None, + num_max_call=5000): """Create a docking group benchmark loader. Raises: @@ -157,7 +159,12 @@ def get(self, benchmark, num_max_call=5000): data = pd.read_csv(os.path.join(self.path, "zinc.tab"), sep="\t") return {"oracle": oracle, "data": data, "name": dataset} - def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True): + def evaluate(self, + pred, + true=None, + benchmark=None, + m1_api=None, + save_dict=True): """Summary Args: @@ -227,7 +234,9 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True) docking_scores = oracle(pred_) print_sys("---- Calculating average docking scores ----") - if len(np.where(np.array(list(docking_scores.values())) > 0)[0]) > 0.7: + if len( + np.where(np.array(list(docking_scores.values())) > 0) + [0]) > 0.7: ## check if the scores are all positive.. if so, make them all negative docking_scores = {j: -k for j, k in docking_scores.items()} if save_dict: @@ -275,7 +284,8 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True) if save_dict: results["pass_list"] = pred_filter results["%pass"] = float(len(pred_filter)) / 100 - results["top1_%pass"] = min([docking_scores[i] for i in pred_filter]) + results["top1_%pass"] = min( + [docking_scores[i] for i in pred_filter]) print_sys("---- Calculating diversity ----") from ..evaluator import Evaluator @@ -284,19 +294,23 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True) results["diversity"] = score print_sys("---- Calculating novelty ----") evaluator = Evaluator(name="Novelty") - training = pd.read_csv(os.path.join(self.path, "zinc.tab"), sep="\t") + training = pd.read_csv(os.path.join(self.path, "zinc.tab"), + sep="\t") score = evaluator(pred_, training.smiles.values) results["novelty"] = score results["top smiles"] = [ - i[0] for i in sorted(docking_scores.items(), key=lambda x: x[1]) + i[0] + for i in sorted(docking_scores.items(), key=lambda x: x[1]) ] results_max_call[num_max_call] = results results_all[data_name] = results_max_call return results_all - def evaluate_many( - self, preds, save_file_name=None, m1_api=None, results_individual=None - ): + def evaluate_many(self, + preds, + save_file_name=None, + m1_api=None, + results_individual=None): """evaluate many runs together and output submission ready pkl file. Args: @@ -310,11 +324,9 @@ def evaluate_many( """ min_requirement = 3 if len(preds) < min_requirement: - return ValueError( - "Must have predictions from at least " - + str(min_requirement) - + " runs for leaderboard submission" - ) + return ValueError("Must have predictions from at least " + + str(min_requirement) + + " runs for leaderboard submission") if results_individual is None: individual_results = [] for pred in preds: @@ -345,13 +357,10 @@ def evaluate_many( for metric in metrics: if metric == "top smiles": results_agg_target_call[metric] = np.unique( - np.array( - [ - individual_results[fold][target][num_calls][metric] - for fold in range(num_folds) - ] - ).reshape(-1) - ).tolist() + np.array([ + individual_results[fold][target][num_calls] + [metric] for fold in range(num_folds) + ]).reshape(-1)).tolist() else: res = [ individual_results[fold][target][num_calls][metric] diff --git a/tdc/benchmark_group/drugcombo_group.py b/tdc/benchmark_group/drugcombo_group.py index 801b411c..a1b5421a 100644 --- a/tdc/benchmark_group/drugcombo_group.py +++ b/tdc/benchmark_group/drugcombo_group.py @@ -15,11 +15,11 @@ class drugcombo_group(BenchmarkGroup): def __init__(self, path="./data"): """create a drug combination benchmark group""" super().__init__(name="DrugCombo_Group", path=path, file_format="pkl") - - + def get_cell_line_meta_data(self): import os from ..utils.load import download_wrapper from ..utils import load_dict - name = download_wrapper('drug_comb_meta_data', self.path, ['drug_comb_meta_data']) - return load_dict(os.path.join(self.path, name + '.pkl')) \ No newline at end of file + name = download_wrapper('drug_comb_meta_data', self.path, + ['drug_comb_meta_data']) + return load_dict(os.path.join(self.path, name + '.pkl')) diff --git a/tdc/chem_utils/evaluator.py b/tdc/chem_utils/evaluator.py index 8a00f290..e2e60fb3 100644 --- a/tdc/chem_utils/evaluator.py +++ b/tdc/chem_utils/evaluator.py @@ -12,7 +12,8 @@ rdBase.DisableLog("rdApp.error") except: - raise ImportError("Please install rdkit by 'conda install -c conda-forge rdkit'! ") + raise ImportError( + "Please install rdkit by 'conda install -c conda-forge rdkit'! ") def single_molecule_validity(smiles): @@ -57,7 +58,8 @@ def canonicalize(smiles): def unique_lst_of_smiles(list_of_smiles): canonical_smiles_lst = list(map(canonicalize, list_of_smiles)) - canonical_smiles_lst = list(filter(lambda x: x is not None, canonical_smiles_lst)) + canonical_smiles_lst = list( + filter(lambda x: x is not None, canonical_smiles_lst)) canonical_smiles_lst = list(set(canonical_smiles_lst)) return canonical_smiles_lst @@ -88,11 +90,9 @@ def novelty(generated_smiles_lst, training_smiles_lst): """ generated_smiles_lst = unique_lst_of_smiles(generated_smiles_lst) training_smiles_lst = unique_lst_of_smiles(training_smiles_lst) - novel_ratio = ( - sum([1 if i in training_smiles_lst else 0 for i in generated_smiles_lst]) - * 1.0 - / len(generated_smiles_lst) - ) + novel_ratio = (sum( + [1 if i in training_smiles_lst else 0 for i in generated_smiles_lst]) * + 1.0 / len(generated_smiles_lst)) return 1 - novel_ratio @@ -107,14 +107,19 @@ def diversity(list_of_smiles): div: float """ list_of_unique_smiles = unique_lst_of_smiles(list_of_smiles) - list_of_mol = [Chem.MolFromSmiles(smiles) for smiles in list_of_unique_smiles] + list_of_mol = [ + Chem.MolFromSmiles(smiles) for smiles in list_of_unique_smiles + ] list_of_fp = [ - AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048, useChirality=False) + AllChem.GetMorganFingerprintAsBitVect(mol, + 2, + nBits=2048, + useChirality=False) for mol in list_of_mol ] avg_lst = [] for idx, fp in enumerate(list_of_fp): - for fp2 in list_of_fp[idx + 1 :]: + for fp2 in list_of_fp[idx + 1:]: sim = DataStructs.TanimotoSimilarity(fp, fp2) ### option I distance = 1 - sim @@ -235,7 +240,9 @@ def get_fingerprints(mols, radius=2, length=4096): Returns: a list of fingerprints """ - return [AllChem.GetMorganFingerprintAsBitVect(m, radius, length) for m in mols] + return [ + AllChem.GetMorganFingerprintAsBitVect(m, radius, length) for m in mols + ] def get_mols(smiles_list): @@ -267,10 +274,8 @@ def calculate_internal_pairwise_similarities(smiles_list): Symmetric matrix of pairwise similarities. Diagonal is set to zero. """ if len(smiles_list) > 10000: - logger.warning( - f"Calculating internal similarity on large set of " - f"SMILES strings ({len(smiles_list)})" - ) + logger.warning(f"Calculating internal similarity on large set of " + f"SMILES strings ({len(smiles_list)})") mols = get_mols(smiles_list) fps = get_fingerprints(mols) @@ -313,7 +318,8 @@ def kl_divergence(generated_smiles_lst, training_smiles_lst): def canonical(smiles): mol = Chem.MolFromSmiles(smiles) if mol is not None: - return Chem.MolToSmiles(mol, isomericSmiles=True) ### todo double check + return Chem.MolToSmiles(mol, + isomericSmiles=True) ### todo double check else: return None @@ -323,17 +329,20 @@ def canonical(smiles): generated_lst_mol = list(filter(filter_out_func, generated_lst_mol)) training_lst_mol = list(filter(filter_out_func, training_lst_mol)) - d_sampled = calculate_pc_descriptors(generated_lst_mol, pc_descriptor_subset) + d_sampled = calculate_pc_descriptors(generated_lst_mol, + pc_descriptor_subset) d_chembl = calculate_pc_descriptors(training_lst_mol, pc_descriptor_subset) kldivs = {} for i in range(4): - kldiv = continuous_kldiv(X_baseline=d_chembl[:, i], X_sampled=d_sampled[:, i]) + kldiv = continuous_kldiv(X_baseline=d_chembl[:, i], + X_sampled=d_sampled[:, i]) kldivs[pc_descriptor_subset[i]] = kldiv # ... and for the int valued ones. for i in range(4, 9): - kldiv = discrete_kldiv(X_baseline=d_chembl[:, i], X_sampled=d_sampled[:, i]) + kldiv = discrete_kldiv(X_baseline=d_chembl[:, i], + X_sampled=d_sampled[:, i]) kldivs[pc_descriptor_subset[i]] = kldiv # pairwise similarity @@ -344,7 +353,8 @@ def canonical(smiles): sampled_sim = calculate_internal_pairwise_similarities(generated_lst_mol) sampled_sim = sampled_sim.max(axis=1) - kldiv_int_int = continuous_kldiv(X_baseline=chembl_sim, X_sampled=sampled_sim) + kldiv_int_int = continuous_kldiv(X_baseline=chembl_sim, + X_sampled=sampled_sim) kldivs["internal_similarity"] = kldiv_int_int """ # for some reason, this runs into problems when both sets are identical. @@ -395,10 +405,14 @@ def _calculate_distribution_statistics(chemnet, molecules): cov = np.cov(gen_mol_act.T) return mu, cov - mu_ref, cov_ref = _calculate_distribution_statistics(chemnet, training_smiles_lst) + mu_ref, cov_ref = _calculate_distribution_statistics( + chemnet, training_smiles_lst) mu, cov = _calculate_distribution_statistics(chemnet, generated_smiles_lst) - FCD = fcd.calculate_frechet_distance(mu1=mu_ref, mu2=mu, sigma1=cov_ref, sigma2=cov) + FCD = fcd.calculate_frechet_distance(mu1=mu_ref, + mu2=mu, + sigma1=cov_ref, + sigma2=cov) fcd_distance = np.exp(-0.2 * FCD) return fcd_distance diff --git a/tdc/chem_utils/featurize/_smiles2pubchem.py b/tdc/chem_utils/featurize/_smiles2pubchem.py index f41fc82a..0cd21cbb 100644 --- a/tdc/chem_utils/featurize/_smiles2pubchem.py +++ b/tdc/chem_utils/featurize/_smiles2pubchem.py @@ -7,8 +7,8 @@ rdBase.DisableLog("rdApp.error") except: - raise ImportError("Please install rdkit by 'conda install -c conda-forge rdkit'! ") - + raise ImportError( + "Please install rdkit by 'conda install -c conda-forge rdkit'! ") try: import networkx as nx @@ -409,9 +409,11 @@ def func_4(mol, bits): for bondIdx in ring: BeginAtom = mol.GetBondWithIdx(bondIdx).GetBeginAtom() EndAtom = mol.GetBondWithIdx(bondIdx).GetEndAtom() - if BeginAtom.GetAtomicNum() not in [1, 6] or EndAtom.GetAtomicNum() not in [ - 1, - 6, + if BeginAtom.GetAtomicNum() not in [ + 1, 6 + ] or EndAtom.GetAtomicNum() not in [ + 1, + 6, ]: heteroatom = True break @@ -752,9 +754,11 @@ def func_7(mol, bits): for bondIdx in ring: BeginAtom = mol.GetBondWithIdx(bondIdx).GetBeginAtom() EndAtom = mol.GetBondWithIdx(bondIdx).GetEndAtom() - if BeginAtom.GetAtomicNum() not in [1, 6] or EndAtom.GetAtomicNum() not in [ - 1, - 6, + if BeginAtom.GetAtomicNum() not in [ + 1, 6 + ] or EndAtom.GetAtomicNum() not in [ + 1, + 6, ]: heteroatom = True break @@ -862,9 +866,11 @@ def func_8(mol, bits): for bondIdx in ring: BeginAtom = mol.GetBondWithIdx(bondIdx).GetBeginAtom() EndAtom = mol.GetBondWithIdx(bondIdx).GetEndAtom() - if BeginAtom.GetAtomicNum() not in [1, 6] or EndAtom.GetAtomicNum() not in [ - 1, - 6, + if BeginAtom.GetAtomicNum() not in [ + 1, 6 + ] or EndAtom.GetAtomicNum() not in [ + 1, + 6, ]: heteroatom = True break @@ -936,6 +942,7 @@ def calcPubChemFingerAll(s): AllBits[index3 + 115] = 1 return np.array(AllBits) + def canonicalize(smiles): mol = Chem.MolFromSmiles(smiles) if mol is not None: @@ -943,13 +950,13 @@ def canonicalize(smiles): else: return None + def smiles2pubchem(s): s = canonicalize(s) try: features = calcPubChemFingerAll(s) except: - print( - "pubchem fingerprint not working for smiles: " + s + " convert to 0 vectors" - ) + print("pubchem fingerprint not working for smiles: " + s + + " convert to 0 vectors") features = np.zeros((881,)) return np.array(features) diff --git a/tdc/chem_utils/featurize/_xyz2mol.py b/tdc/chem_utils/featurize/_xyz2mol.py index 92af581f..9c91d2c6 100644 --- a/tdc/chem_utils/featurize/_xyz2mol.py +++ b/tdc/chem_utils/featurize/_xyz2mol.py @@ -2,15 +2,14 @@ from collections import defaultdict from typing import List - try: from rdkit import Chem from rdkit import rdBase rdBase.DisableLog("rdApp.error") except: - raise ImportError("Please install rdkit by 'conda install -c conda-forge rdkit'! ") - + raise ImportError( + "Please install rdkit by 'conda install -c conda-forge rdkit'! ") try: import networkx as nx @@ -19,7 +18,6 @@ from ...utils import print_sys - ############## begin xyz2mol ################ # from https://github.com/jensengroup/xyz2mol/blob/master/xyz2mol.py @@ -121,7 +119,6 @@ "pu", ] - global atomic_valence global atomic_valence_electrons @@ -179,7 +176,8 @@ def get_UA(maxValence_list, valence_list): """ """ UA = [] DU = [] - for i, (maxValence, valence) in enumerate(zip(maxValence_list, valence_list)): + for i, (maxValence, valence) in enumerate(zip(maxValence_list, + valence_list)): if not maxValence - valence > 0: continue UA.append(i) @@ -235,7 +233,8 @@ def charge_is_OK( BO_valences = list(BO.sum(axis=1)) for i, atom in enumerate(atoms): - q = get_atomic_charge(atom, atomic_valence_electrons[atom], BO_valences[i]) + q = get_atomic_charge(atom, atomic_valence_electrons[atom], + BO_valences[i]) Q += q if atom == 6: number_of_single_bonds_to_C = list(BO[i, :]).count(1) @@ -381,8 +380,7 @@ def BO2mol( if l != l2: raise RuntimeError( - "sizes of adjMat ({0:d}) and Atoms {1:d} differ".format(l, l2) - ) + "sizes of adjMat ({0:d}) and Atoms {1:d} differ".format(l, l2)) rwMol = Chem.RWMol(mol) @@ -403,23 +401,23 @@ def BO2mol( mol = rwMol.GetMol() if allow_charged_fragments: - mol = set_atomic_charges( - mol, atoms, atomic_valence_electrons, BO_valences, BO_matrix, mol_charge - ) + mol = set_atomic_charges(mol, atoms, atomic_valence_electrons, + BO_valences, BO_matrix, mol_charge) else: - mol = set_atomic_radicals(mol, atoms, atomic_valence_electrons, BO_valences) + mol = set_atomic_radicals(mol, atoms, atomic_valence_electrons, + BO_valences) return mol -def set_atomic_charges( - mol, atoms, atomic_valence_electrons, BO_valences, BO_matrix, mol_charge -): +def set_atomic_charges(mol, atoms, atomic_valence_electrons, BO_valences, + BO_matrix, mol_charge): """ """ q = 0 for i, atom in enumerate(atoms): a = mol.GetAtomWithIdx(i) - charge = get_atomic_charge(atom, atomic_valence_electrons[atom], BO_valences[i]) + charge = get_atomic_charge(atom, atomic_valence_electrons[atom], + BO_valences[i]) q += charge if atom == 6: number_of_single_bonds_to_C = list(BO_matrix[i, :]).count(1) @@ -444,7 +442,8 @@ def set_atomic_radicals(mol, atoms, atomic_valence_electrons, BO_valences): """ for i, atom in enumerate(atoms): a = mol.GetAtomWithIdx(i) - charge = get_atomic_charge(atom, atomic_valence_electrons[atom], BO_valences[i]) + charge = get_atomic_charge(atom, atomic_valence_electrons[atom], + BO_valences[i]) if abs(charge) > 0: a.SetNumRadicalElectrons(abs(int(charge))) @@ -457,7 +456,7 @@ def get_bonds(UA, AC): bonds = [] for k, i in enumerate(UA): - for j in UA[k + 1 :]: + for j in UA[k + 1:]: if AC[i, j] == 1: bonds.append(tuple(sorted([i, j]))) @@ -510,7 +509,9 @@ def AC2BO(AC, atoms, charge, allow_charged_fragments=True, use_graph=True): for i, (atomicNum, valence) in enumerate(zip(atoms, AC_valence)): # valence can't be smaller than number of neighbourgs - possible_valence = [x for x in atomic_valence[atomicNum] if x >= valence] + possible_valence = [ + x for x in atomic_valence[atomicNum] if x >= valence + ] if not possible_valence: print_sys( "Valence of atom", @@ -553,7 +554,12 @@ def AC2BO(AC, atoms, charge, allow_charged_fragments=True, use_graph=True): UA_pairs_list = get_UA_pairs(UA, AC, use_graph=use_graph) for UA_pairs in UA_pairs_list: - BO = get_BO(AC, UA, DU_from_AC, valences, UA_pairs, use_graph=use_graph) + BO = get_BO(AC, + UA, + DU_from_AC, + valences, + UA_pairs, + use_graph=use_graph) status = BO_is_OK( BO, AC, @@ -577,17 +583,19 @@ def AC2BO(AC, atoms, charge, allow_charged_fragments=True, use_graph=True): if status: return BO, atomic_valence_electrons - elif ( - BO.sum() >= best_BO.sum() - and valences_not_too_large(BO, valences) - and charge_OK - ): + elif (BO.sum() >= best_BO.sum() and + valences_not_too_large(BO, valences) and charge_OK): best_BO = BO.copy() return best_BO, atomic_valence_electrons -def AC2mol(mol, AC, atoms, charge, allow_charged_fragments=True, use_graph=True): +def AC2mol(mol, + AC, + atoms, + charge, + allow_charged_fragments=True, + use_graph=True): """ """ # convert AC matrix to bond order (BO) matrix @@ -614,9 +622,8 @@ def AC2mol(mol, AC, atoms, charge, allow_charged_fragments=True, use_graph=True) # return [] # BO2mol returns an arbitrary resonance form. Let's make the rest - mols = rdchem.ResonanceMolSupplier( - mol, Chem.UNCONSTRAINED_CATIONS, Chem.UNCONSTRAINED_ANIONS - ) + mols = rdchem.ResonanceMolSupplier(mol, Chem.UNCONSTRAINED_CATIONS, + Chem.UNCONSTRAINED_ANIONS) mols = [mol for mol in mols] return mols, BO @@ -754,15 +761,14 @@ def xyz2AC_huckel(atomicNumList, xyz, charge): mol_huckel = Chem.Mol(mol) mol_huckel.GetAtomWithIdx(0).SetFormalCharge( - charge - ) # mol charge arbitrarily added to 1st atom + charge) # mol charge arbitrarily added to 1st atom passed, result = rdEHTTools.RunMol(mol_huckel) opop = result.GetReducedOverlapPopulationMatrix() tri = np.zeros((num_atoms, num_atoms)) - tri[ - np.tril(np.ones((num_atoms, num_atoms), dtype=bool)) - ] = opop # lower triangular to square matrix + tri[np.tril(np.ones( + (num_atoms, num_atoms), + dtype=bool))] = opop # lower triangular to square matrix for i in range(num_atoms): for j in range(i + 1, num_atoms): pair_pop = abs(tri[j, i]) diff --git a/tdc/chem_utils/featurize/molconvert.py b/tdc/chem_utils/featurize/molconvert.py index 4741d5d4..6a217f58 100644 --- a/tdc/chem_utils/featurize/molconvert.py +++ b/tdc/chem_utils/featurize/molconvert.py @@ -5,7 +5,6 @@ import numpy as np from typing import List - try: from rdkit import Chem, DataStructs from rdkit.Chem import AllChem @@ -15,8 +14,8 @@ from rdkit.Chem.Fingerprints import FingerprintMols from rdkit.Chem import MACCSkeys except: - raise ImportError("Please install rdkit by 'conda install -c conda-forge rdkit'! ") - + raise ImportError( + "Please install rdkit by 'conda install -c conda-forge rdkit'! ") from ...utils import print_sys from ..oracle.oracle import ( @@ -52,15 +51,14 @@ def smiles2morgan(s, radius=2, nBits=1024): try: s = canonicalize(s) mol = Chem.MolFromSmiles(s) - features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits) + features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, + radius, + nBits=nBits) features = np.zeros((1,)) DataStructs.ConvertToNumpyArray(features_vec, features) except: - print_sys( - "rdkit not found this smiles for morgan: " - + s - + " convert to all 0 features" - ) + print_sys("rdkit not found this smiles for morgan: " + s + + " convert to all 0 features") features = np.zeros((nBits,)) return features @@ -89,9 +87,8 @@ def smiles2rdkit2d(s): NaNs = np.isnan(features) features[NaNs] = 0 except: - print_sys( - "descriptastorus not found this smiles: " + s + " convert to all 0 features" - ) + print_sys("descriptastorus not found this smiles: " + s + + " convert to all 0 features") features = np.zeros((200,)) return np.array(features) @@ -115,7 +112,8 @@ def smiles2daylight(s): features = np.zeros((NumFinger,)) features[np.array(temp)] = 1 except: - print_sys("rdkit not found this smiles: " + s + " convert to all 0 features") + print_sys("rdkit not found this smiles: " + s + + " convert to all 0 features") features = np.zeros((2048,)) return np.array(features) @@ -210,7 +208,6 @@ def smiles2ECFP6(smiles): class MoleculeFingerprint: - """ Example: MolFP = MoleculeFingerprint(fp = 'ECFP6') @@ -239,10 +236,9 @@ def __init__(self, fp="ECFP4"): try: assert fp in fp2func except: - raise Exception( - "The fingerprint you specify are not supported. \ + raise Exception("The fingerprint you specify are not supported. \ It can only among 'ECFP2', 'ECFP4', 'ECFP6', 'MACCS', 'Daylight', 'RDKit2D', 'Morgan', 'PubChem'" - ) + ) self.fp = fp self.func = fp2func[fp] @@ -388,12 +384,11 @@ def onek_encoding_unk(x, allowable_set): def get_atom_features(atom): return torch.Tensor( - onek_encoding_unk(atom.GetSymbol(), ELEM_LIST) - + onek_encoding_unk(atom.GetDegree(), [0, 1, 2, 3, 4, 5]) - + onek_encoding_unk(atom.GetFormalCharge(), [-1, -2, 1, 2, 0]) - + onek_encoding_unk(int(atom.GetChiralTag()), [0, 1, 2, 3]) - + [atom.GetIsAromatic()] - ) + onek_encoding_unk(atom.GetSymbol(), ELEM_LIST) + + onek_encoding_unk(atom.GetDegree(), [0, 1, 2, 3, 4, 5]) + + onek_encoding_unk(atom.GetFormalCharge(), [-1, -2, 1, 2, 0]) + + onek_encoding_unk(int(atom.GetChiralTag()), [0, 1, 2, 3]) + + [atom.GetIsAromatic()]) def smiles2PyG(smiles): @@ -413,8 +408,9 @@ def smiles2PyG(smiles): atom_features = torch.stack(atom_features) y = [atom.GetSymbol() for atom in mol.GetAtoms()] y = list( - map(lambda x: ELEM_LIST.index(x) if x in ELEM_LIST else len(ELEM_LIST) - 1, y) - ) + map( + lambda x: ELEM_LIST.index(x) + if x in ELEM_LIST else len(ELEM_LIST) - 1, y)) y = torch.LongTensor(y) bond_features = [] for bond in mol.GetBonds(): @@ -438,6 +434,7 @@ def molfile2PyG(molfile): ############### PyG end ############### + ############### DGL begin ############### def smiles2DGL(smiles): """convert SMILES string into dgl.DGLGraph @@ -468,7 +465,6 @@ def smiles2DGL(smiles): ############### DGL end ############### - from ._xyz2mol import xyzfile2mol @@ -511,7 +507,8 @@ def xyzfile2selfies(xyzfile): def distance3d(coordinate_1, coordinate_2): - return np.sqrt(sum([(c1 - c2) ** 2 for c1, c2 in zip(coordinate_1, coordinate_2)])) + return np.sqrt( + sum([(c1 - c2)**2 for c1, c2 in zip(coordinate_1, coordinate_2)])) def upper_atom(atomsymbol): @@ -526,7 +523,9 @@ def xyzfile2graph3d(xyzfile): for j in range(i + 1, num_atoms): distance = distance3d(xyz_coordinates[i], xyz_coordinates[j]) distance_adj_matrix[i, j] = distance_adj_matrix[j, i] = distance - idx2atom = {idx: upper_atom(str_atom(atom)) for idx, atom in enumerate(atoms)} + idx2atom = { + idx: upper_atom(str_atom(atom)) for idx, atom in enumerate(atoms) + } mol, BO = xyzfile2mol(xyzfile) return idx2atom, distance_adj_matrix, BO @@ -599,9 +598,9 @@ def mol_conformer2graph3d(mol_conformer_lst): positions = np.concatenate(positions, 0) for i in range(atom_num): for j in range(i + 1, atom_num): - distance_adj_matrix[i, j] = distance_adj_matrix[j, i] = distance3d( - positions[i], positions[j] - ) + distance_adj_matrix[i, + j] = distance_adj_matrix[j, i] = distance3d( + positions[i], positions[j]) for bond in mol.GetBonds(): a1 = bond.GetBeginAtom().GetIdx() a2 = bond.GetEndAtom().GetIdx() @@ -687,6 +686,7 @@ def xyzfile2coulomb(xyzfile): # 2D_format = ['SMILES', 'SELFIES', 'Graph2D', 'PyG', 'DGL', 'ECFP2', 'ECFP4', 'ECFP6', 'MACCS', 'Daylight', 'RDKit2D', 'Morgan', 'PubChem'] # 3D_format = ['Graph3D', 'Coulumb'] + ## XXX2smiles def molfile2smiles(molfile): """convert molfile into SMILES string @@ -722,7 +722,6 @@ def mol2file2smiles(molfile): ## smiles2xxx - atom_types = ["C", "N", "O", "H", "F", "unknown"] ### Cl, S? @@ -868,7 +867,6 @@ def raw3D2pyg(raw3d_feature): class MolConvert: - """MolConvert: convert the molecule from src formet to dst format. @@ -902,7 +900,8 @@ def __init__(self, src="SMILES", dst="Graph2D", radius=2, nBits=1024): global sf except: - raise Exception("Please install selfies via 'pip install selfies'") + raise Exception( + "Please install selfies via 'pip install selfies'") if "Coulumb" == dst: try: @@ -1023,7 +1022,8 @@ def __call__(self, x): else: lst = [] for x0 in x: - lst.append(self.func(x0, radius=self._radius, nBits=self._nbits)) + lst.append( + self.func(x0, radius=self._radius, nBits=self._nbits)) out = lst if self._dst in fingerprints_list: out = np.array(out) diff --git a/tdc/chem_utils/oracle/docking.py b/tdc/chem_utils/oracle/docking.py index a92314a5..9a920c78 100644 --- a/tdc/chem_utils/oracle/docking.py +++ b/tdc/chem_utils/oracle/docking.py @@ -9,8 +9,11 @@ center = [float(i) for i in center] box_size = [sys.argv[7], sys.argv[8], sys.argv[9]] box_size = [float(i) for i in box_size] + + # print(ligand_pdbqt_file, receptor_pdbqt_file, output_file, center, box_size) -def docking(ligand_pdbqt_file, receptor_pdbqt_file, output_file, center, box_size): +def docking(ligand_pdbqt_file, receptor_pdbqt_file, output_file, center, + box_size): t1 = time() v = Vina(sf_name="vina") v.set_receptor(rigid_pdbqt_filename=receptor_pdbqt_file) @@ -25,7 +28,6 @@ def docking(ligand_pdbqt_file, receptor_pdbqt_file, output_file, center, box_siz docking(ligand_pdbqt_file, receptor_pdbqt_file, output_file, center, box_size) - """ Example: python XXXX.py data/1iep_ligand.pdbqt ./data/1iep_receptor.pdbqt ./data/out 15.190 53.903 16.917 20 20 20 diff --git a/tdc/chem_utils/oracle/filter.py b/tdc/chem_utils/oracle/filter.py index 06b4e0c3..47e6bde0 100644 --- a/tdc/chem_utils/oracle/filter.py +++ b/tdc/chem_utils/oracle/filter.py @@ -9,7 +9,8 @@ rdBase.DisableLog("rdApp.error") except: - raise ImportError("Please install rdkit by 'conda install -c conda-forge rdkit'! ") + raise ImportError( + "Please install rdkit by 'conda install -c conda-forge rdkit'! ") from ...utils import print_sys, install @@ -73,16 +74,14 @@ def __init__( for i in filters: if i not in all_filters: raise ValueError( - i - + " not found; Please choose from a list of available filters from 'BMS', 'Dundee', 'Glaxo', 'Inpharmatica', 'LINT', 'MLSMR', 'PAINS', 'SureChEMBL'" + i + + " not found; Please choose from a list of available filters from 'BMS', 'Dundee', 'Glaxo', 'Inpharmatica', 'LINT', 'MLSMR', 'PAINS', 'SureChEMBL'" ) alert_file_name = pkg_resources.resource_filename( - "rd_filters", "data/alert_collection.csv" - ) + "rd_filters", "data/alert_collection.csv") rules_file_path = pkg_resources.resource_filename( - "rd_filters", "data/rules.json" - ) + "rd_filters", "data/rules.json") self.rf = RDFilters(alert_file_name) self.rule_dict = read_rules(rules_file_path) self.rule_dict["Rule_Inpharmatica"] = False @@ -163,15 +162,13 @@ def __call__(self, input_data): "Rot", ], ) - df_ok = df[ - (df.FILTER == "OK") - & df.MW.between(*self.rule_dict["MW"]) - & df.LogP.between(*self.rule_dict["LogP"]) - & df.HBD.between(*self.rule_dict["HBD"]) - & df.HBA.between(*self.rule_dict["HBA"]) - & df.TPSA.between(*self.rule_dict["TPSA"]) - & df.Rot.between(*self.rule_dict["Rot"]) - ] + df_ok = df[(df.FILTER == "OK") & + df.MW.between(*self.rule_dict["MW"]) & + df.LogP.between(*self.rule_dict["LogP"]) & + df.HBD.between(*self.rule_dict["HBD"]) & + df.HBA.between(*self.rule_dict["HBA"]) & + df.TPSA.between(*self.rule_dict["TPSA"]) & + df.Rot.between(*self.rule_dict["Rot"])] else: df = pd.DataFrame( diff --git a/tdc/chem_utils/oracle/oracle.py b/tdc/chem_utils/oracle/oracle.py index 64d93092..f29f1848 100644 --- a/tdc/chem_utils/oracle/oracle.py +++ b/tdc/chem_utils/oracle/oracle.py @@ -12,7 +12,6 @@ from packaging import version import pkg_resources - try: import rdkit from rdkit import Chem, DataStructs @@ -24,7 +23,8 @@ from rdkit.Chem import rdMolDescriptors from rdkit.six import iteritems except: - raise ImportError("Please install rdkit by 'conda install -c conda-forge rdkit'! ") + raise ImportError( + "Please install rdkit by 'conda install -c conda-forge rdkit'! ") try: from scipy.stats.mstats import gmean @@ -43,7 +43,8 @@ "geometric": gmean, "arithmetic": np.mean, } -SKLEARN_VERSION = version.parse(pkg_resources.get_distribution("scikit-learn").version) +SKLEARN_VERSION = version.parse( + pkg_resources.get_distribution("scikit-learn").version) def smiles_to_rdkit_mol(smiles): @@ -259,9 +260,11 @@ class ClippedScoreModifier(ScoreModifier): Then the generated values are clipped between low and high scores. """ - def __init__( - self, upper_x: float, lower_x=0.0, high_score=1.0, low_score=0.0 - ) -> None: + def __init__(self, + upper_x: float, + lower_x=0.0, + high_score=1.0, + low_score=0.0) -> None: """ Args: upper_x: x-value from which (or until which if smaller than lower_x) the score is maximal @@ -292,9 +295,11 @@ class SmoothClippedScoreModifier(ScoreModifier): center of the logistic function. """ - def __init__( - self, upper_x: float, lower_x=0.0, high_score=1.0, low_score=0.0 - ) -> None: + def __init__(self, + upper_x: float, + lower_x=0.0, + high_score=1.0, + low_score=0.0) -> None: """ Args: upper_x: x-value from which (or until which if smaller than lower_x) the score approaches high_score @@ -315,7 +320,8 @@ def __init__( self.L = high_score - low_score def __call__(self, x): - return self.low_score + self.L / (1 + np.exp(-self.k * (x - self.middle_x))) + return self.low_score + self.L / (1 + np.exp(-self.k * + (x - self.middle_x))) class ThresholdedLinearModifier(ScoreModifier): @@ -371,8 +377,7 @@ def calculateScore(m): # fragment score fp = rdMolDescriptors.GetMorganFingerprint( - m, 2 - ) # <- 2 is the *radius* of the circular fingerprint + m, 2) # <- 2 is the *radius* of the circular fingerprint fps = fp.GetNonzeroElements() score1 = 0.0 nf = 0 @@ -404,14 +409,8 @@ def calculateScore(m): if nMacrocycles > 0: macrocyclePenalty = math.log10(2) - score2 = ( - 0.0 - - sizePenalty - - stereoPenalty - - spiroPenalty - - bridgePenalty - - macrocyclePenalty - ) + score2 = (0.0 - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - + macrocyclePenalty) # correction for the fingerprint density # not in the original publication, added in version 1.1 @@ -519,7 +518,8 @@ def cyp3a4_veith(smiles): try: from DeepPurpose import utils except: - raise ImportError("Please install DeepPurpose by 'pip install DeepPurpose'") + raise ImportError( + "Please install DeepPurpose by 'pip install DeepPurpose'") import os @@ -535,9 +535,10 @@ def cyp3a4_veith(smiles): X_drug = [smiles] drug_encoding = "CNN" y = [1] - X_pred = utils.data_process( - X_drug=X_drug, y=y, drug_encoding=drug_encoding, split_method="no_split" - ) + X_pred = utils.data_process(X_drug=X_drug, + y=y, + drug_encoding=drug_encoding, + split_method="no_split") # cyp3a4_veith_model = cyp3a4_veith_model.to("cuda:0") y_pred = cyp3a4_veith_model.predict(X_pred) return y_pred[0] @@ -564,8 +565,14 @@ def similarity(smiles_a, smiles_b): bmol = Chem.MolFromSmiles(smiles_b) if amol is None or bmol is None: return 0.0 - fp1 = AllChem.GetMorganFingerprintAsBitVect(amol, 2, nBits=2048, useChirality=False) - fp2 = AllChem.GetMorganFingerprintAsBitVect(bmol, 2, nBits=2048, useChirality=False) + fp1 = AllChem.GetMorganFingerprintAsBitVect(amol, + 2, + nBits=2048, + useChirality=False) + fp2 = AllChem.GetMorganFingerprintAsBitVect(bmol, + 2, + nBits=2048, + useChirality=False) return DataStructs.TanimotoSimilarity(fp1, fp2) @@ -709,6 +716,7 @@ def __call__(self, smiles): class AtomCounter: + def __init__(self, element): """ Args: @@ -789,6 +797,7 @@ def canonicalize(smiles: str, include_stereocenters=True): class Isomer_scoring_prev: + def __init__(self, target_smiles, means="geometric"): assert means in ["geometric", "arithmetic"] if means == "geometric": @@ -797,11 +806,11 @@ def __init__(self, target_smiles, means="geometric"): self.mean_func = np.mean atom2cnt_lst = parse_molecular_formula(target_smiles) total_atom_num = sum([cnt for atom, cnt in atom2cnt_lst]) - self.total_atom_modifier = GaussianModifier(mu=total_atom_num, sigma=2.0) - self.AtomCounter_Modifier_lst = [ - ((AtomCounter(atom)), GaussianModifier(mu=cnt, sigma=1.0)) - for atom, cnt in atom2cnt_lst - ] + self.total_atom_modifier = GaussianModifier(mu=total_atom_num, + sigma=2.0) + self.AtomCounter_Modifier_lst = [((AtomCounter(atom)), + GaussianModifier(mu=cnt, sigma=1.0)) + for atom, cnt in atom2cnt_lst] def __call__(self, test_smiles): molecule = smiles_to_rdkit_mol(test_smiles) @@ -818,6 +827,7 @@ def __call__(self, test_smiles): class Isomer_scoring: + def __init__(self, target_smiles, means="geometric"): assert means in ["geometric", "arithmetic"] if means == "geometric": @@ -826,11 +836,11 @@ def __init__(self, target_smiles, means="geometric"): self.mean_func = np.mean atom2cnt_lst = parse_molecular_formula(target_smiles) total_atom_num = sum([cnt for atom, cnt in atom2cnt_lst]) - self.total_atom_modifier = GaussianModifier(mu=total_atom_num, sigma=2.0) - self.AtomCounter_Modifier_lst = [ - ((AtomCounter(atom)), GaussianModifier(mu=cnt, sigma=1.0)) - for atom, cnt in atom2cnt_lst - ] + self.total_atom_modifier = GaussianModifier(mu=total_atom_num, + sigma=2.0) + self.AtomCounter_Modifier_lst = [((AtomCounter(atom)), + GaussianModifier(mu=cnt, sigma=1.0)) + for atom, cnt in atom2cnt_lst] def __call__(self, test_smiles): #### difference 1 @@ -864,29 +874,34 @@ def isomer_meta(target_smiles, means="geometric"): return Isomer_scoring(target_smiles, means=means) -isomers_c7h8n2o2_prev = isomer_meta_prev(target_smiles="C7H8N2O2", means="geometric") -isomers_c9h10n2o2pf2cl_prev = isomer_meta_prev( - target_smiles="C9H10N2O2PF2Cl", means="geometric" -) -isomers_c11h24_prev = isomer_meta_prev(target_smiles="C11H24", means="geometric") +isomers_c7h8n2o2_prev = isomer_meta_prev(target_smiles="C7H8N2O2", + means="geometric") +isomers_c9h10n2o2pf2cl_prev = isomer_meta_prev(target_smiles="C9H10N2O2PF2Cl", + means="geometric") +isomers_c11h24_prev = isomer_meta_prev(target_smiles="C11H24", + means="geometric") isomers_c7h8n2o2 = isomer_meta(target_smiles="C7H8N2O2", means="geometric") -isomers_c9h10n2o2pf2cl = isomer_meta(target_smiles="C9H10N2O2PF2Cl", means="geometric") +isomers_c9h10n2o2pf2cl = isomer_meta(target_smiles="C9H10N2O2PF2Cl", + means="geometric") isomers_c11h24 = isomer_meta(target_smiles="C11H24", means="geometric") class rediscovery_meta: + def __init__(self, target_smiles, fp="ECFP4"): self.similarity_func = fp2fpfunc[fp] self.target_fp = self.similarity_func(target_smiles) def __call__(self, test_smiles): test_fp = self.similarity_func(test_smiles) - similarity_value = DataStructs.TanimotoSimilarity(self.target_fp, test_fp) + similarity_value = DataStructs.TanimotoSimilarity( + self.target_fp, test_fp) return similarity_value class similarity_meta: + def __init__(self, target_smiles, fp="FCFP4", modifier_func=None): self.similarity_func = fp2fpfunc[fp] self.target_fp = self.similarity_func(target_smiles) @@ -894,7 +909,8 @@ def __init__(self, target_smiles, fp="FCFP4", modifier_func=None): def __call__(self, test_smiles): test_fp = self.similarity_func(test_smiles) - similarity_value = DataStructs.TanimotoSimilarity(self.target_fp, test_fp) + similarity_value = DataStructs.TanimotoSimilarity( + self.target_fp, test_fp) if self.modifier_func is None: modifier_score = similarity_value else: @@ -903,14 +919,14 @@ def __call__(self, test_smiles): celecoxib_rediscovery = rediscovery_meta( - target_smiles="CC1=CC=C(C=C1)C1=CC(=NN1C1=CC=C(C=C1)S(N)(=O)=O)C(F)(F)F", fp="ECFP4" -) + target_smiles="CC1=CC=C(C=C1)C1=CC(=NN1C1=CC=C(C=C1)S(N)(=O)=O)C(F)(F)F", + fp="ECFP4") troglitazone_rediscovery = rediscovery_meta( - target_smiles="Cc1c(C)c2OC(C)(COc3ccc(CC4SC(=O)NC4=O)cc3)CCc2c(C)c1O", fp="ECFP4" -) + target_smiles="Cc1c(C)c2OC(C)(COc3ccc(CC4SC(=O)NC4=O)cc3)CCc2c(C)c1O", + fp="ECFP4") thiothixene_rediscovery = rediscovery_meta( - target_smiles="CN(C)S(=O)(=O)c1ccc2Sc3ccccc3C(=CCCN4CCN(C)CC4)c2c1", fp="ECFP4" -) + target_smiles="CN(C)S(=O)(=O)c1ccc2Sc3ccccc3C(=CCCN4CCN(C)CC4)c2c1", + fp="ECFP4") similarity_modifier = ClippedScoreModifier(upper_x=0.75) aripiprazole_similarity = similarity_meta( @@ -933,6 +949,7 @@ def __call__(self, test_smiles): class median_meta: + def __init__( self, target_smiles_1, @@ -954,13 +971,12 @@ def __init__( def __call__(self, test_smiles): test_fp1 = self.similarity_func1(test_smiles) - test_fp2 = ( - test_fp1 - if self.similarity_func2 == self.similarity_func1 - else self.similarity_func2(test_smiles) - ) - similarity_value1 = DataStructs.TanimotoSimilarity(self.target_fp1, test_fp1) - similarity_value2 = DataStructs.TanimotoSimilarity(self.target_fp2, test_fp2) + test_fp2 = (test_fp1 if self.similarity_func2 == self.similarity_func1 + else self.similarity_func2(test_smiles)) + similarity_value1 = DataStructs.TanimotoSimilarity( + self.target_fp1, test_fp1) + similarity_value2 = DataStructs.TanimotoSimilarity( + self.target_fp2, test_fp2) if self.modifier_func1 is None: modifier_score1 = similarity_value1 else: @@ -1000,6 +1016,7 @@ def __call__(self, test_smiles): class MPO_meta: + def __init__(self, means): """ target_smiles, fp in ['ECFP4', 'AP', ..., ] @@ -1023,8 +1040,7 @@ def osimertinib_mpo(test_smiles): if "osimertinib_fp_fcfc4" not in globals().keys(): global osimertinib_fp_fcfc4, osimertinib_fp_ecfc6 osimertinib_smiles = ( - "COc1cc(N(C)CCN(C)C)c(NC(=O)C=C)cc1Nc2nccc(n2)c3cn(C)c4ccccc34" - ) + "COc1cc(N(C)CCN(C)C)c(NC(=O)C=C)cc1Nc2nccc(n2)c3cn(C)c4ccccc34") osimertinib_fp_fcfc4 = smiles_2_fingerprint_FCFP4(osimertinib_smiles) osimertinib_fp_ecfc6 = smiles_2_fingerprint_ECFP6(osimertinib_smiles) @@ -1039,13 +1055,12 @@ def osimertinib_mpo(test_smiles): tpsa_score = tpsa_modifier(Descriptors.TPSA(molecule)) logp_score = logp_modifier(Descriptors.MolLogP(molecule)) similarity_v1 = sim_v1_modifier( - DataStructs.TanimotoSimilarity(osimertinib_fp_fcfc4, fp_fcfc4) - ) + DataStructs.TanimotoSimilarity(osimertinib_fp_fcfc4, fp_fcfc4)) similarity_v2 = sim_v2_modifier( - DataStructs.TanimotoSimilarity(osimertinib_fp_ecfc6, fp_ecfc6) - ) + DataStructs.TanimotoSimilarity(osimertinib_fp_ecfc6, fp_ecfc6)) - osimertinib_gmean = gmean([tpsa_score, logp_score, similarity_v1, similarity_v2]) + osimertinib_gmean = gmean( + [tpsa_score, logp_score, similarity_v1, similarity_v2]) return osimertinib_gmean @@ -1053,8 +1068,7 @@ def fexofenadine_mpo(test_smiles): if "fexofenadine_fp" not in globals().keys(): global fexofenadine_fp fexofenadine_smiles = ( - "CC(C)(C(=O)O)c1ccc(cc1)C(O)CCCN2CCC(CC2)C(O)(c3ccccc3)c4ccccc4" - ) + "CC(C)(C(=O)O)c1ccc(cc1)C(O)CCCN2CCC(CC2)C(O)(c3ccccc3)c4ccccc4") fexofenadine_fp = smiles_2_fingerprint_AP(fexofenadine_smiles) similar_modifier = ClippedScoreModifier(upper_x=0.8) @@ -1066,8 +1080,7 @@ def fexofenadine_mpo(test_smiles): tpsa_score = tpsa_modifier(Descriptors.TPSA(molecule)) logp_score = logp_modifier(Descriptors.MolLogP(molecule)) similarity_value = similar_modifier( - DataStructs.TanimotoSimilarity(fp_ap, fexofenadine_fp) - ) + DataStructs.TanimotoSimilarity(fp_ap, fexofenadine_fp)) fexofenadine_gmean = gmean([tpsa_score, logp_score, similarity_value]) return fexofenadine_gmean @@ -1089,11 +1102,11 @@ def ranolazine_mpo(test_smiles): tpsa_score = tpsa_modifier(Descriptors.TPSA(molecule)) logp_score = logp_modifier(Descriptors.MolLogP(molecule)) similarity_value = similar_modifier( - DataStructs.TanimotoSimilarity(fp_ap, ranolazine_fp) - ) + DataStructs.TanimotoSimilarity(fp_ap, ranolazine_fp)) fluorine_value = fluorine_modifier(fluorine_counter(molecule)) - ranolazine_gmean = gmean([tpsa_score, logp_score, similarity_value, fluorine_value]) + ranolazine_gmean = gmean( + [tpsa_score, logp_score, similarity_value, fluorine_value]) return ranolazine_gmean @@ -1147,7 +1160,8 @@ def zaleplon_mpo_prev(test_smiles): global zaleplon_fp, isomer_scoring_C19H17N3O2 zaleplon_smiles = "O=C(C)N(CC)C1=CC=CC(C2=CC=NC3=C(C=NN23)C#N)=C1" zaleplon_fp = smiles_2_fingerprint_ECFP4(zaleplon_smiles) - isomer_scoring_C19H17N3O2 = Isomer_scoring_prev(target_smiles="C19H17N3O2") + isomer_scoring_C19H17N3O2 = Isomer_scoring_prev( + target_smiles="C19H17N3O2") fp = smiles_2_fingerprint_ECFP4(test_smiles) similarity_value = DataStructs.TanimotoSimilarity(fp, zaleplon_fp) @@ -1176,8 +1190,10 @@ def sitagliptin_mpo_prev(test_smiles): sitagliptin_mol = Chem.MolFromSmiles(sitagliptin_smiles) sitagliptin_logp = Descriptors.MolLogP(sitagliptin_mol) sitagliptin_tpsa = Descriptors.TPSA(sitagliptin_mol) - sitagliptin_logp_modifier = GaussianModifier(mu=sitagliptin_logp, sigma=0.2) - sitagliptin_tpsa_modifier = GaussianModifier(mu=sitagliptin_tpsa, sigma=5) + sitagliptin_logp_modifier = GaussianModifier(mu=sitagliptin_logp, + sigma=0.2) + sitagliptin_tpsa_modifier = GaussianModifier(mu=sitagliptin_tpsa, + sigma=5) isomers_scoring_C16H15F6N5O = Isomer_scoring_prev("C16H15F6N5O") sitagliptin_similar_modifier = GaussianModifier(mu=0, sigma=0.1) @@ -1189,8 +1205,7 @@ def sitagliptin_mpo_prev(test_smiles): tpsa_score = sitagliptin_tpsa_modifier(tpsa_score) isomer_score = isomers_scoring_C16H15F6N5O(test_smiles) similarity_value = sitagliptin_similar_modifier( - DataStructs.TanimotoSimilarity(fp_ecfp4, sitagliptin_fp_ecfp4) - ) + DataStructs.TanimotoSimilarity(fp_ecfp4, sitagliptin_fp_ecfp4)) return gmean([similarity_value, logp_score, tpsa_score, isomer_score]) @@ -1202,8 +1217,10 @@ def sitagliptin_mpo(test_smiles): sitagliptin_mol = Chem.MolFromSmiles(sitagliptin_smiles) sitagliptin_logp = Descriptors.MolLogP(sitagliptin_mol) sitagliptin_tpsa = Descriptors.TPSA(sitagliptin_mol) - sitagliptin_logp_modifier = GaussianModifier(mu=sitagliptin_logp, sigma=0.2) - sitagliptin_tpsa_modifier = GaussianModifier(mu=sitagliptin_tpsa, sigma=5) + sitagliptin_logp_modifier = GaussianModifier(mu=sitagliptin_logp, + sigma=0.2) + sitagliptin_tpsa_modifier = GaussianModifier(mu=sitagliptin_tpsa, + sigma=5) isomers_scoring_C16H15F6N5O = Isomer_scoring("C16H15F6N5O") sitagliptin_similar_modifier = GaussianModifier(mu=0, sigma=0.1) @@ -1215,8 +1232,7 @@ def sitagliptin_mpo(test_smiles): tpsa_score = sitagliptin_tpsa_modifier(tpsa_score) isomer_score = isomers_scoring_C16H15F6N5O(test_smiles) similarity_value = sitagliptin_similar_modifier( - DataStructs.TanimotoSimilarity(fp_ecfp4, sitagliptin_fp_ecfp4) - ) + DataStructs.TanimotoSimilarity(fp_ecfp4, sitagliptin_fp_ecfp4)) return gmean([similarity_value, logp_score, tpsa_score, isomer_score]) @@ -1228,6 +1244,7 @@ def get_PHCO_fingerprint(mol): class SMARTS_scoring: + def __init__(self, target_smarts, inverse): self.target_mol = Chem.MolFromSmarts(target_smarts) self.inverse = inverse @@ -1253,12 +1270,10 @@ def deco_hop(test_smiles): pharmacophor_mol = smiles_to_rdkit_mol(pharmacophor_smiles) pharmacophor_fp = get_PHCO_fingerprint(pharmacophor_mol) - deco1_smarts_scoring = SMARTS_scoring( - target_smarts="CS([#6])(=O)=O", inverse=True - ) + deco1_smarts_scoring = SMARTS_scoring(target_smarts="CS([#6])(=O)=O", + inverse=True) deco2_smarts_scoring = SMARTS_scoring( - target_smarts="[#7]-c1ccc2ncsc2c1", inverse=True - ) + target_smarts="[#7]-c1ccc2ncsc2c1", inverse=True) scaffold_smarts_scoring = SMARTS_scoring( target_smarts="[#7]-c1n[c;h1]nc2[c;h1]c(-[#8])[c;h0][c;h1]c12", inverse=False, @@ -1269,43 +1284,41 @@ def deco_hop(test_smiles): similarity_modifier = ClippedScoreModifier(upper_x=0.85) similarity_value = similarity_modifier( - DataStructs.TanimotoSimilarity(fp, pharmacophor_fp) - ) + DataStructs.TanimotoSimilarity(fp, pharmacophor_fp)) deco1_score = deco1_smarts_scoring(molecule) deco2_score = deco2_smarts_scoring(molecule) scaffold_score = scaffold_smarts_scoring(molecule) - all_scores = np.mean([similarity_value, deco1_score, deco2_score, scaffold_score]) + all_scores = np.mean( + [similarity_value, deco1_score, deco2_score, scaffold_score]) return all_scores def scaffold_hop(test_smiles): - if ( - "pharmacophor_fp" not in globals().keys() - or "scaffold_smarts_scoring" not in globals().keys() - or "deco_smarts_scoring" not in globals().keys() - ): + if ("pharmacophor_fp" not in globals().keys() or + "scaffold_smarts_scoring" not in globals().keys() or + "deco_smarts_scoring" not in globals().keys()): global pharmacophor_fp, deco_smarts_scoring, scaffold_smarts_scoring pharmacophor_smiles = "CCCOc1cc2ncnc(Nc3ccc4ncsc4c3)c2cc1S(=O)(=O)C(C)(C)C" pharmacophor_mol = smiles_to_rdkit_mol(pharmacophor_smiles) pharmacophor_fp = get_PHCO_fingerprint(pharmacophor_mol) deco_smarts_scoring = SMARTS_scoring( - target_smarts="[#6]-[#6]-[#6]-[#8]-[#6]~[#6]~[#6]~[#6]~[#6]-[#7]-c1ccc2ncsc2c1", + target_smarts= + "[#6]-[#6]-[#6]-[#8]-[#6]~[#6]~[#6]~[#6]~[#6]-[#7]-c1ccc2ncsc2c1", inverse=False, ) scaffold_smarts_scoring = SMARTS_scoring( - target_smarts="[#7]-c1n[c;h1]nc2[c;h1]c(-[#8])[c;h0][c;h1]c12", inverse=True - ) + target_smarts="[#7]-c1n[c;h1]nc2[c;h1]c(-[#8])[c;h0][c;h1]c12", + inverse=True) molecule = smiles_to_rdkit_mol(test_smiles) fp = get_PHCO_fingerprint(molecule) similarity_modifier = ClippedScoreModifier(upper_x=0.75) similarity_value = similarity_modifier( - DataStructs.TanimotoSimilarity(fp, pharmacophor_fp) - ) + DataStructs.TanimotoSimilarity(fp, pharmacophor_fp)) deco_score = deco_smarts_scoring(molecule) scaffold_score = scaffold_smarts_scoring(molecule) @@ -1349,8 +1362,6 @@ def valsartan_smarts(test_smiles): ########################################################################### ### END of Guacamol ########################################################################### - - """ Synthesizability from a full retrosynthetic analysis Including: @@ -1502,9 +1513,9 @@ def askcos( # For each entry, repeat to test up to num_trials times if got error message for _ in range(num_trials): print("Trying to send the request, for the %i times now" % (_ + 1)) - resp = requests.get( - host_ip + "/api/treebuilder/", params=params, verify=False - ) + resp = requests.get(host_ip + "/api/treebuilder/", + params=params, + verify=False) if "error" not in resp.json().keys(): break @@ -1513,8 +1524,7 @@ def askcos( json.dump(resp.json(), f_data) num_path, status, depth, p_score, synthesizability, price = tree_analysis( - resp.json() - ) + resp.json()) if output == "plausibility": return p_score @@ -1539,13 +1549,13 @@ def ibm_rxn(smiles, api_key, output="confidence", sleep_time=30): rxn4chemistry_wrapper = RXN4ChemistryWrapper(api_key=api_key) response = rxn4chemistry_wrapper.create_project("test") time.sleep(sleep_time) - response = rxn4chemistry_wrapper.predict_automatic_retrosynthesis(product=smiles) + response = rxn4chemistry_wrapper.predict_automatic_retrosynthesis( + product=smiles) status = "" while status != "SUCCESS": time.sleep(sleep_time) results = rxn4chemistry_wrapper.get_predict_automatic_retrosynthesis_results( - response["prediction_id"] - ) + response["prediction_id"]) status = results["status"] if output == "confidence": @@ -1557,19 +1567,22 @@ def ibm_rxn(smiles, api_key, output="confidence", sleep_time=30): class molecule_one_retro: + def __init__(self, api_token): try: from m1wrapper import MoleculeOneWrapper except: try: - install("git+https://github.com/molecule-one/m1wrapper-python@v1") + install( + "git+https://github.com/molecule-one/m1wrapper-python@v1") from m1wrapper import MoleculeOneWrapper except: raise ImportError( "Install Molecule.One Wrapper via pip install git+https://github.com/molecule-one/m1wrapper-python@v1" ) - self.m1wrapper = MoleculeOneWrapper(api_token, "https://tdc.molecule.one") + self.m1wrapper = MoleculeOneWrapper(api_token, + "https://tdc.molecule.one") def __call__(self, smiles): if isinstance(smiles, str): @@ -1577,7 +1590,10 @@ def __call__(self, smiles): search = self.m1wrapper.run_batch_search( targets=smiles, - parameters={"exploratory_search": False, "detail_level": "score"}, + parameters={ + "exploratory_search": False, + "detail_level": "score" + }, ) status_cur = search.get_status() @@ -1594,7 +1610,8 @@ def __call__(self, smiles): if status_cur != status: print_sys(status) status_cur = status - result = search.get_results(precision=5, only=["targetSmiles", "result"]) + result = search.get_results(precision=5, + only=["targetSmiles", "result"]) return {i["targetSmiles"]: i["result"] for i in result} @@ -1666,7 +1683,11 @@ def __call__(self, test_smiles, error_value=None): class Score_3d: """Evaluate Vina score (force field) for a conformer binding to a receptor""" - def __init__(self, receptor_pdbqt_file, center, box_size, scorefunction="vina"): + def __init__(self, + receptor_pdbqt_file, + center, + box_size, + scorefunction="vina"): try: from vina import Vina except: @@ -1702,7 +1723,11 @@ def __call__(self, ligand_pdbqt_file, minimize=True): class Vina_3d: """Perform docking search from a conformer.""" - def __init__(self, receptor_pdbqt_file, center, box_size, scorefunction="vina"): + def __init__(self, + receptor_pdbqt_file, + center, + box_size, + scorefunction="vina"): try: from vina import Vina except: @@ -1722,9 +1747,11 @@ def __init__(self, receptor_pdbqt_file, center, box_size, scorefunction="vina"): "Cannot compute the affinity map, please check center and box_size" ) - def __call__( - self, ligand_pdbqt_file, output_file="out.pdbqt", exhaustiveness=8, n_poses=10 - ): + def __call__(self, + ligand_pdbqt_file, + output_file="out.pdbqt", + exhaustiveness=8, + n_poses=10): try: self.v.set_ligand_from_file(ligand_pdbqt_file) self.v.dock(exhaustiveness=exhaustiveness, n_poses=n_poses) @@ -1739,7 +1766,11 @@ def __call__( class Vina_smiles: """Perform docking search from a conformer.""" - def __init__(self, receptor_pdbqt_file, center, box_size, scorefunction="vina"): + def __init__(self, + receptor_pdbqt_file, + center, + box_size, + scorefunction="vina"): try: from vina import Vina except: @@ -1760,9 +1791,11 @@ def __init__(self, receptor_pdbqt_file, center, box_size, scorefunction="vina"): "Cannot compute the affinity map, please check center and box_size" ) - def __call__( - self, ligand_smiles, output_file="out.pdbqt", exhaustiveness=8, n_poses=10 - ): + def __call__(self, + ligand_smiles, + output_file="out.pdbqt", + exhaustiveness=8, + n_poses=10): try: m = Chem.MolFromSmiles(ligand_smiles) m = Chem.AddHs(m) @@ -1818,15 +1851,12 @@ def smina(ligand, protein, score_only=False, raw_input=False): f.write("%d\n\n" % n_atoms) for atom_i in range(n_atoms): atom = mol_atom[atom_i] - f.write( - "%s %.9f %.9f %.9f\n" - % ( - atom, - mol_coord[atom_i, 0], - mol_coord[atom_i, 1], - mol_coord[atom_i, 2], - ) - ) + f.write("%s %.9f %.9f %.9f\n" % ( + atom, + mol_coord[atom_i, 0], + mol_coord[atom_i, 1], + mol_coord[atom_i, 2], + )) f.close() # 2. convert to sdf file try: @@ -1838,8 +1868,8 @@ def smina(ligand, protein, score_only=False, raw_input=False): ligand = "temp_ligand.sdf" if score_only: msg = os.popen( - f"./{smina_model_path} -l {ligand} -r {protein} --score_only" - ).read() + f"./{smina_model_path} -l {ligand} -r {protein} --score_only").read( + ) return float(msg.split("\n")[-7].split(" ")[-2]) else: os.system(f"./{smina_model_path} -l {ligand} -r {protein} --score_only") diff --git a/tdc/generation/bi_generation_dataset.py b/tdc/generation/bi_generation_dataset.py index 229a9b20..6c95bce3 100644 --- a/tdc/generation/bi_generation_dataset.py +++ b/tdc/generation/bi_generation_dataset.py @@ -15,7 +15,6 @@ class DataLoader(base_dataset.DataLoader): - """A base dataset loader class. Attributes: @@ -34,7 +33,9 @@ def __init__( threshold=15, remove_Hs=True, keep_het=False, - allowed_atom_list=["C", "N", "O", "S", "H", "B", "Br", "Cl", "P", "I", "F"], + allowed_atom_list=[ + "C", "N", "O", "S", "H", "B", "Br", "Cl", "P", "I", "F" + ], ): """To create a base dataloader object that each generation task can inherit from. @@ -115,6 +116,7 @@ def get_split(self, method="random", seed=42, frac=[0.7, 0.1, 0.2]): protein, ligand = data["protein"], data["ligand"] if method == "random": - return create_combination_generation_split(protein, ligand, seed, frac) + return create_combination_generation_split(protein, ligand, seed, + frac) else: raise AttributeError("Please use the correct split method") diff --git a/tdc/generation/generation_dataset.py b/tdc/generation/generation_dataset.py index cf008761..704eef19 100644 --- a/tdc/generation/generation_dataset.py +++ b/tdc/generation/generation_dataset.py @@ -20,7 +20,6 @@ class DataLoader(base_dataset.DataLoader): - """A base dataset loader class. Attributes: @@ -42,8 +41,7 @@ def __init__(self, name, path, print_stats, column_name): from ..metadata import single_molecule_dataset_names self.smiles_lst = distribution_dataset_load( - name, path, single_molecule_dataset_names, column_name=column_name - ) + name, path, single_molecule_dataset_names, column_name=column_name) ### including fuzzy-search self.name = name self.path = path @@ -102,7 +100,6 @@ def get_split(self, method="random", seed=42, frac=[0.7, 0.1, 0.2]): class PairedDataLoader(base_dataset.DataLoader): - """A basic class for generation of biomedical entities conditioned on other entities, such as reaction prediction. Attributes: @@ -125,8 +122,8 @@ def __init__(self, name, path, print_stats, input_name, output_name): from ..metadata import paired_dataset_names self.input_smiles_lst, self.output_smiles_lst = generation_paired_dataset_load( - name, path, paired_dataset_names, input_name, output_name - ) ### including fuzzy-search + name, path, paired_dataset_names, input_name, + output_name) ### including fuzzy-search self.name = name self.path = path self.dataset_names = paired_dataset_names @@ -155,11 +152,15 @@ def get_data(self, format="df"): AttributeError: Use the correct format as input (df, dict) """ if format == "df": - return pd.DataFrame( - {"input": self.input_smiles_lst, "output": self.output_smiles_lst} - ) + return pd.DataFrame({ + "input": self.input_smiles_lst, + "output": self.output_smiles_lst + }) elif format == "dict": - return {"input": self.input_smiles_lst, "output": self.output_smiles_lst} + return { + "input": self.input_smiles_lst, + "output": self.output_smiles_lst + } else: raise AttributeError("Please use the correct format input") @@ -187,7 +188,6 @@ def get_split(self, method="random", seed=42, frac=[0.7, 0.1, 0.2]): class DataLoader3D(base_dataset.DataLoader): - """A basic class for generation of 3D biomedical entities. (under construction) Attributes: @@ -209,8 +209,7 @@ def __init__(self, name, path, print_stats, dataset_names, column_name): column_name (str): The name of the column containing smiles strings. """ self.df, self.path, self.name = three_dim_dataset_load( - name, path, dataset_names - ) + name, path, dataset_names) if print_stats: self.print_stats() print_sys("Done!") @@ -240,7 +239,8 @@ def get_data(self, format="df", more_features="None"): """ if more_features in ["None", "SMILES"]: pass - elif more_features in ["Graph3D", "Coulumb", "SELFIES"]: # why SELFIES here? + elif more_features in ["Graph3D", "Coulumb", + "SELFIES"]: # why SELFIES here? try: from rdkit.Chem.PandasTools import LoadSDF from rdkit import rdBase @@ -256,7 +256,8 @@ def get_data(self, format="df", more_features="None"): convert = MolConvert(src="SDF", dst=more_features) for i in sdf_file_names[self.name]: - self.df[i + "_" + more_features] = convert(self.path + i + ".sdf") + self.df[i + "_" + more_features] = convert(self.path + i + + ".sdf") if format == "df": return self.df diff --git a/tdc/generation/ligandmolgen.py b/tdc/generation/ligandmolgen.py index aeecd538..18ec63af 100644 --- a/tdc/generation/ligandmolgen.py +++ b/tdc/generation/ligandmolgen.py @@ -11,7 +11,6 @@ class LigandMolGen(bi_generation_dataset.DataLoader): - """Data loader class accessing to pocket-based ligand generation task.""" def __init__(self, name, path="./data", print_stats=False): diff --git a/tdc/generation/molgen.py b/tdc/generation/molgen.py index 0c75c486..8e76cb5c 100644 --- a/tdc/generation/molgen.py +++ b/tdc/generation/molgen.py @@ -11,10 +11,13 @@ class MolGen(generation_dataset.DataLoader): - """Data loader class accessing to molecular generation task (distribution learning)""" - def __init__(self, name, path="./data", print_stats=False, column_name="smiles"): + def __init__(self, + name, + path="./data", + print_stats=False, + column_name="smiles"): """To create an data loader object for molecular generation task. The goal is to generate diverse, novel molecules that has desirable chemical properties. One can combined with oracle functions. diff --git a/tdc/generation/reaction.py b/tdc/generation/reaction.py index 79f11080..4a3d83ca 100644 --- a/tdc/generation/reaction.py +++ b/tdc/generation/reaction.py @@ -11,7 +11,6 @@ class Reaction(generation_dataset.PairedDataLoader): - """Data loader class accessing to forward reaction prediction task.""" def __init__( diff --git a/tdc/generation/retrosyn.py b/tdc/generation/retrosyn.py index df1232c8..2420ce26 100644 --- a/tdc/generation/retrosyn.py +++ b/tdc/generation/retrosyn.py @@ -12,7 +12,6 @@ class RetroSyn(generation_dataset.PairedDataLoader): - """Data loader class accessing to retro-synthetic prediction task.""" def __init__( @@ -66,10 +65,8 @@ def get_split( df["reaction_type"] = rt except: raise ValueError( - "Reaction Type Unavailable for " - + str(self.name) - + "! Please turn include_reaction_type to be false!" - ) + "Reaction Type Unavailable for " + str(self.name) + + "! Please turn include_reaction_type to be false!") if method == "random": return create_fold(df, seed, frac) diff --git a/tdc/generation/sbdd.py b/tdc/generation/sbdd.py index 80368335..a697049d 100644 --- a/tdc/generation/sbdd.py +++ b/tdc/generation/sbdd.py @@ -16,7 +16,6 @@ class SBDD(base_dataset.DataLoader): - """Data loader class accessing to structure-based drug design task.""" def __init__( @@ -51,7 +50,8 @@ def __init__( try: import biopandas except: - raise ImportError("Please install biopandas by 'pip install biopandas'! ") + raise ImportError( + "Please install biopandas by 'pip install biopandas'! ") protein, ligand = bi_distribution_dataset_load( name, path, @@ -126,7 +126,8 @@ def get_split(self, method="random", seed=42, frac=[0.7, 0.1, 0.2]): data = self.get_data(format="dict") protein, ligand = data["protein"], data["ligand"] - splitted_data = create_combination_generation_split(protein, ligand, seed, frac) + splitted_data = create_combination_generation_split( + protein, ligand, seed, frac) if self.save: np.savez( diff --git a/tdc/multi_pred/antibodyaff.py b/tdc/multi_pred/antibodyaff.py index 85e68774..cfca5143 100644 --- a/tdc/multi_pred/antibodyaff.py +++ b/tdc/multi_pred/antibodyaff.py @@ -13,7 +13,6 @@ class AntibodyAff(bi_pred_dataset.DataLoader): - """Data loader class to load datasets in Antibody-antigen Affinity Prediction task. More info: https://tdcommons.ai/multi_pred_tasks/antibodyaff/ diff --git a/tdc/multi_pred/bi_pred_dataset.py b/tdc/multi_pred/bi_pred_dataset.py index 8c00be44..5fc788cc 100644 --- a/tdc/multi_pred/bi_pred_dataset.py +++ b/tdc/multi_pred/bi_pred_dataset.py @@ -26,7 +26,6 @@ class DataLoader(base_dataset.DataLoader): - """A base data loader class that each bi-instance prediction task dataloader class can inherit from. Attributes: TODO @@ -52,10 +51,8 @@ def __init__(self, name, path, label_name, print_stats, dataset_names): if label_name is None: raise ValueError( "Please select a label name. " - "You can use tdc.utils.retrieve_label_name_list('" - + name.lower() - + "') to retrieve all available label names." - ) + "You can use tdc.utils.retrieve_label_name_list('" + + name.lower() + "') to retrieve all available label names.") name = fuzzy_search(name, dataset_names) if name == "bindingdb_patent": @@ -70,9 +67,11 @@ def __init__(self, name, path, label_name, print_stats, dataset_names): entity1_idx, entity2_idx, aux_column_val, - ) = interaction_dataset_load( - name, path, label_name, dataset_names, aux_column=aux_column - ) + ) = interaction_dataset_load(name, + path, + label_name, + dataset_names, + aux_column=aux_column) self.name = name self.entity1 = entity1 @@ -107,26 +106,22 @@ def get_data(self, format="df"): """ if format == "df": if self.aux_column is None: - return pd.DataFrame( - { - self.entity1_name + "_ID": self.entity1_idx, - self.entity1_name: self.entity1, - self.entity2_name + "_ID": self.entity2_idx, - self.entity2_name: self.entity2, - "Y": self.y, - } - ) + return pd.DataFrame({ + self.entity1_name + "_ID": self.entity1_idx, + self.entity1_name: self.entity1, + self.entity2_name + "_ID": self.entity2_idx, + self.entity2_name: self.entity2, + "Y": self.y, + }) else: - return pd.DataFrame( - { - self.entity1_name + "_ID": self.entity1_idx, - self.entity1_name: self.entity1, - self.entity2_name + "_ID": self.entity2_idx, - self.entity2_name: self.entity2, - "Y": self.y, - self.aux_column: self.aux_column_val, - } - ) + return pd.DataFrame({ + self.entity1_name + "_ID": self.entity1_idx, + self.entity1_name: self.entity1, + self.entity2_name + "_ID": self.entity2_idx, + self.entity2_name: self.entity2, + "Y": self.y, + self.aux_column: self.aux_column_val, + }) elif format == "DeepPurpose": return self.entity1.values, self.entity2.values, self.y.values @@ -165,12 +160,8 @@ def print_stats(self): file=sys.stderr, ) print( - str(len(self.y)) - + " " - + self.entity1_name.lower() - + "-" - + self.entity2_name.lower() - + " pairs.", + str(len(self.y)) + " " + self.entity1_name.lower() + "-" + + self.entity2_name.lower() + " pairs.", flush=True, file=sys.stderr, ) @@ -218,12 +209,10 @@ def get_split( return create_fold_setting_cold(df, seed, frac, self.entity2_name) elif method == "cold_split": if column_name is None or not all( - list(map(lambda x: x in df.columns.values, column_name)) - ): + list(map(lambda x: x in df.columns.values, column_name))): raise AttributeError( "For cold_split, please provide one or multiple column names " - "that are contained in the dataframe." - ) + "that are contained in the dataframe.") return create_fold_setting_cold(df, seed, frac, column_name) elif method == "combination": return create_combination_split(df, seed, frac) @@ -298,9 +287,8 @@ def to_graph( if len(np.unique(self.raw_y)) > 2: print( "The dataset label consists of affinity scores. " - "Binarization using threshold " - + str(threshold) - + " is conducted to construct the positive edges in the network. " + "Binarization using threshold " + str(threshold) + + " is conducted to construct the positive edges in the network. " "Adjust the threshold by to_graph(threshold = X)", flush=True, file=sys.stderr, @@ -308,29 +296,34 @@ def to_graph( if threshold is None: raise AttributeError( "Please specify the threshold to binarize the data by " - "'to_graph(threshold = N)'!" - ) - df["label_binary"] = label_transform( - self.raw_y, True, threshold, False, verbose=False, order=order - ) + "'to_graph(threshold = N)'!") + df["label_binary"] = label_transform(self.raw_y, + True, + threshold, + False, + verbose=False, + order=order) else: # already binary df["label_binary"] = df["Y"] - df[self.entity1_name + "_ID"] = df[self.entity1_name + "_ID"].astype(str) - df[self.entity2_name + "_ID"] = df[self.entity2_name + "_ID"].astype(str) + df[self.entity1_name + "_ID"] = df[self.entity1_name + + "_ID"].astype(str) + df[self.entity2_name + "_ID"] = df[self.entity2_name + + "_ID"].astype(str) df_pos = df[df.label_binary == 1] df_neg = df[df.label_binary == 0] return_dict = {} - pos_edges = df_pos[ - [self.entity1_name + "_ID", self.entity2_name + "_ID"] - ].values - neg_edges = df_neg[ - [self.entity1_name + "_ID", self.entity2_name + "_ID"] - ].values - edges = df[[self.entity1_name + "_ID", self.entity2_name + "_ID"]].values + pos_edges = df_pos[[ + self.entity1_name + "_ID", self.entity2_name + "_ID" + ]].values + neg_edges = df_neg[[ + self.entity1_name + "_ID", self.entity2_name + "_ID" + ]].values + edges = df[[self.entity1_name + "_ID", + self.entity2_name + "_ID"]].values if format == "edge_list": return_dict["edge_list"] = pos_edges @@ -364,7 +357,8 @@ def to_graph( edge_list1 = np.array([dict_[i] for i in pos_edges.T[0]]) edge_list2 = np.array([dict_[i] for i in pos_edges.T[1]]) - edge_index = torch.tensor([edge_list1, edge_list2], dtype=torch.long) + edge_index = torch.tensor([edge_list1, edge_list2], + dtype=torch.long) x = torch.tensor(np.array(index), dtype=torch.float) data = Data(x=x, edge_index=edge_index) return_dict["pyg_graph"] = data diff --git a/tdc/multi_pred/catalyst.py b/tdc/multi_pred/catalyst.py index 285a06f7..3c797a55 100644 --- a/tdc/multi_pred/catalyst.py +++ b/tdc/multi_pred/catalyst.py @@ -13,7 +13,6 @@ class Catalyst(bi_pred_dataset.DataLoader): - """Data loader class to load datasets in Catalyst Prediction task More info: https://tdcommons.ai/multi_pred_tasks/catalyst/ @@ -33,9 +32,11 @@ class Catalyst(bi_pred_dataset.DataLoader): def __init__(self, name, path="./data", label_name=None, print_stats=False): """Create Catalyst Prediction dataloader object""" - super().__init__( - name, path, label_name, print_stats, dataset_names=dataset_names["Catalyst"] - ) + super().__init__(name, + path, + label_name, + print_stats, + dataset_names=dataset_names["Catalyst"]) self.entity1_name = "Reactant" self.entity2_name = "Product" self.two_types = True diff --git a/tdc/multi_pred/ddi.py b/tdc/multi_pred/ddi.py index ce5b41d9..b56559c6 100644 --- a/tdc/multi_pred/ddi.py +++ b/tdc/multi_pred/ddi.py @@ -13,7 +13,6 @@ class DDI(bi_pred_dataset.DataLoader): - """Data loader class to load datasets in Drug-Drug Interaction Prediction task More info: https://tdcommons.ai/multi_pred_tasks/ddi/ @@ -32,9 +31,11 @@ class DDI(bi_pred_dataset.DataLoader): def __init__(self, name, path="./data", label_name=None, print_stats=False): """Create Drug-Drug Interaction (DDI) Prediction dataloader object""" - super().__init__( - name, path, label_name, print_stats, dataset_names=dataset_names["DDI"] - ) + super().__init__(name, + path, + label_name, + print_stats, + dataset_names=dataset_names["DDI"]) self.entity1_name = "Drug1" self.entity2_name = "Drug2" self.two_types = False @@ -50,9 +51,9 @@ def print_stats(self): print_sys("--- Dataset Statistics ---") print( - "There are " - + str(len(np.unique(self.entity1.tolist() + self.entity2.tolist()))) - + " unique drugs.", + "There are " + + str(len(np.unique(self.entity1.tolist() + self.entity2.tolist()))) + + " unique drugs.", flush=True, file=sys.stderr, ) diff --git a/tdc/multi_pred/drugres.py b/tdc/multi_pred/drugres.py index b526fec3..85214cf5 100644 --- a/tdc/multi_pred/drugres.py +++ b/tdc/multi_pred/drugres.py @@ -14,7 +14,6 @@ class DrugRes(bi_pred_dataset.DataLoader): - """Data loader class to load datasets in Drug Response Prediction Task. More info: https://tdcommons.ai/multi_pred_tasks/drugres/ @@ -33,9 +32,11 @@ class DrugRes(bi_pred_dataset.DataLoader): def __init__(self, name, path="./data", label_name=None, print_stats=False): """Create Drug Response Prediction dataloader object""" - super().__init__( - name, path, label_name, print_stats, dataset_names=dataset_names["DrugRes"] - ) + super().__init__(name, + path, + label_name, + print_stats, + dataset_names=dataset_names["DrugRes"]) self.entity1_name = "Drug" self.entity2_name = "Cell Line" self.two_types = True @@ -51,12 +52,11 @@ def get_gene_symbols(self): Retrieve the gene symbols for the cell line gene expression """ path = self.path - name = download_wrapper("gdsc_gene_symbols", path, ["gdsc_gene_symbols"]) + name = download_wrapper("gdsc_gene_symbols", path, + ["gdsc_gene_symbols"]) print_sys("Loading...") import pandas as pd import os df = pd.read_csv(os.path.join(path, name + ".tab"), sep="\t") - return df.values.reshape( - -1, - ) + return df.values.reshape(-1,) diff --git a/tdc/multi_pred/drugsyn.py b/tdc/multi_pred/drugsyn.py index 5c27673a..9a2ee2df 100644 --- a/tdc/multi_pred/drugsyn.py +++ b/tdc/multi_pred/drugsyn.py @@ -13,7 +13,6 @@ class DrugSyn(multi_pred_dataset.DataLoader): - """Data loader class to load datasets in Drug Synergy Prediction task. More info: https://tdcommons.ai/multi_pred_tasks/drugsyn/ @@ -32,9 +31,10 @@ class DrugSyn(multi_pred_dataset.DataLoader): def __init__(self, name, path="./data", print_stats=False): """Create Drug Synergy Prediction dataloader object""" - super().__init__( - name, path, print_stats, dataset_names=dataset_names["DrugSyn"] - ) + super().__init__(name, + path, + print_stats, + dataset_names=dataset_names["DrugSyn"]) if print_stats: self.print_stats() diff --git a/tdc/multi_pred/dti.py b/tdc/multi_pred/dti.py index 33655a04..aa12b736 100644 --- a/tdc/multi_pred/dti.py +++ b/tdc/multi_pred/dti.py @@ -13,7 +13,6 @@ class DTI(bi_pred_dataset.DataLoader): - """Data loader class to load datasets in Drug-Target Interaction Prediction task. More info: https://tdcommons.ai/multi_pred_tasks/dti/ @@ -34,9 +33,11 @@ class DTI(bi_pred_dataset.DataLoader): def __init__(self, name, path="./data", label_name=None, print_stats=False): """Create Drug-Target Interaction Prediction dataloader object""" - super().__init__( - name, path, label_name, print_stats, dataset_names=dataset_names["DTI"] - ) + super().__init__(name, + path, + label_name, + print_stats, + dataset_names=dataset_names["DTI"]) self.entity1_name = "Drug" self.entity2_name = "Target" self.two_types = True @@ -60,30 +61,21 @@ def harmonize_affinities(self, mode=None): print_sys( "The scale is converted to log scale, so we will take the maximum!" ) - df = ( - df_.groupby(["Drug_ID", "Drug", "Target_ID", "Target"]) - .Y.agg(max) - .reset_index() - ) + df = (df_.groupby(["Drug_ID", "Drug", "Target_ID", + "Target"]).Y.agg(max).reset_index()) else: print_sys( "The scale is in original affinity scale, so we will take the minimum!" ) - df = ( - df_.groupby(["Drug_ID", "Drug", "Target_ID", "Target"]) - .Y.agg(min) - .reset_index() - ) + df = (df_.groupby(["Drug_ID", "Drug", "Target_ID", + "Target"]).Y.agg(min).reset_index()) elif mode == "mean": import numpy as np df_ = self.get_data() - df = ( - df_.groupby(["Drug_ID", "Drug", "Target_ID", "Target"]) - .Y.agg(np.mean) - .reset_index() - ) + df = (df_.groupby(["Drug_ID", "Drug", "Target_ID", + "Target"]).Y.agg(np.mean).reset_index()) self.entity1_idx = df.Drug_ID.values self.entity2_idx = df.Target_ID.values @@ -92,4 +84,4 @@ def harmonize_affinities(self, mode=None): self.entity2 = df.Target.values self.y = df.Y.values print_sys("The original data has been updated!") - return df + return df \ No newline at end of file diff --git a/tdc/multi_pred/gda.py b/tdc/multi_pred/gda.py index e3609a6f..bc69477d 100644 --- a/tdc/multi_pred/gda.py +++ b/tdc/multi_pred/gda.py @@ -13,7 +13,6 @@ class GDA(bi_pred_dataset.DataLoader): - """Data loader class to load datasets in Gene-Disease Association Prediction task. More info: https://tdcommons.ai/multi_pred_tasks/gdi/ @@ -35,9 +34,11 @@ class GDA(bi_pred_dataset.DataLoader): def __init__(self, name, path="./data", label_name=None, print_stats=False): """Create Gene-Disease Association Prediction dataloader object""" - super().__init__( - name, path, label_name, print_stats, dataset_names=dataset_names["GDA"] - ) + super().__init__(name, + path, + label_name, + print_stats, + dataset_names=dataset_names["GDA"]) self.entity1_name = "Gene" self.entity2_name = "Disease" self.two_types = True diff --git a/tdc/multi_pred/mti.py b/tdc/multi_pred/mti.py index 2e2dafe4..f0ec169a 100644 --- a/tdc/multi_pred/mti.py +++ b/tdc/multi_pred/mti.py @@ -13,7 +13,6 @@ class MTI(bi_pred_dataset.DataLoader): - """Data loader class to load datasets in MicroRNA-Target Interaction Prediction task. More info: https://tdcommons.ai/multi_pred_tasks/mti/ @@ -35,9 +34,11 @@ class MTI(bi_pred_dataset.DataLoader): def __init__(self, name, path="./data", label_name=None, print_stats=False): """Create MicroRNA-Target Interaction Prediction dataloader object""" - super().__init__( - name, path, label_name, print_stats, dataset_names=dataset_names["MTI"] - ) + super().__init__(name, + path, + label_name, + print_stats, + dataset_names=dataset_names["MTI"]) self.entity1_name = "miRNA" self.entity2_name = "Target" self.two_types = True diff --git a/tdc/multi_pred/multi_pred_dataset.py b/tdc/multi_pred/multi_pred_dataset.py index 4c5ba3ae..eabe3fe2 100644 --- a/tdc/multi_pred/multi_pred_dataset.py +++ b/tdc/multi_pred/multi_pred_dataset.py @@ -41,9 +41,8 @@ def __init__(self, name, path, print_stats, dataset_names): if label_name is None: raise ValueError( "Please select a label name. You can use tdc.utils.retrieve_label_name_list('" - + name.lower() - + "') to retrieve all available label names." - ) + + name.lower() + + "') to retrieve all available label names.") df = multi_dataset_load(name, path, dataset_names) @@ -77,9 +76,11 @@ def print_stats(self): print(str(len(self.df)) + " data points.", flush=True, file=sys.stderr) print_sys("--------------------------") - def get_split( - self, method="random", seed=42, frac=[0.7, 0.1, 0.2], column_name=None - ): + def get_split(self, + method="random", + seed=42, + frac=[0.7, 0.1, 0.2], + column_name=None): """split dataset into train/validation/test. Args: @@ -106,9 +107,8 @@ def get_split( elif method == "cold_split": if isinstance(column_name, str): column_name = [column_name] - if (column_name is None) or ( - not all([x in df.columns.values for x in column_name]) - ): + if (column_name is None) or (not all( + [x in df.columns.values for x in column_name])): raise AttributeError( "For cold_split, please provide one or multiple column names that are contained in the dataframe." ) diff --git a/tdc/multi_pred/peptidemhc.py b/tdc/multi_pred/peptidemhc.py index 21219b05..cb9c7779 100644 --- a/tdc/multi_pred/peptidemhc.py +++ b/tdc/multi_pred/peptidemhc.py @@ -13,7 +13,6 @@ class PeptideMHC(bi_pred_dataset.DataLoader): - """Data loader class to load datasets in Peptide-MHC Binding Prediction task. More info: https://tdcommons.ai/multi_pred_tasks/peptidemhc/ diff --git a/tdc/multi_pred/ppi.py b/tdc/multi_pred/ppi.py index d135d4e0..31d4a7b1 100644 --- a/tdc/multi_pred/ppi.py +++ b/tdc/multi_pred/ppi.py @@ -13,7 +13,6 @@ class PPI(bi_pred_dataset.DataLoader): - """Data loader class to load datasets in Protein-Protein Interaction Prediction task. More info: https://tdcommons.ai/multi_pred_tasks/ppi/ @@ -33,9 +32,11 @@ class PPI(bi_pred_dataset.DataLoader): def __init__(self, name, path="./data", label_name=None, print_stats=False): """Create Protein-Protein Interaction Prediction dataloader object""" - super().__init__( - name, path, label_name, print_stats, dataset_names=dataset_names["PPI"] - ) + super().__init__(name, + path, + label_name, + print_stats, + dataset_names=dataset_names["PPI"]) self.entity1_name = "Protein1" self.entity2_name = "Protein2" self.two_types = False @@ -51,9 +52,9 @@ def print_stats(self): print_sys("--- Dataset Statistics ---") print( - "There are " - + str(len(np.unique(self.entity1.tolist() + self.entity2.tolist()))) - + " unique proteins.", + "There are " + + str(len(np.unique(self.entity1.tolist() + self.entity2.tolist()))) + + " unique proteins.", flush=True, file=sys.stderr, ) diff --git a/tdc/multi_pred/tcr_epi.py b/tdc/multi_pred/tcr_epi.py index efa37f97..0d1aeb5a 100644 --- a/tdc/multi_pred/tcr_epi.py +++ b/tdc/multi_pred/tcr_epi.py @@ -14,7 +14,6 @@ class TCREpitopeBinding(multi_pred_dataset.DataLoader): - """Data loader class to load datasets in T cell receptor (TCR) Specificity Prediction Task. More info: @@ -31,9 +30,10 @@ class TCREpitopeBinding(multi_pred_dataset.DataLoader): def __init__(self, name, path="./data", print_stats=False): """Create TCR Specificity Prediction dataloader object""" - super().__init__( - name, path, print_stats, dataset_names=dataset_names["TCREpitopeBinding"] - ) + super().__init__(name, + path, + print_stats, + dataset_names=dataset_names["TCREpitopeBinding"]) self.entity1_name = "TCR" self.entity2_name = "Epitope" diff --git a/tdc/multi_pred/test_multi_pred.py b/tdc/multi_pred/test_multi_pred.py index 4aa01c3d..981e182a 100644 --- a/tdc/multi_pred/test_multi_pred.py +++ b/tdc/multi_pred/test_multi_pred.py @@ -13,7 +13,6 @@ class TestMultiPred(bi_pred_dataset.DataLoader): - """Summary Attributes: diff --git a/tdc/multi_pred/trialoutcome.py b/tdc/multi_pred/trialoutcome.py index 03f46192..25490e14 100644 --- a/tdc/multi_pred/trialoutcome.py +++ b/tdc/multi_pred/trialoutcome.py @@ -13,7 +13,6 @@ class TrialOutcome(multi_pred_dataset.DataLoader): - """Data loader class to load datasets in clinical trial outcome Prediction task. More info: https://tdcommons.ai/multi_pred_tasks/trialoutcome/ @@ -35,9 +34,10 @@ class TrialOutcome(multi_pred_dataset.DataLoader): def __init__(self, name, path="./data", label_name=None, print_stats=False): """Create Clinical Trial Outcome Prediction dataloader object""" - super().__init__( - name, path, print_stats, dataset_names=dataset_names["TrialOutcome"] - ) + super().__init__(name, + path, + print_stats, + dataset_names=dataset_names["TrialOutcome"]) self.entity1_name = "drug_molecule" self.entity2_name = "disease_code" # self.entity3_name = "eligibility_criteria" From 13acccdac6b6396917369fa772608e1196685a80 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 5 Mar 2024 11:20:42 -0500 Subject: [PATCH 08/39] google lint on resource and single_pred in tdc --- tdc/resource/primekg.py | 16 ++++---- tdc/single_pred/adme.py | 34 ++++++++-------- tdc/single_pred/crispr_outcome.py | 1 - tdc/single_pred/develop.py | 25 ++++++------ tdc/single_pred/epitope.py | 1 - tdc/single_pred/hts.py | 1 - tdc/single_pred/paratope.py | 1 - tdc/single_pred/qm.py | 1 - tdc/single_pred/single_pred_dataset.py | 55 +++++++++++++------------- tdc/single_pred/test_single_pred.py | 1 - tdc/single_pred/tox.py | 1 - tdc/single_pred/yields.py | 10 ++--- 12 files changed, 67 insertions(+), 80 deletions(-) diff --git a/tdc/resource/primekg.py b/tdc/resource/primekg.py index 5cbedba9..00a3a9d4 100644 --- a/tdc/resource/primekg.py +++ b/tdc/resource/primekg.py @@ -16,7 +16,6 @@ class PrimeKG: - """PrimeKG data loader class to load the knowledge graph with additional support functions.""" def __init__(self, path="./data"): @@ -32,19 +31,18 @@ def to_nx(self): G = nx.Graph() for i in self.df.relation.unique(): - G.add_edges_from( - self.df[self.df.relation == i][["x_id", "y_id"]].values, relation=i - ) + G.add_edges_from(self.df[self.df.relation == i][["x_id", + "y_id"]].values, + relation=i) return G def get_features(self, feature_type): if feature_type not in ["drug", "disease"]: raise ValueError("feature_type only supports drug/disease!") - return general_load("primekg_" + feature_type + "_feature", self.path, "\t") + return general_load("primekg_" + feature_type + "_feature", self.path, + "\t") def get_node_list(self, node_type): df = self.df - return np.unique( - df[(df.x_type == node_type)].x_id.unique().tolist() - + df[(df.y_type == node_type)].y_id.unique().tolist() - ) + return np.unique(df[(df.x_type == node_type)].x_id.unique().tolist() + + df[(df.y_type == node_type)].y_id.unique().tolist()) diff --git a/tdc/single_pred/adme.py b/tdc/single_pred/adme.py index b9cf7b0f..02b43645 100644 --- a/tdc/single_pred/adme.py +++ b/tdc/single_pred/adme.py @@ -50,9 +50,9 @@ def __init__( import pandas as pd import os - self.ppbr_df = pd.read_csv( - os.path.join(self.path, self.name + ".tab"), sep="\t" - ) + self.ppbr_df = pd.read_csv(os.path.join(self.path, + self.name + ".tab"), + sep="\t") df = self.ppbr_df[self.ppbr_df.Species == "Homo sapiens"] self.entity1 = df.Drug.values self.y = df.Y.values @@ -66,10 +66,11 @@ def get_approved_set(self): import pandas as pd if self.name not in ["pampa_ncats"]: - raise ValueError("This function is only available for PAMPA_NCATS dataset") - entity1, y, entity1_idx = property_dataset_load( - "approved_pampa_ncats", self.path, None, dataset_names["ADME"] - ) + raise ValueError( + "This function is only available for PAMPA_NCATS dataset") + entity1, y, entity1_idx = property_dataset_load("approved_pampa_ncats", + self.path, None, + dataset_names["ADME"]) return pd.DataFrame({"Drug_ID": entity1_idx, "Drug": entity1, "Y": y}) def get_other_species(self, species=None): @@ -83,7 +84,8 @@ def get_other_species(self, species=None): return self.ppbr_df if species in self.ppbr_df.Species.unique(): - return self.ppbr_df[self.ppbr_df.Species == species].reset_index(drop=True) + return self.ppbr_df[self.ppbr_df.Species == species].reset_index( + drop=True) else: raise ValueError( "You can only specify the following set of species name: 'Canis lupus familiaris', 'Cavia porcellus', 'Homo sapiens', 'Mus musculus', 'Rattus norvegicus', 'all'" @@ -99,19 +101,15 @@ def harmonize(self, mode=None): if mode == "max": df_ = self.get_data() - df = ( - df_.sort_values("Y", ascending=True) - .drop_duplicates("Drug") - .reset_index(drop=True) - ) + df = (df_.sort_values( + "Y", + ascending=True).drop_duplicates("Drug").reset_index(drop=True)) elif mode == "min": df_ = self.get_data() - df = ( - df_.sort_values("Y", ascending=False) - .drop_duplicates("Drug") - .reset_index(drop=True) - ) + df = (df_.sort_values( + "Y", + ascending=False).drop_duplicates("Drug").reset_index(drop=True)) elif mode == "remove_all": df_ = self.get_data() diff --git a/tdc/single_pred/crispr_outcome.py b/tdc/single_pred/crispr_outcome.py index 97f5966e..7f0d686f 100644 --- a/tdc/single_pred/crispr_outcome.py +++ b/tdc/single_pred/crispr_outcome.py @@ -13,7 +13,6 @@ class CRISPROutcome(single_pred_dataset.DataLoader): - """Data loader class to load datasets in CRISPROutcome task. More info: https://tdcommons.ai/single_pred_tasks/CRISPROutcome/ Args: diff --git a/tdc/single_pred/develop.py b/tdc/single_pred/develop.py index d88dc8c7..cc953298 100644 --- a/tdc/single_pred/develop.py +++ b/tdc/single_pred/develop.py @@ -13,7 +13,6 @@ class Develop(single_pred_dataset.DataLoader): - """Data loader class to load datasets in Develop task. More info: https://tdcommons.ai/single_pred_tasks/develop/ Args: @@ -85,17 +84,20 @@ def graphein( from graphein.protein.utils import get_obsolete_mapping obs = get_obsolete_mapping() - train_obs = [t for t in split["train"]["Antibody_ID"] if t in obs.keys()] - valid_obs = [t for t in split["valid"]["Antibody_ID"] if t in obs.keys()] - test_obs = [t for t in split["test"]["Antibody_ID"] if t in obs.keys()] - - split["train"] = split["train"].loc[ - ~split["train"]["Antibody_ID"].isin(train_obs) + train_obs = [ + t for t in split["train"]["Antibody_ID"] if t in obs.keys() ] - split["test"] = split["test"].loc[~split["test"]["Antibody_ID"].isin(test_obs)] - split["valid"] = split["valid"].loc[ - ~split["valid"]["Antibody_ID"].isin(valid_obs) + valid_obs = [ + t for t in split["valid"]["Antibody_ID"] if t in obs.keys() ] + test_obs = [t for t in split["test"]["Antibody_ID"] if t in obs.keys()] + + split["train"] = split["train"].loc[~split["train"]["Antibody_ID"]. + isin(train_obs)] + split["test"] = split["test"].loc[~split["test"]["Antibody_ID"]. + isin(test_obs)] + split["valid"] = split["valid"].loc[~split["valid"]["Antibody_ID"]. + isin(valid_obs)] self.split = split @@ -104,8 +106,7 @@ def get_label_map(split_name: str) -> Dict[str, torch.Tensor]: zip( split[split_name].Antibody_ID, split[split_name].Y.apply(torch.tensor), - ) - ) + )) train_labels = get_label_map("train") valid_labels = get_label_map("valid") diff --git a/tdc/single_pred/epitope.py b/tdc/single_pred/epitope.py index ec9448fd..cb47a0ee 100644 --- a/tdc/single_pred/epitope.py +++ b/tdc/single_pred/epitope.py @@ -13,7 +13,6 @@ class Epitope(single_pred_dataset.DataLoader): - """Data loader class to load datasets in Epitope Prediction task. More info: https://tdcommons.ai/single_pred_tasks/epitope/ Args: diff --git a/tdc/single_pred/hts.py b/tdc/single_pred/hts.py index b3589cee..c203d88a 100644 --- a/tdc/single_pred/hts.py +++ b/tdc/single_pred/hts.py @@ -13,7 +13,6 @@ class HTS(single_pred_dataset.DataLoader): - """Data loader class to load datasets in HTS task. More info: https://tdcommons.ai/single_pred_tasks/hts/ Args: diff --git a/tdc/single_pred/paratope.py b/tdc/single_pred/paratope.py index 8d19060c..1866f000 100644 --- a/tdc/single_pred/paratope.py +++ b/tdc/single_pred/paratope.py @@ -13,7 +13,6 @@ class Paratope(single_pred_dataset.DataLoader): - """Data loader class to load datasets in Paratope Prediction task. More info: https://tdcommons.ai/single_pred_tasks/paratope/ Args: diff --git a/tdc/single_pred/qm.py b/tdc/single_pred/qm.py index 9e851340..75bbe0e7 100644 --- a/tdc/single_pred/qm.py +++ b/tdc/single_pred/qm.py @@ -13,7 +13,6 @@ class QM(single_pred_dataset.DataLoader): - """Data loader class to load datasets in QM (Quantum Mechanics Modeling) task. More info: https://tdcommons.ai/single_pred_tasks/qm/ Args: diff --git a/tdc/single_pred/single_pred_dataset.py b/tdc/single_pred/single_pred_dataset.py index 48523af9..8b2771ab 100644 --- a/tdc/single_pred/single_pred_dataset.py +++ b/tdc/single_pred/single_pred_dataset.py @@ -21,7 +21,6 @@ class DataLoader(base_dataset.DataLoader): - """A base data loader class. Args: @@ -65,13 +64,11 @@ def __init__( if label_name is None: raise ValueError( "Please select a label name. You can use tdc.utils.retrieve_label_name_list('" - + name.lower() - + "') to retrieve all available label names." - ) + + name.lower() + + "') to retrieve all available label names.") - entity1, y, entity1_idx = property_dataset_load( - name, path, label_name, dataset_names - ) + entity1, y, entity1_idx = property_dataset_load(name, path, label_name, + dataset_names) self.entity1 = entity1 self.y = y @@ -106,32 +103,34 @@ def get_data(self, format="df"): if format == "df": if self.convert_format is not None: - return pd.DataFrame( - { - self.entity1_name + "_ID": self.entity1_idx, - self.entity1_name: self.entity1, - self.entity1_name - + "_" - + self.convert_format: self.convert_result, - "Y": self.y, - } - ) + return pd.DataFrame({ + self.entity1_name + "_ID": + self.entity1_idx, + self.entity1_name: + self.entity1, + self.entity1_name + "_" + self.convert_format: + self.convert_result, + "Y": + self.y, + }) else: - return pd.DataFrame( - { - self.entity1_name + "_ID": self.entity1_idx, - self.entity1_name: self.entity1, - "Y": self.y, - } - ) + return pd.DataFrame({ + self.entity1_name + "_ID": self.entity1_idx, + self.entity1_name: self.entity1, + "Y": self.y, + }) elif format == "dict": if self.convert_format is not None: return { - self.entity1_name + "_ID": self.entity1_idx.values, - self.entity1_name: self.entity1.values, - self.entity1_name + "_" + self.convert_format: self.convert_result, - "Y": self.y.values, + self.entity1_name + "_ID": + self.entity1_idx.values, + self.entity1_name: + self.entity1.values, + self.entity1_name + "_" + self.convert_format: + self.convert_result, + "Y": + self.y.values, } else: return { diff --git a/tdc/single_pred/test_single_pred.py b/tdc/single_pred/test_single_pred.py index 20edc3ee..315937d9 100644 --- a/tdc/single_pred/test_single_pred.py +++ b/tdc/single_pred/test_single_pred.py @@ -13,7 +13,6 @@ class TestSinglePred(single_pred_dataset.DataLoader): - """Data loader class to test the single instance prediction data loader. Args: diff --git a/tdc/single_pred/tox.py b/tdc/single_pred/tox.py index 8ad02ca1..3d3892b5 100644 --- a/tdc/single_pred/tox.py +++ b/tdc/single_pred/tox.py @@ -13,7 +13,6 @@ class Tox(single_pred_dataset.DataLoader): - """Data loader class to load datasets in Tox (Toxicity Prediction) task. More info: https://tdcommons.ai/single_pred_tasks/tox/ Args: diff --git a/tdc/single_pred/yields.py b/tdc/single_pred/yields.py index 1091c83c..3d3892b5 100644 --- a/tdc/single_pred/yields.py +++ b/tdc/single_pred/yields.py @@ -12,9 +12,8 @@ from ..metadata import dataset_names -class Yields(single_pred_dataset.DataLoader): - - """Data loader class to load datasets in Yields (Reaction Yields Prediction) task. More info: https://tdcommons.ai/single_pred_tasks/yields/ +class Tox(single_pred_dataset.DataLoader): + """Data loader class to load datasets in Tox (Toxicity Prediction) task. More info: https://tdcommons.ai/single_pred_tasks/tox/ Args: name (str): the dataset name. @@ -36,16 +35,15 @@ def __init__( print_stats=False, convert_format=None, ): - """Create Yields (Reaction Yields Prediction) dataloader object.""" + """Create a Tox (Toxicity Prediction) dataloader object.""" super().__init__( name, path, label_name, print_stats, - dataset_names=dataset_names["Yields"], + dataset_names=dataset_names["Tox"], convert_format=convert_format, ) - self.entity1_name = "Reaction" if print_stats: self.print_stats() print("Done!", flush=True, file=sys.stderr) From d0d31e416b1d836e00768af2396f4dcde859364f Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 5 Mar 2024 11:27:57 -0500 Subject: [PATCH 09/39] google lint on tdc/test --- .../chem_utils_test/test_molconverter.py | 14 ++++----- .../chem_utils_test/test_molfilter.py | 4 ++- .../dev_tests/chem_utils_test/test_oracles.py | 30 +++++++++---------- .../dev_tests/utils_tests/test_misc_utils.py | 4 ++- tdc/test/dev_tests/utils_tests/test_splits.py | 25 ++++++++++------ tdc/test/test_benchmark.py | 11 +++++-- tdc/test/test_dataloaders.py | 7 +++-- tdc/test/test_functions.py | 7 +++-- 8 files changed, 61 insertions(+), 41 deletions(-) diff --git a/tdc/test/dev_tests/chem_utils_test/test_molconverter.py b/tdc/test/dev_tests/chem_utils_test/test_molconverter.py index cd034247..09e6950e 100644 --- a/tdc/test/dev_tests/chem_utils_test/test_molconverter.py +++ b/tdc/test/dev_tests/chem_utils_test/test_molconverter.py @@ -11,10 +11,12 @@ # temporary solution for relative imports in case TDC is not installed # if TDC is installed, no need to use the following line -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) class TestMolConvert(unittest.TestCase): + def setUp(self): print(os.getcwd()) pass @@ -24,12 +26,10 @@ def test_MolConvert(self): from tdc.chem_utils import MolConvert converter = MolConvert(src="SMILES", dst="Graph2D") - converter( - [ - "Clc1ccccc1C2C(=C(/N/C(=C2/C(=O)OCC)COCCN)C)\C(=O)OC", - "CCCOc1cc2ncnc(Nc3ccc4ncsc4c3)c2cc1S(=O)(=O)C(C)(C)C", - ] - ) + converter([ + "Clc1ccccc1C2C(=C(/N/C(=C2/C(=O)OCC)COCCN)C)\C(=O)OC", + "CCCOc1cc2ncnc(Nc3ccc4ncsc4c3)c2cc1S(=O)(=O)C(C)(C)C", + ]) from tdc.chem_utils import MolConvert diff --git a/tdc/test/dev_tests/chem_utils_test/test_molfilter.py b/tdc/test/dev_tests/chem_utils_test/test_molfilter.py index 9b0476c6..a153edcd 100644 --- a/tdc/test/dev_tests/chem_utils_test/test_molfilter.py +++ b/tdc/test/dev_tests/chem_utils_test/test_molfilter.py @@ -11,10 +11,12 @@ # temporary solution for relative imports in case TDC is not installed # if TDC is installed, no need to use the following line -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) class TestMolFilter(unittest.TestCase): + def setUp(self): print(os.getcwd()) pass diff --git a/tdc/test/dev_tests/chem_utils_test/test_oracles.py b/tdc/test/dev_tests/chem_utils_test/test_oracles.py index 10dededf..9406f1cc 100644 --- a/tdc/test/dev_tests/chem_utils_test/test_oracles.py +++ b/tdc/test/dev_tests/chem_utils_test/test_oracles.py @@ -11,10 +11,12 @@ # temporary solution for relative imports in case TDC is not installed # if TDC is installed, no need to use the following line -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) class TestOracle(unittest.TestCase): + def setUp(self): print(os.getcwd()) pass @@ -24,13 +26,11 @@ def test_Oracle(self): from tdc import Oracle oracle = Oracle(name="SA") - x = oracle( - [ - "CC(C)(C)[C@H]1CCc2c(sc(NC(=O)COc3ccc(Cl)cc3)c2C(N)=O)C1", - "CCNC(=O)c1ccc(NC(=O)N2CC[C@H](C)[C@H](O)C2)c(C)c1", - "C[C@@H]1CCN(C(=O)CCCc2ccccc2)C[C@@H]1O", - ] - ) + x = oracle([ + "CC(C)(C)[C@H]1CCc2c(sc(NC(=O)COc3ccc(Cl)cc3)c2C(N)=O)C1", + "CCNC(=O)c1ccc(NC(=O)N2CC[C@H](C)[C@H](O)C2)c(C)c1", + "C[C@@H]1CCN(C(=O)CCCc2ccccc2)C[C@@H]1O", + ]) oracle = Oracle(name="Hop") x = oracle(["CC(=O)OC1=CC=CC=C1C(=O)O", "C1=CC=C(C=C1)C=O"]) @@ -40,14 +40,12 @@ def test_distribution(self): from tdc import Evaluator evaluator = Evaluator(name="Diversity") - x = evaluator( - [ - "CC(C)(C)[C@H]1CCc2c(sc(NC(=O)COc3ccc(Cl)cc3)c2C(N)=O)C1", - "C[C@@H]1CCc2c(sc(NC(=O)c3ccco3)c2C(N)=O)C1", - "CCNC(=O)c1ccc(NC(=O)N2CC[C@H](C)[C@H](O)C2)c(C)c1", - "C[C@@H]1CCN(C(=O)CCCc2ccccc2)C[C@@H]1O", - ] - ) + x = evaluator([ + "CC(C)(C)[C@H]1CCc2c(sc(NC(=O)COc3ccc(Cl)cc3)c2C(N)=O)C1", + "C[C@@H]1CCc2c(sc(NC(=O)c3ccco3)c2C(N)=O)C1", + "CCNC(=O)c1ccc(NC(=O)N2CC[C@H](C)[C@H](O)C2)c(C)c1", + "C[C@@H]1CCN(C(=O)CCCc2ccccc2)C[C@@H]1O", + ]) def tearDown(self): print(os.getcwd()) diff --git a/tdc/test/dev_tests/utils_tests/test_misc_utils.py b/tdc/test/dev_tests/utils_tests/test_misc_utils.py index 8e4ab331..5876c51b 100644 --- a/tdc/test/dev_tests/utils_tests/test_misc_utils.py +++ b/tdc/test/dev_tests/utils_tests/test_misc_utils.py @@ -11,10 +11,12 @@ # temporary solution for relative imports in case TDC is not installed # if TDC is installed, no need to use the following line -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))) +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))) class TestFunctions(unittest.TestCase): + def setUp(self): print(os.getcwd()) pass diff --git a/tdc/test/dev_tests/utils_tests/test_splits.py b/tdc/test/dev_tests/utils_tests/test_splits.py index 7377a723..d5959867 100644 --- a/tdc/test/dev_tests/utils_tests/test_splits.py +++ b/tdc/test/dev_tests/utils_tests/test_splits.py @@ -11,10 +11,12 @@ # temporary solution for relative imports in case TDC is not installed # if TDC is installed, no need to use the following line -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))) +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))) class TestFunctions(unittest.TestCase): + def setUp(self): print(os.getcwd()) pass @@ -42,19 +44,24 @@ def test_cold_start_split(self): split = data.get_split(method="cold_split", column_name="Drug") self.assertEqual( - 0, len(set(split["train"]["Drug"]).intersection(set(split["test"]["Drug"]))) - ) + 0, + len( + set(split["train"]["Drug"]).intersection( + set(split["test"]["Drug"])))) self.assertEqual( - 0, len(set(split["valid"]["Drug"]).intersection(set(split["test"]["Drug"]))) - ) + 0, + len( + set(split["valid"]["Drug"]).intersection( + set(split["test"]["Drug"])))) self.assertEqual( 0, - len(set(split["train"]["Drug"]).intersection(set(split["valid"]["Drug"]))), + len( + set(split["train"]["Drug"]).intersection( + set(split["valid"]["Drug"]))), ) - multi_split = data.get_split( - method="cold_split", column_name=["Drug_ID", "Target_ID"] - ) + multi_split = data.get_split(method="cold_split", + column_name=["Drug_ID", "Target_ID"]) for entity in ["Drug_ID", "Target_ID"]: train_entity = set(multi_split["train"][entity]) valid_entity = set(multi_split["valid"][entity]) diff --git a/tdc/test/test_benchmark.py b/tdc/test/test_benchmark.py index b4b78083..90e16272 100644 --- a/tdc/test/test_benchmark.py +++ b/tdc/test/test_benchmark.py @@ -5,7 +5,8 @@ import sys import os -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) from tdc.benchmark_group import admet_group @@ -17,6 +18,7 @@ def is_classification(values): class TestBenchmarkGroup(unittest.TestCase): + def setUp(self): self.group = admet_group(path="data/") @@ -52,11 +54,14 @@ def test_ADME_evaluate_many(self): for ds_name, metrics in results.items(): self.assertEqual(len(metrics), 2) u, std = metrics - self.assertTrue(u in (1, 0)) # A perfect score for all metrics is 1 or 0 + self.assertTrue(u + in (1, + 0)) # A perfect score for all metrics is 1 or 0 self.assertEqual(0, std) for my_group in self.group: self.assertTrue(my_group["name"] in results) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tdc/test/test_dataloaders.py b/tdc/test/test_dataloaders.py index c5de2fa9..6e04d92a 100644 --- a/tdc/test/test_dataloaders.py +++ b/tdc/test/test_dataloaders.py @@ -11,11 +11,13 @@ # temporary solution for relative imports in case TDC is not installed # if TDC is installed, no need to use the following line -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) # TODO: add verification for the generation other than simple integration class TestDataloader(unittest.TestCase): + def setUp(self): print(os.getcwd()) pass @@ -42,5 +44,6 @@ def tearDown(self): print(os.getcwd()) shutil.rmtree(os.path.join(os.getcwd(), "data")) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tdc/test/test_functions.py b/tdc/test/test_functions.py index e713512e..40f0b976 100644 --- a/tdc/test/test_functions.py +++ b/tdc/test/test_functions.py @@ -11,10 +11,12 @@ # temporary solution for relative imports in case TDC is not installed # if TDC is installed, no need to use the following line -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) class TestFunctions(unittest.TestCase): + def setUp(self): print(os.getcwd()) pass @@ -51,5 +53,6 @@ def tearDown(self): if os.path.exists(os.path.join(os.getcwd(), "oracle")): shutil.rmtree(os.path.join(os.getcwd(), "oracle")) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 4657198675553f81fc61bea087b395c457ceecd5 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 5 Mar 2024 12:07:35 -0500 Subject: [PATCH 10/39] google lint on tdc/utils --- tdc/utils/label.py | 66 +++++++------- tdc/utils/label_name_list.py | 3 - tdc/utils/load.py | 167 ++++++++++++++++++----------------- tdc/utils/misc.py | 11 ++- tdc/utils/query.py | 24 ++--- tdc/utils/retrieve.py | 3 +- tdc/utils/split.py | 60 ++++++------- 7 files changed, 171 insertions(+), 163 deletions(-) diff --git a/tdc/utils/label.py b/tdc/utils/label.py index a73829ad..870f7801 100644 --- a/tdc/utils/label.py +++ b/tdc/utils/label.py @@ -21,7 +21,7 @@ def convert_y_unit(y, from_, to_): if from_ == "nM": y = y elif from_ == "p": - y = (10 ** (-y) - 1e-10) / 1e-9 + y = (10**(-y) - 1e-10) / 1e-9 if to_ == "p": y = -np.log10(y * 1e-9 + 1e-10) @@ -31,9 +31,12 @@ def convert_y_unit(y, from_, to_): return y -def label_transform( - y, binary, threshold, convert_to_log, verbose=True, order="descending" -): +def label_transform(y, + binary, + threshold, + convert_to_log, + verbose=True, + order="descending"): """label transformation helper function Args: @@ -62,7 +65,8 @@ def label_transform( elif order == "ascending": y = np.array([1 if i else 0 for i in np.array(y) > threshold]) else: - raise ValueError("Please select order from 'descending or ascending!") + raise ValueError( + "Please select order from 'descending or ascending!") else: if (len(np.unique(y)) > 2) and convert_to_log: if verbose: @@ -144,16 +148,16 @@ def label_dist(y, name=None): median = np.median(y) mean = np.mean(y) - f, (ax_box, ax_hist) = plt.subplots( - 2, sharex=True, gridspec_kw={"height_ratios": (0.15, 1)} - ) + f, (ax_box, + ax_hist) = plt.subplots(2, + sharex=True, + gridspec_kw={"height_ratios": (0.15, 1)}) if name is None: sns.boxplot(y, ax=ax_box).set_title("Label Distribution") else: - sns.boxplot(y, ax=ax_box).set_title( - "Label Distribution of " + str(name) + " Dataset" - ) + sns.boxplot(y, ax=ax_box).set_title("Label Distribution of " + + str(name) + " Dataset") ax_box.axvline(median, color="b", linestyle="--") ax_box.axvline(mean, color="g", linestyle="--") @@ -191,7 +195,8 @@ def NegSample(df, column_names, frac, two_types): pos_set = set([tuple([i[0], i[1]]) for i in pos]) np.random.seed(1234) samples = np.random.choice(df_unique, size=(x, 2), replace=True) - neg_set = set([tuple([i[0], i[1]]) for i in samples if i[0] != i[1]]) - pos_set + neg_set = set([tuple([i[0], i[1]]) for i in samples if i[0] != i[1] + ]) - pos_set while len(neg_set) < x: sample = np.random.choice(df_unique, 2, replace=False) @@ -208,10 +213,13 @@ def NegSample(df, column_names, frac, two_types): neg_list_val.append([i[0], id2seq[i[0]], i[1], id2seq[i[1]], 0]) df = df.append( - pd.DataFrame(neg_list_val).rename( - columns={0: id1, 1: x1, 2: id2, 3: x2, 4: "Y"} - ) - ).reset_index(drop=True) + pd.DataFrame(neg_list_val).rename(columns={ + 0: id1, + 1: x1, + 2: id2, + 3: x2, + 4: "Y" + })).reset_index(drop=True) return df else: df_unique_id1 = np.unique(df[id1].values.reshape(-1)) @@ -224,16 +232,11 @@ def NegSample(df, column_names, frac, two_types): sample_id1 = np.random.choice(df_unique_id1, size=len(df), replace=True) sample_id2 = np.random.choice(df_unique_id2, size=len(df), replace=True) - neg_set = ( - set( - [ - tuple([sample_id1[i], sample_id2[i]]) - for i in range(len(df)) - if sample_id1[i] != sample_id2[i] - ] - ) - - pos_set - ) + neg_set = (set([ + tuple([sample_id1[i], sample_id2[i]]) + for i in range(len(df)) + if sample_id1[i] != sample_id2[i] + ]) - pos_set) while len(neg_set) < len(df): sample_id1 = np.random.choice(df_unique_id1, size=1, replace=True) @@ -252,8 +255,11 @@ def NegSample(df, column_names, frac, two_types): neg_list_val.append([i[0], id2seq1[i[0]], i[1], id2seq2[i[1]], 0]) df = df.append( - pd.DataFrame(neg_list_val).rename( - columns={0: id1, 1: x1, 2: id2, 3: x2, 4: "Y"} - ) - ).reset_index(drop=True) + pd.DataFrame(neg_list_val).rename(columns={ + 0: id1, + 1: x1, + 2: id2, + 3: x2, + 4: "Y" + })).reset_index(drop=True) return df diff --git a/tdc/utils/label_name_list.py b/tdc/utils/label_name_list.py index e164d91c..24a433dc 100644 --- a/tdc/utils/label_name_list.py +++ b/tdc/utils/label_name_list.py @@ -636,11 +636,9 @@ "Tanguay_ZF_120hpf_YSE_up", ] - QM7_targets = ["Y"] # QM7_targets = ["E_PBE0", "E_max_EINDO", "I_max_ZINDO", "HOMO_ZINDO", "LUMO_ZINDO", "E_1st_ZINDO", "IP_ZINDO", "EA_ZINDO", "HOMO_PBE0", "LUMO_PBE0", "HOMO_GW", "LUMO_GW", "alpha_PBE0", "alpha_SCS"] - #### qm7b: 14 labels QM7b_targets = [ "AE_PBE0", @@ -683,7 +681,6 @@ "f1-CAM", ] - # QM9_targets = [ # "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "cv", "u0", "u298", # "h298", "g298" diff --git a/tdc/utils/load.py b/tdc/utils/load.py index 82b348e4..ea88283a 100644 --- a/tdc/utils/load.py +++ b/tdc/utils/load.py @@ -68,12 +68,16 @@ def download_wrapper(name, path, dataset_names): os.mkdir(path) if os.path.exists( - os.path.join(path, name + "-" + str(i + 1) + "." + name2type[name]) - ): + os.path.join( + path, name + "-" + str(i + 1) + "." + name2type[name])): print_sys("Found local copy...") else: print_sys("Downloading...") - dataverse_download(dataset_path, path, name, name2type, id=i + 1) + dataverse_download(dataset_path, + path, + name, + name2type, + id=i + 1) return name else: @@ -118,11 +122,16 @@ def zip_data_download_wrapper(name, path, dataset_names): ) else: print_sys(f"Downloading {i+1}/{len(name2idlist[name])} file...") - dataverse_download(dataset_path, path, name, name2type, id=i + 1) - print_sys(f"Extracting zip {i+1}/{len(name2idlist[name])} file...") + dataverse_download(dataset_path, + path, + name, + name2type, + id=i + 1) + print_sys( + f"Extracting zip {i+1}/{len(name2idlist[name])} file...") with ZipFile( - os.path.join(path, name + "-" + str(i + 1) + ".zip"), "r" - ) as zip: + os.path.join(path, name + "-" + str(i + 1) + ".zip"), + "r") as zip: zip.extractall(path=os.path.join(path)) if not os.path.exists(os.path.join(path, name)): os.mkdir(os.path.join(path, name)) @@ -196,7 +205,8 @@ def oracle_download_wrapper(name, path, oracle_names): print_sys("Found local copy...") else: print_sys("Downloading Oracle...") - dataverse_download(dataset_path, path, name, oracle2type) ## to-do to-check + dataverse_download(dataset_path, path, name, + oracle2type) ## to-do to-check print_sys("Done!") return name @@ -222,19 +232,16 @@ def receptor_download_wrapper(name, path): os.mkdir(path) if os.path.exists(os.path.join(path, name + ".pdbqt")) and os.path.exists( - os.path.join(path, name + ".pdb") - ): + os.path.join(path, name + ".pdb")): print_sys("Found local copy...") else: print_sys("Downloading receptor...") receptor2type = defaultdict(lambda: "pdbqt") - dataverse_download( - dataset_paths[0], path, name, receptor2type - ) ## to-do to-check + dataverse_download(dataset_paths[0], path, name, + receptor2type) ## to-do to-check receptor2type = defaultdict(lambda: "pdb") - dataverse_download( - dataset_paths[1], path, name, receptor2type - ) ## to-do to-check + dataverse_download(dataset_paths[1], path, name, + receptor2type) ## to-do to-check print_sys("Done!") return name @@ -284,11 +291,13 @@ def pd_load(name, path): """ try: if name2type[name] == "tab": - df = pd.read_csv(os.path.join(path, name + "." + name2type[name]), sep="\t") + df = pd.read_csv(os.path.join(path, name + "." + name2type[name]), + sep="\t") elif name2type[name] == "csv": df = pd.read_csv(os.path.join(path, name + "." + name2type[name])) elif name2type[name] == "pkl": - df = pd.read_pickle(os.path.join(path, name + "." + name2type[name])) + df = pd.read_pickle(os.path.join(path, + name + "." + name2type[name])) elif name2type[name] == "zip": df = pd.read_pickle(os.path.join(path, name + "/" + name + ".pkl")) else: @@ -328,7 +337,8 @@ def property_dataset_load(name, path, target, dataset_names): target = fuzzy_search(target, df.columns.values) # df = df.T.drop_duplicates().T ### does not work # df2 = df.loc[:,~df.T.duplicated(keep='first')] ### does not work - df2 = df.loc[:, ~df.columns.duplicated()] ### remove the duplicate columns + df2 = df.loc[:, + ~df.columns.duplicated()] ### remove the duplicate columns df = df2 df = df[df[target].notnull()].reset_index(drop=True) except: @@ -337,8 +347,8 @@ def property_dataset_load(name, path, target, dataset_names): import pickle file_content = pickle.load( - open(os.path.join(path, name + "." + name2type[name]), "rb") - ) + open(os.path.join(path, name + "." + name2type[name]), + "rb")) else: file_content = " ".join(f.readlines()) flag = "Service Unavailable" in " ".join(file_content) @@ -352,7 +362,8 @@ def property_dataset_load(name, path, target, dataset_names): else: import sys - sys.exit("Please report this error to contact@tdcommons.ai, thanks!") + sys.exit( + "Please report this error to contact@tdcommons.ai, thanks!") try: return df["X"], df[target], df["ID"] except: @@ -386,7 +397,8 @@ def interaction_dataset_load(name, path, target, dataset_names, aux_column): if aux_column is None: return df["X1"], df["X2"], df[target], df["ID1"], df["ID2"], "_" else: - return df["X1"], df["X2"], df[target], df["ID1"], df["ID2"], df[aux_column] + return df["X1"], df["X2"], df[target], df["ID1"], df["ID2"], df[ + aux_column] except: with open(os.path.join(path, name + "." + name2type[name]), "r") as f: @@ -400,7 +412,8 @@ def interaction_dataset_load(name, path, target, dataset_names, aux_column): else: import sys - sys.exit("Please report this error to cosamhkx@gmail.com, thanks!") + sys.exit( + "Please report this error to cosamhkx@gmail.com, thanks!") def multi_dataset_load(name, path, dataset_names): @@ -421,7 +434,8 @@ def multi_dataset_load(name, path, dataset_names): return df -def generation_paired_dataset_load(name, path, dataset_names, input_name, output_name): +def generation_paired_dataset_load(name, path, dataset_names, input_name, + output_name): """a wrapper to download, process and load generation-paired task datasets Args: @@ -502,26 +516,26 @@ def bi_distribution_dataset_load( if name == "pdbbind": print_sys("Processing (this may take long)...") - protein, ligand = process_pdbbind( - path, name, return_pocket, remove_protein_Hs, remove_ligand_Hs, keep_het - ) + protein, ligand = process_pdbbind(path, name, return_pocket, + remove_protein_Hs, remove_ligand_Hs, + keep_het) elif name == "dude": print_sys("Processing (this may take long)...") if return_pocket: raise ImportError("DUD-E does not support pocket extraction yet") - protein, ligand = process_dude( - path, name, return_pocket, remove_protein_Hs, remove_ligand_Hs, keep_het - ) + protein, ligand = process_dude(path, name, return_pocket, + remove_protein_Hs, remove_ligand_Hs, + keep_het) elif name == "scpdb": print_sys("Processing (this may take long)...") - protein, ligand = process_scpdb( - path, name, return_pocket, remove_protein_Hs, remove_ligand_Hs, keep_het - ) + protein, ligand = process_scpdb(path, name, return_pocket, + remove_protein_Hs, remove_ligand_Hs, + keep_het) elif name == "crossdock": print_sys("Processing (this may take long)...") - protein, ligand = process_crossdock( - path, name, return_pocket, remove_protein_Hs, remove_ligand_Hs, keep_het - ) + protein, ligand = process_crossdock(path, name, return_pocket, + remove_protein_Hs, remove_ligand_Hs, + keep_het) return protein, ligand @@ -631,15 +645,13 @@ def process_pdbbind( try: if return_pocket: protein = PandasPdb().read_pdb( - os.path.join(path, f"{file}/{file}_pocket.pdb") - ) + os.path.join(path, f"{file}/{file}_pocket.pdb")) else: protein = PandasPdb().read_pdb( - os.path.join(path, f"{file}/{file}_protein.pdb") - ) - ligand = Chem.SDMolSupplier( - os.path.join(path, f"{file}/{file}_ligand.sdf"), sanitize=False - )[0] + os.path.join(path, f"{file}/{file}_protein.pdb")) + ligand = Chem.SDMolSupplier(os.path.join( + path, f"{file}/{file}_ligand.sdf"), + sanitize=False)[0] ligand = extract_atom_from_mol(ligand, remove_ligand_Hs) # if ligand contains unallowed atoms if ligand is None: @@ -716,17 +728,16 @@ def process_crossdock( else: # full protein not stored in the preprocessed crossdock by Luo et al 2021 protein = PandasPdb().read_pdb(os.path.join(path, pocket_fn)) - ligand = Chem.SDMolSupplier(os.path.join(path, ligand_fn), sanitize=False)[ - 0 - ] + ligand = Chem.SDMolSupplier(os.path.join(path, ligand_fn), + sanitize=False)[0] ligand = extract_atom_from_mol(ligand, remove_ligand_Hs) if ligand is None: continue else: ligand_coord, ligand_atom_type = ligand protein_coord, protein_atom_type = extract_atom_from_protein( - protein.df["ATOM"], protein.df["HETATM"], remove_protein_Hs, keep_het - ) + protein.df["ATOM"], protein.df["HETATM"], remove_protein_Hs, + keep_het) protein_coords.append(protein_coord) ligand_coords.append(ligand_coord) protein_atom_types.append(protein_atom_type) @@ -778,23 +789,24 @@ def process_dude( failure = 0 total_ct = 0 for idx, file in enumerate(tqdm(files)): - protein = PandasPdb().read_pdb(os.path.join(path, f"{file}/receptor.pdb")) + protein = PandasPdb().read_pdb( + os.path.join(path, f"{file}/receptor.pdb")) if not os.path.exists(os.path.join(path, f"{file}/actives_final.sdf")): os.system(f"gzip -d {path}/{file}/actives_final.sdf.gz") - crystal_ligand = Chem.MolFromMol2File( - os.path.join(path, f"{file}/crystal_ligand.mol2"), sanitize=False - ) + crystal_ligand = Chem.MolFromMol2File(os.path.join( + path, f"{file}/crystal_ligand.mol2"), + sanitize=False) crystal_ligand = extract_atom_from_mol(crystal_ligand, remove_ligand_Hs) if crystal_ligand is None: continue else: crystal_ligand_coord, crystal_ligand_atom_type = crystal_ligand - ligands = Chem.SDMolSupplier( - os.path.join(path, f"{file}/actives_final.sdf"), sanitize=False - ) + ligands = Chem.SDMolSupplier(os.path.join(path, + f"{file}/actives_final.sdf"), + sanitize=False) protein_coord, protein_atom_type = extract_atom_from_protein( - protein.df["ATOM"], protein.df["HETATM"], remove_protein_Hs, keep_het - ) + protein.df["ATOM"], protein.df["HETATM"], remove_protein_Hs, + keep_het) protein_coords.append(protein_coord) ligand_coords.append(crystal_ligand_coord) protein_atom_types.append(protein_atom_type) @@ -863,15 +875,13 @@ def process_scpdb( try: if return_pocket: protein = PandasMol2().read_mol2( - os.path.join(path, f"{file}/site.mol2") - ) + os.path.join(path, f"{file}/site.mol2")) else: protein = PandasMol2().read_mol2( - os.path.join(path, f"{file}/protein.mol2") - ) - ligand = Chem.SDMolSupplier( - os.path.join(path, f"{file}/ligand.sdf"), sanitize=False - )[0] + os.path.join(path, f"{file}/protein.mol2")) + ligand = Chem.SDMolSupplier(os.path.join(path, + f"{file}/ligand.sdf"), + sanitize=False)[0] ligand = extract_atom_from_mol(ligand, remove_Hs=remove_ligand_Hs) # if ligand contains unallowed atoms if ligand is None: @@ -879,8 +889,7 @@ def process_scpdb( else: ligand_coord, ligand_atom_type = ligand protein_coord, protein_atom_type = extract_atom_from_protein( - protein.df, None, remove_Hs=remove_protein_Hs, keep_het=False - ) + protein.df, None, remove_Hs=remove_protein_Hs, keep_het=False) protein_coords.append(protein_coord) ligand_coords.append(ligand_coord) protein_atom_types.append(protein_atom_type) @@ -957,23 +966,15 @@ def extract_atom_from_protein(data_frame, data_frame_het, remove_Hs, keep_het): if keep_het and data_frame_het is not None: data_frame = pd.concat([data_frame, data_frame_het]) if remove_Hs: - data_frame = data_frame[data_frame["atom_name"].str.startswith("H") == False] + data_frame = data_frame[data_frame["atom_name"].str.startswith("H") == + False] data_frame.reset_index(inplace=True, drop=True) - x = ( - data_frame["x_coord"].to_numpy() - if "x_coord" in data_frame - else data_frame["x"].to_numpy() - ) - y = ( - data_frame["y_coord"].to_numpy() - if "y_coord" in data_frame - else data_frame["y"].to_numpy() - ) - z = ( - data_frame["z_coord"].to_numpy() - if "z_coord" in data_frame - else data_frame["z"].to_numpy() - ) + x = (data_frame["x_coord"].to_numpy() + if "x_coord" in data_frame else data_frame["x"].to_numpy()) + y = (data_frame["y_coord"].to_numpy() + if "y_coord" in data_frame else data_frame["y"].to_numpy()) + z = (data_frame["z_coord"].to_numpy() + if "z_coord" in data_frame else data_frame["z"].to_numpy()) x = np.expand_dims(x, axis=1) y = np.expand_dims(y, axis=1) z = np.expand_dims(z, axis=1) diff --git a/tdc/utils/misc.py b/tdc/utils/misc.py index c7f65abb..57dc9dc1 100644 --- a/tdc/utils/misc.py +++ b/tdc/utils/misc.py @@ -33,7 +33,8 @@ def fuzzy_search(name, dataset_names): return s else: raise ValueError( - s + " does not belong to this task, please refer to the correct task name!" + s + + " does not belong to this task, please refer to the correct task name!" ) @@ -56,7 +57,9 @@ def get_closet_match(predefined_tokens, test_token, threshold=0.8): for token in predefined_tokens: # print(token) - prob_list.append(fuzz.ratio(str(token).lower(), str(test_token).lower())) + prob_list.append(fuzz.ratio( + str(token).lower(), + str(test_token).lower())) assert len(prob_list) == len(predefined_tokens) @@ -67,8 +70,8 @@ def get_closet_match(predefined_tokens, test_token, threshold=0.8): if prob_max / 100 < threshold: print_sys(predefined_tokens) raise ValueError( - test_token, "does not match to available values. " "Please double check." - ) + test_token, "does not match to available values. " + "Please double check.") return token_max, prob_max / 100 diff --git a/tdc/utils/query.py b/tdc/utils/query.py index 1d450844..a4dede37 100644 --- a/tdc/utils/query.py +++ b/tdc/utils/query.py @@ -15,7 +15,8 @@ def _parse_prop(search, proplist): """Extract property value from record using the given urn search filter.""" props = [ - i for i in proplist if all(item in i["urn"].items() for item in search.items()) + i for i in proplist + if all(item in i["urn"].items() for item in search.items()) ] if len(props) > 0: return props[0]["value"][list(props[0]["value"].keys())[0]] @@ -48,18 +49,15 @@ def request( urlid, postdata = None, None if namespace == "sourceid": identifier = identifier.replace("/", ".") - if ( - namespace in ["listkey", "formula", "sourceid"] - or searchtype == "xref" - or (searchtype and namespace == "cid") - or domain == "sources" - ): + if (namespace in ["listkey", "formula", "sourceid"] or + searchtype == "xref" or (searchtype and namespace == "cid") or + domain == "sources"): urlid = quote(identifier.encode("utf8")) else: postdata = urlencode([(namespace, identifier)]).encode("utf8") comps = filter( - None, [API_BASE, domain, searchtype, namespace, urlid, operation, output] - ) + None, + [API_BASE, domain, searchtype, namespace, urlid, operation, output]) apiurl = "/".join(comps) # Make request response = urlopen(apiurl, postdata) @@ -99,8 +97,12 @@ def cid2smiles(cid): """ try: smiles = _parse_prop( - {"label": "SMILES", "name": "Canonical"}, - json.loads(request(cid).read().decode())["PC_Compounds"][0]["props"], + { + "label": "SMILES", + "name": "Canonical" + }, + json.loads( + request(cid).read().decode())["PC_Compounds"][0]["props"], ) except: print("cid " + str(cid) + " failed, use NULL string") diff --git a/tdc/utils/retrieve.py b/tdc/utils/retrieve.py index 429cfbab..a50c8bca 100644 --- a/tdc/utils/retrieve.py +++ b/tdc/utils/retrieve.py @@ -73,7 +73,8 @@ def get_reaction_type(name, path="./data", output_format="array"): elif output_format == "array": return df["category"].values else: - raise ValueError("Please use the correct output format, select from df, array.") + raise ValueError( + "Please use the correct output format, select from df, array.") def retrieve_label_name_list(name): diff --git a/tdc/utils/split.py b/tdc/utils/split.py index c8f518cc..c1eef453 100644 --- a/tdc/utils/split.py +++ b/tdc/utils/split.py @@ -21,9 +21,9 @@ def create_fold(df, fold_seed, frac): train_frac, val_frac, test_frac = frac test = df.sample(frac=test_frac, replace=False, random_state=fold_seed) train_val = df[~df.index.isin(test.index)] - val = train_val.sample( - frac=val_frac / (1 - test_frac), replace=False, random_state=1 - ) + val = train_val.sample(frac=val_frac / (1 - test_frac), + replace=False, + random_state=1) train = train_val[~train_val.index.isin(val.index)] return { @@ -54,10 +54,9 @@ def create_fold_setting_cold(df, fold_seed, frac, entities): # For each entity, sample the instances belonging to the test datasets test_entity_instances = [ - df[e] - .drop_duplicates() - .sample(frac=test_frac, replace=False, random_state=fold_seed) - .values + df[e].drop_duplicates().sample(frac=test_frac, + replace=False, + random_state=fold_seed).values for e in entities ] @@ -69,8 +68,7 @@ def create_fold_setting_cold(df, fold_seed, frac, entities): if len(test) == 0: raise ValueError( "No test samples found. Try another seed, increasing the test frac or a " - "less stringent splitting strategy." - ) + "less stringent splitting strategy.") # Proceed with validation data train_val = df.copy() @@ -78,10 +76,9 @@ def create_fold_setting_cold(df, fold_seed, frac, entities): train_val = train_val[~train_val[e].isin(test_entity_instances[i])] val_entity_instances = [ - train_val[e] - .drop_duplicates() - .sample(frac=val_frac / (1 - test_frac), replace=False, random_state=fold_seed) - .values + train_val[e].drop_duplicates().sample(frac=val_frac / (1 - test_frac), + replace=False, + random_state=fold_seed).values for e in entities ] val = train_val.copy() @@ -91,8 +88,7 @@ def create_fold_setting_cold(df, fold_seed, frac, entities): if len(val) == 0: raise ValueError( "No validation samples found. Try another seed, increasing the test frac " - "or a less stringent splitting strategy." - ) + "or a less stringent splitting strategy.") train = train_val.copy() for i, e in enumerate(entities): @@ -127,8 +123,7 @@ def create_scaffold_split(df, seed, frac, entity): RDLogger.DisableLog("rdApp.*") except: raise ImportError( - "Please install rdkit by 'conda install -c conda-forge rdkit'! " - ) + "Please install rdkit by 'conda install -c conda-forge rdkit'! ") from tqdm import tqdm from random import Random @@ -144,8 +139,7 @@ def create_scaffold_split(df, seed, frac, entity): for i, smiles in tqdm(enumerate(s), total=len(s)): try: scaffold = MurckoScaffold.MurckoScaffoldSmiles( - mol=Chem.MolFromSmiles(smiles), includeChirality=False - ) + mol=Chem.MolFromSmiles(smiles), includeChirality=False) scaffolds[scaffold].add(i) except: print_sys(smiles + " returns RDKit error and is thus omitted...") @@ -213,9 +207,9 @@ def create_combination_generation_split(dict1, dict2, seed, frac): length = len(dict1["coord"]) indices = np.random.permutation(length) train_idx, val_idx, test_idx = ( - indices[: int(length * train_frac)], - indices[int(length * train_frac) : int(length * (train_frac + val_frac))], - indices[int(length * (train_frac + val_frac)) :], + indices[:int(length * train_frac)], + indices[int(length * train_frac):int(length * (train_frac + val_frac))], + indices[int(length * (train_frac + val_frac)):], ) return { @@ -272,9 +266,10 @@ def create_combination_split(df, seed, frac): intxn = intxn.intersection(c) # Split combinations into train, val and test - test_choices = np.random.choice( - list(intxn), int(test_size / len(df["Cell_Line_ID"].unique())), replace=False - ) + test_choices = np.random.choice(list(intxn), + int(test_size / + len(df["Cell_Line_ID"].unique())), + replace=False) trainval_intxn = intxn.difference(test_choices) val_choices = np.random.choice( list(trainval_intxn), @@ -312,15 +307,18 @@ def create_fold_time(df, frac, date_column): df = df.sort_values(by=date_column).reset_index(drop=True) train_frac, val_frac, test_frac = frac[0], frac[1], frac[2] - split_date = df[: int(len(df) * (train_frac + val_frac))].iloc[-1][date_column] + split_date = df[:int(len(df) * + (train_frac + val_frac))].iloc[-1][date_column] test = df[df[date_column] >= split_date].reset_index(drop=True) train_val = df[df[date_column] < split_date] - split_date_valid = train_val[ - : int(len(train_val) * train_frac / (train_frac + val_frac)) - ].iloc[-1][date_column] - train = train_val[train_val[date_column] <= split_date_valid].reset_index(drop=True) - valid = train_val[train_val[date_column] > split_date_valid].reset_index(drop=True) + split_date_valid = train_val[:int( + len(train_val) * train_frac / + (train_frac + val_frac))].iloc[-1][date_column] + train = train_val[train_val[date_column] <= split_date_valid].reset_index( + drop=True) + valid = train_val[train_val[date_column] > split_date_valid].reset_index( + drop=True) return { "train": train, From d638516170f02ddf54535294aee0c2653aa4215f Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 5 Mar 2024 12:11:41 -0500 Subject: [PATCH 11/39] google lint tdc base files --- tdc/base_dataset.py | 95 +++++++++++---------------- tdc/benchmark_deprecated.py | 125 ++++++++++++++++++++---------------- tdc/evaluator.py | 5 +- tdc/metadata.py | 125 +++++++++++++++++------------------- tdc/oracles.py | 56 +++++++--------- tdc/tdc_hf.py | 74 +++++++++++---------- 6 files changed, 227 insertions(+), 253 deletions(-) diff --git a/tdc/base_dataset.py b/tdc/base_dataset.py index 8f133d91..4e49bea3 100644 --- a/tdc/base_dataset.py +++ b/tdc/base_dataset.py @@ -16,7 +16,6 @@ class DataLoader: - """base data loader class that contains functions shared by almost all data loader classes.""" def __init__(self): @@ -35,13 +34,11 @@ def get_data(self, format="df"): AttributeError: format not supported """ if format == "df": - return pd.DataFrame( - { - self.entity1_name + "_ID": self.entity1_idx, - self.entity1_name: self.entity1, - "Y": self.y, - } - ) + return pd.DataFrame({ + self.entity1_name + "_ID": self.entity1_idx, + self.entity1_name: self.entity1, + "Y": self.y, + }) elif format == "dict": return { self.entity1_name + "_ID": self.entity1_idx, @@ -56,11 +53,8 @@ def get_data(self, format="df"): def print_stats(self): """print statistics""" print( - "There are " - + str(len(np.unique(self.entity1))) - + " unique " - + self.entity1_name.lower() - + "s", + "There are " + str(len(np.unique(self.entity1))) + " unique " + + self.entity1_name.lower() + "s", flush=True, file=sys.stderr, ) @@ -86,7 +80,8 @@ def get_split(self, method="random", seed=42, frac=[0.7, 0.1, 0.2]): if method == "random": return utils.create_fold(df, seed, frac) elif method == "cold_" + self.entity1_name.lower(): - return utils.create_fold_setting_cold(df, seed, frac, self.entity1_name) + return utils.create_fold_setting_cold(df, seed, frac, + self.entity1_name) else: raise AttributeError("Please specify the correct splitting method") @@ -110,30 +105,22 @@ def binarize(self, threshold=None, order="descending"): if threshold is None: raise AttributeError( "Please specify the threshold to binarize the data by " - "'binarize(threshold = N)'!" - ) + "'binarize(threshold = N)'!") if len(np.unique(self.y)) == 2: print("The data is already binarized!", flush=True, file=sys.stderr) else: print( - "Binariztion using threshold " - + str(threshold) - + ", default, we assume the smaller values are 1 " + "Binariztion using threshold " + str(threshold) + + ", default, we assume the smaller values are 1 " "and larger ones is 0, you can change the order " "by 'binarize(order = 'ascending')'", flush=True, file=sys.stderr, ) - if ( - np.unique(self.y) - .reshape( - -1, - ) - .shape[0] - < 2 - ): - raise AttributeError("Adjust your threshold, there is only one class.") + if (np.unique(self.y).reshape(-1,).shape[0] < 2): + raise AttributeError( + "Adjust your threshold, there is only one class.") self.y = utils.binarize(self.y, threshold, order) return self @@ -223,36 +210,26 @@ def balanced(self, oversample=False, seed=42): flush=True, file=sys.stderr, ) - val = ( - pd.concat( - [ - val[val.Y == major_class].sample( - n=len(val[val.Y == minor_class]), - replace=False, - random_state=seed, - ), - val[val.Y == minor_class], - ] - ) - .sample(frac=1, replace=False, random_state=seed) - .reset_index(drop=True) - ) + val = (pd.concat([ + val[val.Y == major_class].sample( + n=len(val[val.Y == minor_class]), + replace=False, + random_state=seed, + ), + val[val.Y == minor_class], + ]).sample(frac=1, replace=False, + random_state=seed).reset_index(drop=True)) else: - print( - " Oversample of minority class is used. ", flush=True, file=sys.stderr - ) - val = ( - pd.concat( - [ - val[val.Y == minor_class].sample( - n=len(val[val.Y == major_class]), - replace=True, - random_state=seed, - ), - val[val.Y == major_class], - ] - ) - .sample(frac=1, replace=False, random_state=seed) - .reset_index(drop=True) - ) + print(" Oversample of minority class is used. ", + flush=True, + file=sys.stderr) + val = (pd.concat([ + val[val.Y == minor_class].sample( + n=len(val[val.Y == major_class]), + replace=True, + random_state=seed, + ), + val[val.Y == major_class], + ]).sample(frac=1, replace=False, + random_state=seed).reset_index(drop=True)) return val diff --git a/tdc/benchmark_deprecated.py b/tdc/benchmark_deprecated.py index 0150b5af..16b10fac 100644 --- a/tdc/benchmark_deprecated.py +++ b/tdc/benchmark_deprecated.py @@ -25,6 +25,7 @@ class BenchmarkGroup: + def __init__( self, name, @@ -162,7 +163,8 @@ def __next__(self): ncpu=self.num_cpus, num_max_call=self.num_max_call, ) - data = pd.read_csv(os.path.join(self.path, "zinc.tab"), sep="\t") + data = pd.read_csv(os.path.join(self.path, "zinc.tab"), + sep="\t") return {"oracle": oracle, "data": data, "name": dataset} else: return {"train_val": train, "test": test, "name": dataset} @@ -200,15 +202,19 @@ def get_train_valid_split(self, seed, benchmark, split_type="default"): frac = [frac[0], frac[1], 0.0] """ if split_method == "scaffold": - out = create_scaffold_split(train_val, seed, frac=frac, entity="Drug") + out = create_scaffold_split(train_val, + seed, + frac=frac, + entity="Drug") elif split_method == "random": out = create_fold(train_val, seed, frac=frac) elif split_method == "combination": out = create_combination_split(train_val, seed, frac=frac) elif split_method == "group": - out = create_group_split( - train_val, seed, holdout_frac=0.2, group_column="Year" - ) + out = create_group_split(train_val, + seed, + holdout_frac=0.2, + group_column="Year") else: raise NotImplementedError return out["train"], out["valid"] @@ -246,7 +252,12 @@ def get(self, benchmark, num_max_call=5000): else: return {"train_val": train, "test": test, "name": dataset} - def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True): + def evaluate(self, + pred, + true=None, + benchmark=None, + m1_api=None, + save_dict=True): if self.name == "docking_group": results_all = {} @@ -284,7 +295,8 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True) dataset = fuzzy_search(data_name, self.dataset_names) # docking scores for the top K smiles (K <= 100) - target_pdb_file = os.path.join(self.path, dataset + ".pdb") + target_pdb_file = os.path.join(self.path, + dataset + ".pdb") from .oracles import Oracle data_path = os.path.join(self.path, dataset) @@ -304,12 +316,14 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True) docking_scores = oracle(pred_) print_sys("---- Calculating average docking scores ----") - if ( - len(np.where(np.array(list(docking_scores.values())) > 0)[0]) - > 0.7 - ): + if (len( + np.where( + np.array(list(docking_scores.values())) > 0)[0]) + > 0.7): ## check if the scores are all positive.. if so, make them all negative - docking_scores = {j: -k for j, k in docking_scores.items()} + docking_scores = { + j: -k for j, k in docking_scores.items() + } if save_dict: results["docking_scores_dict"] = docking_scores values = np.array(list(docking_scores.values())) @@ -327,23 +341,23 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True) ) from .oracles import Oracle - m1 = Oracle(name="Molecule One Synthesis", api_token=m1_api) + m1 = Oracle(name="Molecule One Synthesis", + api_token=m1_api) import heapq from operator import itemgetter top10_docking_smiles = list( dict( - heapq.nsmallest( - 10, docking_scores.items(), key=itemgetter(1) - ) - ).keys() - ) + heapq.nsmallest(10, + docking_scores.items(), + key=itemgetter(1))).keys()) m1_scores = m1(top10_docking_smiles) scores_array = list(m1_scores.values()) - scores_array = np.array([float(i) for i in scores_array]) - scores_array[ - np.where(scores_array == -1.0)[0] - ] = 10 # m1 score errors are usually large complex molecules + scores_array = np.array( + [float(i) for i in scores_array]) + scores_array[np.where( + scores_array == -1.0 + )[0]] = 10 # m1 score errors are usually large complex molecules if save_dict: results["m1_dict"] = m1_scores results["m1"] = np.mean(scores_array) @@ -361,8 +375,7 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True) results["pass_list"] = pred_filter results["%pass"] = float(len(pred_filter)) / 100 results["top1_%pass"] = max( - [docking_scores[i] for i in pred_filter] - ) + [docking_scores[i] for i in pred_filter]) print_sys("---- Calculating diversity ----") from .evaluator import Evaluator @@ -371,13 +384,13 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True) results["diversity"] = score print_sys("---- Calculating novelty ----") evaluator = Evaluator(name="Novelty") - training = pd.read_csv( - os.path.join(self.path, "zinc.tab"), sep="\t" - ) + training = pd.read_csv(os.path.join(self.path, "zinc.tab"), + sep="\t") score = evaluator(pred_, training.smiles.values) results["novelty"] = score results["top smiles"] = [ - i[0] for i in sorted(docking_scores.items(), key=lambda x: x[1]) + i[0] for i in sorted(docking_scores.items(), + key=lambda x: x[1]) ] results_max_call[num_max_call] = results results_all[data_name] = results_max_call @@ -395,8 +408,11 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True) elif self.file_format == "pkl": test = pd.read_pickle(os.path.join(data_path, "test.pkl")) y = test.Y.values - evaluator = eval("Evaluator(name = '" + metric_dict[data_name] + "')") - out[data_name] = {metric_dict[data_name]: round(evaluator(y, pred_), 3)} + evaluator = eval("Evaluator(name = '" + metric_dict[data_name] + + "')") + out[data_name] = { + metric_dict[data_name]: round(evaluator(y, pred_), 3) + } # If reporting accuracy across target classes if "target_class" in test.columns: @@ -407,13 +423,11 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True) y_subset = test_subset.Y.values pred_subset = test_subset.pred.values - evaluator = eval( - "Evaluator(name = '" + metric_dict[data_name_subset] + "')" - ) + evaluator = eval("Evaluator(name = '" + + metric_dict[data_name_subset] + "')") out[data_name_subset] = { - metric_dict[data_name_subset]: round( - evaluator(y_subset, pred_subset), 3 - ) + metric_dict[data_name_subset]: + round(evaluator(y_subset, pred_subset), 3) } return out else: @@ -424,12 +438,15 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True) ) data_name = fuzzy_search(benchmark, self.dataset_names) metric_dict = bm_metric_names[self.name] - evaluator = eval("Evaluator(name = '" + metric_dict[data_name] + "')") + evaluator = eval("Evaluator(name = '" + metric_dict[data_name] + + "')") return {metric_dict[data_name]: round(evaluator(true, pred), 3)} - def evaluate_many( - self, preds, save_file_name=None, m1_api=None, results_individual=None - ): + def evaluate_many(self, + preds, + save_file_name=None, + m1_api=None, + results_individual=None): """ :param preds: list of dict :return: dict os.path.getsize(file_name2): model_file, config_file = file_name1, file_name2 else: config_file, model_file = file_name1, file_name2 os.rename(model_file, save_path + 'model.pt') - os.rename(config_file, save_path + 'config.pkl') + os.rename(config_file, save_path + 'config.pkl') try: from DeepPurpose import CompoundPred except: - raise ValueError("Please install DeepPurpose package following https://github.com/kexinhuang12345/DeepPurpose#installation") - - net = CompoundPred.model_pretrained(path_dir = save_path) + raise ValueError( + "Please install DeepPurpose package following https://github.com/kexinhuang12345/DeepPurpose#installation" + ) + + net = CompoundPred.model_pretrained(path_dir=save_path) return net else: raise ValueError("This repo does not host a DeepPurpose model!") + def predict_deeppurpose(self, model, drugs): try: from DeepPurpose import utils except: - raise ValueError("Please install DeepPurpose package following https://github.com/kexinhuang12345/DeepPurpose#installation") + raise ValueError( + "Please install DeepPurpose package following https://github.com/kexinhuang12345/DeepPurpose#installation" + ) if self.model_name == 'AttentiveFP': self.model_name = 'DGL_' + self.model_name - X_pred = utils.data_process(X_drug = drugs, y = [0]*len(drugs), - drug_encoding = self.model_name, - split_method='no_split') + X_pred = utils.data_process(X_drug=drugs, + y=[0] * len(drugs), + drug_encoding=self.model_name, + split_method='no_split') y_pred = model.predict(X_pred)[0] return y_pred From 3285ae46c58674564138520a56fec0ce8bf9d991 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 5 Mar 2024 12:28:15 -0500 Subject: [PATCH 12/39] fix mistake in yields.py --- tdc/single_pred/yields.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tdc/single_pred/yields.py b/tdc/single_pred/yields.py index 3d3892b5..a4894573 100644 --- a/tdc/single_pred/yields.py +++ b/tdc/single_pred/yields.py @@ -12,8 +12,8 @@ from ..metadata import dataset_names -class Tox(single_pred_dataset.DataLoader): - """Data loader class to load datasets in Tox (Toxicity Prediction) task. More info: https://tdcommons.ai/single_pred_tasks/tox/ +class Yields(single_pred_dataset.DataLoader): + """Data loader class to load datasets in Yields (Reaction Yields Prediction) task. More info: https://tdcommons.ai/single_pred_tasks/yields/ Args: name (str): the dataset name. @@ -35,15 +35,16 @@ def __init__( print_stats=False, convert_format=None, ): - """Create a Tox (Toxicity Prediction) dataloader object.""" + """Create Yields (Reaction Yields Prediction) dataloader object.""" super().__init__( name, path, label_name, print_stats, - dataset_names=dataset_names["Tox"], + dataset_names=dataset_names["Yields"], convert_format=convert_format, ) + self.entity1_name = "Reaction" if print_stats: self.print_stats() print("Done!", flush=True, file=sys.stderr) From d7502081b678a0fe8bf90b9b739fd8947543c4c2 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 5 Mar 2024 13:31:00 -0500 Subject: [PATCH 13/39] completed google lint. added torch dependencies to conda. tests pass. unskipped all but one. --- environment.yml | 5 ++++ run_tests.py | 6 ++++- .../chem_utils_test/test_molconverter.py | 4 +-- .../chem_utils_test/test_molfilter.py | 4 +-- .../dev_tests/chem_utils_test/test_oracles.py | 4 +-- .../dev_tests/utils_tests/test_misc_utils.py | 14 +++++----- tdc/test/dev_tests/utils_tests/test_splits.py | 12 ++++----- tdc/utils/label.py | 26 ++++++++++++------- 8 files changed, 45 insertions(+), 30 deletions(-) diff --git a/environment.yml b/environment.yml index a9ef8915..5d460fe1 100644 --- a/environment.yml +++ b/environment.yml @@ -2,6 +2,7 @@ name: tdc-conda-env channels: - conda-forge - defaults + - pytorch dependencies: - dataclasses=0.8 - fuzzywuzzy=0.18.0 @@ -10,10 +11,14 @@ dependencies: - python=3.9.13 - pip=23.3.1 - pandas=2.1.4 + - pyg=2.5.0 + - pytorch=2.2.1 - requests=2.31.0 - scikit-learn=1.3.0 - seaborn=0.12.2 - tqdm=4.65.0 + - torchaudio=2.2.1 + - torchvision=0.17.1 - pip: - cellxgene-census==1.10.2 - gget==0.28.4 diff --git a/run_tests.py b/run_tests.py index cd250420..6e89f236 100644 --- a/run_tests.py +++ b/run_tests.py @@ -6,4 +6,8 @@ suite = loader.discover(start_dir) runner = unittest.TextTestRunner() - runner.run(suite) + res = runner.run(suite) + if res.wasSuccessful(): + print("All base tests passed") + else: + raise RuntimeError("Some base tests failed") diff --git a/tdc/test/dev_tests/chem_utils_test/test_molconverter.py b/tdc/test/dev_tests/chem_utils_test/test_molconverter.py index 09e6950e..c08b847a 100644 --- a/tdc/test/dev_tests/chem_utils_test/test_molconverter.py +++ b/tdc/test/dev_tests/chem_utils_test/test_molconverter.py @@ -21,7 +21,7 @@ def setUp(self): print(os.getcwd()) pass - @unittest.skip("dev test") + def test_MolConvert(self): from tdc.chem_utils import MolConvert @@ -35,7 +35,7 @@ def test_MolConvert(self): MolConvert.eligible_format() - # @unittest.skip("dev test") + # def tearDown(self): print(os.getcwd()) diff --git a/tdc/test/dev_tests/chem_utils_test/test_molfilter.py b/tdc/test/dev_tests/chem_utils_test/test_molfilter.py index a153edcd..95bf402c 100644 --- a/tdc/test/dev_tests/chem_utils_test/test_molfilter.py +++ b/tdc/test/dev_tests/chem_utils_test/test_molfilter.py @@ -21,14 +21,14 @@ def setUp(self): print(os.getcwd()) pass - @unittest.skip("dev test") + def test_MolConvert(self): from tdc.chem_utils import MolFilter filters = MolFilter(filters=["PAINS"], HBD=[0, 6]) filters(["CCSc1ccccc1C(=O)Nc1onc2c1CCC2"]) - # @unittest.skip("dev test") + # def tearDown(self): print(os.getcwd()) diff --git a/tdc/test/dev_tests/chem_utils_test/test_oracles.py b/tdc/test/dev_tests/chem_utils_test/test_oracles.py index 9406f1cc..884e8a55 100644 --- a/tdc/test/dev_tests/chem_utils_test/test_oracles.py +++ b/tdc/test/dev_tests/chem_utils_test/test_oracles.py @@ -21,7 +21,7 @@ def setUp(self): print(os.getcwd()) pass - @unittest.skip("dev test") + def test_Oracle(self): from tdc import Oracle @@ -35,7 +35,7 @@ def test_Oracle(self): oracle = Oracle(name="Hop") x = oracle(["CC(=O)OC1=CC=CC=C1C(=O)O", "C1=CC=C(C=C1)C=O"]) - @unittest.skip("dev test") + def test_distribution(self): from tdc import Evaluator diff --git a/tdc/test/dev_tests/utils_tests/test_misc_utils.py b/tdc/test/dev_tests/utils_tests/test_misc_utils.py index 5876c51b..c28e1bdb 100644 --- a/tdc/test/dev_tests/utils_tests/test_misc_utils.py +++ b/tdc/test/dev_tests/utils_tests/test_misc_utils.py @@ -21,7 +21,7 @@ def setUp(self): print(os.getcwd()) pass - @unittest.skip("dev test") + @unittest.skip("long running test") def test_neg_sample(self): from tdc.multi_pred import PPI @@ -33,7 +33,7 @@ def test_neg_sample(self): # data = ADME(name='Caco2_Wang') # x = data.label_distribution() - @unittest.skip("dev test") + def test_get_label_map(self): from tdc.multi_pred import DDI from tdc.utils import get_label_map @@ -42,26 +42,26 @@ def test_get_label_map(self): split = data.get_split() get_label_map(name="DrugBank", task="DDI") - @unittest.skip("dev test") + def test_balanced(self): from tdc.single_pred import HTS data = HTS(name="SARSCoV2_3CLPro_Diamond") data.balanced(oversample=True, seed=42) - @unittest.skip("dev test") + def test_cid2smiles(self): from tdc.utils import cid2smiles smiles = cid2smiles(2248631) - @unittest.skip("dev test") + def test_uniprot2seq(self): from tdc.utils import uniprot2seq seq = uniprot2seq("P49122") - @unittest.skip("dev test") + def test_to_graph(self): from tdc.multi_pred import DTI @@ -95,7 +95,7 @@ def test_to_graph(self): ) # output: {'pyg_graph': the PyG graph object, 'index_to_entities': a dict map from ID in the data to node ID in the PyG object, 'split': {'train': df, 'valid': df, 'test': df}} - # @unittest.skip("dev test") + # def tearDown(self): print(os.getcwd()) diff --git a/tdc/test/dev_tests/utils_tests/test_splits.py b/tdc/test/dev_tests/utils_tests/test_splits.py index d5959867..fdf22ac1 100644 --- a/tdc/test/dev_tests/utils_tests/test_splits.py +++ b/tdc/test/dev_tests/utils_tests/test_splits.py @@ -21,14 +21,14 @@ def setUp(self): print(os.getcwd()) pass - @unittest.skip("dev test") + def test_random_split(self): from tdc.single_pred import ADME data = ADME(name="Caco2_Wang") split = data.get_split(method="random") - @unittest.skip("dev test") + def test_scaffold_split(self): ## requires RDKit from tdc.single_pred import ADME @@ -36,7 +36,7 @@ def test_scaffold_split(self): data = ADME(name="Caco2_Wang") split = data.get_split(method="scaffold") - @unittest.skip("dev test") + def test_cold_start_split(self): from tdc.multi_pred import DTI @@ -70,21 +70,21 @@ def test_cold_start_split(self): self.assertEqual(0, len(train_entity.intersection(test_entity))) self.assertEqual(0, len(valid_entity.intersection(test_entity))) - @unittest.skip("dev test") + def test_combination_split(self): from tdc.multi_pred import DrugSyn data = DrugSyn(name="DrugComb") split = data.get_split(method="combination") - @unittest.skip("dev test") + def test_time_split(self): from tdc.multi_pred import DTI data = DTI(name="BindingDB_Patent") split = data.get_split(method="time", time_column="Year") - @unittest.skip("dev test") + def test_tearDown(self): print(os.getcwd()) diff --git a/tdc/utils/label.py b/tdc/utils/label.py index 870f7801..b875d882 100644 --- a/tdc/utils/label.py +++ b/tdc/utils/label.py @@ -212,14 +212,17 @@ def NegSample(df, column_names, frac, two_types): for i in neg_list: neg_list_val.append([i[0], id2seq[i[0]], i[1], id2seq[i[1]], 0]) - df = df.append( - pd.DataFrame(neg_list_val).rename(columns={ - 0: id1, - 1: x1, - 2: id2, - 3: x2, - 4: "Y" - })).reset_index(drop=True) + df = pd.concat([ + df, + pd.DataFrame(neg_list_val).rename(columns={ + 0: id1, + 1: x1, + 2: id2, + 3: x2, + 4: "Y" + }) + ], + ignore_index=True, sort=False) return df else: df_unique_id1 = np.unique(df[id1].values.reshape(-1)) @@ -254,12 +257,15 @@ def NegSample(df, column_names, frac, two_types): for i in neg_list: neg_list_val.append([i[0], id2seq1[i[0]], i[1], id2seq2[i[1]], 0]) - df = df.append( + df = pd.concat([ + df, pd.DataFrame(neg_list_val).rename(columns={ 0: id1, 1: x1, 2: id2, 3: x2, 4: "Y" - })).reset_index(drop=True) + }) + ], + ignore_index=True, sort=False) return df From 657b6aeac60d9a6ec5f73737c9a91af78d24543f Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 5 Mar 2024 13:36:18 -0500 Subject: [PATCH 14/39] mend --- environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment.yml b/environment.yml index 5d460fe1..8a9106bb 100644 --- a/environment.yml +++ b/environment.yml @@ -2,6 +2,7 @@ name: tdc-conda-env channels: - conda-forge - defaults + - pyg - pytorch dependencies: - dataclasses=0.8 From 732101466e75b34eb0b500483430596ad952674f Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 5 Mar 2024 13:44:01 -0500 Subject: [PATCH 15/39] mend --- environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment.yml b/environment.yml index 8a9106bb..3fd80fab 100644 --- a/environment.yml +++ b/environment.yml @@ -12,6 +12,7 @@ dependencies: - python=3.9.13 - pip=23.3.1 - pandas=2.1.4 + - pydantic=2.6.3 - pyg=2.5.0 - pytorch=2.2.1 - requests=2.31.0 From abfbbbbea25bff295e70aef5635d6be49fc6e5a9 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 5 Mar 2024 13:46:05 -0500 Subject: [PATCH 16/39] mend --- environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index 3fd80fab..1cffd22d 100644 --- a/environment.yml +++ b/environment.yml @@ -12,7 +12,7 @@ dependencies: - python=3.9.13 - pip=23.3.1 - pandas=2.1.4 - - pydantic=2.6.3 + - pydantic=2.0.3 - pyg=2.5.0 - pytorch=2.2.1 - requests=2.31.0 From 953720fc2b3cccc7a8daa0ff48054e15e350ae06 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 5 Mar 2024 13:48:40 -0500 Subject: [PATCH 17/39] mend --- environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index 1cffd22d..30f22eb6 100644 --- a/environment.yml +++ b/environment.yml @@ -12,7 +12,6 @@ dependencies: - python=3.9.13 - pip=23.3.1 - pandas=2.1.4 - - pydantic=2.0.3 - pyg=2.5.0 - pytorch=2.2.1 - requests=2.31.0 @@ -24,6 +23,7 @@ dependencies: - pip: - cellxgene-census==1.10.2 - gget==0.28.4 + - pydantic==2.6.3 - rdkit==2023.9.5 - tiledbsoma==1.7.2 - yapf==0.40.2 From cc4f545d5e0c31d0f968654d77f13a025306769a Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 5 Mar 2024 14:11:47 -0500 Subject: [PATCH 18/39] add YAPF to GH Action --- .github/workflows/conda-tests.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/conda-tests.yml b/.github/workflows/conda-tests.yml index dc9a52b8..88d72159 100644 --- a/.github/workflows/conda-tests.yml +++ b/.github/workflows/conda-tests.yml @@ -36,10 +36,11 @@ jobs: auto-update-conda: true auto-activate-base: true - - name: Create and start Conda environment. Run tests. + - name: Create and start Conda environment. Run tests. Run YAPF run: | echo "Creating Conda Environment from environment.yml" conda env create -f environment.yml conda activate tdc-conda-env python run_tests.py + yapf --style=google -r -d tdc conda deactivate From 5b2f69a2fd9d7c904db050ad737a924f840c0e8a Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 5 Mar 2024 14:27:40 -0500 Subject: [PATCH 19/39] remaining lint errors --- tdc/multi_pred/__init__.py | 2 +- tdc/test/dev_tests/chem_utils_test/test_molconverter.py | 3 +-- tdc/test/dev_tests/chem_utils_test/test_molfilter.py | 3 +-- tdc/test/dev_tests/chem_utils_test/test_oracles.py | 2 -- tdc/test/dev_tests/utils_tests/test_misc_utils.py | 7 +------ tdc/test/dev_tests/utils_tests/test_splits.py | 6 ------ tdc/utils/label.py | 6 ++++-- 7 files changed, 8 insertions(+), 21 deletions(-) diff --git a/tdc/multi_pred/__init__.py b/tdc/multi_pred/__init__.py index 6d4527d6..94c58e5d 100644 --- a/tdc/multi_pred/__init__.py +++ b/tdc/multi_pred/__init__.py @@ -10,4 +10,4 @@ from .ppi import PPI from .test_multi_pred import TestMultiPred from .tcr_epi import TCREpitopeBinding -from .trialoutcome import TrialOutcome \ No newline at end of file +from .trialoutcome import TrialOutcome diff --git a/tdc/test/dev_tests/chem_utils_test/test_molconverter.py b/tdc/test/dev_tests/chem_utils_test/test_molconverter.py index c08b847a..bb925dda 100644 --- a/tdc/test/dev_tests/chem_utils_test/test_molconverter.py +++ b/tdc/test/dev_tests/chem_utils_test/test_molconverter.py @@ -21,7 +21,6 @@ def setUp(self): print(os.getcwd()) pass - def test_MolConvert(self): from tdc.chem_utils import MolConvert @@ -35,7 +34,7 @@ def test_MolConvert(self): MolConvert.eligible_format() - # + # def tearDown(self): print(os.getcwd()) diff --git a/tdc/test/dev_tests/chem_utils_test/test_molfilter.py b/tdc/test/dev_tests/chem_utils_test/test_molfilter.py index 95bf402c..c9fbcc1b 100644 --- a/tdc/test/dev_tests/chem_utils_test/test_molfilter.py +++ b/tdc/test/dev_tests/chem_utils_test/test_molfilter.py @@ -21,14 +21,13 @@ def setUp(self): print(os.getcwd()) pass - def test_MolConvert(self): from tdc.chem_utils import MolFilter filters = MolFilter(filters=["PAINS"], HBD=[0, 6]) filters(["CCSc1ccccc1C(=O)Nc1onc2c1CCC2"]) - # + # def tearDown(self): print(os.getcwd()) diff --git a/tdc/test/dev_tests/chem_utils_test/test_oracles.py b/tdc/test/dev_tests/chem_utils_test/test_oracles.py index 884e8a55..48c4ac48 100644 --- a/tdc/test/dev_tests/chem_utils_test/test_oracles.py +++ b/tdc/test/dev_tests/chem_utils_test/test_oracles.py @@ -21,7 +21,6 @@ def setUp(self): print(os.getcwd()) pass - def test_Oracle(self): from tdc import Oracle @@ -35,7 +34,6 @@ def test_Oracle(self): oracle = Oracle(name="Hop") x = oracle(["CC(=O)OC1=CC=CC=C1C(=O)O", "C1=CC=C(C=C1)C=O"]) - def test_distribution(self): from tdc import Evaluator diff --git a/tdc/test/dev_tests/utils_tests/test_misc_utils.py b/tdc/test/dev_tests/utils_tests/test_misc_utils.py index c28e1bdb..8c4b1f17 100644 --- a/tdc/test/dev_tests/utils_tests/test_misc_utils.py +++ b/tdc/test/dev_tests/utils_tests/test_misc_utils.py @@ -33,7 +33,6 @@ def test_neg_sample(self): # data = ADME(name='Caco2_Wang') # x = data.label_distribution() - def test_get_label_map(self): from tdc.multi_pred import DDI from tdc.utils import get_label_map @@ -42,26 +41,22 @@ def test_get_label_map(self): split = data.get_split() get_label_map(name="DrugBank", task="DDI") - def test_balanced(self): from tdc.single_pred import HTS data = HTS(name="SARSCoV2_3CLPro_Diamond") data.balanced(oversample=True, seed=42) - def test_cid2smiles(self): from tdc.utils import cid2smiles smiles = cid2smiles(2248631) - def test_uniprot2seq(self): from tdc.utils import uniprot2seq seq = uniprot2seq("P49122") - def test_to_graph(self): from tdc.multi_pred import DTI @@ -95,7 +90,7 @@ def test_to_graph(self): ) # output: {'pyg_graph': the PyG graph object, 'index_to_entities': a dict map from ID in the data to node ID in the PyG object, 'split': {'train': df, 'valid': df, 'test': df}} - # + # def tearDown(self): print(os.getcwd()) diff --git a/tdc/test/dev_tests/utils_tests/test_splits.py b/tdc/test/dev_tests/utils_tests/test_splits.py index fdf22ac1..69759416 100644 --- a/tdc/test/dev_tests/utils_tests/test_splits.py +++ b/tdc/test/dev_tests/utils_tests/test_splits.py @@ -21,14 +21,12 @@ def setUp(self): print(os.getcwd()) pass - def test_random_split(self): from tdc.single_pred import ADME data = ADME(name="Caco2_Wang") split = data.get_split(method="random") - def test_scaffold_split(self): ## requires RDKit from tdc.single_pred import ADME @@ -36,7 +34,6 @@ def test_scaffold_split(self): data = ADME(name="Caco2_Wang") split = data.get_split(method="scaffold") - def test_cold_start_split(self): from tdc.multi_pred import DTI @@ -70,21 +67,18 @@ def test_cold_start_split(self): self.assertEqual(0, len(train_entity.intersection(test_entity))) self.assertEqual(0, len(valid_entity.intersection(test_entity))) - def test_combination_split(self): from tdc.multi_pred import DrugSyn data = DrugSyn(name="DrugComb") split = data.get_split(method="combination") - def test_time_split(self): from tdc.multi_pred import DTI data = DTI(name="BindingDB_Patent") split = data.get_split(method="time", time_column="Year") - def test_tearDown(self): print(os.getcwd()) diff --git a/tdc/utils/label.py b/tdc/utils/label.py index b875d882..6814633d 100644 --- a/tdc/utils/label.py +++ b/tdc/utils/label.py @@ -222,7 +222,8 @@ def NegSample(df, column_names, frac, two_types): 4: "Y" }) ], - ignore_index=True, sort=False) + ignore_index=True, + sort=False) return df else: df_unique_id1 = np.unique(df[id1].values.reshape(-1)) @@ -267,5 +268,6 @@ def NegSample(df, column_names, frac, two_types): 4: "Y" }) ], - ignore_index=True, sort=False) + ignore_index=True, + sort=False) return df From ef3ecae9aa78857b70121025293c2a1fe3bbfd16 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 27 Feb 2024 18:54:51 -0500 Subject: [PATCH 20/39] simple testing on cellxgene api --- tdc/cellxgene-census-loaders/__init__.py | 0 .../cellxgene-census.py | 78 +++++++++++++++++++ 2 files changed, 78 insertions(+) create mode 100644 tdc/cellxgene-census-loaders/__init__.py create mode 100644 tdc/cellxgene-census-loaders/cellxgene-census.py diff --git a/tdc/cellxgene-census-loaders/__init__.py b/tdc/cellxgene-census-loaders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tdc/cellxgene-census-loaders/cellxgene-census.py b/tdc/cellxgene-census-loaders/cellxgene-census.py new file mode 100644 index 00000000..7be20608 --- /dev/null +++ b/tdc/cellxgene-census-loaders/cellxgene-census.py @@ -0,0 +1,78 @@ +import cellxgene_census +from pandas import concat +import tiledbsoma + +from tdc import base_dataset +""" + +Are we only supporting memory-efficient queries? +https://chanzuckerberg.github.io/cellxgene-census/cellxgene_census_docsite_quick_start.html#memory-efficient-queries + + +""" + + +class CXGDataLoader(base_dataset.DataLoader): + + def __init__(self, + num_slices=None, + census_version="2023-12-15", + dataset="census_data", + organism="homo_sapiens", + measurement_name="RNA", + value_filter="", + column_names=None): + if column_names is None: + raise ValueError("column_names is required for this loader") + self.column_names = column_names + num_slices = num_slices if num_slices is not None else 1 + self.num_slices = num_slices + self.df = None + self.fetch_data(census_version, dataset, organism, measurement_name, + value_filter) + + def fetch_data(self, census_version, dataset, organism, measurement_name, + value_filter): + """TODO: docs + outputs a dataframe with specified query params on census data SOMA collection object + """ + if self.column_names is None: + raise ValueError( + "Column names must be provided to CXGDataLoader class") + + with cellxgene_census.open_soma( + census_version=census_version) as census: + # Reads SOMADataFrame as a slice + cell_metadata = census[dataset][organism].obs.read( + value_filter=value_filter, column_names=self.column_names) + self.df = cell_metadata.concat().to_pandas() + # TODO: not latency on memory-efficient queries is poor... + # organismCollection = census[dataset][organism] + # query = organismCollection.axis_query( + # measurement_name = measurement_name, + # obs_query = tiledbsoma.AxisQuery( + # value_filter = value_filter + # ) + # ) + # it = query.X("raw").tables() + # dfs =[] + # for _ in range(self.num_slices): + # slice = next (it) + # df_slice = slice.to_pandas() + # dfs.append(df_slice) + # self.df = concat(dfs) + + def get_dataframe(self): + if self.df is None: + raise Exception( + "Haven't instantiated a DataFrame yet. You can call self.fetch_data first." + ) + return self.df + + +if __name__ == "__main__": + # TODO: tmp, run testing suite when this file is called as main + loader = CXGDataLoader(value_filter="tissue == 'brain' and sex == 'male'", + column_names=["assay", "cell_type", "tissue"]) + df = loader.get_dataframe() + print(df.head()) From 86f567095cc22b59e5a5d630b9c10fc9f250003b Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Mon, 4 Mar 2024 20:14:43 -0500 Subject: [PATCH 21/39] implement memory-efficient retrieval of count matrix and update environment. Also support AnnData retrieval --- environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment.yml b/environment.yml index 30f22eb6..e6464a30 100644 --- a/environment.yml +++ b/environment.yml @@ -24,6 +24,7 @@ dependencies: - cellxgene-census==1.10.2 - gget==0.28.4 - pydantic==2.6.3 + - gget==0.28.4 - rdkit==2023.9.5 - tiledbsoma==1.7.2 - yapf==0.40.2 From 5ce09960fd3c81b0b70ca8872a1d3d4ccdd57ec8 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 5 Mar 2024 14:43:56 -0500 Subject: [PATCH 22/39] mend --- .github/workflows/conda-tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/conda-tests.yml b/.github/workflows/conda-tests.yml index 88d72159..0916075b 100644 --- a/.github/workflows/conda-tests.yml +++ b/.github/workflows/conda-tests.yml @@ -8,6 +8,7 @@ on: push: branches: - main + - avelez-cellxgene-dev - avelez-dev workflow_dispatch: From 23a0ae30fda99cda72eb93f54e47587c347dbfa9 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 5 Mar 2024 14:45:12 -0500 Subject: [PATCH 23/39] run conda-tests on all branches --- .github/workflows/conda-tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/conda-tests.yml b/.github/workflows/conda-tests.yml index 0916075b..86c934b6 100644 --- a/.github/workflows/conda-tests.yml +++ b/.github/workflows/conda-tests.yml @@ -10,6 +10,7 @@ on: - main - avelez-cellxgene-dev - avelez-dev + - '*' workflow_dispatch: jobs: From 4c95fdef8cc438508a46cba11c036f59e566986e Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 5 Mar 2024 13:31:00 -0500 Subject: [PATCH 24/39] completed google lint. added torch dependencies to conda. tests pass. unskipped all but one. --- tdc/test/dev_tests/chem_utils_test/test_molconverter.py | 3 ++- tdc/test/dev_tests/chem_utils_test/test_molfilter.py | 3 ++- tdc/test/dev_tests/chem_utils_test/test_oracles.py | 2 ++ tdc/test/dev_tests/utils_tests/test_misc_utils.py | 7 ++++++- tdc/test/dev_tests/utils_tests/test_splits.py | 6 ++++++ tdc/utils/label.py | 8 ++++---- 6 files changed, 22 insertions(+), 7 deletions(-) diff --git a/tdc/test/dev_tests/chem_utils_test/test_molconverter.py b/tdc/test/dev_tests/chem_utils_test/test_molconverter.py index bb925dda..c08b847a 100644 --- a/tdc/test/dev_tests/chem_utils_test/test_molconverter.py +++ b/tdc/test/dev_tests/chem_utils_test/test_molconverter.py @@ -21,6 +21,7 @@ def setUp(self): print(os.getcwd()) pass + def test_MolConvert(self): from tdc.chem_utils import MolConvert @@ -34,7 +35,7 @@ def test_MolConvert(self): MolConvert.eligible_format() - # + # def tearDown(self): print(os.getcwd()) diff --git a/tdc/test/dev_tests/chem_utils_test/test_molfilter.py b/tdc/test/dev_tests/chem_utils_test/test_molfilter.py index c9fbcc1b..95bf402c 100644 --- a/tdc/test/dev_tests/chem_utils_test/test_molfilter.py +++ b/tdc/test/dev_tests/chem_utils_test/test_molfilter.py @@ -21,13 +21,14 @@ def setUp(self): print(os.getcwd()) pass + def test_MolConvert(self): from tdc.chem_utils import MolFilter filters = MolFilter(filters=["PAINS"], HBD=[0, 6]) filters(["CCSc1ccccc1C(=O)Nc1onc2c1CCC2"]) - # + # def tearDown(self): print(os.getcwd()) diff --git a/tdc/test/dev_tests/chem_utils_test/test_oracles.py b/tdc/test/dev_tests/chem_utils_test/test_oracles.py index 48c4ac48..884e8a55 100644 --- a/tdc/test/dev_tests/chem_utils_test/test_oracles.py +++ b/tdc/test/dev_tests/chem_utils_test/test_oracles.py @@ -21,6 +21,7 @@ def setUp(self): print(os.getcwd()) pass + def test_Oracle(self): from tdc import Oracle @@ -34,6 +35,7 @@ def test_Oracle(self): oracle = Oracle(name="Hop") x = oracle(["CC(=O)OC1=CC=CC=C1C(=O)O", "C1=CC=C(C=C1)C=O"]) + def test_distribution(self): from tdc import Evaluator diff --git a/tdc/test/dev_tests/utils_tests/test_misc_utils.py b/tdc/test/dev_tests/utils_tests/test_misc_utils.py index 8c4b1f17..c28e1bdb 100644 --- a/tdc/test/dev_tests/utils_tests/test_misc_utils.py +++ b/tdc/test/dev_tests/utils_tests/test_misc_utils.py @@ -33,6 +33,7 @@ def test_neg_sample(self): # data = ADME(name='Caco2_Wang') # x = data.label_distribution() + def test_get_label_map(self): from tdc.multi_pred import DDI from tdc.utils import get_label_map @@ -41,22 +42,26 @@ def test_get_label_map(self): split = data.get_split() get_label_map(name="DrugBank", task="DDI") + def test_balanced(self): from tdc.single_pred import HTS data = HTS(name="SARSCoV2_3CLPro_Diamond") data.balanced(oversample=True, seed=42) + def test_cid2smiles(self): from tdc.utils import cid2smiles smiles = cid2smiles(2248631) + def test_uniprot2seq(self): from tdc.utils import uniprot2seq seq = uniprot2seq("P49122") + def test_to_graph(self): from tdc.multi_pred import DTI @@ -90,7 +95,7 @@ def test_to_graph(self): ) # output: {'pyg_graph': the PyG graph object, 'index_to_entities': a dict map from ID in the data to node ID in the PyG object, 'split': {'train': df, 'valid': df, 'test': df}} - # + # def tearDown(self): print(os.getcwd()) diff --git a/tdc/test/dev_tests/utils_tests/test_splits.py b/tdc/test/dev_tests/utils_tests/test_splits.py index 69759416..fdf22ac1 100644 --- a/tdc/test/dev_tests/utils_tests/test_splits.py +++ b/tdc/test/dev_tests/utils_tests/test_splits.py @@ -21,12 +21,14 @@ def setUp(self): print(os.getcwd()) pass + def test_random_split(self): from tdc.single_pred import ADME data = ADME(name="Caco2_Wang") split = data.get_split(method="random") + def test_scaffold_split(self): ## requires RDKit from tdc.single_pred import ADME @@ -34,6 +36,7 @@ def test_scaffold_split(self): data = ADME(name="Caco2_Wang") split = data.get_split(method="scaffold") + def test_cold_start_split(self): from tdc.multi_pred import DTI @@ -67,18 +70,21 @@ def test_cold_start_split(self): self.assertEqual(0, len(train_entity.intersection(test_entity))) self.assertEqual(0, len(valid_entity.intersection(test_entity))) + def test_combination_split(self): from tdc.multi_pred import DrugSyn data = DrugSyn(name="DrugComb") split = data.get_split(method="combination") + def test_time_split(self): from tdc.multi_pred import DTI data = DTI(name="BindingDB_Patent") split = data.get_split(method="time", time_column="Year") + def test_tearDown(self): print(os.getcwd()) diff --git a/tdc/utils/label.py b/tdc/utils/label.py index 6814633d..0c098694 100644 --- a/tdc/utils/label.py +++ b/tdc/utils/label.py @@ -222,8 +222,7 @@ def NegSample(df, column_names, frac, two_types): 4: "Y" }) ], - ignore_index=True, - sort=False) + ignore_index=True, sort=False) return df else: df_unique_id1 = np.unique(df[id1].values.reshape(-1)) @@ -258,6 +257,8 @@ def NegSample(df, column_names, frac, two_types): for i in neg_list: neg_list_val.append([i[0], id2seq1[i[0]], i[1], id2seq2[i[1]], 0]) + df = pd.concat([ + df, df = pd.concat([ df, pd.DataFrame(neg_list_val).rename(columns={ @@ -268,6 +269,5 @@ def NegSample(df, column_names, frac, two_types): 4: "Y" }) ], - ignore_index=True, - sort=False) + ignore_index=True, sort=False) return df From 0eed2380f1f097c9df7fdd8c31775cd47361cd92 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 5 Mar 2024 13:48:40 -0500 Subject: [PATCH 25/39] mend --- environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment.yml b/environment.yml index e6464a30..f8acc6f5 100644 --- a/environment.yml +++ b/environment.yml @@ -25,6 +25,7 @@ dependencies: - gget==0.28.4 - pydantic==2.6.3 - gget==0.28.4 + - pydantic==2.6.3 - rdkit==2023.9.5 - tiledbsoma==1.7.2 - yapf==0.40.2 From d04ad8aa2d441fbfb1d277c463899826c416d1b1 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Mon, 4 Mar 2024 20:14:43 -0500 Subject: [PATCH 26/39] implement memory-efficient retrieval of count matrix and update environment. Also support AnnData retrieval --- environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment.yml b/environment.yml index f8acc6f5..c17a2f67 100644 --- a/environment.yml +++ b/environment.yml @@ -26,6 +26,7 @@ dependencies: - pydantic==2.6.3 - gget==0.28.4 - pydantic==2.6.3 + - gget==0.28.4 - rdkit==2023.9.5 - tiledbsoma==1.7.2 - yapf==0.40.2 From 7e0fc32eb28b49c24546138158fa04a5750b92b3 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 5 Mar 2024 13:48:40 -0500 Subject: [PATCH 27/39] mend --- environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment.yml b/environment.yml index c17a2f67..f4b93b24 100644 --- a/environment.yml +++ b/environment.yml @@ -27,6 +27,7 @@ dependencies: - gget==0.28.4 - pydantic==2.6.3 - gget==0.28.4 + - pydantic==2.6.3 - rdkit==2023.9.5 - tiledbsoma==1.7.2 - yapf==0.40.2 From 1ddc8289b7f9d5b0d1f432b7e5ce12f8a92e8923 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 5 Mar 2024 15:22:48 -0500 Subject: [PATCH 28/39] merge and lint --- tdc/resource/cellxgene-census.py | 180 +++++++++++------- .../chem_utils_test/test_molconverter.py | 3 +- .../chem_utils_test/test_molfilter.py | 3 +- .../dev_tests/chem_utils_test/test_oracles.py | 2 - .../dev_tests/utils_tests/test_misc_utils.py | 7 +- tdc/test/dev_tests/utils_tests/test_splits.py | 6 - tdc/utils/label.py | 16 +- 7 files changed, 119 insertions(+), 98 deletions(-) diff --git a/tdc/resource/cellxgene-census.py b/tdc/resource/cellxgene-census.py index d400e608..d09eee9d 100644 --- a/tdc/resource/cellxgene-census.py +++ b/tdc/resource/cellxgene-census.py @@ -1,5 +1,6 @@ # TODO: tmp fix import os + os.environ['KMP_DUPLICATE_LIB_OK'] = "TRUE" # TODO: find better fix or encode in environment / docker ^^^ import cellxgene_census @@ -8,7 +9,7 @@ class CensusResource: - + _CENSUS_DATA = "census_data" _CENSUS_META = "census_info" _FEATURE_PRESENCE = "feature_dataset_presence_matrix" @@ -16,27 +17,28 @@ class CensusResource: _HUMAN = "homo_sapiens" class decorators: + @classmethod - def check_dataset_is_census_data(cls,func): + def check_dataset_is_census_data(cls, func): # @wraps(func) def check(*args, **kwargs): self = args[0] self.dataset = self._CENSUS_DATA return func(*args, **kwargs) + return check @classmethod - def check_dataset_is_census_info(cls,func): + def check_dataset_is_census_info(cls, func): + def check(*args, **kwargs): self = args[0] self.dataset = self._CENSUS_META return func(*args, **kwargs) + return check - def __init__(self, - census_version=None, - organism=None - ): + def __init__(self, census_version=None, organism=None): """Initialize the Census Resource. Args: @@ -45,11 +47,13 @@ def __init__(self, """ self.census_version = census_version if census_version is not None else self._LATEST_CENSUS self.organism = organism if organism is not None else self._HUMAN - self.dataset = None # variable to set target census collection to either info or data + self.dataset = None # variable to set target census collection to either info or data def fmt_cellxgene_data(self, tiledb_ptr, fmt=None): if fmt is None: - raise Exception("format not provided to fmt_cellxgene_data(), please provide fmt variable") + raise Exception( + "format not provided to fmt_cellxgene_data(), please provide fmt variable" + ) elif fmt == "pandas": return tiledb_ptr.concat().to_pandas() elif fmt == "pyarrow": @@ -57,13 +61,15 @@ def fmt_cellxgene_data(self, tiledb_ptr, fmt=None): elif fmt == "scipy": return tiledb_ptr.concat().to_scipy() else: - raise Exception("fmt not in [pandas, pyarrow, scipy] for fmt_cellxgene_data()") - + raise Exception( + "fmt not in [pandas, pyarrow, scipy] for fmt_cellxgene_data()") + @decorators.check_dataset_is_census_data def get_cell_metadata(self, value_filter=None, column_names=None, fmt=None): """Get the cell metadata (obs) data from the Census API""" if value_filter is None: - raise Exception("No value filter was provided, dataset is too large to fit in memory. \ + raise Exception( + "No value filter was provided, dataset is too large to fit in memory. \ Memory-Efficient queries are not supported yet.") fmt = fmt if fmt is not None else "pandas" with cellxgene_census.open_soma( @@ -71,33 +77,45 @@ def get_cell_metadata(self, value_filter=None, column_names=None, fmt=None): obs = census[self.dataset][self.organism].obs obsread = None if column_names: - obsread = obs.read(value_filter=value_filter, column_names=column_names) + obsread = obs.read(value_filter=value_filter, + column_names=column_names) else: obsread = obs.read(value_filter=value_filter) return self.fmt_cellxgene_data(obsread, fmt) - + @decorators.check_dataset_is_census_data - def get_gene_metadata(self, value_filter=None, column_names=None, measurement_name=None, fmt=None): + def get_gene_metadata(self, + value_filter=None, + column_names=None, + measurement_name=None, + fmt=None): """Get the gene metadata (var) data from the Census API""" if value_filter is None: - raise Exception("No value filter was provided, dataset is too large to fit in memory. \ + raise Exception( + "No value filter was provided, dataset is too large to fit in memory. \ Memory-Efficient queries are not supported yet.") elif measurement_name is None: raise ValueError("measurment_name must be provided , i.e. 'RNA'") fmt = fmt if fmt is not None else "pandas" with cellxgene_census.open_soma( - census_version=self.census_version - ) as census: - var = census[self.dataset][self.organism].ms[measurement_name].var + census_version=self.census_version) as census: + var = census[self.dataset][self.organism].ms[measurement_name].var varread = None if column_names: - varread = var.read(value_filter=value_filter, column_names=column_names) + varread = var.read(value_filter=value_filter, + column_names=column_names) else: varread = var.read(value_filter=value_filter) return self.fmt_cellxgene_data(varread, fmt) - + @decorators.check_dataset_is_census_data - def get_measurement_matrix(self, upper=None, lower=None, value_adjustment=None, measurement_name=None, fmt=None, todense=None): + def get_measurement_matrix(self, + upper=None, + lower=None, + value_adjustment=None, + measurement_name=None, + fmt=None, + todense=None): """Count matrix for an input measurement by slice Args: @@ -111,47 +129,61 @@ def get_measurement_matrix(self, upper=None, lower=None, value_adjustment=None, Exception: _description_ """ if upper is None or lower is None: - raise Exception("No upper and/or lower bound for slicing was provided. Dataset is too large to fit in memory. \ + raise Exception( + "No upper and/or lower bound for slicing was provided. Dataset is too large to fit in memory. \ Memory-Efficient queries are not supported yet.") elif measurement_name is None: raise Exception("measurement_name was not provided.") elif fmt is not None and fmt not in ["scipy", "pyarrow"]: - raise ValueError("measurement_matrix only supports 'scipy' or 'pyarrow' format") + raise ValueError( + "measurement_matrix only supports 'scipy' or 'pyarrow' format") value_adjustment = value_adjustment if value_adjustment is not None else "raw" todense = todense if todense is not None else False fmt = fmt if fmt is not None else "scipy" if todense and fmt != "scipy": - raise ValueError("dense representation only available in scipy format") + raise ValueError( + "dense representation only available in scipy format") with cellxgene_census.open_soma( - census_version=self.census_version - ) as census: + census_version=self.census_version) as census: n_obs = len(census[self.dataset][self.organism].obs) - n_var = len(census[self.dataset][self.organism].ms[measurement_name].var) - X = census[self.dataset][self.organism].ms[measurement_name].X[value_adjustment] + n_var = len( + census[self.dataset][self.organism].ms[measurement_name].var) + X = census[self.dataset][ + self.organism].ms[measurement_name].X[value_adjustment] slc = X.read((slice(lower, upper),)).coos((n_obs, n_var)) out = self.fmt_cellxgene_data(slc, fmt) return out if not todense else out.todense() - - @decorators.check_dataset_is_census_data - def get_feature_dataset_presence_matrix(self, upper=None, lower=None, measurement_name=None, fmt=None, todense=None): + + @decorators.check_dataset_is_census_data + def get_feature_dataset_presence_matrix(self, + upper=None, + lower=None, + measurement_name=None, + fmt=None, + todense=None): if upper is None or lower is None: - raise ValueError("No upper and/or lower bound for slicing was provided. Dataset is too large to fit in memory. \ + raise ValueError( + "No upper and/or lower bound for slicing was provided. Dataset is too large to fit in memory. \ Memory-Efficient queries are not supported yet.") elif measurement_name is None: raise ValueError("measurement_name was not provided") elif fmt is not None and fmt not in ["scipy", "pyarrow"]: - raise ValueError("feature dataset presence matrix only supports 'scipy' or 'pyarrow' formats") + raise ValueError( + "feature dataset presence matrix only supports 'scipy' or 'pyarrow' formats" + ) todense = todense if todense is not None else False fmt = fmt if fmt is not None else "scipy" - if todense and fmt!="scipy": - raise ValueError("dense representation only available in scipy format") + if todense and fmt != "scipy": + raise ValueError( + "dense representation only available in scipy format") with cellxgene_census.open_soma( - census_version=self.census_version - ) as census: + census_version=self.census_version) as census: n_obs = len(census[self.dataset][self.organism].obs) - n_var = len(census[self.dataset][self.organism].ms[measurement_name].var) - fMatrix = census[self.dataset][self.organism].ms[measurement_name]["feature_dataset_presence_matrix"] - slc = fMatrix.read((slice(0, 5),)).coos((n_obs,n_var)) + n_var = len( + census[self.dataset][self.organism].ms[measurement_name].var) + fMatrix = census[self.dataset][self.organism].ms[measurement_name][ + "feature_dataset_presence_matrix"] + slc = fMatrix.read((slice(0, 5),)).coos((n_obs, n_var)) out = self.fmt_cellxgene_data(slc, fmt) return out if not todense else out.todense() @@ -159,35 +191,30 @@ def get_feature_dataset_presence_matrix(self, upper=None, lower=None, measuremen def get_metadata(self): """Get the metadata for the Cell Census.""" with cellxgene_census.open_soma( - census_version=self.census_version - ) as census: + census_version=self.census_version) as census: return census[self.dataset]["summary"] @decorators.check_dataset_is_census_info def get_dataset_metadata(self): """Get the metadata for the Cell Census's datasets.""" with cellxgene_census.open_soma( - census_version=self.census_version - ) as census: + census_version=self.census_version) as census: return census[self.dataset]["datasets"] - + @decorators.check_dataset_is_census_info def get_cell_count_metadata(self): """Get the cell counts across cell metadata for the Cell Census.""" with cellxgene_census.open_soma( - census_version=self.census_version - ) as census: + census_version=self.census_version) as census: return census[self.dataset]["summary_cell_counts"] @decorators.check_dataset_is_census_data - def query_measurement_matrix( - self, - value_filter=None, - value_adjustment=None, - measurement_name=None, - fmt=None, - todense=None - ): + def query_measurement_matrix(self, + value_filter=None, + value_adjustment=None, + measurement_name=None, + fmt=None, + todense=None): """Query the Census Measurement Matrix. Function returns a Python generator. Args: @@ -207,32 +234,31 @@ def query_measurement_matrix( a slice of the output query in the specified format """ if value_filter is None: - raise ValueError("query_measurement_matrix expects a value_filter. if you don't plan to apply a filter, use get_measurement_matrix()") + raise ValueError( + "query_measurement_matrix expects a value_filter. if you don't plan to apply a filter, use get_measurement_matrix()" + ) elif measurement_name is None: raise Exception("measurement_name was not provided.") elif fmt is not None and fmt not in ["scipy", "pyarrow"]: - raise ValueError("measurement_matrix only supports 'scipy' or 'pyarrow' format") + raise ValueError( + "measurement_matrix only supports 'scipy' or 'pyarrow' format") value_adjustment = value_adjustment if value_adjustment is not None else "raw" todense = todense if todense is not None else False fmt = fmt if fmt is not None else "scipy" if todense and fmt != "scipy": - raise ValueError("dense representation only available in scipy format") + raise ValueError( + "dense representation only available in scipy format") with cellxgene_census.open_soma( - census_version=self.census_version - ) as census: + census_version=self.census_version) as census: organism = census[self.dataset][self.organism] query = organism.axis_query( - measurement_name = measurement_name, - obs_query = tiledbsoma.AxisQuery( - value_filter = value_filter - ) - ) + measurement_name=measurement_name, + obs_query=tiledbsoma.AxisQuery(value_filter=value_filter)) it = query.X(value_adjustment).tables() for slc in it: out = self.fmt_cellxgene_data(slc, fmt) out = out if not todense else out.todense() yield out - @classmethod def gget_czi_cellxgene(cls, **kwargs): @@ -298,19 +324,29 @@ def gget_czi_cellxgene(cls, **kwargs): gene_value_filter = "feature_id in ['ENSG00000161798', 'ENSG00000188229']" gene_column_names = ["feature_name", "feature_length"] print("getting cell metadata as pandas dataframe") - obsdf = resource.get_cell_metadata(value_filter=cell_value_filter, column_names=cell_column_names, fmt="pandas") + obsdf = resource.get_cell_metadata(value_filter=cell_value_filter, + column_names=cell_column_names, + fmt="pandas") print("success!") print(obsdf.head()) print("geting gene metadata as pyarrow") - varpyarrow = resource.get_gene_metadata(value_filter=gene_value_filter, column_names=gene_column_names, fmt="pyarrow", measurement_name="RNA") + varpyarrow = resource.get_gene_metadata(value_filter=gene_value_filter, + column_names=gene_column_names, + fmt="pyarrow", + measurement_name="RNA") print("success!") print(varpyarrow) print("getting sample count matrix, checking todense() and scipy") - Xslice = resource.get_measurement_matrix(upper=5, lower=0, measurement_name="RNA", fmt="scipy", todense=True) + Xslice = resource.get_measurement_matrix(upper=5, + lower=0, + measurement_name="RNA", + fmt="scipy", + todense=True) print("success") print(Xslice) print("getting feature presence matrix, checking pyarrow") - FMslice = resource.get_feature_dataset_presence_matrix(upper=5, lower=0, measurement_name="RNA", fmt="pyarrow", todense=False) + FMslice = resource.get_feature_dataset_presence_matrix( + upper=5, lower=0, measurement_name="RNA", fmt="pyarrow", todense=False) print("success") print(FMslice) - print("all tests passed") \ No newline at end of file + print("all tests passed") diff --git a/tdc/test/dev_tests/chem_utils_test/test_molconverter.py b/tdc/test/dev_tests/chem_utils_test/test_molconverter.py index c08b847a..bb925dda 100644 --- a/tdc/test/dev_tests/chem_utils_test/test_molconverter.py +++ b/tdc/test/dev_tests/chem_utils_test/test_molconverter.py @@ -21,7 +21,6 @@ def setUp(self): print(os.getcwd()) pass - def test_MolConvert(self): from tdc.chem_utils import MolConvert @@ -35,7 +34,7 @@ def test_MolConvert(self): MolConvert.eligible_format() - # + # def tearDown(self): print(os.getcwd()) diff --git a/tdc/test/dev_tests/chem_utils_test/test_molfilter.py b/tdc/test/dev_tests/chem_utils_test/test_molfilter.py index 95bf402c..c9fbcc1b 100644 --- a/tdc/test/dev_tests/chem_utils_test/test_molfilter.py +++ b/tdc/test/dev_tests/chem_utils_test/test_molfilter.py @@ -21,14 +21,13 @@ def setUp(self): print(os.getcwd()) pass - def test_MolConvert(self): from tdc.chem_utils import MolFilter filters = MolFilter(filters=["PAINS"], HBD=[0, 6]) filters(["CCSc1ccccc1C(=O)Nc1onc2c1CCC2"]) - # + # def tearDown(self): print(os.getcwd()) diff --git a/tdc/test/dev_tests/chem_utils_test/test_oracles.py b/tdc/test/dev_tests/chem_utils_test/test_oracles.py index 884e8a55..48c4ac48 100644 --- a/tdc/test/dev_tests/chem_utils_test/test_oracles.py +++ b/tdc/test/dev_tests/chem_utils_test/test_oracles.py @@ -21,7 +21,6 @@ def setUp(self): print(os.getcwd()) pass - def test_Oracle(self): from tdc import Oracle @@ -35,7 +34,6 @@ def test_Oracle(self): oracle = Oracle(name="Hop") x = oracle(["CC(=O)OC1=CC=CC=C1C(=O)O", "C1=CC=C(C=C1)C=O"]) - def test_distribution(self): from tdc import Evaluator diff --git a/tdc/test/dev_tests/utils_tests/test_misc_utils.py b/tdc/test/dev_tests/utils_tests/test_misc_utils.py index c28e1bdb..8c4b1f17 100644 --- a/tdc/test/dev_tests/utils_tests/test_misc_utils.py +++ b/tdc/test/dev_tests/utils_tests/test_misc_utils.py @@ -33,7 +33,6 @@ def test_neg_sample(self): # data = ADME(name='Caco2_Wang') # x = data.label_distribution() - def test_get_label_map(self): from tdc.multi_pred import DDI from tdc.utils import get_label_map @@ -42,26 +41,22 @@ def test_get_label_map(self): split = data.get_split() get_label_map(name="DrugBank", task="DDI") - def test_balanced(self): from tdc.single_pred import HTS data = HTS(name="SARSCoV2_3CLPro_Diamond") data.balanced(oversample=True, seed=42) - def test_cid2smiles(self): from tdc.utils import cid2smiles smiles = cid2smiles(2248631) - def test_uniprot2seq(self): from tdc.utils import uniprot2seq seq = uniprot2seq("P49122") - def test_to_graph(self): from tdc.multi_pred import DTI @@ -95,7 +90,7 @@ def test_to_graph(self): ) # output: {'pyg_graph': the PyG graph object, 'index_to_entities': a dict map from ID in the data to node ID in the PyG object, 'split': {'train': df, 'valid': df, 'test': df}} - # + # def tearDown(self): print(os.getcwd()) diff --git a/tdc/test/dev_tests/utils_tests/test_splits.py b/tdc/test/dev_tests/utils_tests/test_splits.py index fdf22ac1..69759416 100644 --- a/tdc/test/dev_tests/utils_tests/test_splits.py +++ b/tdc/test/dev_tests/utils_tests/test_splits.py @@ -21,14 +21,12 @@ def setUp(self): print(os.getcwd()) pass - def test_random_split(self): from tdc.single_pred import ADME data = ADME(name="Caco2_Wang") split = data.get_split(method="random") - def test_scaffold_split(self): ## requires RDKit from tdc.single_pred import ADME @@ -36,7 +34,6 @@ def test_scaffold_split(self): data = ADME(name="Caco2_Wang") split = data.get_split(method="scaffold") - def test_cold_start_split(self): from tdc.multi_pred import DTI @@ -70,21 +67,18 @@ def test_cold_start_split(self): self.assertEqual(0, len(train_entity.intersection(test_entity))) self.assertEqual(0, len(valid_entity.intersection(test_entity))) - def test_combination_split(self): from tdc.multi_pred import DrugSyn data = DrugSyn(name="DrugComb") split = data.get_split(method="combination") - def test_time_split(self): from tdc.multi_pred import DTI data = DTI(name="BindingDB_Patent") split = data.get_split(method="time", time_column="Year") - def test_tearDown(self): print(os.getcwd()) diff --git a/tdc/utils/label.py b/tdc/utils/label.py index 0c098694..d5de77f2 100644 --- a/tdc/utils/label.py +++ b/tdc/utils/label.py @@ -212,7 +212,7 @@ def NegSample(df, column_names, frac, two_types): for i in neg_list: neg_list_val.append([i[0], id2seq[i[0]], i[1], id2seq[i[1]], 0]) - df = pd.concat([ + df2 = pd.concat([ df, pd.DataFrame(neg_list_val).rename(columns={ 0: id1, @@ -222,8 +222,9 @@ def NegSample(df, column_names, frac, two_types): 4: "Y" }) ], - ignore_index=True, sort=False) - return df + ignore_index=True, + sort=False) + return df2 else: df_unique_id1 = np.unique(df[id1].values.reshape(-1)) df_unique_id2 = np.unique(df[id2].values.reshape(-1)) @@ -257,9 +258,7 @@ def NegSample(df, column_names, frac, two_types): for i in neg_list: neg_list_val.append([i[0], id2seq1[i[0]], i[1], id2seq2[i[1]], 0]) - df = pd.concat([ - df, - df = pd.concat([ + df2 = pd.concat([ df, pd.DataFrame(neg_list_val).rename(columns={ 0: id1, @@ -269,5 +268,6 @@ def NegSample(df, column_names, frac, two_types): 4: "Y" }) ], - ignore_index=True, sort=False) - return df + ignore_index=True, + sort=False) + return df2 From 2d4e4f4f2b7e6016b13b1cdea9b8862c2f523ea6 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 5 Mar 2024 17:40:40 -0500 Subject: [PATCH 29/39] decorator for X and feature presence checks --- tdc/resource/cellxgene-census.py | 73 +++++++++++++------------------- 1 file changed, 29 insertions(+), 44 deletions(-) diff --git a/tdc/resource/cellxgene-census.py b/tdc/resource/cellxgene-census.py index d09eee9d..9b51bac0 100644 --- a/tdc/resource/cellxgene-census.py +++ b/tdc/resource/cellxgene-census.py @@ -13,7 +13,7 @@ class CensusResource: _CENSUS_DATA = "census_data" _CENSUS_META = "census_info" _FEATURE_PRESENCE = "feature_dataset_presence_matrix" - _LATEST_CENSUS = "2023-12-15" + _LATEST_CENSUS = "2023-12-15" # TODO: maybe change to 'latest' _HUMAN = "homo_sapiens" class decorators: @@ -38,6 +38,31 @@ def check(*args, **kwargs): return check + @classmethod + def slice_checks_X_and_FM(cls, func): + """Decorator for functions that need X and feature presence matrix apply slicing if not filtering""" + def check(*args, **kwargs): + upper, lower = kwargs.get('upper', None), kwargs.get("lower", None) + measurement_name = kwargs.get("measurement_name") + fmt = kwargs.get("fmt") + todense = kwargs.get("todense") + if upper is None or lower is None: + raise Exception( + "No upper and/or lower bound for slicing was provided. Dataset is too large to fit in memory. \ + Memory-Efficient queries are not supported yet.") + elif measurement_name is None: + raise Exception("measurement_name was not provided.") + elif fmt is not None and fmt not in ["scipy", "pyarrow"]: + raise ValueError( + "measurement_matrix only supports 'scipy' or 'pyarrow' format") + kwargs["todense"] = todense if todense is not None else False + if todense and fmt != "scipy": + raise ValueError( + "dense representation only available in scipy format") + kwargs["fmt"] = fmt if fmt is not None else "scipy" + return func(*args, **kwargs) + return check + def __init__(self, census_version=None, organism=None): """Initialize the Census Resource. @@ -108,6 +133,7 @@ def get_gene_metadata(self, varread = var.read(value_filter=value_filter) return self.fmt_cellxgene_data(varread, fmt) + @decorators.slice_checks_X_and_FM @decorators.check_dataset_is_census_data def get_measurement_matrix(self, upper=None, @@ -128,21 +154,7 @@ def get_measurement_matrix(self, Exception: _description_ Exception: _description_ """ - if upper is None or lower is None: - raise Exception( - "No upper and/or lower bound for slicing was provided. Dataset is too large to fit in memory. \ - Memory-Efficient queries are not supported yet.") - elif measurement_name is None: - raise Exception("measurement_name was not provided.") - elif fmt is not None and fmt not in ["scipy", "pyarrow"]: - raise ValueError( - "measurement_matrix only supports 'scipy' or 'pyarrow' format") value_adjustment = value_adjustment if value_adjustment is not None else "raw" - todense = todense if todense is not None else False - fmt = fmt if fmt is not None else "scipy" - if todense and fmt != "scipy": - raise ValueError( - "dense representation only available in scipy format") with cellxgene_census.open_soma( census_version=self.census_version) as census: n_obs = len(census[self.dataset][self.organism].obs) @@ -154,6 +166,7 @@ def get_measurement_matrix(self, out = self.fmt_cellxgene_data(slc, fmt) return out if not todense else out.todense() + @decorators.slice_checks_X_and_FM @decorators.check_dataset_is_census_data def get_feature_dataset_presence_matrix(self, upper=None, @@ -161,21 +174,6 @@ def get_feature_dataset_presence_matrix(self, measurement_name=None, fmt=None, todense=None): - if upper is None or lower is None: - raise ValueError( - "No upper and/or lower bound for slicing was provided. Dataset is too large to fit in memory. \ - Memory-Efficient queries are not supported yet.") - elif measurement_name is None: - raise ValueError("measurement_name was not provided") - elif fmt is not None and fmt not in ["scipy", "pyarrow"]: - raise ValueError( - "feature dataset presence matrix only supports 'scipy' or 'pyarrow' formats" - ) - todense = todense if todense is not None else False - fmt = fmt if fmt is not None else "scipy" - if todense and fmt != "scipy": - raise ValueError( - "dense representation only available in scipy format") with cellxgene_census.open_soma( census_version=self.census_version) as census: n_obs = len(census[self.dataset][self.organism].obs) @@ -208,6 +206,7 @@ def get_cell_count_metadata(self): census_version=self.census_version) as census: return census[self.dataset]["summary_cell_counts"] + @decorators.slice_checks_X_and_FM @decorators.check_dataset_is_census_data def query_measurement_matrix(self, value_filter=None, @@ -233,21 +232,7 @@ def query_measurement_matrix(self, Yields: a slice of the output query in the specified format """ - if value_filter is None: - raise ValueError( - "query_measurement_matrix expects a value_filter. if you don't plan to apply a filter, use get_measurement_matrix()" - ) - elif measurement_name is None: - raise Exception("measurement_name was not provided.") - elif fmt is not None and fmt not in ["scipy", "pyarrow"]: - raise ValueError( - "measurement_matrix only supports 'scipy' or 'pyarrow' format") value_adjustment = value_adjustment if value_adjustment is not None else "raw" - todense = todense if todense is not None else False - fmt = fmt if fmt is not None else "scipy" - if todense and fmt != "scipy": - raise ValueError( - "dense representation only available in scipy format") with cellxgene_census.open_soma( census_version=self.census_version) as census: organism = census[self.dataset][self.organism] From 87776c94d7589dc930c13c8c0aadf39a5e70052d Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Tue, 5 Mar 2024 18:49:11 -0500 Subject: [PATCH 30/39] documentation --- tdc/resource/cellxgene-census.py | 100 +++++++++++++++++++------------ 1 file changed, 63 insertions(+), 37 deletions(-) diff --git a/tdc/resource/cellxgene-census.py b/tdc/resource/cellxgene-census.py index 9b51bac0..e754be00 100644 --- a/tdc/resource/cellxgene-census.py +++ b/tdc/resource/cellxgene-census.py @@ -20,7 +20,7 @@ class decorators: @classmethod def check_dataset_is_census_data(cls, func): - # @wraps(func) + """Sets self.dataset to census_data""" def check(*args, **kwargs): self = args[0] self.dataset = self._CENSUS_DATA @@ -30,7 +30,7 @@ def check(*args, **kwargs): @classmethod def check_dataset_is_census_info(cls, func): - + """Sets self.dataset to census_data""" def check(*args, **kwargs): self = args[0] self.dataset = self._CENSUS_META @@ -40,26 +40,34 @@ def check(*args, **kwargs): @classmethod def slice_checks_X_and_FM(cls, func): - """Decorator for functions that need X and feature presence matrix apply slicing if not filtering""" + """Decorator for: + 1. functions that need X and feature presence matrix apply slicing if not filtering + 2. functions with a todense() option abide by required formatting + 3. functions requiring a measurement name provide a measurement name + 4. fmt is a valid format + asserts these requirements hold in input arguments.""" def check(*args, **kwargs): - upper, lower = kwargs.get('upper', None), kwargs.get("lower", None) - measurement_name = kwargs.get("measurement_name") + if "upper" in kwargs: + upper, lower = kwargs.get('upper', None), kwargs.get("lower", None) + if upper is None or lower is None: + raise Exception( + "No upper and/or lower bound for slicing was provided. Dataset is too large to fit in memory. \ + Memory-Efficient queries are not supported yet.") fmt = kwargs.get("fmt") - todense = kwargs.get("todense") - if upper is None or lower is None: - raise Exception( - "No upper and/or lower bound for slicing was provided. Dataset is too large to fit in memory. \ - Memory-Efficient queries are not supported yet.") - elif measurement_name is None: + fmt = fmt if fmt is not None else "pandas" + if "todense" in kwargs: + todense = kwargs.get("todense") + kwargs["todense"] = todense if todense is not None else False + if todense and fmt != "scipy": + raise ValueError( + "dense representation only available in scipy format") + measurement_name = kwargs.get("measurement_name") + if measurement_name is None: raise Exception("measurement_name was not provided.") elif fmt is not None and fmt not in ["scipy", "pyarrow"]: raise ValueError( "measurement_matrix only supports 'scipy' or 'pyarrow' format") - kwargs["todense"] = todense if todense is not None else False - if todense and fmt != "scipy": - raise ValueError( - "dense representation only available in scipy format") - kwargs["fmt"] = fmt if fmt is not None else "scipy" + kwargs["fmt"] = fmt if fmt is not None else "pandas" return func(*args, **kwargs) return check @@ -67,14 +75,27 @@ def __init__(self, census_version=None, organism=None): """Initialize the Census Resource. Args: - census_version (str): The date of the census data release in YYYY- - TODO: complete + census_version (str): The date of the census data release to use + organism (str): string for census data organism to query data for. defaults to human. """ self.census_version = census_version if census_version is not None else self._LATEST_CENSUS self.organism = organism if organism is not None else self._HUMAN self.dataset = None # variable to set target census collection to either info or data def fmt_cellxgene_data(self, tiledb_ptr, fmt=None): + """Transform TileDB DataFrame or SparseNDArray to one of the supported API formats. + + Args: + tiledb_ptr (TileDB DataFrame or SparseNDArray): pointer to the TileDB DataFrame + fmt (str, optional): deisgnates a format to transfowm TileDB data to. Defaults to None. + + Raises: + Exception: if no format is provided + Exception: if format is not a valid option + + Returns: + The dataset in selected format if it's a valid format + """ if fmt is None: raise Exception( "format not provided to fmt_cellxgene_data(), please provide fmt variable" @@ -145,14 +166,13 @@ def get_measurement_matrix(self, """Count matrix for an input measurement by slice Args: - upper (_type_, optional): _description_. Defaults to None. - lower (_type_, optional): _description_. Defaults to None. - value_adjustment (_type_, optional): _description_. Defaults to None. - measurement_name (_type_, optional): _description_. Defaults to None. + upper (int, optional): upper bound on the slice to obtain. Defaults to None. + lower (int, optional): lower bound on the slice to obtain. Defaults to None. + value_adjustment (str, optional): designates the type of count desired for this measurement. Defaults to None. + measurement_name (str, optional): name of measurement, i.e. 'raw'. Defaults to None. - Raises: - Exception: _description_ - Exception: _description_ + Returns: + A slice from the count matrix in the specified format. If `todense` is True, then a dense scipy array will be returned. """ value_adjustment = value_adjustment if value_adjustment is not None else "raw" with cellxgene_census.open_soma( @@ -174,6 +194,18 @@ def get_feature_dataset_presence_matrix(self, measurement_name=None, fmt=None, todense=None): + """Gets a slice from the feature_dataset_presence_matrix for a given measurement_name + + Args: + upper (int, optional): upper bound on the slice. Defaults to None. + lower (int, optional): lower bound on the slice. Defaults to None. + measurement_name (str, optional): measurment_name for the query i.e. 'rna'. Defaults to None. + fmt (str, optional): output format desired for the output dataset. Defaults to None. + todense (bool, optional): if True, returns scipy dense representation. Defaults to None. + + Returns: + dataset in desired format + """ with cellxgene_census.open_soma( census_version=self.census_version) as census: n_obs = len(census[self.dataset][self.organism].obs) @@ -181,7 +213,7 @@ def get_feature_dataset_presence_matrix(self, census[self.dataset][self.organism].ms[measurement_name].var) fMatrix = census[self.dataset][self.organism].ms[measurement_name][ "feature_dataset_presence_matrix"] - slc = fMatrix.read((slice(0, 5),)).coos((n_obs, n_var)) + slc = fMatrix.read((slice(lower, upper),)).coos((n_obs, n_var)) out = self.fmt_cellxgene_data(slc, fmt) return out if not todense else out.todense() @@ -217,18 +249,12 @@ def query_measurement_matrix(self, """Query the Census Measurement Matrix. Function returns a Python generator. Args: - value_filter (_type_, optional): _description_. Defaults to None. - value_adjustment (_type_, optional): _description_. Defaults to None. - measurement_name (_type_, optional): _description_. Defaults to None. - fmt (_type_, optional): _description_. Defaults to None. - todense (_type_, optional): _description_. Defaults to None. + value_filter (str, optional): a valuer filter (obs) to apply to the query. Defaults to None. + value_adjustment (str, optional): the type of count to obtain from count matricx for this measurement. Defaults to None. + measurement_name (str, optional): measurement name to query, i.e. "RNA". Defaults to None. + fmt (str, optional): output format for the output dataset. Defaults to None. + todense (bool, optional): if True, will output a dense scipy array as the representation. Defaults to None. - Raises: - ValueError: _description_ - Exception: _description_ - ValueError: _description_ - ValueError: _description_ - Yields: a slice of the output query in the specified format """ From f7bf2b17683d256acc11fd338af87c6d73273698 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Wed, 6 Mar 2024 16:53:03 -0500 Subject: [PATCH 31/39] ship cellxgene loader --- tdc/resource/__init__.py | 1 + ...ellxgene-census.py => cellxgene_census.py} | 38 +------------- .../dev_tests/utils_tests/test_misc_utils.py | 1 + tdc/test/test_resources.py | 51 +++++++++++++++++++ 4 files changed, 55 insertions(+), 36 deletions(-) rename tdc/resource/{cellxgene-census.py => cellxgene_census.py} (90%) create mode 100644 tdc/test/test_resources.py diff --git a/tdc/resource/__init__.py b/tdc/resource/__init__.py index 873d428f..a040b03c 100644 --- a/tdc/resource/__init__.py +++ b/tdc/resource/__init__.py @@ -1 +1,2 @@ from .primekg import PrimeKG +from .cellxgene_census import CensusResource diff --git a/tdc/resource/cellxgene-census.py b/tdc/resource/cellxgene_census.py similarity index 90% rename from tdc/resource/cellxgene-census.py rename to tdc/resource/cellxgene_census.py index e754be00..4b14acef 100644 --- a/tdc/resource/cellxgene-census.py +++ b/tdc/resource/cellxgene_census.py @@ -2,7 +2,7 @@ import os os.environ['KMP_DUPLICATE_LIB_OK'] = "TRUE" -# TODO: find better fix or encode in environment / docker ^^^ +# TODO: remove import cellxgene_census import gget import tiledbsoma @@ -326,38 +326,4 @@ def gget_czi_cellxgene(cls, **kwargs): if __name__ == "__main__": - # TODO: tmp, run testing suite when this file is called as main - print("running tests for census resource") - print("instantiating resource") - resource = CensusResource() - cell_value_filter = "tissue == 'brain' and sex == 'male'" - cell_column_names = ["assay", "cell_type", "tissue"] - gene_value_filter = "feature_id in ['ENSG00000161798', 'ENSG00000188229']" - gene_column_names = ["feature_name", "feature_length"] - print("getting cell metadata as pandas dataframe") - obsdf = resource.get_cell_metadata(value_filter=cell_value_filter, - column_names=cell_column_names, - fmt="pandas") - print("success!") - print(obsdf.head()) - print("geting gene metadata as pyarrow") - varpyarrow = resource.get_gene_metadata(value_filter=gene_value_filter, - column_names=gene_column_names, - fmt="pyarrow", - measurement_name="RNA") - print("success!") - print(varpyarrow) - print("getting sample count matrix, checking todense() and scipy") - Xslice = resource.get_measurement_matrix(upper=5, - lower=0, - measurement_name="RNA", - fmt="scipy", - todense=True) - print("success") - print(Xslice) - print("getting feature presence matrix, checking pyarrow") - FMslice = resource.get_feature_dataset_presence_matrix( - upper=5, lower=0, measurement_name="RNA", fmt="pyarrow", todense=False) - print("success") - print(FMslice) - print("all tests passed") + pass diff --git a/tdc/test/dev_tests/utils_tests/test_misc_utils.py b/tdc/test/dev_tests/utils_tests/test_misc_utils.py index 8c4b1f17..7b6aae87 100644 --- a/tdc/test/dev_tests/utils_tests/test_misc_utils.py +++ b/tdc/test/dev_tests/utils_tests/test_misc_utils.py @@ -57,6 +57,7 @@ def test_uniprot2seq(self): seq = uniprot2seq("P49122") + @unittest.skip("long running test") # TODO: debug def test_to_graph(self): from tdc.multi_pred import DTI diff --git a/tdc/test/test_resources.py b/tdc/test/test_resources.py new file mode 100644 index 00000000..4ae2851a --- /dev/null +++ b/tdc/test/test_resources.py @@ -0,0 +1,51 @@ +import os +import sys +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) +import unittest + +from pandas import DataFrame +from pyarrow import SparseCOOTensor +from tdc.resource import cellxgene_census + +class TestResources(unittest.TestCase): + pass + +class TestCellXGene(unittest.TestCase): + + def setUp(self): + self.resource = cellxgene_census.CensusResource() + self.gene_value_filter = "feature_id in ['ENSG00000161798', 'ENSG00000188229']" + self.gene_column_names = ["feature_name", "feature_length"] + self.cell_value_filter = "tissue == 'brain' and sex == 'male'" + self.cell_column_names = ["assay", "cell_type", "tissue"] + + def test_get_cell_metadata(self): + obsdf = self.resource.get_cell_metadata(value_filter=self.cell_value_filter, + column_names=self.cell_column_names, + fmt="pandas") + assert isinstance(obsdf, DataFrame) + + def test_get_gene_metadata(self): + varpyarrow = self.resource.get_gene_metadata(value_filter=self.gene_value_filter, + column_names=self.gene_column_names, + fmt="pyarrow", + measurement_name="RNA") + print(varpyarrow) + # assert isinstance(varpyarrow, SparseCOOTensor) + + def test_get_measurement_matrix(self): + Xslice = self.resource.get_measurement_matrix(upper=5, + lower=0, + measurement_name="RNA", + fmt="scipy", + todense=True) + print("x", Xslice) + + def test_get_feature_dataset_presence_matrix(self): + FMslice = self.resource.get_feature_dataset_presence_matrix( + upper=5, lower=0, measurement_name="RNA", fmt="pyarrow", todense=False) + print("f", FMslice) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From cd65f41924e736d7edea724c27687597efa34717 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Wed, 6 Mar 2024 16:56:56 -0500 Subject: [PATCH 32/39] mend --- tdc/test/test_resources.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tdc/test/test_resources.py b/tdc/test/test_resources.py index 4ae2851a..3a5c96ee 100644 --- a/tdc/test/test_resources.py +++ b/tdc/test/test_resources.py @@ -5,7 +5,7 @@ import unittest from pandas import DataFrame -from pyarrow import SparseCOOTensor +# from pyarrow import SparseCOOTensor from tdc.resource import cellxgene_census class TestResources(unittest.TestCase): From 88d7e5996cc8d95264a178ed245282e9b5ff319b Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Wed, 6 Mar 2024 16:59:00 -0500 Subject: [PATCH 33/39] mend --- tdc/resource/cellxgene_census.py | 22 +++++++++----- tdc/test/test_resources.py | 50 +++++++++++++++++++------------- 2 files changed, 45 insertions(+), 27 deletions(-) diff --git a/tdc/resource/cellxgene_census.py b/tdc/resource/cellxgene_census.py index 4b14acef..e8369788 100644 --- a/tdc/resource/cellxgene_census.py +++ b/tdc/resource/cellxgene_census.py @@ -2,7 +2,7 @@ import os os.environ['KMP_DUPLICATE_LIB_OK'] = "TRUE" -# TODO: remove +# TODO: remove import cellxgene_census import gget import tiledbsoma @@ -13,7 +13,7 @@ class CensusResource: _CENSUS_DATA = "census_data" _CENSUS_META = "census_info" _FEATURE_PRESENCE = "feature_dataset_presence_matrix" - _LATEST_CENSUS = "2023-12-15" # TODO: maybe change to 'latest' + _LATEST_CENSUS = "2023-12-15" # TODO: maybe change to 'latest' _HUMAN = "homo_sapiens" class decorators: @@ -21,6 +21,7 @@ class decorators: @classmethod def check_dataset_is_census_data(cls, func): """Sets self.dataset to census_data""" + def check(*args, **kwargs): self = args[0] self.dataset = self._CENSUS_DATA @@ -31,6 +32,7 @@ def check(*args, **kwargs): @classmethod def check_dataset_is_census_info(cls, func): """Sets self.dataset to census_data""" + def check(*args, **kwargs): self = args[0] self.dataset = self._CENSUS_META @@ -46,9 +48,11 @@ def slice_checks_X_and_FM(cls, func): 3. functions requiring a measurement name provide a measurement name 4. fmt is a valid format asserts these requirements hold in input arguments.""" + def check(*args, **kwargs): if "upper" in kwargs: - upper, lower = kwargs.get('upper', None), kwargs.get("lower", None) + upper, lower = kwargs.get('upper', + None), kwargs.get("lower", None) if upper is None or lower is None: raise Exception( "No upper and/or lower bound for slicing was provided. Dataset is too large to fit in memory. \ @@ -57,18 +61,22 @@ def check(*args, **kwargs): fmt = fmt if fmt is not None else "pandas" if "todense" in kwargs: todense = kwargs.get("todense") - kwargs["todense"] = todense if todense is not None else False + kwargs[ + "todense"] = todense if todense is not None else False if todense and fmt != "scipy": raise ValueError( - "dense representation only available in scipy format") + "dense representation only available in scipy format" + ) measurement_name = kwargs.get("measurement_name") if measurement_name is None: raise Exception("measurement_name was not provided.") elif fmt is not None and fmt not in ["scipy", "pyarrow"]: raise ValueError( - "measurement_matrix only supports 'scipy' or 'pyarrow' format") + "measurement_matrix only supports 'scipy' or 'pyarrow' format" + ) kwargs["fmt"] = fmt if fmt is not None else "pandas" return func(*args, **kwargs) + return check def __init__(self, census_version=None, organism=None): @@ -238,7 +246,7 @@ def get_cell_count_metadata(self): census_version=self.census_version) as census: return census[self.dataset]["summary_cell_counts"] - @decorators.slice_checks_X_and_FM + @decorators.slice_checks_X_and_FM @decorators.check_dataset_is_census_data def query_measurement_matrix(self, value_filter=None, diff --git a/tdc/test/test_resources.py b/tdc/test/test_resources.py index 3a5c96ee..2f82c974 100644 --- a/tdc/test/test_resources.py +++ b/tdc/test/test_resources.py @@ -1,16 +1,19 @@ import os import sys + sys.path.append( os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) import unittest from pandas import DataFrame -# from pyarrow import SparseCOOTensor +# from pyarrow import SparseCOOTensor from tdc.resource import cellxgene_census + class TestResources(unittest.TestCase): pass + class TestCellXGene(unittest.TestCase): def setUp(self): @@ -19,33 +22,40 @@ def setUp(self): self.gene_column_names = ["feature_name", "feature_length"] self.cell_value_filter = "tissue == 'brain' and sex == 'male'" self.cell_column_names = ["assay", "cell_type", "tissue"] - + def test_get_cell_metadata(self): - obsdf = self.resource.get_cell_metadata(value_filter=self.cell_value_filter, - column_names=self.cell_column_names, - fmt="pandas") + obsdf = self.resource.get_cell_metadata( + value_filter=self.cell_value_filter, + column_names=self.cell_column_names, + fmt="pandas") assert isinstance(obsdf, DataFrame) - + def test_get_gene_metadata(self): - varpyarrow = self.resource.get_gene_metadata(value_filter=self.gene_value_filter, - column_names=self.gene_column_names, - fmt="pyarrow", - measurement_name="RNA") + varpyarrow = self.resource.get_gene_metadata( + value_filter=self.gene_value_filter, + column_names=self.gene_column_names, + fmt="pyarrow", + measurement_name="RNA") print(varpyarrow) - # assert isinstance(varpyarrow, SparseCOOTensor) - + # assert isinstance(varpyarrow, SparseCOOTensor) + def test_get_measurement_matrix(self): Xslice = self.resource.get_measurement_matrix(upper=5, - lower=0, - measurement_name="RNA", - fmt="scipy", - todense=True) + lower=0, + measurement_name="RNA", + fmt="scipy", + todense=True) print("x", Xslice) - + def test_get_feature_dataset_presence_matrix(self): FMslice = self.resource.get_feature_dataset_presence_matrix( - upper=5, lower=0, measurement_name="RNA", fmt="pyarrow", todense=False) + upper=5, + lower=0, + measurement_name="RNA", + fmt="pyarrow", + todense=False) print("f", FMslice) - + + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 38c754b7cee34cdb594dc18e3ead5b8ebde18d64 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Wed, 6 Mar 2024 17:06:06 -0500 Subject: [PATCH 34/39] mend --- tdc/test/test_resources.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tdc/test/test_resources.py b/tdc/test/test_resources.py index 2f82c974..092e06d7 100644 --- a/tdc/test/test_resources.py +++ b/tdc/test/test_resources.py @@ -39,6 +39,7 @@ def test_get_gene_metadata(self): print(varpyarrow) # assert isinstance(varpyarrow, SparseCOOTensor) + @unittest.skip("this test takes up too much mem for GH worker.. skip for now") def test_get_measurement_matrix(self): Xslice = self.resource.get_measurement_matrix(upper=5, lower=0, From 0410503b5514a1a7bb06bc8b536247db0c65615d Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Wed, 6 Mar 2024 17:11:07 -0500 Subject: [PATCH 35/39] mend --- tdc/test/test_resources.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tdc/test/test_resources.py b/tdc/test/test_resources.py index 092e06d7..580985a6 100644 --- a/tdc/test/test_resources.py +++ b/tdc/test/test_resources.py @@ -39,7 +39,8 @@ def test_get_gene_metadata(self): print(varpyarrow) # assert isinstance(varpyarrow, SparseCOOTensor) - @unittest.skip("this test takes up too much mem for GH worker.. skip for now") + @unittest.skip( + "this test takes up too much mem for GH worker.. skip for now") def test_get_measurement_matrix(self): Xslice = self.resource.get_measurement_matrix(upper=5, lower=0, From 5afd1b1451ade6bb361b5dc5d147ba205e31f909 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Wed, 6 Mar 2024 17:19:09 -0500 Subject: [PATCH 36/39] circle-ci problems.. should migrate --- requirements_ci.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/requirements_ci.txt b/requirements_ci.txt index efc9feeb..f723b4f4 100644 --- a/requirements_ci.txt +++ b/requirements_ci.txt @@ -9,3 +9,6 @@ torch tqdm huggingface_hub dataclasses +cellxgene-census +gget +tiledbsoma From 95c2d1931de7a1329ec056f25afb3c8aba2b2209 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Wed, 6 Mar 2024 17:25:48 -0500 Subject: [PATCH 37/39] mend --- .circleci/config.yml | 2 +- requirements_ci.txt | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 99b1742e..eb50d951 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -50,7 +50,7 @@ jobs: no_output_timeout: 30m command: | . venv/bin/activate - pytest --ignore=tdc/test/dev_tests/ + pytest --ignore=tdc/test/dev_tests/ --ignore=tdc/test/test_resources.py - store_artifacts: path: test-reports diff --git a/requirements_ci.txt b/requirements_ci.txt index f723b4f4..da1504a8 100644 --- a/requirements_ci.txt +++ b/requirements_ci.txt @@ -8,7 +8,4 @@ scikit-learn torch tqdm huggingface_hub -dataclasses -cellxgene-census -gget -tiledbsoma +dataclasses \ No newline at end of file From 64abd2acc63b92ab8886b175dbbee95233b4938d Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Wed, 6 Mar 2024 17:37:28 -0500 Subject: [PATCH 38/39] mend --- tdc/test/dev_tests/utils_tests/test_misc_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tdc/test/dev_tests/utils_tests/test_misc_utils.py b/tdc/test/dev_tests/utils_tests/test_misc_utils.py index c0d7f11d..9fbd383a 100644 --- a/tdc/test/dev_tests/utils_tests/test_misc_utils.py +++ b/tdc/test/dev_tests/utils_tests/test_misc_utils.py @@ -19,7 +19,6 @@ class TestFunctions(unittest.TestCase): - def setUp(self): print(os.getcwd()) pass From ad00907894ae5b001d14449890f5f740c3deecd4 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Wed, 6 Mar 2024 17:40:02 -0500 Subject: [PATCH 39/39] mend --- tdc/utils/label.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tdc/utils/label.py b/tdc/utils/label.py index dce4d799..be2800bd 100644 --- a/tdc/utils/label.py +++ b/tdc/utils/label.py @@ -32,12 +32,6 @@ def convert_y_unit(y, from_, to_): return y -def label_transform(y, - binary, - threshold, - convert_to_log, - verbose=True, - order="descending"): def label_transform(y, binary, threshold,