Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deepspeech CUDNN_STATUS_EXECUTION_FAILED #523

Closed
priyakasimbeg opened this issue Sep 28, 2023 · 4 comments
Closed

Deepspeech CUDNN_STATUS_EXECUTION_FAILED #523

priyakasimbeg opened this issue Sep 28, 2023 · 4 comments
Labels
🚀 Launch Blocker Issues that are blocking launch of benchmark P1 Launch 2023 High priority issues for October 2023 AlgoPerf Launch

Comments

@priyakasimbeg
Copy link
Contributor

Deepspeech returns CUDNN_STATUS_EXECUTION_FAILED error when calling the cudnnRNNForward layer.

Description

Traceback:

I0922 23:02:03.070308 140024268789568 spec.py:333] Evaluating on the validation split.
I0922 23:02:03.270553 140024268789568 input_pipeline.py:20] Loading split = dev-clean
I0922 23:02:03.27[48](https://github.com/mlcommons/algorithmic-efficiency/actions/runs/6277930012/job/17073112058#step:3:49)69 140024268789568 input_pipeline.py:20] Loading split = dev-other
I0922 23:03:08.161050 140024268789568 spec.py:3[49](https://github.com/mlcommons/algorithmic-efficiency/actions/runs/6277930012/job/17073112058#step:3:50)] Evaluating on the test split.
I0922 23:03:08.366027 140024268789568 input_pipeline.py:20] Loading split = test-clean
2023-09-22 23:03:16.1873[52](https://github.com/mlcommons/algorithmic-efficiency/actions/runs/6277930012/job/17073112058#step:3:53): E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2469] Execution of replica 6 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: jaxlib/gpu/rnn_kernels.cc:256: operation cudnnRNNForward( handle.get(), rnn_desc, fwdMode, (const int32_t*)seq_lengths_buf, input_data_desc, input_buf, output_data_desc, output_buf, h_desc, h_0_buf, h_n_buf, c_desc, c_0_buf, c_n_buf, weight_space_size, weights_buf, d.workspace_size, workspace_buf, d.reserve_space_size, reserve_space_buf) failed: CUDNN_STATUS_EXECUTION_FAILED.

Steps to Reproduce

Git commit: ae3587d

python3 submission_runner.py --framework=jax --workload=librispeech_deepspeech --submission_path=baselines/adamw/jax/submission.py --tuning_search_space=baselines/adamw/tuning_search_space.json --data_dir=/data/librispeech --num_tuning_trials=1 --experiment_dir=/experiment_runs --experiment_name=tests/regression_tests/adamw --overwrite=True --save_checkpoints=False --max_global_steps=10 --librispeech_tokenizer_vocab_path=/data/librispeech/spm_model.vocab 2

Source or Possible Fix

The deepspeech regression test failed in #511. I wrongly thought this was a transient issue.
Traced the change in behavior to fixes in our shard_and_maybe_pad_np function #515.

@priyakasimbeg
Copy link
Contributor Author

Padding gotchas are commom w deepspeech. Will debug this w @sourabh2k15 tomorrow.

@priyakasimbeg priyakasimbeg added the 🚀 Launch Blocker Issues that are blocking launch of benchmark label Sep 28, 2023
@priyakasimbeg
Copy link
Contributor Author

This error happens only in evaluation of the test set for the padded batch of the test set. The final batch of the test set has a bsz of 168 (pre-padding) and will be padded to 256 and sharded over 8 devices.
The final batch in the validation set has bsz 228 (pre-padding) and does not trigger the above error.
The error also only happens with the CuDNNLSTM layer and not on the legacy jax LSTM layer.
I suspect there is some numerical issue with the sharded batches that just have constant padding values throughout and the CuDNNLSTM layer.

@priyakasimbeg priyakasimbeg added the P1 Launch 2023 High priority issues for October 2023 AlgoPerf Launch label Oct 3, 2023
priyakasimbeg added a commit that referenced this issue Oct 5, 2023
Remove global_batch_size arg in call to shard_and_maybe_pad batch call. This will result in the final batch of the validation and test sets for librispeech being just padded just enough so that it can be split equally amongst the devices. So we will not have device batches containing all padding. 
Workaround for #523.
@priyakasimbeg
Copy link
Contributor Author

The issue is that cuDNN is unable to handle length 0 sequences. Jax team filed jax-ml/jax#17966. And in the meantime we will change the padding for the librispeech eval and test sets.

@priyakasimbeg
Copy link
Contributor Author

Closing this since it is obsolute with work around. Although it is unclear if the underlying jax / cudnn issue has been resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🚀 Launch Blocker Issues that are blocking launch of benchmark P1 Launch 2023 High priority issues for October 2023 AlgoPerf Launch
Projects
None yet
Development

No branches or pull requests

1 participant