Skip to content

Commit

Permalink
Added the figure class
Browse files Browse the repository at this point in the history
  • Loading branch information
RabiyaF committed Apr 25, 2024
1 parent dcd8db1 commit ed8f06e
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 190 deletions.
283 changes: 141 additions & 142 deletions RAT/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,161 +6,32 @@
from RAT.rat_core import PlotEventData, makeSLDProfileXY


class RATPlots:
class Figure:
"""
Creates the RAT reflectivity and Scattering
Lenght Density (SLD) plots
Creates a plotting figure.
"""

def __init__(self, delay: bool = True):
def __init__(self, row: int = 1, col: int = 2):
"""
Initializes the figure and the subplots
Initializes the figure and the subplots.
Parameters
----------
row : int
The number of rows in subplot
col : int
The number of columns in subplot
"""
self._fig, (self._ref_plot, self._sld_plot) = \
plt.subplots(1, 2, num="Reflectivity Algorithms Toolbox (RAT)")
self._fig, self._ax = \
plt.subplots(row, col, num="Reflectivity Algorithms Toolbox (RAT)")
plt.show(block=False)
self._delay = delay
self._esc_pressed = False
self._close_clicked = False
self._fig.canvas.mpl_connect("key_press_event",
self._process_button_press)
self._fig.canvas.mpl_connect('close_event',
self._close)

def _plot(self, data: PlotEventData):
"""
Clears the previous plots and updates the ref and SLD plots.
Parameters
----------
data : PlotEventData
The plot event data that contains all the information
to generate the ref and sld plots
"""
# Clears the previous plots
self._ref_plot.cla()
self._sld_plot.cla()

for i, (r, sd, sld, layer) in enumerate(zip(data.reflectivity,
data.shiftedData,
data.sldProfiles,
data.resampledLayers)):

r, sd, sld, layer = map(lambda x: x[0], (r, sd, sld, layer))

# Calculate the divisor
div = 1 if i == 0 else 2**(4*(i+1))

# Plot the reflectivity on plot (1,1)
self._ref_plot.plot(r[:, 0],
r[:, 1]/div,
label=f'ref {i+1}',
linewidth=2)
self.color = self._ref_plot.get_lines()[-1].get_color()
if data.dataPresent[i]:
self._plot_shifted_data(sd, div)

# Plot the slds on plot (1,2)
for j in range(1, sld.shape[1]):
self._sld_plot.plot(sld[:, 0],
sld[:, j],
label=f'sld {i+1}',
color=self.color,
linewidth=2)

if data.resample[i] == 1 or data.modelType == 'custom xy':
new = makeSLDProfileXY(layer[0, 1],
layer[-1, 1],
data.subRoughs[i],
layer,
len(layer),
1.0)

self._sld_plot.plot([row[0]-49 for row in new],
[row[1] for row in new],
color=self.color,
linewidth=1)

self._format_plots()

if self._delay:
plt.pause(0.005)

def _plot_shifted_data(self, sd, div):
"""
Plots the shifted data.
Parameters
----------
sd : np.ndarray
The shifted data containing the x, y, e data
div : int
The divisor for the data
"""
sd_x = sd[:, 0]
sd_y, sd_e = map(lambda x: x/div, (sd[:, 1], sd[:, 2]))

# Plot the errorbars
indices_removed = np.flip(np.nonzero(sd_y - sd_e < 0)[0])
sd_x_r, sd_y_r, sd_e_r = map(lambda x:
np.delete(x, indices_removed),
(sd_x, sd_y, sd_e))
self._plot_errorbars(sd_x_r, sd_y_r, sd_e_r, False)

# Plot one sided errorbars
indices_selected = [x for x in indices_removed
if x not in np.nonzero(sd_y < 0)[0]]
sd_x_s, sd_y_s, sd_e_s = map(lambda x:
[x[i] for i in indices_selected],
(sd_x, sd_y, sd_e))
self._plot_errorbars(sd_x_s, sd_y_s, sd_e_s, True)

def _plot_errorbars(self, x, y, err, onesided):
"""
Plots the error bars.
Parameters
----------
x : np.ndarray
The shifted data x axis data
y : np.ndarray
The shifted data y axis data
err : np.ndarray
The shifted data e data
onesided : bool
A boolean to indicate whether
to draw one sided errorbars
"""
y_error = [[0]*len(err), err] if onesided else err
self._ref_plot.errorbar(x=x,
y=y,
yerr=y_error,
fmt='none',
ecolor=self.color,
elinewidth=1,
capsize=0)
self._ref_plot.scatter(x=x,
y=y,
s=3,
marker="o",
color=self.color)

def _format_plots(self):
"""
Formats the ref and sld subplots.
"""
self._ref_plot.set_yscale('log')
self._ref_plot.set_xscale('log')
self._ref_plot.set_xlabel('Qz')
self._ref_plot.set_ylabel('Ref')
self._ref_plot.legend()
self._ref_plot.grid()

self._sld_plot.set_xlabel('Z')
self._sld_plot.set_ylabel('SLD')
self._sld_plot.legend()
self._sld_plot.grid()

def wait_for_close(self):
"""
Waits for the user to close the figure
Expand All @@ -182,3 +53,131 @@ def _close(self, _):
Process the close_event.
"""
self._close_clicked = True


def plot_errorbars(ax, x, y, err, onesided, color):
"""
Plots the error bars.
Parameters
----------
ax : matplotlib.axes._axes.Axes
The axis on which to draw errorbars
x : np.ndarray
The shifted data x axis data
y : np.ndarray
The shifted data y axis data
err : np.ndarray
The shifted data e data
onesided : bool
A boolean to indicate whether to draw one sided errorbars
color : str
The hex representing the color of the errorbars
"""
y_error = [[0]*len(err), err] if onesided else err
ax.errorbar(x=x,
y=y,
yerr=y_error,
fmt='none',
ecolor=color,
elinewidth=1,
capsize=0)
ax.scatter(x=x, y=y, s=3, marker="o", color=color)


def plot_ref_sld(data: PlotEventData, fig: Figure = None, delay: bool = True):
"""
Clears the previous plots and updates the ref and SLD plots.
Parameters
----------
data : PlotEventData
The plot event data that contains all the information
to generate the ref and sld plots
fig : Figure
The figure class that has two subplots
delay : bool
Controls whether to delay 0.005s after plot is created
"""
if fig is None:
fig = Figure()

ref_plot = fig._ax[0]
sld_plot = fig._ax[1]

# Clears the previous plots
ref_plot.cla()
sld_plot.cla()

for i, (r, sd, sld, layer) in enumerate(zip(data.reflectivity,
data.shiftedData,
data.sldProfiles,
data.resampledLayers)):

r, sd, sld, layer = map(lambda x: x[0], (r, sd, sld, layer))

# Calculate the divisor
div = 1 if i == 0 else 2**(4*(i+1))

# Plot the reflectivity on plot (1,1)
ref_plot.plot(r[:, 0],
r[:, 1]/div,
label=f'ref {i+1}',
linewidth=2)
color = ref_plot.get_lines()[-1].get_color()

if data.dataPresent[i]:
sd_x = sd[:, 0]
sd_y, sd_e = map(lambda x: x/div, (sd[:, 1], sd[:, 2]))

# Plot the errorbars
indices_removed = np.flip(np.nonzero(sd_y - sd_e < 0)[0])
sd_x_r, sd_y_r, sd_e_r = map(lambda x:
np.delete(x, indices_removed),
(sd_x, sd_y, sd_e))
plot_errorbars(ref_plot, sd_x_r, sd_y_r, sd_e_r, False, color)

# Plot one sided errorbars
indices_selected = [x for x in indices_removed
if x not in np.nonzero(sd_y < 0)[0]]
sd_x_s, sd_y_s, sd_e_s = map(lambda x:
[x[i] for i in indices_selected],
(sd_x, sd_y, sd_e))
plot_errorbars(ref_plot, sd_x_s, sd_y_s, sd_e_s, True, color)

# Plot the slds on plot (1,2)
for j in range(1, sld.shape[1]):
sld_plot.plot(sld[:, 0],
sld[:, j],
label=f'sld {i+1}',
color=color,
linewidth=2)

if data.resample[i] == 1 or data.modelType == 'custom xy':
new = makeSLDProfileXY(layer[0, 1],
layer[-1, 1],
data.subRoughs[i],
layer,
len(layer),
1.0)

sld_plot.plot([row[0]-49 for row in new],
[row[1] for row in new],
color=color,
linewidth=1)

# Format the axis
ref_plot.set_yscale('log')
ref_plot.set_xscale('log')
ref_plot.set_xlabel('Qz')
ref_plot.set_ylabel('Ref')
ref_plot.legend()
ref_plot.grid()

sld_plot.set_xlabel('Z')
sld_plot.set_ylabel('SLD')
sld_plot.legend()
sld_plot.grid()

if delay:
plt.pause(0.005)
Loading

0 comments on commit ed8f06e

Please sign in to comment.