-
Notifications
You must be signed in to change notification settings - Fork 0
/
getters.py
78 lines (66 loc) · 2.03 KB
/
getters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
Objects intended for public use:
AbstractGetter
get_sigma_stdev
getter_kyeq0
getter_kxeq0
"""
import numpy as np
import abc
import warnings
from .utils import stdev_central
class AbstractGetter(abc.ABC):
def __new__(cls, *args, **kwargs):
self = object.__new__(cls)
if len(args) > 0:
self.dr = args[0]
else:
self.dr = kwargs['dr']
data, omt = self.get_data(*args, **kwargs)
sigma = self.get_sigma(*args, **kwargs, data=data, omt=omt)
if np.all(sigma == 0):
"""
This can only happen if np.all(data == 0), in which case we set sigma=1 to prevent divide-by-zero errors.
"""
sigma[:] = 1
elif np.any(sigma == 0):
warnings.warn("Estimated error was zero in some bins. Applying floor to estimated error.")
min_sigma = np.min(np.compress(sigma != 0, sigma)) #smallest nonzero value
sigma = np.where(sigma == 0, min_sigma, sigma)
return {
'data_near_target': data,
'omt_near_target': omt,
'sigma': sigma,
}
@abc.abstractmethod
def get_data(self, dr, k_tilde, z, om_tilde_min, om_tilde_max):
raise NotImplementedError
@abc.abstractmethod
def get_sigma(self, *args, **kwargs):
raise NotImplementedError
class get_sigma_stdev():
"""
Frequency-independent estimate of sigma using the standard deviation of the data.
"""
def get_sigma(self, *args, data, omt, **kwargs):
return stdev_central(data, 0.05, adjust=True)
class getter_kyeq0(get_sigma_stdev, AbstractGetter):
def get_data(self, dr, k_tilde, z, om_tilde_min, om_tilde_max):
data_near_target, [omt_near_target, *_] = dr.get_slice(
omega_tilde=(om_tilde_min, om_tilde_max),
kx_tilde = k_tilde,
ky_tilde = 0,
z = z,
compress = True,
)
return data_near_target, omt_near_target
class getter_kxeq0(get_sigma_stdev, AbstractGetter):
def get_data(self, dr, k_tilde, z, om_tilde_min, om_tilde_max):
data_near_target, [omt_near_target, *_] = dr.get_slice(
omega_tilde=(om_tilde_min, om_tilde_max),
kx_tilde = 0,
ky_tilde = k_tilde,
z = z,
compress = True,
)
return data_near_target, omt_near_target