Skip to content

Commit

Permalink
add registry
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Oct 17, 2024
1 parent c3f721d commit 6b30f99
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 26 deletions.
3 changes: 1 addition & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,12 @@ repos:
rev: v0.6.4
hooks:
- id: ruff
# Next line if for documenation cod snippets
exclude: '.*/[^_].*_\.py$'
args:
- --line-length=120
- --fix
- --exit-non-zero-on-fix
- --preview
- --exclude=docs/**/*_.py
- repo: https://github.com/sphinx-contrib/sphinx-lint
rev: v1.0.0
hooks:
Expand Down
28 changes: 28 additions & 0 deletions dev/dev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from earthkit.data import from_source

from anemoi.transform.filters import filter_factory

################
data = from_source(
"mars",
param=["u", "v", "t", "q"],
grid=[1, 1],
date="20200101/to/20200105",
levelist=[1000, 850, 500],
)
for f in data:
print(f)

################

uv_2_ddff = filter_factory("uv_2_ddff")

data = uv_2_ddff.forward(data)
for f in data:
print(f)


ddff_2_uv = filter_factory("ddff_2_uv")
data = ddff_2_uv.forward(data)
for f in data:
print(f)
71 changes: 65 additions & 6 deletions src/anemoi/transform/filters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
# nor does it submit to any jurisdiction.


import importlib
import logging
import os
from abc import ABC
from abc import abstractmethod
from collections import defaultdict
from typing import Any

import earthkit.data as ekd
import numpy as np
from earthkit.data import FieldList
from earthkit.meteo.wind.array import polar_to_xy
from earthkit.meteo.wind.array import xy_to_polar
import entrypoints

LOG = logging.getLogger(__name__)


class Filter(ABC):
Expand All @@ -34,6 +35,10 @@ def backward(self, x: ekd.FieldList) -> ekd.FieldList:
def reverse(self) -> "Filter":
return ReversedFilter(self)

@classmethod
def reversed(cls, *args, **kwargs):
return ReversedFilter(cls(*args, **kwargs))


class ReversedFilter(Filter):
"""Swap the forward and backward methods of a filter."""
Expand Down Expand Up @@ -89,7 +94,7 @@ def _transform(self, data, transform, *group_by):
def new_field_from_numpy(self, array, *, template, param):
"""Create a new field from a numpy array."""
md = template.metadata().override(shortName=param)
return FieldList.from_array(array, md)[0]
return ekd.ArrayField(array, md)

def new_fieldlits_from_list(self, fields):
from earthkit.data.indexing.fieldlist import FieldArray
Expand All @@ -105,3 +110,57 @@ def forward_transform(self, *fields):
def backward_transform(self, *fields):
"""To be implemented by subclasses."""
pass


FILTERS = {}


def register_filter(name, klass):
FILTERS[name] = klass


def _load(file):
name, _ = os.path.splitext(file)
try:
# The module is expected to register the filter
# with the register_filter function
importlib.import_module(f".{name}", package=__name__)
except Exception:
LOG.warning(f"Error loading filter '{name}'", exc_info=True)


def filter_registry(name) -> Filter:
if name in FILTERS:
return FILTERS[name]

for entry_point in entrypoints.get_group_all("anemoi.filters"):
if entry_point.name == name:
FILTERS[name] = entry_point.load()
return FILTERS[name]

here = os.path.dirname(__file__)
for file in os.listdir(here):
print(file)
if file[0] == ".":
continue

if file == "__init__.py":
continue

full = os.path.join(here, file)
if os.path.isdir(full):
if os.path.exists(os.path.join(full, "__init__.py")):
_load(file)
continue

if file.endswith(".py"):
_load(file)

if name not in FILTERS:
raise ValueError(f"Unknown filter '{name}'")

return FILTERS[name]


def filter_factory(name, *args, **kwargs) -> Filter:
return filter_registry(name)(*args, **kwargs)
21 changes: 3 additions & 18 deletions src/anemoi/transform/filters/uv_to_ddff.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from earthkit.meteo.wind.array import xy_to_polar

from anemoi.transform.filters import TransformFilter
from anemoi.transform.filters import register_filter


class WindComponents(TransformFilter):
Expand Down Expand Up @@ -83,21 +84,5 @@ def backward_transform(self, speed, direction):
yield self.new_field_from_numpy(v, template=direction, param=self.v_component)


if __name__ == "__main__":
from earthkit.data import from_source

################
data = from_source(
"mars",
param=["u", "v", "t", "q"],
grid=[1, 1],
date="20200101/to/20200105",
levelist=[1000, 850, 500],
)
for f in data:
print(f)

################
data = WindComponents().forward(data)
for f in data:
print(f)
register_filter("uv_2_ddff", WindComponents)
register_filter("ddff_2_uv", WindComponents.reversed)

0 comments on commit 6b30f99

Please sign in to comment.