-
Notifications
You must be signed in to change notification settings - Fork 141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor filenames
handling in Field.from_netcdf
#1787
base: main
Are you sure you want to change the base?
Changes from all commits
8a9b2c0
9b6720d
4e2710f
4aedbdc
7470597
f4adcb5
c795ba4
dc6d4f8
a6b46d4
8a7de2a
ed60b4b
132183e
d8713c8
9013062
508006d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,8 +1,10 @@ | ||||||
import collections | ||||||
from __future__ import annotations | ||||||
|
||||||
import math | ||||||
import warnings | ||||||
from collections.abc import Iterable | ||||||
from ctypes import POINTER, Structure, c_float, c_int, pointer | ||||||
from glob import glob | ||||||
from pathlib import Path | ||||||
from typing import TYPE_CHECKING, Literal | ||||||
|
||||||
|
@@ -50,6 +52,9 @@ | |||||
|
||||||
from parcels.fieldset import FieldSet | ||||||
|
||||||
T_Dimensions = Literal["lon", "lat", "depth", "data"] | ||||||
T_SanitizedFilenames = list[str] | dict[T_Dimensions, list[str]] | ||||||
|
||||||
__all__ = ["Field", "NestedField", "VectorField"] | ||||||
|
||||||
|
||||||
|
@@ -426,22 +431,8 @@ | |||||
|
||||||
@classmethod | ||||||
@deprecated_made_private # TODO: Remove 6 months after v3.1.0 | ||||||
def get_dim_filenames(cls, *args, **kwargs): | ||||||
return cls._get_dim_filenames(*args, **kwargs) | ||||||
|
||||||
@classmethod | ||||||
def _get_dim_filenames(cls, filenames, dim): | ||||||
if isinstance(filenames, str) or not isinstance(filenames, collections.abc.Iterable): | ||||||
return [filenames] | ||||||
elif isinstance(filenames, dict): | ||||||
assert dim in filenames.keys(), "filename dimension keys must be lon, lat, depth or data" | ||||||
filename = filenames[dim] | ||||||
if isinstance(filename, str): | ||||||
return [filename] | ||||||
else: | ||||||
return filename | ||||||
else: | ||||||
return filenames | ||||||
def get_dim_filenames(*args, **kwargs): | ||||||
return _get_dim_filenames(*args, **kwargs) | ||||||
|
||||||
@staticmethod | ||||||
@deprecated_made_private # TODO: Remove 6 months after v3.1.0 | ||||||
|
@@ -498,7 +489,7 @@ | |||||
time_periodic: TimePeriodic = False, | ||||||
deferred_load: bool = True, | ||||||
**kwargs, | ||||||
) -> "Field": | ||||||
) -> Field: | ||||||
"""Create field from netCDF file. | ||||||
|
||||||
Parameters | ||||||
|
@@ -558,6 +549,8 @@ | |||||
* `Timestamps <../examples/tutorial_timestamps.ipynb>`__ | ||||||
|
||||||
""" | ||||||
filenames = _sanitize_field_filenames(filenames) | ||||||
|
||||||
if kwargs.get("netcdf_decodewarning") is not None: | ||||||
_deprecated_param_netcdf_decodewarning() | ||||||
kwargs.pop("netcdf_decodewarning") | ||||||
|
@@ -598,20 +591,20 @@ | |||||
len(variable) == 2 | ||||||
), "The variable tuple must have length 2. Use FieldSet.from_netcdf() for multiple variables" | ||||||
|
||||||
data_filenames = cls._get_dim_filenames(filenames, "data") | ||||||
lonlat_filename = cls._get_dim_filenames(filenames, "lon") | ||||||
data_filenames = _get_dim_filenames(filenames, "data") | ||||||
lonlat_filename_lst = _get_dim_filenames(filenames, "lon") | ||||||
if isinstance(filenames, dict): | ||||||
assert len(lonlat_filename) == 1 | ||||||
if lonlat_filename != cls._get_dim_filenames(filenames, "lat"): | ||||||
assert len(lonlat_filename_lst) == 1 | ||||||
if lonlat_filename_lst != _get_dim_filenames(filenames, "lat"): | ||||||
raise NotImplementedError( | ||||||
"longitude and latitude dimensions are currently processed together from one single file" | ||||||
) | ||||||
lonlat_filename = lonlat_filename[0] | ||||||
lonlat_filename = lonlat_filename_lst[0] | ||||||
if "depth" in dimensions: | ||||||
depth_filename = cls._get_dim_filenames(filenames, "depth") | ||||||
if isinstance(filenames, dict) and len(depth_filename) != 1: | ||||||
depth_filename_lst = _get_dim_filenames(filenames, "depth") | ||||||
if isinstance(filenames, dict) and len(depth_filename_lst) != 1: | ||||||
raise NotImplementedError("Vertically adaptive meshes not implemented for from_netcdf()") | ||||||
depth_filename = depth_filename[0] | ||||||
depth_filename = depth_filename_lst[0] | ||||||
|
||||||
netcdf_engine = kwargs.pop("netcdf_engine", "netcdf4") | ||||||
gridindexingtype = kwargs.get("gridindexingtype", "nemo") | ||||||
|
@@ -2584,3 +2577,67 @@ | |||||
else: | ||||||
pass | ||||||
return val | ||||||
|
||||||
|
||||||
def _get_dim_filenames(filenames: T_SanitizedFilenames, dim: T_Dimensions) -> list[str]: | ||||||
"""Get's the relevant filenames for a given dimension.""" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
if isinstance(filenames, list): | ||||||
return filenames | ||||||
|
||||||
if isinstance(filenames, dict): | ||||||
return filenames[dim] | ||||||
|
||||||
raise ValueError("Filenames must be a string, pathlib.Path, or a dictionary") | ||||||
|
||||||
|
||||||
def _sanitize_field_filenames(filenames, *, recursed=False) -> T_SanitizedFilenames: | ||||||
"""The Field initializer can take `filenames` to be of various formats including: | ||||||
|
||||||
1. a string or Path object. String can be a glob expression. | ||||||
2. a list of (a) | ||||||
3. a dictionary mapping with keys 'lon', 'lat', 'depth', 'data' and values of (1) or (2) | ||||||
|
||||||
This function sanitizes the inputs such that it returns, in the case of: | ||||||
1. A sorted list of strings with the expanded glob expression | ||||||
2. A sorted list of strings with the expanded glob expressions | ||||||
Comment on lines
+2601
to
+2602
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the difference between 1 and 2? Only whether there are multiple glob expressions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 1 and 2 are the same. Just that in 1 the input was outside of a list. Perhaps it would be easier to illustrate these with examples actually rather than explain it here. |
||||||
3. A dictionary with same keys but values as in (1) or (2). | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also mention what the |
||||||
See tests for examples. | ||||||
""" | ||||||
allowed_dimension_keys = ("lon", "lat", "depth", "data") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should |
||||||
|
||||||
if isinstance(filenames, str) or not isinstance(filenames, Iterable): | ||||||
return sorted(_expand_filename(filenames)) | ||||||
|
||||||
if isinstance(filenames, list): | ||||||
files = [] | ||||||
for f in filenames: | ||||||
files.extend(_expand_filename(f)) | ||||||
return sorted(files) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So does it mean we will always sort a list? What if users provide a non-sorted list by intention? |
||||||
|
||||||
if isinstance(filenames, dict): | ||||||
if recursed: | ||||||
raise ValueError("Invalid filenames format. Nested dictionary not allowed in dimension dictionary") | ||||||
|
||||||
for key in filenames: | ||||||
if key not in allowed_dimension_keys: | ||||||
raise ValueError( | ||||||
f"Invalid key in filenames dimension dictionary. Must be one of {allowed_dimension_keys}" | ||||||
) | ||||||
filenames[key] = _sanitize_field_filenames(filenames[key], recursed=True) | ||||||
|
||||||
return filenames | ||||||
|
||||||
raise ValueError("Filenames must be a string, pathlib.Path, list, or a dictionary") | ||||||
|
||||||
|
||||||
def _expand_filename(filename: str | Path) -> list[str]: | ||||||
""" | ||||||
Converts a filename to a list of filenames if it is a glob expression. | ||||||
|
||||||
If a file is explicitly provided (i.e., not via glob), existence is only checked later. | ||||||
""" | ||||||
filename = str(filename) | ||||||
if "*" in filename: | ||||||
return glob(filename) | ||||||
return [filename] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
from pathlib import Path | ||
|
||
import cftime | ||
import numpy as np | ||
import pytest | ||
import xarray as xr | ||
|
||
from parcels import Field | ||
from parcels.field import _expand_filename, _sanitize_field_filenames | ||
from parcels.tools.converters import ( | ||
_get_cftime_calendars, | ||
_get_cftime_datetimes, | ||
) | ||
from tests.utils import TEST_DATA | ||
|
||
|
||
def test_field_from_netcdf_variables(): | ||
filename = str(TEST_DATA / "perlinfieldsU.nc") | ||
dims = {"lon": "x", "lat": "y"} | ||
|
||
variable = "vozocrtx" | ||
f1 = Field.from_netcdf(filename, variable, dims) | ||
variable = ("U", "vozocrtx") | ||
f2 = Field.from_netcdf(filename, variable, dims) | ||
variable = {"U": "vozocrtx"} | ||
f3 = Field.from_netcdf(filename, variable, dims) | ||
|
||
assert np.allclose(f1.data, f2.data, atol=1e-12) | ||
assert np.allclose(f1.data, f3.data, atol=1e-12) | ||
|
||
with pytest.raises(AssertionError): | ||
variable = {"U": "vozocrtx", "nav_lat": "nav_lat"} # multiple variables will fail | ||
f3 = Field.from_netcdf(filename, variable, dims) | ||
|
||
|
||
@pytest.mark.parametrize("with_timestamps", [True, False]) | ||
def test_field_from_netcdf(with_timestamps): | ||
filenames = { | ||
"lon": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), | ||
"lat": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), | ||
"data": str(TEST_DATA / "Uu_eastward_nemo_cross_180lon.nc"), | ||
} | ||
variable = "U" | ||
dimensions = {"lon": "glamf", "lat": "gphif"} | ||
if with_timestamps: | ||
timestamp_types = [[[2]], [[np.datetime64("2000-01-01")]]] | ||
for timestamps in timestamp_types: | ||
Field.from_netcdf(filenames, variable, dimensions, interp_method="cgrid_velocity", timestamps=timestamps) | ||
else: | ||
Field.from_netcdf(filenames, variable, dimensions, interp_method="cgrid_velocity") | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"f", | ||
[ | ||
pytest.param(lambda x: x, id="Path"), | ||
pytest.param(lambda x: str(x), id="str"), | ||
], | ||
) | ||
def test_from_netcdf_path_object(f): | ||
filenames = { | ||
"lon": f(TEST_DATA / "mask_nemo_cross_180lon.nc"), | ||
"lat": f(TEST_DATA / "mask_nemo_cross_180lon.nc"), | ||
"data": f(TEST_DATA / "Uu_eastward_nemo_cross_180lon.nc"), | ||
} | ||
variable = "U" | ||
dimensions = {"lon": "glamf", "lat": "gphif"} | ||
|
||
Field.from_netcdf(filenames, variable, dimensions, interp_method="cgrid_velocity") | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"calendar, cftime_datetime", zip(_get_cftime_calendars(), _get_cftime_datetimes(), strict=True) | ||
) | ||
def test_field_nonstandardtime(calendar, cftime_datetime, tmpdir): | ||
xdim = 4 | ||
ydim = 6 | ||
filepath = tmpdir.join("test_nonstandardtime.nc") | ||
dates = [getattr(cftime, cftime_datetime)(1, m, 1) for m in range(1, 13)] | ||
da = xr.DataArray( | ||
np.random.rand(12, xdim, ydim), coords=[dates, range(xdim), range(ydim)], dims=["time", "lon", "lat"], name="U" | ||
) | ||
da.to_netcdf(str(filepath)) | ||
|
||
dims = {"lon": "lon", "lat": "lat", "time": "time"} | ||
try: | ||
field = Field.from_netcdf(filepath, "U", dims) | ||
except NotImplementedError: | ||
field = None | ||
|
||
if field is not None: | ||
assert field.grid.time_origin.calendar == calendar | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_,expected", | ||
[ | ||
pytest.param("file1.nc", ["file1.nc"], id="str"), | ||
pytest.param(["file1.nc", "file2.nc"], ["file1.nc", "file2.nc"], id="list"), | ||
pytest.param(["file2.nc", "file1.nc"], ["file1.nc", "file2.nc"], id="list-unsorted"), | ||
pytest.param([Path("file1.nc"), Path("file2.nc")], ["file1.nc", "file2.nc"], id="list-Path"), | ||
pytest.param( | ||
{ | ||
"lon": "lon_file.nc", | ||
"lat": ["lat_file1.nc", Path("lat_file2.nc")], | ||
"depth": Path("depth_file.nc"), | ||
"data": ["data_file1.nc", "data_file2.nc"], | ||
}, | ||
{ | ||
"lon": ["lon_file.nc"], | ||
"lat": ["lat_file1.nc", "lat_file2.nc"], | ||
"depth": ["depth_file.nc"], | ||
"data": ["data_file1.nc", "data_file2.nc"], | ||
}, | ||
id="dict-mix", | ||
), | ||
], | ||
) | ||
def test_sanitize_field_filenames_cases(input_, expected): | ||
assert _sanitize_field_filenames(input_) == expected | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_,expected", | ||
[ | ||
pytest.param("file*.nc", [], id="glob-no-match"), | ||
], | ||
) | ||
def test_sanitize_field_filenames_glob(input_, expected): | ||
assert _sanitize_field_filenames(input_) == expected | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_,expected", | ||
[ | ||
pytest.param("test", ["test"], id="str"), | ||
pytest.param(Path("test"), ["test"], id="Path"), | ||
pytest.param("file*.nc", [], id="glob-no-match"), | ||
], | ||
) | ||
def test_expand_filename(input_, expected): | ||
assert _expand_filename(input_) == expected |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the
*_lst
additions are just to makemypy
happy