diff --git a/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc index d0df85842ed2c3..195e1161a3aa3a 100644 --- a/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc +++ b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc @@ -96,7 +96,8 @@ xla::StatusOr GetBlasComputationType( case DataType::kFloat: // fall-through if (lhs_dtype == DataType::kHalf) return f16_comp; if (lhs_dtype == DataType::kBF16) return bf16_comp; - return ComputationType::kF32; + return TF32_Enabled() ? ComputationType::kTF32AsF32 + : ComputationType::kF32; case DataType::kComplexFloat: return ComputationType::kF32; case DataType::kDouble: // fall-through