Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow including master table in cascading delete #1158

Merged
merged 15 commits into from
Aug 19, 2024
2 changes: 1 addition & 1 deletion datajoint/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,7 @@ class U:
>>> dj.U().aggr(expr, n='count(*)')

The following expressions both yield one element containing the number `n` of distinct values of attribute `attr` in
query expressio `expr`.
query expression `expr`.

>>> dj.U().aggr(expr, n='count(distinct attr)')
>>> dj.U().aggr(dj.U('attr').aggr(expr), 'n=count(*)')
Expand Down
28 changes: 26 additions & 2 deletions datajoint/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ def delete(
transaction: bool = True,
safemode: Union[bool, None] = None,
force_parts: bool = False,
include_parts: bool = True,
CBroz1 marked this conversation as resolved.
Show resolved Hide resolved
) -> int:
"""
Deletes the contents of the table and its dependent tables, recursively.
Expand All @@ -497,6 +498,8 @@ def delete(
safemode: If `True`, prohibit nested transactions and prompt to confirm. Default
is `dj.config['safemode']`.
force_parts: Delete from parts even when not deleting from their masters.
include_parts: If `True`, include part/master pairs in the cascade.
CBroz1 marked this conversation as resolved.
Show resolved Hide resolved
Default is `True`.
CBroz1 marked this conversation as resolved.
Show resolved Hide resolved

Returns:
Number of deleted rows (excluding those from dependent tables).
Expand All @@ -507,6 +510,7 @@ def delete(
DataJointError: Deleting a part table before its master.
"""
deleted = set()
visited_masters = set()
ethho marked this conversation as resolved.
Show resolved Hide resolved

def cascade(table):
"""service function to perform cascading deletes recursively."""
Expand Down Expand Up @@ -565,11 +569,31 @@ def cascade(table):
)
else:
child &= table.proj()
cascade(child)

master_name = get_master(child.full_table_name)
if (
include_parts
CBroz1 marked this conversation as resolved.
Show resolved Hide resolved
and master_name
and master_name != table.full_table_name
and master_name not in visited_masters
):
master = FreeTable(table.connection, master_name)
master._restriction_attributes = set()
master._restriction = [
make_condition( # &= may cause in target tables in subquery
master,
(master.proj() & child.proj()).fetch(),
master._restriction_attributes,
)
]
visited_masters.add(master_name)
cascade(master)
else:
cascade(child)
else:
deleted.add(table.full_table_name)
logger.info(
"Deleting {count} rows from {table}".format(
"Deleting: {count} rows from {table}".format(
count=delete_count, table=table.full_table_name
)
)
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def schema_simp(connection_test, prefix):
schema(schema_simple.E)
schema(schema_simple.F)
schema(schema_simple.F)
schema(schema_simple.G)
schema(schema_simple.DataA)
schema(schema_simple.DataB)
schema(schema_simple.Website)
Expand Down
40 changes: 34 additions & 6 deletions tests/schema_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,36 @@ class F(dj.Part):
-> B.C
"""

class G(dj.Part):
definition = """ # test secondary fk reference
-> E
id_g :int
---
-> L
"""

class H(dj.Part):
definition = """ # test no additional fk reference
-> E
id_h :int
"""

def make(self, key):
random.seed(str(key))
self.insert1(dict(key, **random.choice(list(L().fetch("KEY")))))
sub = E.F()
references = list((B.C() & key).fetch("KEY"))
random.shuffle(references)
sub.insert(
l_contents = list(L().fetch("KEY"))
part_f, part_g, part_h = E.F(), E.G(), E.H()
bc_references = list((B.C() & key).fetch("KEY"))
random.shuffle(bc_references)

self.insert1(dict(key, **random.choice(l_contents)))
part_f.insert(
dict(key, id_f=i, **ref)
for i, ref in enumerate(references)
for i, ref in enumerate(bc_references)
if random.getrandbits(1)
)
g_inserts = [dict(key, id_g=i, **ref) for i, ref in enumerate(l_contents)]
part_g.insert(g_inserts)
part_h.insert(dict(key, id_h=i) for i in range(4))


class F(dj.Manual):
Expand All @@ -132,6 +151,15 @@ class F(dj.Manual):
"""


class G(dj.Computed):
definition = """ # test downstream of complex master/parts
-> E
"""

def make(self, key):
self.insert1(key)


class DataA(dj.Lookup):
definition = """
idx : int
Expand Down
22 changes: 18 additions & 4 deletions tests/test_cascading_delete.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import datajoint as dj
from .schema_simple import A, B, D, E, L, Website, Profile
from .schema_simple import A, B, D, E, G, L, Website, Profile
from .schema import ComplexChild, ComplexParent


Expand All @@ -11,6 +11,7 @@ def schema_simp_pop(schema_simp):
B().populate()
D().populate()
E().populate()
G().populate()
yield schema_simp


Expand Down Expand Up @@ -96,7 +97,7 @@ def test_delete_complex_keys(schema_any):
**{
"child_id_{}".format(i + 1): (i + parent_key_count)
for i in range(child_key_count)
}
},
)
assert len(ComplexParent & restriction) == 1, "Parent record missing"
assert len(ComplexChild & restriction) == 1, "Child record missing"
Expand All @@ -110,11 +111,24 @@ def test_delete_master(schema_simp_pop):
Profile().delete()


def test_delete_parts(schema_simp_pop):
def test_delete_parts_error(schema_simp_pop):
"""test issue #151"""
with pytest.raises(dj.DataJointError):
Profile().populate_random()
Website().delete()
Website().delete(include_parts=False)
CBroz1 marked this conversation as resolved.
Show resolved Hide resolved


def test_delete_parts(schema_simp_pop):
"""test issue #151"""
Profile().populate_random()
Website().delete(include_parts=True)
CBroz1 marked this conversation as resolved.
Show resolved Hide resolved


def test_delete_parts_complex(schema_simp_pop):
"""test issue #151 with complex master/part. PR #1158."""
prev_len = len(G())
(A() & "id_a=1").delete()
CBroz1 marked this conversation as resolved.
Show resolved Hide resolved
assert prev_len - len(G()) == 16, "Failed to delete parts"


def test_drop_part(schema_simp_pop):
Expand Down
12 changes: 8 additions & 4 deletions tests/test_erd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datajoint as dj
from .schema_simple import LOCALS_SIMPLE, A, B, D, E, L, OutfitLaunch
from .schema_simple import LOCALS_SIMPLE, A, B, D, E, G, L, OutfitLaunch
from .schema_advanced import *


Expand All @@ -20,7 +20,7 @@ def test_dependencies(schema_simp):
assert set(D().parents(primary=True)) == set([A.full_table_name])
assert set(D().parents(primary=False)) == set([L.full_table_name])
assert set(deps.descendants(L.full_table_name)).issubset(
cls.full_table_name for cls in (L, D, E, E.F)
cls.full_table_name for cls in (L, D, E, E.F, E.G, E.H, G)
)


Expand All @@ -38,10 +38,14 @@ def test_erd_algebra(schema_simp):
erd3 = erd1 * erd2
erd4 = (erd0 + E).add_parts() - B - E
assert erd0.nodes_to_show == set(cls.full_table_name for cls in [B])
assert erd1.nodes_to_show == set(cls.full_table_name for cls in (B, B.C, E, E.F))
assert erd1.nodes_to_show == set(
cls.full_table_name for cls in (B, B.C, E, E.F, E.G, E.H, G)
)
assert erd2.nodes_to_show == set(cls.full_table_name for cls in (A, B, D, E, L))
assert erd3.nodes_to_show == set(cls.full_table_name for cls in (B, E))
assert erd4.nodes_to_show == set(cls.full_table_name for cls in (B.C, E.F))
assert erd4.nodes_to_show == set(
cls.full_table_name for cls in (B.C, E.F, E.G, E.H)
)


def test_repr_svg(schema_adv):
Expand Down
9 changes: 7 additions & 2 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def test_list_tables(schema_simp):
"""
https://github.com/datajoint/datajoint-python/issues/838
"""
assert set(
expected = set(
[
"reserved_word",
"#l",
Expand All @@ -194,6 +194,9 @@ def test_list_tables(schema_simp):
"__b__c",
"__e",
"__e__f",
"__e__g",
"__e__h",
"__g",
"#outfit_launch",
"#outfit_launch__outfit_piece",
"#i_j",
Expand All @@ -207,7 +210,9 @@ def test_list_tables(schema_simp):
"profile",
"profile__website",
]
) == set(schema_simp.list_tables())
)
actual = set(schema_simp.list_tables())
assert actual == expected, f"Missing from list_tables(): {expected - actual}"


def test_schema_save_any(schema_any):
Expand Down