Skip to content

Commit

Permalink
add hack
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Nov 22, 2024
1 parent 54fb291 commit 62f8322
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 3 deletions.
3 changes: 3 additions & 0 deletions src/anemoi/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ class Config:
such as `eccodes`. In certain cases, the variables mey be set too late, if the package for which they are intended
is already loaded when the runner is configured."""

development_hacks: dict = {}
"""A dictionary of development hacks to apply to the runner. This is used to test new features or to work around"""


def load_config(path, overrides, Configuration=Configuration):

Expand Down
1 change: 1 addition & 0 deletions src/anemoi/inference/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Context(ABC):
allow_nans = (None,) # can be True of False
use_grib_paramid = False
verbosity = 0
development_hacks = {} # For testing purposes, don't use in production

@property
@abstractmethod
Expand Down
13 changes: 13 additions & 0 deletions src/anemoi/inference/grib/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def grib_keys(
step,
type,
keys,
quiet,
grib1_keys={},
grib2_keys={},
):
Expand All @@ -50,6 +51,18 @@ def grib_keys(
if edition is None and template is not None:
edition = template.metadata("edition")
# centre = template.metadata("centre")
if edition == 2:
productDefinitionTemplateNumber = template.metadata("productDefinitionTemplateNumber")
if productDefinitionTemplateNumber in (8, 11) and not accumulation:
if f"{param}-accumulation" not in quiet:
LOG.warning(
"%s: Template %s is accumulation but `accumulation` was not specified",
param,
productDefinitionTemplateNumber,
)
LOG.warning("%s: Setting `accumulation` to `True`", param)
quiet.add(f"{param}-accumulation")
accumulation = True

if edition is None:
edition = 1
Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/inference/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ def set_private_attributes(self, state, value):
"""
pass

def template(self, variables, dates, **kwargs):
def template(self, variable, date, **kwargs):
"""Used for fetching GRIB templates."""
raise NotImplementedError(f"{self.__class__.__name__}.template() not implemented")
7 changes: 7 additions & 0 deletions src/anemoi/inference/inputs/ekd.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,13 @@ def _name(field, _, original_metadata):

return data

def _find_variable(self, data, name, **kwargs):
def _name(field, _, original_metadata):
return self._namer(field, original_metadata)

data = FieldArray([f.clone(name=_name) for f in data])
return data.sel(name=name, **kwargs)

def _load_forcings(self, fields, variables, dates):
data = self._filter_and_sort(fields, variables=variables, dates=dates, title="Load forcings")
return data.to_numpy(dtype=np.float32, flatten=True).reshape(len(variables), len(dates), -1)
7 changes: 7 additions & 0 deletions src/anemoi/inference/inputs/gribfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,10 @@ def create_input_state(self, *, date):

def load_forcings(self, *, variables, dates):
return self._load_forcings(ekd.from_source("file", self.path), variables=variables, dates=dates)

def template(self, variable, date, **kwargs):
fields = ekd.from_source("file", self.path)
data = self._find_variable(fields, variable)
if len(data) == 0:
return None
return data[0]
12 changes: 10 additions & 2 deletions src/anemoi/inference/outputs/grib.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, g
self._template_source = None
self._template_date = None
self._template_reuse = None
self.use_closest_template = False # Off for now

def write_initial_state(self, state):
# We trust the GribInput class to provide the templates
Expand Down Expand Up @@ -73,6 +74,7 @@ def write_initial_state(self, state):
keys=self.encoding,
grib1_keys=self.grib1_keys,
grib2_keys=self.grib2_keys,
quiet=self.quiet,
)

# LOG.info("Step 0 GRIB %s\n%s", template, json.dumps(keys, indent=4))
Expand Down Expand Up @@ -127,6 +129,7 @@ def write_state(self, state):
keys=keys,
grib1_keys=self.grib1_keys,
grib2_keys=self.grib2_keys,
quiet=self.quiet,
)

if LOG.isEnabledFor(logging.DEBUG):
Expand Down Expand Up @@ -161,7 +164,7 @@ def template(self, state, name):
if None in self._template_cache:
return self._template_cache[None]

if False: #
if self.use_closest_template: #
template, name2 = self._clostest_template(self._template_cache, name)

if name not in self.quiet:
Expand All @@ -187,8 +190,13 @@ def template(self, state, name):

self._template_reuse = self.templates.get("reuse", False)

LOG.info("Loading template for %s from %s", name, self._template_source)

date = self._template_date if self._template_date is not None else state["date"]
field = self._template_source.template(variable=name, date=date, edition=self.edition)
field = self._template_source.template(variable=name, date=date)

if field is None:
LOG.warning("No template found for `%s`", name)

if self._template_reuse:
self._template_cache[None] = field
Expand Down
9 changes: 9 additions & 0 deletions src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import datetime
import logging
import warnings
from functools import cached_property

import numpy as np
Expand Down Expand Up @@ -61,6 +62,7 @@ def __init__(
use_grib_paramid=False,
verbosity=0,
inference_options=None,
development_hacks={}, # For testing purposes, don't use in production
):
self._checkpoint = Checkpoint(checkpoint)

Expand All @@ -72,6 +74,8 @@ def __init__(
self.verbosity = verbosity
self.allow_nans = allow_nans
self.use_grib_paramid = use_grib_paramid
self.development_hacks = development_hacks
self.hacks = bool(development_hacks)

# This could also be passed as an argument

Expand Down Expand Up @@ -335,6 +339,11 @@ def copy_prognostic_fields_to_input_tensor(self, input_tensor_torch, y_pred, che

def add_dynamic_forcings_to_input_tensor(self, input_tensor_torch, state, date, check):

if self.hacks:
if "dynamic_forcings_date" in self.development_hacks:
date = self.development_hacks["dynamic_forcings_date"]
warnings.warn(f"🧑‍💻 Using `dynamic_forcings_date` hack: {date} 🧑‍💻")

# TODO: check if there were not already loaded as part of the input state

# input_tensor_torch is shape: (batch, multi_step_input, values, variables)
Expand Down
1 change: 1 addition & 0 deletions src/anemoi/inference/runners/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self, config):
verbosity=config.verbosity,
report_error=config.report_error,
use_grib_paramid=config.use_grib_paramid,
development_hacks=config.development_hacks,
)

def create_input(self):
Expand Down

0 comments on commit 62f8322

Please sign in to comment.