Skip to content

Commit

Permalink
Added plot legend and updates tests (RascalSoftware#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenNneji authored Aug 2, 2024
1 parent 70bcb3f commit d4fb0bb
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 20 deletions.
6 changes: 1 addition & 5 deletions .github/workflows/run_ruff.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
name: Ruff

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
on: workflow_call

jobs:
ruff:
Expand Down
5 changes: 4 additions & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ concurrency:
cancel-in-progress: true

jobs:
test:
run_ruff:
uses: ./.github/workflows/run_ruff.yml

test:
needs: [run_ruff]
strategy:
fail-fast: false
matrix:
Expand Down
1 change: 1 addition & 0 deletions RATapi/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def make_cells(project: RATapi.Project) -> Cells:
]

cells.f20 = [param.name for param in project.domain_ratios]
cells.f21 = [contrast.name for contrast in project.contrasts]

return cells

Expand Down
10 changes: 6 additions & 4 deletions RATapi/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ def plot_ref_sld_helper(data: PlotEventData, fig: Optional[plt.figure] = None, d
ref_plot.cla()
sld_plot.cla()

for i, (r, sd, sld, layer) in enumerate(
zip(data.reflectivity, data.shiftedData, data.sldProfiles, data.resampledLayers),
for i, (r, sd, sld, name) in enumerate(
zip(data.reflectivity, data.shiftedData, data.sldProfiles, data.contrastNames),
):
# Calculate the divisor
div = 1 if i == 0 else 2 ** (4 * (i + 1))

# Plot the reflectivity on plot (1,1)
ref_plot.plot(r[:, 0], r[:, 1] / div, label=f"ref {i+1}", linewidth=2)
ref_plot.plot(r[:, 0], r[:, 1] / div, label=name, linewidth=2)
color = ref_plot.get_lines()[-1].get_color()

if data.dataPresent[i]:
Expand All @@ -100,7 +100,8 @@ def plot_ref_sld_helper(data: PlotEventData, fig: Optional[plt.figure] = None, d

# Plot the slds on plot (1,2)
for j in range(len(sld)):
sld_plot.plot(sld[j][:, 0], sld[j][:, 1], label=f"sld {i+1}", linewidth=1)
label = name if len(sld) == 1 else f"{name} Domain {j+1}"
sld_plot.plot(sld[j][:, 0], sld[j][:, 1], label=label, linewidth=1)

if data.resample[i] == 1 or data.modelType == "custom xy":
layers = data.resampledLayers[i][0]
Expand Down Expand Up @@ -172,6 +173,7 @@ def plot_ref_sld(
data.dataPresent = RATapi.inputs.make_data_present(project)
data.subRoughs = results.contrastParams.subRoughs
data.resample = RATapi.inputs.make_resample(project)
data.contrastNames = [contrast.name for contrast in project.contrasts]

figure = plt.subplots(1, 2)[0]

Expand Down
18 changes: 15 additions & 3 deletions cpp/rat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ struct PlotEventData
py::array_t<double> resample;
py::array_t<double> dataPresent;
std::string modelType;
py::list contrastNames;
};

class EventBridge
Expand Down Expand Up @@ -271,6 +272,13 @@ class EventBridge
eventData.resampledLayers = unpackDataToCell(pEvent->data->nContrast, (pEvent->data->nLayers2 == NULL) ? 1 : 2,
pEvent->data->layers, pEvent->data->nLayers,
pEvent->data->layers2, pEvent->data->nLayers2, true);

int offset = 0;
for (int i = 0; i < pEvent->data->nContrast; i++){
eventData.contrastNames.append(std::string(pEvent->data->contrastNames + offset,
pEvent->data->nContrastNames[i]));
offset += pEvent->data->nContrastNames[i];
}
this->callback(event.type, eventData);
}
};
Expand Down Expand Up @@ -445,6 +453,7 @@ struct Cells {
py::list f18;
py::list f19;
py::list f20;
py::list f21;
};

struct ProblemDefinition {
Expand Down Expand Up @@ -835,7 +844,7 @@ RAT::cell_7 createCell7(const Cells& cells)
cells_struct.f2 = customCaller("Cells.f2", pyListToRatCellWrap3, cells.f2);
cells_struct.f3 = customCaller("Cells.f3", pyListToRatCellWrap2, cells.f3);
cells_struct.f4 = customCaller("Cells.f4", pyListToRatCellWrap2, cells.f4);
cells_struct.f5 = customCaller("Cells.f5", pyListToRatCellWrap4, cells.f5); //improve this error
cells_struct.f5 = customCaller("Cells.f5", pyListToRatCellWrap4, cells.f5);
cells_struct.f6 = customCaller("Cells.f6", pyListToRatCellWrap5, cells.f6);
cells_struct.f7 = customCaller("Cells.f7", pyListToRatCellWrap6, cells.f7);
cells_struct.f8 = customCaller("Cells.f8", pyListToRatCellWrap6, cells.f8);
Expand All @@ -851,6 +860,7 @@ RAT::cell_7 createCell7(const Cells& cells)
cells_struct.f18 = customCaller("Cells.f18", pyListToRatCellWrap2, cells.f18);
cells_struct.f19 = customCaller("Cells.f19", pyListToRatCellWrap4, cells.f19);
cells_struct.f20 = customCaller("Cells.f20", pyListToRatCellWrap6, cells.f20);
cells_struct.f21 = customCaller("Cells.f21", pyListToRatCellWrap6, cells.f21);

return cells_struct;
}
Expand Down Expand Up @@ -1257,7 +1267,8 @@ PYBIND11_MODULE(rat_core, m) {
.def_readwrite("subRoughs", &PlotEventData::subRoughs)
.def_readwrite("resample", &PlotEventData::resample)
.def_readwrite("dataPresent", &PlotEventData::dataPresent)
.def_readwrite("modelType", &PlotEventData::modelType);
.def_readwrite("modelType", &PlotEventData::modelType)
.def_readwrite("contrastNames", &PlotEventData::contrastNames);

py::class_<ProgressEventData>(m, "ProgressEventData")
.def(py::init<>())
Expand Down Expand Up @@ -1402,7 +1413,8 @@ PYBIND11_MODULE(rat_core, m) {
.def_readwrite("f17", &Cells::f17)
.def_readwrite("f18", &Cells::f18)
.def_readwrite("f19", &Cells::f19)
.def_readwrite("f20", &Cells::f20);
.def_readwrite("f20", &Cells::f20)
.def_readwrite("f21", &Cells::f21);

py::class_<Control>(m, "Control")
.def(py::init<>())
Expand Down
3 changes: 3 additions & 0 deletions tests/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def standard_layers_cells():
cells.f18 = []
cells.f19 = []
cells.f20 = []
cells.f21 = ["D2O"]

return cells

Expand Down Expand Up @@ -309,6 +310,7 @@ def domains_cells():
cells.f18 = [[0, 1], [0, 1]]
cells.f19 = [[1], [1]]
cells.f20 = ["Domain Ratio 1"]
cells.f21 = ["D2O"]

return cells

Expand Down Expand Up @@ -337,6 +339,7 @@ def custom_xy_cells():
cells.f18 = []
cells.f19 = []
cells.f20 = []
cells.f21 = ["D2O"]

return cells

Expand Down
36 changes: 30 additions & 6 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def data() -> PlotEventData:
"""Creates the fixture for the tests."""
"""Creates the data for the tests."""
data_path = os.path.join(TEST_DIR_PATH, "plotting_data.pickle")
with open(data_path, "rb") as f:
loaded_data = pickle.load(f)
Expand All @@ -27,17 +27,29 @@ def data() -> PlotEventData:
data.reflectivity = loaded_data["reflectivity"]
data.shiftedData = loaded_data["shiftedData"]
data.sldProfiles = loaded_data["sldProfiles"]
data.contrastNames = ["D2O", "SMW", "H2O"]

return data


@pytest.fixture
def fig() -> plt.figure:
def domains_data() -> PlotEventData:
"""Creates the fake domains data for the tests."""
domains_data = data()
for sld_list in domains_data.sldProfiles:
sld_list.append(sld_list[0])

return domains_data


@pytest.fixture(params=[False])
def fig(request) -> plt.figure:
"""Creates the fixture for the tests."""
plt.close("all")
figure = plt.subplots(1, 3)[0]
return plot_ref_sld_helper(fig=figure, data=data())
return plot_ref_sld_helper(fig=figure, data=domains_data() if request.param else data())


@pytest.mark.parametrize("fig", [False, True], indirect=True)
def test_figure_axis_formating(fig: plt.figure) -> None:
"""Tests the axis formating of the figure."""
ref_plot = fig.axes[0]
Expand All @@ -50,13 +62,24 @@ def test_figure_axis_formating(fig: plt.figure) -> None:
assert ref_plot.get_xscale() == "log"
assert ref_plot.get_ylabel() == "Reflectivity"
assert ref_plot.get_yscale() == "log"
assert [label._text for label in ref_plot.get_legend().texts] == ["ref 1", "ref 2", "ref 3"]
assert [label._text for label in ref_plot.get_legend().texts] == ["D2O", "SMW", "H2O"]

assert sld_plot.get_xlabel() == "$Z (\u00c5)$"
assert sld_plot.get_xscale() == "linear"
assert sld_plot.get_ylabel() == "$SLD (\u00c5^{-2})$"
assert sld_plot.get_yscale() == "linear"
assert [label._text for label in sld_plot.get_legend().texts] == ["sld 1", "sld 2", "sld 3"]
labels = [label._text for label in sld_plot.get_legend().texts]
if len(labels) == 3:
assert labels == ["D2O", "SMW", "H2O"]
else:
assert labels == [
"D2O Domain 1",
"D2O Domain 2",
"SMW Domain 1",
"SMW Domain 2",
"H2O Domain 1",
"H2O Domain 2",
]


def test_figure_color_formating(fig: plt.figure) -> None:
Expand Down Expand Up @@ -157,3 +180,4 @@ def test_plot_ref_sld(mock: MagicMock, input_project, reflectivity_calculation_r
assert data.dataPresent.size == 0
assert (data.subRoughs == reflectivity_calculation_results.contrastParams.subRoughs).all()
assert data.resample.size == 0
assert len(data.contrastNames) == 0

0 comments on commit d4fb0bb

Please sign in to comment.