-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrainer.py
95 lines (82 loc) · 2.87 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import numpy as np
import tensorflow as tf
def trainer(model: tf.keras.Model,
loss_fn: tf.keras.losses,
X_train: np.ndarray,
y_train: np.ndarray = None,
optimizer: tf.keras.optimizers = tf.keras.optimizers.Adam(learning_rate=1e-3),
loss_fn_kwargs: dict = None,
epochs: int = 100,
batch_size: int = 1,
buffer_size: int = 1024,
shuffle: bool = False,
verbose: bool = True) -> None:
"""
Train TensorFlow model.
Parameters
----------
model
Model to train.
loss_fn
Loss function used for training.
X_train
Training batch.
y_train
Training labels.
optimizer
Optimizer used for training.
loss_fn_kwargs
Kwargs for loss function.
epochs
Number of training epochs.
batch_size
Batch size used for training.
buffer_size
Maximum number of elements that will be buffered when prefetching.
shuffle
Whether to shuffle training data.
verbose
Whether to print training progress.
"""
# create dataset
if y_train is None: # unsupervised model
train_data = X_train
else:
train_data = (X_train, y_train)
train_data = tf.data.Dataset.from_tensor_slices(train_data)
if shuffle:
train_data = train_data.shuffle(buffer_size=buffer_size).batch(batch_size)
n_minibatch = int(np.ceil(X_train.shape[0] / batch_size))
# iterate over epochs
for epoch in range(epochs):
if verbose:
pbar = tf.keras.utils.Progbar(n_minibatch, 1)
# iterate over the batches of the dataset
for step, train_batch in enumerate(train_data):
if y_train is None:
X_train_batch = train_batch
else:
X_train_batch, y_train_batch = train_batch
with tf.GradientTape() as tape:
preds = model(X_train_batch)
if y_train is None:
ground_truth = X_train_batch
else:
ground_truth = y_train_batch
# compute loss
if tf.is_tensor(preds):
args = [ground_truth, preds]
else:
args = [ground_truth] + list(preds)
if loss_fn_kwargs:
loss = loss_fn(*args, **loss_fn_kwargs)
else:
loss = loss_fn(*args)
if model.losses: # additional model losses
loss += sum(model.losses)
grads = tape.gradient(loss, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
if verbose:
loss_val = loss.numpy().mean()
pbar_values = [('loss', loss_val)]
pbar.add(1, values=pbar_values)