diff --git a/tests/unit/test_function.py b/tests/unit/test_function.py index e41d30f04..c1a1adc9d 100644 --- a/tests/unit/test_function.py +++ b/tests/unit/test_function.py @@ -158,3 +158,55 @@ def test_get_value_opt(x, y, z): func = Function(source, interpolation="shepard", extrapolation="natural") assert isinstance(func.get_value_opt(x, y), float) assert np.isclose(func.get_value_opt(x, y), z, atol=1e-6) + + +@pytest.mark.parametrize("samples", [2, 50, 1000]) +def test_set_discrete_mutator(samples): + """Tests the set_discrete method of the Function class.""" + func = Function(lambda x: x**3) + discretized_func = func.set_discrete(-10, 10, samples, mutate_self=True) + + assert isinstance(discretized_func, Function) + assert isinstance(func, Function) + assert discretized_func.source.shape == (samples, 2) + assert func.source.shape == (samples, 2) + + +@pytest.mark.parametrize("samples", [2, 50, 1000]) +def test_set_discrete_non_mutator(samples): + """Tests the set_discrete method of the Function class. + The mutator argument is set to False. + """ + func = Function(lambda x: x**3) + discretized_func = func.set_discrete(-10, 10, samples, mutate_self=False) + + assert isinstance(discretized_func, Function) + assert isinstance(func, Function) + assert discretized_func.source.shape == (samples, 2) + assert callable(func.source) + + +def test_set_discrete_based_on_model_mutator(linear_func): + """Tests the set_discrete_based_on_model method of the Function class. + The mutator argument is set to True. + """ + func = Function(lambda x: x**3) + discretized_func = func.set_discrete_based_on_model(linear_func, mutate_self=True) + + assert isinstance(discretized_func, Function) + assert isinstance(func, Function) + assert discretized_func.source.shape == (4, 2) + assert func.source.shape == (4, 2) + + +def test_set_discrete_based_on_model_non_mutator(linear_func): + """Tests the set_discrete_based_on_model method of the Function class. + The mutator argument is set to False. + """ + func = Function(lambda x: x**3) + discretized_func = func.set_discrete_based_on_model(linear_func, mutate_self=False) + + assert isinstance(discretized_func, Function) + assert isinstance(func, Function) + assert discretized_func.source.shape == (4, 2) + assert callable(func.source)