Skip to content

Commit

Permalink
[StableHLO] Pin StableHLOv0.19.0 for older PJRT plugins.
Browse files Browse the repository at this point in the history
This is a temporary measure to allow plugins to update to latest jaxlib.

PiperOrigin-RevId: 681151171
  • Loading branch information
GleasonK authored and Google-ML-Automation committed Oct 1, 2024
1 parent e31ec46 commit cef616b
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 13 deletions.
10 changes: 8 additions & 2 deletions xla/pjrt/mlir_to_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,15 +244,21 @@ absl::Status UpgradeVersionedStablehlo(mlir::ModuleOp mlir_module) {
return absl::OkStatus();
}

std::string GetDefaultStablehloVersion() {
std::string GetDefaultStablehloVersion(std::optional<int64_t> plugin_version) {
// TODO: (b/370803410) Use WEEK_12 in PJRT, some plugins were not up to date,
// so temporarily using 1.0.0 to allow them time for a new release.
// PJRT v54 released Jun 10, so most plugins should use WEEK_12 by default.
if (plugin_version.has_value() && plugin_version.value() < 54) {
return "0.19.0";
}

// This version must be >=12w old.
return mlir::vhlo::Version::fromCompatibilityRequirement(
mlir::vhlo::Version::CompatibilityRequirement::WEEK_12)
.toString();
}

absl::StatusOr<std::string> Serialize(mlir::ModuleOp module,
std::optional<int64_t> /*plugin_version*/,
absl::string_view target, bool inplace) {
// Current PJRT users expect 12 weeks forward compat, VHLO provides this
// compat.
Expand Down
4 changes: 2 additions & 2 deletions xla/pjrt/mlir_to_hlo.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ absl::Status ParseMlirModuleStringAndConvertToXlaComputation(

// Returns a version of StableHLO ~12w old, for forward compatibility with PJRT
// plugins on a quarterly update cycle.
std::string GetDefaultStablehloVersion();
std::string GetDefaultStablehloVersion(
std::optional<int64_t> plugin_version = std::nullopt);

// Serialize using MLIR Bytecode Format which does not guarantee forward or
// backward compatiblity of the dialects used. If passing StableHLO with forward
Expand All @@ -52,7 +53,6 @@ std::string GetDefaultStablehloVersion();
// For plugin_version < 41, returns `SerializeUsingNativeBytecode`.
// For plugin_version >= 41, returns `SerializeUsingVersionedStablehlo`.
absl::StatusOr<std::string> Serialize(mlir::ModuleOp mlir_module,
std::optional<int64_t> plugin_version,
absl::string_view target,
bool inplace = false);

Expand Down
8 changes: 4 additions & 4 deletions xla/pjrt/mlir_to_hlo_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ TEST(MlirToHloTest, StablehloTest) {
mlir::MLIRContext context;
TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef<mlir::ModuleOp> module,
ParseMlirModuleString(kProgram, context));
TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, 47, "1.0.0"));
TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, "1.0.0"));

// StableHLO uses VHLO for PJRT serialization.
EXPECT_THAT(blob, IsVhloArtifact("1.0.0"));
Expand All @@ -69,7 +69,7 @@ TEST(MlirToHloTest, ChloTest) {
mlir::MLIRContext context;
TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef<mlir::ModuleOp> module,
ParseMlirModuleString(kProgram, context));
TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, 47, "1.0.0"));
TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, "1.0.0"));

// CHLO decomposes to StableHLO, so uses VHLO serialization.
EXPECT_THAT(blob, IsVhloArtifact("1.0.0"));
Expand All @@ -86,7 +86,7 @@ TEST(MlirToHloTest, ChloTanOpTest) {
mlir::MLIRContext context;
TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef<mlir::ModuleOp> module,
ParseMlirModuleString(kProgram, context));
TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, 47, "1.0.0"));
TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, "1.0.0"));

// CHLO decomposes to StableHLO, so uses VHLO serialization.
EXPECT_THAT(blob, IsVhloArtifact("1.0.0"));
Expand All @@ -104,7 +104,7 @@ TEST(MlirToHloTest, MhloTest) {
mlir::MLIRContext context;
TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef<mlir::ModuleOp> module,
ParseMlirModuleString(kProgram, context));
TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, 47, "1.0.0"));
TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, "1.0.0"));

// MHLO and other dialects use native MLIR bytecode, not VHLO.
EXPECT_THAT(blob, Not(IsVhloArtifact("1.0.0")));
Expand Down
11 changes: 6 additions & 5 deletions xla/pjrt/pjrt_c_api_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,9 @@ absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> PjRtCApiClient::Compile(
if (!pjrt_c_api()) llvm::report_fatal_error("pjrt_c_api is null");
TF_ASSIGN_OR_RETURN(
std::string serialized,
xla::Serialize(module, plugin_attributes()->pjrt_c_api_minor_version,
xla::GetDefaultStablehloVersion()));
xla::Serialize(module,
xla::GetDefaultStablehloVersion(
plugin_attributes()->pjrt_c_api_minor_version)));
std::string format(pjrt::kMlirFormat);
return InitializeArgsAndCompile(this, c_api_, c_client_.get(), options,
serialized, format);
Expand Down Expand Up @@ -2311,9 +2312,9 @@ absl::StatusOr<std::unique_ptr<PjRtExecutable>> PjRtCApiCompiler::Compile(
if (client) {
plugin_version = client->plugin_attributes()->pjrt_c_api_minor_version;
}
TF_ASSIGN_OR_RETURN(std::string serialized,
xla::Serialize(module, plugin_version,
xla::GetDefaultStablehloVersion()));
TF_ASSIGN_OR_RETURN(
std::string serialized,
xla::Serialize(module, xla::GetDefaultStablehloVersion(plugin_version)));
std::string format(pjrt::kMlirFormat);
return InitializeArgsAndCompileAot(c_api_, client, options, topology,
serialized, format);
Expand Down

0 comments on commit cef616b

Please sign in to comment.