Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NotFoundError: Graph execution error: TPU #1370

Open
innat opened this issue Mar 10, 2024 · 8 comments
Open

NotFoundError: Graph execution error: TPU #1370

innat opened this issue Mar 10, 2024 · 8 comments
Labels
bug bug & failures with existing packages help wanted

Comments

@innat
Copy link

innat commented Mar 10, 2024

While trying to run the following code on tpu-vm, it didn't work.

tf: 2.15
keras: 3.0.5

tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
strategy = tf.distribute.TPUStrategy(tpu)

def get_compiled_model():
    # Make a simple 2-layer densely-connected neural network.
    inputs = keras.Input(shape=(784,))
    x = keras.layers.Dense(256, activation="relu")(inputs)
    x = keras.layers.Dense(256, activation="relu")(x)
    outputs = keras.layers.Dense(10)(x)
    model = keras.Model(inputs, outputs)
    model.compile(
        optimizer=keras.optimizers.Adam(),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[keras.metrics.SparseCategoricalAccuracy()],
    )
    return model


def get_dataset():
    batch_size = 32
    num_val_samples = 10000

    # Return the MNIST dataset in the form of a [`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset).
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

    # Preprocess the data (these are Numpy arrays)
    x_train = x_train.reshape(-1, 784).astype("float32") / 255
    x_test = x_test.reshape(-1, 784).astype("float32") / 255
    y_train = y_train.astype("float32")
    y_test = y_test.astype("float32")

    # Reserve num_val_samples samples for validation
    x_val = x_train[-num_val_samples:]
    y_val = y_train[-num_val_samples:]
    x_train = x_train[:-num_val_samples]
    y_train = y_train[:-num_val_samples]
    return (
        tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size),
        tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size),
        tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size),
    )

with strategy.scope():
    model_ = get_compiled_model()
    train_dataset, val_dataset, test_dataset = get_dataset()

model_.fit(train_dataset, epochs=2, validation_data=val_dataset)
---------------------------------------------------------------------------
NotFoundError                             Traceback (most recent call last)
Cell In[5], line 1
----> 1 model_.fit(train_dataset, epochs=2, validation_data=val_dataset)

File /usr/local/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:123, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    120     filtered_tb = _process_traceback_frames(e.__traceback__)
    121     # To get the full stack trace, call:
    122     # `keras.config.disable_traceback_filtering()`
--> 123     raise e.with_traceback(filtered_tb) from None
    124 finally:
    125     del filtered_tb

File /usr/local/lib/python3.10/site-packages/tensorflow/python/eager/execute.py:53, in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     51 try:
     52   ctx.ensure_initialized()
---> 53   tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     54                                       inputs, attrs, num_outputs)
     55 except core._NotOkStatusException as e:
     56   if name is not None:

NotFoundError: Graph execution error:

Detected at node TPUReplicate/_compile/_9074053372847989778/_4 defined at (most recent call last):
<stack traces unavailable>
@innat innat added bug bug & failures with existing packages help wanted labels Mar 10, 2024
@innat
Copy link
Author

innat commented Mar 13, 2024

Hi @djherbis, Could you please provide any information regarding this issue? Is there any blockers to use tpu-vm at the moment?

@djherbis
Copy link
Contributor

@innat Could you share a public notebook with the complete code? That makes it a bit easier to debug, thanks!

@innat
Copy link
Author

innat commented Mar 13, 2024

@djherbis Thanks for the quick response. Here is the gist.

@djherbis
Copy link
Contributor

Hey, have you confirmed that Keras is using Tensorflow under the hood?
I took a quick try at this, I switched to tf-cpu, removed the TPU VM + tensorflow related code, and switched to the Keras backend to JAX and then I think it works?

@innat
Copy link
Author

innat commented Mar 14, 2024

I don't fully get your points. However, I was able to run keras with all backend (tf, torch, jax) on cpu and gpu. But as shown in the above gist, for tpu-vm it didn't.

I have run the above gist again with keras+tensorflow and keras+jax setup for tpu. And both fail to run the program.

@djherbis
Copy link
Contributor

I meant when I ran it as Jax without tensorflow on tpuvm then it worked:
https://www.kaggle.com/code/herbison/keras-jax-tpu-vm-model-build-test

Its not too uncommon for something to work on CPU/GPU and not tpu since the actual underlying systems are different.

If possible using the Jax example might be a path forward.

@innat
Copy link
Author

innat commented Mar 15, 2024

Ah, I see.

I als tried following without installing tf-cpu, didn't work though.

tf.config.set_visible_devices([], "TPU")

import keras, jax
devices = jax.devices("tpu")
data_parallel = keras.distribution.DataParallel(devices=devices)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[5], line 4
      1 import keras, jax
----> 4 data_parallel = keras.distribution.DataParallel(devices=devices)
      5 keras.distribution.set_distribution(data_parallel)

File /usr/local/lib/python3.10/site-packages/keras/src/distribution/distribution_lib.py:400, in DataParallel.__init__(self, device_mesh, devices)
    398 self._batch_dim_name = self.device_mesh.axis_names[0]
    399 # Those following attributes might get convert to public methods.
--> 400 self._num_process = distribution_lib.num_processes()
    401 self._process_id = distribution_lib.process_id()
    402 self._is_multi_process = self._num_process > 1

AttributeError: module 'keras.src.backend.tensorflow.distribution_lib' has no attribute 'num_processes'

@djherbis
Copy link
Contributor

Yeah, its impossible to use tensorflow (TPU) install with JAX or Pytorch, and since Keras is calling tensorflow here, thats loading the TPU twice (once for JAX, once for tensorflow) which breaks things.

Installing tensorflow-cpu, and then using JAX (TPU) works though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug bug & failures with existing packages help wanted
Projects
None yet
Development

No branches or pull requests

2 participants