Skip to content

Commit

Permalink
Generator: various speed improvements.
Browse files Browse the repository at this point in the history
  • Loading branch information
agarny committed Nov 5, 2024
1 parent 66ef714 commit be1c5a0
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 72 deletions.
105 changes: 50 additions & 55 deletions src/generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,18 @@ void Generator::GeneratorImpl::reset()
mCode = {};
}

std::string Generator::GeneratorImpl::doVariableIndexString(const AnalyserModelPtr &model,
const AnalyserVariablePtr &variable,
const std::vector<AnalyserVariablePtr> &variables)
std::string Generator::GeneratorImpl::variableIndexString(const AnalyserModelPtr &model,
const AnalyserVariablePtr &variable)
{
// Determine the actual index of the variable in the list of variables by accounting for the fact that some
// variables may be untracked.

auto variables = libcellml::variables(variable);

if (variables.empty()) {
return convertToString(variable->index());
}

size_t i = MAX_SIZE_T;
size_t res = MAX_SIZE_T;

Expand All @@ -69,36 +74,21 @@ std::string Generator::GeneratorImpl::doVariableIndexString(const AnalyserModelP
return convertToString(res);
}

std::string Generator::GeneratorImpl::variableIndexString(const AnalyserModelPtr &model,
const AnalyserVariablePtr &variable)
{
switch (variable->type()) {
case AnalyserVariable::Type::CONSTANT:
return doVariableIndexString(model, variable, model->constants());
case AnalyserVariable::Type::COMPUTED_CONSTANT:
return doVariableIndexString(model, variable, model->computedConstants());
case AnalyserVariable::Type::ALGEBRAIC:
return doVariableIndexString(model, variable, model->algebraic());
default:
break;
}

return convertToString(variable->index());
}

bool Generator::GeneratorImpl::doIsTrackedEquation(const AnalyserEquationPtr &equation, bool tracked)
{
switch (equation->type()) {
case AnalyserEquation::Type::COMPUTED_CONSTANT:
return isTrackedVariable(equation->computedConstants().front()) == tracked;
case AnalyserEquation::Type::NLA:
return true;
case AnalyserEquation::Type::ALGEBRAIC:
return isTrackedVariable(equation->algebraic().front()) == tracked;
case AnalyserEquation::Type::EXTERNAL:
return isTrackedVariable(equation->externals().front()) == tracked;
case AnalyserEquation::Type::COMPUTED_CONSTANT: {
auto variable = equation->computedConstants().front();

return doIsTrackedVariable(variable->model(), variable) == tracked;
}
case AnalyserEquation::Type::ALGEBRAIC: {
auto variable = equation->algebraic().front();

return doIsTrackedVariable(variable->model(), variable) == tracked;
}
default:
return false;
return true;
}
}

Expand Down Expand Up @@ -130,15 +120,7 @@ bool Generator::GeneratorImpl::doIsTrackedVariable(const AnalyserVariablePtr &va
return false;
}

auto model = variable->model();

for (const auto &modelVariable : variables(model, false)) {
if ((variable == modelVariable) && trackableVariable(modelVariable, tracked, false)) {
return doIsTrackedVariable(model, modelVariable, tracked);
}
}

return tracked;
return doIsTrackedVariable(variable->model(), variable, tracked);
}

bool Generator::GeneratorImpl::isTrackedVariable(const AnalyserVariablePtr &variable)
Expand Down Expand Up @@ -309,7 +291,7 @@ void Generator::GeneratorImpl::doTrackVariable(const AnalyserVariablePtr &variab

auto model = variable->model();

for (const auto &modelVariable : variables(model, false)) {
for (const auto &modelVariable : variables(variable)) {
if (variable == modelVariable) {
if (trackableVariable(variable, tracked)) {
mTrackedVariables[modelVariable->model()][modelVariable] = tracked;
Expand Down Expand Up @@ -425,17 +407,29 @@ void Generator::GeneratorImpl::untrackAllAlgebraic(const AnalyserModelPtr &model
}
}

std::vector<AnalyserVariablePtr> Generator::GeneratorImpl::trackableVariables(const AnalyserModelPtr &model) const
{
auto res = model->constants();
auto computedConstants = model->computedConstants();
auto algebraic = model->algebraic();

res.insert(res.end(), computedConstants.begin(), computedConstants.end());
res.insert(res.end(), algebraic.begin(), algebraic.end());

return res;
}

void Generator::GeneratorImpl::trackAllVariables(const AnalyserModelPtr &model)
{
if (validModel(model)) {
doTrackVariables(variables(model, false), true);
doTrackVariables(trackableVariables(model), true);
}
}

void Generator::GeneratorImpl::untrackAllVariables(const AnalyserModelPtr &model)
{
if (validModel(model)) {
doTrackVariables(variables(model, false), false);
doTrackVariables(trackableVariables(model), false);
}
}

Expand Down Expand Up @@ -519,7 +513,7 @@ size_t Generator::GeneratorImpl::trackedVariableCount(const AnalyserModelPtr &mo
return 0;
}

return doTrackedVariableCount(model, variables(model, false), true);
return doTrackedVariableCount(model, trackableVariables(model), true);
}

size_t Generator::GeneratorImpl::untrackedVariableCount(const AnalyserModelPtr &model)
Expand All @@ -528,7 +522,7 @@ size_t Generator::GeneratorImpl::untrackedVariableCount(const AnalyserModelPtr &
return 0;
}

return doTrackedVariableCount(model, variables(model, false), false);
return doTrackedVariableCount(model, trackableVariables(model), false);
}

bool Generator::GeneratorImpl::modelHasOdes(const AnalyserModelPtr &model) const
Expand Down Expand Up @@ -783,23 +777,23 @@ void Generator::GeneratorImpl::addStateAndVariableCountCode(const AnalyserModelP
code += interface ?
mProfile->interfaceConstantCountString() :
replace(mProfile->implementationConstantCountString(),
"[CONSTANT_COUNT]", std::to_string(trackedConstantCount(model)));
"[CONSTANT_COUNT]", std::to_string(doTrackedVariableCount(model, model->constants(), true)));
}

if ((interface && !mProfile->interfaceComputedConstantCountString().empty())
|| (!interface && !mProfile->implementationComputedConstantCountString().empty())) {
code += interface ?
mProfile->interfaceComputedConstantCountString() :
replace(mProfile->implementationComputedConstantCountString(),
"[COMPUTED_CONSTANT_COUNT]", std::to_string(trackedComputedConstantCount(model)));
"[COMPUTED_CONSTANT_COUNT]", std::to_string(doTrackedVariableCount(model, model->computedConstants(), true)));
}

if ((interface && !mProfile->interfaceAlgebraicCountString().empty())
|| (!interface && !mProfile->implementationAlgebraicCountString().empty())) {
code += interface ?
mProfile->interfaceAlgebraicCountString() :
replace(mProfile->implementationAlgebraicCountString(),
"[ALGEBRAIC_COUNT]", std::to_string(trackedAlgebraicCount(model)));
"[ALGEBRAIC_COUNT]", std::to_string(doTrackedVariableCount(model, model->algebraic(), true)));
}

if ((model->externalCount() != 0)
Expand All @@ -825,7 +819,7 @@ std::string Generator::GeneratorImpl::generateVariableInfoObjectCode(const Analy
size_t unitsSize = 0;

for (const auto &variable : variables(model)) {
if (isTrackedVariable(variable)) {
if (doIsTrackedVariable(model, variable)) {
updateVariableInfoSizes(componentSize, nameSize, unitsSize, variable);
}
}
Expand Down Expand Up @@ -898,10 +892,11 @@ void Generator::GeneratorImpl::doAddImplementationVariableInfoCode(const std::st
if (!variableInfoString.empty()
&& !mProfile->variableInfoEntryString().empty()
&& !mProfile->arrayElementSeparatorString().empty()) {
auto model = variables.empty() ? nullptr : variables.front()->model();
std::string infoElementsCode;

for (const auto &variable : variables) {
if (isTrackedVariable(variable)) {
if (doIsTrackedVariable(model, variable)) {
if (!infoElementsCode.empty()) {
infoElementsCode += mProfile->arrayElementSeparatorString() + "\n";
}
Expand Down Expand Up @@ -1233,7 +1228,7 @@ void Generator::GeneratorImpl::addNlaSystemsCode(const AnalyserModelPtr &model)
auto methodBodySize = methodBody.size();

for (const auto &constantDependency : equation->mPimpl->mConstantDependencies) {
if (isUntrackedVariable(constantDependency)) {
if (doIsTrackedVariable(model, constantDependency, false)) {
methodBody += generateInitialisationCode(model, constantDependency, true);
}
}
Expand Down Expand Up @@ -1395,7 +1390,7 @@ std::string Generator::GeneratorImpl::generateVariableNameCode(const AnalyserMod
return mProfile->voiString();
}

if (isUntrackedVariable(analyserVariable)) {
if (doIsTrackedVariable(model, analyserVariable, false)) {
return owningComponent(analyserVariable->variable())->name() + "_" + analyserVariable->variable()->name();
}

Expand Down Expand Up @@ -2105,7 +2100,7 @@ std::string Generator::GeneratorImpl::generateCode(const AnalyserModelPtr &model
if ((model != nullptr)
&& (astParent->type() == AnalyserEquationAst::Type::EQUALITY)
&& (astParent->leftChild() == ast)
&& isUntrackedVariable(model->variable(ast->variable()))) {
&& doIsTrackedVariable(model, model->variable(ast->variable()), false)) {
// Note: we want this AST to be its parent's left child since a declaration is always of the form x = RHS,
// not LHS = x.

Expand Down Expand Up @@ -2166,7 +2161,7 @@ bool Generator::GeneratorImpl::isToBeComputedAgain(const AnalyserEquationPtr &eq
case AnalyserEquation::Type::ALGEBRAIC:
if (equation->isStateRateBased()) {
for (const auto &variable : variables(equation)) {
if (isTrackedVariable(variable)) {
if (doIsTrackedVariable(variable->model(), variable)) {
return true;
}
}
Expand Down Expand Up @@ -2202,7 +2197,7 @@ std::string Generator::GeneratorImpl::generateZeroInitialisationCode(const Analy
std::string Generator::GeneratorImpl::generateInitialisationCode(const AnalyserModelPtr &model,
const AnalyserVariablePtr &variable, bool force)
{
if (!force && isUntrackedVariable(variable)) {
if (!force && doIsTrackedVariable(model, variable, false)) {
return {};
}

Expand All @@ -2219,7 +2214,7 @@ std::string Generator::GeneratorImpl::generateInitialisationCode(const AnalyserM
+ scalingFactorCode + generateDoubleOrConstantVariableNameCode(model, initialisingVariable)
+ mProfile->commandSeparatorString() + "\n";

if (isUntrackedVariable(variable)) {
if (doIsTrackedVariable(model, variable, false)) {
code = replace(mProfile->variableDeclarationString(), "[CODE]", code);
}

Expand Down Expand Up @@ -2253,7 +2248,7 @@ std::string Generator::GeneratorImpl::generateEquationCode(const AnalyserModelPt

for (const auto &constantDependency : equation->mPimpl->mConstantDependencies) {
if ((equation->type() != AnalyserEquation::Type::NLA)
&& isUntrackedVariable(constantDependency)
&& doIsTrackedVariable(model, constantDependency, false)
&& (std::find(generatedConstantDependencies.begin(), generatedConstantDependencies.end(), constantDependency) == generatedConstantDependencies.end())) {
res += generateInitialisationCode(model, constantDependency, true);

Expand Down
4 changes: 2 additions & 2 deletions src/generator_p.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ struct Generator::GeneratorImpl: public Logger::LoggerImpl

void reset();

std::string doVariableIndexString(const AnalyserModelPtr &model, const AnalyserVariablePtr &variable,
const std::vector<AnalyserVariablePtr> &variables);
std::string variableIndexString(const AnalyserModelPtr &model, const AnalyserVariablePtr &variable);

bool doIsTrackedEquation(const AnalyserEquationPtr &equation, bool tracked);
Expand Down Expand Up @@ -82,6 +80,8 @@ struct Generator::GeneratorImpl: public Logger::LoggerImpl
void trackAllAlgebraic(const AnalyserModelPtr &model);
void untrackAllAlgebraic(const AnalyserModelPtr &model);

std::vector<AnalyserVariablePtr> trackableVariables(const AnalyserModelPtr &model) const;

void trackAllVariables(const AnalyserModelPtr &model);
void untrackAllVariables(const AnalyserModelPtr &model);

Expand Down
45 changes: 32 additions & 13 deletions src/utilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.

#include "libcellml/analyserequation.h"
#include "libcellml/analysermodel.h"
#include "libcellml/analyservariable.h"
#include "libcellml/component.h"
#include "libcellml/importsource.h"
#include "libcellml/model.h"
Expand Down Expand Up @@ -1317,33 +1318,51 @@ XmlNodePtr mathmlChildNode(const XmlNodePtr &node, size_t index)
return res;
}

std::vector<AnalyserVariablePtr> variables(const AnalyserModelPtr &model, bool allVariables)
std::vector<AnalyserVariablePtr> variables(const AnalyserVariablePtr &variable)
{
std::vector<AnalyserVariablePtr> res;

if (allVariables) {
if (model->voi() != nullptr) {
res.push_back(model->voi());
}
switch (variable->type()) {
case AnalyserVariable::Type::CONSTANT:
return variable->model()->constants();

break;
case AnalyserVariable::Type::COMPUTED_CONSTANT:
return variable->model()->computedConstants();

auto states = model->states();
break;
case AnalyserVariable::Type::ALGEBRAIC:
return variable->model()->algebraic();

res.insert(res.end(), states.begin(), states.end());
break;
default:
break;
}

return {};
}

std::vector<AnalyserVariablePtr> variables(const AnalyserModelPtr &model)
{
std::vector<AnalyserVariablePtr> res;

if (model->voi() != nullptr) {
res.push_back(model->voi());
}

auto states = model->states();

res.insert(res.end(), states.begin(), states.end());

auto constants = model->constants();
auto computedConstants = model->computedConstants();
auto algebraic = model->algebraic();
auto externals = model->externals();

res.insert(res.end(), constants.begin(), constants.end());
res.insert(res.end(), computedConstants.begin(), computedConstants.end());
res.insert(res.end(), algebraic.begin(), algebraic.end());

if (allVariables) {
auto externals = model->externals();

res.insert(res.end(), externals.begin(), externals.end());
}
res.insert(res.end(), externals.begin(), externals.end());

return res;
}
Expand Down
14 changes: 12 additions & 2 deletions src/utilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -870,17 +870,27 @@ size_t mathmlChildCount(const XmlNodePtr &node);
*/
XmlNodePtr mathmlChildNode(const XmlNodePtr &node, size_t index);

/**
* @brief Return the variables of the same type as the given variable.
*
* Return the variables of the same type as the given variable.
*
* @param variable The variable for which we want the variables of the same type.
*
* @return The variables of the same type as the given variable.
*/
std::vector<AnalyserVariablePtr> variables(const AnalyserVariablePtr &variable);

/**
* @brief Return the variables in the given model.
*
* Return the variables in the given model.
*
* @param model The model for which we want the variables.
* @param allVariables Whether to return all variables or just the ones that can be untracked.
*
* @return The variables in the given model.
*/
std::vector<AnalyserVariablePtr> variables(const AnalyserModelPtr &model, bool allVariables = true);
std::vector<AnalyserVariablePtr> variables(const AnalyserModelPtr &model);

/**
* @brief Return the variables in the given equation.
Expand Down

0 comments on commit be1c5a0

Please sign in to comment.