Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added plotting feature using matplotlib #24

Merged
merged 26 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 190 additions & 0 deletions RAT/utils/plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
"""
Plots using the matplotlib library
"""
import matplotlib.pyplot as plt
import numpy as np
from RAT.rat_core import PlotEventData, makeSLDProfileXY


class Figure:
"""
Creates a plotting figure.
"""

def __init__(self, row: int = 1, col: int = 2):
"""
Initializes the figure and the subplots.

Parameters
----------
row : int
The number of rows in subplot
col : int
The number of columns in subplot
"""
self._fig, self._ax = \
plt.subplots(row, col, num="Reflectivity Algorithms Toolbox (RAT)")
plt.show(block=False)
self._esc_pressed = False
self._close_clicked = False
self._fig.canvas.mpl_connect("key_press_event",
self._process_button_press)
self._fig.canvas.mpl_connect('close_event',
self._close)

def wait_for_close(self):
"""
Waits for the user to close the figure
using the esc key.
"""
while not (self._esc_pressed or self._close_clicked):
plt.waitforbuttonpress(timeout=0.005)
plt.close(self._fig)

def _process_button_press(self, event):
"""
Process the key_press_event.
"""
if event.key == 'escape':
self._esc_pressed = True

def _close(self, _):
"""
Process the close_event.
"""
self._close_clicked = True


def plot_errorbars(ax, x, y, err, onesided, color):
"""
Plots the error bars.

Parameters
----------
ax : matplotlib.axes._axes.Axes
The axis on which to draw errorbars
x : np.ndarray
The shifted data x axis data
y : np.ndarray
The shifted data y axis data
err : np.ndarray
The shifted data e data
onesided : bool
A boolean to indicate whether to draw one sided errorbars
color : str
The hex representing the color of the errorbars
"""
y_error = [[0]*len(err), err] if onesided else err
ax.errorbar(x=x,
y=y,
yerr=y_error,
fmt='none',
ecolor=color,
elinewidth=1,
capsize=0)
ax.scatter(x=x, y=y, s=3, marker="o", color=color)


def plot_ref_sld(data: PlotEventData, fig: Figure = None, delay: bool = True):
"""
Clears the previous plots and updates the ref and SLD plots.

Parameters
----------
data : PlotEventData
The plot event data that contains all the information
to generate the ref and sld plots
fig : Figure
The figure class that has two subplots
delay : bool
Controls whether to delay 0.005s after plot is created

Returns
-------
fig : Figure
The figure class that has two subplots
"""
if fig is None:
fig = Figure()
RabiyaF marked this conversation as resolved.
Show resolved Hide resolved

ref_plot = fig._ax[0]
sld_plot = fig._ax[1]

# Clears the previous plots
ref_plot.cla()
sld_plot.cla()

for i, (r, sd, sld, layer) in enumerate(zip(data.reflectivity,
data.shiftedData,
data.sldProfiles,
data.resampledLayers)):

r, sd, sld, layer = map(lambda x: x[0], (r, sd, sld, layer))

# Calculate the divisor
div = 1 if i == 0 else 2**(4*(i+1))

# Plot the reflectivity on plot (1,1)
ref_plot.plot(r[:, 0],
r[:, 1]/div,
label=f'ref {i+1}',
linewidth=2)
color = ref_plot.get_lines()[-1].get_color()

if data.dataPresent[i]:
sd_x = sd[:, 0]
sd_y, sd_e = map(lambda x: x/div, (sd[:, 1], sd[:, 2]))

# Plot the errorbars
indices_removed = np.flip(np.nonzero(sd_y - sd_e < 0)[0])
sd_x_r, sd_y_r, sd_e_r = map(lambda x:
np.delete(x, indices_removed),
(sd_x, sd_y, sd_e))
plot_errorbars(ref_plot, sd_x_r, sd_y_r, sd_e_r, False, color)

# Plot one sided errorbars
indices_selected = [x for x in indices_removed
if x not in np.nonzero(sd_y < 0)[0]]
sd_x_s, sd_y_s, sd_e_s = map(lambda x:
[x[i] for i in indices_selected],
(sd_x, sd_y, sd_e))
plot_errorbars(ref_plot, sd_x_s, sd_y_s, sd_e_s, True, color)

# Plot the slds on plot (1,2)
for j in range(1, sld.shape[1]):
sld_plot.plot(sld[:, 0],
sld[:, j],
label=f'sld {i+1}',
color=color,
linewidth=2)

if data.resample[i] == 1 or data.modelType == 'custom xy':
new = makeSLDProfileXY(layer[0, 1],
layer[-1, 1],
data.subRoughs[i],
layer,
len(layer),
1.0)

sld_plot.plot([row[0]-49 for row in new],
[row[1] for row in new],
color=color,
linewidth=1)

# Format the axis
ref_plot.set_yscale('log')
ref_plot.set_xscale('log')
ref_plot.set_xlabel('Qz')
ref_plot.set_ylabel('Ref')
ref_plot.legend()
ref_plot.grid()

sld_plot.set_xlabel('Z')
sld_plot.set_ylabel('SLD')
sld_plot.legend()
sld_plot.grid()

if delay:
plt.pause(0.005)

return fig
26 changes: 25 additions & 1 deletion cpp/rat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ setup_pybind11(cfg)
#include "RAT/RATMain_initialize.h"
#include "RAT/RATMain_terminate.h"
#include "RAT/RATMain_types.h"
#include "RAT/makeSLDProfileXY.h"
#include "RAT/classHandle.hpp"
#include "RAT/dylib.hpp"
#include "RAT/events/eventManager.h"
Expand Down Expand Up @@ -1165,6 +1166,27 @@ py::tuple RATMain(const ProblemDefinition& problem_def, const Cells& cells, cons
bayesResultsFromStruct8T(bayesResults));
}

py::array_t<real_T> makeSLDProfileXY(real_T bulk_in,
real_T bulk_out,
real_T ssub,
const py::array_t<real_T> &layers,
real_T number_of_layers,
real_T repeats)
{
coder::array<real_T, 2U> out;
coder::array<real_T, 2U> layers_array = pyArrayToRatArray2d(layers);
RAT::makeSLDProfileXY(bulk_in,
bulk_out,
ssub,
layers_array,
number_of_layers,
repeats,
out);

return pyArrayFromRatArray2d(out);

}

class Module
{
public:
Expand Down Expand Up @@ -1434,5 +1456,7 @@ PYBIND11_MODULE(rat_core, m) {
.def_readwrite("fitLimits", &ProblemDefinition::fitLimits)
.def_readwrite("otherLimits", &ProblemDefinition::otherLimits);

m.def("RATMain", &RATMain, "Entry point for the main reflectivity computation.");
m.def("RATMain", &RATMain, "Entry point for the main reflectivity computation.");

m.def("makeSLDProfileXY", &makeSLDProfileXY, "Creates the profiles for the SLD plots");
}
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ pybind11 >= 2.4
pydantic >= 2.4.2, <= 2.6.4
pytest >= 7.4.0
pytest-cov >= 4.1.0
matplotlib >= 3.8.3
StrEnum >= 0.4.15; python_version < '3.11'
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def build_libraries(self, libraries):
libraries = [libevent],
ext_modules = ext_modules,
python_requires = '>=3.9',
install_requires = ['numpy >= 1.20', 'prettytable >= 3.9.0', 'pydantic >= 2.4.2, <= 2.6.4'],
install_requires = ['numpy >= 1.20', 'prettytable >= 3.9.0', 'pydantic >= 2.4.2, <= 2.6.4', 'matplotlib >= 3.8.3'],
extras_require = {':python_version < "3.11"': ['StrEnum >= 0.4.15'],
'Dev': ['pytest>=7.4.0', 'pytest-cov>=4.1.0'],
'Matlab_latest': ['matlabengine'],
Expand Down
Binary file added tests/test_data/plotting_data.pickle
Binary file not shown.
Loading
Loading