From 16e2ffc44118089201a7d52cf0b42ac10530b027 Mon Sep 17 00:00:00 2001 From: Marcel Stimberg Date: Thu, 16 Nov 2023 12:02:02 +0100 Subject: [PATCH] Test summed variable error for overlapping subgroups --- brian2/tests/test_subgroup.py | 69 +++++++++++++++++++++++++++-------- 1 file changed, 54 insertions(+), 15 deletions(-) diff --git a/brian2/tests/test_subgroup.py b/brian2/tests/test_subgroup.py index 6ca472912..e67683559 100644 --- a/brian2/tests/test_subgroup.py +++ b/brian2/tests/test_subgroup.py @@ -648,7 +648,7 @@ def test_subgroup_summed_variable(): target = NeuronGroup( 7, """Iin : 1 - x : 1""", + x : 1""", ) source.x = 5 target.Iin = 10 @@ -668,6 +668,45 @@ def test_subgroup_summed_variable(): assert_array_equal(target.Iin, [5, 6, 10, 8, 9, 10, 11]) +@pytest.mark.codegen_independent +def test_subgroup_summed_variable_overlap(): + # Check that overlapping subgroups raise an error + source = NeuronGroup(1, "") + target = NeuronGroup(10, "Iin : 1") + target1 = target[1:3] + target2 = target[2:5] + target3 = target[[1, 6]] + target4 = target[[4, 6]] + + syn1 = Synapses(source, target1, "Iin_post = 1 : 1 (summed)") + syn1.connect(True) + + syn2 = Synapses(source, target2, "Iin_post = 2 : 1 (summed)") + syn2.connect(True) + + syn3 = Synapses(source, target3, "Iin_post = 3 : 1 (summed)") + syn3.connect(True) + + syn4 = Synapses(source, target4, "Iin_post = 4 : 1 (summed)") + syn4.connect(True) + + net1 = Network(source, target, syn1, syn2) # overlap between contiguous subgroups + with pytest.raises(NotImplementedError): + net1.run(0 * ms) + + net2 = Network( + source, target, syn1, syn3 + ) # overlap between contiguous and non-contiguous subgroups + with pytest.raises(NotImplementedError): + net2.run(0 * ms) + + net3 = Network( + source, target, syn3, syn4 + ) # overlap between non-contiguous subgroups + with pytest.raises(NotImplementedError): + net3.run(0 * ms) + + def test_subexpression_references(): """ Assure that subexpressions in targeted groups are handled correctly. @@ -675,7 +714,7 @@ def test_subexpression_references(): G = NeuronGroup( 10, """v : 1 - v2 = 2*v : 1""", + v2 = 2*v : 1""", ) G.v = np.arange(10) SG1 = G[:5] @@ -685,8 +724,8 @@ def test_subexpression_references(): SG1, SG2, """w : 1 - u = v2_post + 1 : 1 - x = v2_pre + 1 : 1""", + u = v2_post + 1 : 1 + x = v2_pre + 1 : 1""", ) S1.connect("i==(5-1-j)") assert_equal(S1.i[:], np.arange(5)) @@ -698,8 +737,8 @@ def test_subexpression_references(): G, SG2, """w : 1 - u = v2_post + 1 : 1 - x = v2_pre + 1 : 1""", + u = v2_post + 1 : 1 + x = v2_pre + 1 : 1""", ) S2.connect("i==(5-1-j)") assert_equal(S2.i[:], np.arange(5)) @@ -711,8 +750,8 @@ def test_subexpression_references(): SG1, G, """w : 1 - u = v2_post + 1 : 1 - x = v2_pre + 1 : 1""", + u = v2_post + 1 : 1 + x = v2_pre + 1 : 1""", ) S3.connect("i==(10-1-j)") assert_equal(S3.i[:], np.arange(5)) @@ -729,7 +768,7 @@ def test_subexpression_no_references(): G = NeuronGroup( 10, """v : 1 - v2 = 2*v : 1""", + v2 = 2*v : 1""", ) G.v = np.arange(10) @@ -739,8 +778,8 @@ def test_subexpression_no_references(): G[:5], G[5:], """w : 1 - u = v2_post + 1 : 1 - x = v2_pre + 1 : 1""", + u = v2_post + 1 : 1 + x = v2_pre + 1 : 1""", ) S1.connect("i==(5-1-j)") assert_equal(S1.i[:], np.arange(5)) @@ -752,8 +791,8 @@ def test_subexpression_no_references(): G, G[5:], """w : 1 - u = v2_post + 1 : 1 - x = v2_pre + 1 : 1""", + u = v2_post + 1 : 1 + x = v2_pre + 1 : 1""", ) S2.connect("i==(5-1-j)") assert_equal(S2.i[:], np.arange(5)) @@ -765,8 +804,8 @@ def test_subexpression_no_references(): G[:5], G, """w : 1 - u = v2_post + 1 : 1 - x = v2_pre + 1 : 1""", + u = v2_post + 1 : 1 + x = v2_pre + 1 : 1""", ) S3.connect("i==(10-1-j)") assert_equal(S3.i[:], np.arange(5))