Skip to content

Commit

Permalink
add lock funciont to kvs.py
Browse files Browse the repository at this point in the history
  • Loading branch information
micchu committed Aug 14, 2024
1 parent acf46fd commit b7e4eac
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 64 deletions.
175 changes: 112 additions & 63 deletions bdpy/dataform/kvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ class BaseKeyValueStore(object):

def get(self, **kwargs) -> _array_t:
raise NotImplementedError("get should be implemented in the subclass.")

def set(self, value: _array_t, **kwargs) -> None:
raise NotImplementedError("set should be implemented in the subclass.")


class SQLite3KeyValueStore(BaseKeyValueStore):
"""Key-value store implemented on SQLite3."""

def __init__(self, path: _path_t, keys: Optional[List[str]] = None):
def __init__(self, path: _path_t, timeout: int = 60, keys: Optional[List[str]] = None):
"""Initialize the SQLite3KeyValueStore object.
Parameters
Expand All @@ -46,7 +46,7 @@ def __init__(self, path: _path_t, keys: Optional[List[str]] = None):
new_db = not os.path.exists(self._path)

# Connect to DB
self._conn = sqlite3.connect(self._path, isolation_level="EXCLUSIVE", timeout=60)
self._conn = sqlite3.connect(self._path, isolation_level='EXCLUSIVE', timeout=timeout)

# Enable foreign key
cursor = self._conn.cursor()
Expand All @@ -66,7 +66,10 @@ def __init__(self, path: _path_t, keys: Optional[List[str]] = None):
self._keys = self._get_keys()

def set(self, value: _array_t, **kwargs) -> None:
"""Set value for the given keys."""
"""
Set value for the given keys.
Transaction mode: DEFERRED
"""
# Check if the keys are valid
for key in kwargs.keys():
if key not in self._keys:
Expand All @@ -76,67 +79,80 @@ def set(self, value: _array_t, **kwargs) -> None:
if len(kwargs) != len(self._keys):
raise ValueError("All keys must be given.")

# Set transaction
self._conn.execute("BEGIN TRANSACTION;")

_v = value.astype(float).tobytes(order='C')

where = self._generate_where(**kwargs)

self._conn.execute(
f"""
CREATE TABLE tmp AS
WITH hit AS (
SELECT kgm.key_value_store_id FROM key_group_members AS kgm
JOIN key_instances AS ki ON kgm.key_instance_id = ki.id
JOIN key_names AS kn ON ki.key_name_id = kn.id
WHERE {where}
GROUP BY kgm.key_value_store_id
)
SELECT * FROM hit;
"""
)
self._conn.execute("CREATE TABLE kvs_last_inserted_rowid (rowid INTEGER);")
self._conn.execute(
"""
CREATE TRIGGER kvs_insert AFTER INSERT ON key_value_store
BEGIN
DELETE FROM kvs_last_inserted_rowid;
INSERT INTO kvs_last_inserted_rowid (rowid) VALUES (new.rowid);
END;
# Check if the given keys already exist
key_group_id = self._get_key_group_id(**kwargs)
cursor = self._conn.cursor()
cursor.execute("BEGIN DEFERRED;")
if key_group_id is None:
# Add new key-value pair
sql = "INSERT INTO key_value_store (value) VALUES (?)"
cursor.execute(sql, (value.astype(float).tobytes(order='C'),))
key_value_store_id = cursor.lastrowid
self._add_key_group_id(key_value_store_id, **kwargs)
else:
# Update existing key-value pair
sql = f"""
UPDATE key_value_store
SET value = ?
WHERE id = {key_group_id}
"""
)

sql_update = "UPDATE key_value_store SET value = ? WHERE id = (SELECT key_value_store_id FROM tmp LIMIT 1) AND (SELECT COUNT(*) FROM tmp) = 1;"
self._conn.execute(sql_update, (_v,))

insert_instances = ', '.join([
f"('{inst}', (SELECT id FROM key_names WHERE name = '{key}'))"
for key, inst in kwargs.items()
])
sql_insert_inst = f"INSERT OR IGNORE INTO key_instances (name, key_name_id) VALUES {insert_instances};"
self._conn.execute(sql_insert_inst)
cursor.execute(sql, (value.astype(float).tobytes(order='C'),))
self._conn.commit()
cursor.close()

sql_insert_kvs = "INSERT INTO key_value_store (value) SELECT ? WHERE (SELECT COUNT(*) FROM tmp) = 0;"
self._conn.execute(sql_insert_kvs, (_v,))
return None

for key, inst in kwargs.items():
sql_insert_kgm = f"""
INSERT OR IGNORE INTO key_group_members (key_value_store_id, key_instance_id)
SELECT
(SELECT id FROM key_value_store WHERE rowid = (SELECT * FROM kvs_last_inserted_rowid)),
(SELECT ki.id FROM key_instances AS ki JOIN key_names AS kn ON ki.key_name_id = kn.id WHERE kn.name = '{key}' AND ki.name = '{inst}')
WHERE (SELECT COUNT(*) FROM tmp) = 0;
"""
self._conn.execute(sql_insert_kgm)
def lock(self, **kwargs) -> bool:
"""
If a record with the specified condition does not exist, insert a record with a null value and return True.
If a record exists, return False.
Transaction mode: EXCLUSIVE
"""
# Check if the keys are valid
for key in kwargs.keys():
if key not in self._keys:
raise ValueError(f"Key '{key}' is not defined.")

self._conn.execute("DROP TABLE tmp;")
self._conn.execute("DROP TABLE kvs_last_inserted_rowid;")
self._conn.execute("DROP TRIGGER kvs_insert;")
# Check if all keys are given
if len(kwargs) != len(self._keys):
raise ValueError("All keys must be given.")

# Start EXCLUSIVE transaction
cursor = self._conn.cursor()
cursor.execute("BEGIN EXCLUSIVE;")

# Check if a record with the specified condition already exists
try:
key_value_store_id = self._get_key_group_id(**kwargs)
except ValueError:
# Close transaction
self._conn.commit()
cursor.close()
raise

# If the condition already exists,
# It is determined that it is impossible to obtain the lock,
# close the cursor, return False.
if key_value_store_id is not None:
# Close transaction
self._conn.commit()
cursor.close()
return False

# If no record with the specified condition exists,
# take a lock.
# Add new record to key-value pair and get the last key_value_store_id
sql = "INSERT INTO key_value_store (value) VALUES (?)"
cursor.execute(sql, (np.array([[]]).astype(float).tobytes(order='C'),))
key_value_store_id = cursor.lastrowid
self._add_key_group_id(key_value_store_id, **kwargs)

# Close transaction
self._conn.commit()
cursor.close()

return None
# Lock をとることに成功したので True を返す
return True

def get(self, **kwargs) -> Optional[_array_t]:
"""Get value for the given keys."""
Expand Down Expand Up @@ -177,13 +193,44 @@ def delete(self, **kwargs) -> None:
DELETE FROM key_value_store WHERE id = {key_group_id}
""",
]
self._conn.execute("BEGIN TRANSACTION;")
# Start transaction
cursor = self._conn.cursor()
cursor.execute("BEGIN EXCLUSIVE;")
for sql in sqls:
self._conn.execute(sql)
cursor.execute(sql)
# Close transaction
self._conn.commit()

cursor.close()
return None

def _add_key_group_id(self, key_value_store_id: int, **kwargs) -> int:
"""Add key group ID."""
# Open cursor
cursor = self._conn.cursor()

for key, inst in kwargs.items():
# Add key instance if not exists
key_instance_id = self._get_key_instance_id(key, inst)
if key_instance_id is not None:
continue
key_name_id = self._get_key_name_id(key)
sql = f"""
INSERT OR IGNORE INTO key_instances (name, key_name_id) VALUES ('{inst}', {key_name_id})
"""
cursor.execute(sql)

inst_ids = [self._get_key_instance_id(key, inst) for key, inst in kwargs.items()]
sqls = [
f"INSERT INTO key_group_members (key_value_store_id, key_instance_id) VALUES ({key_value_store_id}, {inst_id})"
for inst_id in inst_ids
]
for sql in sqls:
cursor.execute(sql)

# Close cursor
cursor.close()
return key_value_store_id

def _get_key_group_id(self, **kwargs) -> Optional[int]:
"""Get key group ID."""
where = self._generate_where(**kwargs)
Expand Down Expand Up @@ -280,10 +327,12 @@ def _init_empty_db(self) -> None:
)
""",
]
self._conn.execute("BEGIN TRANSACTION;")
cursor = self._conn.cursor()
cursor.execute("BEGIN EXCLUSIVE;")
for sql in sqls:
self._conn.execute(sql)
cursor.execute(sql)
self._conn.commit()
cursor.close()
return None

def _validate_db(self, keys: List[str]) -> None:
Expand Down
16 changes: 15 additions & 1 deletion tests/dataform/test_kvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,23 @@ def test_delete(self):
kvs.set(np.array([np.nan]), layer="conv1", subject="sub03", roi="PPA", metric="accuracy")

kvs.delete(layer="conv1", subject="sub03", roi="PPA", metric="accuracy")
np.testing.assert_(kvs.exists(layer="conv1", subject="sub03", roi="LOC", metric="accuracy"))
np.testing.assert_(~kvs.exists(layer="conv1", subject="sub03", roi="PPA", metric="accuracy"),
'AssertionError: Failed to delete the record.')

def test_lock(self):
with tempfile.TemporaryDirectory() as tmpdir:
db_path = os.path.join(tmpdir, "test_3304.db")
self._init_test_db(db_path)

kvs = SQLite3KeyValueStore(db_path)

kvs.set(np.array([ 1, 2, 3, 4]), layer="conv1", subject="sub03", roi="LOC", metric="accuracy")
kvs.set(np.array([ 5, 6, 7, 8]), layer="conv1", subject="sub03", roi="FFA", metric="accuracy")
np.testing.assert_(kvs.lock(layer="conv1", subject="sub03", roi="PPA", metric="accuracy"),
'AssertionError: Failed to lock the specified condition.')
np.testing.assert_(~kvs.lock(layer="conv1", subject="sub03", roi="PPA", metric="accuracy"),
'AssertionError: A condition that was already locked has been newly locked.')


if __name__ == "__main__":
unittest.main()

0 comments on commit b7e4eac

Please sign in to comment.