diff --git a/orangecontrib/spectroscopy/preprocess/me_emsc.py b/orangecontrib/spectroscopy/preprocess/me_emsc.py index 68a26d3cc..88b92997b 100644 --- a/orangecontrib/spectroscopy/preprocess/me_emsc.py +++ b/orangecontrib/spectroscopy/preprocess/me_emsc.py @@ -3,12 +3,13 @@ from sklearn.decomposition import TruncatedSVD import Orange +from Orange.data import Table from Orange.preprocess.preprocess import Preprocess from Orange.data.util import get_unique_names from orangecontrib.spectroscopy.data import getx from orangecontrib.spectroscopy.preprocess.utils import SelectColumn, CommonDomainOrderUnknowns, \ - interpolate_extend_to + interpolate_extend_to, CommonDomainRef, table_eq_x from orangecontrib.spectroscopy.preprocess.emsc import weighted_wavenumbers, average_table_x @@ -87,13 +88,13 @@ class ME_EMSCModel(SelectColumn): InheritEq = True -class _ME_EMSC(CommonDomainOrderUnknowns): +class _ME_EMSC(CommonDomainOrderUnknowns, CommonDomainRef): def __init__(self, reference, weights, ncomp, alpha0, gamma, maxNiter, fixedNiter, positiveRef, domain): - super().__init__(domain) + CommonDomainOrderUnknowns.__init__(self, domain) + CommonDomainRef.__init__(self, reference, domain) assert len(reference) == 1 - self.reference = reference - self.weights = weights # !!! THIS SHOULD BE A NP ARRAY (or similar) with inflection points + self.weights = weights self.ncomp = ncomp self.alpha0 = alpha0 self.gamma = gamma @@ -276,6 +277,24 @@ def iterate(spectra, correctedFirsIteration, residualsFirstIteration, wavenumber newspectra = np.hstack((newspectra, numberOfIterations.reshape(-1, 1),RMSEall.reshape(-1, 1))) return newspectra + def __eq__(self, other): + return CommonDomainRef.__eq__(self, other) \ + and self.ncomp == other.ncomp \ + and np.array_equal(self.alpha0, other.alpha0) \ + and np.array_equal(self.gamma, other.gamma) \ + and self.maxNiter == other.maxNiter \ + and self.fixedNiter == other.fixedNiter \ + and self.positiveRef == other.positiveRef \ + and (self.weights == other.weights + if not isinstance(self.weights, Table) + else table_eq_x(self.weights, other.weights)) + + def __hash__(self): + weights = self.weights \ + if not isinstance(self.weights, Table) else tuple(self.weights.X[0][:10]) + return hash((CommonDomainRef.__hash__(self), weights, self.ncomp, tuple(self.alpha0), + tuple(self.gamma), self.maxNiter, self.fixedNiter, self.positiveRef)) + class MissingReferenceException(Exception): pass diff --git a/orangecontrib/spectroscopy/tests/test_me_emsc.py b/orangecontrib/spectroscopy/tests/test_me_emsc.py index 8722d229d..b651d5a3e 100644 --- a/orangecontrib/spectroscopy/tests/test_me_emsc.py +++ b/orangecontrib/spectroscopy/tests/test_me_emsc.py @@ -208,6 +208,39 @@ def test_short_reference(self): # it was crashing before ME_EMSC(reference=reference)(self.spectra) + def test_eq(self): + ref = self.reference + spectra = self.spectra + d1 = ME_EMSC(reference=ref, ncomp=False, weights=False, max_iter=1)(spectra) + d2 = ME_EMSC(reference=ref, ncomp=False, weights=False, max_iter=1)(spectra) + self.assertEqual(d1.domain, d2.domain) + self.assertEqual(hash(d1.domain), hash(d2.domain)) + + d2 = ME_EMSC(reference=ref, ncomp=2, weights=False, max_iter=1)(spectra) + self.assertNotEqual(d1.domain, d2.domain) + self.assertNotEqual(hash(d1.domain), hash(d2.domain)) + + d2 = ME_EMSC(reference=ref, ncomp=False, weights=False, max_iter=1, + n0=np.linspace(1.1, 1.4, 11), + a=np.linspace(2, 7.1, 11))(spectra) + self.assertNotEqual(d1.domain, d2.domain) + self.assertNotEqual(hash(d1.domain), hash(d2.domain)) + + d2 = ME_EMSC(reference=ref, ncomp=False, weights=False, max_iter=2)(spectra) + self.assertNotEqual(d1.domain, d2.domain) + self.assertNotEqual(hash(d1.domain), hash(d2.domain)) + + r2 = ref.copy() + d2 = ME_EMSC(reference=r2, ncomp=False, weights=False, max_iter=1)(spectra) + self.assertEqual(d1.domain, d2.domain) + self.assertEqual(hash(d1.domain), hash(d2.domain)) + + with r2.unlocked(): + r2[0][0] = 1 + d2 = ME_EMSC(reference=r2, ncomp=False, weights=False, max_iter=1)(spectra) + self.assertNotEqual(d1.domain, d2.domain) + self.assertNotEqual(hash(d1.domain), hash(d2.domain)) + class TestInflectionPointWeighting(unittest.TestCase):