Skip to content

Commit

Permalink
Merge pull request #5536 from janezd/distribution-deepcopy
Browse files Browse the repository at this point in the history
[FIX] Fix deepcopy and pickle for classes derived from `np.ndarray`
  • Loading branch information
VesnaT authored Aug 6, 2021
2 parents 77f53fe + 41413ad commit 376530d
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 3 deletions.
13 changes: 12 additions & 1 deletion Orange/statistics/contingency.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,20 @@ def __reduce__(self):
return (
_create_discrete,
(Discrete, np.copy(self), self.col_variable, self.row_variable,
self.col_unknowns, self.row_unknowns)
self.col_unknowns, self.row_unknowns, self.unknowns)
)

def __array_finalize__(self, obj):
# defined in __new__, pylint: disable=attribute-defined-outside-init
"""See http://docs.scipy.org/doc/numpy/user/basics.subclassing.html"""
if obj is None:
return
self.col_variable = getattr(obj, 'col_variable', None)
self.row_variable = getattr(obj, 'row_variable', None)
self.col_unknowns = getattr(obj, 'col_unknowns', None)
self.row_unknowns = getattr(obj, 'row_unknowns', None)
self.unknowns = getattr(obj, 'unknowns', None)


class Continuous:
def __init__(self, dat, col_variable=None, row_variable=None,
Expand Down
20 changes: 20 additions & 0 deletions Orange/statistics/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,26 @@ def _get_variable(dat, variable, expected_type=None, expected_name=""):


class Distribution(np.ndarray):
def __array_finalize__(self, obj):
# defined in derived classes,
# pylint: disable=attribute-defined-outside-init
"""See http://docs.scipy.org/doc/numpy/user/basics.subclassing.html"""
if obj is None:
return
self.variable = getattr(obj, 'variable', None)
self.unknowns = getattr(obj, 'unknowns', 0)

def __reduce__(self):
state = super().__reduce__()
newstate = state[2] + (self.variable, self.unknowns)
return state[0], state[1], newstate

def __setstate__(self, state):
# defined in derived classes,
# pylint: disable=attribute-defined-outside-init
super().__setstate__(state[:-2])
self.variable, self.unknowns = state[-2:]

def __eq__(self, other):
return (
np.array_equal(self, other) and
Expand Down
9 changes: 8 additions & 1 deletion Orange/tests/test_contingency.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Test methods with long descriptive names can omit docstrings
# pylint: disable=missing-docstring

import copy
import unittest
from unittest.mock import Mock

Expand Down Expand Up @@ -71,6 +71,13 @@ def test_discrete_missing(self):
np.testing.assert_almost_equal(cont.row_unknowns, [0, 0])
self.assertEqual(1, cont.unknowns)

def test_deepcopy(self):
cont = contingency.Discrete(self.zoo, 0)
dc = copy.deepcopy(cont)
self.assertEqual(dc, cont)
self.assertEqual(dc.col_variable, cont.col_variable)
self.assertEqual(dc.row_variable, cont.row_variable)

def test_array_with_unknowns(self):
d = data.Table("zoo")
d.Y[2] = float("nan")
Expand Down
53 changes: 52 additions & 1 deletion Orange/tests/test_distribution.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Test methods with long descriptive names can omit docstrings
# Test internal methods
# pylint: disable=missing-docstring, protected-access

import copy
import pickle
import unittest
from unittest.mock import Mock
import warnings
Expand Down Expand Up @@ -110,6 +111,32 @@ def test_fallback_with_weights_and_nan(self):
np.asarray(fallback), np.asarray(default))
np.testing.assert_almost_equal(fallback.unknowns, default.unknowns)

def test_pickle(self):
d = data.Table("zoo")
d1 = distribution.Discrete(d, 0)
dc = pickle.loads(pickle.dumps(d1))
# This always worked because `other` wasn't required to have `unknowns`
self.assertEqual(d1, dc)
# This failed before implementing `__reduce__`
self.assertEqual(dc, d1)
self.assertEqual(hash(d1), hash(dc))
# Test that `dc` has the required attributes
self.assertEqual(dc.variable, d1.variable)
self.assertEqual(dc.unknowns, d1.unknowns)

def test_deepcopy(self):
d = data.Table("zoo")
d1 = distribution.Discrete(d, 0)
dc = copy.deepcopy(d1)
# This always worked because `other` wasn't required to have `unknowns`
self.assertEqual(d1, dc)
# This failed before implementing `__deepcopy__`
self.assertEqual(dc, d1)
self.assertEqual(hash(d1), hash(dc))
# Test that `dc` has the required attributes
self.assertEqual(dc.variable, d1.variable)
self.assertEqual(dc.unknowns, d1.unknowns)

def test_equality(self):
d = data.Table("zoo")
d1 = distribution.Discrete(d, 0)
Expand Down Expand Up @@ -285,6 +312,30 @@ def test_construction(self):
self.assertEqual(disc2.unknowns, 0)
assert_dist_equal(disc2, dd)

def test_pickle(self):
d1 = distribution.Continuous(self.iris, 0)
dc = pickle.loads(pickle.dumps(d1))
# This always worked because `other` wasn't required to have `unknowns`
self.assertEqual(d1, dc)
# This failed before implementing `__reduce__`
self.assertEqual(dc, d1)
self.assertEqual(hash(d1), hash(dc))
# Test that `dc` has the required attributes
self.assertEqual(dc.variable, d1.variable)
self.assertEqual(dc.unknowns, d1.unknowns)

def test_deepcopy(self):
d1 = distribution.Continuous(self.iris, 0)
dc = copy.deepcopy(d1)
# This always worked because `other` wasn't required to have `unknowns`
self.assertEqual(d1, dc)
# This failed before implementing `__deepcopy__`
self.assertEqual(dc, d1)
self.assertEqual(hash(d1), hash(dc))
# Test that `dc` has the required attributes
self.assertEqual(dc.variable, d1.variable)
self.assertEqual(dc.unknowns, d1.unknowns)

def test_hash(self):
d = self.iris
petal_length = d.columns.petal_length
Expand Down

0 comments on commit 376530d

Please sign in to comment.