-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deepspeech CUDNN_STATUS_EXECUTION_FAILED #523
Comments
Padding gotchas are commom w deepspeech. Will debug this w @sourabh2k15 tomorrow. |
This error happens only in evaluation of the test set for the padded batch of the test set. The final batch of the test set has a bsz of 168 (pre-padding) and will be padded to 256 and sharded over 8 devices. |
Remove global_batch_size arg in call to shard_and_maybe_pad batch call. This will result in the final batch of the validation and test sets for librispeech being just padded just enough so that it can be split equally amongst the devices. So we will not have device batches containing all padding. Workaround for #523.
The issue is that cuDNN is unable to handle length 0 sequences. Jax team filed jax-ml/jax#17966. And in the meantime we will change the padding for the librispeech eval and test sets. |
Closing this since it is obsolute with work around. Although it is unclear if the underlying jax / cudnn issue has been resolved. |
Deepspeech returns CUDNN_STATUS_EXECUTION_FAILED error when calling the cudnnRNNForward layer.
Description
Traceback:
Steps to Reproduce
Git commit: ae3587d
Source or Possible Fix
The deepspeech regression test failed in #511. I wrongly thought this was a transient issue.
Traced the change in behavior to fixes in our shard_and_maybe_pad_np function #515.
The text was updated successfully, but these errors were encountered: