Skip to content

Commit

Permalink
Changed default shuffle behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
bill-lotter committed Aug 30, 2016
1 parent 07c3ce1 commit 8fb1b9d
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
3 changes: 2 additions & 1 deletion data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def __init__(self, data_file, source_file, nt,
curr_location += 1
self.possible_starts = possible_starts

self.possible_starts = np.random.permutation(self.possible_starts)
if shuffle:
self.possible_starts = np.random.permutation(self.possible_starts)
if N_seq is not None and len(self.possible_starts) > N_seq: # select a subset of sequences if want to
self.possible_starts = self.possible_starts[:N_seq]
self.N_sequences = len(self.possible_starts)
Expand Down
4 changes: 3 additions & 1 deletion kitti_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
layer_config['output_mode'] = 'prediction'
dim_ordering = layer_config['dim_ordering']
test_prednet = PredNet(weights=train_model.layers[1].get_weights(), **layer_config)
inputs = Input(shape=train_model.layers[0].batch_input_shape[1:])
input_shape = list(train_model.layers[0].batch_input_shape[1:])
input_shape[0] = nt
inputs = Input(shape=tuple(input_shape))
predictions = test_prednet(inputs)
test_model = Model(input=inputs, output=predictions)

Expand Down
2 changes: 1 addition & 1 deletion kitti_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
model = Model(input=inputs, output=final_errors)
model.compile(loss='mean_absolute_error', optimizer='adam')

train_generator = SequenceGenerator(train_file, train_sources, nt, batch_size=batch_size)
train_generator = SequenceGenerator(train_file, train_sources, nt, batch_size=batch_size, shuffle=True)
val_generator = SequenceGenerator(val_file, val_sources, nt, batch_size=batch_size, N_seq=N_seq_val)

lr_schedule = lambda epoch: 0.001 if epoch < 75 else 0.0001 # start with lr of 0.001 and then drop to 0.0001 after 75 epochs
Expand Down

0 comments on commit 8fb1b9d

Please sign in to comment.