From d4fb0bb96332a1cfb7cca13ba2e6bc94636ad3f0 Mon Sep 17 00:00:00 2001 From: StephenNneji <34302892+StephenNneji@users.noreply.github.com> Date: Fri, 2 Aug 2024 15:35:00 +0100 Subject: [PATCH] Added plot legend and updates tests (#59) --- .github/workflows/run_ruff.yml | 6 +----- .github/workflows/run_tests.yml | 5 ++++- RATapi/inputs.py | 1 + RATapi/utils/plotting.py | 10 +++++---- cpp/RAT | 2 +- cpp/rat.cpp | 18 ++++++++++++++--- tests/test_inputs.py | 3 +++ tests/test_plotting.py | 36 +++++++++++++++++++++++++++------ 8 files changed, 61 insertions(+), 20 deletions(-) diff --git a/.github/workflows/run_ruff.yml b/.github/workflows/run_ruff.yml index 41b9e1ac..5fc5b79f 100644 --- a/.github/workflows/run_ruff.yml +++ b/.github/workflows/run_ruff.yml @@ -1,10 +1,6 @@ name: Ruff -on: - push: - branches: [ main ] - pull_request: - branches: [ main ] +on: workflow_call jobs: ruff: diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index bae09233..191ee87b 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -11,8 +11,11 @@ concurrency: cancel-in-progress: true jobs: - test: + run_ruff: + uses: ./.github/workflows/run_ruff.yml + test: + needs: [run_ruff] strategy: fail-fast: false matrix: diff --git a/RATapi/inputs.py b/RATapi/inputs.py index 2e92c565..07dc7087 100644 --- a/RATapi/inputs.py +++ b/RATapi/inputs.py @@ -387,6 +387,7 @@ def make_cells(project: RATapi.Project) -> Cells: ] cells.f20 = [param.name for param in project.domain_ratios] + cells.f21 = [contrast.name for contrast in project.contrasts] return cells diff --git a/RATapi/utils/plotting.py b/RATapi/utils/plotting.py index 84adf1b0..6f313142 100644 --- a/RATapi/utils/plotting.py +++ b/RATapi/utils/plotting.py @@ -74,14 +74,14 @@ def plot_ref_sld_helper(data: PlotEventData, fig: Optional[plt.figure] = None, d ref_plot.cla() sld_plot.cla() - for i, (r, sd, sld, layer) in enumerate( - zip(data.reflectivity, data.shiftedData, data.sldProfiles, data.resampledLayers), + for i, (r, sd, sld, name) in enumerate( + zip(data.reflectivity, data.shiftedData, data.sldProfiles, data.contrastNames), ): # Calculate the divisor div = 1 if i == 0 else 2 ** (4 * (i + 1)) # Plot the reflectivity on plot (1,1) - ref_plot.plot(r[:, 0], r[:, 1] / div, label=f"ref {i+1}", linewidth=2) + ref_plot.plot(r[:, 0], r[:, 1] / div, label=name, linewidth=2) color = ref_plot.get_lines()[-1].get_color() if data.dataPresent[i]: @@ -100,7 +100,8 @@ def plot_ref_sld_helper(data: PlotEventData, fig: Optional[plt.figure] = None, d # Plot the slds on plot (1,2) for j in range(len(sld)): - sld_plot.plot(sld[j][:, 0], sld[j][:, 1], label=f"sld {i+1}", linewidth=1) + label = name if len(sld) == 1 else f"{name} Domain {j+1}" + sld_plot.plot(sld[j][:, 0], sld[j][:, 1], label=label, linewidth=1) if data.resample[i] == 1 or data.modelType == "custom xy": layers = data.resampledLayers[i][0] @@ -172,6 +173,7 @@ def plot_ref_sld( data.dataPresent = RATapi.inputs.make_data_present(project) data.subRoughs = results.contrastParams.subRoughs data.resample = RATapi.inputs.make_resample(project) + data.contrastNames = [contrast.name for contrast in project.contrasts] figure = plt.subplots(1, 2)[0] diff --git a/cpp/RAT b/cpp/RAT index 3f2462a8..5e3622f7 160000 --- a/cpp/RAT +++ b/cpp/RAT @@ -1 +1 @@ -Subproject commit 3f2462a80471a2daa896553cb3ec673ed0777fee +Subproject commit 5e3622f7d2d8cb48cab153d6660fce650b868d38 diff --git a/cpp/rat.cpp b/cpp/rat.cpp index 9a135861..0a599601 100644 --- a/cpp/rat.cpp +++ b/cpp/rat.cpp @@ -171,6 +171,7 @@ struct PlotEventData py::array_t resample; py::array_t dataPresent; std::string modelType; + py::list contrastNames; }; class EventBridge @@ -271,6 +272,13 @@ class EventBridge eventData.resampledLayers = unpackDataToCell(pEvent->data->nContrast, (pEvent->data->nLayers2 == NULL) ? 1 : 2, pEvent->data->layers, pEvent->data->nLayers, pEvent->data->layers2, pEvent->data->nLayers2, true); + + int offset = 0; + for (int i = 0; i < pEvent->data->nContrast; i++){ + eventData.contrastNames.append(std::string(pEvent->data->contrastNames + offset, + pEvent->data->nContrastNames[i])); + offset += pEvent->data->nContrastNames[i]; + } this->callback(event.type, eventData); } }; @@ -445,6 +453,7 @@ struct Cells { py::list f18; py::list f19; py::list f20; + py::list f21; }; struct ProblemDefinition { @@ -835,7 +844,7 @@ RAT::cell_7 createCell7(const Cells& cells) cells_struct.f2 = customCaller("Cells.f2", pyListToRatCellWrap3, cells.f2); cells_struct.f3 = customCaller("Cells.f3", pyListToRatCellWrap2, cells.f3); cells_struct.f4 = customCaller("Cells.f4", pyListToRatCellWrap2, cells.f4); - cells_struct.f5 = customCaller("Cells.f5", pyListToRatCellWrap4, cells.f5); //improve this error + cells_struct.f5 = customCaller("Cells.f5", pyListToRatCellWrap4, cells.f5); cells_struct.f6 = customCaller("Cells.f6", pyListToRatCellWrap5, cells.f6); cells_struct.f7 = customCaller("Cells.f7", pyListToRatCellWrap6, cells.f7); cells_struct.f8 = customCaller("Cells.f8", pyListToRatCellWrap6, cells.f8); @@ -851,6 +860,7 @@ RAT::cell_7 createCell7(const Cells& cells) cells_struct.f18 = customCaller("Cells.f18", pyListToRatCellWrap2, cells.f18); cells_struct.f19 = customCaller("Cells.f19", pyListToRatCellWrap4, cells.f19); cells_struct.f20 = customCaller("Cells.f20", pyListToRatCellWrap6, cells.f20); + cells_struct.f21 = customCaller("Cells.f21", pyListToRatCellWrap6, cells.f21); return cells_struct; } @@ -1257,7 +1267,8 @@ PYBIND11_MODULE(rat_core, m) { .def_readwrite("subRoughs", &PlotEventData::subRoughs) .def_readwrite("resample", &PlotEventData::resample) .def_readwrite("dataPresent", &PlotEventData::dataPresent) - .def_readwrite("modelType", &PlotEventData::modelType); + .def_readwrite("modelType", &PlotEventData::modelType) + .def_readwrite("contrastNames", &PlotEventData::contrastNames); py::class_(m, "ProgressEventData") .def(py::init<>()) @@ -1402,7 +1413,8 @@ PYBIND11_MODULE(rat_core, m) { .def_readwrite("f17", &Cells::f17) .def_readwrite("f18", &Cells::f18) .def_readwrite("f19", &Cells::f19) - .def_readwrite("f20", &Cells::f20); + .def_readwrite("f20", &Cells::f20) + .def_readwrite("f21", &Cells::f21); py::class_(m, "Control") .def(py::init<>()) diff --git a/tests/test_inputs.py b/tests/test_inputs.py index 04a609f1..a3ae2b7e 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -281,6 +281,7 @@ def standard_layers_cells(): cells.f18 = [] cells.f19 = [] cells.f20 = [] + cells.f21 = ["D2O"] return cells @@ -309,6 +310,7 @@ def domains_cells(): cells.f18 = [[0, 1], [0, 1]] cells.f19 = [[1], [1]] cells.f20 = ["Domain Ratio 1"] + cells.f21 = ["D2O"] return cells @@ -337,6 +339,7 @@ def custom_xy_cells(): cells.f18 = [] cells.f19 = [] cells.f20 = [] + cells.f21 = ["D2O"] return cells diff --git a/tests/test_plotting.py b/tests/test_plotting.py index c21b634f..c2d69406 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -13,7 +13,7 @@ def data() -> PlotEventData: - """Creates the fixture for the tests.""" + """Creates the data for the tests.""" data_path = os.path.join(TEST_DIR_PATH, "plotting_data.pickle") with open(data_path, "rb") as f: loaded_data = pickle.load(f) @@ -27,17 +27,29 @@ def data() -> PlotEventData: data.reflectivity = loaded_data["reflectivity"] data.shiftedData = loaded_data["shiftedData"] data.sldProfiles = loaded_data["sldProfiles"] + data.contrastNames = ["D2O", "SMW", "H2O"] + return data -@pytest.fixture -def fig() -> plt.figure: +def domains_data() -> PlotEventData: + """Creates the fake domains data for the tests.""" + domains_data = data() + for sld_list in domains_data.sldProfiles: + sld_list.append(sld_list[0]) + + return domains_data + + +@pytest.fixture(params=[False]) +def fig(request) -> plt.figure: """Creates the fixture for the tests.""" plt.close("all") figure = plt.subplots(1, 3)[0] - return plot_ref_sld_helper(fig=figure, data=data()) + return plot_ref_sld_helper(fig=figure, data=domains_data() if request.param else data()) +@pytest.mark.parametrize("fig", [False, True], indirect=True) def test_figure_axis_formating(fig: plt.figure) -> None: """Tests the axis formating of the figure.""" ref_plot = fig.axes[0] @@ -50,13 +62,24 @@ def test_figure_axis_formating(fig: plt.figure) -> None: assert ref_plot.get_xscale() == "log" assert ref_plot.get_ylabel() == "Reflectivity" assert ref_plot.get_yscale() == "log" - assert [label._text for label in ref_plot.get_legend().texts] == ["ref 1", "ref 2", "ref 3"] + assert [label._text for label in ref_plot.get_legend().texts] == ["D2O", "SMW", "H2O"] assert sld_plot.get_xlabel() == "$Z (\u00c5)$" assert sld_plot.get_xscale() == "linear" assert sld_plot.get_ylabel() == "$SLD (\u00c5^{-2})$" assert sld_plot.get_yscale() == "linear" - assert [label._text for label in sld_plot.get_legend().texts] == ["sld 1", "sld 2", "sld 3"] + labels = [label._text for label in sld_plot.get_legend().texts] + if len(labels) == 3: + assert labels == ["D2O", "SMW", "H2O"] + else: + assert labels == [ + "D2O Domain 1", + "D2O Domain 2", + "SMW Domain 1", + "SMW Domain 2", + "H2O Domain 1", + "H2O Domain 2", + ] def test_figure_color_formating(fig: plt.figure) -> None: @@ -157,3 +180,4 @@ def test_plot_ref_sld(mock: MagicMock, input_project, reflectivity_calculation_r assert data.dataPresent.size == 0 assert (data.subRoughs == reflectivity_calculation_results.contrastParams.subRoughs).all() assert data.resample.size == 0 + assert len(data.contrastNames) == 0