Skip to content

Commit

Permalink
Zhijxu/fix softmax cudnn bf16 (#21045)
Browse files Browse the repository at this point in the history
if seq >2048, ort will fallback to cudnn version, while when dtype is
bf16, ort will throw exception, this PR trying to fix it.
  • Loading branch information
zhijxu-MS authored Jun 24, 2024
1 parent 5b5ce0b commit 269d9b0
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 1 deletion.
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/cuda/cudnn_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,11 @@ cudnnDataType_t CudnnTensor::GetDataType<half>() {

template <>
cudnnDataType_t CudnnTensor::GetDataType<BFloat16>() {
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8200
return CUDNN_DATA_BFLOAT16;
#else
ORT_THROW("cuDNN doesn't support BFloat16.");
#endif
}

template <>
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/common/random_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class RandomValueGenerator {
// Random values generated are in the range [min, max).
template <typename TFloat16>
typename std::enable_if<
std::is_same_v<TFloat16, MLFloat16>,
std::is_same_v<TFloat16, MLFloat16> || std::is_same_v<TFloat16, BFloat16>,
std::vector<TFloat16>>::type
Uniform(gsl::span<const int64_t> dims, float min, float max) {
std::vector<TFloat16> val(detail::SizeFromDims(dims));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,38 @@ def test_onnx_ops(self):
device = torch.device(device_name)
self.gradient_correctness(name, device)

@unittest.skipIf(not torch.cuda.is_bf16_supported(), "Test requires CUDA and BF16 support")
def test_softmax_bf16_large(self):
if not torch.cuda.is_available():
# only test bf16 on cuda
return

class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
out = torch.softmax(input, dim=-1)
return out

device = "cuda:0"
input_shape = [2, 4096]
# run torch to get the expected result
data_torch = torch.randn(size=input_shape, device=device, dtype=torch.bfloat16) + 10
data_torch.requires_grad = True
torch_model = Model()
torch_res = torch_model(input=data_torch)
init_grad = torch.ones_like(torch_res)
torch_res.backward(gradient=init_grad)
# run ort
ort_model = ORTModule(torch_model)
data_ort = data_torch.detach().clone()
data_ort.requires_grad = True
ort_res = ort_model(input=data_ort)
ort_res.backward(gradient=init_grad)
# compara result
torch.testing.assert_close(data_torch.grad, data_ort.grad, rtol=1e-5, atol=1e-4)


if __name__ == "__main__":
unittest.main()

0 comments on commit 269d9b0

Please sign in to comment.