diff --git a/lib/Dialect/QCS/Utils/ParameterInitialValueAnalysis.cpp b/lib/Dialect/QCS/Utils/ParameterInitialValueAnalysis.cpp index c60d1e24e..bc7053964 100644 --- a/lib/Dialect/QCS/Utils/ParameterInitialValueAnalysis.cpp +++ b/lib/Dialect/QCS/Utils/ParameterInitialValueAnalysis.cpp @@ -28,59 +28,81 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "ParameterInitialValueAnalysis" using namespace mlir::qcs; +static llvm::cl::opt printAnalysisEntries( + "qcs-parameter-initial-value-analysis-print", + llvm::cl::desc("Print ParameterInitialValueAnalysis entries"), + llvm::cl::init(false)); + ParameterInitialValueAnalysis::ParameterInitialValueAnalysis( mlir::Operation *moduleOp) { - // ParameterInitialValueAnalysis should only process the top level - // module where parameters are defined - // find the top level module - auto parentOp = moduleOp->getParentOfType(); - while (parentOp) { - moduleOp = parentOp; - parentOp = moduleOp->getParentOfType(); - } - if (not invalid_) return; - // process the module top level to cache declareParameterOp initial_values - // this does not use a walk method so that submodule (if present) are not - // processed in order to limit processing time - - for (auto ®ion : moduleOp->getRegions()) - for (auto &block : region.getBlocks()) - for (auto &op : block.getOperations()) { - auto declareParameterOp = dyn_cast(op); - if (!declareParameterOp) - continue; - - // moduleOp->walk([&](DeclareParameterOp declareParameterOp) { - double initial_value = 0.0; - if (declareParameterOp.getInitialValue().has_value()) { - auto angleAttr = declareParameterOp.getInitialValue() - .value() - .dyn_cast(); - auto floatAttr = declareParameterOp.getInitialValue() - .value() - .dyn_cast(); - if (!(angleAttr || floatAttr)) - declareParameterOp.emitError("Parameters are currently limited to " - "angles or float[64] only."); - - if (angleAttr) - initial_value = angleAttr.getValue().convertToDouble(); - - if (floatAttr) - initial_value = floatAttr.getValue().convertToDouble(); + bool foundParameters = false; + + // search for the parameters in the current module + // if not found search parent module + // TODO - determine if there is a faster way to do this + do { + + // process the module top level to cache declareParameterOp initial_values + // this does not use a walk method so that submodule (if present) are not + // processed in order to limit processing time + + for (auto ®ion : moduleOp->getRegions()) + for (auto &block : region.getBlocks()) + for (auto &op : block.getOperations()) { + auto declareParameterOp = dyn_cast(op); + if (!declareParameterOp) + continue; + + double initial_value = 0.0; + if (declareParameterOp.getInitialValue().has_value()) { + auto angleAttr = declareParameterOp.getInitialValue() + .value() + .dyn_cast(); + auto floatAttr = declareParameterOp.getInitialValue() + .value() + .dyn_cast(); + if (!(angleAttr || floatAttr)) + declareParameterOp.emitError( + "Parameters are currently limited to " + "angles or float[64] only."); + + if (angleAttr) + initial_value = angleAttr.getValue().convertToDouble(); + + if (floatAttr) + initial_value = floatAttr.getValue().convertToDouble(); + } + initial_values_[declareParameterOp.getSymName()] = initial_value; + foundParameters = true; } - initial_values_[declareParameterOp.getSymName()] = initial_value; - } + if (!foundParameters) { + auto parentOp = moduleOp->getParentOfType(); + if (parentOp) + moduleOp = parentOp; + else + break; + } + } while (!foundParameters); invalid_ = false; + + // debugging / test print out + if (printAnalysisEntries) { + for (auto &initial_value : initial_values_) { + llvm::outs() << initial_value.first() << " = " + << std::get(initial_value.second) << "\n"; + } + } } void ParameterInitialValueAnalysisPass::runOnOperation() { diff --git a/test/Dialect/QCS/Utils/parameter-initial-value-analysis.mlir b/test/Dialect/QCS/Utils/parameter-initial-value-analysis.mlir new file mode 100644 index 000000000..e1f76ffe8 --- /dev/null +++ b/test/Dialect/QCS/Utils/parameter-initial-value-analysis.mlir @@ -0,0 +1,36 @@ +// RUN: qss-opt --qcs-parameter-initial-value-analysis-print --pass-pipeline='builtin.module(builtin.module(qcs-parameter-initial-value-analysis))' --mlir-disable-threading %s | FileCheck %s --check-prefixes CHECK,NESTED +// RUN: qss-opt --qcs-parameter-initial-value-analysis-print --qcs-parameter-initial-value-analysis %s | FileCheck %s + +// +// This code is part of Qiskit. +// +// (C) Copyright IBM 2024. +// +// This code is licensed under the Apache License, Version 2.0 with LLVM +// Exceptions. You may obtain a copy of this license in the LICENSE.txt +// file in the root directory of this source tree. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +module { + // without nested pass manager should only find + // alpha and beta + qcs.declare_parameter @alpha : !quir.angle<64> = #quir.angle<1.000000e+00> : !quir.angle<64> + qcs.declare_parameter @beta : f64 = 2.000000e+00 : f64 + module @first { + // nested test should find alpha and beta + qcs.declare_parameter @theta : !quir.angle<64> = #quir.angle<3.140000e+00> : !quir.angle<64> + qcs.declare_parameter @phi : f64 = 1.500000e+00 : f64 + } + module @second { + // test module without declare_parameter + // should find alpha and beta when nested + } +} + +// CHECK-DAG: alpha = 1.000000e+00 +// CHECK-DAG: beta = 2.000000e+00 +// NESTED-DAG: theta = 3.140000e+00 +// NESTED-DAG: phi = 1.500000e+00