diff --git a/tests/test_dimension.py b/tests/test_dimension.py index da2d98252a..5ed0bd08a5 100644 --- a/tests/test_dimension.py +++ b/tests/test_dimension.py @@ -1288,6 +1288,20 @@ def test_no_index_symbolic(self): op = Operator(eq) op.cfunction + @pytest.mark.parametrize('value', [0, 1]) + def test_constant_as_condition(self, value): + x = Dimension('x') + + c = Constant(name="c", dtype=np.int8, value=value) + cd = ConditionalDimension(name="cd", parent=x, condition=c) + + f = Function(name='f', dimensions=(x,), shape=(11,), dtype=np.int32) + + op = Operator(Eq(f, 1, implicit_dims=cd)) + op.apply() + + assert np.all(f.data == value) + def test_symbolic_factor(self): """ Test ConditionalDimension with symbolic factor (provided as a Constant).