Skip to content

Commit

Permalink
meemsc: __eq__ and __hash__
Browse files Browse the repository at this point in the history
  • Loading branch information
markotoplak committed Nov 16, 2023
1 parent ce9571a commit 0636def
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 5 deletions.
29 changes: 24 additions & 5 deletions orangecontrib/spectroscopy/preprocess/me_emsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from sklearn.decomposition import TruncatedSVD

import Orange
from Orange.data import Table
from Orange.preprocess.preprocess import Preprocess
from Orange.data.util import get_unique_names

from orangecontrib.spectroscopy.data import getx
from orangecontrib.spectroscopy.preprocess.utils import SelectColumn, CommonDomainOrderUnknowns, \
interpolate_extend_to
interpolate_extend_to, CommonDomainRef, table_eq_x
from orangecontrib.spectroscopy.preprocess.emsc import weighted_wavenumbers, average_table_x


Expand Down Expand Up @@ -87,13 +88,13 @@ class ME_EMSCModel(SelectColumn):
InheritEq = True


class _ME_EMSC(CommonDomainOrderUnknowns):
class _ME_EMSC(CommonDomainOrderUnknowns, CommonDomainRef):

def __init__(self, reference, weights, ncomp, alpha0, gamma, maxNiter, fixedNiter, positiveRef, domain):
super().__init__(domain)
CommonDomainOrderUnknowns.__init__(self, domain)
CommonDomainRef.__init__(self, reference, domain)
assert len(reference) == 1
self.reference = reference
self.weights = weights # !!! THIS SHOULD BE A NP ARRAY (or similar) with inflection points
self.weights = weights
self.ncomp = ncomp
self.alpha0 = alpha0
self.gamma = gamma
Expand Down Expand Up @@ -276,6 +277,24 @@ def iterate(spectra, correctedFirsIteration, residualsFirstIteration, wavenumber
newspectra = np.hstack((newspectra, numberOfIterations.reshape(-1, 1),RMSEall.reshape(-1, 1)))
return newspectra

def __eq__(self, other):
return CommonDomainRef.__eq__(self, other) \
and self.ncomp == other.ncomp \
and np.array_equal(self.alpha0, other.alpha0) \
and np.array_equal(self.gamma, other.gamma) \
and self.maxNiter == other.maxNiter \
and self.fixedNiter == other.fixedNiter \
and self.positiveRef == other.positiveRef \
and (self.weights == other.weights
if not isinstance(self.weights, Table)
else table_eq_x(self.weights, other.weights))

def __hash__(self):
weights = self.weights \
if not isinstance(self.weights, Table) else tuple(self.weights.X[0][:10])
return hash((CommonDomainRef.__hash__(self), weights, self.ncomp, tuple(self.alpha0),
tuple(self.gamma), self.maxNiter, self.fixedNiter, self.positiveRef))


class MissingReferenceException(Exception):
pass
Expand Down
33 changes: 33 additions & 0 deletions orangecontrib/spectroscopy/tests/test_me_emsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,39 @@ def test_short_reference(self):
# it was crashing before
ME_EMSC(reference=reference)(self.spectra)

def test_eq(self):
ref = self.reference
spectra = self.spectra
d1 = ME_EMSC(reference=ref, ncomp=False, weights=False, max_iter=1)(spectra)
d2 = ME_EMSC(reference=ref, ncomp=False, weights=False, max_iter=1)(spectra)
self.assertEqual(d1.domain, d2.domain)
self.assertEqual(hash(d1.domain), hash(d2.domain))

d2 = ME_EMSC(reference=ref, ncomp=2, weights=False, max_iter=1)(spectra)
self.assertNotEqual(d1.domain, d2.domain)
self.assertNotEqual(hash(d1.domain), hash(d2.domain))

d2 = ME_EMSC(reference=ref, ncomp=False, weights=False, max_iter=1,
n0=np.linspace(1.1, 1.4, 11),
a=np.linspace(2, 7.1, 11))(spectra)
self.assertNotEqual(d1.domain, d2.domain)
self.assertNotEqual(hash(d1.domain), hash(d2.domain))

d2 = ME_EMSC(reference=ref, ncomp=False, weights=False, max_iter=2)(spectra)
self.assertNotEqual(d1.domain, d2.domain)
self.assertNotEqual(hash(d1.domain), hash(d2.domain))

r2 = ref.copy()
d2 = ME_EMSC(reference=r2, ncomp=False, weights=False, max_iter=1)(spectra)
self.assertEqual(d1.domain, d2.domain)
self.assertEqual(hash(d1.domain), hash(d2.domain))

with r2.unlocked():
r2[0][0] = 1
d2 = ME_EMSC(reference=r2, ncomp=False, weights=False, max_iter=1)(spectra)
self.assertNotEqual(d1.domain, d2.domain)
self.assertNotEqual(hash(d1.domain), hash(d2.domain))


class TestInflectionPointWeighting(unittest.TestCase):

Expand Down

0 comments on commit 0636def

Please sign in to comment.