forked from Reed-CompBio/spras
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request Reed-CompBio#170 from ntalluri/implement-eval
Evaluation Implementation
- Loading branch information
Showing
8 changed files
with
226 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,3 +89,5 @@ analysis: | |
include: true | ||
ml: | ||
include: false | ||
evaluation: | ||
include: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
A | ||
B |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
C |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import os | ||
import pickle as pkl | ||
from pathlib import Path | ||
from typing import Dict, Iterable | ||
|
||
import pandas as pd | ||
from sklearn.metrics import precision_score | ||
|
||
|
||
class Evaluation: | ||
NODE_ID = "NODEID" | ||
|
||
def __init__(self, gold_standard_dict: Dict): | ||
self.label = None | ||
self.datasets = None | ||
self.node_table = None | ||
# self.edge_table = None TODO: later iteration | ||
self.load_files_from_dict(gold_standard_dict) | ||
return | ||
|
||
@staticmethod | ||
def merge_gold_standard_input(gs_dict, gs_file): | ||
""" | ||
Merge files listed for this gold standard dataset and write the dataset to disk | ||
@param gs_dict: gold standard dataset to process | ||
@param gs_file: output filename | ||
""" | ||
gs_dataset = Evaluation(gs_dict) | ||
gs_dataset.to_file(gs_file) | ||
|
||
def to_file(self, file_name): | ||
""" | ||
Saves gold standard object to pickle file | ||
""" | ||
with open(file_name, "wb") as f: | ||
pkl.dump(self, f) | ||
|
||
@staticmethod | ||
def from_file(file_name): | ||
""" | ||
Loads gold standard object from a pickle file. | ||
Usage: gold_standard = Evaluation.from_file(pickle_file) | ||
""" | ||
with open(file_name, "rb") as f: | ||
return pkl.load(f) | ||
|
||
def load_files_from_dict(self, gold_standard_dict: Dict): | ||
""" | ||
Loads gold standard files from gold_standard_dict, which is one gold standard dataset | ||
dictionary from the list in the config file with the fields in the config file. | ||
Populates node_table. | ||
node_table is a single column of nodes pandas table. | ||
returns: none | ||
""" | ||
self.label = gold_standard_dict["label"] # cannot be empty, will break with a NoneType exception | ||
self.datasets = gold_standard_dict["dataset_labels"] # can be empty, snakemake will not run evaluation due to dataset_gold_standard_pairs in snakemake file | ||
|
||
# cannot be empty, snakemake will run evaluation even if empty | ||
node_data_files = gold_standard_dict["node_files"][0] # TODO: single file for now | ||
|
||
data_loc = gold_standard_dict["data_dir"] | ||
|
||
single_node_table = pd.read_table(os.path.join(data_loc, node_data_files), header=None) | ||
single_node_table.columns = [self.NODE_ID] | ||
self.node_table = single_node_table | ||
|
||
# TODO: are we allowing multiple node files or single in node_files for gs | ||
# if yes, a for loop is needed | ||
|
||
# TODO: later iteration - chose between node and edge file, or allow both | ||
|
||
@staticmethod | ||
def precision(file_paths: Iterable[Path], node_table: pd.DataFrame, output_file: str): | ||
""" | ||
Takes in file paths for a specific dataset and an associated gold standard node table. | ||
Calculates precision for each pathway file | ||
Returns output back to output_file | ||
@param file_paths: file paths of pathway reconstruction algorithm outputs | ||
@param node_table: the gold standard nodes | ||
@param output_file: the filename to save the precision of each pathway | ||
""" | ||
y_true = set(node_table['NODEID']) | ||
results = [] | ||
|
||
for file in file_paths: | ||
df = pd.read_table(file, sep="\t", header=0, usecols=["Node1", "Node2"]) | ||
y_pred = set(df['Node1']).union(set(df['Node2'])) | ||
all_nodes = y_true.union(y_pred) | ||
y_true_binary = [1 if node in y_true else 0 for node in all_nodes] | ||
y_pred_binary = [1 if node in y_pred else 0 for node in all_nodes] | ||
|
||
# default to 0.0 if there is a divide by 0 error | ||
precision = precision_score(y_true_binary, y_pred_binary, zero_division=0.0) | ||
|
||
results.append({"Pathway": file, "Precision": precision}) | ||
|
||
precision_df = pd.DataFrame(results) | ||
precision_df.to_csv(output_file, sep="\t", index=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters