Skip to content

Commit

Permalink
Generator: generated some wrong indices when untracking some variables.
Browse files Browse the repository at this point in the history
  • Loading branch information
agarny committed Nov 4, 2024
1 parent b4936d8 commit 66ef714
Show file tree
Hide file tree
Showing 12 changed files with 173 additions and 127 deletions.
56 changes: 49 additions & 7 deletions src/generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,48 @@ void Generator::GeneratorImpl::reset()
mCode = {};
}

std::string Generator::GeneratorImpl::doVariableIndexString(const AnalyserModelPtr &model,
const AnalyserVariablePtr &variable,
const std::vector<AnalyserVariablePtr> &variables)
{
// Determine the actual index of the variable in the list of variables by accounting for the fact that some
// variables may be untracked.

size_t i = MAX_SIZE_T;
size_t res = MAX_SIZE_T;

for (;;) {
auto var = variables[++i];

if (doIsTrackedVariable(model, var)) {
++res;
}

if (variable == var) {
break;
}
}

return convertToString(res);
}

std::string Generator::GeneratorImpl::variableIndexString(const AnalyserModelPtr &model,
const AnalyserVariablePtr &variable)
{
switch (variable->type()) {
case AnalyserVariable::Type::CONSTANT:
return doVariableIndexString(model, variable, model->constants());
case AnalyserVariable::Type::COMPUTED_CONSTANT:
return doVariableIndexString(model, variable, model->computedConstants());
case AnalyserVariable::Type::ALGEBRAIC:
return doVariableIndexString(model, variable, model->algebraic());
default:
break;
}

return convertToString(variable->index());
}

bool Generator::GeneratorImpl::doIsTrackedEquation(const AnalyserEquationPtr &equation, bool tracked)
{
switch (equation->type()) {
Expand Down Expand Up @@ -1177,7 +1219,7 @@ void Generator::GeneratorImpl::addNlaSystemsCode(const AnalyserModelPtr &model)
mProfile->algebraicArrayString();

methodBody += mProfile->indentString()
+ arrayString + mProfile->openArrayString() + convertToString(variable->index()) + mProfile->closeArrayString()
+ arrayString + mProfile->openArrayString() + variableIndexString(model, variable) + mProfile->closeArrayString()
+ mProfile->equalityString()
+ mProfile->uArrayString() + mProfile->openArrayString() + convertToString(++i) + mProfile->closeArrayString()
+ mProfile->commandSeparatorString() + "\n";
Expand Down Expand Up @@ -1255,7 +1297,7 @@ void Generator::GeneratorImpl::addNlaSystemsCode(const AnalyserModelPtr &model)
methodBody += mProfile->indentString()
+ mProfile->uArrayString() + mProfile->openArrayString() + convertToString(++i) + mProfile->closeArrayString()
+ mProfile->equalityString()
+ arrayString + mProfile->openArrayString() + convertToString(variable->index()) + mProfile->closeArrayString()
+ arrayString + mProfile->openArrayString() + variableIndexString(model, variable) + mProfile->closeArrayString()
+ mProfile->commandSeparatorString() + "\n";
}

Expand All @@ -1281,7 +1323,7 @@ void Generator::GeneratorImpl::addNlaSystemsCode(const AnalyserModelPtr &model)
mProfile->algebraicArrayString();

methodBody += mProfile->indentString()
+ arrayString + mProfile->openArrayString() + convertToString(variable->index()) + mProfile->closeArrayString()
+ arrayString + mProfile->openArrayString() + variableIndexString(model, variable) + mProfile->closeArrayString()
+ mProfile->equalityString()
+ mProfile->uArrayString() + mProfile->openArrayString() + convertToString(++i) + mProfile->closeArrayString()
+ mProfile->commandSeparatorString() + "\n";
Expand Down Expand Up @@ -1322,7 +1364,7 @@ std::string generateDoubleCode(const std::string &value)
}

std::string Generator::GeneratorImpl::generateDoubleOrConstantVariableNameCode(const AnalyserModelPtr &model,
const VariablePtr &variable) const
const VariablePtr &variable)
{
if (isCellMLReal(variable->initialValue())) {
return generateDoubleCode(variable->initialValue());
Expand All @@ -1331,7 +1373,7 @@ std::string Generator::GeneratorImpl::generateDoubleOrConstantVariableNameCode(c
auto initialValueVariable = owningComponent(variable)->variable(variable->initialValue());
auto analyserInitialValueVariable = model->variable(initialValueVariable);

return mProfile->constantsArrayString() + mProfile->openArrayString() + convertToString(analyserInitialValueVariable->index()) + mProfile->closeArrayString();
return mProfile->constantsArrayString() + mProfile->openArrayString() + variableIndexString(model, analyserInitialValueVariable) + mProfile->closeArrayString();
}

std::string Generator::GeneratorImpl::generateVariableNameCode(const AnalyserModelPtr &model,
Expand Down Expand Up @@ -1373,7 +1415,7 @@ std::string Generator::GeneratorImpl::generateVariableNameCode(const AnalyserMod
arrayName = mProfile->externalArrayString();
}

return arrayName + mProfile->openArrayString() + convertToString(analyserVariable->index()) + mProfile->closeArrayString();
return arrayName + mProfile->openArrayString() + variableIndexString(model, analyserVariable) + mProfile->closeArrayString();
}

std::string Generator::GeneratorImpl::generateOperatorCode(const AnalyserModelPtr &model, const std::string &op,
Expand Down Expand Up @@ -2252,7 +2294,7 @@ std::string Generator::GeneratorImpl::generateEquationCode(const AnalyserModelPt
+ generateVariableNameCode(model, variable->variable())
+ mProfile->equalityString()
+ replace(mProfile->externalVariableMethodCallString(modelHasOdes(model)),
"[INDEX]", convertToString(variable->index()))
"[INDEX]", variableIndexString(model, variable))
+ mProfile->commandSeparatorString() + "\n";
}

Expand Down
8 changes: 6 additions & 2 deletions src/generator_p.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,16 @@ struct Generator::GeneratorImpl: public Logger::LoggerImpl

void reset();

std::string doVariableIndexString(const AnalyserModelPtr &model, const AnalyserVariablePtr &variable,
const std::vector<AnalyserVariablePtr> &variables);
std::string variableIndexString(const AnalyserModelPtr &model, const AnalyserVariablePtr &variable);

bool doIsTrackedEquation(const AnalyserEquationPtr &equation, bool tracked);

bool isTrackedEquation(const AnalyserEquationPtr &equation);
bool isUntrackedEquation(const AnalyserEquationPtr &equation);

bool doIsTrackedVariable(const AnalyserModelPtr &model, const AnalyserVariablePtr &variable, bool tracked);
bool doIsTrackedVariable(const AnalyserModelPtr &model, const AnalyserVariablePtr &variable, bool tracked = true);
bool doIsTrackedVariable(const AnalyserVariablePtr &variable, bool tracked);

bool isTrackedVariable(const AnalyserVariablePtr &variable);
Expand Down Expand Up @@ -161,7 +165,7 @@ struct Generator::GeneratorImpl: public Logger::LoggerImpl
std::string generateMethodBodyCode(const std::string &methodBody) const;

std::string generateDoubleOrConstantVariableNameCode(const AnalyserModelPtr &model,
const VariablePtr &variable) const;
const VariablePtr &variable);
std::string generateVariableNameCode(const AnalyserModelPtr &model, const VariablePtr &variable,
bool state = true);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,21 +274,21 @@ void objectiveFunction6(double *u, double *f, void *data)
double *computedConstants = ((RootFindingInfo *) data)->computedConstants;
double *algebraic = ((RootFindingInfo *) data)->algebraic;

algebraic[7] = u[0];
algebraic[6] = u[0];

f[0] = algebraic[7]-4.0*exp(states[0]/18.0)-0.0;
f[0] = algebraic[6]-4.0*exp(states[0]/18.0)-0.0;
}

void findRoot6(double voi, double *states, double *rates, double *constants, double *computedConstants, double *algebraic)
{
RootFindingInfo rfi = { voi, states, rates, constants, computedConstants, algebraic };
double u[1];

u[0] = algebraic[7];
u[0] = algebraic[6];

nlaSolve(objectiveFunction6, u, 1, &rfi);

algebraic[7] = u[0];
algebraic[6] = u[0];
}

void objectiveFunction7(double *u, double *f, void *data)
Expand All @@ -304,7 +304,7 @@ void objectiveFunction7(double *u, double *f, void *data)

double sodium_channel_m_gate_alpha_m = 0.1*(states[0]+25.0)/(exp((states[0]+25.0)/10.0)-1.0);

f[0] = rates[2]-(sodium_channel_m_gate_alpha_m*(1.0-states[2])-algebraic[7]*states[2])-0.0;
f[0] = rates[2]-(sodium_channel_m_gate_alpha_m*(1.0-states[2])-algebraic[6]*states[2])-0.0;
}

void findRoot7(double voi, double *states, double *rates, double *constants, double *computedConstants, double *algebraic)
Expand All @@ -328,21 +328,21 @@ void objectiveFunction8(double *u, double *f, void *data)
double *computedConstants = ((RootFindingInfo *) data)->computedConstants;
double *algebraic = ((RootFindingInfo *) data)->algebraic;

algebraic[8] = u[0];
algebraic[7] = u[0];

f[0] = algebraic[8]-0.07*exp(states[0]/20.0)-0.0;
f[0] = algebraic[7]-0.07*exp(states[0]/20.0)-0.0;
}

void findRoot8(double voi, double *states, double *rates, double *constants, double *computedConstants, double *algebraic)
{
RootFindingInfo rfi = { voi, states, rates, constants, computedConstants, algebraic };
double u[1];

u[0] = algebraic[8];
u[0] = algebraic[7];

nlaSolve(objectiveFunction8, u, 1, &rfi);

algebraic[8] = u[0];
algebraic[7] = u[0];
}

void objectiveFunction9(double *u, double *f, void *data)
Expand All @@ -354,21 +354,21 @@ void objectiveFunction9(double *u, double *f, void *data)
double *computedConstants = ((RootFindingInfo *) data)->computedConstants;
double *algebraic = ((RootFindingInfo *) data)->algebraic;

algebraic[9] = u[0];
algebraic[8] = u[0];

f[0] = algebraic[9]-1.0/(exp((states[0]+30.0)/10.0)+1.0)-0.0;
f[0] = algebraic[8]-1.0/(exp((states[0]+30.0)/10.0)+1.0)-0.0;
}

void findRoot9(double voi, double *states, double *rates, double *constants, double *computedConstants, double *algebraic)
{
RootFindingInfo rfi = { voi, states, rates, constants, computedConstants, algebraic };
double u[1];

u[0] = algebraic[9];
u[0] = algebraic[8];

nlaSolve(objectiveFunction9, u, 1, &rfi);

algebraic[9] = u[0];
algebraic[8] = u[0];
}

void objectiveFunction10(double *u, double *f, void *data)
Expand All @@ -382,7 +382,7 @@ void objectiveFunction10(double *u, double *f, void *data)

rates[1] = u[0];

f[0] = rates[1]-(algebraic[8]*(1.0-states[1])-algebraic[9]*states[1])-0.0;
f[0] = rates[1]-(algebraic[7]*(1.0-states[1])-algebraic[8]*states[1])-0.0;
}

void findRoot10(double voi, double *states, double *rates, double *constants, double *computedConstants, double *algebraic)
Expand Down Expand Up @@ -432,21 +432,21 @@ void objectiveFunction12(double *u, double *f, void *data)
double *computedConstants = ((RootFindingInfo *) data)->computedConstants;
double *algebraic = ((RootFindingInfo *) data)->algebraic;

algebraic[10] = u[0];
algebraic[9] = u[0];

f[0] = algebraic[10]-0.01*(states[0]+10.0)/(exp((states[0]+10.0)/10.0)-1.0)-0.0;
f[0] = algebraic[9]-0.01*(states[0]+10.0)/(exp((states[0]+10.0)/10.0)-1.0)-0.0;
}

void findRoot12(double voi, double *states, double *rates, double *constants, double *computedConstants, double *algebraic)
{
RootFindingInfo rfi = { voi, states, rates, constants, computedConstants, algebraic };
double u[1];

u[0] = algebraic[10];
u[0] = algebraic[9];

nlaSolve(objectiveFunction12, u, 1, &rfi);

algebraic[10] = u[0];
algebraic[9] = u[0];
}

void objectiveFunction13(double *u, double *f, void *data)
Expand All @@ -458,21 +458,21 @@ void objectiveFunction13(double *u, double *f, void *data)
double *computedConstants = ((RootFindingInfo *) data)->computedConstants;
double *algebraic = ((RootFindingInfo *) data)->algebraic;

algebraic[11] = u[0];
algebraic[10] = u[0];

f[0] = algebraic[11]-0.125*exp(states[0]/80.0)-0.0;
f[0] = algebraic[10]-0.125*exp(states[0]/80.0)-0.0;
}

void findRoot13(double voi, double *states, double *rates, double *constants, double *computedConstants, double *algebraic)
{
RootFindingInfo rfi = { voi, states, rates, constants, computedConstants, algebraic };
double u[1];

u[0] = algebraic[11];
u[0] = algebraic[10];

nlaSolve(objectiveFunction13, u, 1, &rfi);

algebraic[11] = u[0];
algebraic[10] = u[0];
}

void objectiveFunction14(double *u, double *f, void *data)
Expand All @@ -486,7 +486,7 @@ void objectiveFunction14(double *u, double *f, void *data)

rates[3] = u[0];

f[0] = rates[3]-(algebraic[10]*(1.0-states[3])-algebraic[11]*states[3])-0.0;
f[0] = rates[3]-(algebraic[9]*(1.0-states[3])-algebraic[10]*states[3])-0.0;
}

void findRoot14(double voi, double *states, double *rates, double *constants, double *computedConstants, double *algebraic)
Expand Down Expand Up @@ -522,11 +522,11 @@ void initialiseVariables(double *states, double *rates, double *constants, doubl
algebraic[3] = 0.0;
algebraic[4] = 0.0;
algebraic[5] = 0.0;
algebraic[6] = 0.0;
algebraic[7] = 0.0;
algebraic[8] = 0.0;
algebraic[9] = 0.0;
algebraic[10] = 0.0;
algebraic[11] = 0.0;
}

void computeComputedConstants(double *constants, double *computedConstants)
Expand Down
Loading

0 comments on commit 66ef714

Please sign in to comment.