Skip to content

Commit

Permalink
remove autotuner from gpu blaslt gemm runner --> compile successfully
Browse files Browse the repository at this point in the history
  • Loading branch information
ScXfjiang committed Oct 29, 2024
1 parent 3f7e238 commit ddfd0d5
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 103 deletions.
4 changes: 3 additions & 1 deletion tensorflow/compiler/xla/stream_executor/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,10 @@ cc_library(
"//tensorflow/compiler/xla:autotune_results_proto_cc",
# "//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/stream_executor:scratch_allocator",
"//tensorflow/compiler/xla/service/gpu:autotuner_util",
# "//tensorflow/compiler/xla/service/gpu:autotuner_util",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/core/lib/gtl:array_slice",
"//tensorflow/core/util:env_var",
":gpu_blas_lt",
]),
)
Expand Down
192 changes: 96 additions & 96 deletions tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ limitations under the License.

#include "tensorflow/core/util/env_var.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/service/gpu/autotuner_util.h"
// #include "tensorflow/compiler/xla/service/gpu/autotuner_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h"
Expand All @@ -27,7 +27,7 @@ limitations under the License.
namespace stream_executor {
namespace gpu {

bool BlasLtGemmRunner::autotune_enabled_ = true;
bool BlasLtGemmRunner::autotune_enabled_ = false;

bool operator ==(const GroupedGemmConfig& rhs, const GroupedGemmConfig& lhs) {
return AsTuple(rhs) == AsTuple(lhs);
Expand All @@ -50,10 +50,10 @@ std::ostream& operator <<(std::ostream& os, const StridedGemmConfig& cfg) {
}

BlasLtGemmRunner::BlasLtGemmRunner(StreamExecutor *parent) :
mutex_(std::make_unique< absl::Mutex >()),
autotune_config_(std::make_unique< xla::gpu::AutotuneConfig >(
xla::gpu::DeviceConfig{parent, nullptr},
xla::GetDebugOptionsFromFlags()))
mutex_(std::make_unique< absl::Mutex >())
// autotune_config_(std::make_unique< xla::gpu::AutotuneConfig >(
// xla::gpu::DeviceConfig{parent, nullptr},
// xla::GetDebugOptionsFromFlags()))
{ }

BlasLtGemmRunner::~BlasLtGemmRunner() { }
Expand All @@ -67,48 +67,48 @@ BlasLtGemmRunner::~BlasLtGemmRunner() { }
size_t dev_id = stream->parent()->device_ordinal();
if (dev_id >= meta.size()) meta.resize(dev_id + 1);
auto& res = meta[dev_id];
if (!res) {
autotune_enabled_ = xla::GetDebugOptionsFromFlags().xla_gpu_autotune_level() > 0;
res.reset(new BlasLtGemmRunner(stream->parent()));
xla::gpu::AutotunerUtil::LoadAutotuneResultsFromFileOnce(*res->autotune_config_);
}
// if (!res) {
// autotune_enabled_ = xla::GetDebugOptionsFromFlags().xla_gpu_autotune_level() > 0;
// res.reset(new BlasLtGemmRunner(stream->parent()));
// xla::gpu::AutotunerUtil::LoadAutotuneResultsFromFileOnce(*res->autotune_config_);
// }
return *res;
}

template < class TuneFunc >
xla::StatusOr< gpu::BlasLt::MatmulAlgorithm > BlasLtGemmRunner::Autotune(
const std::vector< gpu::BlasLt::MatmulAlgorithm >& algorithms,
TuneFunc&& benchmark_func) {
gpu::BlasLt::MatmulAlgorithm best_algo;
float best_ms = std::numeric_limits< float >::max(), total_ms = 0;
uint32_t n_warmups = 1, n_iters = 5, n_total = n_warmups + n_iters, i = 0;
// template < class TuneFunc >
// xla::StatusOr< gpu::BlasLt::MatmulAlgorithm > BlasLtGemmRunner::Autotune(
// const std::vector< gpu::BlasLt::MatmulAlgorithm >& algorithms,
// TuneFunc&& benchmark_func) {
// gpu::BlasLt::MatmulAlgorithm best_algo;
// float best_ms = std::numeric_limits< float >::max(), total_ms = 0;
// uint32_t n_warmups = 1, n_iters = 5, n_total = n_warmups + n_iters, i = 0;

for (uint32_t j = 0; j < algorithms.size(); j++) {
const auto& algo = algorithms[j];
if (!benchmark_func(algo, nullptr).ok()) continue;
// for (uint32_t j = 0; j < algorithms.size(); j++) {
// const auto& algo = algorithms[j];
// if (!benchmark_func(algo, nullptr).ok()) continue;

blas::ProfileResult profile;
for (i = 0, total_ms = 0; i < n_total; i++) {
auto res = benchmark_func(algo, &profile);
if (!res.ok() || !profile.is_valid()) {
VLOG(1) << j << ": gemm algorithm is not valid: " /* << res.error_message() */;
break;
}
if (i >= n_warmups) total_ms += profile.elapsed_time_in_ms();
}
if (i < n_total) continue; // invalid algorithm
total_ms /= n_iters;
VLOG(2) << j << ": gemm algorithm " << profile.algorithm() << " took "
<< total_ms << "ms, workspace: " << algo.workspace_size;
if (total_ms < best_ms) {
best_ms = total_ms, best_algo = algo;
}
} // for algorithms
if (!best_algo.opaque_algo.has_value()) {
return xla::InternalError("No valid gemm algorithms found!");
}
return best_algo;
}
// blas::ProfileResult profile;
// for (i = 0, total_ms = 0; i < n_total; i++) {
// auto res = benchmark_func(algo, &profile);
// if (!res.ok() || !profile.is_valid()) {
// VLOG(1) << j << ": gemm algorithm is not valid: " /* << res.error_message() */;
// break;
// }
// if (i >= n_warmups) total_ms += profile.elapsed_time_in_ms();
// }
// if (i < n_total) continue; // invalid algorithm
// total_ms /= n_iters;
// VLOG(2) << j << ": gemm algorithm " << profile.algorithm() << " took "
// << total_ms << "ms, workspace: " << algo.workspace_size;
// if (total_ms < best_ms) {
// best_ms = total_ms, best_algo = algo;
// }
// } // for algorithms
// if (!best_algo.opaque_algo.has_value()) {
// return xla::InternalError("No valid gemm algorithms found!");
// }
// return best_algo;
// }

xla::StatusOr< std::array< uint64_t, 3 >> BlasLtGemmRunner::ContiguousStrides(
const ArraySlice<DeviceMemoryBase *>& a,
Expand Down Expand Up @@ -188,13 +188,13 @@ xla::Status BlasLtGemmRunner::RunBatchedImpl(Stream& stream,
if (algorithms.empty()) return xla::InternalError("No GG algorithms found!");
best_algo = algorithms[0]; // otherwise use default algorithm
} else {
TF_ASSIGN_OR_RETURN(auto best_algo, Autotune(algorithms,
[&](const gpu::BlasLt::MatmulAlgorithm& algo, blas::ProfileResult *profile){
if (profile == nullptr) {
return res->second->SetAlgorithm(algo, allocator);
}
return res->second->ExecuteOnStream(&stream, cfg, profile);
}));
// TF_ASSIGN_OR_RETURN(auto best_algo, Autotune(algorithms,
// [&](const gpu::BlasLt::MatmulAlgorithm& algo, blas::ProfileResult *profile){
// if (profile == nullptr) {
// return res->second->SetAlgorithm(algo, allocator);
// }
// return res->second->ExecuteOnStream(&stream, cfg, profile);
// }));
}
TF_RETURN_IF_ERROR(res->second->SetAlgorithm(best_algo, allocator));
}
Expand Down Expand Up @@ -275,53 +275,53 @@ xla::Status BlasLtGemmRunner::RunStridedBatchedImpl(Stream& stream,
break;
}

BlasLt::MatmulAlgorithm best_algo{ .id = blas::kNoAlgorithm };
xla::gpu::AutotuneCacheKey key(ToCSVString(cfg, /*full_string*/false));
auto opt_res = xla::gpu::AutotunerUtil::TryToFindInInMemoryCache(key);
if (opt_res.has_value()) {
auto id = *opt_res;
for (const auto& algo : algorithms) {
if (algo.id == id) best_algo = algo;
}
if (best_algo.id == blas::kNoAlgorithm) {
LOG(WARNING) << "Best algorithm not valid: need to autotune..";
}
}
// BlasLt::MatmulAlgorithm best_algo{ .id = blas::kNoAlgorithm };
// xla::gpu::AutotuneCacheKey key(ToCSVString(cfg, /*full_string*/false));
// auto opt_res = xla::gpu::AutotunerUtil::TryToFindInInMemoryCache(key);
// if (opt_res.has_value()) {
// auto id = *opt_res;
// for (const auto& algo : algorithms) {
// if (algo.id == id) best_algo = algo;
// }
// if (best_algo.id == blas::kNoAlgorithm) {
// LOG(WARNING) << "Best algorithm not valid: need to autotune..";
// }
// }

if (best_algo.id == blas::kNoAlgorithm) {
TF_ASSIGN_OR_RETURN(best_algo, Autotune(algorithms,
[&](const gpu::BlasLt::MatmulAlgorithm& algo, blas::ProfileResult *profile){
if (profile == nullptr) {
return res->second->SetAlgorithm(algo);
}
return res->second->ExecuteOnStream(
&stream, a, b, *c, *c,
DeviceMemoryBase{}, // bias
DeviceMemoryBase{}, // aux
DeviceMemoryBase{}, // a_scale
DeviceMemoryBase{}, // b_scale
DeviceMemoryBase{}, // c_scale
DeviceMemoryBase{}, // d_scale
DeviceMemoryBase{}, // d_amax
absl::nullopt, // workspace
allocator, // allocator
profile);
}));
xla::gpu::AutotunerUtil::CacheValue ares = best_algo.id;
// reread algorithm ID from cache again (in case some other thread has
// already added this config to the cache to be sure we use the same ID)
auto new_id = xla::gpu::AutotunerUtil::AddResultToInMemoryCache(key, ares,
*autotune_config_);

if (new_id != best_algo.id) {
for (const auto& algo : algorithms) {
if (algo.id == new_id) best_algo = algo;
}
}
} // best_algo.id == blas::kNoAlgorithm

res->second->SetAlgorithm(best_algo);
break;
// if (best_algo.id == blas::kNoAlgorithm) {
// TF_ASSIGN_OR_RETURN(best_algo, Autotune(algorithms,
// [&](const gpu::BlasLt::MatmulAlgorithm& algo, blas::ProfileResult *profile){
// if (profile == nullptr) {
// return res->second->SetAlgorithm(algo);
// }
// return res->second->ExecuteOnStream(
// &stream, a, b, *c, *c,
// DeviceMemoryBase{}, // bias
// DeviceMemoryBase{}, // aux
// DeviceMemoryBase{}, // a_scale
// DeviceMemoryBase{}, // b_scale
// DeviceMemoryBase{}, // c_scale
// DeviceMemoryBase{}, // d_scale
// DeviceMemoryBase{}, // d_amax
// absl::nullopt, // workspace
// allocator, // allocator
// profile);
// }));
// xla::gpu::AutotunerUtil::CacheValue ares = best_algo.id;
// // reread algorithm ID from cache again (in case some other thread has
// // already added this config to the cache to be sure we use the same ID)
// auto new_id = xla::gpu::AutotunerUtil::AddResultToInMemoryCache(key, ares,
// *autotune_config_);

// if (new_id != best_algo.id) {
// for (const auto& algo : algorithms) {
// if (algo.id == new_id) best_algo = algo;
// }
// }
// } // best_algo.id == blas::kNoAlgorithm

// res->second->SetAlgorithm(best_algo);
// break;
} // while
return res->second->ExecuteOnStream(
&stream, a, b, *c, *c,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,10 @@ struct BlasLtGemmRunner {
private:
explicit BlasLtGemmRunner(StreamExecutor *parent);

template < class TuneFunc >
xla::StatusOr< gpu::BlasLt::MatmulAlgorithm > Autotune(
const std::vector< gpu::BlasLt::MatmulAlgorithm >& algorithms,
TuneFunc&& benchmark_func);

// template < class TuneFunc >
// xla::StatusOr< gpu::BlasLt::MatmulAlgorithm > Autotune(
// const std::vector< gpu::BlasLt::MatmulAlgorithm >& algorithms,
// TuneFunc&& benchmark_func);

xla::Status RunBatchedImpl(Stream& stream, blas::Transpose trans_a,
blas::Transpose trans_b, int64 m, int64 n, int64 k,
Expand All @@ -248,7 +247,7 @@ struct BlasLtGemmRunner {

static bool autotune_enabled_;
std::unique_ptr< absl::Mutex > mutex_;
std::unique_ptr< xla::gpu::AutotuneConfig > autotune_config_;
// std::unique_ptr< xla::gpu::AutotuneConfig > autotune_config_;
absl::flat_hash_map<GroupedGemmConfig, BlasLt::GroupedMatmulPlanPtr> grouped_gemm_map_;
absl::flat_hash_map<StridedGemmConfig, BlasLt::MatmulPlanPtr> strided_gemm_map_;
};
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,10 @@ cc_library(
"//tensorflow/core/platform:types",
"//tensorflow/tsl/util:env_var",
],
visibility = [
"//tensorflow/core:__subpackages__",
"//tensorflow/compiler/xla:__subpackages__",
],
)

cc_library(
Expand Down

0 comments on commit ddfd0d5

Please sign in to comment.