diff --git a/tdc/metadata.py b/tdc/metadata.py index abac5c01..24019db2 100644 --- a/tdc/metadata.py +++ b/tdc/metadata.py @@ -927,6 +927,16 @@ def get_task2category(): "pinnacle_protein_embed": "pth", "pinnacle_labels_dict": "txt", "panpep": "tab", + "pinnacle_output1": "zip", + "pinnacle_output2": "zip", + "pinnacle_output3": "zip", + "pinnacle_output4": "zip", + "pinnacle_output5": "zip", + "pinnacle_output6": "zip", + "pinnacle_output7": "zip", + "pinnacle_output8": "zip", + "pinnacle_output9": "zip", + "pinnacle_output10": "zip", } name2id = { @@ -1104,6 +1114,16 @@ def get_task2category(): "pinnacle_protein_embed": 10407128, "pinnacle_labels_dict": 10409635, "panpep": 10428565, + "pinnacle_output1": 10431072, + "pinnacle_output2": 10431073, + "pinnacle_output3": 10431078, + "pinnacle_output4": 10431080, + "pinnacle_output5": 10431077, + "pinnacle_output6": 10431076, + "pinnacle_output7": 10431079, + "pinnacle_output8": 10431074, + "pinnacle_output9": 10431075, + "pinnacle_output10": 10431081, } oracle2type = { diff --git a/tdc/resource/pinnacle.py b/tdc/resource/pinnacle.py index fefcab94..f5034d05 100644 --- a/tdc/resource/pinnacle.py +++ b/tdc/resource/pinnacle.py @@ -1,5 +1,5 @@ from ..utils import general_load -from ..utils.load import download_wrapper, load_json_from_txt_file +from ..utils.load import download_wrapper, load_json_from_txt_file, zip_data_download_wrapper import pandas as pd import os @@ -9,15 +9,6 @@ class PINNACLE: """ PINNACLE is a class for loading and manipulating the PINNACLE networks and embeddings. - @article{ - Li2023, - author = "Michelle Li", - title = "{PINNACLE}", - year = "2023", - month = "4", - url = "https://figshare.com/articles/software/AWARE/22708126", - doi = "10.6084/m9.figshare.22708126.v5" - } """ def __init__(self, path="./data"): @@ -30,7 +21,6 @@ def __init__(self, path="./data"): "\t") # use tab as names were left with spaces self.cell_tissue_mg.columns = ["Tissue", "Cell"] self.embeds_name = "pinnacle_protein_embed" - # self.embeds = resource_dataset_load(self.embeds_name, path, [self.embeds_name]) self.embeds_name = download_wrapper(self.embeds_name, path, self.embeds_name) self.embeds = torch.load(os.path.join(path, self.embeds_name + ".pth")) @@ -61,7 +51,6 @@ def get_keys(self): def get_embeds(self): prots = self.get_keys() emb = self.get_embeds_raw() - # nemb = {'--'.join(prots.iloc[k]): v for k, v in emb.items()} x = {} ctr = 0 for _, v in emb.items(): @@ -86,3 +75,44 @@ def get_embeds(self): x), "dims not mantained when translated to pandas. {} vs {}".format( len(df), len(x)) return df + + def get_exp_data(self, seed=1, split="train"): + if split not in ["train", "val", "test"]: + raise ValueError("{} not a valid split".format(split)) + filename = "pinnacle_output{}".format(seed) + # clean data directory + file_list = os.listdir("./data") + for file in file_list: + os.remove(os.path.join("./data", file)) + print("downloading pinancle zip data...") + zip_data_download_wrapper( + filename, "./data", + ["pinnacle_output{}".format(x) for x in range(1, 11)]) + print("success!") + # Get non-csv files and remove them + non_csv_files = [ + f for f in os.listdir("./data") if not f.endswith(".csv") + ] + for x in non_csv_files: + os.remove("./data/{}".format(x)) + # Get a list of all CSV files in the unzipped folder + csv_files = [f for f in os.listdir("./data") if f.endswith(".csv")] + if not csv_files: + raise Exception("no csv") + x = [] + print("iterating over csv files...") + for file in csv_files: + print("got file {}".format(file)) + if "_{}_".format(split) not in file: + os.remove("./data/{}".format(file)) + continue + print("reading into pandas...") + df = pd.read_csv("./data/{}".format(file)) + cell = file.split("_")[-1] + cell = cell.split(".")[0] + df["cell_type_label"] = cell + disease = "IBD" if "3767" in file else "RA" + df["disease"] = disease + x.append(df) + os.remove("./data/{}".format(file)) + return pd.concat(x, axis=0, ignore_index=True) \ No newline at end of file diff --git a/tdc/test/test_resources.py b/tdc/test/test_resources.py index d6eafa4c..9775348e 100644 --- a/tdc/test/test_resources.py +++ b/tdc/test/test_resources.py @@ -38,7 +38,6 @@ def test_get_gene_metadata(self): fmt="pyarrow", measurement_name="RNA") print(varpyarrow) - # assert isinstance(varpyarrow, SparseCOOTensor) def test_get_measurement_matrix(self): X = self.resource.query_measurement_matrix( @@ -112,6 +111,13 @@ def test_embeddings(self): assert len(set(cells)) == num_cells, "{} vs {} for cell_types".format( len(cells), num_cells) + def test_exp_data(self): + from tdc.resource.pinnacle import PINNACLE + pinnacle = PINNACLE() + exp_data = pinnacle.get_exp_data() + assert isinstance(exp_data, DataFrame) + assert len(exp_data) > 0, "PINNACLE exp_data is empty" + def tearDown(self): try: print(os.getcwd())