Skip to content

Commit

Permalink
ENH: Function inputs from CSV file header.
Browse files Browse the repository at this point in the history
  • Loading branch information
phmbressan committed Jan 27, 2024
1 parent 2fee77e commit 058ab49
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
19 changes: 13 additions & 6 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def __init__(
(II) Fields in CSV files may be enclosed in double quotes. If fields are
not quoted, double quotes should not appear inside them.
"""
# Set input and output
if inputs is None:
inputs = ["Scalar"]
if outputs is None:
Expand Down Expand Up @@ -3018,7 +3017,7 @@ def _check_user_input(
if isinstance(inputs, str):
inputs = [inputs]

elif len(outputs) > 1:
if len(outputs) > 1:
raise ValueError(
"Output must either be a string or have dimension 1, "
+ f"it currently has dimension ({len(outputs)})."
Expand All @@ -3035,8 +3034,16 @@ def _check_user_input(
try:
source = np.loadtxt(source, delimiter=",", dtype=float)
except ValueError:
# Skip header
source = np.loadtxt(source, delimiter=",", dtype=float, skiprows=1)
with open(source, "r") as file:
header, *data = file.read().splitlines()
header = [
label.strip("'").strip('"') for label in header.split(",")
]
if inputs == ["Scalar"]:
inputs = header[:-1]
if outputs == ["Scalar"]:
outputs = [header[-1]]
source = np.loadtxt(data, delimiter=",", dtype=float)
except Exception as e:
raise ValueError(
"The source file is not a valid csv or txt file."
Expand All @@ -3054,7 +3061,7 @@ def _check_user_input(

## single dimension
if source_dim == 2:
# possible interpolation values: llinear, polynomial, akima and spline
# possible interpolation values: linear, polynomial, akima and spline
if interpolation is None:
interpolation = "spline"
elif interpolation.lower() not in [
Expand Down Expand Up @@ -3105,7 +3112,7 @@ def _check_user_input(
in_out_dim = len(inputs) + len(outputs)
if source_dim != in_out_dim:
raise ValueError(
"Source dimension ({source_dim}) does not match input "
f"Source dimension ({source_dim}) does not match input "
+ f"and output dimension ({in_out_dim})."
)
return inputs, outputs, interpolation, extrapolation
Expand Down
2 changes: 1 addition & 1 deletion tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_func_from_csv_with_header(csv_file):
line. It tests cases where the fields are separated by quotes and without
quotes."""
f = Function(csv_file)
assert f.__repr__() == "'Function from R1 to R1 : (Scalar) → (Scalar)'"
assert f.__repr__() == "'Function from R1 to R1 : (time) → (value)'"
assert np.isclose(f(0), 100)
assert np.isclose(f(0) + f(1), 300), "Error summing the values of the function"

Expand Down

0 comments on commit 058ab49

Please sign in to comment.