From 2e95843f9748608b11a89d9673b81750f107317d Mon Sep 17 00:00:00 2001 From: Tom Bannink Date: Mon, 17 Jun 2024 14:28:46 +0200 Subject: [PATCH] Update MLIR passes --- larq_compute_engine/mlir/BUILD | 3 + larq_compute_engine/mlir/tf_tfl_passes.cc | 6 +- .../mlir/tf_to_tfl_flatbuffer.cc | 93 +++++++++---------- .../mlir/tf_to_tfl_flatbuffer.h | 2 +- 4 files changed, 53 insertions(+), 51 deletions(-) diff --git a/larq_compute_engine/mlir/BUILD b/larq_compute_engine/mlir/BUILD index 8471d561..c74cc190 100644 --- a/larq_compute_engine/mlir/BUILD +++ b/larq_compute_engine/mlir/BUILD @@ -490,8 +490,11 @@ cc_library( "@local_tsl//tsl/platform:statusor", "@org_tensorflow//tensorflow/compiler/mlir:op_or_arg_name_mapper", "@org_tensorflow//tensorflow/compiler/mlir/lite:flatbuffer_export", + "@org_tensorflow//tensorflow/compiler/mlir/lite/debug", "@org_tensorflow//tensorflow/compiler/mlir/lite/metrics:error_collector", "@org_tensorflow//tensorflow/compiler/mlir/lite/quantization:quantization_config", + "@org_tensorflow//tensorflow/compiler/mlir/lite/stablehlo:op_stat_pass", + "@org_tensorflow//tensorflow/compiler/mlir/lite/stablehlo:stablehlo_util", "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:error_util", "@org_tensorflow//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_freeze_variables", ], diff --git a/larq_compute_engine/mlir/tf_tfl_passes.cc b/larq_compute_engine/mlir/tf_tfl_passes.cc index 5cc1dbf5..7139955a 100644 --- a/larq_compute_engine/mlir/tf_tfl_passes.cc +++ b/larq_compute_engine/mlir/tf_tfl_passes.cc @@ -108,7 +108,7 @@ void AddPreVariableFreezingTFToLCETFLConversionPasses( // This decomposes resource ops like ResourceGather into read-variable op // followed by gather. This is used when the saved model import path is used - // during which resources dont get frozen in the python layer. + // during which resources don't get frozen in the python layer. pass_manager->addNestedPass( mlir::TFDevice::CreateDecomposeResourceOpsPass()); @@ -257,7 +257,9 @@ void AddPostVariableFreezingTFToLCETFLConversionPasses( // Run quantization after all the floating point model conversion is // completed. - if (quant_specs.RunPropagationAndRewriteQuantizationPasses()) { + if (quant_specs.RunPropagationAndRewriteQuantizationPasses() || + quant_specs.qdq_conversion_mode != + mlir::quant::QDQConversionMode::kQDQNone) { AddQuantizationPasses(quant_specs, *pass_manager); // Remove unnecessary QDQs while handling QAT models. pass_manager->addNestedPass( diff --git a/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.cc b/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.cc index c5ea6957..be080cef 100644 --- a/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.cc +++ b/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.cc @@ -3,10 +3,14 @@ #include "larq_compute_engine/mlir/tf_tfl_passes.h" #include "larq_compute_engine/mlir/transforms/passes.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/PassManager.h" +#include "tensorflow/compiler/mlir/lite/debug/debug.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" @@ -55,7 +59,7 @@ class TruncateOpOrArgLocNameMapper : public OpOrArgLocNameMapper { }; } // namespace -Status ConvertTFExecutorToTFLOrFlatbuffer( +absl::Status ConvertTFExecutorToTFLOrFlatbuffer( mlir::ModuleOp module, bool export_to_mlir, const LCETarget target, mlir::quant::QuantizationSpecs quant_specs, const std::unordered_set& saved_model_tags, @@ -64,70 +68,59 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( // Explicitly disable dumping Op details on failures. module.getContext()->printOpOnDiagnostic(false); - // Register a warning handler only log to std out. - mlir::ScopedDiagnosticHandler s( - module.getContext(), [](mlir::Diagnostic& diag) { - if (diag.getSeverity() == mlir::DiagnosticSeverity::Warning) { - for (auto& note : diag.getNotes()) { - std::cout << note.str() << "\n"; - LOG(WARNING) << note.str() << "\n"; - } - } - return mlir::failure(); - }); + mlir::DialectRegistry registry; + mlir::func::registerAllExtensions(registry); + module.getContext()->appendDialectRegistry(registry); mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(), /*propagate=*/true); - if (failed(IsValidGraph(module))) { - return statusHandler.ConsumeStatus(); - } - mlir::PassManager pass_manager(module.getContext()); + mlir::registerPassManagerCLOptions(); if (mlir::failed(mlir::applyPassManagerCLOptions(pass_manager))) { - // We don't return here as in the normal TF converter, since apparently this - // actually fails in our case, but the failure isn't terminal. - // return tensorflow::FromAbslStatus( - // absl::UnknownError("failed to apply MLIR pass manager CL options")); + return absl::InternalError("Failed to apply MLIR pass manager CL options."); } + // DebugOptions::ir_dump_dir can be set for debugging + converter::DebugOptions debug_options; + InitPassManager(pass_manager, debug_options); + pass_manager.addInstrumentation( std::make_unique( pass_manager.getContext())); + if (mlir::failed(IsValidGraph(module))) { + return statusHandler.ConsumeStatus(); + } + tensorflow::AddPreVariableFreezingTFToLCETFLConversionPasses(&pass_manager); - if (failed(pass_manager.run(module))) { + if (mlir::failed(pass_manager.run(module))) { return statusHandler.ConsumeStatus(); } // Freeze variables if a session is provided. - if (session.has_value()) { - mlir::TFL::ErrorCollectorInstrumentation collector(module.getContext()); - if (mlir::failed( - mlir::tf_saved_model::FreezeVariables(module, session.value()))) { - auto status = statusHandler.ConsumeStatus(); - mlir::TFL::ErrorCollector* collector = - mlir::TFL::ErrorCollector::GetErrorCollector(); - if (!collector->CollectedErrors().empty()) { - return errors::InvalidArgument("Variable constant folding has failed."); - } - return status; - } + if (session.has_value() && mlir::failed(mlir::tf_saved_model::FreezeVariables( + module, session.value_or(nullptr)))) { + return statusHandler.Combine( + absl::InvalidArgumentError("Variable constant folding is failed.")); } + pass_manager.clear(); + tensorflow::AddPostVariableFreezingTFToLCETFLConversionPasses( saved_model_dir, quant_specs, &pass_manager, target); - if (failed(pass_manager.run(module))) { - auto status = statusHandler.ConsumeStatus(); - mlir::TFL::ErrorCollector* collector = - mlir::TFL::ErrorCollector::GetErrorCollector(); - for (const auto& error_data : collector->CollectedErrors()) { - if (error_data.subcomponent() == "FreezeGlobalTensorsPass") { - return errors::InvalidArgument("Variable constant folding is failed."); - } - } - return status; + if (mlir::failed(pass_manager.run(module))) { + return statusHandler.Combine( + absl::InvalidArgumentError("Variable constant folding failed.")); } if (export_to_mlir) { + pass_manager.clear(); + // Print out a detailed report of ops that are not converted to TFL ops. + pass_manager.addPass(mlir::odml::createPrintOpStatsPass( + mlir::odml::GetAcceptedTFLiteDialects())); + if (mlir::failed(pass_manager.run(module))) { + return statusHandler.ConsumeStatus(); + } + llvm::raw_string_ostream os(*result); module.print(os); return statusHandler.ConsumeStatus(); @@ -142,14 +135,18 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( options.toco_flags = toco_flags; options.saved_model_tags = saved_model_tags; options.op_or_arg_name_mapper = &op_or_arg_name_mapper; - if (!tflite::MlirToFlatBufferTranslateFunction(module, options, result)) { - return statusHandler.ConsumeStatus(); + const bool serialize_stablehlo_ops = false; + if (!tflite::MlirToFlatBufferTranslateFunction(module, options, result, + serialize_stablehlo_ops)) { + return statusHandler.Combine( + absl::InternalError("Could not translate MLIR to FlatBuffer.")); } - if (mlir::failed(module.verify())) { - return tensorflow::errors::Unknown("Final module is invalid"); + if (mlir::failed(module.verifyInvariants())) { + return statusHandler.Combine( + absl::InternalError("Final module is invalid.")); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.h b/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.h index 7fee937f..e40eec8b 100644 --- a/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.h +++ b/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.h @@ -15,7 +15,7 @@ namespace tensorflow { // This is a fork of ConvertTFExecutorToTFLOrFlatbuffer to enable custom // OpOrArgLocNameMapper // https://github.com/tensorflow/tensorflow/blob/v2.8.0/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h#L60-L78 -Status ConvertTFExecutorToTFLOrFlatbuffer( +absl::Status ConvertTFExecutorToTFLOrFlatbuffer( mlir::ModuleOp module, bool export_to_mlir, const LCETarget target, mlir::quant::QuantizationSpecs quant_specs, const std::unordered_set& saved_model_tags,