Skip to content

Commit

Permalink
SelectionFunction and SmoothedSelectionFunction: __eq__ and __hash__
Browse files Browse the repository at this point in the history
  • Loading branch information
markotoplak committed Nov 16, 2023
1 parent baf3156 commit ce9571a
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 14 deletions.
51 changes: 40 additions & 11 deletions orangecontrib/spectroscopy/preprocess/emsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,57 @@
from orangecontrib.spectroscopy.preprocess.npfunc import Function, Segments


class SelectionFunction(Segments):
class SelectionFunction(Function):
"""
Weighted selection function. Includes min and max.
"""
def __init__(self, min_, max_, w):
super().__init__((lambda x: True,
lambda x: 0),
(lambda x: np.logical_and(x >= min_, x <= max_),
lambda x: w))
super().__init__(None)
self.min_ = min_
self.max_ = max_
self.w = w

def __call__(self, x):
seg = Segments((lambda x: True, lambda x: 0),
(lambda x: np.logical_and(x >= self.min_, x <= self.max_),
lambda x: self.w)
)
return seg(x)

def __eq__(self, other):
return super().__eq__(other) \
and self.min_ == other.min_ \
and self.max_ == other.max_ \
and self.w == other.w

def __hash__(self):
return hash((super().__hash__(), self.min_, self.max_, self.w))


class SmoothedSelectionFunction(Segments):
class SmoothedSelectionFunction(SelectionFunction):
"""
Weighted selection function. Min and max points are middle
points of smoothing with hyperbolic tangent.
"""
def __init__(self, min_, max_, s, w):
middle = (min_ + max_) / 2
super().__init__((lambda x: x < middle,
lambda x: (np.tanh((x - min_) / s) + 1) / 2 * w),
(lambda x: x >= middle,
lambda x: (-np.tanh((x - max_) / s) + 1) / 2 * w))
super().__init__(min_, max_, w)
self.s = s

def __call__(self, x):
middle = (self.min_ + self.max_) / 2
seg = Segments((lambda x: x < middle,
lambda x: (np.tanh((x - self.min_) / self.s) + 1) / 2 * self.w),
(lambda x: x >= middle,
lambda x: (-np.tanh((x - self.max_) / self.s) + 1) / 2 * self.w)
)
return seg(x)

def __eq__(self, other):
return super().__eq__(other) \
and self.s == other.s

def __hash__(self):
return hash((super().__hash__(), self.s))


def weighted_wavenumbers(weights, wavenumbers):
Expand Down
35 changes: 33 additions & 2 deletions orangecontrib/spectroscopy/preprocess/npfunc.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,44 @@
import numpy as np


class Function():
class Function:

def __init__(self, fn):
self.fn = fn

def __call__(self, x):
return self.fn(x)

def __eq__(self, other):
return type(self) is type(other) \
and self.fn == other.fn

def __hash__(self):
return hash((type(self), self.fn))


class Constant(Function):

def __init__(self, c):
super().__init__(None)
self.c = c

def __call__(self, x):
x = np.asarray(x)
return np.ones(x.shape)*self.c

def __eq__(self, other):
return super().__eq__(other) \
and self.c == other.c

def __hash__(self):
return hash((super().__hash__(), self.c))


class Identity(Function):

def __init__(self):
pass
super().__init__(None)

def __call__(self, x):
return x
Expand All @@ -38,6 +53,7 @@ class Segments(Function):
"""

def __init__(self, *segments):
super().__init__(None)
self.segments = segments

def __call__(self, x):
Expand All @@ -48,10 +64,18 @@ def __call__(self, x):
output[ind] = fn(x[ind])
return output

def __eq__(self, other):
return super().__eq__(other) \
and self.segments == other.segments

def __hash__(self):
return hash((super().__hash__(), self.segments))


class Sum(Function):

def __init__(self, *elements):
super().__init__(None)
self.elements = elements

def __call__(self, x):
Expand All @@ -63,3 +87,10 @@ def __call__(self, x):
else:
acc = acc + current
return acc

def __eq__(self, other):
return super().__eq__(other) \
and self.segments == other.elements

def __hash__(self):
return hash((super().__hash__(), self.elements))
6 changes: 5 additions & 1 deletion orangecontrib/spectroscopy/tests/test_emsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,13 @@ def test_eq(self):
self.assertNotEqual(d1.domain, d2.domain)
self.assertNotEqual(hash(d1.domain), hash(d2.domain))

# these two are not the same because SelectionFunction does not define __eq__ and __hash__
d3 = EMSC(reference=data_ref[0:1], badspectra=badspec, order=1, output_model=True,
weights=SelectionFunction(0, 3, 1))(data)
self.assertEqual(d3.domain, d2.domain)
self.assertEqual(hash(d3.domain), hash(d2.domain))

d3 = EMSC(reference=data_ref[0:1], badspectra=badspec, order=1, output_model=True,
weights=SmoothedSelectionFunction(0, 3, 1, 0.5))(data)
self.assertNotEqual(d3.domain, d2.domain)
self.assertNotEqual(hash(d3.domain), hash(d2.domain))

Expand Down

0 comments on commit ce9571a

Please sign in to comment.