From 09ae7953a58221e423dd164df2944cbb2c370767 Mon Sep 17 00:00:00 2001 From: Marko Toplak Date: Thu, 16 Nov 2023 07:43:25 +0100 Subject: [PATCH 1/5] emsc, me_emsc: refactor reference compat --- orangecontrib/spectroscopy/preprocess/emsc.py | 26 +++++++++---------- .../spectroscopy/preprocess/me_emsc.py | 24 ++++++----------- .../spectroscopy/preprocess/utils.py | 22 +++++++++------- 3 files changed, 33 insertions(+), 39 deletions(-) diff --git a/orangecontrib/spectroscopy/preprocess/emsc.py b/orangecontrib/spectroscopy/preprocess/emsc.py index 6c7ffefb6..1ee701dcf 100644 --- a/orangecontrib/spectroscopy/preprocess/emsc.py +++ b/orangecontrib/spectroscopy/preprocess/emsc.py @@ -4,9 +4,9 @@ from Orange.preprocess.preprocess import Preprocess from Orange.data.util import get_unique_names -from orangecontrib.spectroscopy.data import getx, spectra_mean +from orangecontrib.spectroscopy.data import getx from orangecontrib.spectroscopy.preprocess.utils import SelectColumn, CommonDomainOrderUnknowns, \ - interp1d_with_unknowns_numpy, nan_extend_edges_and_interpolate, MissingReferenceException + interp1d_with_unknowns_numpy, MissingReferenceException, interpolate_extend_to from orangecontrib.spectroscopy.preprocess.npfunc import Function, Segments @@ -65,6 +65,7 @@ class _EMSC(CommonDomainOrderUnknowns): def __init__(self, reference, badspectra, weights, order, scaling, domain): super().__init__(domain) + assert len(reference) == 1 self.reference = reference self.badspectra = badspectra self.weights = weights @@ -74,17 +75,7 @@ def __init__(self, reference, badspectra, weights, order, scaling, domain): def transformed(self, X, wavenumbers): # wavenumber have to be input as sorted # about 85% of time in __call__ function is spent is lstsq - # compute average spectrum from the reference - ref_X = np.atleast_2d(spectra_mean(self.reference.X)) - - def interpolate_to_data(other_xs, other_data): - # all input data needs to be interpolated (and NaNs removed) - interpolated = interp1d_with_unknowns_numpy(other_xs, other_data, wavenumbers) - # we know that X is not NaN. same handling of reference as of X - interpolated, _ = nan_extend_edges_and_interpolate(wavenumbers, interpolated) - return interpolated - - ref_X = interpolate_to_data(getx(self.reference), ref_X) + ref_X = interpolate_extend_to(self.reference, wavenumbers) wei_X = weighted_wavenumbers(self.weights, wavenumbers) N = wavenumbers.shape[0] @@ -93,7 +84,7 @@ def interpolate_to_data(other_xs, other_data): n_badspec = len(self.badspectra) if self.badspectra is not None else 0 if self.badspectra: - badspectra_X = interpolate_to_data(getx(self.badspectra), self.badspectra.X) + badspectra_X = interpolate_extend_to(self.badspectra, wavenumbers) M = [] for x in range(0, self.order+1): @@ -123,6 +114,11 @@ def interpolate_to_data(other_xs, other_data): return newspectra +def average_table_x(data): + return Orange.data.Table.from_numpy(Orange.data.Domain(data.domain.attributes), + X=data.X.mean(axis=0, keepdims=True)) + + class EMSC(Preprocess): def __init__(self, reference=None, badspectra=None, weights=None, order=2, scaling=True, @@ -132,6 +128,8 @@ def __init__(self, reference=None, badspectra=None, weights=None, order=2, scali if reference is None: raise MissingReferenceException() self.reference = reference + if len(self.reference) > 1: + self.reference = average_table_x(self.reference) self.badspectra = badspectra self.weights = weights self.order = order diff --git a/orangecontrib/spectroscopy/preprocess/me_emsc.py b/orangecontrib/spectroscopy/preprocess/me_emsc.py index eb8eaa244..68a26d3cc 100644 --- a/orangecontrib/spectroscopy/preprocess/me_emsc.py +++ b/orangecontrib/spectroscopy/preprocess/me_emsc.py @@ -6,18 +6,10 @@ from Orange.preprocess.preprocess import Preprocess from Orange.data.util import get_unique_names -from orangecontrib.spectroscopy.data import getx, spectra_mean +from orangecontrib.spectroscopy.data import getx from orangecontrib.spectroscopy.preprocess.utils import SelectColumn, CommonDomainOrderUnknowns, \ - interp1d_with_unknowns_numpy, nan_extend_edges_and_interpolate -from orangecontrib.spectroscopy.preprocess.emsc import weighted_wavenumbers - - -def interpolate_to_data(other_xs, other_data, wavenumbers): - # all input data needs to be interpolated (and NaNs removed) - interpolated = interp1d_with_unknowns_numpy(other_xs, other_data, wavenumbers) - # we know that X is not NaN. same handling of reference as of X - interpolated, _ = nan_extend_edges_and_interpolate(wavenumbers, interpolated) - return interpolated + interpolate_extend_to +from orangecontrib.spectroscopy.preprocess.emsc import weighted_wavenumbers, average_table_x def calculate_complex_n(ref_X,wavenumbers): @@ -99,6 +91,7 @@ class _ME_EMSC(CommonDomainOrderUnknowns): def __init__(self, reference, weights, ncomp, alpha0, gamma, maxNiter, fixedNiter, positiveRef, domain): super().__init__(domain) + assert len(reference) == 1 self.reference = reference self.weights = weights # !!! THIS SHOULD BE A NP ARRAY (or similar) with inflection points self.ncomp = ncomp @@ -234,8 +227,7 @@ def iterate(spectra, correctedFirsIteration, residualsFirstIteration, wavenumber break return newspectra, RMSEall, numberOfIterations - ref_X = np.atleast_2d(spectra_mean(self.reference.X)) - ref_X = interpolate_to_data(getx(self.reference), ref_X, wavenumbers) + ref_X = interpolate_extend_to(self.reference, wavenumbers) ref_X = ref_X[0] wei_X = weighted_wavenumbers(self.weights, wavenumbers) @@ -298,6 +290,8 @@ def __init__(self, reference=None, weights=None, ncomp=False, n0=np.linspace(1.1 if reference is None: raise MissingReferenceException() self.reference = reference + if len(self.reference) > 1: + self.reference = average_table_x(self.reference) self.weights = weights self.ncomp = ncomp self.output_model = output_model @@ -315,10 +309,8 @@ def __init__(self, reference=None, weights=None, ncomp=False, n0=np.linspace(1.1 self.gamma = self.h * np.log(10) / (4 * np.pi * 0.5 * np.pi * (self.n0 - 1) * self.a * 1e-6) if not self.ncomp: - ref_X = np.atleast_2d(spectra_mean(self.reference.X)) wavenumbers_ref = np.array(sorted(getx(self.reference))) - ref_X = interpolate_to_data(getx(self.reference), ref_X, wavenumbers_ref) - ref_X = ref_X[0] + ref_X = interpolate_extend_to(self.reference, wavenumbers_ref)[0] self.ncomp = cal_ncomp(ref_X, wavenumbers_ref, explainedVariance, self.alpha0, self.gamma) else: self.explainedVariance = False diff --git a/orangecontrib/spectroscopy/preprocess/utils.py b/orangecontrib/spectroscopy/preprocess/utils.py index 000c55984..f4f832c85 100644 --- a/orangecontrib/spectroscopy/preprocess/utils.py +++ b/orangecontrib/spectroscopy/preprocess/utils.py @@ -97,15 +97,7 @@ def __init__(self, reference: Table, domain: Domain): self.reference = reference def interpolate_extend_to(self, interpolate: Table, wavenumbers): - """ - Interpolate data to given wavenumbers and extend the possibly - nan-edges with the nearest values. - """ - # interpolate reference to the given wavenumbers - X = interp1d_with_unknowns_numpy(getx(interpolate), interpolate.X, wavenumbers) - # we know that X is not NaN. same handling of reference as of X - X, _ = nan_extend_edges_and_interpolate(wavenumbers, X) - return X + return interpolate_extend_to(interpolate, wavenumbers) def __eq__(self, other): return super().__eq__(other) \ @@ -329,3 +321,15 @@ def replacex(data: Table, replacement: list): natts = [at.renamed(str(n)) for n, at in zip(replacement, data.domain.attributes)] ndom = Domain(natts, data.domain.class_vars, data.domain.metas) return data.transform(ndom) + + +def interpolate_extend_to(interpolate: Table, wavenumbers): + """ + Interpolate data to given wavenumbers and extend the possibly + nan-edges with the nearest values. + """ + # interpolate reference to the given wavenumbers + X = interp1d_with_unknowns_numpy(getx(interpolate), interpolate.X, wavenumbers) + # we know that X is not NaN. same handling of reference as of X + X, _ = nan_extend_edges_and_interpolate(wavenumbers, X) + return X From 4dc6019d64391dc8f4cb9c16a65475e54e5090d8 Mon Sep 17 00:00:00 2001 From: Marko Toplak Date: Thu, 16 Nov 2023 11:55:45 +0100 Subject: [PATCH 2/5] emsc: __eq__ and __hash__ --- orangecontrib/spectroscopy/preprocess/emsc.py | 27 ++++++++-- .../spectroscopy/preprocess/utils.py | 7 +-- orangecontrib/spectroscopy/tests/test_emsc.py | 49 +++++++++++++++++++ .../tests/test_preprocess_utils.py | 14 +++--- 4 files changed, 82 insertions(+), 15 deletions(-) diff --git a/orangecontrib/spectroscopy/preprocess/emsc.py b/orangecontrib/spectroscopy/preprocess/emsc.py index 1ee701dcf..eb884f6eb 100644 --- a/orangecontrib/spectroscopy/preprocess/emsc.py +++ b/orangecontrib/spectroscopy/preprocess/emsc.py @@ -1,12 +1,14 @@ import numpy as np import Orange +from Orange.data import Table from Orange.preprocess.preprocess import Preprocess from Orange.data.util import get_unique_names from orangecontrib.spectroscopy.data import getx from orangecontrib.spectroscopy.preprocess.utils import SelectColumn, CommonDomainOrderUnknowns, \ - interp1d_with_unknowns_numpy, MissingReferenceException, interpolate_extend_to + interp1d_with_unknowns_numpy, MissingReferenceException, interpolate_extend_to, \ + CommonDomainRef, table_eq_x from orangecontrib.spectroscopy.preprocess.npfunc import Function, Segments @@ -61,12 +63,12 @@ class EMSCModel(SelectColumn): InheritEq = True -class _EMSC(CommonDomainOrderUnknowns): +class _EMSC(CommonDomainOrderUnknowns, CommonDomainRef): def __init__(self, reference, badspectra, weights, order, scaling, domain): - super().__init__(domain) - assert len(reference) == 1 - self.reference = reference + CommonDomainOrderUnknowns.__init__(self, domain) + CommonDomainRef.__init__(self, reference, domain) + assert len(self.reference) == 1 self.badspectra = badspectra self.weights = weights self.order = order @@ -113,6 +115,21 @@ def transformed(self, X, wavenumbers): return newspectra + def __eq__(self, other): + return CommonDomainRef.__eq__(self, other) \ + and table_eq_x(self.badspectra, other.badspectra) \ + and self.order == other.order \ + and self.scaling == other.scaling \ + and (self.weights == other.weights + if not isinstance(self.weights, Table) + else table_eq_x(self.weights, other.weights)) + + def __hash__(self): + domain = self.badspectra.domain if self.badspectra is not None else None + fv = tuple(self.badspectra.X[0][:10]) if self.badspectra is not None else None + weights = self.weights if not isinstance(self.weights, Table) else tuple(self.weights.X[0][:10]) + return hash((CommonDomainRef.__hash__(self), domain, fv, weights, self.order, self.scaling)) + def average_table_x(data): return Orange.data.Table.from_numpy(Orange.data.Domain(data.domain.attributes), diff --git a/orangecontrib/spectroscopy/preprocess/utils.py b/orangecontrib/spectroscopy/preprocess/utils.py index f4f832c85..54befe265 100644 --- a/orangecontrib/spectroscopy/preprocess/utils.py +++ b/orangecontrib/spectroscopy/preprocess/utils.py @@ -101,11 +101,12 @@ def interpolate_extend_to(self, interpolate: Table, wavenumbers): def __eq__(self, other): return super().__eq__(other) \ - and reference_eq_X(self.reference, other.reference) + and table_eq_x(self.reference, other.reference) def __hash__(self): domain = self.reference.domain if self.reference is not None else None - return hash((super().__hash__(), domain)) + fv = tuple(self.reference.X[0][:10]) if self.reference is not None else None + return hash((super().__hash__(), domain, fv)) class CommonDomainOrder(CommonDomain): @@ -185,7 +186,7 @@ def __hash__(self): return super().__hash__() -def reference_eq_X(first: Optional[Table], second: Optional[Table]): +def table_eq_x(first: Optional[Table], second: Optional[Table]): if first is second: return True elif first is None or second is None: diff --git a/orangecontrib/spectroscopy/tests/test_emsc.py b/orangecontrib/spectroscopy/tests/test_emsc.py index c588b8514..35b8fae02 100644 --- a/orangecontrib/spectroscopy/tests/test_emsc.py +++ b/orangecontrib/spectroscopy/tests/test_emsc.py @@ -138,6 +138,55 @@ def test_multiple_badspectra(self): fdata.metas, [[1.375, 1.375, 3.0, 6.0, 2.0]]) + def test_eq(self): + data = Table.from_numpy(None, [[0, 0.25, 4.5, 4.75, 1.0, 1.25, + 7.5, 7.75, 2.0, 5.25, 5.5, 2.75]]) + data_ref = Table.from_numpy(None, [[0, 0, 2, 2, 0, 0, 3, 3, 0, 0, 0, 0]]) + badspec = Table.from_numpy(None, [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0]]) + + d1 = EMSC(reference=data_ref[0:1], badspectra=badspec, order=1, output_model=True)(data) + d2 = EMSC(reference=data_ref[0:1], badspectra=badspec, order=1, output_model=True)(data) + self.assertEqual(d1.domain, d2.domain) + self.assertEqual(hash(d1.domain), hash(d2.domain)) + + d2 = EMSC(reference=Table.from_numpy(None, [[1, 0, 2, 2, 0, 0, 3, 3, 0, 0, 0, 0]]), + badspectra=badspec, order=1, output_model=True)(data) + self.assertNotEqual(d1.domain, d2.domain) + self.assertNotEqual(hash(d1.domain), hash(d2.domain)) + + d2 = EMSC(reference=data_ref[0:1], + badspectra=Table.from_numpy(None, [[1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0]]), + order=1, output_model=True)(data) + self.assertNotEqual(d1.domain, d2.domain) + self.assertNotEqual(hash(d1.domain), hash(d2.domain)) + + d2 = EMSC(reference=data, badspectra=badspec, order=1, output_model=True)(data) + self.assertNotEqual(d1.domain, d2.domain) + self.assertNotEqual(hash(d1.domain), hash(d2.domain)) + + weight_table = spectra_table([-5e-324, 0.0, 3.0, 3.0000000000000004], + [[0, 1, 1, 0]]) + d2 = EMSC(reference=data_ref[0:1], badspectra=badspec, order=1, output_model=True, + weights=weight_table)(data) + self.assertNotEqual(d1.domain, d2.domain) + self.assertNotEqual(hash(d1.domain), hash(d2.domain)) + + d3 = EMSC(reference=data_ref[0:1], badspectra=badspec, order=1, output_model=True, + weights=weight_table)(data) + self.assertEqual(d3.domain, d2.domain) + self.assertEqual(hash(d3.domain), hash(d2.domain)) + + d2 = EMSC(reference=data_ref[0:1], badspectra=badspec, order=1, output_model=True, + weights=SelectionFunction(0, 3, 1))(data) + self.assertNotEqual(d1.domain, d2.domain) + self.assertNotEqual(hash(d1.domain), hash(d2.domain)) + + # these two are not the same because SelectionFunction does not define __eq__ and __hash__ + d3 = EMSC(reference=data_ref[0:1], badspectra=badspec, order=1, output_model=True, + weights=SelectionFunction(0, 3, 1))(data) + self.assertNotEqual(d3.domain, d2.domain) + self.assertNotEqual(hash(d3.domain), hash(d2.domain)) + class TestSelectionFuctions(unittest.TestCase): diff --git a/orangecontrib/spectroscopy/tests/test_preprocess_utils.py b/orangecontrib/spectroscopy/tests/test_preprocess_utils.py index 884eeb1bf..38f5f055a 100644 --- a/orangecontrib/spectroscopy/tests/test_preprocess_utils.py +++ b/orangecontrib/spectroscopy/tests/test_preprocess_utils.py @@ -2,7 +2,7 @@ from Orange.data import Table -from orangecontrib.spectroscopy.preprocess.utils import reference_eq_X +from orangecontrib.spectroscopy.preprocess.utils import table_eq_x class TestEq(TestCase): @@ -17,11 +17,11 @@ def setUpClass(cls): def test_reference_eq_X_none(self): data = self.iris - self.assertTrue(reference_eq_X(None, None)) - self.assertFalse(reference_eq_X(data, None)) - self.assertFalse(reference_eq_X(None, data)) + self.assertTrue(table_eq_x(None, None)) + self.assertFalse(table_eq_x(data, None)) + self.assertFalse(table_eq_x(None, data)) def test_reference_eq_X_same(self): - self.assertTrue(reference_eq_X(self.iris, self.iris)) - self.assertTrue(reference_eq_X(self.iris, self.iris2)) - self.assertFalse(reference_eq_X(self.iris, self.iris_changed)) + self.assertTrue(table_eq_x(self.iris, self.iris)) + self.assertTrue(table_eq_x(self.iris, self.iris2)) + self.assertFalse(table_eq_x(self.iris, self.iris_changed)) From baf315631679e08417828fed542648bb3cfc4298 Mon Sep 17 00:00:00 2001 From: Marko Toplak Date: Thu, 16 Nov 2023 12:01:52 +0100 Subject: [PATCH 3/5] lint --- orangecontrib/spectroscopy/preprocess/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/orangecontrib/spectroscopy/preprocess/utils.py b/orangecontrib/spectroscopy/preprocess/utils.py index 54befe265..4f302c92a 100644 --- a/orangecontrib/spectroscopy/preprocess/utils.py +++ b/orangecontrib/spectroscopy/preprocess/utils.py @@ -80,7 +80,7 @@ def transform_domain(self, data): return data def transformed(self, data): - raise NotImplemented + raise NotImplementedError def __eq__(self, other): return type(self) is type(other) \ @@ -131,7 +131,7 @@ def _restore_order(self, X, mon, xsind, xc): return np.hstack((restored, X[:, xc:])) def transformed(self, X, wavenumbers): - raise NotImplemented + raise NotImplementedError def __eq__(self, other): # pylint: disable=useless-parent-delegation From ce9571a1268fccf520e14458c80226e00df04c60 Mon Sep 17 00:00:00 2001 From: Marko Toplak Date: Thu, 16 Nov 2023 13:17:57 +0100 Subject: [PATCH 4/5] SelectionFunction and SmoothedSelectionFunction: __eq__ and __hash__ --- orangecontrib/spectroscopy/preprocess/emsc.py | 51 +++++++++++++++---- .../spectroscopy/preprocess/npfunc.py | 35 ++++++++++++- orangecontrib/spectroscopy/tests/test_emsc.py | 6 ++- 3 files changed, 78 insertions(+), 14 deletions(-) diff --git a/orangecontrib/spectroscopy/preprocess/emsc.py b/orangecontrib/spectroscopy/preprocess/emsc.py index eb884f6eb..cd859de8f 100644 --- a/orangecontrib/spectroscopy/preprocess/emsc.py +++ b/orangecontrib/spectroscopy/preprocess/emsc.py @@ -12,28 +12,57 @@ from orangecontrib.spectroscopy.preprocess.npfunc import Function, Segments -class SelectionFunction(Segments): +class SelectionFunction(Function): """ Weighted selection function. Includes min and max. """ def __init__(self, min_, max_, w): - super().__init__((lambda x: True, - lambda x: 0), - (lambda x: np.logical_and(x >= min_, x <= max_), - lambda x: w)) + super().__init__(None) + self.min_ = min_ + self.max_ = max_ + self.w = w + + def __call__(self, x): + seg = Segments((lambda x: True, lambda x: 0), + (lambda x: np.logical_and(x >= self.min_, x <= self.max_), + lambda x: self.w) + ) + return seg(x) + + def __eq__(self, other): + return super().__eq__(other) \ + and self.min_ == other.min_ \ + and self.max_ == other.max_ \ + and self.w == other.w + + def __hash__(self): + return hash((super().__hash__(), self.min_, self.max_, self.w)) -class SmoothedSelectionFunction(Segments): +class SmoothedSelectionFunction(SelectionFunction): """ Weighted selection function. Min and max points are middle points of smoothing with hyperbolic tangent. """ def __init__(self, min_, max_, s, w): - middle = (min_ + max_) / 2 - super().__init__((lambda x: x < middle, - lambda x: (np.tanh((x - min_) / s) + 1) / 2 * w), - (lambda x: x >= middle, - lambda x: (-np.tanh((x - max_) / s) + 1) / 2 * w)) + super().__init__(min_, max_, w) + self.s = s + + def __call__(self, x): + middle = (self.min_ + self.max_) / 2 + seg = Segments((lambda x: x < middle, + lambda x: (np.tanh((x - self.min_) / self.s) + 1) / 2 * self.w), + (lambda x: x >= middle, + lambda x: (-np.tanh((x - self.max_) / self.s) + 1) / 2 * self.w) + ) + return seg(x) + + def __eq__(self, other): + return super().__eq__(other) \ + and self.s == other.s + + def __hash__(self): + return hash((super().__hash__(), self.s)) def weighted_wavenumbers(weights, wavenumbers): diff --git a/orangecontrib/spectroscopy/preprocess/npfunc.py b/orangecontrib/spectroscopy/preprocess/npfunc.py index 062d2f36c..3bf29cbac 100644 --- a/orangecontrib/spectroscopy/preprocess/npfunc.py +++ b/orangecontrib/spectroscopy/preprocess/npfunc.py @@ -1,7 +1,7 @@ import numpy as np -class Function(): +class Function: def __init__(self, fn): self.fn = fn @@ -9,21 +9,36 @@ def __init__(self, fn): def __call__(self, x): return self.fn(x) + def __eq__(self, other): + return type(self) is type(other) \ + and self.fn == other.fn + + def __hash__(self): + return hash((type(self), self.fn)) + class Constant(Function): def __init__(self, c): + super().__init__(None) self.c = c def __call__(self, x): x = np.asarray(x) return np.ones(x.shape)*self.c + def __eq__(self, other): + return super().__eq__(other) \ + and self.c == other.c + + def __hash__(self): + return hash((super().__hash__(), self.c)) + class Identity(Function): def __init__(self): - pass + super().__init__(None) def __call__(self, x): return x @@ -38,6 +53,7 @@ class Segments(Function): """ def __init__(self, *segments): + super().__init__(None) self.segments = segments def __call__(self, x): @@ -48,10 +64,18 @@ def __call__(self, x): output[ind] = fn(x[ind]) return output + def __eq__(self, other): + return super().__eq__(other) \ + and self.segments == other.segments + + def __hash__(self): + return hash((super().__hash__(), self.segments)) + class Sum(Function): def __init__(self, *elements): + super().__init__(None) self.elements = elements def __call__(self, x): @@ -63,3 +87,10 @@ def __call__(self, x): else: acc = acc + current return acc + + def __eq__(self, other): + return super().__eq__(other) \ + and self.segments == other.elements + + def __hash__(self): + return hash((super().__hash__(), self.elements)) diff --git a/orangecontrib/spectroscopy/tests/test_emsc.py b/orangecontrib/spectroscopy/tests/test_emsc.py index 35b8fae02..e349a07c8 100644 --- a/orangecontrib/spectroscopy/tests/test_emsc.py +++ b/orangecontrib/spectroscopy/tests/test_emsc.py @@ -181,9 +181,13 @@ def test_eq(self): self.assertNotEqual(d1.domain, d2.domain) self.assertNotEqual(hash(d1.domain), hash(d2.domain)) - # these two are not the same because SelectionFunction does not define __eq__ and __hash__ d3 = EMSC(reference=data_ref[0:1], badspectra=badspec, order=1, output_model=True, weights=SelectionFunction(0, 3, 1))(data) + self.assertEqual(d3.domain, d2.domain) + self.assertEqual(hash(d3.domain), hash(d2.domain)) + + d3 = EMSC(reference=data_ref[0:1], badspectra=badspec, order=1, output_model=True, + weights=SmoothedSelectionFunction(0, 3, 1, 0.5))(data) self.assertNotEqual(d3.domain, d2.domain) self.assertNotEqual(hash(d3.domain), hash(d2.domain)) From 0636def776f558dec57c88aa0a8a5a196cc1be40 Mon Sep 17 00:00:00 2001 From: Marko Toplak Date: Thu, 16 Nov 2023 14:06:39 +0100 Subject: [PATCH 5/5] meemsc: __eq__ and __hash__ --- .../spectroscopy/preprocess/me_emsc.py | 29 +++++++++++++--- .../spectroscopy/tests/test_me_emsc.py | 33 +++++++++++++++++++ 2 files changed, 57 insertions(+), 5 deletions(-) diff --git a/orangecontrib/spectroscopy/preprocess/me_emsc.py b/orangecontrib/spectroscopy/preprocess/me_emsc.py index 68a26d3cc..88b92997b 100644 --- a/orangecontrib/spectroscopy/preprocess/me_emsc.py +++ b/orangecontrib/spectroscopy/preprocess/me_emsc.py @@ -3,12 +3,13 @@ from sklearn.decomposition import TruncatedSVD import Orange +from Orange.data import Table from Orange.preprocess.preprocess import Preprocess from Orange.data.util import get_unique_names from orangecontrib.spectroscopy.data import getx from orangecontrib.spectroscopy.preprocess.utils import SelectColumn, CommonDomainOrderUnknowns, \ - interpolate_extend_to + interpolate_extend_to, CommonDomainRef, table_eq_x from orangecontrib.spectroscopy.preprocess.emsc import weighted_wavenumbers, average_table_x @@ -87,13 +88,13 @@ class ME_EMSCModel(SelectColumn): InheritEq = True -class _ME_EMSC(CommonDomainOrderUnknowns): +class _ME_EMSC(CommonDomainOrderUnknowns, CommonDomainRef): def __init__(self, reference, weights, ncomp, alpha0, gamma, maxNiter, fixedNiter, positiveRef, domain): - super().__init__(domain) + CommonDomainOrderUnknowns.__init__(self, domain) + CommonDomainRef.__init__(self, reference, domain) assert len(reference) == 1 - self.reference = reference - self.weights = weights # !!! THIS SHOULD BE A NP ARRAY (or similar) with inflection points + self.weights = weights self.ncomp = ncomp self.alpha0 = alpha0 self.gamma = gamma @@ -276,6 +277,24 @@ def iterate(spectra, correctedFirsIteration, residualsFirstIteration, wavenumber newspectra = np.hstack((newspectra, numberOfIterations.reshape(-1, 1),RMSEall.reshape(-1, 1))) return newspectra + def __eq__(self, other): + return CommonDomainRef.__eq__(self, other) \ + and self.ncomp == other.ncomp \ + and np.array_equal(self.alpha0, other.alpha0) \ + and np.array_equal(self.gamma, other.gamma) \ + and self.maxNiter == other.maxNiter \ + and self.fixedNiter == other.fixedNiter \ + and self.positiveRef == other.positiveRef \ + and (self.weights == other.weights + if not isinstance(self.weights, Table) + else table_eq_x(self.weights, other.weights)) + + def __hash__(self): + weights = self.weights \ + if not isinstance(self.weights, Table) else tuple(self.weights.X[0][:10]) + return hash((CommonDomainRef.__hash__(self), weights, self.ncomp, tuple(self.alpha0), + tuple(self.gamma), self.maxNiter, self.fixedNiter, self.positiveRef)) + class MissingReferenceException(Exception): pass diff --git a/orangecontrib/spectroscopy/tests/test_me_emsc.py b/orangecontrib/spectroscopy/tests/test_me_emsc.py index 8722d229d..b651d5a3e 100644 --- a/orangecontrib/spectroscopy/tests/test_me_emsc.py +++ b/orangecontrib/spectroscopy/tests/test_me_emsc.py @@ -208,6 +208,39 @@ def test_short_reference(self): # it was crashing before ME_EMSC(reference=reference)(self.spectra) + def test_eq(self): + ref = self.reference + spectra = self.spectra + d1 = ME_EMSC(reference=ref, ncomp=False, weights=False, max_iter=1)(spectra) + d2 = ME_EMSC(reference=ref, ncomp=False, weights=False, max_iter=1)(spectra) + self.assertEqual(d1.domain, d2.domain) + self.assertEqual(hash(d1.domain), hash(d2.domain)) + + d2 = ME_EMSC(reference=ref, ncomp=2, weights=False, max_iter=1)(spectra) + self.assertNotEqual(d1.domain, d2.domain) + self.assertNotEqual(hash(d1.domain), hash(d2.domain)) + + d2 = ME_EMSC(reference=ref, ncomp=False, weights=False, max_iter=1, + n0=np.linspace(1.1, 1.4, 11), + a=np.linspace(2, 7.1, 11))(spectra) + self.assertNotEqual(d1.domain, d2.domain) + self.assertNotEqual(hash(d1.domain), hash(d2.domain)) + + d2 = ME_EMSC(reference=ref, ncomp=False, weights=False, max_iter=2)(spectra) + self.assertNotEqual(d1.domain, d2.domain) + self.assertNotEqual(hash(d1.domain), hash(d2.domain)) + + r2 = ref.copy() + d2 = ME_EMSC(reference=r2, ncomp=False, weights=False, max_iter=1)(spectra) + self.assertEqual(d1.domain, d2.domain) + self.assertEqual(hash(d1.domain), hash(d2.domain)) + + with r2.unlocked(): + r2[0][0] = 1 + d2 = ME_EMSC(reference=r2, ncomp=False, weights=False, max_iter=1)(spectra) + self.assertNotEqual(d1.domain, d2.domain) + self.assertNotEqual(hash(d1.domain), hash(d2.domain)) + class TestInflectionPointWeighting(unittest.TestCase):