Skip to content

Commit

Permalink
use xla flags to enable hipblaslt instead of env vars
Browse files Browse the repository at this point in the history
  • Loading branch information
ScXfjiang committed Nov 22, 2024
1 parent b7d31bd commit 7630bad
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/stream_executor/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ cc_library(
#"//tensorflow/core/platform:env",
"//tensorflow/tsl/util:env_var",
"@com_google_absl//absl/types:any",
"//tensorflow/compiler/xla:debug_options_flags",
]),
)

Expand Down
6 changes: 2 additions & 4 deletions tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/stream_executor/stream_executor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/tsl/util/env_var.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"

namespace stream_executor {

Expand All @@ -33,10 +34,7 @@ using xla::PrimitiveType;

bool GpuBlasLtEnabled() {
static std::atomic_bool result{[] {
bool value = false;
tsl::ReadBoolFromEnvVar("TF_ENABLE_GPU_BLASLT",
/*default_value=*/false, &value);
return value;
return xla::GetDebugOptionsFromFlags().xla_gpu_enable_cublaslt();
}()};
return result;
}
Expand Down

0 comments on commit 7630bad

Please sign in to comment.