diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc index 9aa011c1d0ec4..914dc02a9eda4 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.cc +++ b/onnxruntime/core/providers/cuda/cudnn_common.cc @@ -174,7 +174,11 @@ cudnnDataType_t CudnnTensor::GetDataType() { template <> cudnnDataType_t CudnnTensor::GetDataType() { +#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8200 + return CUDNN_DATA_BFLOAT16; +#else ORT_THROW("cuDNN doesn't support BFloat16."); +#endif } template <> diff --git a/onnxruntime/test/common/random_generator.h b/onnxruntime/test/common/random_generator.h index cb1ce885d2d45..336e0f197fcc9 100644 --- a/onnxruntime/test/common/random_generator.h +++ b/onnxruntime/test/common/random_generator.h @@ -70,7 +70,7 @@ class RandomValueGenerator { // Random values generated are in the range [min, max). template typename std::enable_if< - std::is_same_v, + std::is_same_v || std::is_same_v, std::vector>::type Uniform(gsl::span dims, float min, float max) { std::vector val(detail::SizeFromDims(dims)); diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py index df0b5f195f0b9..88735ff18515e 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py @@ -146,6 +146,38 @@ def test_onnx_ops(self): device = torch.device(device_name) self.gradient_correctness(name, device) + @unittest.skipIf(not torch.cuda.is_bf16_supported(), "Test requires CUDA and BF16 support") + def test_softmax_bf16_large(self): + if not torch.cuda.is_available(): + # only test bf16 on cuda + return + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.softmax(input, dim=-1) + return out + + device = "cuda:0" + input_shape = [2, 4096] + # run torch to get the expected result + data_torch = torch.randn(size=input_shape, device=device, dtype=torch.bfloat16) + 10 + data_torch.requires_grad = True + torch_model = Model() + torch_res = torch_model(input=data_torch) + init_grad = torch.ones_like(torch_res) + torch_res.backward(gradient=init_grad) + # run ort + ort_model = ORTModule(torch_model) + data_ort = data_torch.detach().clone() + data_ort.requires_grad = True + ort_res = ort_model(input=data_ort) + ort_res.backward(gradient=init_grad) + # compara result + torch.testing.assert_close(data_torch.grad, data_ort.grad, rtol=1e-5, atol=1e-4) + if __name__ == "__main__": unittest.main()