Skip to content

Commit

Permalink
Updates C++ extension and add context manager to print message events (
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenNneji authored Sep 2, 2024
1 parent b53fe23 commit 06ecccb
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 74 deletions.
2 changes: 2 additions & 0 deletions RATapi/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,5 +468,7 @@ def make_controls(input_controls: RATapi.Controls, checks: Checks) -> Control:
controls.adaptPCR = input_controls.adaptPCR
# Checks
controls.checks = checks
# IPC
controls.IPCFilePath = ""

return controls
35 changes: 34 additions & 1 deletion RATapi/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,39 @@ def __exit__(self, _exc_type, _exc_val, _traceback):
RATapi.events.clear(RATapi.events.EventTypes.Progress, self.updateProgress)


class TextOutput:
"""Pipes the message event to stdout
Parameters
----------
display : bool, default: True
Indicates if displaying is allowed
"""

def __init__(self, display=True):
self.display = display

def __enter__(self):
if self.display:
RATapi.events.register(RATapi.events.EventTypes.Message, self.printMessage)

return self

def printMessage(self, msg):
"""Callback for the message event.
Parameters
----------
msg: str
The event message.
"""
print(msg, end="")

def __exit__(self, _exc_type, _exc_val, _traceback):
if self.display:
RATapi.events.clear(RATapi.events.EventTypes.Message, self.printMessage)


def run(project, controls):
"""Run RAT for the given project and controls inputs."""
parameter_field = {
Expand All @@ -77,7 +110,7 @@ def run(project, controls):
print("Starting RAT " + horizontal_line)

start = time.time()
with ProgressBar(display=display_on):
with ProgressBar(display=display_on), TextOutput(display=display_on):
problem_definition, output_results, bayes_results = RATapi.rat_core.RATMain(
problem_definition,
cells,
Expand Down
2 changes: 1 addition & 1 deletion cpp/RAT
Submodule RAT updated 115 files
170 changes: 99 additions & 71 deletions cpp/rat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ struct Control {
std::string boundHandling {};
boolean_T adaptPCR;
Checks checks {};
std::string IPCFilePath {};
};


Expand All @@ -541,7 +542,27 @@ void stringToRatCharArray(std::string value, coder::array<char_T, 2U>& result)
}
}

coder::array<real_T, 2U> pyArrayToRatArray1d(py::array_t<real_T> value)
coder::array<real_T, 1U> pyArrayToRatRowArray1d(py::array_t<real_T> value)
{
coder::array<real_T, 1U> result;

py::buffer_info buffer_info = value.request();

if (buffer_info.size == 0)
return result;

if (buffer_info.ndim != 1)
throw std::runtime_error("Expects a 1D numeric array");

result.set_size(buffer_info.shape[0]);
for (int32_T idx0{0}; idx0 < buffer_info.shape[0]; idx0++) {
result[idx0] = value.at(idx0);
}

return result;
}

coder::array<real_T, 2U> pyArrayToRatColArray1d(py::array_t<real_T> value)
{
coder::array<real_T, 2U> result;

Expand Down Expand Up @@ -619,8 +640,9 @@ coder::array<RAT::cell_0, 1U> pyListToUnboundedCell0(py::list values)
!py::isinstance<py::float_>(value[2]) || !py::isinstance<py::float_>(value[3]))
throw std::runtime_error("Expects a 2D list where each row must contain 4 elements. "
"Columns 1 and 2 must be strings and Columns 3 and 4 must be numeric arrays");
stringToRatCharArray(value[0].cast<std::string>(), result[idx].f1);
stringToRatCharArray(value[1].cast<std::string>(), result[idx].f2);

stringToRatArray(value[0].cast<std::string>(), result[idx].f1.data, result[idx].f1.size);
stringToRatArray(value[1].cast<std::string>(), result[idx].f2.data, result[idx].f2.size);
result[idx].f3 = value[2].cast<real_T>();
result[idx].f4 = value[3].cast<real_T>();
idx++;
Expand Down Expand Up @@ -662,30 +684,30 @@ RAT::struct0_T createStruct0(const ProblemDefinition& problem)
stringToRatArray(problem.geometry, problem_struct.geometry.data, problem_struct.geometry.size);
stringToRatArray(problem.TF, problem_struct.TF.data, problem_struct.TF.size);

problem_struct.contrastBackgroundParams = customCaller("Problem.contrastBackgroundParams", pyArrayToRatArray1d, problem.contrastBackgroundParams);
problem_struct.contrastBackgroundActions = customCaller("Problem.contrastBackgroundActions", pyArrayToRatArray1d, problem.contrastBackgroundActions);
problem_struct.resample = customCaller("Problem.resample", pyArrayToRatArray1d, problem.resample);
problem_struct.dataPresent = customCaller("Problem.dataPresent", pyArrayToRatArray1d, problem.dataPresent);
problem_struct.oilChiDataPresent = customCaller("Problem.oilChiDataPresent", pyArrayToRatArray1d, problem.oilChiDataPresent);
problem_struct.contrastQzshifts = customCaller("Problem.contrastQzshifts", pyArrayToRatArray1d, problem.contrastQzshifts);
problem_struct.contrastScalefactors = customCaller("Problem.contrastScalefactors", pyArrayToRatArray1d, problem.contrastScalefactors);
problem_struct.contrastBulkIns = customCaller("Problem.contrastBulkIns", pyArrayToRatArray1d, problem.contrastBulkIns);
problem_struct.contrastBulkOuts = customCaller("Problem.contrastBulkOuts", pyArrayToRatArray1d, problem.contrastBulkOuts);
problem_struct.contrastResolutionParams = customCaller("Problem.contrastResolutionParams", pyArrayToRatArray1d, problem.contrastResolutionParams);
problem_struct.backgroundParams = customCaller("Problem.backgroundParams", pyArrayToRatArray1d, problem.backgroundParams);
problem_struct.qzshifts = customCaller("Problem.qzshifts", pyArrayToRatArray1d, problem.qzshifts);
problem_struct.scalefactors = customCaller("Problem.scalefactors", pyArrayToRatArray1d, problem.scalefactors);
problem_struct.bulkIn = customCaller("Problem.bulkIn", pyArrayToRatArray1d, problem.bulkIn);
problem_struct.bulkOut = customCaller("Problem.bulkOut", pyArrayToRatArray1d, problem.bulkOut);
problem_struct.resolutionParams = customCaller("Problem.resolutionParams", pyArrayToRatArray1d, problem.resolutionParams);
problem_struct.params = customCaller("Problem.params", pyArrayToRatArray1d, problem.params);

problem_struct.contrastCustomFiles = customCaller("Problem.contrastCustomFiles", pyArrayToRatArray1d, problem.contrastCustomFiles);
problem_struct.contrastDomainRatios = customCaller("Problem.contrastDomainRatios", pyArrayToRatArray1d, problem.contrastDomainRatios);
problem_struct.domainRatio = customCaller("Problem.domainRatio", pyArrayToRatArray1d, problem.domainRatio);

problem_struct.fitParams = customCaller("Problem.fitParams", pyArrayToRatArray1d, problem.fitParams);
problem_struct.otherParams = customCaller("Problem.otherParams", pyArrayToRatArray1d, problem.otherParams);
problem_struct.contrastBackgroundParams = customCaller("Problem.contrastBackgroundParams", pyArrayToRatColArray1d, problem.contrastBackgroundParams);
problem_struct.contrastBackgroundActions = customCaller("Problem.contrastBackgroundActions", pyArrayToRatColArray1d, problem.contrastBackgroundActions);
problem_struct.resample = customCaller("Problem.resample", pyArrayToRatColArray1d, problem.resample);
problem_struct.dataPresent = customCaller("Problem.dataPresent", pyArrayToRatColArray1d, problem.dataPresent);
problem_struct.oilChiDataPresent = customCaller("Problem.oilChiDataPresent", pyArrayToRatColArray1d, problem.oilChiDataPresent);
problem_struct.contrastQzshifts = customCaller("Problem.contrastQzshifts", pyArrayToRatColArray1d, problem.contrastQzshifts);
problem_struct.contrastScalefactors = customCaller("Problem.contrastScalefactors", pyArrayToRatColArray1d, problem.contrastScalefactors);
problem_struct.contrastBulkIns = customCaller("Problem.contrastBulkIns", pyArrayToRatColArray1d, problem.contrastBulkIns);
problem_struct.contrastBulkOuts = customCaller("Problem.contrastBulkOuts", pyArrayToRatColArray1d, problem.contrastBulkOuts);
problem_struct.contrastResolutionParams = customCaller("Problem.contrastResolutionParams", pyArrayToRatColArray1d, problem.contrastResolutionParams);
problem_struct.backgroundParams = customCaller("Problem.backgroundParams", pyArrayToRatColArray1d, problem.backgroundParams);
problem_struct.qzshifts = customCaller("Problem.qzshifts", pyArrayToRatColArray1d, problem.qzshifts);
problem_struct.scalefactors = customCaller("Problem.scalefactors", pyArrayToRatColArray1d, problem.scalefactors);
problem_struct.bulkIn = customCaller("Problem.bulkIn", pyArrayToRatColArray1d, problem.bulkIn);
problem_struct.bulkOut = customCaller("Problem.bulkOut", pyArrayToRatColArray1d, problem.bulkOut);
problem_struct.resolutionParams = customCaller("Problem.resolutionParams", pyArrayToRatColArray1d, problem.resolutionParams);
problem_struct.params = customCaller("Problem.params", pyArrayToRatColArray1d, problem.params);

problem_struct.contrastCustomFiles = customCaller("Problem.contrastCustomFiles", pyArrayToRatColArray1d, problem.contrastCustomFiles);
problem_struct.contrastDomainRatios = customCaller("Problem.contrastDomainRatios", pyArrayToRatColArray1d, problem.contrastDomainRatios);
problem_struct.domainRatio = customCaller("Problem.domainRatio", pyArrayToRatColArray1d, problem.domainRatio);

problem_struct.fitParams = customCaller("Problem.fitParams", pyArrayToRatRowArray1d, problem.fitParams);
problem_struct.otherParams = customCaller("Problem.otherParams", pyArrayToRatRowArray1d, problem.otherParams);
problem_struct.fitLimits = customCaller("Problem.fitLimits", pyArrayToRatArray2d, problem.fitLimits);
problem_struct.otherLimits = customCaller("Problem.otherLimits", pyArrayToRatArray2d, problem.otherLimits);

Expand All @@ -710,14 +732,14 @@ RAT::struct1_T createStruct1(const Limits& limits)
RAT::struct3_T createStruct3(const Checks& checks)
{
RAT::struct3_T checks_struct;
checks_struct.fitParam = customCaller("Checks.fitParam", pyArrayToRatArray1d, checks.fitParam);
checks_struct.fitBackgroundParam = customCaller("Checks.fitBackgroundParam", pyArrayToRatArray1d, checks.fitBackgroundParam);
checks_struct.fitQzshift = customCaller("Checks.fitQzshift", pyArrayToRatArray1d, checks.fitQzshift);
checks_struct.fitScalefactor = customCaller("Checks.fitScalefactor", pyArrayToRatArray1d, checks.fitScalefactor);
checks_struct.fitBulkIn = customCaller("Checks.fitBulkIn", pyArrayToRatArray1d, checks.fitBulkIn);
checks_struct.fitBulkOut = customCaller("Checks.fitBulkOut", pyArrayToRatArray1d, checks.fitBulkOut);
checks_struct.fitResolutionParam = customCaller("Checks.fitResolutionParam", pyArrayToRatArray1d, checks.fitResolutionParam);
checks_struct.fitDomainRatio = customCaller("Checks.fitDomainRatio", pyArrayToRatArray1d, checks.fitDomainRatio);
checks_struct.fitParam = customCaller("Checks.fitParam", pyArrayToRatColArray1d, checks.fitParam);
checks_struct.fitBackgroundParam = customCaller("Checks.fitBackgroundParam", pyArrayToRatColArray1d, checks.fitBackgroundParam);
checks_struct.fitQzshift = customCaller("Checks.fitQzshift", pyArrayToRatColArray1d, checks.fitQzshift);
checks_struct.fitScalefactor = customCaller("Checks.fitScalefactor", pyArrayToRatColArray1d, checks.fitScalefactor);
checks_struct.fitBulkIn = customCaller("Checks.fitBulkIn", pyArrayToRatColArray1d, checks.fitBulkIn);
checks_struct.fitBulkOut = customCaller("Checks.fitBulkOut", pyArrayToRatColArray1d, checks.fitBulkOut);
checks_struct.fitResolutionParam = customCaller("Checks.fitResolutionParam", pyArrayToRatColArray1d, checks.fitResolutionParam);
checks_struct.fitDomainRatio = customCaller("Checks.fitDomainRatio", pyArrayToRatColArray1d, checks.fitDomainRatio);

return checks_struct;
}
Expand Down Expand Up @@ -780,7 +802,7 @@ coder::array<RAT::cell_wrap_4, 2U> pyListToRatCellWrap4(py::list values)
for (py::handle array: values)
{
py::array_t<real_T> casted_array = py::cast<py::array>(array);
result[idx].f1 = customCaller("$id[" + std::to_string(idx) +"]", pyArrayToRatArray1d, casted_array);
result[idx].f1 = customCaller("$id[" + std::to_string(idx) +"]", pyArrayToRatColArray1d, casted_array);
idx++;
}

Expand Down Expand Up @@ -897,13 +919,16 @@ RAT::struct2_T createStruct2T(const Control& control)
stringToRatArray(control.boundHandling, control_struct.boundHandling.data, control_struct.boundHandling.size);
control_struct.adaptPCR = control.adaptPCR;
control_struct.checks = createStruct3(control.checks);
stringToRatArray(control.IPCFilePath, control_struct.IPCFilePath.data, control_struct.IPCFilePath.size);

return control_struct;
}

py::array_t<real_T> pyArrayFromRatArray1d(coder::array<real_T, 2U> array)

template <typename T>
py::array_t<real_T> pyArrayFromRatArray1d(T array, bool isCol=true)
{
auto size = (array.size(0) > 1) ? array.size(0) : array.size(1);
auto size = isCol ? array.size(1) : array.size(0);
auto result_array = py::array_t<real_T>(size);
std::memcpy(result_array.request().ptr, array.data(), result_array.nbytes());

Expand Down Expand Up @@ -1040,37 +1065,38 @@ ProblemDefinition problemDefinitionFromStruct0T(const RAT::struct0_T problem)
problem_def.TF.resize(problem.TF.size[1]);
memcpy(&problem_def.TF[0], problem.TF.data, problem.TF.size[1]);

problem_def.contrastBackgroundParams = pyArrayFromRatArray1d(problem.contrastBackgroundParams);
problem_def.contrastBackgroundActions = pyArrayFromRatArray1d(problem.contrastBackgroundActions);
problem_def.resample = pyArrayFromRatArray1d(problem.resample);
problem_def.dataPresent = pyArrayFromRatArray1d(problem.dataPresent);
problem_def.oilChiDataPresent = pyArrayFromRatArray1d(problem.oilChiDataPresent);
problem_def.contrastQzshifts = pyArrayFromRatArray1d(problem.contrastQzshifts);
problem_def.contrastScalefactors = pyArrayFromRatArray1d(problem.contrastScalefactors);
problem_def.contrastBulkIns = pyArrayFromRatArray1d(problem.contrastBulkIns);
problem_def.contrastBulkOuts = pyArrayFromRatArray1d(problem.contrastBulkOuts);
problem_def.contrastResolutionParams = pyArrayFromRatArray1d(problem.contrastResolutionParams);
problem_def.backgroundParams = pyArrayFromRatArray1d(problem.backgroundParams);
problem_def.qzshifts = pyArrayFromRatArray1d(problem.qzshifts);
problem_def.scalefactors = pyArrayFromRatArray1d(problem.scalefactors);
problem_def.bulkIn = pyArrayFromRatArray1d(problem.bulkIn);
problem_def.bulkOut = pyArrayFromRatArray1d(problem.bulkOut);
problem_def.resolutionParams = pyArrayFromRatArray1d(problem.resolutionParams);
problem_def.params = pyArrayFromRatArray1d(problem.params);

problem_def.contrastCustomFiles = pyArrayFromRatArray1d(problem.contrastCustomFiles);
problem_def.contrastDomainRatios = pyArrayFromRatArray1d(problem.contrastDomainRatios);
problem_def.domainRatio = pyArrayFromRatArray1d(problem.domainRatio);

problem_def.fitParams = pyArrayFromRatArray1d(problem.fitParams);
problem_def.otherParams = pyArrayFromRatArray1d(problem.otherParams);
problem_def.fitLimits = pyArrayFromRatArray2d(problem.fitLimits);
problem_def.otherLimits = pyArrayFromRatArray2d(problem.otherLimits);
problem_def.contrastBackgroundParams = pyArrayFromRatArray1d<coder::array<real_T, 2U>>(problem.contrastBackgroundParams);
problem_def.contrastBackgroundActions = pyArrayFromRatArray1d<coder::array<real_T, 2U>>(problem.contrastBackgroundActions);
problem_def.resample = pyArrayFromRatArray1d<coder::array<real_T, 2U>>(problem.resample);
problem_def.dataPresent = pyArrayFromRatArray1d<coder::array<real_T, 2U>>(problem.dataPresent);
problem_def.oilChiDataPresent = pyArrayFromRatArray1d<coder::array<real_T, 2U>>(problem.oilChiDataPresent);
problem_def.contrastQzshifts = pyArrayFromRatArray1d<coder::array<real_T, 2U>>(problem.contrastQzshifts);
problem_def.contrastScalefactors = pyArrayFromRatArray1d<coder::array<real_T, 2U>>(problem.contrastScalefactors);
problem_def.contrastBulkIns = pyArrayFromRatArray1d<coder::array<real_T, 2U>>(problem.contrastBulkIns);
problem_def.contrastBulkOuts = pyArrayFromRatArray1d<coder::array<real_T, 2U>>(problem.contrastBulkOuts);
problem_def.contrastResolutionParams = pyArrayFromRatArray1d<coder::array<real_T, 2U>>(problem.contrastResolutionParams);
problem_def.backgroundParams = pyArrayFromRatArray1d<coder::array<real_T, 2U>>(problem.backgroundParams);
problem_def.qzshifts = pyArrayFromRatArray1d<coder::array<real_T, 2U>>(problem.qzshifts);
problem_def.scalefactors = pyArrayFromRatArray1d<coder::array<real_T, 2U>>(problem.scalefactors);
problem_def.bulkIn = pyArrayFromRatArray1d<coder::array<real_T, 2U>>(problem.bulkIn);
problem_def.bulkOut = pyArrayFromRatArray1d<coder::array<real_T, 2U>>(problem.bulkOut);
problem_def.resolutionParams = pyArrayFromRatArray1d<coder::array<real_T, 2U>>(problem.resolutionParams);
problem_def.params = pyArrayFromRatArray1d<coder::array<real_T, 2U>>(problem.params);

problem_def.contrastCustomFiles = pyArrayFromRatArray1d<coder::array<real_T, 2U>>(problem.contrastCustomFiles);
problem_def.contrastDomainRatios = pyArrayFromRatArray1d<coder::array<real_T, 2U>>(problem.contrastDomainRatios);
problem_def.domainRatio = pyArrayFromRatArray1d<coder::array<real_T, 2U>>(problem.domainRatio);

problem_def.fitParams = pyArrayFromRatArray1d<coder::array<real_T, 1U>>(problem.fitParams, false);
problem_def.otherParams = pyArrayFromRatArray1d<coder::array<real_T, 1U>>(problem.otherParams, false);
problem_def.fitLimits = pyArrayFromRatArray2d(problem.fitLimits);
problem_def.otherLimits = pyArrayFromRatArray2d(problem.otherLimits);

return problem_def;
}

py::list pyList1DFromRatCellWrap(const coder::array<RAT::cell_wrap_10, 1U>& values)
template <typename T>
py::list pyList1DFromRatCellWrap(const T& values)
{
py::list result;

Expand All @@ -1081,7 +1107,8 @@ py::list pyList1DFromRatCellWrap(const coder::array<RAT::cell_wrap_10, 1U>& valu
return result;
}

py::list pyList2dFromRatCellWrap(const coder::array<RAT::cell_wrap_10, 2U>& values)
template <typename T>
py::list pyList2dFromRatCellWrap(const T& values)
{
py::list result;
int32_T idx {0};
Expand Down Expand Up @@ -1129,10 +1156,10 @@ BayesResults bayesResultsFromStruct8T(const RAT::struct8_T results)

bayesResults.chain = pyArrayFromRatArray2d(results.chain);

bayesResults.predictionIntervals.reflectivity = pyList1DFromRatCellWrap(results.predictionIntervals.reflectivity);
bayesResults.predictionIntervals.sld = pyList2dFromRatCellWrap(results.predictionIntervals.sld);
bayesResults.predictionIntervals.reflectivityXData = pyList1DFromRatCellWrap(results.predictionIntervals.reflectivityXData);
bayesResults.predictionIntervals.sldXData = pyList2dFromRatCellWrap(results.predictionIntervals.sldXData);
bayesResults.predictionIntervals.reflectivity = pyList1DFromRatCellWrap<coder::array<RAT::cell_wrap_11, 1U>>(results.predictionIntervals.reflectivity);
bayesResults.predictionIntervals.sld = pyList2dFromRatCellWrap<coder::array<RAT::cell_wrap_11, 2U>>(results.predictionIntervals.sld);
bayesResults.predictionIntervals.reflectivityXData = pyList1DFromRatCellWrap<coder::array<RAT::cell_wrap_12, 1U>>(results.predictionIntervals.reflectivityXData);
bayesResults.predictionIntervals.sldXData = pyList2dFromRatCellWrap<coder::array<RAT::cell_wrap_12, 2U>>(results.predictionIntervals.sldXData);
bayesResults.predictionIntervals.sampleChi = pyArray1dFromBoundedArray<coder::bounded_array<real_T, 1000U, 1U>>(results.predictionIntervals.sampleChi);

bayesResults.confidenceIntervals.percentile95 = pyArrayFromRatArray2d(results.confidenceIntervals.percentile95);
Expand Down Expand Up @@ -1445,7 +1472,8 @@ PYBIND11_MODULE(rat_core, m) {
.def_readwrite("pUnitGamma", &Control::pUnitGamma)
.def_readwrite("boundHandling", &Control::boundHandling)
.def_readwrite("adaptPCR", &Control::adaptPCR)
.def_readwrite("checks", &Control::checks);
.def_readwrite("checks", &Control::checks)
.def_readwrite("IPCFilePath", &Control::IPCFilePath);

py::class_<ProblemDefinition>(m, "ProblemDefinition")
.def(py::init<>())
Expand Down
Loading

0 comments on commit 06ecccb

Please sign in to comment.