diff --git a/orangecontrib/spectroscopy/preprocess/emsc.py b/orangecontrib/spectroscopy/preprocess/emsc.py index 6c7ffefb6..cd859de8f 100644 --- a/orangecontrib/spectroscopy/preprocess/emsc.py +++ b/orangecontrib/spectroscopy/preprocess/emsc.py @@ -1,37 +1,68 @@ import numpy as np import Orange +from Orange.data import Table from Orange.preprocess.preprocess import Preprocess from Orange.data.util import get_unique_names -from orangecontrib.spectroscopy.data import getx, spectra_mean +from orangecontrib.spectroscopy.data import getx from orangecontrib.spectroscopy.preprocess.utils import SelectColumn, CommonDomainOrderUnknowns, \ - interp1d_with_unknowns_numpy, nan_extend_edges_and_interpolate, MissingReferenceException + interp1d_with_unknowns_numpy, MissingReferenceException, interpolate_extend_to, \ + CommonDomainRef, table_eq_x from orangecontrib.spectroscopy.preprocess.npfunc import Function, Segments -class SelectionFunction(Segments): +class SelectionFunction(Function): """ Weighted selection function. Includes min and max. """ def __init__(self, min_, max_, w): - super().__init__((lambda x: True, - lambda x: 0), - (lambda x: np.logical_and(x >= min_, x <= max_), - lambda x: w)) + super().__init__(None) + self.min_ = min_ + self.max_ = max_ + self.w = w + def __call__(self, x): + seg = Segments((lambda x: True, lambda x: 0), + (lambda x: np.logical_and(x >= self.min_, x <= self.max_), + lambda x: self.w) + ) + return seg(x) -class SmoothedSelectionFunction(Segments): + def __eq__(self, other): + return super().__eq__(other) \ + and self.min_ == other.min_ \ + and self.max_ == other.max_ \ + and self.w == other.w + + def __hash__(self): + return hash((super().__hash__(), self.min_, self.max_, self.w)) + + +class SmoothedSelectionFunction(SelectionFunction): """ Weighted selection function. Min and max points are middle points of smoothing with hyperbolic tangent. """ def __init__(self, min_, max_, s, w): - middle = (min_ + max_) / 2 - super().__init__((lambda x: x < middle, - lambda x: (np.tanh((x - min_) / s) + 1) / 2 * w), - (lambda x: x >= middle, - lambda x: (-np.tanh((x - max_) / s) + 1) / 2 * w)) + super().__init__(min_, max_, w) + self.s = s + + def __call__(self, x): + middle = (self.min_ + self.max_) / 2 + seg = Segments((lambda x: x < middle, + lambda x: (np.tanh((x - self.min_) / self.s) + 1) / 2 * self.w), + (lambda x: x >= middle, + lambda x: (-np.tanh((x - self.max_) / self.s) + 1) / 2 * self.w) + ) + return seg(x) + + def __eq__(self, other): + return super().__eq__(other) \ + and self.s == other.s + + def __hash__(self): + return hash((super().__hash__(), self.s)) def weighted_wavenumbers(weights, wavenumbers): @@ -61,11 +92,12 @@ class EMSCModel(SelectColumn): InheritEq = True -class _EMSC(CommonDomainOrderUnknowns): +class _EMSC(CommonDomainOrderUnknowns, CommonDomainRef): def __init__(self, reference, badspectra, weights, order, scaling, domain): - super().__init__(domain) - self.reference = reference + CommonDomainOrderUnknowns.__init__(self, domain) + CommonDomainRef.__init__(self, reference, domain) + assert len(self.reference) == 1 self.badspectra = badspectra self.weights = weights self.order = order @@ -74,17 +106,7 @@ def __init__(self, reference, badspectra, weights, order, scaling, domain): def transformed(self, X, wavenumbers): # wavenumber have to be input as sorted # about 85% of time in __call__ function is spent is lstsq - # compute average spectrum from the reference - ref_X = np.atleast_2d(spectra_mean(self.reference.X)) - - def interpolate_to_data(other_xs, other_data): - # all input data needs to be interpolated (and NaNs removed) - interpolated = interp1d_with_unknowns_numpy(other_xs, other_data, wavenumbers) - # we know that X is not NaN. same handling of reference as of X - interpolated, _ = nan_extend_edges_and_interpolate(wavenumbers, interpolated) - return interpolated - - ref_X = interpolate_to_data(getx(self.reference), ref_X) + ref_X = interpolate_extend_to(self.reference, wavenumbers) wei_X = weighted_wavenumbers(self.weights, wavenumbers) N = wavenumbers.shape[0] @@ -93,7 +115,7 @@ def interpolate_to_data(other_xs, other_data): n_badspec = len(self.badspectra) if self.badspectra is not None else 0 if self.badspectra: - badspectra_X = interpolate_to_data(getx(self.badspectra), self.badspectra.X) + badspectra_X = interpolate_extend_to(self.badspectra, wavenumbers) M = [] for x in range(0, self.order+1): @@ -122,6 +144,26 @@ def interpolate_to_data(other_xs, other_data): return newspectra + def __eq__(self, other): + return CommonDomainRef.__eq__(self, other) \ + and table_eq_x(self.badspectra, other.badspectra) \ + and self.order == other.order \ + and self.scaling == other.scaling \ + and (self.weights == other.weights + if not isinstance(self.weights, Table) + else table_eq_x(self.weights, other.weights)) + + def __hash__(self): + domain = self.badspectra.domain if self.badspectra is not None else None + fv = tuple(self.badspectra.X[0][:10]) if self.badspectra is not None else None + weights = self.weights if not isinstance(self.weights, Table) else tuple(self.weights.X[0][:10]) + return hash((CommonDomainRef.__hash__(self), domain, fv, weights, self.order, self.scaling)) + + +def average_table_x(data): + return Orange.data.Table.from_numpy(Orange.data.Domain(data.domain.attributes), + X=data.X.mean(axis=0, keepdims=True)) + class EMSC(Preprocess): @@ -132,6 +174,8 @@ def __init__(self, reference=None, badspectra=None, weights=None, order=2, scali if reference is None: raise MissingReferenceException() self.reference = reference + if len(self.reference) > 1: + self.reference = average_table_x(self.reference) self.badspectra = badspectra self.weights = weights self.order = order diff --git a/orangecontrib/spectroscopy/preprocess/me_emsc.py b/orangecontrib/spectroscopy/preprocess/me_emsc.py index eb8eaa244..88b92997b 100644 --- a/orangecontrib/spectroscopy/preprocess/me_emsc.py +++ b/orangecontrib/spectroscopy/preprocess/me_emsc.py @@ -3,21 +3,14 @@ from sklearn.decomposition import TruncatedSVD import Orange +from Orange.data import Table from Orange.preprocess.preprocess import Preprocess from Orange.data.util import get_unique_names -from orangecontrib.spectroscopy.data import getx, spectra_mean +from orangecontrib.spectroscopy.data import getx from orangecontrib.spectroscopy.preprocess.utils import SelectColumn, CommonDomainOrderUnknowns, \ - interp1d_with_unknowns_numpy, nan_extend_edges_and_interpolate -from orangecontrib.spectroscopy.preprocess.emsc import weighted_wavenumbers - - -def interpolate_to_data(other_xs, other_data, wavenumbers): - # all input data needs to be interpolated (and NaNs removed) - interpolated = interp1d_with_unknowns_numpy(other_xs, other_data, wavenumbers) - # we know that X is not NaN. same handling of reference as of X - interpolated, _ = nan_extend_edges_and_interpolate(wavenumbers, interpolated) - return interpolated + interpolate_extend_to, CommonDomainRef, table_eq_x +from orangecontrib.spectroscopy.preprocess.emsc import weighted_wavenumbers, average_table_x def calculate_complex_n(ref_X,wavenumbers): @@ -95,12 +88,13 @@ class ME_EMSCModel(SelectColumn): InheritEq = True -class _ME_EMSC(CommonDomainOrderUnknowns): +class _ME_EMSC(CommonDomainOrderUnknowns, CommonDomainRef): def __init__(self, reference, weights, ncomp, alpha0, gamma, maxNiter, fixedNiter, positiveRef, domain): - super().__init__(domain) - self.reference = reference - self.weights = weights # !!! THIS SHOULD BE A NP ARRAY (or similar) with inflection points + CommonDomainOrderUnknowns.__init__(self, domain) + CommonDomainRef.__init__(self, reference, domain) + assert len(reference) == 1 + self.weights = weights self.ncomp = ncomp self.alpha0 = alpha0 self.gamma = gamma @@ -234,8 +228,7 @@ def iterate(spectra, correctedFirsIteration, residualsFirstIteration, wavenumber break return newspectra, RMSEall, numberOfIterations - ref_X = np.atleast_2d(spectra_mean(self.reference.X)) - ref_X = interpolate_to_data(getx(self.reference), ref_X, wavenumbers) + ref_X = interpolate_extend_to(self.reference, wavenumbers) ref_X = ref_X[0] wei_X = weighted_wavenumbers(self.weights, wavenumbers) @@ -284,6 +277,24 @@ def iterate(spectra, correctedFirsIteration, residualsFirstIteration, wavenumber newspectra = np.hstack((newspectra, numberOfIterations.reshape(-1, 1),RMSEall.reshape(-1, 1))) return newspectra + def __eq__(self, other): + return CommonDomainRef.__eq__(self, other) \ + and self.ncomp == other.ncomp \ + and np.array_equal(self.alpha0, other.alpha0) \ + and np.array_equal(self.gamma, other.gamma) \ + and self.maxNiter == other.maxNiter \ + and self.fixedNiter == other.fixedNiter \ + and self.positiveRef == other.positiveRef \ + and (self.weights == other.weights + if not isinstance(self.weights, Table) + else table_eq_x(self.weights, other.weights)) + + def __hash__(self): + weights = self.weights \ + if not isinstance(self.weights, Table) else tuple(self.weights.X[0][:10]) + return hash((CommonDomainRef.__hash__(self), weights, self.ncomp, tuple(self.alpha0), + tuple(self.gamma), self.maxNiter, self.fixedNiter, self.positiveRef)) + class MissingReferenceException(Exception): pass @@ -298,6 +309,8 @@ def __init__(self, reference=None, weights=None, ncomp=False, n0=np.linspace(1.1 if reference is None: raise MissingReferenceException() self.reference = reference + if len(self.reference) > 1: + self.reference = average_table_x(self.reference) self.weights = weights self.ncomp = ncomp self.output_model = output_model @@ -315,10 +328,8 @@ def __init__(self, reference=None, weights=None, ncomp=False, n0=np.linspace(1.1 self.gamma = self.h * np.log(10) / (4 * np.pi * 0.5 * np.pi * (self.n0 - 1) * self.a * 1e-6) if not self.ncomp: - ref_X = np.atleast_2d(spectra_mean(self.reference.X)) wavenumbers_ref = np.array(sorted(getx(self.reference))) - ref_X = interpolate_to_data(getx(self.reference), ref_X, wavenumbers_ref) - ref_X = ref_X[0] + ref_X = interpolate_extend_to(self.reference, wavenumbers_ref)[0] self.ncomp = cal_ncomp(ref_X, wavenumbers_ref, explainedVariance, self.alpha0, self.gamma) else: self.explainedVariance = False diff --git a/orangecontrib/spectroscopy/preprocess/npfunc.py b/orangecontrib/spectroscopy/preprocess/npfunc.py index 062d2f36c..3bf29cbac 100644 --- a/orangecontrib/spectroscopy/preprocess/npfunc.py +++ b/orangecontrib/spectroscopy/preprocess/npfunc.py @@ -1,7 +1,7 @@ import numpy as np -class Function(): +class Function: def __init__(self, fn): self.fn = fn @@ -9,21 +9,36 @@ def __init__(self, fn): def __call__(self, x): return self.fn(x) + def __eq__(self, other): + return type(self) is type(other) \ + and self.fn == other.fn + + def __hash__(self): + return hash((type(self), self.fn)) + class Constant(Function): def __init__(self, c): + super().__init__(None) self.c = c def __call__(self, x): x = np.asarray(x) return np.ones(x.shape)*self.c + def __eq__(self, other): + return super().__eq__(other) \ + and self.c == other.c + + def __hash__(self): + return hash((super().__hash__(), self.c)) + class Identity(Function): def __init__(self): - pass + super().__init__(None) def __call__(self, x): return x @@ -38,6 +53,7 @@ class Segments(Function): """ def __init__(self, *segments): + super().__init__(None) self.segments = segments def __call__(self, x): @@ -48,10 +64,18 @@ def __call__(self, x): output[ind] = fn(x[ind]) return output + def __eq__(self, other): + return super().__eq__(other) \ + and self.segments == other.segments + + def __hash__(self): + return hash((super().__hash__(), self.segments)) + class Sum(Function): def __init__(self, *elements): + super().__init__(None) self.elements = elements def __call__(self, x): @@ -63,3 +87,10 @@ def __call__(self, x): else: acc = acc + current return acc + + def __eq__(self, other): + return super().__eq__(other) \ + and self.segments == other.elements + + def __hash__(self): + return hash((super().__hash__(), self.elements)) diff --git a/orangecontrib/spectroscopy/preprocess/utils.py b/orangecontrib/spectroscopy/preprocess/utils.py index 000c55984..4f302c92a 100644 --- a/orangecontrib/spectroscopy/preprocess/utils.py +++ b/orangecontrib/spectroscopy/preprocess/utils.py @@ -80,7 +80,7 @@ def transform_domain(self, data): return data def transformed(self, data): - raise NotImplemented + raise NotImplementedError def __eq__(self, other): return type(self) is type(other) \ @@ -97,23 +97,16 @@ def __init__(self, reference: Table, domain: Domain): self.reference = reference def interpolate_extend_to(self, interpolate: Table, wavenumbers): - """ - Interpolate data to given wavenumbers and extend the possibly - nan-edges with the nearest values. - """ - # interpolate reference to the given wavenumbers - X = interp1d_with_unknowns_numpy(getx(interpolate), interpolate.X, wavenumbers) - # we know that X is not NaN. same handling of reference as of X - X, _ = nan_extend_edges_and_interpolate(wavenumbers, X) - return X + return interpolate_extend_to(interpolate, wavenumbers) def __eq__(self, other): return super().__eq__(other) \ - and reference_eq_X(self.reference, other.reference) + and table_eq_x(self.reference, other.reference) def __hash__(self): domain = self.reference.domain if self.reference is not None else None - return hash((super().__hash__(), domain)) + fv = tuple(self.reference.X[0][:10]) if self.reference is not None else None + return hash((super().__hash__(), domain, fv)) class CommonDomainOrder(CommonDomain): @@ -138,7 +131,7 @@ def _restore_order(self, X, mon, xsind, xc): return np.hstack((restored, X[:, xc:])) def transformed(self, X, wavenumbers): - raise NotImplemented + raise NotImplementedError def __eq__(self, other): # pylint: disable=useless-parent-delegation @@ -193,7 +186,7 @@ def __hash__(self): return super().__hash__() -def reference_eq_X(first: Optional[Table], second: Optional[Table]): +def table_eq_x(first: Optional[Table], second: Optional[Table]): if first is second: return True elif first is None or second is None: @@ -329,3 +322,15 @@ def replacex(data: Table, replacement: list): natts = [at.renamed(str(n)) for n, at in zip(replacement, data.domain.attributes)] ndom = Domain(natts, data.domain.class_vars, data.domain.metas) return data.transform(ndom) + + +def interpolate_extend_to(interpolate: Table, wavenumbers): + """ + Interpolate data to given wavenumbers and extend the possibly + nan-edges with the nearest values. + """ + # interpolate reference to the given wavenumbers + X = interp1d_with_unknowns_numpy(getx(interpolate), interpolate.X, wavenumbers) + # we know that X is not NaN. same handling of reference as of X + X, _ = nan_extend_edges_and_interpolate(wavenumbers, X) + return X diff --git a/orangecontrib/spectroscopy/tests/test_emsc.py b/orangecontrib/spectroscopy/tests/test_emsc.py index c588b8514..e349a07c8 100644 --- a/orangecontrib/spectroscopy/tests/test_emsc.py +++ b/orangecontrib/spectroscopy/tests/test_emsc.py @@ -138,6 +138,59 @@ def test_multiple_badspectra(self): fdata.metas, [[1.375, 1.375, 3.0, 6.0, 2.0]]) + def test_eq(self): + data = Table.from_numpy(None, [[0, 0.25, 4.5, 4.75, 1.0, 1.25, + 7.5, 7.75, 2.0, 5.25, 5.5, 2.75]]) + data_ref = Table.from_numpy(None, [[0, 0, 2, 2, 0, 0, 3, 3, 0, 0, 0, 0]]) + badspec = Table.from_numpy(None, [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0]]) + + d1 = EMSC(reference=data_ref[0:1], badspectra=badspec, order=1, output_model=True)(data) + d2 = EMSC(reference=data_ref[0:1], badspectra=badspec, order=1, output_model=True)(data) + self.assertEqual(d1.domain, d2.domain) + self.assertEqual(hash(d1.domain), hash(d2.domain)) + + d2 = EMSC(reference=Table.from_numpy(None, [[1, 0, 2, 2, 0, 0, 3, 3, 0, 0, 0, 0]]), + badspectra=badspec, order=1, output_model=True)(data) + self.assertNotEqual(d1.domain, d2.domain) + self.assertNotEqual(hash(d1.domain), hash(d2.domain)) + + d2 = EMSC(reference=data_ref[0:1], + badspectra=Table.from_numpy(None, [[1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0]]), + order=1, output_model=True)(data) + self.assertNotEqual(d1.domain, d2.domain) + self.assertNotEqual(hash(d1.domain), hash(d2.domain)) + + d2 = EMSC(reference=data, badspectra=badspec, order=1, output_model=True)(data) + self.assertNotEqual(d1.domain, d2.domain) + self.assertNotEqual(hash(d1.domain), hash(d2.domain)) + + weight_table = spectra_table([-5e-324, 0.0, 3.0, 3.0000000000000004], + [[0, 1, 1, 0]]) + d2 = EMSC(reference=data_ref[0:1], badspectra=badspec, order=1, output_model=True, + weights=weight_table)(data) + self.assertNotEqual(d1.domain, d2.domain) + self.assertNotEqual(hash(d1.domain), hash(d2.domain)) + + d3 = EMSC(reference=data_ref[0:1], badspectra=badspec, order=1, output_model=True, + weights=weight_table)(data) + self.assertEqual(d3.domain, d2.domain) + self.assertEqual(hash(d3.domain), hash(d2.domain)) + + d2 = EMSC(reference=data_ref[0:1], badspectra=badspec, order=1, output_model=True, + weights=SelectionFunction(0, 3, 1))(data) + self.assertNotEqual(d1.domain, d2.domain) + self.assertNotEqual(hash(d1.domain), hash(d2.domain)) + + d3 = EMSC(reference=data_ref[0:1], badspectra=badspec, order=1, output_model=True, + weights=SelectionFunction(0, 3, 1))(data) + self.assertEqual(d3.domain, d2.domain) + self.assertEqual(hash(d3.domain), hash(d2.domain)) + + d3 = EMSC(reference=data_ref[0:1], badspectra=badspec, order=1, output_model=True, + weights=SmoothedSelectionFunction(0, 3, 1, 0.5))(data) + self.assertNotEqual(d3.domain, d2.domain) + self.assertNotEqual(hash(d3.domain), hash(d2.domain)) + class TestSelectionFuctions(unittest.TestCase): diff --git a/orangecontrib/spectroscopy/tests/test_me_emsc.py b/orangecontrib/spectroscopy/tests/test_me_emsc.py index 8722d229d..b651d5a3e 100644 --- a/orangecontrib/spectroscopy/tests/test_me_emsc.py +++ b/orangecontrib/spectroscopy/tests/test_me_emsc.py @@ -208,6 +208,39 @@ def test_short_reference(self): # it was crashing before ME_EMSC(reference=reference)(self.spectra) + def test_eq(self): + ref = self.reference + spectra = self.spectra + d1 = ME_EMSC(reference=ref, ncomp=False, weights=False, max_iter=1)(spectra) + d2 = ME_EMSC(reference=ref, ncomp=False, weights=False, max_iter=1)(spectra) + self.assertEqual(d1.domain, d2.domain) + self.assertEqual(hash(d1.domain), hash(d2.domain)) + + d2 = ME_EMSC(reference=ref, ncomp=2, weights=False, max_iter=1)(spectra) + self.assertNotEqual(d1.domain, d2.domain) + self.assertNotEqual(hash(d1.domain), hash(d2.domain)) + + d2 = ME_EMSC(reference=ref, ncomp=False, weights=False, max_iter=1, + n0=np.linspace(1.1, 1.4, 11), + a=np.linspace(2, 7.1, 11))(spectra) + self.assertNotEqual(d1.domain, d2.domain) + self.assertNotEqual(hash(d1.domain), hash(d2.domain)) + + d2 = ME_EMSC(reference=ref, ncomp=False, weights=False, max_iter=2)(spectra) + self.assertNotEqual(d1.domain, d2.domain) + self.assertNotEqual(hash(d1.domain), hash(d2.domain)) + + r2 = ref.copy() + d2 = ME_EMSC(reference=r2, ncomp=False, weights=False, max_iter=1)(spectra) + self.assertEqual(d1.domain, d2.domain) + self.assertEqual(hash(d1.domain), hash(d2.domain)) + + with r2.unlocked(): + r2[0][0] = 1 + d2 = ME_EMSC(reference=r2, ncomp=False, weights=False, max_iter=1)(spectra) + self.assertNotEqual(d1.domain, d2.domain) + self.assertNotEqual(hash(d1.domain), hash(d2.domain)) + class TestInflectionPointWeighting(unittest.TestCase): diff --git a/orangecontrib/spectroscopy/tests/test_preprocess_utils.py b/orangecontrib/spectroscopy/tests/test_preprocess_utils.py index 884eeb1bf..38f5f055a 100644 --- a/orangecontrib/spectroscopy/tests/test_preprocess_utils.py +++ b/orangecontrib/spectroscopy/tests/test_preprocess_utils.py @@ -2,7 +2,7 @@ from Orange.data import Table -from orangecontrib.spectroscopy.preprocess.utils import reference_eq_X +from orangecontrib.spectroscopy.preprocess.utils import table_eq_x class TestEq(TestCase): @@ -17,11 +17,11 @@ def setUpClass(cls): def test_reference_eq_X_none(self): data = self.iris - self.assertTrue(reference_eq_X(None, None)) - self.assertFalse(reference_eq_X(data, None)) - self.assertFalse(reference_eq_X(None, data)) + self.assertTrue(table_eq_x(None, None)) + self.assertFalse(table_eq_x(data, None)) + self.assertFalse(table_eq_x(None, data)) def test_reference_eq_X_same(self): - self.assertTrue(reference_eq_X(self.iris, self.iris)) - self.assertTrue(reference_eq_X(self.iris, self.iris2)) - self.assertFalse(reference_eq_X(self.iris, self.iris_changed)) + self.assertTrue(table_eq_x(self.iris, self.iris)) + self.assertTrue(table_eq_x(self.iris, self.iris2)) + self.assertFalse(table_eq_x(self.iris, self.iris_changed))