diff --git a/tensorflow/compiler/xla/stream_executor/stream.cc b/tensorflow/compiler/xla/stream_executor/stream.cc index 097f6203c79b89..d269590b1a7696 100644 --- a/tensorflow/compiler/xla/stream_executor/stream.cc +++ b/tensorflow/compiler/xla/stream_executor/stream.cc @@ -1477,13 +1477,6 @@ tsl::Status Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, ConstantType beta, DeviceMemory *c, int ldc, blas::ComputePrecision precision, blas::CallContext context) { - if(gpu::GpuBlasLtEnabled()) { - auto& r = gpu::BlasLtGemmRunner::i(this); - CheckStatus(r.Run(*this, transa, transb, m, n, k, - alpha, a, lda, b, ldb, beta, c, ldc, - /* allocator */nullptr)); //! NOTE: allocator is not available!! - return ::tsl::OkStatus(); - } static_assert( detail::is_any_of, std::complex>(), @@ -1500,6 +1493,15 @@ tsl::Status Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, static_assert( detail::is_any_of(), "If input is not Eigen::half, constant and input types have to match"); + + if(gpu::GpuBlasLtEnabled()) { + auto& r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.Run(*this, transa, transb, m, n, k, + alpha, a, lda, b, ldb, beta, c, ldc, + /* allocator */nullptr)); //! NOTE: allocator is not available!! + return ::tsl::OkStatus(); + } + blas::BlasSupport *blas = parent()->AsBlas(); if (!blas) { return tsl::errors::Internal( @@ -1540,6 +1542,72 @@ INSTANTIATE_THEN_BLAS_GEMM(Eigen::bfloat16, float) #undef INSTANTIATE_THEN_BLAS_GEMM +template +tsl::Status Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, + uint64_t k, ConstantType alpha, const DeviceMemory &a, int lda, + int64_t stride_a, const DeviceMemory &b, int ldb, + int64_t stride_b, ConstantType beta, DeviceMemory *c, int ldc, + int64_t stride_c, int batch_count, blas::ComputePrecision precision, + blas::CallContext context) { + static_assert( + detail::is_any_of, std::complex>(), + "Unsupported input type"); + static_assert( + std::is_same_v || + (detail::is_any_of() && + std::is_same_v), + "Mismatched input and alpha/beta types"); + + if (gpu::GpuBlasLtEnabled()) { + auto &r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunStridedBatched( + *this, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_count, + /* allocator */ nullptr)); //! NOTE: allocator is not available!! + return ::tsl::OkStatus(); + } + + blas::BlasSupport *blas = parent()->AsBlas(); + if (!blas) { + return tsl::errors::Internal( + "Attempting to perform BLAS operation using " + "StreamExecutor without BLAS support"); + } + + void *alpha_ptr = α + void *beta_ptr = β + float alpha_storage, beta_storage; + UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, + &beta_storage); + + return blas->DoBlasGemmStridedBatched( + this, transa, transb, m, n, k, blas::ToDataType::value, + alpha_ptr, a, lda, stride_a, b, ldb, stride_b, beta_ptr, c, ldc, + stride_c, batch_count, precision, context); +} + +#define INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(INPUT_TYPE, CONSTANT_TYPE) \ + template tsl::Status Stream::ThenBlasGemmStridedBatched ( \ + blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, \ + uint64_t k, CONSTANT_TYPE alpha, const DeviceMemory &a, int lda, \ + int64_t stride_a, const DeviceMemory &b, int ldb, \ + int64_t stride_b, CONSTANT_TYPE beta, DeviceMemory *c, int ldc, \ + int64_t stride_c, int batch_count, blas::ComputePrecision precision, \ + blas::CallContext context); + +INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(float, float) +INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(double, double) +INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(Eigen::half, Eigen::half) +INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(Eigen::bfloat16, Eigen::bfloat16) +INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(std::complex, std::complex) +INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(std::complex, std::complex) +INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(Eigen::half, float) +INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(Eigen::bfloat16, float) + +#undef INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED + namespace { // Like ThenBlasImpl, except this expects the last argument of blas_func to be a // blas::ProfileResult*. This functor doesn't put the stream into an error diff --git a/tensorflow/compiler/xla/stream_executor/stream.h b/tensorflow/compiler/xla/stream_executor/stream.h index d3d30bc0257752..7b9175b53d854c 100644 --- a/tensorflow/compiler/xla/stream_executor/stream.h +++ b/tensorflow/compiler/xla/stream_executor/stream.h @@ -1093,34 +1093,7 @@ class Stream { int64_t stride_a, const DeviceMemory &b, int ldb, int64_t stride_b, ConstantType beta, DeviceMemory *c, int ldc, int64_t stride_c, int batch_count, blas::ComputePrecision precision, - blas::CallContext context) { - static_assert( - detail::is_any_of, std::complex>(), - "Unsupported input type"); - static_assert( - std::is_same_v || - (detail::is_any_of() && - std::is_same_v), - "Mismatched input and alpha/beta types"); - blas::BlasSupport *blas = parent()->AsBlas(); - if (!blas) { - return tsl::errors::Internal( - "Attempting to perform BLAS operation using " - "StreamExecutor without BLAS support"); - } - - void *alpha_ptr = α - void *beta_ptr = β - float alpha_storage, beta_storage; - UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, - &beta_storage); - - return blas->DoBlasGemmStridedBatched( - this, transa, transb, m, n, k, blas::ToDataType::value, - alpha_ptr, a, lda, stride_a, b, ldb, stride_b, beta_ptr, c, ldc, - stride_c, batch_count, precision, context); - } + blas::CallContext context); // See BlasSupport::DoBlasTrsm. Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,