Skip to content

Commit

Permalink
cleaned up and added commenting
Browse files Browse the repository at this point in the history
  • Loading branch information
ntalluri committed Jul 19, 2024
1 parent 2c6d1a5 commit 860d840
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 27 deletions.
46 changes: 24 additions & 22 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,15 @@ hac_params = _config.config.hac_params
FRAMEWORK = _config.config.container_framework
print(f"Running {FRAMEWORK} containers")

# Return the dataset dictionary from the config file given the label
def get_dataset(_datasets, label):
# Return the dataset or goldstandard dictionary from the config file given the label
def get_dict(_datasets, label):
print(_datasets, label)
return _datasets[label]

algorithms = list(algorithm_params)
algorithms_with_params = [f'{algorithm}-params-{params_hash}' for algorithm, param_combos in algorithm_params.items() for params_hash in param_combos.keys()]
dataset_labels = list(_config.config.datasets.keys())
gold_standard_labels = list(_config.config.gold_standard.keys())

dataset_gs_pairs_tuples = [(gs_values['label'], dataset) for gs_values in _config.config.gold_standard.values() for dataset in gs_values['datasets']]
# am I able to send tuples around?
dataset_gs_pairs_formatted = [f"{dataset}-{gs_values['label']}" for gs_values in _config.config.gold_standard.values() for dataset in gs_values['datasets']]
# prefomatting makes it easier to send around but requires more functions to use
dataset_goldstandard_pairs = [f"{dataset}-{gs_values['label']}" for gs_values in _config.config.gold_standard.values() for dataset in gs_values['datasets']]

# Get algorithms that are running multiple parameter combinations
def algo_has_mult_param_combos(algo):
Expand All @@ -63,7 +59,7 @@ def write_parameter_log(algorithm, param_label, logfile):

# Log the dataset contents specified in the config file in a yaml file
def write_dataset_log(dataset, logfile):
dataset_contents = get_dataset(_config.config.datasets,dataset)
dataset_contents = get_dict(_config.config.datasets,dataset)

# safe_dump gives RepresenterError for an OrderedDict
# config file has to convert the dataset from OrderedDict to dict to avoid this
Expand Down Expand Up @@ -110,7 +106,7 @@ def make_final_input(wildcards):
final_input.extend(expand('{out_dir}{sep}{dataset}-ml{sep}{algorithm}-ensemble-pathway.txt',out_dir=out_dir,sep=SEP,dataset=dataset_labels,algorithm=algorithms_mult_param_combos,algorithm_params=algorithms_with_params))

if _config.config.analysis_include_evalution:
final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-evaluation.txt',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gs_pairs_formatted,algorithm_params=algorithms_with_params))
final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-evaluation.txt',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_goldstandard_pairs,algorithm_params=algorithms_with_params))

if len(final_input) == 0:
# No analysis added yet, so add reconstruction output files if they exist.
Expand Down Expand Up @@ -160,21 +156,23 @@ rule merge_input:
output: dataset_file = SEP.join([out_dir, '{dataset}-merged.pickle'])
run:
# Pass the dataset to PRRunner where the files will be merged and written to disk (i.e. pickled)
dataset_dict = get_dataset(_config.config.datasets, wildcards.dataset)
dataset_dict = get_dict(_config.config.datasets, wildcards.dataset)
runner.merge_input(dataset_dict, output.dataset_file)

def get_gs_dependencies(wildcards):
# Return all files used in the gold standard
def get_goldstandard_dependencies(wildcards):
gs = _config.config.gold_standard[wildcards.gold_standard]
all_files = gs["node_files"]
all_files = [gs["data_dir"] + SEP + data_file for data_file in all_files]
return all_files

# Merge all node files for a goldstandard into a single node table
rule merge_gs_input:
input: get_gs_dependencies
output: gs_file = SEP.join([out_dir, '{gold_standard}-merged.pickle'])
input: get_goldstandard_dependencies
output: goldstandard_file = SEP.join([out_dir, '{gold_standard}-merged.pickle'])
run:
gs_dict = get_dataset(_config.config.gold_standard, wildcards.gold_standard)
runner.merge_gold_standard_input(gs_dict, output.gs_file)
goldstandard_dict = get_dict(_config.config.gold_standard, wildcards.gold_standard)
runner.merge_gold_standard_input(goldstandard_dict, output.goldstandard_file)

# The checkpoint is like a rule but can be used in dynamic workflows
# The workflow directed acyclic graph is re-evaluated after the checkpoint job runs
Expand Down Expand Up @@ -326,6 +324,7 @@ def collect_pathways_per_algo(wildcards):
filtered_algo_params = [algo_param for algo_param in algorithms_with_params if wildcards.algorithm in algo_param]
return expand('{out_dir}{sep}{{dataset}}-{algorithm_params}{sep}pathway.txt', out_dir=out_dir, sep=SEP, algorithm_params=filtered_algo_params)

# Cluster the output pathways per algorithm for each dataset
rule ml_analysis_aggregate_algo:
input:
pathways = collect_pathways_per_algo
Expand All @@ -345,23 +344,26 @@ rule ml_analysis_aggregate_algo:
ml.hac_horizontal(summary_df, output.hac_image_horizontal, output.hac_clusters_horizontal, **hac_params)
ml.ensemble_network(summary_df, output.ensemble_network_file)

def get_gs_pickle_file(wildcards):
parts = wildcards.dataset_gs_pairs_formatted.split('-')
# Return the gold standard pickle file for a specific gold standard
def get_goldstandard_pickle_file(wildcards):
parts = wildcards.dataset_goldstandard_pairs.split('-')
gs = parts[1]
return SEP.join([out_dir, f'{gs}-merged.pickle'])

# Returns the dataset corresponding to the gold standard pair
def get_dataset_label(wildcards):
parts = wildcards.dataset_gs_pairs_formatted.split('-')
parts = wildcards.dataset_goldstandard_pairs.split('-')
dataset = parts[0]
return dataset

# Run evaluation code for a specific dataset's pathway outputs against its paired gold standard
rule evaluation:
input:
gs_file = get_gs_pickle_file,
goldstandard_file = get_goldstandard_pickle_file,
pathways = expand('{out_dir}{sep}{dataset_label}-{algorithm_params}{sep}pathway.txt', out_dir=out_dir, sep=SEP, algorithm_params=algorithms_with_params, dataset_label=get_dataset_label),
output: eval_file = SEP.join([out_dir, "{dataset_gs_pairs_formatted}-evaluation.txt"])
output: eval_file = SEP.join([out_dir, "{dataset_goldstandard_pairs}-evaluation.txt"])
run:
node_table = Evaluation.from_file(input.gs_file).node_table
node_table = Evaluation.from_file(input.goldstandard_file).node_table
Evaluation.precision(input.pathways, node_table, output.eval_file)

# Remove the output directory
Expand Down
1 change: 1 addition & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ gold_standard:
node_files: ["gs_nodes0.txt"]
# edge_files: [] TODO: later iteration
data_dir: "input"
# Set of datasets to compare with the specific gold standard dataset
datasets: ["data0"]
-
label: gs1
Expand Down
3 changes: 2 additions & 1 deletion spras/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, raw_config):
self.unpack_singularity = False
# A dictionary to store configured datasets against which SPRAS will be run
self.datasets = None
# a dictionary to store configured gold standard data against ouptut of SPRAS runs
# A dictionary to store configured gold standard data against ouptut of SPRAS runs
self.gold_standard = None
# The hash length SPRAS will use to identify parameter combinations. Default is 7
self.hash_length = DEFAULT_HASH_LENGTH
Expand Down Expand Up @@ -146,6 +146,7 @@ def process_config(self, raw_config):
# Convert to dicts to simplify the yaml logging
self.datasets = {dataset["label"]: dict(dataset) for dataset in raw_config["datasets"]}

# TODO: turn into try except
self.gold_standard = {goldstandard["label"]: dict(goldstandard) for goldstandard in raw_config["gold_standard"]}
for key in self.gold_standard:
pattern = r'^\w+$'
Expand Down
23 changes: 19 additions & 4 deletions spras/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,30 @@ def __init__(self, gold_standard_dict):

def to_file(self, file_name):
"""
Saves dataset object to pickle file
Saves gold standard object to pickle file
"""
with open(file_name, "wb") as f:
pkl.dump(self, f)

@classmethod
def from_file(cls, file_name):
"""
Loads dataset object from a pickle file.
Usage: dataset = Dataset.from_file(pickle_file)
Loads gold standard object from a pickle file.
Usage: gold_standard = Evaluation.from_file(pickle_file)
"""
with open(file_name, "rb") as f:
return pkl.load(f)

def load_files_from_dict(self, gold_standard_dict):
"""
Loads gold standard files from gold_standard_dict, which is one gold standard dataset
dictionary from the list in the config file with the fields in the config file.
Populates node_table.
node_table is a single column of nodes pandas table.
returns: none
"""
self.label = gold_standard_dict["label"]
self.datasets = gold_standard_dict["datasets"]

Expand All @@ -53,7 +61,14 @@ def load_files_from_dict(self, gold_standard_dict):
# TODO: later iteration - chose between node and edge file, or allow both

def precision(file_paths: Iterable[Path], node_table: pd.DataFrame, output_file: str):

"""
Takes in file paths for a specific dataset and an associated gold standard node table.
Calculates precision for each pathway file
Returns output back to output_file
@param file_paths: file paths of pathway reconstruction algorithm outputs
@param node_table: the gold standard nodes
@param output_file: the filename to save the precision of each pathway
"""
y_true = node_table['NODEID'].tolist()
results = []

Expand Down

0 comments on commit 860d840

Please sign in to comment.