diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index 7981bbbfa..2ae7e077f 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -98,7 +98,6 @@ def __init__( (II) Fields in CSV files may be enclosed in double quotes. If fields are not quoted, double quotes should not appear inside them. """ - # Set input and output if inputs is None: inputs = ["Scalar"] if outputs is None: @@ -3018,7 +3017,7 @@ def _check_user_input( if isinstance(inputs, str): inputs = [inputs] - elif len(outputs) > 1: + if len(outputs) > 1: raise ValueError( "Output must either be a string or have dimension 1, " + f"it currently has dimension ({len(outputs)})." @@ -3035,8 +3034,16 @@ def _check_user_input( try: source = np.loadtxt(source, delimiter=",", dtype=float) except ValueError: - # Skip header - source = np.loadtxt(source, delimiter=",", dtype=float, skiprows=1) + with open(source, "r") as file: + header, *data = file.read().splitlines() + header = [ + label.strip("'").strip('"') for label in header.split(",") + ] + if inputs == ["Scalar"]: + inputs = header[:-1] + if outputs == ["Scalar"]: + outputs = [header[-1]] + source = np.loadtxt(data, delimiter=",", dtype=float) except Exception as e: raise ValueError( "The source file is not a valid csv or txt file." @@ -3054,7 +3061,7 @@ def _check_user_input( ## single dimension if source_dim == 2: - # possible interpolation values: llinear, polynomial, akima and spline + # possible interpolation values: linear, polynomial, akima and spline if interpolation is None: interpolation = "spline" elif interpolation.lower() not in [ @@ -3105,7 +3112,7 @@ def _check_user_input( in_out_dim = len(inputs) + len(outputs) if source_dim != in_out_dim: raise ValueError( - "Source dimension ({source_dim}) does not match input " + f"Source dimension ({source_dim}) does not match input " + f"and output dimension ({in_out_dim})." ) return inputs, outputs, interpolation, extrapolation diff --git a/tests/test_function.py b/tests/test_function.py index 6896b01af..2ce94f691 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -49,7 +49,7 @@ def test_func_from_csv_with_header(csv_file): line. It tests cases where the fields are separated by quotes and without quotes.""" f = Function(csv_file) - assert f.__repr__() == "'Function from R1 to R1 : (Scalar) → (Scalar)'" + assert f.__repr__() == "'Function from R1 to R1 : (time) → (value)'" assert np.isclose(f(0), 100) assert np.isclose(f(0) + f(1), 300), "Error summing the values of the function"