Skip to content

Commit

Permalink
move code from datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Nov 14, 2024
1 parent 13883f5 commit a124fc4
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 11 deletions.
106 changes: 106 additions & 0 deletions src/anemoi/transform/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import logging

from earthkit.data.indexing.fieldlist import FieldArray

LOG = logging.getLogger(__name__)


def new_fieldlist_from_list(fields):
return FieldArray(fields)


def new_empty_fieldlist():
return FieldArray([])


class WrappedField:
"""A wrapper around a earthkit-data field object."""

def __init__(self, field):
self._field = field

def __getattr__(self, name):
if name not in (
"mars_area",
"mars_grid",
"to_numpy",
"metadata",
):
LOG.warning(f"NewField: forwarding `{name}`")
return getattr(self._field, name)

def __repr__(self) -> str:
return repr(self._field)


class NewDataField(WrappedField):
"""Change the data of a field."""

def __init__(self, field, data):
super().__init__(field)
self._data = data
self.shape = data.shape

def to_numpy(self, flatten=False, dtype=None, index=None):
data = self._data
if dtype is not None:
data = data.astype(dtype)
if flatten:
data = data.flatten()
if index is not None:
data = data[index]
return data


class NewMetadataField(WrappedField):
"""Change the metadata of a field."""

def __init__(self, field, **kwargs):
super().__init__(field)
self._metadata = kwargs

def metadata(self, *args, **kwargs):

if kwargs.get("namespace"):
assert kwargs.get("namespace") == "mars", kwargs
assert len(args) == 0, (args, kwargs)
mars = self._field.metadata(**kwargs).copy()
for k in list(mars.keys()):
if k in self._metadata:
mars[k] = self._metadata[k]
return mars

if len(args) == 1 and args[0] in self._metadata:
return self._metadata[args[0]]

return self._field.metadata(*args, **kwargs)


class NewValidDateTimeField(NewMetadataField):
"""Change the valid_datetime of a field."""

def __init__(self, field, valid_datetime):
date = int(valid_datetime.date().strftime("%Y%m%d"))
assert valid_datetime.time().minute == 0, valid_datetime.time()
time = valid_datetime.time().hour

self.valid_datetime = valid_datetime

super().__init__(field, date=date, time=time, step=0, valid_datetime=valid_datetime.isoformat())


def new_field_from_numpy(array, *, template, **metadata):
return NewMetadataField(NewDataField(template, array), **metadata)


def new_field_with_valid_datetime(template, date):
return NewValidDateTimeField(template, date)
18 changes: 7 additions & 11 deletions src/anemoi/transform/filters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
# nor does it submit to any jurisdiction.


import logging
from abc import abstractmethod

import earthkit.data as ekd

from ..fields import new_field_from_numpy
from ..fields import new_fieldlist_from_list
from ..filter import Filter
from ..grouping import GroupByMarsParam

LOG = logging.getLogger(__name__)


class SimpleFilter(Filter):
"""A filter to convert only some fields.
Expand All @@ -35,17 +38,10 @@ def _transform(self, data, transform, *group_by):

def new_field_from_numpy(self, array, *, template, param):
"""Create a new field from a numpy array."""
if isinstance(param, int):
md = template.metadata().override(paramId=param)
else:
md = template.metadata().override(shortName=param)
# return ekd.ArrayField(array, md)
return ekd.FieldList.from_array(array, md)[0]
return new_field_from_numpy(array, template=template, param=param)

def new_fieldlist_from_list(self, fields):
from earthkit.data.indexing.fieldlist import FieldArray

return FieldArray(fields)
return new_fieldlist_from_list(fields)

@abstractmethod
def forward_transform(self, *fields):
Expand Down
2 changes: 2 additions & 0 deletions src/anemoi/transform/filters/land_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@ class LandParameters(SimpleFilter):
def __init__(
self,
*,
# Input parameters
high_veg_type="tvh",
low_veg_type="tvl",
soil_type="slt",
# Output parameters
hveg_rsmin="hveg_rsmin",
hveg_cov="hveg_cov",
hveg_z0m="hveg_z0m",
Expand Down
2 changes: 2 additions & 0 deletions src/anemoi/transform/grouping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def iterate(self, data, *, other=_lost):

for _, group in groups.items():
if len(group) != len(self.params):
for p in data:
print(p)
raise ValueError(f"Missing component. Want {sorted(self.params)}, got {sorted(group.keys())}")

yield tuple(group[p] for p in self.params)

0 comments on commit a124fc4

Please sign in to comment.