diff --git a/tests/conftest.py b/tests/conftest.py index 0b72cd03d..d11652540 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1166,3 +1166,18 @@ def func_2d_from_csv(): source="tests/fixtures/function/2d.csv", ) return func + + +@pytest.fixture +def lambda_quad_func(): + """Create a lambda function based on a string. + + Returns + ------- + Function + A lambda function based on a string. + """ + func = lambda x: x**2 + return Function( + source=func, + ) diff --git a/tests/unit/test_function.py b/tests/unit/test_function.py index e41d30f04..de8a1f4ac 100644 --- a/tests/unit/test_function.py +++ b/tests/unit/test_function.py @@ -1,3 +1,9 @@ +"""Unit tests for the Function class. Each method in tis module tests an +individual method of the Function class. The tests are made on both the +expected behaviour and the return instances.""" + +import os + import numpy as np import pytest @@ -158,3 +164,63 @@ def test_get_value_opt(x, y, z): func = Function(source, interpolation="shepard", extrapolation="natural") assert isinstance(func.get_value_opt(x, y), float) assert np.isclose(func.get_value_opt(x, y), z, atol=1e-6) + + +@pytest.mark.parametrize( + "func", + [ + "linearly_interpolated_func", + "spline_interpolated_func", + "func_2d_from_csv", + "lambda_quad_func", + ], +) +def test_savetxt(request, func): + """Test the savetxt method of various Function objects. + + This test function verifies that the `savetxt` method correctly writes the + function's data to a CSV file and that a new function object created from + this file has the same data as the original function object. + + Notes + ----- + The test performs the following steps: + 1. It invokes the `savetxt` method of the given function object. + 2. It then reads this file to create a new function object. + 3. The test asserts that the data of the new function matches the original. + 4. Finally, the test cleans up by removing the created CSV file. + + Raises + ------ + AssertionError + If the `savetxt` method fails to save the file, or if the data of the + newly read function does not match the data of the original function. + """ + func = request.getfixturevalue(func) + assert ( + func.savetxt( + filename="test_func.csv", + lower=0, + upper=9, + samples=10, + fmt="%.6f", + delimiter=",", + newline="\n", + encoding=None, + ) + is None + ), "Couldn't save the file using the Function.savetxt method." + + read_func = Function( + "test_func.csv", interpolation="linear", extrapolation="natural" + ) + if callable(func.source): + source = np.column_stack( + (np.linspace(0, 9, 10), func.source(np.linspace(0, 9, 10))) + ) + assert np.allclose(source, read_func.source) + else: + assert np.allclose(func.source, read_func.source) + + # clean up the file + os.remove("test_func.csv")