Skip to content

Commit

Permalink
[RF] Add RooFuncWrapper::writeDebugMacro() method
Browse files Browse the repository at this point in the history
This new method finally makes it easier to produce standalone versions
of the generated likelihood code. It was already used a lot for
debugging the ATLAS and CMS likelihoods.
  • Loading branch information
guitargeek committed Jul 9, 2024
1 parent cc9d9b1 commit b848606
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 19 deletions.
12 changes: 7 additions & 5 deletions roofit/roofitcore/inc/RooFuncWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <map>
#include <memory>
#include <string>
#include <sstream>

class RooSimultaneous;

Expand Down Expand Up @@ -48,10 +49,6 @@ class RooFuncWrapper final : public RooAbsReal {

std::size_t getNumParams() const { return _params.size(); }

void dumpCode();

void dumpGradient();

/// No constant term optimization is possible in code-generation mode.
void constOptimizeTestStatistic(ConstOpCode /*opcode*/, bool /*doAlsoTrackingOpt*/) override {}

Expand All @@ -61,13 +58,15 @@ class RooFuncWrapper final : public RooAbsReal {

void disableEvaluator() { _useEvaluator = false; }

void writeDebugMacro(std::string const &) const;

protected:
double evaluate() const override;

private:
std::string buildCode(RooAbsReal const &head);

static std::string declareFunction(std::string const &funcBody);
std::string declareFunction(std::string const &funcBody);

void updateGradientVarBuffer() const;

Expand All @@ -76,6 +75,8 @@ class RooFuncWrapper final : public RooAbsReal {

void buildFuncAndGradFunctors();

bool declareToInterpreter(std::string const &code);

using Func = double (*)(double *, double const *, double const *);
using Grad = void (*)(double *, double const *, double const *, double *);

Expand All @@ -97,6 +98,7 @@ class RooFuncWrapper final : public RooAbsReal {
std::map<RooFit::Detail::DataKey, ObsInfo> _obsInfos;
std::map<RooFit::Detail::DataKey, std::size_t> _nodeOutputSizes;
std::vector<double> _xlArr;
std::stringstream _allCode;
};

} // namespace Experimental
Expand Down
71 changes: 57 additions & 14 deletions roofit/roofitcore/src/RooFuncWrapper.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <TROOT.h>
#include <TSystem.h>

#include <fstream>

namespace RooFit {

namespace Experimental {
Expand Down Expand Up @@ -55,6 +57,8 @@ RooFuncWrapper::RooFuncWrapper(const char *name, const char *title, RooAbsReal &

func = buildCode(obj);

declareToInterpreter("#pragma cling optimize(2)");

// Declare the function and create its derivative.
_funcName = declareFunction(func);
_func = reinterpret_cast<Func>(gInterpreter->ProcessLine((_funcName + ";").c_str()));
Expand Down Expand Up @@ -123,14 +127,11 @@ std::string RooFuncWrapper::declareFunction(std::string const &funcBody)
static int iFuncWrapper = 0;
auto funcName = "roo_func_wrapper_" + std::to_string(iFuncWrapper++);

gInterpreter->Declare("#pragma cling optimize(2)");

// Declare the function
std::stringstream bodyWithSigStrm;
bodyWithSigStrm << "double " << funcName << "(double* params, double const* obs, double const* xlArr) {\n"
<< funcBody << "\n}";
bool comp = gInterpreter->Declare(bodyWithSigStrm.str().c_str());
if (!comp) {
if (!declareToInterpreter(bodyWithSigStrm.str())) {
std::stringstream errorMsg;
errorMsg << "Function " << funcName << " could not be compiled. See above for details.";
oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
Expand All @@ -146,7 +147,7 @@ void RooFuncWrapper::createGradient()
std::string wrapperName = _funcName + "_derivativeWrapper";

// Calculate gradient
gInterpreter->ProcessLine("#include <Math/CladDerivator.h>");
declareToInterpreter("#include <Math/CladDerivator.h>\n");
// disable clang-format for making the following code unreadable.
// clang-format off
std::stringstream requestFuncStrm;
Expand All @@ -156,8 +157,7 @@ void RooFuncWrapper::createGradient()
"}\n"
"#pragma clad OFF";
// clang-format on
auto comp = gInterpreter->Declare(requestFuncStrm.str().c_str());
if (!comp) {
if (!declareToInterpreter(requestFuncStrm.str())) {
std::stringstream errorMsg;
errorMsg << "Function " << GetName() << " could not be differentiated. See above for details.";
oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
Expand All @@ -173,7 +173,7 @@ void RooFuncWrapper::createGradient()
" " << gradName << "(params, obs, xlArr, cladOut);\n"
"}";
// clang-format on
gInterpreter->Declare(dWrapperStrm.str().c_str());
declareToInterpreter(dWrapperStrm.str());
_grad = reinterpret_cast<Grad>(gInterpreter->ProcessLine((wrapperName + ";").c_str()));
_hasGradient = true;
}
Expand Down Expand Up @@ -235,16 +235,59 @@ std::string RooFuncWrapper::buildCode(RooAbsReal const &head)
return ctx.assembleCode(ctx.getResult(head));
}

/// @brief Prints the squashed code body to console.
void RooFuncWrapper::dumpCode()
/// @brief Declare code to the interpreter and keep track of all declared code in this RooFuncWrapper.
bool RooFuncWrapper::declareToInterpreter(std::string const &code)
{
gInterpreter->ProcessLine(_funcName.c_str());
_allCode << code << std::endl;
return gInterpreter->Declare(code.c_str());
}

/// @brief Prints the derivative code body to console.
void RooFuncWrapper::dumpGradient()
/// @brief Dumps a macro "filename.C" that can be used to test and debug the generated code and gradient.
void RooFuncWrapper::writeDebugMacro(std::string const &filename) const
{
std::ofstream outFile;
outFile.open(filename + ".C");
outFile << "#include <RooFit/Detail/MathFuncs.h>" << std::endl;
outFile << std::endl;
outFile << _allCode.str();
outFile << std::endl;

updateGradientVarBuffer();

auto writeVector = [&](std::string const &name, std::span<const double> vec) {
outFile << "std::vector<double> " << name << " = {";
for (std::size_t i = 0; i < vec.size(); ++i) {
if (i % 10 == 0)
outFile << "\n ";
outFile << vec[i];
if (i < vec.size() - 1)
outFile << ", ";
}
outFile << "\n};\n";
};

outFile << "// clang-format off\n" << std::endl;
writeVector("parametersVec", _gradientVarBuffer);
outFile << std::endl;
writeVector("observablesVec", _observables);
outFile << std::endl;
writeVector("auxConstantsVec", _xlArr);
outFile << std::endl;
outFile << "// clang-format on\n" << std::endl;

outFile << R"(
// To run as a ROOT macro
void )" << filename
<< R"(()
{
gInterpreter->ProcessLine((_funcName + "_grad_0").c_str());
std::vector<double> gradientVec(parametersVec.size());
)" << _funcName
<< R"((parametersVec.data(), observablesVec.data(), auxConstantsVec.data());
)" << _funcName
<< R"(_grad_0(parametersVec.data(), observablesVec.data(), auxConstantsVec.data(), gradientVec.data());
}
)";
}

} // namespace Experimental
Expand Down

0 comments on commit b848606

Please sign in to comment.