Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conv2DKAN giving graph execution error for non-sequential models. #19

Open
MrigankMIDAS opened this issue Nov 4, 2024 · 7 comments
Open

Comments

@MrigankMIDAS
Copy link

Sir/ma'am
When the user fits a non-sequential models (as in the case of multi-branch models), a graph execution error is encountered.
In particular no error is encountered with the model declaration:
model = tf.keras.Sequential([
...
TimeDistributed(Conv2DKAN(16, (3, 3), padding='same')),
...
])

However, using the form
x=TimeDistributed(Conv2DKAN(32, (3, 3), padding='same'))(x)
raises the said error.

Please, look into matter at your earlier convenience.
Thank you for the help in advance.

@ZPZhou-lab
Copy link
Owner

I will check this error 😊

Could you please provide more error info (the traceback) about this error and what is the TimeDistributed wrapper for(the behavior of this module)

@MrigankMIDAS
Copy link
Author

MrigankMIDAS commented Nov 5, 2024

Please, find below the additional details pertaining to the error:
The model compilation function:

METRICS = ['accuracy', AUC(name='auroc')]
clip_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(label_smoothing=0.1),
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    metrics=METRICS
)

The model fitting function:

early_stopping_callback = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', 
    patience=3, 
    restore_best_weights=True)

reduce_lr_callback = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    patience=2,
    factor=0.1,
    verbose=1)
CALLBACKS = [early_stopping_callback, reduce_lr_callback]
clip_model_history = clip_model.fit(
    train_ds,
    epochs = CFG.EPOCHS,
    validation_data=val_ds,
    callbacks=CALLBACKS)

The model plotting and compilation occurs without raising any error. However, for the model fitting cell raises the following error:

---------------------------------------------------------------------------
InternalError                             Traceback (most recent call last)
Cell In[44], line 6
      3 print(f'Train on {len(train_binary_df)} samples, validate on {len(val_binary_df)} samples.')
      4 print('----------------------------------')
----> 6 clip_model_history = clip_model.fit(
      7     train_ds,
      8     epochs = CFG.EPOCHS,
      9     validation_data=val_ds,
     10     callbacks=CALLBACKS)

File /opt/conda/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:70, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     67     filtered_tb = _process_traceback_frames(e.__traceback__)
     68     # To get the full stack trace, call:
     69     # `tf.debugging.disable_traceback_filtering()`
---> 70     raise e.with_traceback(filtered_tb) from None
     71 finally:
     72     del filtered_tb

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/eager/execute.py:53, in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     51 try:
     52   ctx.ensure_initialized()
---> 53   tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     54                                       inputs, attrs, num_outputs)
     55 except core._NotOkStatusException as e:
     56   if name is not None:

InternalError: Graph execution error:

Detected at node model/time_distributed/conv2dkan/dense_kan/Less defined at (most recent call last):
  File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main

  File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code

  File "/opt/conda/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>

  File "/opt/conda/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance

  File "/opt/conda/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start

  File "/opt/conda/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 205, in start

  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever

  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once

  File "/opt/conda/lib/python3.10/asyncio/events.py", line 80, in _run

  File "/opt/conda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue

  File "/opt/conda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 534, in process_one

  File "/opt/conda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell

  File "/opt/conda/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 362, in execute_request

  File "/opt/conda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 778, in execute_request

  File "/opt/conda/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 449, in do_execute

  File "/opt/conda/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 549, in run_cell

  File "/opt/conda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3051, in run_cell

  File "/opt/conda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3106, in _run_cell

  File "/opt/conda/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner

  File "/opt/conda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3311, in run_cell_async

  File "/opt/conda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3493, in run_ast_nodes

  File "/opt/conda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3553, in run_code

  File "/tmp/ipykernel_31/529747391.py", line 6, in <module>

  File "/opt/conda/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/opt/conda/lib/python3.10/site-packages/keras/src/engine/training.py", line 1807, in fit

  File "/opt/conda/lib/python3.10/site-packages/keras/src/engine/training.py", line 1401, in train_function

  File "/opt/conda/lib/python3.10/site-packages/keras/src/engine/training.py", line 1384, in step_function

  File "/opt/conda/lib/python3.10/site-packages/keras/src/engine/training.py", line 1373, in run_step

  File "/opt/conda/lib/python3.10/site-packages/keras/src/engine/training.py", line 1150, in train_step

  File "/opt/conda/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/opt/conda/lib/python3.10/site-packages/keras/src/engine/training.py", line 590, in __call__

  File "/opt/conda/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/opt/conda/lib/python3.10/site-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/opt/conda/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/opt/conda/lib/python3.10/site-packages/keras/src/engine/functional.py", line 515, in call

  File "/opt/conda/lib/python3.10/site-packages/keras/src/engine/functional.py", line 672, in _run_internal_graph

  File "/opt/conda/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/opt/conda/lib/python3.10/site-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/opt/conda/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/opt/conda/lib/python3.10/site-packages/keras/src/layers/rnn/time_distributed.py", line 246, in call

  File "/opt/conda/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/opt/conda/lib/python3.10/site-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/opt/conda/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/opt/conda/lib/python3.10/site-packages/tfkan/layers/convolution.py", line 114, in call

  File "/opt/conda/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/opt/conda/lib/python3.10/site-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/opt/conda/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/opt/conda/lib/python3.10/site-packages/tfkan/layers/dense.py", line 105, in call

  File "/opt/conda/lib/python3.10/site-packages/tfkan/layers/base.py", line 30, in calc_spline_output

  File "/opt/conda/lib/python3.10/site-packages/tfkan/ops/spline.py", line 26, in calc_spline_values

'cuLaunchKernel(function, gridX, gridY, gridZ, blockX, blockY, blockZ, 0, reinterpret_cast<CUstream>(stream), params, nullptr)' failed with 'CUDA_ERROR_INVALID_HANDLE'
	 [[{{node model/time_distributed/conv2dkan/dense_kan/Less}}]] [Op:__inference_train_function_27597]

Further the execution occurs smoothly when the Conv2DKAN layers are replaced by Conv2D layers. Also using the DenseKAN layer does not result in an error in both Sequential and non-sequential models.

The TimeDistributed layer in Keras is a wrapper that applies a specified layer independently to each time step in a sequence of inputs. In my case it allows me to apply the same processing to each frame in the video sequence.

Once again thank you for your support 😊

@ZPZhou-lab
Copy link
Owner

Let's check the inputs shape together.

From my understanding of your description, I believe your input should be a 5-dimensional tensor with (batch_size, time_steps, width, height, num_channels). I test this case and the code works fine as below:

import tensorflow as tf
from tfkan.layers import Conv2DKAN
from keras.layers import TimeDistributed

layer = TimeDistributed(Conv2DKAN(64, (3, 3), padding='same'))
x = tf.random.normal((2, 10, 32, 32, 3)) # x is a 5D tensor
y = layer(x)
print(y.shape) # (2, 10, 32, 32, 64)

Let me know the inputs tensor you use can help us better find the problem!

@MrigankMIDAS
Copy link
Author

MrigankMIDAS commented Nov 5, 2024

The input shape for the example code is as follows:

(batch_size, time_steps, width, height, num_channels)=(16,10,224,224,3)

The following is the condensed representation of the model:

from tensorflow.keras.models import Model
def kan_single_branch_non_sequential(do=0.005,rl1=0,rl2=0):
    # input for RGB frames
    input_rgb = Input(shape=(10, 224, 224, 3),dtype=tf.float32, name='rgb_input')
    
    x_rgb=TimeDistributed(Conv2DKAN(16, (3, 3), padding='same'))(input_rgb)
    x_rgb=TimeDistributed(MaxPooling2D((4, 4)))(x_rgb)
    x_rgb=TimeDistributed(Dropout(do))(x_rgb)
    #...
    x_rgb=TimeDistributed(Flatten())(x_rgb)
    x = x_rgb
                               
    x=LSTM(32)(x)
                                      
    x=DenseKAN(32)(x)
    x=ActivityRegularization(l1=rl1,l2=rl2)(x)
    output=Dense(6,activation='softmax',)(x)
    
    # model definition
    model = Model(inputs=[input_rgb], outputs=output)
    return model

@ZPZhou-lab
Copy link
Owner

I test your code and it works fine in my env as below:

import tensorflow as tf
from tfkan.layers import Conv2DKAN, DenseKAN
from keras.layers import TimeDistributed, Input, MaxPooling2D, Dropout, Flatten, LSTM, Dense
from tensorflow.keras.models import Model

def kan_single_branch_non_sequential(do=0.005, rl1=0, rl2=0):
    # input for RGB frames
    input_rgb = Input(shape=(10, 224, 224, 3),dtype=tf.float32, name='rgb_input')
    
    x_rgb = TimeDistributed(Conv2DKAN(16, (3, 3), padding='same'))(input_rgb)
    x_rgb = TimeDistributed(MaxPooling2D((4, 4)))(x_rgb)
    x_rgb = TimeDistributed(Dropout(do))(x_rgb)
    
    # ...
    
    x_rgb = TimeDistributed(Flatten())(x_rgb)
    x = x_rgb
                               
    x = LSTM(32)(x)               
    x = DenseKAN(32)(x)
    # x=ActivityRegularization(l1=rl1,l2=rl2)(x)
    output = Dense(6, activation='softmax')(x)
    
    # model definition
    model = Model(inputs=[input_rgb], outputs=output)
    return model

# build and compile
model = kan_single_branch_non_sequential()
model.build(input_shape=(None, 10, 224, 224, 3))
model.compile(
    optimizer='adam', 
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# generate mock data
bz = 1
x = tf.random.normal((bz, 10, 224, 224, 3))
y = tf.random.uniform(shape=(bz,), maxval=6, dtype=tf.int32)

# fit the model
model.fit(x, y, batch_size=bz, epochs=1)

and the model shows the training progress successfully:

1/1 [==============================] - 1s 1s/step - loss: 1.8257 - accuracy: 0.0000e+00

Maybe the graph execution error is caused by OOM? or the version of tensorflow and keras?

You might try to reduce the batch_size or (time_steps, width, height) to test whether the code can run successfully on tiny dataset. Hope this can help u!🤗

@MrigankMIDAS
Copy link
Author

Thank you very much for your valuable input! The graph execution error was indeed due to OOM. I was able to rectify it by reducing frame size from 224x224 to 64x64. I was struggling to pinpoint the issue for weeks. Thank you very much for your time, effort and interest 😊.

@ZPZhou-lab
Copy link
Owner

you are welcome, best😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants