Skip to content

Commit

Permalink
[xla:cpu] NFC: Remove deprecated XLA:CPU mlir based codegen part #4
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 630139768
  • Loading branch information
ezhulenev authored and copybara-github committed May 2, 2024
1 parent 566edab commit f6613e2
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 308 deletions.
13 changes: 0 additions & 13 deletions xla/service/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -245,19 +245,10 @@ cc_library(
"//xla:xla_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:hlo_module_group",
"//xla/mlir/framework/ir:xla_framework",
"//xla/mlir/runtime/ir:rt",
"//xla/mlir/runtime/transforms:calling_convention",
"//xla/mlir/runtime/transforms:compilation_pipeline_cpu",
"//xla/mlir/runtime/transforms:compiler",
"//xla/mlir/runtime/transforms:jit_compiler",
"//xla/mlir_hlo",
"//xla/mlir_hlo:all_passes",
"//xla/mlir_hlo:mhlo_passes",
"//xla/mlir_hlo:transforms_passes",
"//xla/runtime:custom_call_registry",
"//xla/runtime:executable",
"//xla/runtime:jit_executable",
"//xla/service:algebraic_simplifier",
"//xla/service:all_reduce_promotion",
"//xla/service:all_to_all_decomposer",
Expand Down Expand Up @@ -564,9 +555,6 @@ cc_library(
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/mlir/runtime/transforms:compiler",
"//xla/runtime:executable",
"//xla/runtime:jit_executable",
"//xla/service:buffer_assignment",
"//xla/service:computation_layout",
"//xla/service:custom_call_status_internal",
Expand Down Expand Up @@ -808,7 +796,6 @@ cc_library(
"//xla:statusor",
"//xla:types",
"//xla:util",
"//xla/runtime:execution_engine",
"//xla/service:llvm_compiler",
"//xla/service/llvm_ir:llvm_util",
"@com_google_absl//absl/functional:any_invocable",
Expand Down
15 changes: 0 additions & 15 deletions xla/service/cpu/compiler_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ limitations under the License.
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/Instrumentation/DataFlowSanitizer.h"
#include "xla/runtime/execution_engine.h"
#include "xla/service/cpu/cpu_runtime.h"
#include "xla/service/cpu/llvm_ir_runtime.h"
#include "xla/service/llvm_ir/llvm_util.h"
Expand Down Expand Up @@ -160,20 +159,6 @@ llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> CompilerFunctor::operator()(

llvm::ModulePassManager pm;

for (const auto& func_name : convert_to_xla_runtime_abi_) {
llvm::Function* func = module.getFunction(func_name);
// Create a new function with the XLA Runtime ABI and inline the original
// (i.e. with ctx + memref args) into it.
std::string inlined_func_name =
absl::StrCat(func_name, "__orig_xla_runtime_abi");
func->setName(inlined_func_name);
absl::Status status = xla::runtime::ExportWithXlaRuntimeAbi(
module, inlined_func_name, func_name);
if (!status.ok()) {
LOG(FATAL) << status.message();
}
}

if (dfsan_enabled_) {
pm.addPass(llvm::DataFlowSanitizerPass(dfsan_abi_list_files_));
}
Expand Down
7 changes: 2 additions & 5 deletions xla/service/cpu/compiler_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ class CompilerFunctor : public llvm::orc::IRCompileLayer::IRCompiler {
absl::AnyInvocable<void(const llvm::object::ObjectFile&)>
post_codegen_hook = nullptr,
bool dfsan_enabled = false,
const std::vector<std::string>& dfsan_abi_list_files = {},
const std::vector<std::string>& convert_to_xla_runtime_abi = {})
const std::vector<std::string>& dfsan_abi_list_files = {})
: IRCompiler(llvm::orc::IRSymbolMapper::ManglingOptions()),
target_machine_(target_machine),
opt_level_(opt_level),
Expand All @@ -59,8 +58,7 @@ class CompilerFunctor : public llvm::orc::IRCompileLayer::IRCompiler {
post_optimization_hook_(std::move(post_optimization_hook)),
post_codegen_hook_(std::move(post_codegen_hook)),
dfsan_enabled_(dfsan_enabled),
dfsan_abi_list_files_(dfsan_abi_list_files),
convert_to_xla_runtime_abi_(convert_to_xla_runtime_abi) {}
dfsan_abi_list_files_(dfsan_abi_list_files) {}

// Compile a Module to an ObjectFile.
llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> operator()(
Expand All @@ -78,7 +76,6 @@ class CompilerFunctor : public llvm::orc::IRCompileLayer::IRCompiler {
absl::AnyInvocable<void(const llvm::object::ObjectFile&)> post_codegen_hook_;
const bool dfsan_enabled_ = false;
const std::vector<std::string> dfsan_abi_list_files_;
const std::vector<std::string> convert_to_xla_runtime_abi_;
};

} // namespace cpu
Expand Down
8 changes: 1 addition & 7 deletions xla/service/cpu/cpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1412,11 +1412,6 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
obj_file.getData().size()));
};

std::vector<std::string> xla_runtime_abi_conversions;
if (options.use_mlir_hlo_lowering()) {
xla_runtime_abi_conversions.push_back(options.entry_point_name());
}

CompilerFunctor compiler_functor(
target_machine.get(), static_cast<int>(opt_level),
options::OptimizeForSizeRequested(module->config()),
Expand All @@ -1425,8 +1420,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
llvm_ir::GetCpuFastMathFlags(module->config()),
pre_optimization_ir_hook, post_optimization_ir_hook,
post_codegen_hook, aot_options.sanitize_dataflow(),
aot_options.sanitize_abilists_dataflow(),
xla_runtime_abi_conversions);
aot_options.sanitize_abilists_dataflow());
std::unique_ptr<llvm::MemoryBuffer> object_file =
cantFail(compiler_functor(*llvm_module));
ObjectFileData object_file_data(object_file->getBufferStart(),
Expand Down
206 changes: 12 additions & 194 deletions xla/service/cpu/cpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ limitations under the License.
#include "mlir/Parser/Parser.h" // from @llvm-project
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/mlir/runtime/transforms/compiler.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/computation_layout.h"
#include "xla/service/logical_buffer.h"
Expand All @@ -56,8 +55,6 @@ limitations under the License.
namespace xla {
namespace cpu {

namespace runtime = ::xla::runtime;

absl::StatusOr<std::unique_ptr<CpuExecutable>> CpuExecutable::Create(
std::unique_ptr<SimpleOrcJIT> jit,
std::unique_ptr<const BufferAssignment> assignment,
Expand Down Expand Up @@ -95,15 +92,11 @@ absl::StatusOr<std::unique_ptr<CpuExecutable>> CpuExecutable::Create(
std::unique_ptr<HloModule> hlo_module,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map,
std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<XlaRuntimeCpuExecutable> xla_runtime_executable) {
std::unique_ptr<const BufferAssignment> assignment) {
std::unique_ptr<CpuExecutable> executable(new CpuExecutable(
std::move(hlo_module), std::move(hlo_profile_printer_data),
std::move(hlo_profile_index_map), std::move(assignment)));
executable->set_ir_module_string(
xla_runtime_executable->GetExecutable().take_ir_module_string());
executable->module_name_ = "main";
executable->xla_runtime_executable_ = std::move(xla_runtime_executable);
return executable;
}

Expand Down Expand Up @@ -237,33 +230,17 @@ Status CpuExecutable::ExecuteComputeFunction(
}
};

if (IsXlaRuntime()) {
std::vector<BufferDesc> descriptor_table;
descriptor_table.reserve(buffers.size());
for (const auto& buffer : buffers) {
const tensorflow::se::DeviceMemoryBase& base =
buffer.AsDeviceMemoryBase();
BufferDesc desc(const_cast<void*>(base.opaque()), base.size());
descriptor_table.push_back(std::move(desc));
}
Status status = ExecuteXlaRuntime(descriptor_table, run_options);
record_profile();
if (!status.ok()) {
return status;
}
} else {
XlaCustomCallStatus status;
// For the entry computation (like all global computations), all inputs and
// outputs are in the buffer table, and both the result pointer and args
// array pointers are unused (so we set them to 'nullptr').
compute_function_(nullptr, run_options, nullptr, buffer_pointers.data(),
&status, profile_counters);
record_profile();
std::optional<absl::string_view> error_message =
CustomCallStatusGetMessage(&status);
if (error_message) {
return Internal("CustomCall failed: %s", *error_message);
}
XlaCustomCallStatus status;
// For the entry computation (like all global computations), all inputs and
// outputs are in the buffer table, and both the result pointer and args
// array pointers are unused (so we set them to 'nullptr').
compute_function_(nullptr, run_options, nullptr, buffer_pointers.data(),
&status, profile_counters);
record_profile();
std::optional<absl::string_view> error_message =
CustomCallStatusGetMessage(&status);
if (error_message) {
return Internal("CustomCall failed: %s", *error_message);
}

return OkStatus();
Expand Down Expand Up @@ -369,162 +346,6 @@ absl::StatusOr<ExecutionOutput> CpuExecutable::CreateResultShapedBuffer(
return std::move(result);
}

// Converts a BufferDesc to a MemrefDesc according to the given 'operand_type',
// which should point to a runtime::MemrefType.
// Note: 'descriptor_index' and 'operand_index' are just used for error
// reporting.
static absl::StatusOr<runtime::MemrefDesc> BufferToMemref(
const BufferDesc& descriptor, const runtime::Type& operand_type,
size_t descriptor_index, size_t operand_index) {
auto* memref = llvm::dyn_cast<runtime::MemrefType>(&operand_type);
if (!memref) {
return Internal(
"Cannot convert descriptor %zu (operand_index %zu): "
"the corresponding type in the signature is a %s, "
"not a MemrefType.",
descriptor_index, operand_index, operand_type.ToString());
}

absl::Span<const int64_t> dims = memref->sizes();

// Verify that the provided descriptor size matches that of the memref.
size_t n_elem = absl::c_accumulate(dims, size_t{1}, std::multiplies<>());
size_t expected_size =
primitive_util::ByteWidth(memref->element_type()) * n_elem;
if (LLVM_UNLIKELY(expected_size != descriptor.size())) {
return InvalidArgument(
"Cannot convert descriptor %zu (operand_index %zu): "
"buffer size is not equal to that expected from the element type: "
"got %zu vs expected %zu.",
descriptor_index, operand_index, descriptor.size(), expected_size);
}

auto fill_sizes_and_strides = [&](auto sizes, auto strides) {
size_t multiplier = 1;
for (int i = static_cast<int>(dims.size()) - 1; i >= 0; --i) {
size_t size = dims[i];
sizes[i] = size;
strides[i] = multiplier;
multiplier *= size;
}
};
return runtime::MemrefDesc(memref->rank(), memref->element_type(),
descriptor.data(), /*offset=*/0,
fill_sizes_and_strides);
}

// Executes from an XLA Runtime CPU executable, given a buffer descriptor table.
// Relevant elements of the descriptor table (i.e. arguments and results) are
// converted to MemrefDesc's according to the corresponding operands in the
// runtime signature.
Status XlaRuntimeCpuExecutable::Execute(
const std::vector<BufferDesc>& descriptor_table,
const ExecutableRunOptions* run_options) {
const runtime::FunctionType& signature = GetExecutable().runtime_signature();

size_t num_arguments = xla_framework_mapping_.inputs.size();
if (xla_framework_mapping_.output_is_tuple) {
num_arguments += xla_framework_mapping_.flattened_outputs.size();
} else if (xla_framework_mapping_.result != -1) {
num_arguments += 1;
}

// Verify that the number of arguments in the mapping matches the signature.
// Add one to num_arguments to account for the signature's execution context.
if (num_arguments + 1 != signature.num_operands()) {
return Internal(
"Wrong number of arguments: got %zu via XLA FrameworkMapping, expected "
"%d.",
num_arguments, static_cast<int>(signature.num_operands()) - 1);
}

std::vector<runtime::MemrefDesc> arguments;
arguments.reserve(num_arguments);

auto append_converted_buffer = [&](size_t descriptor_index) -> Status {
const BufferDesc& descriptor = descriptor_table[descriptor_index];

// Use 1-based index to account for the execution context.
size_t operand_index = arguments.size() + 1;
const runtime::Type* operand_type = signature.operand(operand_index);

absl::StatusOr<runtime::MemrefDesc> memref = BufferToMemref(
descriptor, *operand_type, descriptor_index, operand_index);
if (!memref.ok()) {
return memref.status();
}
arguments.push_back(std::move(*memref));
return OkStatus();
};

// Inputs come first; results come last.
for (int64_t index : xla_framework_mapping_.inputs) {
TF_RETURN_IF_ERROR(append_converted_buffer(index));
}

int64_t result_index = xla_framework_mapping_.result;
if (xla_framework_mapping_.output_is_tuple) {
size_t num_outputs = xla_framework_mapping_.flattened_outputs.size();
for (size_t i = 0; i < num_outputs; ++i) {
int64_t output_index = xla_framework_mapping_.flattened_outputs[i];

TF_RETURN_IF_ERROR(append_converted_buffer(output_index));

// Populate the output tuple with a pointer to this result.
// TODO(b/249078472): make this work with nested tuples, if needed.
assert(result_index != -1);
void** results =
static_cast<void**>(descriptor_table[result_index].data());
results[i] = descriptor_table[output_index].data();
}
} else if (result_index != -1) {
TF_RETURN_IF_ERROR(append_converted_buffer(result_index));
}

runtime::Executable::CallFrame call_frame;
// Skip verification. The MemrefDesc's we created above come from the runtime
// signature; verifying them against the same signature would be redundant.
if (auto status =
GetExecutable().InitializeCallFrame(arguments, &call_frame,
/*verify_arguments=*/false);
!status.ok()) {
return Internal("Failed to initialize call frame: %s.",
status.message());
}

// No results to return; they are returned via out params.
runtime::NoResultConverter converter;

// Collect all emitted diagnostic messages.
std::string diagnostic;
runtime::DiagnosticEngine diagnostic_engine;
diagnostic_engine.AddHandler([&](runtime::Diagnostic& d) {
absl::StrAppend(&diagnostic, d.status().message());
return runtime::success();
});

runtime::CustomCall::UserData user_data(run_options);

runtime::Executable::ExecuteOpts opts;
opts.custom_call_data = &user_data;
opts.diagnostic_engine = &diagnostic_engine;
opts.custom_call_registry = &dynamic_custom_calls_;

// We don't expect to see any async tasks in the XLA Runtime executable.
opts.async_task_runner =
reinterpret_cast<runtime::AsyncTaskRunner*>(0xdeadbeef);

// Execute with the prepared call frame.
GetExecutable().Execute(call_frame, opts);
if (auto status = GetExecutable().ReturnResults(converter, &call_frame);
!status.ok()) {
return Internal("Failed to execute XLA Runtime executable: %s%s%s.",
status.message(), diagnostic.empty() ? "" : ": ",
diagnostic);
}
return OkStatus();
}

absl::StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ExecutionInput> arguments,
Expand Down Expand Up @@ -612,9 +433,6 @@ const InstructionValueSet& CpuExecutable::GetRootValueSet() const {
}

int64_t CpuExecutable::SizeOfGeneratedCodeInBytes() const {
// TODO(b/233850967): support profiling in XLA:CPU-Next, instead of
// punting on it as we are doing here.
if (IsXlaRuntime()) return 0;
return jit_->SizeOfGeneratedCodeInBytes();
}

Expand Down
Loading

0 comments on commit f6613e2

Please sign in to comment.