Skip to content

Commit

Permalink
Merge pull request #6552 from VesnaT/pickle_ids
Browse files Browse the repository at this point in the history
Table: Assure unique table.ids when unpickling
  • Loading branch information
PrimozGodec authored Sep 1, 2023
2 parents 4ddff41 + 9fb7554 commit 67a7ee9
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 4 deletions.
24 changes: 24 additions & 0 deletions Orange/data/sql/table.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Support for example tables wrapping data stored on a PostgreSQL server.
"""
import contextlib
import functools
import logging
import threading
Expand Down Expand Up @@ -669,3 +670,26 @@ def get_nan_frequency_attribute(self):

def get_nan_frequency_class(self):
return self.__get_nan_frequency(self.domain.class_vars)

def __getstate__(self):
# avoids locking magic in Table.__getstate__
return self.__dict__

def __setstate__(self, state):
# avoid locking magic in Table.__setstate__
self.__dict__.update(state)

# if X is defined then it was already downloaded
# thus ids exist to, rewrite them
if self._X is not None:
self._init_ids(self)

# pylint: disable=unused-argument
def _update_locks(self, *args, **kwargs):
# avoid locking inherited from Table
return

# pylint: disable=unused-argument
def unlocked(self, *parts):
# avoid locking inherited from Table
return contextlib.nullcontext()
8 changes: 6 additions & 2 deletions Orange/data/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,8 @@ def no_view(x):
setattr(self, "Y", no_view(state.pop("_Y"))) # state["_Y"] is a 2d array
self.__dict__.update(state)

self._init_ids(self)

def __getstate__(self):
# Compatibility with pickles before table locking:
# return the same state as before table lock
Expand Down Expand Up @@ -1007,9 +1009,11 @@ def from_list(cls, domain, rows, weights=None):

@classmethod
def _init_ids(cls, obj):
length = int(obj.X.shape[0])
with cls._next_instance_lock:
obj.ids = np.array(range(cls._next_instance_id, cls._next_instance_id + obj.X.shape[0]))
cls._next_instance_id += obj.X.shape[0]
nid = cls._next_instance_id
cls._next_instance_id += length
obj.ids = np.arange(nid, nid + length, dtype=int)

@classmethod
def new_id(cls):
Expand Down
20 changes: 20 additions & 0 deletions Orange/tests/sql/test_sql_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,26 @@ def test_pickling_restores_connection_pool(self):

self.assertEqual(iris[0], iris2[0])

@dbt.run_on(["postgres"])
def test_pickling_respects_downloaded_state(self):
iris = SqlTable(self.conn, self.iris, inspect_values=True)
iris2 = pickle.loads(pickle.dumps(iris))
# pylint: disable=protected-access
self.assertIsNone(iris._X)
self.assertIsNone(iris2._X)
self.assertIsNone(iris._ids)
self.assertIsNone(iris2._ids)

# trigger download into X, Y, metas
iris.X.shape[0] # pylint: disable=pointless-statement
self.assertIsNotNone(iris._X)
self.assertIsNotNone(iris._ids)
iris2 = pickle.loads(pickle.dumps(iris))
self.assertIsNotNone(iris2._X)
self.assertIsNotNone(iris2._ids)
np.testing.assert_equal(iris.X, iris2.X)
self.assertEqual(len(set(iris.ids) | set(iris2.ids)), 300)

@dbt.run_on(["postgres"])
def test_list_tables_with_schema(self):
with self.backend.execute_sql_query("DROP SCHEMA IF EXISTS orange_tests CASCADE") as cur:
Expand Down
10 changes: 10 additions & 0 deletions Orange/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,16 @@ def test_save_pickle(self):
finally:
os.remove("iris.pickle")

def test_read_pickle_ids(self):
table = data.Table("iris")
try:
table.save("iris.pickle")
table1 = data.Table.from_file("iris.pickle")
table2 = data.Table.from_file("iris.pickle")
self.assertEqual(len(set(table1.ids) | set(table2.ids)), 300)
finally:
os.remove("iris.pickle")

def test_from_numpy(self):
a = np.arange(20, dtype="d").reshape((4, 5)).copy()
m = np.arange(4, dtype="d").reshape((4, 1)).copy()
Expand Down
3 changes: 1 addition & 2 deletions Orange/widgets/visualize/tests/test_owvenndiagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import unittest
from unittest.mock import patch
from copy import deepcopy

import numpy as np

Expand Down Expand Up @@ -36,7 +35,7 @@ def _select_data(self):

def test_rows_id(self):
data = Table('zoo')
data1 = deepcopy(data)
data1 = data.copy()
with data1.unlocked():
data1[:, 1] = 1
self.widget.rowwise = True
Expand Down

0 comments on commit 67a7ee9

Please sign in to comment.