-
Notifications
You must be signed in to change notification settings - Fork 7
/
wwzplotter.py
437 lines (364 loc) · 14.3 KB
/
wwzplotter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
# -*- coding: utf-8 -*-
#!/usr/bin/env python
"""A class for plotting results of the weighted wavelet z-transform analysis.
"""
import matplotlib.gridspec as gs
import matplotlib.pyplot as plt
from matplotlib.ticker import LogLocator
import numpy as np
import os
import sys
__author__ = "Sebastian Kiehlmann"
__credits__ = ["Sebastian Kiehlmann"]
__license__ = "BSD"
__maintainer__ = "Sebastian Kiehlmann"
__email__ = "[email protected]"
__status__ = "Production"
#==============================================================================
# CLASSES
#==============================================================================
class WWZPlotter:
"""A class for plotting WWZ results."""
#--------------------------------------------------------------------------
def __init__(self, wwz, tunit=None):
"""A class for plotting WWZ results.
Parameters
----------
wwz : wwz.WWZ
A WWZ instance that is used for plotting.
Returns
-------
None.
"""
self.wwz = wwz
self.okay = True
if wwz.wwz is None:
print('Note: There is no WWZ transform data stored in this WWZ ' \
'instance. There will be nothing to plot.')
self.okay = False
if wwz.freq is not None:
# check if frequencies are linearly scaled:
freq = self.wwz.freq
df = np.diff(freq)
self.linear_freq = np.all(np.isclose(df, df.mean()))
self.fmin = freq.min()
self.fmax = freq.max()
# get periods and check if linearly scaled:
period = 1. / freq
dp = np.diff(period)
self.linear_period = np.all(np.isclose(dp, dp.mean()))
self.pmin = period.min()
self.pmax = period.max()
self.n_ybins = freq.size
else:
self.okay = False
if wwz.tau is not None:
self.tmin = wwz.tau.min()
self.tmax = wwz.tau.max()
else:
self.okay = False
if self.okay:
if self.linear_freq:
self.ymin = self.fmax
self.ymax = self.fmin
self.ymin_alt = self.pmax
self.ymax_alt = self.pmin
self.ylabel = f'Frequency [1/{tunit}]' \
if isinstance(tunit, str) else 'Frequency'
self.ylabel_alt = f'Period [{tunit}]' \
if isinstance(tunit, str) else 'Period'
elif self.linear_period:
self.ymin = self.pmin
self.ymax = self.pmax
self.ymin_alt = self.fmin
self.ymax_alt = self.fmax
self.ylabel = f'Period [{tunit}]' if isinstance(tunit, str) \
else 'Period'
self.ylabel_alt = f'Frequency [1/{tunit}]' \
if isinstance(tunit, str) else 'Frequency'
else:
self.ymin = 0
self.ymax = 1
self.ymin_alt = 1
self.ymax_alt = 0
self.ylabel = 'Non-linear scale'
self.ylabel_alt = 'Non-linear scale'
#--------------------------------------------------------------------------
def _select_map(self, select):
"""Helper method to select a map from a WWZ instance.
Parameters
----------
select : str
Select either 'wwz' or 'wwa'.
Raises
------
ValueError
Raised if 'select' is not one of the allowed options.
Returns
-------
result : numpy.ndarray
The selected WWZ or WWA array.
"""
# check that selection is allowed:
if select.lower() not in ['wwz', 'wwa']:
raise ValueError(f"'{select}' is not a valid selection.")
select = select.lower()
result = eval(f'self.wwz.{select}')
# check if result map is available:
if result is None:
print(f'No {select.upper()} transform available.')
result = result.transpose()
return result
#--------------------------------------------------------------------------
def plot_map(
self, select, ax=None, xlabel=None, **kwargs):
"""Plot the resulting map from a WWZ instance.
Parameters
----------
select : str
Select either 'wwz' or 'wwa'.
ax : matplotlib.pyplot.axis, optional
The axis to plot to. If None is given a new axis is crated. The
default is None.
xlabel : str, optional
The x-axis label. If None is provided no label is placed. The
default is None.
kwargs : dict, optional
Keyword arguments forwarded to the matplotlib.pyplot.imshow()
function.
Returns
-------
matplotlib.pyplot.axis
The axis to which the map was plotted.
matplotlib.image.AxesImage
The image.
"""
if not self.okay:
return None, None
# select result:
result = self._select_map(select)
if result is None:
return None, None
# create figure if needed:
if ax is None:
__, ax = plt.subplots(1)
# plot:
extent = [self.tmin, self.tmax, self.ymin, self.ymax]
im = ax.imshow(
result, origin='upper', aspect='auto', extent=extent,
**kwargs)
# add labels:
if xlabel:
ax.set_xlabel(xlabel)
ax.set_ylabel(self.ylabel)
return ax, im
#--------------------------------------------------------------------------
def plot_map_avg(
self, select, statistic='mean', ax=None, ylabel=False, **kwargs):
"""Vertically plot an average along the time axis of the transform map.
Parameters
----------
select : str
Select either 'wwz' or 'wwa'.
statistic : str, optional
Choose either 'mean' or 'median'. The default is 'mean'.
ax : matplotlib.pyplot.axis, optional
The axis to plot to. If None is given a new axis is crated. The
default is None.
ylabel : bool, optional
If True a label is added to the y-axis. The default is False.
**kwargs : dict
Keyword arguments forwarded to the matplotlib.pyplot.plot()
function.
Raises
------
ValueError
Raised if 'statistic' is not one of the allowed options.
Returns
-------
matplotlib.pyplot.axis
The axis to which the data was plotted.
"""
if not self.okay:
return None
# select result:
result = self._select_map(select)
if result is None:
return None, None
# calculate statistic:
if statistic not in ['mean', 'median']:
raise ValueError(f"'{statistic}' is not a valid statistic.")
elif statistic == 'median':
result_avg = np.median(result, axis=1)
else:
result_avg = np.mean(result, axis=1)
# create figure if needed:
if ax is None:
__, ax = plt.subplots(1)
# plot:
y = np.linspace(self.ymin, self.ymax, result_avg.size)
ax.plot(result_avg[::-1], y, **kwargs)
# add labels:
if ylabel:
ax.set_ylabel(self.ylabel)
ax.set_xlabel(f'{statistic.capitalize()} {select.upper()}')
return ax
#--------------------------------------------------------------------------
def plot_data(
self, ax=None, errorbars=True, xlabel=None, ylabel=None, **kwargs):
"""Plot the data stored in a WWZ instance.
Parameters
----------
ax : matplotlib.pyplot.axis, optional
The axis to plot to. If None is given a new axis is crated. The
default is None.
errorbars : bool, optional
If True errorbars are shown, if uncertainties were stored in the
WWZ instance. The default is True.
xlabel : str, optional
The x-axis description. If None is provided no label is printed.
The default is None.
ylabel : str, optional
The y-axis description. If None is provided no label is printed.
The default is None.
**kwargs : dict
Keyword arguments forwarded to the matplotlib.pyplot.errorbar()
function.
Returns
-------
matplotlib.pyplot.axis
The axis to which the data was plotted.
"""
# check if data is available:
if self.wwz.t is None:
print('No data available.')
return None
# create figure if needed:
if ax is None:
__, ax = plt.subplots(1)
# plot:
if errorbars and self.wwz.s_x is not None:
ax.errorbar(self.wwz.t, self.wwz.x, self.wwz.s_x, **kwargs)
else:
ax.plot(self.wwz.t, self.wwz.x, **kwargs)
# add labels:
if isinstance(xlabel, str):
ax.set_xlabel(xlabel)
if isinstance(ylabel, str):
ax.set_ylabel(ylabel)
return ax
#--------------------------------------------------------------------------
def add_right_labels(self, ax):
"""Add ticks and labels to the right side of a plot showing the
alternative unit, i.e. frequency if period is used on the left side and
vice versa.
Parameters
----------
ax : matplotlib.pyplot.axis, optional
The axis to plot to. If None is given a new axis is crated. The
default is None.
Returns
-------
ax2 : matplotlib.pyplot.axis
The new axis to which the labels were added.
"""
ax2 = ax.twinx()
plt.setp(ax2.get_xticklabels(), visible=False)
ax2.yaxis.set_label_position("right")
ax2.yaxis.tick_right()
ax2.set_ylim(self.ymin_alt, self.ymax_alt)
sys.stderr = open(os.devnull, "w") # silence stderr to supress warning
conversion = lambda x: 1/x
ax2.set_yscale('function', functions=(conversion, conversion))
sys.stderr = sys.__stderr__ # unsilence stderr
ax2.yaxis.set_major_locator(LogLocator(subs='all'))
ax2.set_ylabel(self.ylabel_alt)
return ax2
#--------------------------------------------------------------------------
def plot(self, select, statistic='mean', errorbars=True,
peaks_quantile=None, xlabel=None, ylabel=None, figsize=None,
height_ratios=(2, 1), width_ratios=(5, 1), kwargs_map={},
kwargs_map_avg={}, kwargs_data={}, kwargs_peaks={}):
"""Plot the WWZ map, average, and data.
Parameters
----------
select : str
Select either 'wwz' or 'wwa'.
statistic : str, optional
Choose either 'mean' or 'median'. The default is 'mean'.
errorbars : bool, optional
If True errorbars are shown, if uncertainties were stored in the
WWZ instance. The default is True.
peaks_quantile : float, optional
If not None, a ridge line along the peak position is shown.
peaks_quantile needs to be a float between 0 and 1. Only peaks in
the quantile above this threshold are shown. The default is None.
xlabel : str, optional
The x-axis description. If None is provided no label is printed.
The default is None.
ylabel : str, optional
The y-axis description. If None is provided no label is printed.
The default is None.
figsize : tuple, optional
Set the figure size. The default is None.
height_ratios : tuple, optional
Set the size ratio between the top and bottom panel with two values
in a tuple. The default is (2, 1).
width_ratios : tuple, optional
Set the size ratio between the left and right panel with two values
in a tuple. The default is (5, 1).
kwargs_map : dict, optional
Keyword arguments forwarded to plotting the map. The default is {}.
kwargs_map_avg : dict, optional
Keyword arguments forwarded to plotting the map average. The
default is {}.
kwargs_data : dict, optional
Keyword arguments forwarded to plotting the data. The default is
{}.
kwargs_peaks : dict, optional
Keyword arguments forwarded to plotting the peak ridge lines. The
default is {}.
Returns
-------
ax_map : matplotlib.pyplot.axis
The map axis.
ax_map_avg : matplotlib.pyplot.axis
The map average axis.
ax_data : matplotlib.pyplot.axis
The data axis.
"""
# create figure:
plt.figure(figsize=figsize)
grid = gs.GridSpec(
2, 2, hspace=0, wspace=0, height_ratios=height_ratios,
width_ratios=width_ratios)
ax_map = plt.subplot(grid[0,0])
ax_map_avg = plt.subplot(grid[0,1])
ax_data = plt.subplot(grid[1,0])
# plot map:
self.plot_map(
select, ax=ax_map, **kwargs_map)
# plot map average:
self.plot_map_avg(
select, statistic=statistic, ax=ax_map_avg, **kwargs_map_avg)
extend = (self.ymax - self.ymin) / (self.n_ybins - 1) / 2.
ax_map_avg.set_ylim(self.ymin-extend, self.ymax+extend)
# plot data:
self.plot_data(
ax=ax_data, errorbars=errorbars, xlabel=xlabel, ylabel=ylabel,
**kwargs_data)
ax_data.set_xlim(self.tmin, self.tmax)
# plot peaks:
if peaks_quantile:
peak_tau, peak_pos, peak_signal = self.wwz.find_peaks(
select, peaks_quantile)
ax_map.plot(peak_tau, peak_pos, **kwargs_peaks)
# add right axis labels:
self.add_right_labels(ax_map_avg)
# add data axis labels:
ax_data.set_xlabel(xlabel)
ax_data.set_ylabel(ylabel)
plt.setp(ax_map_avg.get_yticklabels(), visible=False)
plt.setp(ax_map.get_xticklabels(), visible=False)
return ax_map, ax_map_avg, ax_data
#==============================================================================