From 072f954cf5eaea47c664879ff3487e22bafac0e4 Mon Sep 17 00:00:00 2001 From: StephenNneji <34302892+StephenNneji@users.noreply.github.com> Date: Fri, 4 Oct 2024 11:59:06 +0100 Subject: [PATCH] Modifies cells to make function handles pickleable (#82) --- RATapi/inputs.py | 58 ++++++++++++++++++++++++++++++++++-------- cpp/rat.cpp | 11 ++++---- tests/test_inputs.py | 60 ++++++++++++++++---------------------------- 3 files changed, 74 insertions(+), 55 deletions(-) diff --git a/RATapi/inputs.py b/RATapi/inputs.py index f5e2490e..90bb5895 100644 --- a/RATapi/inputs.py +++ b/RATapi/inputs.py @@ -279,6 +279,52 @@ def check_indices(problem: ProblemDefinition) -> None: ) +class FileHandles: + """Class to defer creation of custom file handles. + + Parameters + ---------- + files : ClassList[CustomFile] + A list of custom file models. + """ + + def __init__(self, files): + self.index = 0 + self.files = [*files] + + def __iter__(self): + self.index = 0 + return self + + def get_handle(self, index): + """Returns file handle for a given custom file. + + Parameters + ---------- + index : int + The index of the custom file. + + """ + custom_file = self.files[index] + full_path = os.path.join(custom_file.path, custom_file.filename) + if custom_file.language == Languages.Python: + file_handle = get_python_handle(custom_file.filename, custom_file.function_name, custom_file.path) + elif custom_file.language == Languages.Matlab: + file_handle = RATapi.wrappers.MatlabWrapper(full_path).getHandle() + elif custom_file.language == Languages.Cpp: + file_handle = RATapi.wrappers.DylibWrapper(full_path, custom_file.function_name).getHandle() + + return file_handle + + def __next__(self): + if self.index < len(self.files): + custom_file = self.get_handle(self.index) + self.index += 1 + return custom_file + else: + raise StopIteration + + def make_cells(project: RATapi.Project) -> Cells: """Constructs the cells input required for the compiled RAT code. @@ -344,16 +390,6 @@ def make_cells(project: RATapi.Project) -> Cells: else: simulation_limits.append([0.0, 0.0]) - file_handles = [] - for custom_file in project.custom_files: - full_path = os.path.join(custom_file.path, custom_file.filename) - if custom_file.language == Languages.Python: - file_handles.append(get_python_handle(custom_file.filename, custom_file.function_name, custom_file.path)) - elif custom_file.language == Languages.Matlab: - file_handles.append(RATapi.wrappers.MatlabWrapper(full_path).getHandle()) - elif custom_file.language == Languages.Cpp: - file_handles.append(RATapi.wrappers.DylibWrapper(full_path, custom_file.function_name).getHandle()) - # Populate the set of cells cells = Cells() cells.f1 = [[0, 1]] * len(project.contrasts) # This is marked as "to do" in RAT @@ -369,7 +405,7 @@ def make_cells(project: RATapi.Project) -> Cells: cells.f11 = [param.name for param in project.bulk_in] cells.f12 = [param.name for param in project.bulk_out] cells.f13 = [param.name for param in project.resolution_parameters] - cells.f14 = file_handles + cells.f14 = FileHandles(project.custom_files) cells.f15 = [param.type for param in project.backgrounds] cells.f16 = [param.type for param in project.resolutions] diff --git a/cpp/rat.cpp b/cpp/rat.cpp index 59f2bab2..0a6d129f 100644 --- a/cpp/rat.cpp +++ b/cpp/rat.cpp @@ -446,7 +446,7 @@ struct Cells { py::list f11; py::list f12; py::list f13; - py::list f14; + py::object f14; py::list f15; py::list f16; py::list f17; @@ -844,12 +844,13 @@ coder::array pyListToRatCellWrap6(py::list values) return result; } -coder::array py_function_array_to_rat_cell_wrap_6(py::list values) +coder::array py_function_array_to_rat_cell_wrap_6(py::object values) { + auto handles = py::cast(values); coder::array result; - result.set_size(1, values.size()); + result.set_size(1, handles.size()); int32_T idx {0}; - for (py::handle array: values) + for (py::handle array: handles) { auto func = py::cast(array); std::string func_ptr = convertPtr2String(new Library(func)); @@ -1585,7 +1586,7 @@ PYBIND11_MODULE(rat_core, m) { cell.f11 = t[10].cast(); cell.f12 = t[11].cast(); cell.f13 = t[12].cast(); - cell.f14 = t[13].cast(); + cell.f14 = t[13].cast(); cell.f15 = t[14].cast(); cell.f16 = t[15].cast(); cell.f17 = t[16].cast(); diff --git a/tests/test_inputs.py b/tests/test_inputs.py index 43870487..5c831395 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -624,25 +624,7 @@ def test_make_input(test_project, test_problem, test_cells, test_limits, test_pr "domainRatio", ] - mocked_matlab_future = mock.MagicMock() - mocked_engine = mock.MagicMock() - mocked_matlab_future.result.return_value = mocked_engine - - with mock.patch.object( - RATapi.wrappers.MatlabWrapper, - "loader", - mocked_matlab_future, - ), mock.patch.object(RATapi.rat_core, "DylibEngine", mock.MagicMock()), mock.patch.object( - RATapi.inputs, - "get_python_handle", - mock.MagicMock(return_value=dummy_function), - ), mock.patch.object( - RATapi.wrappers.MatlabWrapper, - "getHandle", - mock.MagicMock(return_value=dummy_function), - ), mock.patch.object(RATapi.wrappers.DylibWrapper, "getHandle", mock.MagicMock(return_value=dummy_function)): - problem, cells, limits, priors, controls = make_input(test_project, RATapi.Controls()) - + problem, cells, limits, priors, controls = make_input(test_project, RATapi.Controls()) problem = pickle.loads(pickle.dumps(problem)) check_problem_equal(problem, test_problem) cells = pickle.loads(pickle.dumps(cells)) @@ -768,25 +750,7 @@ def test_make_cells(test_project, test_cells, request) -> None: """The cells object should be populated according to the input project object.""" test_project = request.getfixturevalue(test_project) test_cells = request.getfixturevalue(test_cells) - - mocked_matlab_future = mock.MagicMock() - mocked_engine = mock.MagicMock() - mocked_matlab_future.result.return_value = mocked_engine - with mock.patch.object( - RATapi.wrappers.MatlabWrapper, - "loader", - mocked_matlab_future, - ), mock.patch.object(RATapi.rat_core, "DylibEngine", mock.MagicMock()), mock.patch.object( - RATapi.inputs, - "get_python_handle", - mock.MagicMock(return_value=dummy_function), - ), mock.patch.object( - RATapi.wrappers.MatlabWrapper, - "getHandle", - mock.MagicMock(return_value=dummy_function), - ), mock.patch.object(RATapi.wrappers.DylibWrapper, "getHandle", mock.MagicMock(return_value=dummy_function)): - cells = make_cells(test_project) - + cells = make_cells(test_project) check_cells_equal(cells, test_cells) @@ -865,7 +829,25 @@ def check_cells_equal(actual_cells, expected_cells) -> None: "NaN" if np.isnan(el) else el for entry in actual_cells.f6 for el in entry ] == ["NaN" if np.isnan(el) else el for entry in expected_cells.f6 for el in entry] - for index in chain(range(3, 6), range(7, 21)): + mocked_matlab_future = mock.MagicMock() + mocked_engine = mock.MagicMock() + mocked_matlab_future.result.return_value = mocked_engine + with mock.patch.object( + RATapi.wrappers.MatlabWrapper, + "loader", + mocked_matlab_future, + ), mock.patch.object(RATapi.rat_core, "DylibEngine", mock.MagicMock()), mock.patch.object( + RATapi.inputs, + "get_python_handle", + mock.MagicMock(return_value=dummy_function), + ), mock.patch.object( + RATapi.wrappers.MatlabWrapper, + "getHandle", + mock.MagicMock(return_value=dummy_function), + ), mock.patch.object(RATapi.wrappers.DylibWrapper, "getHandle", mock.MagicMock(return_value=dummy_function)): + assert list(actual_cells.f14) == expected_cells.f14 + + for index in chain(range(3, 6), range(7, 14), range(15, 21)): field = f"f{index}" assert getattr(actual_cells, field) == getattr(expected_cells, field)