forked from aspamers/siamese
-
Notifications
You must be signed in to change notification settings - Fork 0
/
siamese.py
266 lines (222 loc) · 11.6 KB
/
siamese.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
"""
Siamese neural network module.
"""
import random
import numpy as np
from keras.layers import Input
from keras.models import Model
class SiameseNetwork:
"""
A simple and lightweight siamese neural network implementation.
The SiameseNetwork class requires the base and head model to be defined via the constructor. The class exposes
public methods that allow it to behave similarly to a regular Keras model by passing kwargs through to the
underlying keras model object where possible. This allows Keras features like callbacks and metrics to be used.
"""
def __init__(self, base_model, head_model):
"""
Construct the siamese model class with the following structure.
-------------------------------------------------------------------
input1 -> base_model |
--> embedding --> head_model --> binary output
input2 -> base_model |
-------------------------------------------------------------------
:param base_model: The embedding model.
* Input shape must be equal to that of data.
:param head_model: The discriminator model.
* Input shape must be equal to that of embedding
* Output shape must be equal to 1..
"""
# Set essential parameters
self.base_model = base_model
self.head_model = head_model
# Get input shape from base model
self.input_shape = self.base_model.input_shape[1:]
# Initialize siamese model
self.siamese_model = None
self.__initialize_siamese_model()
def compile(self, *args, **kwargs):
"""
Configures the model for training.
Passes all arguments to the underlying Keras model compile function.
"""
self.siamese_model.compile(*args, **kwargs)
def fit(self, *args, **kwargs):
"""
Trains the model on data generated batch-by-batch using the siamese network generator function.
Redirects arguments to the fit_generator function.
"""
x_train = args[0]
y_train = args[1]
x_test, y_test = kwargs.pop('validation_data')
batch_size = kwargs.pop('batch_size')
train_generator = self.__pair_generator(x_train, y_train, batch_size)
train_steps = max(len(x_train) / batch_size, 1)
test_generator = self.__pair_generator(x_test, y_test, batch_size)
test_steps = max(len(x_test) / batch_size, 1)
self.siamese_model.fit_generator(train_generator,
steps_per_epoch=train_steps,
validation_data=test_generator,
validation_steps=test_steps, **kwargs)
def fit_generator(self, x_train, y_train, x_test, y_test, batch_size, *args, **kwargs):
"""
Trains the model on data generated batch-by-batch using the siamese network generator function.
:param x_train: Training input data.
:param y_train: Training output data.
:param x_test: Validation input data.
:param y_test: Validation output data.
:param batch_size: Number of pairs to generate per batch.
"""
train_generator = self.__pair_generator(x_train, y_train, batch_size)
train_steps = max(len(x_train) / batch_size, 1)
test_generator = self.__pair_generator(x_test, y_test, batch_size)
test_steps = max(len(x_test) / batch_size, 1)
self.siamese_model.fit_generator(train_generator,
steps_per_epoch=train_steps,
validation_data=test_generator,
validation_steps=test_steps,
*args, **kwargs)
def load_weights(self, checkpoint_path):
"""
Load siamese model weights. This also affects the reference to the base and head models.
:param checkpoint_path: Path to the checkpoint file.
"""
self.siamese_model.load_weights(checkpoint_path)
def evaluate(self, *args, **kwargs):
"""
Evaluate the siamese network with the same generator that is used to train it. Passes arguments through to the
underlying Keras function so that callbacks etc can be used.
Redirects arguments to the evaluate_generator function.
:return: A tuple of scores
"""
x = args[0]
y = args[1]
batch_size = kwargs.pop('batch_size')
generator = self.__pair_generator(x, y, batch_size)
steps = len(x) / batch_size
return self.siamese_model.evaluate_generator(generator, steps=steps, **kwargs)
def evaluate_generator(self, x, y, batch_size, *args, **kwargs):
"""
Evaluate the siamese network with the same generator that is used to train it. Passes arguments through to the
underlying Keras function so that callbacks etc can be used.
:param x: Input data
:param y: Class labels
:param batch_size: Number of pairs to generate per batch.
:return: A tuple of scores
"""
generator = self.__pair_generator(x, y, batch_size=batch_size)
steps = len(x) / batch_size
return self.siamese_model.evaluate_generator(generator, steps=steps, *args, **kwargs)
def __initialize_siamese_model(self):
"""
Create the siamese model structure using the supplied base and head model.
"""
input_a = Input(shape=self.input_shape)
input_b = Input(shape=self.input_shape)
processed_a = self.base_model(input_a)
processed_b = self.base_model(input_b)
head = self.head_model([processed_a, processed_b])
self.siamese_model = Model([input_a, input_b], head)
def __create_pairs(self, x, class_indices, batch_size, num_classes):
"""
Create a numpy array of positive and negative pairs and their associated labels.
:param x: Input data
:param class_indices: A python list of lists that contains each of the indices in the input data that belong
to each class. It is used to find and access elements in the input data that belong to a desired class.
* Example usage:
* element_index = class_indices[class][index]
* element = x[element_index]
:param batch_size: The number of pair samples to create.
:param num_classes: number of classes in the supplied input data
:return: A tuple of (Numpy array of pairs, Numpy array of labels)
"""
num_pairs = batch_size / 2
positive_pairs, positive_labels = self.__create_positive_pairs(x, class_indices, num_pairs, num_classes)
negative_pairs, negative_labels = self.__create_negative_pairs(x, class_indices, num_pairs, num_classes)
return np.array(positive_pairs + negative_pairs), np.array(positive_labels + negative_labels)
def __create_positive_pairs(self, x, class_indices, num_positive_pairs, num_classes):
"""
Create a list of positive pairs and labels. A positive pair is defined as two input samples of the same class.
:param x: Input data
:param class_indices: A python list of lists that contains each of the indices in the input data that belong
to each class. It is used to find and access elements in the input data that belong to a desired class.
* Example usage:
* element_index = class_indices[class][index]
* element = x[element_index]
:param num_positive_pairs: The number of positive pair samples to create.
:param num_classes: number of classes in the supplied input data
:return: A tuple of (python list of positive pairs, python list of positive labels)
"""
positive_pairs = []
positive_labels = []
for _ in range(int(num_positive_pairs)):
class_1 = random.randint(0, num_classes - 1)
num_elements = len(class_indices[class_1])
index_1, index_2 = self.__randint_unequal(0, num_elements - 1)
element_index_1, element_index_2 = class_indices[class_1][index_1], class_indices[class_1][index_2]
positive_pairs.append([x[element_index_1], x[element_index_2]])
positive_labels.append([1.0])
return positive_pairs, positive_labels
def __create_negative_pairs(self, x, class_indices, num_negative_pairs, num_classes):
"""
Create a list of negative pairs and labels. A negative pair is defined as two input samples of different class.
:param x: Input data
:param class_indices: A python list of lists that contains each of the indices in the input data that belong
to each class. It is used to find and access elements in the input data that belong to a desired class.
* Example usage:
* element_index = class_indices[class][index]
* element = x[element_index]
:param num_negative_pairs: The number of negative pair samples to create.
:param num_classes: number of classes in the supplied input data
:return: A tuple of (python list of negative pairs, python list of negative labels)
"""
negative_pairs = []
negative_labels = []
for _ in range(int(num_negative_pairs)):
cls_1, cls_2 = self.__randint_unequal(0, num_classes - 1)
index_1 = random.randint(0, len(class_indices[cls_1]) - 1)
index_2 = random.randint(0, len(class_indices[cls_2]) - 1)
element_index_1, element_index_2 = class_indices[cls_1][index_1], class_indices[cls_2][index_2]
negative_pairs.append([x[element_index_1], x[element_index_2]])
negative_labels.append([0.0])
return negative_pairs, negative_labels
def __pair_generator(self, x, y, batch_size):
"""
Creates a python generator that produces pairs from the original input data.
:param x: Input data
:param y: Integer class labels
:param batch_size: The number of pair samples to create per batch.
:return:
"""
class_indices, num_classes = self.__get_class_indices(y)
while True:
pairs, labels = self.__create_pairs(x, class_indices, batch_size, num_classes)
# The siamese network expects two inputs and one output. Split the pairs into a list of inputs.
yield [pairs[:, 0], pairs[:, 1]], labels
def __get_class_indices(self, y):
"""
Create a python list of lists that contains each of the indices in the input data that belong
to each class. It is used to find and access elements in the input data that belong to a desired class.
* Example usage:
* element_index = class_indices[class][index]
* element = x[element_index]
:param y: Integer class labels
:return: Python list of lists
"""
num_classes = np.max(y) + 1
return [np.where(y == i)[0] for i in range(num_classes)], num_classes
@staticmethod
def __randint_unequal(lower, upper):
"""
Get two random integers that are not equal.
Note: In some cases (such as there being only one sample of a class) there may be an endless loop here. This
will only happen on fairly exotic datasets though. May have to address in future.
:param lower: Lower limit inclusive of the random integer.
:param upper: Upper limit inclusive of the random integer. Need to use -1 for random indices.
:return: Tuple of (integer, integer)
"""
int_1 = random.randint(lower, upper)
int_2 = random.randint(lower, upper)
while int_1 == int_2:
int_1 = random.randint(lower, upper)
int_2 = random.randint(lower, upper)
return int_1, int_2