diff --git a/CHANGELOG.md b/CHANGELOG.md index 771e8ca36..602739494 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ straightforward as possible. ### Added +- ENH: Function Support for CSV Header Inputs [#542](https://github.com/RocketPy-Team/RocketPy/pull/542) - ENH: Shepard Optimized Interpolation - Multiple Inputs Support [#515](https://github.com/RocketPy-Team/RocketPy/pull/515) - ENH: adds new Function.savetxt method [#514](https://github.com/RocketPy-Team/RocketPy/pull/514) - ENH: Argument for Optional Mutation on Function Discretize [#519](https://github.com/RocketPy-Team/RocketPy/pull/519) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index c899f3cad..048b125e2 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -59,7 +59,8 @@ def __init__( and 'z' is the output. - string: Path to a CSV file. The file is read and converted into an - ndarray. The file can optionally contain a single header line. + ndarray. The file can optionally contain a single header line, see + notes below for more information. - Function: Copies the source of the provided Function object, creating a new Function with adjusted inputs and outputs. @@ -94,12 +95,19 @@ def __init__( Notes ----- - (I) CSV files can optionally contain a single header line. If present, - the header is ignored during processing. - (II) Fields in CSV files may be enclosed in double quotes. If fields are - not quoted, double quotes should not appear inside them. + (I) CSV files may include an optional single header line. If this + header line is present and contains names for each data column, those + names will be used to label the inputs and outputs unless specified + otherwise by the `inputs` and `outputs` arguments. + If the header is specified for only a few columns, it is ignored. + + Commas in a header will be interpreted as a delimiter, which may cause + undesired input or output labeling. To avoid this, specify each input + and output name using the `inputs` and `outputs` arguments. + + (II) Fields in CSV files may be enclosed in double quotes. If fields + are not quoted, double quotes should not appear inside them. """ - # Set input and output if inputs is None: inputs = ["Scalar"] if outputs is None: @@ -184,10 +192,18 @@ def set_source(self, source): Notes ----- - (I) CSV files can optionally contain a single header line. If present, - the header is ignored during processing. - (II) Fields in CSV files may be enclosed in double quotes. If fields are - not quoted, double quotes should not appear inside them. + (I) CSV files may include an optional single header line. If this + header line is present and contains names for each data column, those + names will be used to label the inputs and outputs unless specified + otherwise. If the header is specified for only a few columns, it is + ignored. + + Commas in a header will be interpreted as a delimiter, which may cause + undesired input or output labeling. To avoid this, specify each input + and output name using the `inputs` and `outputs` arguments. + + (II) Fields in CSV files may be enclosed in double quotes. If fields + are not quoted, double quotes should not appear inside them. Returns ------- @@ -3019,7 +3035,7 @@ def _check_user_input( if isinstance(inputs, str): inputs = [inputs] - elif len(outputs) > 1: + if len(outputs) > 1: raise ValueError( "Output must either be a string or have dimension 1, " + f"it currently has dimension ({len(outputs)})." @@ -3036,8 +3052,19 @@ def _check_user_input( try: source = np.loadtxt(source, delimiter=",", dtype=float) except ValueError: - # Skip header - source = np.loadtxt(source, delimiter=",", dtype=float, skiprows=1) + with open(source, "r") as file: + header, *data = file.read().splitlines() + + header = [ + label.strip("'").strip('"') for label in header.split(",") + ] + source = np.loadtxt(data, delimiter=",", dtype=float) + + if len(source[0]) == len(header): + if inputs == ["Scalar"]: + inputs = header[:-1] + if outputs == ["Scalar"]: + outputs = [header[-1]] except Exception as e: raise ValueError( "The source file is not a valid csv or txt file." @@ -3055,7 +3082,7 @@ def _check_user_input( ## single dimension if source_dim == 2: - # possible interpolation values: llinear, polynomial, akima and spline + # possible interpolation values: linear, polynomial, akima and spline if interpolation is None: interpolation = "spline" elif interpolation.lower() not in [ @@ -3106,7 +3133,7 @@ def _check_user_input( in_out_dim = len(inputs) + len(outputs) if source_dim != in_out_dim: raise ValueError( - "Source dimension ({source_dim}) does not match input " + f"Source dimension ({source_dim}) does not match input " + f"and output dimension ({in_out_dim})." ) return inputs, outputs, interpolation, extrapolation diff --git a/tests/test_function.py b/tests/test_function.py index 6896b01af..2ce94f691 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -49,7 +49,7 @@ def test_func_from_csv_with_header(csv_file): line. It tests cases where the fields are separated by quotes and without quotes.""" f = Function(csv_file) - assert f.__repr__() == "'Function from R1 to R1 : (Scalar) → (Scalar)'" + assert f.__repr__() == "'Function from R1 to R1 : (time) → (value)'" assert np.isclose(f(0), 100) assert np.isclose(f(0) + f(1), 300), "Error summing the values of the function"