Skip to content

Commit

Permalink
Test summed variable error for overlapping subgroups
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Nov 16, 2023
1 parent ce5bb3b commit 16e2ffc
Showing 1 changed file with 54 additions and 15 deletions.
69 changes: 54 additions & 15 deletions brian2/tests/test_subgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ def test_subgroup_summed_variable():
target = NeuronGroup(
7,
"""Iin : 1
x : 1""",
x : 1""",
)
source.x = 5
target.Iin = 10
Expand All @@ -668,14 +668,53 @@ def test_subgroup_summed_variable():
assert_array_equal(target.Iin, [5, 6, 10, 8, 9, 10, 11])


@pytest.mark.codegen_independent
def test_subgroup_summed_variable_overlap():
# Check that overlapping subgroups raise an error
source = NeuronGroup(1, "")
target = NeuronGroup(10, "Iin : 1")
target1 = target[1:3]
target2 = target[2:5]
target3 = target[[1, 6]]
target4 = target[[4, 6]]

syn1 = Synapses(source, target1, "Iin_post = 1 : 1 (summed)")
syn1.connect(True)

syn2 = Synapses(source, target2, "Iin_post = 2 : 1 (summed)")
syn2.connect(True)

syn3 = Synapses(source, target3, "Iin_post = 3 : 1 (summed)")
syn3.connect(True)

syn4 = Synapses(source, target4, "Iin_post = 4 : 1 (summed)")
syn4.connect(True)

net1 = Network(source, target, syn1, syn2) # overlap between contiguous subgroups
with pytest.raises(NotImplementedError):
net1.run(0 * ms)

net2 = Network(
source, target, syn1, syn3
) # overlap between contiguous and non-contiguous subgroups
with pytest.raises(NotImplementedError):
net2.run(0 * ms)

net3 = Network(
source, target, syn3, syn4
) # overlap between non-contiguous subgroups
with pytest.raises(NotImplementedError):
net3.run(0 * ms)


def test_subexpression_references():
"""
Assure that subexpressions in targeted groups are handled correctly.
"""
G = NeuronGroup(
10,
"""v : 1
v2 = 2*v : 1""",
v2 = 2*v : 1""",
)
G.v = np.arange(10)
SG1 = G[:5]
Expand All @@ -685,8 +724,8 @@ def test_subexpression_references():
SG1,
SG2,
"""w : 1
u = v2_post + 1 : 1
x = v2_pre + 1 : 1""",
u = v2_post + 1 : 1
x = v2_pre + 1 : 1""",
)
S1.connect("i==(5-1-j)")
assert_equal(S1.i[:], np.arange(5))
Expand All @@ -698,8 +737,8 @@ def test_subexpression_references():
G,
SG2,
"""w : 1
u = v2_post + 1 : 1
x = v2_pre + 1 : 1""",
u = v2_post + 1 : 1
x = v2_pre + 1 : 1""",
)
S2.connect("i==(5-1-j)")
assert_equal(S2.i[:], np.arange(5))
Expand All @@ -711,8 +750,8 @@ def test_subexpression_references():
SG1,
G,
"""w : 1
u = v2_post + 1 : 1
x = v2_pre + 1 : 1""",
u = v2_post + 1 : 1
x = v2_pre + 1 : 1""",
)
S3.connect("i==(10-1-j)")
assert_equal(S3.i[:], np.arange(5))
Expand All @@ -729,7 +768,7 @@ def test_subexpression_no_references():
G = NeuronGroup(
10,
"""v : 1
v2 = 2*v : 1""",
v2 = 2*v : 1""",
)
G.v = np.arange(10)

Expand All @@ -739,8 +778,8 @@ def test_subexpression_no_references():
G[:5],
G[5:],
"""w : 1
u = v2_post + 1 : 1
x = v2_pre + 1 : 1""",
u = v2_post + 1 : 1
x = v2_pre + 1 : 1""",
)
S1.connect("i==(5-1-j)")
assert_equal(S1.i[:], np.arange(5))
Expand All @@ -752,8 +791,8 @@ def test_subexpression_no_references():
G,
G[5:],
"""w : 1
u = v2_post + 1 : 1
x = v2_pre + 1 : 1""",
u = v2_post + 1 : 1
x = v2_pre + 1 : 1""",
)
S2.connect("i==(5-1-j)")
assert_equal(S2.i[:], np.arange(5))
Expand All @@ -765,8 +804,8 @@ def test_subexpression_no_references():
G[:5],
G,
"""w : 1
u = v2_post + 1 : 1
x = v2_pre + 1 : 1""",
u = v2_post + 1 : 1
x = v2_pre + 1 : 1""",
)
S3.connect("i==(10-1-j)")
assert_equal(S3.i[:], np.arange(5))
Expand Down

0 comments on commit 16e2ffc

Please sign in to comment.