Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] PLS supports multiple targets #694

Merged
merged 2 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions orangecontrib/spectroscopy/models/pls.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import numpy as np
import pkg_resources
import sklearn
import sklearn.cross_decomposition as skl_pls

from Orange.data import Table, Domain, Variable, \
Expand Down Expand Up @@ -59,11 +61,17 @@

@property
def coefficients(self):
return self.skl_model.coef_
coef = self.skl_model.coef_
# 1.3 has transposed coef_
if pkg_resources.parse_version(sklearn.__version__) < pkg_resources.parse_version("1.3.0"):
coef = coef.T

Check warning on line 67 in orangecontrib/spectroscopy/models/pls.py

View check run for this annotation

Codecov / codecov/patch

orangecontrib/spectroscopy/models/pls.py#L67

Added line #L67 was not covered by tests
return coef

def predict(self, X):
vals = self.skl_model.predict(X)
return vals.ravel()
if len(self.domain.class_vars) == 1:
vals = vals.ravel()
return vals

def __str__(self):
return 'PLSModel {}'.format(self.skl_model)
Expand All @@ -74,7 +82,7 @@

def project(self, data):
if not isinstance(data, Table):
raise RuntimeError("PLSModel can only project tables")

Check warning on line 85 in orangecontrib/spectroscopy/models/pls.py

View check run for this annotation

Codecov / codecov/patch

orangecontrib/spectroscopy/models/pls.py#L85

Added line #L85 was not covered by tests

transformer = _PLSCommonTransform(self)

Expand Down Expand Up @@ -117,16 +125,24 @@
components.name = 'components'
return components

def coefficients_table(self):
coeffs = self.coefficients.T
domain = Domain(
[ContinuousVariable(f"coef {i}") for i in range(coeffs.shape[1])],
metas=[StringVariable("name")]
)
waves = [[attr.name] for attr in self.domain.attributes]
coef_table = Table.from_numpy(domain, X=coeffs, metas=waves)
coef_table.name = "coefficients"
return coef_table


class PLSRegressionLearner(SklLearner, _FeatureScorerMixin):
__wraps__ = skl_pls.PLSRegression
__returns__ = PLSModel

supports_multiclass = True
preprocessors = SklLearner.preprocessors

# this learner enforces a single class because multitarget is not
# explicitly allowed

def fit(self, X, Y, W=None):
params = self.params.copy()
params["n_components"] = min(X.shape[1] - 1,
Expand All @@ -140,6 +156,15 @@
super().__init__(preprocessors=preprocessors)
self.params = vars()

def incompatibility_reason(self, domain):
reason = None
if not domain.class_vars:
reason = "Numeric targets expected."
else:
for cv in domain.class_vars:
if not cv.is_continuous:
reason = "Only numeric target variables expected."
return reason

if __name__ == '__main__':
import Orange
Expand Down
69 changes: 47 additions & 22 deletions orangecontrib/spectroscopy/tests/test_owpls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import numpy as np

import pkg_resources
import sklearn
from sklearn.cross_decomposition import PLSRegression

from Orange.data import Table, Domain, ContinuousVariable
Expand All @@ -21,28 +23,42 @@ def table(rows, attr, vars):
return Table.from_numpy(domain, X=X, Y=Y)


def coefficients(sklmodel):
coef = sklmodel.coef_
# 1.3 has transposed coef_
if pkg_resources.parse_version(sklearn.__version__) < pkg_resources.parse_version("1.3.0"):
coef = coef.T
return coef


class TestPLS(TestCase):

def test_allow_y_dim(self):
""" The current PLS version allows only a single Y dimension. """
d = table(10, 5, 1)
learner = PLSRegressionLearner(n_components=2)
learner(d)
for n_class_vars in [0, 2]:
d = table(10, 5, 0)
with self.assertRaises(ValueError):
learner(d)
for n_class_vars in [1, 2, 3]:
d = table(10, 5, n_class_vars)
with self.assertRaises(ValueError):
learner(d)
learner(d) # no exception

def test_compare_to_sklearn(self):
d = table(10, 5, 1)
with d.unlocked():
d.X = np.random.RandomState(0).rand(*d.X.shape)
d.Y = np.random.RandomState(0).rand(*d.Y.shape)
orange_model = PLSRegressionLearner()(d)
scikit_model = PLSRegression().fit(d.X, d.Y)
np.testing.assert_almost_equal(scikit_model.predict(d.X).ravel(),
orange_model(d))
np.testing.assert_almost_equal(scikit_model.coef_,
np.testing.assert_almost_equal(coefficients(scikit_model),
orange_model.coefficients)

def test_compare_to_sklearn_multid(self):
d = table(10, 5, 3)
orange_model = PLSRegressionLearner()(d)
scikit_model = PLSRegression().fit(d.X, d.Y)
np.testing.assert_almost_equal(scikit_model.predict(d.X),
orange_model(d))
np.testing.assert_almost_equal(coefficients(scikit_model),
orange_model.coefficients)

def test_too_many_components(self):
Expand All @@ -60,21 +76,30 @@ def test_too_many_components(self):
self.assertEqual(model.skl_model.n_components, 4)

def test_scores(self):
d = table(10, 5, 1)
orange_model = PLSRegressionLearner()(d)
scikit_model = PLSRegression().fit(d.X, d.Y)
scores = orange_model.project(d)
sx, sy = scikit_model.transform(d.X, d.Y)
np.testing.assert_almost_equal(sx, scores.X)
np.testing.assert_almost_equal(sy, scores.metas)
for d in [table(10, 5, 1), table(10, 5, 3)]:
orange_model = PLSRegressionLearner()(d)
scikit_model = PLSRegression().fit(d.X, d.Y)
scores = orange_model.project(d)
sx, sy = scikit_model.transform(d.X, d.Y)
np.testing.assert_almost_equal(sx, scores.X)
np.testing.assert_almost_equal(sy, scores.metas)

def test_components(self):
d = table(10, 5, 1)
orange_model = PLSRegressionLearner()(d)
scikit_model = PLSRegression().fit(d.X, d.Y)
components = orange_model.components()
np.testing.assert_almost_equal(scikit_model.x_loadings_, components.X.T)
np.testing.assert_almost_equal(scikit_model.y_loadings_, components.Y.reshape(1, -1))
def t2d(m):
return m.reshape(-1, 1) if len(m.shape) == 1 else m
for d in [table(10, 5, 1), table(10, 5, 3)]:
orange_model = PLSRegressionLearner()(d)
scikit_model = PLSRegression().fit(d.X, d.Y)
components = orange_model.components()
np.testing.assert_almost_equal(scikit_model.x_loadings_, components.X.T)
np.testing.assert_almost_equal(scikit_model.y_loadings_, t2d(components.Y).T)

def test_coefficients(self):
for d in [table(10, 5, 1), table(10, 5, 3)]:
orange_model = PLSRegressionLearner()(d)
scikit_model = PLSRegression().fit(d.X, d.Y)
coef_table = orange_model.coefficients_table()
np.testing.assert_almost_equal(coefficients(scikit_model).T, coef_table.X)


class TestOWPLS(WidgetTest, WidgetLearnerTestMixin):
Expand Down
25 changes: 17 additions & 8 deletions orangecontrib/spectroscopy/widgets/owpls.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,7 @@
projection = None
components = None
if self.model is not None:
domain = Domain(
[ContinuousVariable("coef")], metas=[StringVariable("name")])
coefs = self.model.coefficients
coefs = coefs.reshape(-1, 1)
waves = [[attr.name] for attr in self.model.domain.attributes]
coef_table = Table.from_numpy(domain, X=coefs, metas=waves)
coef_table.name = "coefficients"
coef_table = self.model.coefficients_table()
projection = self.model.project(self.data)
components = self.model.components()
self.Outputs.coefsdata.send(coef_table)
Expand All @@ -68,8 +62,23 @@

@OWBaseLearner.Inputs.data
def set_data(self, data):
# reimplemented completely because the base learner does not
# allow multiclass

self.Warning.sparse_data.clear()
super().set_data(data)

self.Error.data_error.clear()
self.data = data

if data is not None and data.domain.class_var is None and not data.domain.class_vars:
self.Error.data_error(

Check warning on line 74 in orangecontrib/spectroscopy/widgets/owpls.py

View check run for this annotation

Codecov / codecov/patch

orangecontrib/spectroscopy/widgets/owpls.py#L74

Added line #L74 was not covered by tests
"Data has no target variable.\n"
"Select one with the Select Columns widget.")
self.data = None

Check warning on line 77 in orangecontrib/spectroscopy/widgets/owpls.py

View check run for this annotation

Codecov / codecov/patch

orangecontrib/spectroscopy/widgets/owpls.py#L77

Added line #L77 was not covered by tests

# invalidate the model so that handleNewSignals will update it
self.model = None

if self.data and sp.issparse(self.data.X):
self.Warning.sparse_data()

Expand Down
Loading