diff --git a/src/anemoi/inference/plugin.py b/src/anemoi/inference/plugin.py index 7aad9ec..1a47e01 100644 --- a/src/anemoi/inference/plugin.py +++ b/src/anemoi/inference/plugin.py @@ -93,7 +93,12 @@ def parse_model_args(self, args): @cached_property def runner(self): - return PluginRunner(self._checkpoint, device=self.device) + return PluginRunner( + self._checkpoint, + device=self.device, + pre_processors=self.pre_processors(), + post_processors=self.post_processors(), + ) def run(self): if self.deterministic: @@ -111,6 +116,14 @@ def run(self): output.close() + def pre_processors(self): + # To override in subclasses + return [] + + def post_processors(self): + # To override in subclasses + return [] + # Below are methods forwarded to the checkpoint @property diff --git a/src/anemoi/inference/postprocess.py b/src/anemoi/inference/postprocess.py deleted file mode 100644 index 7740cdf..0000000 --- a/src/anemoi/inference/postprocess.py +++ /dev/null @@ -1,42 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import logging - -import numpy as np - -LOG = logging.getLogger(__name__) - - -class Noop: - - def __call__(self, source): - yield from source - - -class Accumulator: - """Accumulate fields from zero and return the accumulated fields""" - - def __init__(self, accumulations): - self.accumulations = accumulations - LOG.info("Accumulating fields %s", self.accumulations) - - self.accumulators = {} - - def __call__(self, source): - for state in source: - for accumulation in self.accumulations: - if accumulation in state["fields"]: - if accumulation not in self.accumulators: - self.accumulators[accumulation] = np.zeros_like(state["fields"][accumulation]) - self.accumulators[accumulation] += np.maximum(0, state["fields"][accumulation]) - state["fields"][accumulation] = self.accumulators[accumulation] - - yield state diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index 2856763..6a5f3d7 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -21,9 +21,9 @@ from .checkpoint import Checkpoint from .context import Context -from .postprocess import Accumulator -from .postprocess import Noop from .precisions import PRECISIONS +from .processors import Accumulator +from .processors import Chain LOG = logging.getLogger(__name__) @@ -63,6 +63,8 @@ def __init__( verbosity=0, inference_options=None, development_hacks={}, # For testing purposes, don't use in production + pre_processors=None, + post_processors=None, ): self._checkpoint = Checkpoint(checkpoint) @@ -79,14 +81,15 @@ def __init__( # This could also be passed as an argument - self.postprocess = Noop() + self.preprocess = Chain("pre-processors", pre_processors) + self.postprocess = Chain("post-processors", post_processors) if accumulations is True: # Get accumulations from the checkpoint accumulations = self.checkpoint.accumulations if accumulations: - self.postprocess = Accumulator(accumulations) + self.postprocess.append(Accumulator(accumulations)) self._input_kinds = {} self._input_tensor_by_name = [] @@ -108,6 +111,12 @@ def checkpoint(self): def run(self, *, input_state, lead_time): + # Shallow copy to avoid modifying the user's input state + input_state = input_state.copy() + input_state["fields"] = input_state["fields"].copy() + + input_state = self.preprocess(input_state) + self.constant_forcings_inputs = self.checkpoint.constant_forcings_inputs(self, input_state) self.dynamic_forcings_inputs = self.checkpoint.dynamic_forcings_inputs(self, input_state) self.boundary_forcings_inputs = self.checkpoint.boundary_forcings_inputs(self, input_state) diff --git a/src/anemoi/inference/runners/plugin.py b/src/anemoi/inference/runners/plugin.py index d218282..1a15dc6 100644 --- a/src/anemoi/inference/runners/plugin.py +++ b/src/anemoi/inference/runners/plugin.py @@ -21,8 +21,8 @@ class PluginRunner(Runner): """A runner implementing the ai-models plugin API.""" - def __init__(self, checkpoint: str, *, device: str): - super().__init__(checkpoint, device=device) + def __init__(self, checkpoint: str, *, device: str, pre_processors=None, post_processors=None): + super().__init__(checkpoint, device=device, pre_processors=pre_processors, post_processors=post_processors) # Compatibility with the ai_models API