From 5e9f2d1200279fd183b2f988a5dd5168946ddd7c Mon Sep 17 00:00:00 2001 From: jacobdweightman Date: Wed, 4 Sep 2024 20:23:51 -0500 Subject: [PATCH] Parallelize passes that prepare IR to be run in interpreter --- zirgen/Dialect/ZStruct/Transforms/Passes.h | 2 +- zirgen/Dialect/ZStruct/Transforms/Passes.td | 2 +- zirgen/Dialect/ZStruct/Transforms/Unroll.cpp | 2 +- zirgen/dsl/driver.cpp | 10 ++++++---- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/zirgen/Dialect/ZStruct/Transforms/Passes.h b/zirgen/Dialect/ZStruct/Transforms/Passes.h index e1cc2c31..2e7e1c4e 100644 --- a/zirgen/Dialect/ZStruct/Transforms/Passes.h +++ b/zirgen/Dialect/ZStruct/Transforms/Passes.h @@ -22,7 +22,7 @@ namespace zirgen::ZStruct { // Pass constructors std::unique_ptr> createOptimizeLayoutPass(); -std::unique_ptr> createUnrollPass(); +std::unique_ptr createUnrollPass(); std::unique_ptr> createExpandLayoutPass(); std::unique_ptr createInlineLayoutPass(); diff --git a/zirgen/Dialect/ZStruct/Transforms/Passes.td b/zirgen/Dialect/ZStruct/Transforms/Passes.td index 4e3538dc..a4426071 100644 --- a/zirgen/Dialect/ZStruct/Transforms/Passes.td +++ b/zirgen/Dialect/ZStruct/Transforms/Passes.td @@ -23,7 +23,7 @@ def OptimizeLayout : Pass<"optimize-layout", "mlir::ModuleOp"> { let constructor = "zirgen::ZStruct::createOptimizeLayoutPass()"; } -def Unroll : Pass<"unroll", "mlir::ModuleOp"> { +def Unroll : Pass<"unroll"> { let summary = "Unroll zhlt map and reduce"; let description = [{ Removes zhlt.map and zhlt.reduce instructions by unrolling the loops. diff --git a/zirgen/Dialect/ZStruct/Transforms/Unroll.cpp b/zirgen/Dialect/ZStruct/Transforms/Unroll.cpp index 7c51c11d..94ecd5de 100644 --- a/zirgen/Dialect/ZStruct/Transforms/Unroll.cpp +++ b/zirgen/Dialect/ZStruct/Transforms/Unroll.cpp @@ -36,7 +36,7 @@ struct UnrollPass : public UnrollBase { } // End namespace -std::unique_ptr> createUnrollPass() { +std::unique_ptr createUnrollPass() { return std::make_unique(); } diff --git a/zirgen/dsl/driver.cpp b/zirgen/dsl/driver.cpp index da582c2a..29cd7ce0 100644 --- a/zirgen/dsl/driver.cpp +++ b/zirgen/dsl/driver.cpp @@ -287,15 +287,17 @@ int runTests(mlir::ModuleOp& module) { mlir::MLIRContext& context = *module.getContext(); // Set all the symbols to private mlir::PassManager pm(&context); + applyDefaultTimingPassManagerCLOptions(pm); if (failed(applyPassManagerCLOptions(pm))) { llvm::errs() << "Pass manager does not agree with command line options.\n"; return 1; } pm.enableVerifier(true); pm.addPass(mlir::createInlinerPass()); - pm.addPass(zirgen::ZStruct::createUnrollPass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createCSEPass()); + mlir::OpPassManager& opm = pm.nest(); + opm.addPass(zirgen::ZStruct::createUnrollPass()); + opm.addPass(mlir::createCanonicalizerPass()); + opm.addPass(mlir::createCSEPass()); if (failed(pm.run(module))) { llvm::errs() << "an internal compiler error occurred while inlining the tests:\n"; module.print(llvm::errs()); @@ -455,7 +457,6 @@ int main(int argc, char* argv[]) { llvm::SourceMgr sourceManager; sourceManager.setIncludeDirs(includeDirs); - context.disableMultithreading(); mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceManager, &context); openMainFile(sourceManager, inputFilename); @@ -498,6 +499,7 @@ int main(int argc, char* argv[]) { } mlir::PassManager pm(&context); + applyDefaultTimingPassManagerCLOptions(pm); if (failed(applyPassManagerCLOptions(pm))) { llvm::errs() << "Pass manager does not agree with command line options.\n"; return 1;