Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor filenames handling in Field.from_netcdf #1787

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
109 changes: 83 additions & 26 deletions parcels/field.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import collections
from __future__ import annotations

import math
import warnings
from collections.abc import Iterable
from ctypes import POINTER, Structure, c_float, c_int, pointer
from glob import glob
from pathlib import Path
from typing import TYPE_CHECKING, Literal

Expand Down Expand Up @@ -50,6 +52,9 @@

from parcels.fieldset import FieldSet

T_Dimensions = Literal["lon", "lat", "depth", "data"]
T_SanitizedFilenames = list[str] | dict[T_Dimensions, list[str]]

__all__ = ["Field", "NestedField", "VectorField"]


Expand Down Expand Up @@ -426,22 +431,8 @@

@classmethod
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
def get_dim_filenames(cls, *args, **kwargs):
return cls._get_dim_filenames(*args, **kwargs)

@classmethod
def _get_dim_filenames(cls, filenames, dim):
if isinstance(filenames, str) or not isinstance(filenames, collections.abc.Iterable):
return [filenames]
elif isinstance(filenames, dict):
assert dim in filenames.keys(), "filename dimension keys must be lon, lat, depth or data"
filename = filenames[dim]
if isinstance(filename, str):
return [filename]
else:
return filename
else:
return filenames
def get_dim_filenames(*args, **kwargs):
return _get_dim_filenames(*args, **kwargs)

Check warning on line 435 in parcels/field.py

View check run for this annotation

Codecov / codecov/patch

parcels/field.py#L435

Added line #L435 was not covered by tests

@staticmethod
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
Expand Down Expand Up @@ -498,7 +489,7 @@
time_periodic: TimePeriodic = False,
deferred_load: bool = True,
**kwargs,
) -> "Field":
) -> Field:
"""Create field from netCDF file.

Parameters
Expand Down Expand Up @@ -558,6 +549,8 @@
* `Timestamps <../examples/tutorial_timestamps.ipynb>`__

"""
filenames = _sanitize_field_filenames(filenames)

if kwargs.get("netcdf_decodewarning") is not None:
_deprecated_param_netcdf_decodewarning()
kwargs.pop("netcdf_decodewarning")
Expand Down Expand Up @@ -598,20 +591,20 @@
len(variable) == 2
), "The variable tuple must have length 2. Use FieldSet.from_netcdf() for multiple variables"

data_filenames = cls._get_dim_filenames(filenames, "data")
lonlat_filename = cls._get_dim_filenames(filenames, "lon")
data_filenames = _get_dim_filenames(filenames, "data")
lonlat_filename_lst = _get_dim_filenames(filenames, "lon")
if isinstance(filenames, dict):
assert len(lonlat_filename) == 1
if lonlat_filename != cls._get_dim_filenames(filenames, "lat"):
assert len(lonlat_filename_lst) == 1
if lonlat_filename_lst != _get_dim_filenames(filenames, "lat"):
raise NotImplementedError(
"longitude and latitude dimensions are currently processed together from one single file"
)
lonlat_filename = lonlat_filename[0]
lonlat_filename = lonlat_filename_lst[0]
if "depth" in dimensions:
depth_filename = cls._get_dim_filenames(filenames, "depth")
if isinstance(filenames, dict) and len(depth_filename) != 1:
depth_filename_lst = _get_dim_filenames(filenames, "depth")
if isinstance(filenames, dict) and len(depth_filename_lst) != 1:
raise NotImplementedError("Vertically adaptive meshes not implemented for from_netcdf()")
depth_filename = depth_filename[0]
depth_filename = depth_filename_lst[0]

netcdf_engine = kwargs.pop("netcdf_engine", "netcdf4")
gridindexingtype = kwargs.get("gridindexingtype", "nemo")
Comment on lines +594 to 610
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the *_lst additions are just to make mypy happy

Expand Down Expand Up @@ -2584,3 +2577,67 @@
else:
pass
return val


def _get_dim_filenames(filenames: T_SanitizedFilenames, dim: T_Dimensions) -> list[str]:
"""Get's the relevant filenames for a given dimension."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Get's the relevant filenames for a given dimension."""
"""Get the relevant filenames for a given dimension."""

if isinstance(filenames, list):
return filenames

if isinstance(filenames, dict):
return filenames[dim]

raise ValueError("Filenames must be a string, pathlib.Path, or a dictionary")


def _sanitize_field_filenames(filenames, *, recursed=False) -> T_SanitizedFilenames:
"""The Field initializer can take `filenames` to be of various formats including:

1. a string or Path object. String can be a glob expression.
2. a list of (a)
3. a dictionary mapping with keys 'lon', 'lat', 'depth', 'data' and values of (1) or (2)

This function sanitizes the inputs such that it returns, in the case of:
1. A sorted list of strings with the expanded glob expression
2. A sorted list of strings with the expanded glob expressions
Comment on lines +2601 to +2602
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between 1 and 2? Only whether there are multiple glob expressions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 and 2 are the same. Just that in 1 the input was outside of a list. Perhaps it would be easier to illustrate these with examples actually rather than explain it here.

3. A dictionary with same keys but values as in (1) or (2).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also mention what the recursed flag does?

See tests for examples.
"""
allowed_dimension_keys = ("lon", "lat", "depth", "data")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should time not also be in this list?


if isinstance(filenames, str) or not isinstance(filenames, Iterable):
return sorted(_expand_filename(filenames))

if isinstance(filenames, list):
files = []
for f in filenames:
files.extend(_expand_filename(f))
return sorted(files)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So does it mean we will always sort a list? What if users provide a non-sorted list by intention?


if isinstance(filenames, dict):
if recursed:
raise ValueError("Invalid filenames format. Nested dictionary not allowed in dimension dictionary")

for key in filenames:
if key not in allowed_dimension_keys:
raise ValueError(
f"Invalid key in filenames dimension dictionary. Must be one of {allowed_dimension_keys}"
)
filenames[key] = _sanitize_field_filenames(filenames[key], recursed=True)

return filenames

raise ValueError("Filenames must be a string, pathlib.Path, list, or a dictionary")


def _expand_filename(filename: str | Path) -> list[str]:
"""
Converts a filename to a list of filenames if it is a glob expression.

If a file is explicitly provided (i.e., not via glob), existence is only checked later.
"""
filename = str(filename)
if "*" in filename:
return glob(filename)
return [filename]
27 changes: 7 additions & 20 deletions parcels/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import sys
import warnings
from copy import deepcopy
from glob import glob

import dask.array as da
import numpy as np
Expand Down Expand Up @@ -348,19 +347,9 @@
@classmethod
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
def parse_wildcards(cls, *args, **kwargs):
return cls._parse_wildcards(*args, **kwargs)

@classmethod
def _parse_wildcards(cls, paths, filenames, var):
if not isinstance(paths, list):
paths = sorted(glob(str(paths)))
if len(paths) == 0:
notfound_paths = filenames[var] if isinstance(filenames, dict) and var in filenames else filenames
raise OSError(f"FieldSet files not found for variable {var}: {notfound_paths}")
for fp in paths:
if not os.path.exists(fp):
raise OSError(f"FieldSet file not found: {fp}")
return paths
raise NotImplementedError(
"parse_wildcards was removed as a function as the internal implementation was no longer used."

Check warning on line 351 in parcels/fieldset.py

View check run for this annotation

Codecov / codecov/patch

parcels/fieldset.py#L350-L351

Added lines #L350 - L351 were not covered by tests
)

@classmethod
def from_netcdf(
Expand Down Expand Up @@ -474,13 +463,11 @@
if "creation_log" not in kwargs.keys():
kwargs["creation_log"] = "from_netcdf"
for var, name in variables.items():
# Resolve all matching paths for the current variable
paths = filenames[var] if type(filenames) is dict and var in filenames else filenames
if type(paths) is not dict:
paths = cls._parse_wildcards(paths, filenames, var)
paths: list[str]
if isinstance(filenames, dict) and var in filenames:
paths = filenames[var]
else:
for dim, p in paths.items():
paths[dim] = cls._parse_wildcards(p, filenames, var)
paths = filenames

# Use dimensions[var] and indices[var] if either of them is a dict of dicts
dims = dimensions[var] if var in dimensions else dimensions
Expand Down
1 change: 0 additions & 1 deletion tests/test_advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def test_advection_zonal(lon, lat, depth, mode):
}
dimensions = {"lon": lon, "lat": lat}
fieldset2D = FieldSet.from_data(data2D, dimensions, mesh="spherical", transpose=True)
assert fieldset2D.U._creation_log == "from_data"
VeckoTheGecko marked this conversation as resolved.
Show resolved Hide resolved

pset2D = ParticleSet(fieldset2D, pclass=ptype[mode], lon=np.zeros(npart) + 20.0, lat=np.linspace(0, 80, npart))
pset2D.execute(AdvectionRK4, runtime=timedelta(hours=2), dt=timedelta(seconds=30))
Expand Down
4 changes: 2 additions & 2 deletions tests/test_deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_testing_action_class():
Action("Field", "c_data_chunks", "make_private" ),
Action("Field", "chunk_set", "make_private" ),
Action("Field", "cell_edge_sizes", "read_only" ),
Action("Field", "get_dim_filenames()", "make_private" ),
Action("Field", "get_dim_filenames()", "make_private" , skip_reason="Moved underlying function."),
Action("Field", "collect_timeslices()", "make_private" ),
Action("Field", "reshape()", "make_private" ),
Action("Field", "calc_cell_edge_sizes()", "make_private" ),
Expand All @@ -148,7 +148,7 @@ def test_testing_action_class():
Action("FieldSet", "particlefile", "read_only" ),
Action("FieldSet", "add_UVfield()", "make_private" ),
Action("FieldSet", "check_complete()", "make_private" ),
Action("FieldSet", "parse_wildcards()", "make_private" ),
Action("FieldSet", "parse_wildcards()", "make_private" , skip_reason="Moved underlying function."),

# 1713
Action("ParticleSet", "repeat_starttime", "make_private" ),
Expand Down
142 changes: 142 additions & 0 deletions tests/test_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from pathlib import Path

import cftime
import numpy as np
import pytest
import xarray as xr

from parcels import Field
from parcels.field import _expand_filename, _sanitize_field_filenames
from parcels.tools.converters import (
_get_cftime_calendars,
_get_cftime_datetimes,
)
from tests.utils import TEST_DATA


def test_field_from_netcdf_variables():
filename = str(TEST_DATA / "perlinfieldsU.nc")
dims = {"lon": "x", "lat": "y"}

variable = "vozocrtx"
f1 = Field.from_netcdf(filename, variable, dims)
variable = ("U", "vozocrtx")
f2 = Field.from_netcdf(filename, variable, dims)
variable = {"U": "vozocrtx"}
f3 = Field.from_netcdf(filename, variable, dims)

assert np.allclose(f1.data, f2.data, atol=1e-12)
assert np.allclose(f1.data, f3.data, atol=1e-12)

with pytest.raises(AssertionError):
variable = {"U": "vozocrtx", "nav_lat": "nav_lat"} # multiple variables will fail
f3 = Field.from_netcdf(filename, variable, dims)


@pytest.mark.parametrize("with_timestamps", [True, False])
def test_field_from_netcdf(with_timestamps):
filenames = {
"lon": str(TEST_DATA / "mask_nemo_cross_180lon.nc"),
"lat": str(TEST_DATA / "mask_nemo_cross_180lon.nc"),
"data": str(TEST_DATA / "Uu_eastward_nemo_cross_180lon.nc"),
}
variable = "U"
dimensions = {"lon": "glamf", "lat": "gphif"}
if with_timestamps:
timestamp_types = [[[2]], [[np.datetime64("2000-01-01")]]]
for timestamps in timestamp_types:
Field.from_netcdf(filenames, variable, dimensions, interp_method="cgrid_velocity", timestamps=timestamps)
else:
Field.from_netcdf(filenames, variable, dimensions, interp_method="cgrid_velocity")


@pytest.mark.parametrize(
"f",
[
pytest.param(lambda x: x, id="Path"),
pytest.param(lambda x: str(x), id="str"),
],
)
def test_from_netcdf_path_object(f):
filenames = {
"lon": f(TEST_DATA / "mask_nemo_cross_180lon.nc"),
"lat": f(TEST_DATA / "mask_nemo_cross_180lon.nc"),
"data": f(TEST_DATA / "Uu_eastward_nemo_cross_180lon.nc"),
}
variable = "U"
dimensions = {"lon": "glamf", "lat": "gphif"}

Field.from_netcdf(filenames, variable, dimensions, interp_method="cgrid_velocity")


@pytest.mark.parametrize(
"calendar, cftime_datetime", zip(_get_cftime_calendars(), _get_cftime_datetimes(), strict=True)
)
def test_field_nonstandardtime(calendar, cftime_datetime, tmpdir):
xdim = 4
ydim = 6
filepath = tmpdir.join("test_nonstandardtime.nc")
dates = [getattr(cftime, cftime_datetime)(1, m, 1) for m in range(1, 13)]
da = xr.DataArray(
np.random.rand(12, xdim, ydim), coords=[dates, range(xdim), range(ydim)], dims=["time", "lon", "lat"], name="U"
)
da.to_netcdf(str(filepath))

dims = {"lon": "lon", "lat": "lat", "time": "time"}
try:
field = Field.from_netcdf(filepath, "U", dims)
except NotImplementedError:
field = None

Check warning on line 89 in tests/test_field.py

View check run for this annotation

Codecov / codecov/patch

tests/test_field.py#L88-L89

Added lines #L88 - L89 were not covered by tests

if field is not None:
assert field.grid.time_origin.calendar == calendar


@pytest.mark.parametrize(
"input_,expected",
[
pytest.param("file1.nc", ["file1.nc"], id="str"),
pytest.param(["file1.nc", "file2.nc"], ["file1.nc", "file2.nc"], id="list"),
pytest.param(["file2.nc", "file1.nc"], ["file1.nc", "file2.nc"], id="list-unsorted"),
pytest.param([Path("file1.nc"), Path("file2.nc")], ["file1.nc", "file2.nc"], id="list-Path"),
pytest.param(
{
"lon": "lon_file.nc",
"lat": ["lat_file1.nc", Path("lat_file2.nc")],
"depth": Path("depth_file.nc"),
"data": ["data_file1.nc", "data_file2.nc"],
},
{
"lon": ["lon_file.nc"],
"lat": ["lat_file1.nc", "lat_file2.nc"],
"depth": ["depth_file.nc"],
"data": ["data_file1.nc", "data_file2.nc"],
},
id="dict-mix",
),
],
)
def test_sanitize_field_filenames_cases(input_, expected):
assert _sanitize_field_filenames(input_) == expected


@pytest.mark.parametrize(
"input_,expected",
[
pytest.param("file*.nc", [], id="glob-no-match"),
],
)
def test_sanitize_field_filenames_glob(input_, expected):
assert _sanitize_field_filenames(input_) == expected


@pytest.mark.parametrize(
"input_,expected",
[
pytest.param("test", ["test"], id="str"),
pytest.param(Path("test"), ["test"], id="Path"),
pytest.param("file*.nc", [], id="glob-no-match"),
],
)
def test_expand_filename(input_, expected):
assert _expand_filename(input_) == expected
Loading
Loading