Skip to content

Commit

Permalink
Adds live plot context manager and update events
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenNneji committed Jul 18, 2024
1 parent 7dcfc1d commit 94483fb
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 143 deletions.
25 changes: 20 additions & 5 deletions RATapi/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,26 @@ def register(event_type: EventTypes, callback: Callable[[Union[str, PlotEventDat
__event_callbacks[event_type].add(callback)


def clear() -> None:
"""Clears all event callbacks."""
__event_impl.clear()
for key in __event_callbacks:
__event_callbacks[key] = set()
def clear(key=None, callback=None) -> None:
"""Clears all event callbacks or specific callback.
Parameters
----------
callback : Callable[[Union[str, PlotEventData, ProgressEventData]], None]
The callback for when the event is triggered.
"""
if key is None and callback is None:
for key in __event_callbacks:
__event_callbacks[key] = set()
elif key is not None and callback is not None:
__event_callbacks[key].remove(callback)

for value in __event_callbacks.values():
if value:
break
else:
__event_impl.clear()


dir_path = os.path.dirname(os.path.realpath(__file__))
Expand Down
2 changes: 1 addition & 1 deletion RATapi/examples/domains/alloy_domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def alloy_domains(params, bulkIn, bulkOut, contrast, domain):
gold = [goldThick, goldSLD, goldRough]

# Make the model depending on which domain we are looking at
if domain == 1:
if domain == 0:
output = [alloyUp, gold]
else:
output = [alloyDn, gold]
Expand Down
2 changes: 1 addition & 1 deletion RATapi/examples/domains/domains_XY_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def domains_XY_model(params, bulk_in, bulk_out, contrast, domain):
oxSLD = vfOxide * 3.41e-6

# Layer SLD depends on whether we are calculating the domain or not
if domain == 1:
if domain == 0:
laySLD = vfLayer * layerSLD
else:
laySLD = vfLayer * domainSLD
Expand Down
115 changes: 64 additions & 51 deletions RATapi/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Optional, Union

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes._axes import Axes
Expand All @@ -12,45 +13,6 @@
from RATapi.rat_core import PlotEventData, makeSLDProfileXY


class Figure:
"""Creates a plotting figure."""

def __init__(self, row: int = 1, col: int = 1):
"""Initializes the figure and the subplots.
Parameters
----------
row : int, default: 1
The number of rows in subplot
col : int, default: 1
The number of columns in subplot
"""
self._fig, self._ax = plt.subplots(row, col, num="Reflectivity Algorithms Toolbox (RAT)")
plt.show(block=False)
self._esc_pressed = False
self._close_clicked = False
self._fig.canvas.mpl_connect("key_press_event", self._process_button_press)
self._fig.canvas.mpl_connect("close_event", self._close)

def wait_for_close(self):
"""Waits for the user to close the figure
using the esc key.
"""
while not (self._esc_pressed or self._close_clicked):
plt.waitforbuttonpress(timeout=0.005)
plt.close(self._fig)

def _process_button_press(self, event):
"""Process the key_press_event."""
if event.key == "escape":
self._esc_pressed = True

def _close(self, _):
"""Process the close_event."""
self._close_clicked = True


def plot_errorbars(ax: Axes, x: np.ndarray, y: np.ndarray, err: np.ndarray, one_sided: bool, color: str):
"""Plots the error bars.
Expand All @@ -75,33 +37,38 @@ def plot_errorbars(ax: Axes, x: np.ndarray, y: np.ndarray, err: np.ndarray, one_
ax.scatter(x=x, y=y, s=3, marker="o", color=color)


def plot_ref_sld_helper(data: PlotEventData, fig: Optional[Figure] = None, delay: bool = True):
def plot_ref_sld_helper(data: PlotEventData, fig: Optional[matplotlib.pyplot.figure] = None, delay: bool = True):
"""Clears the previous plots and updates the ref and SLD plots.
Parameters
----------
data : PlotEventData
The plot event data that contains all the information
to generate the ref and sld plots
fig : Figure, optional
fig : matplotlib.pyplot.figure, optional
The figure class that has two subplots
delay : bool, default: True
Controls whether to delay 0.005s after plot is created
Returns
-------
fig : Figure
fig : matplotlib.pyplot.figure
The figure class that has two subplots
"""
preserve_zoom = False

if fig is None:
fig = Figure(1, 2)
elif fig._ax.shape != (2,):
fig._fig.clf()
fig._ax = fig._fig.subplots(1, 2)
fig = plt.subplots(1, 2)[0]
elif len(fig.axes) != 2:
fig.clf()
fig.subplots(1, 2)

ref_plot = fig._ax[0]
sld_plot = fig._ax[1]
ref_plot = fig.axes[0]
sld_plot = fig.axes[1]
if ref_plot.lines and fig.canvas.toolbar is not None:
preserve_zoom = True
fig.canvas.toolbar.push_current()

# Clears the previous plots
ref_plot.cla()
Expand Down Expand Up @@ -170,6 +137,8 @@ def plot_ref_sld_helper(data: PlotEventData, fig: Optional[Figure] = None, delay
sld_plot.legend()
sld_plot.grid()

if preserve_zoom:
fig.canvas.toolbar.back()
if delay:
plt.pause(0.005)

Expand Down Expand Up @@ -204,8 +173,52 @@ def plot_ref_sld(
data.subRoughs = results.contrastParams.subRoughs
data.resample = RATapi.inputs.make_resample(project)

figure = Figure(1, 2)
figure = plt.subplots(1, 2)[0]

plot_ref_sld_helper(data, figure)
if block:
figure.wait_for_close()

plt.show(block=block)


class LivePlot:
"""Creates a plot that gets updates from the plot event during a
calculation
Parameters
----------
block : bool, default: False
Indicates the plot should block until it is closed
"""

def __init__(self, block=False):
self.block = block
self.closed = False

def __enter__(self):
self.figure = plt.subplots(1, 2)[0]
self.figure.canvas.mpl_connect("close_event", self._setCloseState)
self.figure.show()
RATapi.events.register(RATapi.events.EventTypes.Plot, self.plotEvent)

return self.figure

def _setCloseState(self, _):
"""Close event handler"""
self.closed = True

def plotEvent(self, event):
"""Callback for the plot event.
Parameters
----------
event: PlotEventData
The plot event data.
"""
if not self.closed and self.figure.number in plt.get_fignums():
plot_ref_sld_helper(event, self.figure)

def __exit__(self, _exc_type, _exc_val, _traceback):
RATapi.events.clear(RATapi.events.EventTypes.Plot, self.plotEvent)
if not self.closed and self.figure.number in plt.get_fignums():
plt.show(block=self.block)
2 changes: 1 addition & 1 deletion cpp/RAT
Submodule RAT updated 1 files
+180 −112 triggerEvent.cpp
33 changes: 20 additions & 13 deletions cpp/rat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,27 +192,34 @@ class EventBridge
};

py::list unpackDataToCell(int rows, int cols, double* data, double* nData,
double* data2, double* nData2, int dataCol)
double* data2, double* nData2, bool isOutput2D=false)
{
py::list allResults;
int dims[2] = {0, dataCol};
int dims[2] = {0, 0};
int offset = 0;
for (int i = 0; i < rows; i++){
py::list rowList;
dims[0] = (int)nData[i] / dataCol;
dims[0] = (int)nData[2*i];
dims[1] = (int)nData[2*i+1];
auto result = py::array_t<double, py::array::f_style>({dims[0], dims[1]});
std::memcpy(result.request().ptr, data + offset, result.nbytes());
offset += result.size();
rowList.append(result);
allResults.append(rowList);
if (isOutput2D){
py::list rowList;
rowList.append(result);
allResults.append(rowList);
}
else{
allResults.append(result);
}
}

if (data2 != NULL && nData2 != NULL)
{
// This is used to unpack the domains data into the second column
offset = 0;
for ( int i = 0; i < rows; i++){
dims[0] = (int)nData2[i] / dataCol;
dims[0] = (int)nData2[2*i];
dims[1] = (int)nData2[2*i+1];
auto result = py::array_t<double, py::array::f_style>({dims[0], dims[1]});
std::memcpy(result.request().ptr, data2 + offset, result.nbytes());
offset += result.size();
Expand Down Expand Up @@ -252,18 +259,18 @@ class EventBridge
std::memcpy(eventData.dataPresent.request().ptr, pEvent->data->dataPresent, eventData.dataPresent.nbytes());

eventData.reflectivity = unpackDataToCell(pEvent->data->nContrast, 1,
pEvent->data->reflect, pEvent->data->nReflect, NULL, NULL, 2);
pEvent->data->reflect, pEvent->data->nReflect, NULL, NULL);

eventData.shiftedData = unpackDataToCell(pEvent->data->nContrast, 1,
pEvent->data->shiftedData, pEvent->data->nShiftedData, NULL, NULL, 3);
pEvent->data->shiftedData, pEvent->data->nShiftedData, NULL, NULL);

eventData.sldProfiles = unpackDataToCell(pEvent->data->nContrast, (pEvent->data->nSldProfiles2 == NULL) ? 1 : 2,
pEvent->data->sldProfiles, pEvent->data->nSldProfiles,
pEvent->data->sldProfiles2, pEvent->data->nSldProfiles2, 2);

pEvent->data->sldProfiles2, pEvent->data->nSldProfiles2, true);
eventData.resampledLayers = unpackDataToCell(pEvent->data->nContrast, (pEvent->data->nLayers2 == NULL) ? 1 : 2,
pEvent->data->layers, pEvent->data->nLayers,
pEvent->data->layers2, pEvent->data->nLayers, 2);
pEvent->data->layers, pEvent->data->nLayers,
pEvent->data->layers2, pEvent->data->nLayers2, true);
this->callback(event.type, eventData);
}
};
Expand Down
5 changes: 5 additions & 0 deletions tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def test_event_register() -> None:
[first_callback, second_callback]
)

RATapi.events.clear(RATapi.events.EventTypes.Message, second_callback)
result = RATapi.events.get_event_callback(RATapi.events.EventTypes.Message)
assert result == [first_callback]
result = RATapi.events.get_event_callback(RATapi.events.EventTypes.Plot)
assert result == [second_callback]
RATapi.events.clear()
assert RATapi.events.get_event_callback(RATapi.events.EventTypes.Plot) == []
assert RATapi.events.get_event_callback(RATapi.events.EventTypes.Message) == []
Expand Down
Loading

0 comments on commit 94483fb

Please sign in to comment.