-
Notifications
You must be signed in to change notification settings - Fork 287
/
keras_words_subtoken_metrics.py
128 lines (103 loc) · 6.42 KB
/
keras_words_subtoken_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import tensorflow as tf
import tensorflow.keras.backend as K
import abc
from typing import Optional, Callable, List
from functools import reduce
from common import common
class WordsSubtokenMetricBase(tf.metrics.Metric):
FilterType = Callable[[tf.Tensor, tf.Tensor], tf.Tensor]
def __init__(self,
index_to_word_table: Optional[tf.lookup.StaticHashTable] = None,
topk_predicted_words=None,
predicted_words_filters: Optional[List[FilterType]] = None,
subtokens_delimiter: str = '|', name=None, dtype=None):
super(WordsSubtokenMetricBase, self).__init__(name=name, dtype=dtype)
self.tp = self.add_weight('true_positives', shape=(), initializer=tf.zeros_initializer)
self.fp = self.add_weight('false_positives', shape=(), initializer=tf.zeros_initializer)
self.fn = self.add_weight('false_negatives', shape=(), initializer=tf.zeros_initializer)
self.index_to_word_table = index_to_word_table
self.topk_predicted_words = topk_predicted_words
self.predicted_words_filters = predicted_words_filters
self.subtokens_delimiter = subtokens_delimiter
def _get_true_target_word_string(self, true_target_word):
if self.index_to_word_table is None:
return true_target_word
true_target_word_index = tf.cast(true_target_word, dtype=self.index_to_word_table.key_dtype)
return self.index_to_word_table.lookup(true_target_word_index)
def update_state(self, true_target_word, predictions, sample_weight=None):
"""Accumulates true positive, false positive and false negative statistics."""
if sample_weight is not None:
raise NotImplemented("WordsSubtokenMetricBase with non-None `sample_weight` is not implemented.")
# For each example in the batch we have:
# (i) one ground true target word;
# (ii) one predicted word (argmax y_hat)
topk_predicted_words = predictions if self.topk_predicted_words is None else self.topk_predicted_words
assert topk_predicted_words is not None
predicted_word = self._get_prediction_from_topk(topk_predicted_words)
true_target_word_string = self._get_true_target_word_string(true_target_word)
true_target_word_string = tf.reshape(true_target_word_string, [-1])
# We split each word into subtokens
true_target_subwords = tf.compat.v1.string_split(true_target_word_string, sep=self.subtokens_delimiter)
prediction_subwords = tf.compat.v1.string_split(predicted_word, sep=self.subtokens_delimiter)
true_target_subwords = tf.sparse.to_dense(true_target_subwords, default_value='<PAD>')
prediction_subwords = tf.sparse.to_dense(prediction_subwords, default_value='<PAD>')
true_target_subwords_mask = tf.not_equal(true_target_subwords, '<PAD>')
prediction_subwords_mask = tf.not_equal(prediction_subwords, '<PAD>')
# Now shapes of true_target_subwords & true_target_subwords are (batch, subtokens)
# We use broadcast to calculate 2 lists difference with duplicates preserving.
true_target_subwords = tf.expand_dims(true_target_subwords, -1)
prediction_subwords = tf.expand_dims(prediction_subwords, -1)
# Now shapes of true_target_subwords & true_target_subwords are (batch, subtokens, 1)
true_target_subwords__in__prediction_subwords = \
tf.reduce_any(tf.equal(true_target_subwords, tf.transpose(prediction_subwords, perm=[0, 2, 1])), axis=2)
prediction_subwords__in__true_target_subwords = \
tf.reduce_any(tf.equal(prediction_subwords, tf.transpose(true_target_subwords, perm=[0, 2, 1])), axis=2)
# Count ground true label subwords that exist in the predicted word.
batch_true_positive = tf.reduce_sum(tf.cast(
tf.logical_and(prediction_subwords__in__true_target_subwords, prediction_subwords_mask), tf.float32))
# Count ground true label subwords that don't exist in the predicted word.
batch_false_positive = tf.reduce_sum(tf.cast(
tf.logical_and(~prediction_subwords__in__true_target_subwords, prediction_subwords_mask), tf.float32))
# Count predicted word subwords that don't exist in the ground true label.
batch_false_negative = tf.reduce_sum(tf.cast(
tf.logical_and(~true_target_subwords__in__prediction_subwords, true_target_subwords_mask), tf.float32))
self.tp.assign_add(batch_true_positive)
self.fp.assign_add(batch_false_positive)
self.fn.assign_add(batch_false_negative)
def _get_prediction_from_topk(self, topk_predicted_words):
# apply given filter
masks = []
if self.predicted_words_filters is not None:
masks = [fltr(topk_predicted_words) for fltr in self.predicted_words_filters]
if masks:
# assert all(mask.shape.assert_is_compatible_with(top_k_pred_indices) for mask in masks)
legal_predicted_target_words_mask = reduce(tf.logical_and, masks)
else:
legal_predicted_target_words_mask = tf.cast(tf.ones_like(topk_predicted_words), dtype=tf.bool)
# the first legal predicted word is our prediction
first_legal_predicted_target_word_mask = common.tf_get_first_true(legal_predicted_target_words_mask)
first_legal_predicted_target_word_idx = tf.where(first_legal_predicted_target_word_mask)
first_legal_predicted_word_string = tf.gather_nd(topk_predicted_words,
first_legal_predicted_target_word_idx)
prediction = tf.reshape(first_legal_predicted_word_string, [-1])
return prediction
@abc.abstractmethod
def result(self):
...
def reset_states(self):
for v in self.variables:
K.set_value(v, 0)
class WordsSubtokenPrecisionMetric(WordsSubtokenMetricBase):
def result(self):
precision = tf.math.divide_no_nan(self.tp, self.tp + self.fp)
return precision
class WordsSubtokenRecallMetric(WordsSubtokenMetricBase):
def result(self):
recall = tf.math.divide_no_nan(self.tp, self.tp + self.fn)
return recall
class WordsSubtokenF1Metric(WordsSubtokenMetricBase):
def result(self):
recall = tf.math.divide_no_nan(self.tp, self.tp + self.fn)
precision = tf.math.divide_no_nan(self.tp, self.tp + self.fp)
f1 = tf.math.divide_no_nan(2 * precision * recall, precision + recall + K.epsilon())
return f1