Skip to content

Commit

Permalink
Merge pull request Reed-CompBio#170 from ntalluri/implement-eval
Browse files Browse the repository at this point in the history
Evaluation Implementation
  • Loading branch information
agitter authored Aug 23, 2024
2 parents 679e8b4 + 522e66c commit cf0b401
Show file tree
Hide file tree
Showing 8 changed files with 226 additions and 10 deletions.
45 changes: 44 additions & 1 deletion Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ from spras import runner
import shutil
import yaml
from spras.dataset import Dataset
from spras.evaluation import Evaluation
from spras.analysis import ml, summary, graphspace, cytoscape
import spras.config as _config

Expand All @@ -27,13 +28,14 @@ hac_params = _config.config.hac_params
FRAMEWORK = _config.config.container_framework
print(f"Running {FRAMEWORK} containers")

# Return the dataset dictionary from the config file given the label
# Return the dataset or gold_standard dictionary from the config file given the label
def get_dataset(_datasets, label):
return _datasets[label]

algorithms = list(algorithm_params)
algorithms_with_params = [f'{algorithm}-params-{params_hash}' for algorithm, param_combos in algorithm_params.items() for params_hash in param_combos.keys()]
dataset_labels = list(_config.config.datasets.keys())
dataset_gold_standard_pairs = [f"{dataset}-{gs_values['label']}" for gs_values in _config.config.gold_standards.values() for dataset in gs_values['dataset_labels']]

# Get algorithms that are running multiple parameter combinations
def algo_has_mult_param_combos(algo):
Expand Down Expand Up @@ -102,6 +104,9 @@ def make_final_input(wildcards):
final_input.extend(expand('{out_dir}{sep}{dataset}-ml{sep}{algorithm}-hac-clusters-horizontal.txt',out_dir=out_dir,sep=SEP,dataset=dataset_labels,algorithm=algorithms_mult_param_combos,algorithm_params=algorithms_with_params))
final_input.extend(expand('{out_dir}{sep}{dataset}-ml{sep}{algorithm}-ensemble-pathway.txt',out_dir=out_dir,sep=SEP,dataset=dataset_labels,algorithm=algorithms_mult_param_combos,algorithm_params=algorithms_with_params))

if _config.config.analysis_include_evaluation:
final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-evaluation.txt',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gold_standard_pairs,algorithm_params=algorithms_with_params))

if len(final_input) == 0:
# No analysis added yet, so add reconstruction output files if they exist.
# (if analysis is specified, these should be implicitly run).
Expand Down Expand Up @@ -153,6 +158,21 @@ rule merge_input:
dataset_dict = get_dataset(_config.config.datasets, wildcards.dataset)
runner.merge_input(dataset_dict, output.dataset_file)

# Return all files used in the gold standard
def get_gold_standard_dependencies(wildcards):
gs = _config.config.gold_standards[wildcards.gold_standard]
all_files = gs["node_files"]
all_files = [gs["data_dir"] + SEP + data_file for data_file in all_files]
return all_files

# Merge all node files for a gold_standard into a single node table
rule merge_gs_input:
input: get_gold_standard_dependencies
output: gold_standard_file = SEP.join([out_dir, '{gold_standard}-merged.pickle'])
run:
gold_standard_dict = get_dataset(_config.config.gold_standards, wildcards.gold_standard)
Evaluation.merge_gold_standard_input(gold_standard_dict, output.gold_standard_file)

# The checkpoint is like a rule but can be used in dynamic workflows
# The workflow directed acyclic graph is re-evaluated after the checkpoint job runs
# If the checkpoint has not executed for the provided wildcard values, it will be run and then the rest of the
Expand Down Expand Up @@ -303,6 +323,7 @@ def collect_pathways_per_algo(wildcards):
filtered_algo_params = [algo_param for algo_param in algorithms_with_params if wildcards.algorithm in algo_param]
return expand('{out_dir}{sep}{{dataset}}-{algorithm_params}{sep}pathway.txt', out_dir=out_dir, sep=SEP, algorithm_params=filtered_algo_params)

# Cluster the output pathways per algorithm for each dataset
rule ml_analysis_aggregate_algo:
input:
pathways = collect_pathways_per_algo
Expand All @@ -322,6 +343,28 @@ rule ml_analysis_aggregate_algo:
ml.hac_horizontal(summary_df, output.hac_image_horizontal, output.hac_clusters_horizontal, **hac_params)
ml.ensemble_network(summary_df, output.ensemble_network_file)

# Return the gold standard pickle file for a specific gold standard
def get_gold_standard_pickle_file(wildcards):
parts = wildcards.dataset_gold_standard_pairs.split('-')
gs = parts[1]
return SEP.join([out_dir, f'{gs}-merged.pickle'])

# Returns the dataset corresponding to the gold standard pair
def get_dataset_label(wildcards):
parts = wildcards.dataset_gold_standard_pairs.split('-')
dataset = parts[0]
return dataset

# Run evaluation code for a specific dataset's pathway outputs against its paired gold standard
rule evaluation:
input:
gold_standard_file = get_gold_standard_pickle_file,
pathways = expand('{out_dir}{sep}{dataset_label}-{algorithm_params}{sep}pathway.txt', out_dir=out_dir, sep=SEP, algorithm_params=algorithms_with_params, dataset_label=get_dataset_label),
output: eval_file = SEP.join([out_dir, "{dataset_gold_standard_pairs}-evaluation.txt"])
run:
node_table = Evaluation.from_file(input.gold_standard_file).node_table
Evaluation.precision(input.pathways, node_table, output.eval_file)

# Remove the output directory
rule clean:
shell: f'rm -rf {out_dir}'
17 changes: 17 additions & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,21 @@ datasets:
# Relative path from the spras directory
data_dir: "input"

gold_standards:
-
# Labels can only contain letters, numbers, or underscores
label: gs0
node_files: ["gs_nodes0.txt"]
# edge_files: [] TODO: later iteration
data_dir: "input"
# List of dataset labels to compare with the specific gold standard dataset
dataset_labels: ["data0"]
-
label: gs1
node_files: ["gs_nodes1.txt"]
data_dir: "input"
dataset_labels: ["data1", "data0"]

# If we want to reconstruct then we should set run to true.
# TODO: if include is true above but run is false here, algs are not run.
# is this the behavior we want?
Expand Down Expand Up @@ -157,3 +172,5 @@ analysis:
linkage: 'ward'
# 'euclidean', 'manhattan', 'cosine'
metric: 'euclidean'
evaluation:
include: true
2 changes: 2 additions & 0 deletions config/egfr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,5 @@ analysis:
include: true
ml:
include: false
evaluation:
include: false
2 changes: 2 additions & 0 deletions input/gs_nodes0.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
A
B
1 change: 1 addition & 0 deletions input/gs_nodes1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
C
32 changes: 30 additions & 2 deletions spras/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def __init__(self, raw_config):
self.unpack_singularity = False
# A dictionary to store configured datasets against which SPRAS will be run
self.datasets = None
# A dictionary to store configured gold standard data against output of SPRAS runs
self.gold_standards = None
# The hash length SPRAS will use to identify parameter combinations. Default is 7
self.hash_length = DEFAULT_HASH_LENGTH
# The list of algorithms to run in the workflow. Each is a dict with 'name' as an expected key.
Expand All @@ -90,11 +92,13 @@ def __init__(self, raw_config):
# A Boolean specifying whether to run the summary analysis
self.analysis_include_summary = None
# A Boolean specifying whether to run the GraphSpace analysis
self.analysis_include_graphspace = None
self.analysis_include_graphspace = None
# A Boolean specifying whether to run the Cytoscape analysis
self.analysis_include_cytoscape = None
self.analysis_include_cytoscape = None
# A Boolean specifying whether to run the ML analysis
self.analysis_include_ml = None
# A Boolean specifying whether to run the Evaluation analysis
self.analysis_include_evaluation = None

_raw_config = copy.deepcopy(raw_config)
self.process_config(_raw_config)
Expand Down Expand Up @@ -146,6 +150,25 @@ def process_config(self, raw_config):
if not bool(re.match(pattern, key)):
raise ValueError(f"Dataset label \'{key}\' contains invalid values. Dataset labels can only contain letters, numbers, or underscores.")

# parse gold standard information
try:
self.gold_standards = {gold_standard["label"]: dict(gold_standard) for gold_standard in raw_config["gold_standards"]}
except:
self.gold_standards = {}

# check that gold_standard labels are formatted correctly
for key in self.gold_standards:
pattern = r'^\w+$'
if not bool(re.match(pattern, key)):
raise ValueError(f"Gold standard label \'{key}\' contains invalid values. Gold standard labels can only contain letters, numbers, or underscores.")

# check that all the dataset labels in the gold standards are existing datasets labels
dataset_labels = set(self.datasets.keys())
gold_standard_dataset_labels = {dataset_label for value in self.gold_standards.values() for dataset_label in value['dataset_labels']}
for label in gold_standard_dataset_labels:
if label not in dataset_labels:
raise ValueError(f"Dataset label '{label}' provided in gold standards does not exist in the existing dataset labels.")

# Code snipped from Snakefile that may be useful for assigning default labels
# dataset_labels = [dataset.get('label', f'dataset{index}') for index, dataset in enumerate(datasets)]
# Maps from the dataset label to the dataset list index
Expand Down Expand Up @@ -225,6 +248,11 @@ def process_config(self, raw_config):
self.analysis_include_graphspace = raw_config["analysis"]["graphspace"]["include"]
self.analysis_include_cytoscape = raw_config["analysis"]["cytoscape"]["include"]
self.analysis_include_ml = raw_config["analysis"]["ml"]["include"]
self.analysis_include_evaluation = raw_config["analysis"]["evaluation"]["include"]

if self.gold_standards == {} and self.analysis_include_evaluation:
raise ValueError("Evaluation analysis cannot run as gold standard data not provided. "
"Please set evaluation include to false or provide gold standard data.")

if 'aggregate_per_algorithm' not in self.ml_params:
self.analysis_include_ml_aggregate_algo = False
Expand Down
100 changes: 100 additions & 0 deletions spras/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
import pickle as pkl
from pathlib import Path
from typing import Dict, Iterable

import pandas as pd
from sklearn.metrics import precision_score


class Evaluation:
NODE_ID = "NODEID"

def __init__(self, gold_standard_dict: Dict):
self.label = None
self.datasets = None
self.node_table = None
# self.edge_table = None TODO: later iteration
self.load_files_from_dict(gold_standard_dict)
return

@staticmethod
def merge_gold_standard_input(gs_dict, gs_file):
"""
Merge files listed for this gold standard dataset and write the dataset to disk
@param gs_dict: gold standard dataset to process
@param gs_file: output filename
"""
gs_dataset = Evaluation(gs_dict)
gs_dataset.to_file(gs_file)

def to_file(self, file_name):
"""
Saves gold standard object to pickle file
"""
with open(file_name, "wb") as f:
pkl.dump(self, f)

@staticmethod
def from_file(file_name):
"""
Loads gold standard object from a pickle file.
Usage: gold_standard = Evaluation.from_file(pickle_file)
"""
with open(file_name, "rb") as f:
return pkl.load(f)

def load_files_from_dict(self, gold_standard_dict: Dict):
"""
Loads gold standard files from gold_standard_dict, which is one gold standard dataset
dictionary from the list in the config file with the fields in the config file.
Populates node_table.
node_table is a single column of nodes pandas table.
returns: none
"""
self.label = gold_standard_dict["label"] # cannot be empty, will break with a NoneType exception
self.datasets = gold_standard_dict["dataset_labels"] # can be empty, snakemake will not run evaluation due to dataset_gold_standard_pairs in snakemake file

# cannot be empty, snakemake will run evaluation even if empty
node_data_files = gold_standard_dict["node_files"][0] # TODO: single file for now

data_loc = gold_standard_dict["data_dir"]

single_node_table = pd.read_table(os.path.join(data_loc, node_data_files), header=None)
single_node_table.columns = [self.NODE_ID]
self.node_table = single_node_table

# TODO: are we allowing multiple node files or single in node_files for gs
# if yes, a for loop is needed

# TODO: later iteration - chose between node and edge file, or allow both

@staticmethod
def precision(file_paths: Iterable[Path], node_table: pd.DataFrame, output_file: str):
"""
Takes in file paths for a specific dataset and an associated gold standard node table.
Calculates precision for each pathway file
Returns output back to output_file
@param file_paths: file paths of pathway reconstruction algorithm outputs
@param node_table: the gold standard nodes
@param output_file: the filename to save the precision of each pathway
"""
y_true = set(node_table['NODEID'])
results = []

for file in file_paths:
df = pd.read_table(file, sep="\t", header=0, usecols=["Node1", "Node2"])
y_pred = set(df['Node1']).union(set(df['Node2']))
all_nodes = y_true.union(y_pred)
y_true_binary = [1 if node in y_true else 0 for node in all_nodes]
y_pred_binary = [1 if node in y_pred else 0 for node in all_nodes]

# default to 0.0 if there is a divide by 0 error
precision = precision_score(y_true_binary, y_pred_binary, zero_division=0.0)

results.append({"Pathway": file, "Precision": precision})

precision_df = pd.DataFrame(results)
precision_df.to_csv(output_file, sep="\t", index=False)
37 changes: 30 additions & 7 deletions test/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def get_test_config():
"reconstruction_dir": "my_dir"
}
},
"datasets": [{"label":"alg1"}, {"label":"alg2"}],
"datasets": [{"label": "alg1"}, {"label": "alg2"}],
"gold_standards": [{"label": "gs1", "dataset_labels": []}],
"algorithms": [{"params": ["param2", "param2"]}],
"analysis": {
"summary": {
Expand All @@ -34,11 +35,15 @@ def get_test_config():
"cytoscape": {
"include": False
},
"evaluation": {
"include": False
},
},
}

return test_raw_config


class TestConfig:
"""
Tests various parts of the configuration mechanism
Expand Down Expand Up @@ -105,17 +110,35 @@ def test_config_container_registry(self):

def test_error_dataset_label(self):
test_config = get_test_config()
error_test_dicts = [{"label":"test$"}, {"label":"@test'"}, {"label":"[test]"}, {"label":"test-test"}, {"label":"✉"}]
error_test_dicts = [{"label": "test$"}, {"label": "@test'"}, {"label": "[test]"}, {"label": "test-test"},
{"label": "✉"}]

for test_dict in error_test_dicts:
test_config["datasets"]= [test_dict]
with pytest.raises(ValueError): #raises error if any chars other than letters, numbers, or underscores are in dataset label
test_config["datasets"] = [test_dict]
with pytest.raises(ValueError): # raises error if any chars other than letters, numbers, or underscores are in dataset label
config.init_global(test_config)

def test_correct_dataset_label(self):
test_config = get_test_config()
correct_test_dicts = [{"label":"test"}, {"label":"123"}, {"label":"test123"}, {"label":"123test"}, {"label":"_"}, {"label":"test_test"}, {"label":"_test"}, {"label":"test_"}]
correct_test_dicts = [{"label": "test"}, {"label": "123"}, {"label": "test123"}, {"label": "123test"}, {"label": "_"},
{"label": "test_test"}, {"label": "_test"}, {"label": "test_"}]

for test_dict in correct_test_dicts:
test_config["datasets"]= [test_dict]
config.init_global(test_config) # no error should be raised
test_config["datasets"] = [test_dict]
config.init_global(test_config) # no error should be raised

def test_error_gs_label(self):
test_config = get_test_config()
error_labels = ["test$", "@test'"]

for test_label in error_labels:
test_config["gold_standards"][0]["label"] = test_label
with pytest.raises(ValueError): # raises error if any chars other than letters, numbers, or underscores are in gs label
config.init_global(test_config)

def test_error_gs_dataset_mismatch(self):
test_config = get_test_config()
test_config["gold_standards"] = [{"label": "gs1", "dataset_labels": ["mismatch"]}]

with pytest.raises(ValueError):
config.init_global(test_config)

0 comments on commit cf0b401

Please sign in to comment.