Skip to content

Commit

Permalink
Update : RE (chemprot only)
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjininfo committed Feb 14, 2019
1 parent b702357 commit c09a6bb
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 16 deletions.
46 changes: 32 additions & 14 deletions biocodes/re_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,45 @@
parser = argparse.ArgumentParser(description='')
parser.add_argument('--output_path', type=str, help='')
parser.add_argument('--answer_path', type=str, help='')
parser.add_argument('--task', type=str, default="binary", help='default:binary, possible other options:{chemprot}')
args = parser.parse_args()




testdf = pd.read_csv(args.answer_path, sep="\t", index_col=0)
preddf = pd.read_csv(args.output_path, sep="\t", header=None)

pred = [preddf.iloc[i].tolist() for i in preddf.index]
pred_class = [np.argmax(v) for v in pred]
pred_prob_one = [v[1] for v in pred]




p,r,f,s = sklearn.metrics.precision_recall_fscore_support(y_pred=pred_class, y_true=testdf["label"])
results = dict()
results["f1 score"] = f[1]
results["recall"] = r[1]
results["precision"] = p[1]
results["specificity"] = r[0]
# binary
if args.task == "binary":
pred = [preddf.iloc[i].tolist() for i in preddf.index]
pred_class = [np.argmax(v) for v in pred]
pred_prob_one = [v[1] for v in pred]

p,r,f,s = sklearn.metrics.precision_recall_fscore_support(y_pred=pred_class, y_true=testdf["label"])
results = dict()
results["f1 score"] = f[1]
results["recall"] = r[1]
results["precision"] = p[1]
results["specificity"] = r[0]

# chemprot
# micro-average of 5 target classes
# see "Potent pairing: ensemble of long short-term memory networks and support vector machine for chemical-protein relation extraction (Mehryary, 2018)" for details
if args.task == "chemprot":
pred = [preddf.iloc[i].tolist() for i in preddf.index]
pred_class = [np.argmax(v) for v in pred]
str_to_int_mapper = dict()

for i,v in enumerate(sorted(testdf["label"].unique())):
str_to_int_mapper[v] = i
test_answer = [str_to_int_mapper[v] for v in testdf["label"]]

p,r,f,s = sklearn.metrics.precision_recall_fscore_support(y_pred=pred_class, y_true=test_answer, labels=[0,1,2,3,4], average="micro")
results = dict()
results["f1 score"] = f
results["recall"] = r
results["precision"] = p

for k,v in results.items():
print("{:11s} : {:.2%}".format(k,v))

87 changes: 85 additions & 2 deletions run_re.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,89 @@ def _create_examples(self, lines, set_type):
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples

class BioBERTDDIProcessor(DataProcessor):
"""Processor for the BioBERT data set (GLUE version)."""
def __init__(self):
raise NotImplementedError

def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

def get_labels(self):
"""See base class."""
return ["advise", "mechanism", "int", "false"]

def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
# Only the test set has a header
if set_type == "test" and i == 0:
continue
guid = "%s-%s" % (set_type, i)
if set_type == "test":
text_a = tokenization.convert_to_unicode(line[1])
label = "false"
else:
text_a = tokenization.convert_to_unicode(line[0])
label = tokenization.convert_to_unicode(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples


class BioBERTChemprotProcessor(DataProcessor):
"""Processor for the BioBERT data set (GLUE version)."""

def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

def get_labels(self):
"""See base class."""
return ["cpr:3", "cpr:4", "cpr:5", "cpr:6", "cpr:9", "false"]

def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
# Only the test set has a header
if set_type == "test" and i == 0:
continue
guid = "%s-%s" % (set_type, i)
if set_type == "test":
text_a = tokenization.convert_to_unicode(line[1])
label = "false"
else:
text_a = tokenization.convert_to_unicode(line[0])
label = tokenization.convert_to_unicode(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples


class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version)."""
Expand Down Expand Up @@ -832,8 +915,8 @@ def main(_):
"polysearch": BioBERTProcessor,
"mirnadisease": BioBERTProcessor,
"euadr": BioBERTProcessor,
"chemprot": BioBERTProcessor,
"ddi13": BioBERTProcessor,
"chemprot": BioBERTChemprotProcessor,
"ddi13": BioBERTDDIProcessor,
}

tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
Expand Down

0 comments on commit c09a6bb

Please sign in to comment.