Skip to content

Commit

Permalink
Refactors plot_ref_sld to plot_ref_sld_helper, ensures plots are exte…
Browse files Browse the repository at this point in the history
…nded similar to matlab
  • Loading branch information
StephenNneji committed May 15, 2024
1 parent a708a2f commit 378ff0b
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 32 deletions.
87 changes: 60 additions & 27 deletions RAT/utils/plotting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Plots using the matplotlib library
"""
from typing import Optional
import matplotlib.pyplot as plt
import numpy as np
from RAT.rat_core import PlotEventData, makeSLDProfileXY
Expand All @@ -17,9 +18,9 @@ def __init__(self, row: int = 1, col: int = 2):
Parameters
----------
row : int
row : int, default: 1
The number of rows in subplot
col : int
col : int, default: 2
The number of columns in subplot
"""
self._fig, self._ax = \
Expand Down Expand Up @@ -55,7 +56,8 @@ def _close(self, _):
self._close_clicked = True


def plot_errorbars(ax, x, y, err, onesided, color):
def plot_errorbars(ax: 'matplotlib.axes._axes.Axes', x: np.ndarray, y: np.ndarray, err: np.ndarray,
one_sided: bool, color: str):
"""
Plots the error bars.
Expand All @@ -69,12 +71,12 @@ def plot_errorbars(ax, x, y, err, onesided, color):
The shifted data y axis data
err : np.ndarray
The shifted data e data
onesided : bool
one_sided : bool
A boolean to indicate whether to draw one sided errorbars
color : str
The hex representing the color of the errorbars
"""
y_error = [[0]*len(err), err] if onesided else err
y_error = [[0]*len(err), err] if one_sided else err
ax.errorbar(x=x,
y=y,
yerr=y_error,
Expand All @@ -85,7 +87,7 @@ def plot_errorbars(ax, x, y, err, onesided, color):
ax.scatter(x=x, y=y, s=3, marker="o", color=color)


def plot_ref_sld(data: PlotEventData, fig: Figure = None, delay: bool = True):
def plot_ref_sld_helper(data: PlotEventData, fig: Optional[Figure] = None, delay: bool = True):
"""
Clears the previous plots and updates the ref and SLD plots.
Expand All @@ -94,9 +96,9 @@ def plot_ref_sld(data: PlotEventData, fig: Figure = None, delay: bool = True):
data : PlotEventData
The plot event data that contains all the information
to generate the ref and sld plots
fig : Figure
fig : Figure, optional
The figure class that has two subplots
delay : bool
delay : bool, default: True
Controls whether to delay 0.005s after plot is created
Returns
Expand All @@ -121,9 +123,6 @@ def plot_ref_sld(data: PlotEventData, fig: Figure = None, delay: bool = True):
data.shiftedData,
data.sldProfiles,
data.resampledLayers)):

sld, layer = map(lambda x: x[0], (sld, layer))

# Calculate the divisor
div = 1 if i == 0 else 2**(4*(i+1))

Expand Down Expand Up @@ -154,25 +153,29 @@ def plot_ref_sld(data: PlotEventData, fig: Figure = None, delay: bool = True):
plot_errorbars(ref_plot, sd_x_s, sd_y_s, sd_e_s, True, color)

# Plot the slds on plot (1,2)
for j in range(1, sld.shape[1]):
sld_plot.plot(sld[:, 0],
sld[:, j],
for j in range(len(sld)):
sld_plot.plot(sld[j][:, 0],
sld[j][:, 1],
label=f'sld {i+1}',
color=color,
linewidth=2)
linewidth=1)

if data.resample[i] == 1 or data.modelType == 'custom xy':
new = makeSLDProfileXY(layer[0, 1],
layer[-1, 1],
data.subRoughs[i],
layer,
len(layer),
1.0)

sld_plot.plot([row[0]-49 for row in new],
[row[1] for row in new],
color=color,
linewidth=1)
layers = data.resampledLayers[i][0]
for j in range(len(data.resampledLayers[i])):
layer = data.resampledLayers[i][j]
if layers.shape[1] == 4:
layer = np.delete(layer, 2, 1)
new_profile = makeSLDProfileXY(layers[0, 1], # Bulk In
layers[-1, 1], # Bulk Out
data.subRoughs[i], # roughness
layer,
len(layer),
1.0)

sld_plot.plot([row[0]-49 for row in new_profile],
[row[1] for row in new_profile],
color=color,
linewidth=1)

# Format the axis
ref_plot.set_yscale('log')
Expand All @@ -191,3 +194,33 @@ def plot_ref_sld(data: PlotEventData, fig: Figure = None, delay: bool = True):
plt.pause(0.005)

return fig


def plot_ref_sld(problem, results, block: bool = False):
"""
Plots the reflectivity and SLD profiles.
Parameters
----------
problem : Project
An instance of the Project class
results : Result
The result from the calculation
block : bool, default: False
Indicates the plot should block until it is closed
"""
data = PlotEventData()

data.reflectivity = results.reflectivity
data.shiftedData = results.shiftedData
data.sldProfiles = results.sldProfiles
data.resampledLayers = results.resampledLayers
data.dataPresent = problem.dataPresent
data.subRoughs = results.contrastParams.subRoughs
data.resample = problem.resample

figure = Figure()

plot_ref_sld_helper(data, figure)
if block:
figure.wait_for_close()
6 changes: 3 additions & 3 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from unittest.mock import MagicMock
import matplotlib.pyplot as plt
from RAT.rat_core import PlotEventData
from RAT.utils.plotting import Figure, plot_ref_sld
from RAT.utils.plotting import Figure, plot_ref_sld_helper


TEST_DIR_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)),
Expand Down Expand Up @@ -40,7 +40,7 @@ def fig() -> Figure:
"""
plt.close('all')
figure = Figure(1, 3)
fig = plot_ref_sld(fig=figure, data=data())
fig = plot_ref_sld_helper(fig=figure, data=data())
return fig


Expand Down Expand Up @@ -155,7 +155,7 @@ def test_sld_profile_function_call(mock: MagicMock) -> None:
Tests the makeSLDProfileXY function called with
correct args.
"""
plot_ref_sld(data())
plot_ref_sld_helper(data())

assert mock.call_count == 3
assert mock.call_args_list[0].args[0] == 2.07e-06
Expand Down
5 changes: 3 additions & 2 deletions tests/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@


def test_matlab_wrapper() -> None:
with pytest.raises(ImportError):
RAT.wrappers.MatlabWrapper('demo.m')
with mock.patch.dict('sys.modules', {'matlab': mock.MagicMock(side_effect=ImportError)}):
with pytest.raises(ImportError):
RAT.wrappers.MatlabWrapper('demo.m')
mocked_matlab_module = mock.MagicMock()
mocked_engine = mock.MagicMock()
mocked_matlab_module.engine.start_matlab.return_value = mocked_engine
Expand Down

0 comments on commit 378ff0b

Please sign in to comment.