From b7e4eac8546658f21821005b94899ab9c1d39cee Mon Sep 17 00:00:00 2001 From: micchu Date: Wed, 14 Aug 2024 12:57:36 +0900 Subject: [PATCH] add lock funciont to kvs.py --- bdpy/dataform/kvs.py | 175 ++++++++++++++++++++++++------------- tests/dataform/test_kvs.py | 16 +++- 2 files changed, 127 insertions(+), 64 deletions(-) diff --git a/bdpy/dataform/kvs.py b/bdpy/dataform/kvs.py index ea593f9..ac533d4 100644 --- a/bdpy/dataform/kvs.py +++ b/bdpy/dataform/kvs.py @@ -19,7 +19,7 @@ class BaseKeyValueStore(object): def get(self, **kwargs) -> _array_t: raise NotImplementedError("get should be implemented in the subclass.") - + def set(self, value: _array_t, **kwargs) -> None: raise NotImplementedError("set should be implemented in the subclass.") @@ -27,7 +27,7 @@ def set(self, value: _array_t, **kwargs) -> None: class SQLite3KeyValueStore(BaseKeyValueStore): """Key-value store implemented on SQLite3.""" - def __init__(self, path: _path_t, keys: Optional[List[str]] = None): + def __init__(self, path: _path_t, timeout: int = 60, keys: Optional[List[str]] = None): """Initialize the SQLite3KeyValueStore object. Parameters @@ -46,7 +46,7 @@ def __init__(self, path: _path_t, keys: Optional[List[str]] = None): new_db = not os.path.exists(self._path) # Connect to DB - self._conn = sqlite3.connect(self._path, isolation_level="EXCLUSIVE", timeout=60) + self._conn = sqlite3.connect(self._path, isolation_level='EXCLUSIVE', timeout=timeout) # Enable foreign key cursor = self._conn.cursor() @@ -66,7 +66,10 @@ def __init__(self, path: _path_t, keys: Optional[List[str]] = None): self._keys = self._get_keys() def set(self, value: _array_t, **kwargs) -> None: - """Set value for the given keys.""" + """ + Set value for the given keys. + Transaction mode: DEFERRED + """ # Check if the keys are valid for key in kwargs.keys(): if key not in self._keys: @@ -76,67 +79,80 @@ def set(self, value: _array_t, **kwargs) -> None: if len(kwargs) != len(self._keys): raise ValueError("All keys must be given.") - # Set transaction - self._conn.execute("BEGIN TRANSACTION;") - - _v = value.astype(float).tobytes(order='C') - - where = self._generate_where(**kwargs) - - self._conn.execute( - f""" - CREATE TABLE tmp AS - WITH hit AS ( - SELECT kgm.key_value_store_id FROM key_group_members AS kgm - JOIN key_instances AS ki ON kgm.key_instance_id = ki.id - JOIN key_names AS kn ON ki.key_name_id = kn.id - WHERE {where} - GROUP BY kgm.key_value_store_id - ) - SELECT * FROM hit; - """ - ) - self._conn.execute("CREATE TABLE kvs_last_inserted_rowid (rowid INTEGER);") - self._conn.execute( - """ - CREATE TRIGGER kvs_insert AFTER INSERT ON key_value_store - BEGIN - DELETE FROM kvs_last_inserted_rowid; - INSERT INTO kvs_last_inserted_rowid (rowid) VALUES (new.rowid); - END; + # Check if the given keys already exist + key_group_id = self._get_key_group_id(**kwargs) + cursor = self._conn.cursor() + cursor.execute("BEGIN DEFERRED;") + if key_group_id is None: + # Add new key-value pair + sql = "INSERT INTO key_value_store (value) VALUES (?)" + cursor.execute(sql, (value.astype(float).tobytes(order='C'),)) + key_value_store_id = cursor.lastrowid + self._add_key_group_id(key_value_store_id, **kwargs) + else: + # Update existing key-value pair + sql = f""" + UPDATE key_value_store + SET value = ? + WHERE id = {key_group_id} """ - ) - - sql_update = "UPDATE key_value_store SET value = ? WHERE id = (SELECT key_value_store_id FROM tmp LIMIT 1) AND (SELECT COUNT(*) FROM tmp) = 1;" - self._conn.execute(sql_update, (_v,)) - - insert_instances = ', '.join([ - f"('{inst}', (SELECT id FROM key_names WHERE name = '{key}'))" - for key, inst in kwargs.items() - ]) - sql_insert_inst = f"INSERT OR IGNORE INTO key_instances (name, key_name_id) VALUES {insert_instances};" - self._conn.execute(sql_insert_inst) + cursor.execute(sql, (value.astype(float).tobytes(order='C'),)) + self._conn.commit() + cursor.close() - sql_insert_kvs = "INSERT INTO key_value_store (value) SELECT ? WHERE (SELECT COUNT(*) FROM tmp) = 0;" - self._conn.execute(sql_insert_kvs, (_v,)) + return None - for key, inst in kwargs.items(): - sql_insert_kgm = f""" - INSERT OR IGNORE INTO key_group_members (key_value_store_id, key_instance_id) - SELECT - (SELECT id FROM key_value_store WHERE rowid = (SELECT * FROM kvs_last_inserted_rowid)), - (SELECT ki.id FROM key_instances AS ki JOIN key_names AS kn ON ki.key_name_id = kn.id WHERE kn.name = '{key}' AND ki.name = '{inst}') - WHERE (SELECT COUNT(*) FROM tmp) = 0; - """ - self._conn.execute(sql_insert_kgm) + def lock(self, **kwargs) -> bool: + """ + If a record with the specified condition does not exist, insert a record with a null value and return True. + If a record exists, return False. + Transaction mode: EXCLUSIVE + """ + # Check if the keys are valid + for key in kwargs.keys(): + if key not in self._keys: + raise ValueError(f"Key '{key}' is not defined.") - self._conn.execute("DROP TABLE tmp;") - self._conn.execute("DROP TABLE kvs_last_inserted_rowid;") - self._conn.execute("DROP TRIGGER kvs_insert;") + # Check if all keys are given + if len(kwargs) != len(self._keys): + raise ValueError("All keys must be given.") + # Start EXCLUSIVE transaction + cursor = self._conn.cursor() + cursor.execute("BEGIN EXCLUSIVE;") + + # Check if a record with the specified condition already exists + try: + key_value_store_id = self._get_key_group_id(**kwargs) + except ValueError: + # Close transaction + self._conn.commit() + cursor.close() + raise + + # If the condition already exists, + # It is determined that it is impossible to obtain the lock, + # close the cursor, return False. + if key_value_store_id is not None: + # Close transaction + self._conn.commit() + cursor.close() + return False + + # If no record with the specified condition exists, + # take a lock. + # Add new record to key-value pair and get the last key_value_store_id + sql = "INSERT INTO key_value_store (value) VALUES (?)" + cursor.execute(sql, (np.array([[]]).astype(float).tobytes(order='C'),)) + key_value_store_id = cursor.lastrowid + self._add_key_group_id(key_value_store_id, **kwargs) + + # Close transaction self._conn.commit() + cursor.close() - return None + # Lock をとることに成功したので True を返す + return True def get(self, **kwargs) -> Optional[_array_t]: """Get value for the given keys.""" @@ -177,13 +193,44 @@ def delete(self, **kwargs) -> None: DELETE FROM key_value_store WHERE id = {key_group_id} """, ] - self._conn.execute("BEGIN TRANSACTION;") + # Start transaction + cursor = self._conn.cursor() + cursor.execute("BEGIN EXCLUSIVE;") for sql in sqls: - self._conn.execute(sql) + cursor.execute(sql) + # Close transaction self._conn.commit() - + cursor.close() return None + def _add_key_group_id(self, key_value_store_id: int, **kwargs) -> int: + """Add key group ID.""" + # Open cursor + cursor = self._conn.cursor() + + for key, inst in kwargs.items(): + # Add key instance if not exists + key_instance_id = self._get_key_instance_id(key, inst) + if key_instance_id is not None: + continue + key_name_id = self._get_key_name_id(key) + sql = f""" + INSERT OR IGNORE INTO key_instances (name, key_name_id) VALUES ('{inst}', {key_name_id}) + """ + cursor.execute(sql) + + inst_ids = [self._get_key_instance_id(key, inst) for key, inst in kwargs.items()] + sqls = [ + f"INSERT INTO key_group_members (key_value_store_id, key_instance_id) VALUES ({key_value_store_id}, {inst_id})" + for inst_id in inst_ids + ] + for sql in sqls: + cursor.execute(sql) + + # Close cursor + cursor.close() + return key_value_store_id + def _get_key_group_id(self, **kwargs) -> Optional[int]: """Get key group ID.""" where = self._generate_where(**kwargs) @@ -280,10 +327,12 @@ def _init_empty_db(self) -> None: ) """, ] - self._conn.execute("BEGIN TRANSACTION;") + cursor = self._conn.cursor() + cursor.execute("BEGIN EXCLUSIVE;") for sql in sqls: - self._conn.execute(sql) + cursor.execute(sql) self._conn.commit() + cursor.close() return None def _validate_db(self, keys: List[str]) -> None: diff --git a/tests/dataform/test_kvs.py b/tests/dataform/test_kvs.py index d33afbb..1d85c69 100644 --- a/tests/dataform/test_kvs.py +++ b/tests/dataform/test_kvs.py @@ -188,9 +188,23 @@ def test_delete(self): kvs.set(np.array([np.nan]), layer="conv1", subject="sub03", roi="PPA", metric="accuracy") kvs.delete(layer="conv1", subject="sub03", roi="PPA", metric="accuracy") - np.testing.assert_(kvs.exists(layer="conv1", subject="sub03", roi="LOC", metric="accuracy")) np.testing.assert_(~kvs.exists(layer="conv1", subject="sub03", roi="PPA", metric="accuracy"), 'AssertionError: Failed to delete the record.') + def test_lock(self): + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test_3304.db") + self._init_test_db(db_path) + + kvs = SQLite3KeyValueStore(db_path) + + kvs.set(np.array([ 1, 2, 3, 4]), layer="conv1", subject="sub03", roi="LOC", metric="accuracy") + kvs.set(np.array([ 5, 6, 7, 8]), layer="conv1", subject="sub03", roi="FFA", metric="accuracy") + np.testing.assert_(kvs.lock(layer="conv1", subject="sub03", roi="PPA", metric="accuracy"), + 'AssertionError: Failed to lock the specified condition.') + np.testing.assert_(~kvs.lock(layer="conv1", subject="sub03", roi="PPA", metric="accuracy"), + 'AssertionError: A condition that was already locked has been newly locked.') + + if __name__ == "__main__": unittest.main()