Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds live plot context manager and update events #49

Merged
merged 2 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions RATapi/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,26 @@ def register(event_type: EventTypes, callback: Callable[[Union[str, PlotEventDat
__event_callbacks[event_type].add(callback)


def clear() -> None:
"""Clears all event callbacks."""
__event_impl.clear()
for key in __event_callbacks:
__event_callbacks[key] = set()
def clear(key=None, callback=None) -> None:
"""Clears all event callbacks or specific callback.

Parameters
----------
callback : Callable[[Union[str, PlotEventData, ProgressEventData]], None]
The callback for when the event is triggered.

"""
if key is None and callback is None:
for key in __event_callbacks:
__event_callbacks[key] = set()
elif key is not None and callback is not None:
__event_callbacks[key].remove(callback)

for value in __event_callbacks.values():
if value:
break
else:
__event_impl.clear()


dir_path = os.path.dirname(os.path.realpath(__file__))
Expand Down
125 changes: 69 additions & 56 deletions RATapi/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,45 +12,6 @@
from RATapi.rat_core import PlotEventData, makeSLDProfileXY


class Figure:
"""Creates a plotting figure."""

def __init__(self, row: int = 1, col: int = 1):
"""Initializes the figure and the subplots.

Parameters
----------
row : int, default: 1
The number of rows in subplot
col : int, default: 1
The number of columns in subplot

"""
self._fig, self._ax = plt.subplots(row, col, num="Reflectivity Algorithms Toolbox (RAT)")
plt.show(block=False)
self._esc_pressed = False
self._close_clicked = False
self._fig.canvas.mpl_connect("key_press_event", self._process_button_press)
self._fig.canvas.mpl_connect("close_event", self._close)

def wait_for_close(self):
"""Waits for the user to close the figure
using the esc key.
"""
while not (self._esc_pressed or self._close_clicked):
plt.waitforbuttonpress(timeout=0.005)
plt.close(self._fig)

def _process_button_press(self, event):
"""Process the key_press_event."""
if event.key == "escape":
self._esc_pressed = True

def _close(self, _):
"""Process the close_event."""
self._close_clicked = True


def plot_errorbars(ax: Axes, x: np.ndarray, y: np.ndarray, err: np.ndarray, one_sided: bool, color: str):
"""Plots the error bars.

Expand All @@ -75,33 +36,39 @@ def plot_errorbars(ax: Axes, x: np.ndarray, y: np.ndarray, err: np.ndarray, one_
ax.scatter(x=x, y=y, s=3, marker="o", color=color)


def plot_ref_sld_helper(data: PlotEventData, fig: Optional[Figure] = None, delay: bool = True):
def plot_ref_sld_helper(data: PlotEventData, fig: Optional[plt.figure] = None, delay: bool = True):
"""Clears the previous plots and updates the ref and SLD plots.

Parameters
----------
data : PlotEventData
The plot event data that contains all the information
to generate the ref and sld plots
fig : Figure, optional
fig : matplotlib.pyplot.figure, optional
The figure class that has two subplots
delay : bool, default: True
Controls whether to delay 0.005s after plot is created

Returns
-------
fig : Figure
fig : matplotlib.pyplot.figure
The figure class that has two subplots

"""
if fig is None:
fig = Figure(1, 2)
elif fig._ax.shape != (2,):
fig._fig.clf()
fig._ax = fig._fig.subplots(1, 2)
preserve_zoom = False

ref_plot = fig._ax[0]
sld_plot = fig._ax[1]
if fig is None:
fig = plt.subplots(1, 2)[0]
elif len(fig.axes) != 2:
fig.clf()
fig.subplots(1, 2)
fig.subplots_adjust(wspace=0.3)

ref_plot = fig.axes[0]
sld_plot = fig.axes[1]
if ref_plot.lines and fig.canvas.toolbar is not None:
preserve_zoom = True
fig.canvas.toolbar.push_current()

# Clears the previous plots
ref_plot.cla()
Expand Down Expand Up @@ -160,16 +127,18 @@ def plot_ref_sld_helper(data: PlotEventData, fig: Optional[Figure] = None, delay
# Format the axis
ref_plot.set_yscale("log")
ref_plot.set_xscale("log")
ref_plot.set_xlabel("Qz")
ref_plot.set_ylabel("Ref")
ref_plot.set_xlabel("$Q_{z} (\u00c5^{-1})$")
ref_plot.set_ylabel("Reflectivity")
ref_plot.legend()
ref_plot.grid()

sld_plot.set_xlabel("Z")
sld_plot.set_ylabel("SLD")
sld_plot.set_xlabel("$Z (\u00c5)$")
sld_plot.set_ylabel("$SLD (\u00c5^{-2})$")
sld_plot.legend()
sld_plot.grid()

if preserve_zoom:
fig.canvas.toolbar.back()
if delay:
plt.pause(0.005)

Expand Down Expand Up @@ -204,8 +173,52 @@ def plot_ref_sld(
data.subRoughs = results.contrastParams.subRoughs
data.resample = RATapi.inputs.make_resample(project)

figure = Figure(1, 2)
figure = plt.subplots(1, 2)[0]

plot_ref_sld_helper(data, figure)
if block:
figure.wait_for_close()

plt.show(block=block)


class LivePlot:
"""Creates a plot that gets updates from the plot event during a
calculation

Parameters
----------
block : bool, default: False
Indicates the plot should block until it is closed

"""

def __init__(self, block=False):
self.block = block
self.closed = False

def __enter__(self):
self.figure = plt.subplots(1, 2)[0]
self.figure.canvas.mpl_connect("close_event", self._setCloseState)
self.figure.show()
RATapi.events.register(RATapi.events.EventTypes.Plot, self.plotEvent)

return self.figure

def _setCloseState(self, _):
"""Close event handler"""
self.closed = True

def plotEvent(self, event):
"""Callback for the plot event.

Parameters
----------
event: PlotEventData
The plot event data.
"""
if not self.closed and self.figure.number in plt.get_fignums():
plot_ref_sld_helper(event, self.figure)

def __exit__(self, _exc_type, _exc_val, _traceback):
RATapi.events.clear(RATapi.events.EventTypes.Plot, self.plotEvent)
if not self.closed and self.figure.number in plt.get_fignums():
plt.show(block=self.block)
2 changes: 1 addition & 1 deletion cpp/RAT
Submodule RAT updated 1 files
+180 −112 triggerEvent.cpp
33 changes: 20 additions & 13 deletions cpp/rat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,27 +192,34 @@ class EventBridge
};

py::list unpackDataToCell(int rows, int cols, double* data, double* nData,
double* data2, double* nData2, int dataCol)
double* data2, double* nData2, bool isOutput2D=false)
{
py::list allResults;
int dims[2] = {0, dataCol};
int dims[2] = {0, 0};
int offset = 0;
for (int i = 0; i < rows; i++){
py::list rowList;
dims[0] = (int)nData[i] / dataCol;
dims[0] = (int)nData[2*i];
dims[1] = (int)nData[2*i+1];
auto result = py::array_t<double, py::array::f_style>({dims[0], dims[1]});
std::memcpy(result.request().ptr, data + offset, result.nbytes());
offset += result.size();
rowList.append(result);
allResults.append(rowList);
if (isOutput2D){
py::list rowList;
rowList.append(result);
allResults.append(rowList);
}
else{
allResults.append(result);
}
}

if (data2 != NULL && nData2 != NULL)
{
// This is used to unpack the domains data into the second column
offset = 0;
for ( int i = 0; i < rows; i++){
dims[0] = (int)nData2[i] / dataCol;
dims[0] = (int)nData2[2*i];
dims[1] = (int)nData2[2*i+1];
auto result = py::array_t<double, py::array::f_style>({dims[0], dims[1]});
std::memcpy(result.request().ptr, data2 + offset, result.nbytes());
offset += result.size();
Expand Down Expand Up @@ -252,18 +259,18 @@ class EventBridge
std::memcpy(eventData.dataPresent.request().ptr, pEvent->data->dataPresent, eventData.dataPresent.nbytes());

eventData.reflectivity = unpackDataToCell(pEvent->data->nContrast, 1,
pEvent->data->reflect, pEvent->data->nReflect, NULL, NULL, 2);
pEvent->data->reflect, pEvent->data->nReflect, NULL, NULL);

eventData.shiftedData = unpackDataToCell(pEvent->data->nContrast, 1,
pEvent->data->shiftedData, pEvent->data->nShiftedData, NULL, NULL, 3);
pEvent->data->shiftedData, pEvent->data->nShiftedData, NULL, NULL);

eventData.sldProfiles = unpackDataToCell(pEvent->data->nContrast, (pEvent->data->nSldProfiles2 == NULL) ? 1 : 2,
pEvent->data->sldProfiles, pEvent->data->nSldProfiles,
pEvent->data->sldProfiles2, pEvent->data->nSldProfiles2, 2);

pEvent->data->sldProfiles2, pEvent->data->nSldProfiles2, true);
eventData.resampledLayers = unpackDataToCell(pEvent->data->nContrast, (pEvent->data->nLayers2 == NULL) ? 1 : 2,
pEvent->data->layers, pEvent->data->nLayers,
pEvent->data->layers2, pEvent->data->nLayers, 2);
pEvent->data->layers, pEvent->data->nLayers,
pEvent->data->layers2, pEvent->data->nLayers2, true);
this->callback(event.type, eventData);
}
};
Expand Down
Loading
Loading