From fcb3c276e0e1af94303013cbe809cac5905b1327 Mon Sep 17 00:00:00 2001 From: Nicolas Date: Sat, 26 Oct 2024 19:33:05 +0200 Subject: [PATCH] Jax tracing fix (#20412) * `JAXTrainer`: refactoring and fixes Fix for https://github.com/keras-team/keras/issues/20402 Fix for https://github.com/keras-team/keras/issues/20411 * CI setup * Fix tests * Revert CI branch to master * `function` -> `iterator_step` --- keras/src/backend/jax/trainer.py | 366 +++++++++--------- .../data_adapters/generator_data_adapter.py | 10 +- keras/src/trainers/trainer.py | 7 +- keras/src/trainers/trainer_test.py | 173 ++++++++- 4 files changed, 372 insertions(+), 184 deletions(-) diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index f1d1cad2dee..7b9523ef5fb 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -1,6 +1,5 @@ import collections import itertools -from functools import partial import jax import numpy as np @@ -224,108 +223,109 @@ def predict_step(self, state, data): ) return outputs, non_trainable_variables - def make_train_function(self, force=False): - if self.train_function is not None and not force: - return + def _make_function(self, step_function, concatenate_outputs=False): + if self.steps_per_execution > 1: + if concatenate_outputs: + + def concatenate(outputs): + output = outputs[0] + for next_output in outputs[1:]: + output = tree.map_structure( + lambda t1, t2: jax.numpy.concatenate([t1, t2]), + output, + next_output, + ) + return output + + if not self.run_eagerly and self.jit_compile: + concatenate = jax.jit(concatenate) + + def iterator_step(state, iterator): + data = next(iterator) + outputs, state = step_function(state, data) + outputs = [outputs] + try: + for _ in range(self.steps_per_execution - 1): + data = next(iterator) + _outputs, state = step_function(state, data) + outputs.append(_outputs) + except StopIteration: + pass + outputs = concatenate(outputs) + return outputs, state - def one_train_step(state, data): - data = data[0] - return self.train_step(state, data) + else: - def multi_train_steps(state, data): - for single_step_data in data: - logs, state = one_train_step(state, [single_step_data]) - return logs, state + def iterator_step(state, iterator): + data = next(iterator) + outputs, state = step_function(state, data) + try: + for _ in range(self.steps_per_execution - 1): + data = next(iterator) + outputs, state = step_function(state, data) + except StopIteration: + pass + return outputs, state - if self.steps_per_execution > 1: - train_step = multi_train_steps else: - train_step = one_train_step + def iterator_step(state, iterator): + return step_function(state, next(iterator)) + + return iterator_step + + def make_train_function(self, force=False): + if self.train_function is not None and not force: + return if not self.run_eagerly and self.jit_compile: - # Note that we mark the state and data to be donated to jax, + # Note that we mark the state to be donated to jax, # so that jax will reuse the memory buffer for outputs. # This will reduce the memory usage of the training function by # half. - @partial(jax.jit, donate_argnames="state") - def compiled_train_step(state, data): - return train_step(state, data) + train_step = jax.jit(self.train_step, donate_argnums=0) + else: + train_step = self.train_step - self.train_function = compiled_train_step + step_function = self._make_function(train_step) - else: - self.train_function = train_step + self.train_function = step_function def make_test_function(self, force=False): if self.test_function is not None and not force: return - - def one_test_step(state, data): - data = data[0] - return self.test_step(state, data) - - def multi_test_steps(state, data): - for single_step_data in data: - logs, state = one_test_step(state, [single_step_data]) - return logs, state - - if self.steps_per_execution > 1: - test_step = multi_test_steps - else: - test_step = one_test_step - if not self.run_eagerly and self.jit_compile: - # Note that we mark the state and data to be donated to jax, + # Note that we mark the state to be donated to jax, # so that jax will reuse the memory buffer for outputs. # This will reduce the memory usage of the training function by # half. - @partial(jax.jit, donate_argnames="state") - def compiled_test_step(state, data): - return test_step(state, data) + test_step = jax.jit(self.test_step, donate_argnums=0) + else: + test_step = self.test_step - self.test_function = compiled_test_step + step_function = self._make_function(test_step) - else: - self.test_function = test_step + self.test_function = step_function def make_predict_function(self, force=False): if self.predict_function is not None and not force: return self.predict_function - def one_predict_step(state, data): - data = data[0] - return self.predict_step(state, data) - - def multi_predict_steps(state, data): - outputs, trainable_variables = one_predict_step(state, data[:1]) - - for single_step_data in data[1:]: - step_outputs, trainable_variables = one_predict_step( - state, - [single_step_data], - ) - outputs = tree.map_structure( - lambda t1, t2: jax.numpy.concatenate([t1, t2]), - outputs, - step_outputs, - ) - return outputs, trainable_variables - - if self.steps_per_execution > 1: - predict_step = multi_predict_steps - else: - predict_step = one_predict_step + def predict_step(state, data): + outputs, non_trainable_variables = self.predict_step(state, data) + return outputs, (state[0], non_trainable_variables) if not self.run_eagerly and self.jit_compile: + predict_step = jax.jit(predict_step) - @jax.jit - def compiled_predict_step(state, data): - return predict_step(state, data) + _step_function = self._make_function( + predict_step, concatenate_outputs=True + ) - self.predict_function = compiled_predict_step + def step_function(state, iterator): + outputs, state = _step_function(state, iterator) + return outputs, state[1] - else: - self.predict_function = predict_step + self.predict_function = step_function @traceback_utils.filter_traceback def fit( @@ -406,45 +406,46 @@ def fit( callbacks.on_epoch_begin(epoch) self._jax_state_synced = True - for step, data in epoch_iterator: - # Callbacks - callbacks.on_train_batch_begin(step) - - # Train step - if self._jax_state_synced: - # The state may have been synced by a callback. - state = self._get_jax_state( - trainable_variables=True, - non_trainable_variables=True, - optimizer_variables=True, - metrics_variables=True, - purge_model_variables=True, - ) - self._jax_state_synced = False - - logs, state = self.train_function(state, data) - ( - trainable_variables, - non_trainable_variables, - optimizer_variables, - metrics_variables, - ) = state - - # Setting _jax_state enables callbacks to force a state sync - # if they need to. - self._jax_state = { - "trainable_variables": trainable_variables, - "non_trainable_variables": non_trainable_variables, - "optimizer_variables": optimizer_variables, - "metrics_variables": metrics_variables, - } - # Dispatch callbacks. This takes care of async dispatch. - callbacks.on_train_batch_end(step, logs) - - if self.stop_training: - # Stop training if a callback has set - # this flag in on_(train_)batch_end. - break + with epoch_iterator.catch_stop_iteration(): + for step, iterator in epoch_iterator: + # Callbacks + callbacks.on_train_batch_begin(step) + + # Train step + if self._jax_state_synced: + # The state may have been synced by a callback. + state = self._get_jax_state( + trainable_variables=True, + non_trainable_variables=True, + optimizer_variables=True, + metrics_variables=True, + purge_model_variables=True, + ) + self._jax_state_synced = False + + logs, state = self.train_function(state, iterator) + ( + trainable_variables, + non_trainable_variables, + optimizer_variables, + metrics_variables, + ) = state + + # Setting _jax_state enables callbacks to force a state sync + # if they need to. + self._jax_state = { + "trainable_variables": trainable_variables, + "non_trainable_variables": non_trainable_variables, + "optimizer_variables": optimizer_variables, + "metrics_variables": metrics_variables, + } + # Dispatch callbacks. This takes care of async dispatch. + callbacks.on_train_batch_end(step, logs) + + if self.stop_training: + # Stop training if a callback has set + # this flag in on_(train_)batch_end. + break # Reattach state to the model (if not already done by a callback). # NOTE: doing this after each step would be a big performance @@ -562,41 +563,42 @@ def evaluate( self.reset_metrics() self._jax_state_synced = True - for step, data in epoch_iterator: - callbacks.on_test_batch_begin(step) - - if self._jax_state_synced: - # The state may have been synced by a callback. - state = self._get_jax_state( - trainable_variables=True, - non_trainable_variables=True, - metrics_variables=True, - purge_model_variables=True, - ) - self._jax_state_synced = False + with epoch_iterator.catch_stop_iteration(): + for step, iterator in epoch_iterator: + callbacks.on_test_batch_begin(step) - logs, state = self.test_function(state, data) - ( - trainable_variables, - non_trainable_variables, - metrics_variables, - ) = state - - # Setting _jax_state enables callbacks to force a state sync - # if they need to. - self._jax_state = { - # I wouldn't recommend modifying non-trainable model state - # during evaluate(), but it's allowed. - "trainable_variables": trainable_variables, - "non_trainable_variables": non_trainable_variables, - "metrics_variables": metrics_variables, - } - - # Dispatch callbacks. This takes care of async dispatch. - callbacks.on_test_batch_end(step, logs) - - if self.stop_evaluating: - break + if self._jax_state_synced: + # The state may have been synced by a callback. + state = self._get_jax_state( + trainable_variables=True, + non_trainable_variables=True, + metrics_variables=True, + purge_model_variables=True, + ) + self._jax_state_synced = False + + logs, state = self.test_function(state, iterator) + ( + trainable_variables, + non_trainable_variables, + metrics_variables, + ) = state + + # Setting _jax_state enables callbacks to force a state sync + # if they need to. + self._jax_state = { + # I wouldn't recommend modifying non-trainable model state + # during evaluate(), but it's allowed. + "trainable_variables": trainable_variables, + "non_trainable_variables": non_trainable_variables, + "metrics_variables": metrics_variables, + } + + # Dispatch callbacks. This takes care of async dispatch. + callbacks.on_test_batch_end(step, logs) + + if self.stop_evaluating: + break # Reattach state back to model (if not already done by a callback). self.jax_state_sync() @@ -627,13 +629,15 @@ def predict( if not all(layer.built for layer in self._flatten_layers()): # Build the model on one batch of data. - for _, data in epoch_iterator: + for _, iterator in epoch_iterator: # Build model - x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data[0]) + x, _, _ = data_adapter_utils.unpack_x_y_sample_weight( + next(iterator) + ) with backend.StatelessScope(): self(x) break - + epoch_iterator.reset() # Container that configures and calls callbacks. if not isinstance(callbacks, callbacks_module.CallbackList): callbacks = callbacks_module.CallbackList( @@ -669,28 +673,29 @@ def append_to_outputs(batch_outputs, outputs): self._jax_state_synced = True outputs = None non_trainable_variables = None - for step, x in epoch_iterator: - callbacks.on_predict_batch_begin(step) - if self._jax_state_synced: - # The state may have been synced by a callback. - state = self._get_jax_state( - trainable_variables=True, - non_trainable_variables=True, + with epoch_iterator.catch_stop_iteration(): + for step, iterator in epoch_iterator: + callbacks.on_predict_batch_begin(step) + if self._jax_state_synced: + # The state may have been synced by a callback. + state = self._get_jax_state( + trainable_variables=True, + non_trainable_variables=True, + ) + self._purge_model_variables(non_trainable_variables=True) + self._jax_state_synced = False + else: + state = (state[0], non_trainable_variables) + batch_outputs, non_trainable_variables = self.predict_function( + state, iterator ) - self._purge_model_variables(non_trainable_variables=True) - self._jax_state_synced = False - else: - state = (state[0], non_trainable_variables) - batch_outputs, non_trainable_variables = self.predict_function( - state, x - ) - outputs = append_to_outputs(batch_outputs, outputs) + outputs = append_to_outputs(batch_outputs, outputs) - # Dispatch callbacks. This takes care of async dispatch. - callbacks.on_predict_batch_end(step, {"outputs": batch_outputs}) + # Dispatch callbacks. This takes care of async dispatch. + callbacks.on_predict_batch_end(step, {"outputs": batch_outputs}) - if self.stop_predicting: - break + if self.stop_predicting: + break self._jax_state = { # I wouldn't recommend modifying non-trainable model state @@ -722,11 +727,12 @@ def train_on_batch( sample_weight = data_adapter_utils.class_weight_to_sample_weights( y, class_weight ) - data = (x, y, sample_weight) - data = _distribute_data(data) + + def data(): + yield _distribute_data((x, y, sample_weight)) # Maybe build model - self._symbolic_build(data_batch=data) + self._symbolic_build(data_batch=next(data())) self._record_training_state_sharding_spec() self.make_train_function() @@ -739,7 +745,7 @@ def train_on_batch( purge_model_variables=False, ) self._jax_state_synced = False - logs, state = self.train_function(state, [data]) + logs, state = self.train_function(state, data()) # State sync ( @@ -771,10 +777,11 @@ def test_on_batch( ): self._assert_compile_called("test_on_batch") - data = (x, y, sample_weight) - data = _distribute_data(data) + def data(): + yield _distribute_data((x, y, sample_weight)) + # Maybe build model - self._symbolic_build(data_batch=data) + self._symbolic_build(data_batch=next(data())) self._record_training_state_sharding_spec() self.make_test_function() @@ -786,7 +793,7 @@ def test_on_batch( purge_model_variables=False, ) self._jax_state_synced = False - logs, state = self.test_function(state, [data]) + logs, state = self.test_function(state, data()) # State sync trainable_variables, non_trainable_variables, metrics_variables = state @@ -818,8 +825,12 @@ def predict_on_batch(self, x): purge_model_variables=False, ) self._jax_state_synced = False + + def data(): + yield (x,) + batch_outputs, non_trainable_variables = self.predict_function( - state, [(x,)] + state, data() ) self._jax_state = { "non_trainable_variables": non_trainable_variables, @@ -992,6 +1003,9 @@ def _distribute_data(data, layouts=None): class JAXEpochIterator(EpochIterator): + def __next__(self): + return next(self._epoch_iterator) + def _get_iterator(self): distribution = distribution_lib.distribution() if distribution is not None: diff --git a/keras/src/trainers/data_adapters/generator_data_adapter.py b/keras/src/trainers/data_adapters/generator_data_adapter.py index 7f241838842..50603e99c7d 100644 --- a/keras/src/trainers/data_adapters/generator_data_adapter.py +++ b/keras/src/trainers/data_adapters/generator_data_adapter.py @@ -23,10 +23,10 @@ def __init__(self, generator): ) def get_numpy_iterator(self): - return data_adapter_utils.get_numpy_iterator(self.generator) + return data_adapter_utils.get_numpy_iterator(self.generator()) def get_jax_iterator(self): - return data_adapter_utils.get_jax_iterator(self.generator) + return data_adapter_utils.get_jax_iterator(self.generator()) def get_tf_dataset(self): from keras.src.utils.module_utils import tensorflow as tf @@ -49,7 +49,7 @@ def convert_to_tf(x, spec): return x def get_tf_iterator(): - for batch in self.generator: + for batch in self.generator(): batch = tree.map_structure( convert_to_tf, batch, self._output_signature ) @@ -67,7 +67,7 @@ def get_tf_iterator(): return ds def get_torch_dataloader(self): - return data_adapter_utils.get_torch_dataloader(self.generator) + return data_adapter_utils.get_torch_dataloader(self.generator()) @property def num_batches(self): @@ -84,4 +84,4 @@ def peek_and_restore(generator): generator, data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC ) ) - return batches, itertools.chain(batches, generator) + return batches, lambda: itertools.chain(batches, generator) diff --git a/keras/src/trainers/trainer.py b/keras/src/trainers/trainer.py index 0775c4f8628..ad0bc14b9d7 100644 --- a/keras/src/trainers/trainer.py +++ b/keras/src/trainers/trainer.py @@ -1050,8 +1050,11 @@ def to_symbolic_input(v): ) if data_batch is None: - for _, data in iterator: - data_batch = data[0] + for _, data_or_iterator in iterator: + if isinstance(data_or_iterator, (list, tuple)): + data_batch = data_or_iterator[0] + else: + data_batch = next(data_or_iterator) break data_batch = tree.map_structure(to_symbolic_input, data_batch) ( diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py index 02e7ac365fc..d1ad16a4c57 100644 --- a/keras/src/trainers/trainer_test.py +++ b/keras/src/trainers/trainer_test.py @@ -655,7 +655,11 @@ def test_predict_sparse(self, generator_type, mode): jit_compile=False, ) dataset = sparse_generator(generator_type) - model.predict(dataset) + dataset_size = sum( + [batch[1].shape[0] for batch in sparse_generator(generator_type)] + ) + y = model.predict(dataset) + self.assertEqual(len(y), dataset_size) @pytest.mark.skipif( backend.backend() != "jax", @@ -782,6 +786,100 @@ def test_steps_per_execution_steps_count(self, steps_per_execution, mode): ) self.assertAllClose(model.evaluate(x, y), model_2.evaluate(x, y)) + @parameterized.named_parameters( + named_product( + steps_per_execution=[1, 50], mode=["eager", "non_jit", "jit"] + ) + ) + def test_predict_preserve_order(self, steps_per_execution, mode): + if steps_per_execution > 1 and backend.backend() == "torch": + self.skipTest("`steps_per_execution` not implemented for torch yet") + + def generate_uneven_batches(): + batch_sizes = [2, 3, 4] + + def gen_i(): + for i in range(100): + yield i + + iterator = iter(gen_i()) + j = 0 + while True: + batch_size = batch_sizes[j % len(batch_sizes)] + try: + batch = np.array( + [next(iterator) for _ in range(batch_size)] + ) + except StopIteration: + break + j += 1 + yield batch + + from keras.src.utils.module_utils import tensorflow as tf + + dataset = tf.data.Dataset.from_generator( + generate_uneven_batches, + output_signature=tf.TensorSpec((None,), dtype=tf.int32), + ) + x = keras.layers.Input(shape=()) + y = keras.layers.Identity()(x) + model = keras.Model(x, y) + model.compile( + loss="mse", + optimizer="sgd", + steps_per_execution=steps_per_execution, + run_eagerly=(mode == "eager"), + jit_compile=(mode == "jit"), + ) + + preds = model.predict(x=dataset, verbose=0) + + self.assertAllEqual(preds, np.arange(len(preds), dtype=np.float32)) + + @parameterized.named_parameters( + named_product( + steps_per_execution=[1, 50], mode=["eager", "non_jit", "jit"] + ) + ) + def test_predict_generator(self, steps_per_execution, mode): + if steps_per_execution > 1 and backend.backend() == "torch": + self.skipTest("`steps_per_execution` not implemented for torch yet") + + batch_size = 2 + + def generate_batches(): + def gen_i(): + for i in range(10): + yield i + + iterator = iter(gen_i()) + j = 0 + while True: + try: + batch = np.array( + [next(iterator) for _ in range(batch_size)] + ) + except StopIteration: + break + j += 1 + yield (batch,) + + model = keras.Sequential( + [keras.layers.InputLayer(shape=()), keras.layers.Identity()] + ) + model.compile( + loss="mse", + optimizer="sgd", + steps_per_execution=steps_per_execution, + run_eagerly=(mode == "eager"), + jit_compile=(mode == "jit"), + ) + + preds = model.predict(x=generate_batches(), verbose=0) + self.assertAllEqual( + preds, np.concatenate(list(generate_batches()), axis=1)[0] + ) + @parameterized.named_parameters( named_product( steps_per_execution=[3, 101], mode=["eager", "non_jit", "jit"] @@ -2277,6 +2375,79 @@ def test_jit_compile_with_tf_determinism(self): self.assertFalse(model.jit_compile) disable_op_determinism() + @pytest.mark.requires_trainable_backend + @pytest.mark.skipif( + backend.backend() == "torch", + reason="`steps_per_execution` not implemented for torch yet", + ) + def test_retracing(self): + x = np.ones((100, 4)) + y = np.ones((100, 1)) + + input = keras.Input(shape=[4]) + output = keras.layers.Dense(1, activation="relu")(input) + + tracing_count = [0] + + class TracingCounterModel(keras.Model): + def train_step(self, *args): + tracing_count[0] = tracing_count[0] + 1 + return super().train_step(*args) + + model = TracingCounterModel(inputs=input, outputs=output) + model.compile( + loss="mse", + optimizer="adam", + steps_per_execution=20, + ) + + epochs = 1 + model.fit( + x=x, + y=y, + batch_size=1, + epochs=epochs, + verbose=0, + ) + self.assertLessEqual(tracing_count[0], 2) + + @pytest.mark.requires_trainable_backend + @pytest.mark.skipif( + backend.backend() == "torch", + reason="`steps_per_execution` not implemented for torch yet", + ) + @pytest.mark.skipif( + backend.backend() == "tensorflow", + reason="`predict_function` with `steps_per_execution` is not " + "optimized for tensorflow yet", + ) + def test_retracing_predict(self): + x = np.ones((100, 4)) + + input = keras.Input(shape=[4]) + output = keras.layers.Dense(1, activation="relu")(input) + + tracing_count = [0] + + class TracingCounterModel(keras.Model): + def predict_step(self, *args): + tracing_count[0] = tracing_count[0] + 1 + return super().predict_step(*args) + + model = TracingCounterModel(inputs=input, outputs=output) + model.compile( + loss="mse", + optimizer="adam", + steps_per_execution=20, + ) + + model.predict( + x=x, + batch_size=1, + verbose=0, + ) + self.assertLessEqual(tracing_count[0], 2) + class TrainerDistributeTest(testing.TestCase): @pytest.mark.skipif(