-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
152 lines (120 loc) · 4.53 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import tensorflow as tf
import tensorflow_decision_forests as tfdf
import tensorflow_transform as tft
from absl import logging
from tensorflow.keras import layers, Model, optimizers, losses, metrics
from tfx import v1 as tfx
from tfx_bsl.public import tfxio
from typing import List, Text
import constants
LABEL = constants.LABEL
BATCH_SIZE = 32
EPOCHS = 50
def _input_fn(
file_pattern: List[Text],
data_accessor: tfx.components.DataAccessor,
tf_transform_output: tft.TFTransformOutput,
batch_size: int,
) -> tf.data.Dataset:
"""
Generates a dataset of features that can be used to train
and evaluate the model.
Args:
file_pattern: List of paths or patterns of input data files.
data_accessor: An instance of DataAccessor that we can use to
convert the input to a RecordBatch.
tf_transform_output: The transformation output.
batch_size: The number of consecutive elements that we should
combine in a single batch.
Returns:
A dataset that contains a tuple of (features, indices) where
features is a dictionary of Tensors, and indices is a single
Tensor of label indices.
"""
dataset = data_accessor.tf_dataset_factory(
file_pattern,
tfxio.TensorFlowDatasetOptions(batch_size=batch_size),
schema=tf_transform_output.raw_metadata.schema,
)
tft_layer = tf_transform_output.transform_features_layer()
def apply_transform(raw_features):
transformed_features = tft_layer(raw_features)
transformed_label = transformed_features.pop(LABEL)
return transformed_features, transformed_label
return dataset.map(apply_transform).repeat()
def _get_serve_tf_examples_fn(model, tf_transform_output):
"""
Returns a function that parses a serialized tf.Example and applies
the transformations during inference.
Args:
model: The model that we are serving.
tf_transform_output: The transformation output that we want to
include with the model.
"""
model.tft_layer = tf_transform_output.transform_features_layer()
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string, name="examples")])
def serve_tf_examples_fn(serialized_tf_examples):
feature_spec = tf_transform_output.raw_feature_spec()
required_feature_spec = {
k: v for k, v in feature_spec.items() if k != LABEL
}
parsed_features = tf.io.parse_example(
serialized_tf_examples,
required_feature_spec
)
transformed_features = model.tft_layer(parsed_features)
return model(transformed_features)
return serve_tf_examples_fn
def _model() -> tf.keras.Model:
inputs = [
layers.Input(shape=(1,), name="Age"),
layers.Input(shape=(1,), name="EstimatedSalary"),
layers.Input(shape=(1,), name="Gender")
]
x = layers.concatenate(inputs)
x = layers.Dense(8, activation="relu")(x)
x = layers.Dense(8, activation="relu")(x)
outputs = layers.Dense(1, activation="sigmoid")(x)
model = Model(inputs=inputs, outputs=outputs)
model.compile(
optimizer=optimizers.Adam(1e-2),
loss="binary_crossentropy",
metrics=[metrics.BinaryAccuracy()],
)
model.summary(print_fn=logging.info)
return model
def run_fn(fn_args: tfx.components.FnArgs):
"""
The callback function that will be called by the Trainer component
to train the model using the suplied arguments.
Args:
fn_args: A collection of name/value pairs representing the
arguments to train the model.
"""
tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)
train_dataset = _input_fn(
fn_args.train_files,
fn_args.data_accessor,
tf_transform_output,
batch_size=BATCH_SIZE,
)
eval_dataset = _input_fn(
fn_args.eval_files,
fn_args.data_accessor,
tf_transform_output,
batch_size=BATCH_SIZE,
)
model = _model()
model.fit(
train_dataset,
steps_per_epoch=fn_args.train_steps,
validation_data=eval_dataset,
validation_steps=fn_args.eval_steps,
epochs=EPOCHS
)
# We need to modify the default signature to include the transform layer in
# the computational graph.
signatures = {
"serving_default": _get_serve_tf_examples_fn(model, tf_transform_output),
}
model.save(fn_args.serving_model_dir, save_format="tf", signatures=signatures)