Skip to content

Commit

Permalink
updated code to work optionally, updated code based on PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ntalluri committed Jul 22, 2024
1 parent f635e99 commit 0a1e305
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 35 deletions.
39 changes: 19 additions & 20 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,14 @@ hac_params = _config.config.hac_params
FRAMEWORK = _config.config.container_framework
print(f"Running {FRAMEWORK} containers")

# Return the dataset or goldstandard dictionary from the config file given the label
def get_dict(_datasets, label):
print(_datasets, label)
# Return the dataset or gold_standard dictionary from the config file given the label
def get_dataset(_datasets, label):
return _datasets[label]

algorithms = list(algorithm_params)
algorithms_with_params = [f'{algorithm}-params-{params_hash}' for algorithm, param_combos in algorithm_params.items() for params_hash in param_combos.keys()]
dataset_labels = list(_config.config.datasets.keys())
dataset_goldstandard_pairs = [f"{dataset}-{gs_values['label']}" for gs_values in _config.config.gold_standard.values() for dataset in gs_values['datasets']]
dataset_gold_standard_pairs = [f"{dataset}-{gs_values['label']}" for gs_values in _config.config.gold_standards.values() for dataset in gs_values['datasets']]

# Get algorithms that are running multiple parameter combinations
def algo_has_mult_param_combos(algo):
Expand All @@ -59,7 +58,7 @@ def write_parameter_log(algorithm, param_label, logfile):

# Log the dataset contents specified in the config file in a yaml file
def write_dataset_log(dataset, logfile):
dataset_contents = get_dict(_config.config.datasets,dataset)
dataset_contents = get_dataset(_config.config.datasets,dataset)

# safe_dump gives RepresenterError for an OrderedDict
# config file has to convert the dataset from OrderedDict to dict to avoid this
Expand Down Expand Up @@ -106,7 +105,7 @@ def make_final_input(wildcards):
final_input.extend(expand('{out_dir}{sep}{dataset}-ml{sep}{algorithm}-ensemble-pathway.txt',out_dir=out_dir,sep=SEP,dataset=dataset_labels,algorithm=algorithms_mult_param_combos,algorithm_params=algorithms_with_params))

if _config.config.analysis_include_evalution:
final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-evaluation.txt',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_goldstandard_pairs,algorithm_params=algorithms_with_params))
final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-evaluation.txt',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gold_standard_pairs,algorithm_params=algorithms_with_params))

if len(final_input) == 0:
# No analysis added yet, so add reconstruction output files if they exist.
Expand Down Expand Up @@ -156,23 +155,23 @@ rule merge_input:
output: dataset_file = SEP.join([out_dir, '{dataset}-merged.pickle'])
run:
# Pass the dataset to PRRunner where the files will be merged and written to disk (i.e. pickled)
dataset_dict = get_dict(_config.config.datasets, wildcards.dataset)
dataset_dict = get_dataset(_config.config.datasets, wildcards.dataset)
runner.merge_input(dataset_dict, output.dataset_file)

# Return all files used in the gold standard
def get_goldstandard_dependencies(wildcards):
gs = _config.config.gold_standard[wildcards.gold_standard]
def get_gold_standard_dependencies(wildcards):
gs = _config.config.gold_standards[wildcards.gold_standard]
all_files = gs["node_files"]
all_files = [gs["data_dir"] + SEP + data_file for data_file in all_files]
return all_files

# Merge all node files for a goldstandard into a single node table
# Merge all node files for a gold_standard into a single node table
rule merge_gs_input:
input: get_goldstandard_dependencies
output: goldstandard_file = SEP.join([out_dir, '{gold_standard}-merged.pickle'])
input: get_gold_standard_dependencies
output: gold_standard_file = SEP.join([out_dir, '{gold_standard}-merged.pickle'])
run:
goldstandard_dict = get_dict(_config.config.gold_standard, wildcards.gold_standard)
runner.merge_gold_standard_input(goldstandard_dict, output.goldstandard_file)
gold_standard_dict = get_dataset(_config.config.gold_standards, wildcards.gold_standard)
Evaluation.merge_gold_standard_input(gold_standard_dict, output.gold_standard_file)

# The checkpoint is like a rule but can be used in dynamic workflows
# The workflow directed acyclic graph is re-evaluated after the checkpoint job runs
Expand Down Expand Up @@ -345,25 +344,25 @@ rule ml_analysis_aggregate_algo:
ml.ensemble_network(summary_df, output.ensemble_network_file)

# Return the gold standard pickle file for a specific gold standard
def get_goldstandard_pickle_file(wildcards):
parts = wildcards.dataset_goldstandard_pairs.split('-')
def get_gold_standard_pickle_file(wildcards):
parts = wildcards.dataset_gold_standard_pairs.split('-')
gs = parts[1]
return SEP.join([out_dir, f'{gs}-merged.pickle'])

# Returns the dataset corresponding to the gold standard pair
def get_dataset_label(wildcards):
parts = wildcards.dataset_goldstandard_pairs.split('-')
parts = wildcards.dataset_gold_standard_pairs.split('-')
dataset = parts[0]
return dataset

# Run evaluation code for a specific dataset's pathway outputs against its paired gold standard
rule evaluation:
input:
goldstandard_file = get_goldstandard_pickle_file,
gold_standard_file = get_gold_standard_pickle_file,
pathways = expand('{out_dir}{sep}{dataset_label}-{algorithm_params}{sep}pathway.txt', out_dir=out_dir, sep=SEP, algorithm_params=algorithms_with_params, dataset_label=get_dataset_label),
output: eval_file = SEP.join([out_dir, "{dataset_goldstandard_pairs}-evaluation.txt"])
output: eval_file = SEP.join([out_dir, "{dataset_gold_standard_pairs}-evaluation.txt"])
run:
node_table = Evaluation.from_file(input.goldstandard_file).node_table
node_table = Evaluation.from_file(input.gold_standard_file).node_table
Evaluation.precision(input.pathways, node_table, output.eval_file)

# Remove the output directory
Expand Down
2 changes: 1 addition & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ gold_standard:
node_files: ["gs_nodes0.txt"]
# edge_files: [] TODO: later iteration
data_dir: "input"
# Set of datasets to compare with the specific gold standard dataset
# Set of datasets (dataset labels) to compare with the specific gold standard dataset
datasets: ["data0"]
-
label: gs1
Expand Down
2 changes: 2 additions & 0 deletions config/egfr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,5 @@ analysis:
include: true
ml:
include: false
evaluation:
include: false
18 changes: 14 additions & 4 deletions spras/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self, raw_config):
# A dictionary to store configured datasets against which SPRAS will be run
self.datasets = None
# A dictionary to store configured gold standard data against ouptut of SPRAS runs
self.gold_standard = None
self.gold_standards = None
# The hash length SPRAS will use to identify parameter combinations. Default is 7
self.hash_length = DEFAULT_HASH_LENGTH
# The list of algorithms to run in the workflow. Each is a dict with 'name' as an expected key.
Expand Down Expand Up @@ -147,11 +147,15 @@ def process_config(self, raw_config):
self.datasets = {dataset["label"]: dict(dataset) for dataset in raw_config["datasets"]}

# TODO: turn into try except
self.gold_standard = {goldstandard["label"]: dict(goldstandard) for goldstandard in raw_config["gold_standard"]}
for key in self.gold_standard:
try:
self.gold_standards = {gold_standard["label"]: dict(gold_standard) for gold_standard in raw_config["gold_standard"]}
except:
self.gold_standards = {}

for key in self.gold_standards:
pattern = r'^\w+$'
if not bool(re.match(pattern, key)):
raise ValueError(f"Gold Standard label \'{key}\' contains invalid values. Gold Standard labels can only contain letters, numbers, or underscores.")
raise ValueError(f"Gold standard label \'{key}\' contains invalid values. Gold standard labels can only contain letters, numbers, or underscores.")

# Code snipped from Snakefile that may be useful for assigning default labels
# dataset_labels = [dataset.get('label', f'dataset{index}') for index, dataset in enumerate(datasets)]
Expand Down Expand Up @@ -234,6 +238,12 @@ def process_config(self, raw_config):
self.analysis_include_ml = raw_config["analysis"]["ml"]["include"]
self.analysis_include_evalution = raw_config["analysis"]["evaluation"]["include"]

# the code will run correctly without this section below
# TODO: decide if this part is needed
if self.gold_standards == {} and self.analysis_include_evalution == True:
print("Gold standard data not provided. Evaluation analysis cannot run.")
self.analysis_include_evalution = False

if 'aggregate_per_algorithm' not in self.ml_params:
self.analysis_include_ml_aggregate_algo = False
else:
Expand Down
10 changes: 10 additions & 0 deletions spras/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ def __init__(self, gold_standard_dict):
self.datasets = None
return

def merge_gold_standard_input(gs_dict, gs_file):
"""
Merge files listed for this gold standard dataset and write the dataset to disk
@param gs_dict: gold standard dataset to process
@param gs_file: output filename
"""
gs_dataset = Evaluation(gs_dict)
gs_dataset.to_file(gs_file)

def to_file(self, file_name):
"""
Saves gold standard object to pickle file
Expand Down Expand Up @@ -60,6 +69,7 @@ def load_files_from_dict(self, gold_standard_dict):

# TODO: later iteration - chose between node and edge file, or allow both

# TODO: move outside of Evaluation class?
def precision(file_paths: Iterable[Path], node_table: pd.DataFrame, output_file: str):
"""
Takes in file paths for a specific dataset and an associated gold standard node table.
Expand Down
10 changes: 0 additions & 10 deletions spras/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,6 @@ def merge_input(dataset_dict, dataset_file):
dataset = Dataset(dataset_dict)
dataset.to_file(dataset_file)

def merge_gold_standard_input(gs_dict, gs_file):
"""
Merge files listed for this gold standard dataset and write the dataset to disk
@param gs_dict: gold standard dataset to process
@param gs_file: output filename
"""
gs_dataset = Evaluation(gs_dict)
gs_dataset.to_file(gs_file)


def prepare_inputs(algorithm, data_file, filename_map):
"""
Prepare general dataset files for this algorithm
Expand Down

0 comments on commit 0a1e305

Please sign in to comment.