Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Nov 24, 2024
1 parent 24f4b40 commit df4eee9
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 49 deletions.
15 changes: 14 additions & 1 deletion src/anemoi/inference/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,12 @@ def parse_model_args(self, args):

@cached_property
def runner(self):
return PluginRunner(self._checkpoint, device=self.device)
return PluginRunner(
self._checkpoint,
device=self.device,
pre_processors=self.pre_processors(),
post_processors=self.post_processors(),
)

def run(self):
if self.deterministic:
Expand All @@ -111,6 +116,14 @@ def run(self):

output.close()

def pre_processors(self):
# To override in subclasses
return []

def post_processors(self):
# To override in subclasses
return []

# Below are methods forwarded to the checkpoint

@property
Expand Down
42 changes: 0 additions & 42 deletions src/anemoi/inference/postprocess.py

This file was deleted.

17 changes: 13 additions & 4 deletions src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

from .checkpoint import Checkpoint
from .context import Context
from .postprocess import Accumulator
from .postprocess import Noop
from .precisions import PRECISIONS
from .processors import Accumulator
from .processors import Chain

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -63,6 +63,8 @@ def __init__(
verbosity=0,
inference_options=None,
development_hacks={}, # For testing purposes, don't use in production
pre_processors=None,
post_processors=None,
):
self._checkpoint = Checkpoint(checkpoint)

Expand All @@ -79,14 +81,15 @@ def __init__(

# This could also be passed as an argument

self.postprocess = Noop()
self.preprocess = Chain("pre-processors", pre_processors)
self.postprocess = Chain("post-processors", post_processors)

if accumulations is True:
# Get accumulations from the checkpoint
accumulations = self.checkpoint.accumulations

if accumulations:
self.postprocess = Accumulator(accumulations)
self.postprocess.append(Accumulator(accumulations))

self._input_kinds = {}
self._input_tensor_by_name = []
Expand All @@ -108,6 +111,12 @@ def checkpoint(self):

def run(self, *, input_state, lead_time):

# Shallow copy to avoid modifying the user's input state
input_state = input_state.copy()
input_state["fields"] = input_state["fields"].copy()

input_state = self.preprocess(input_state)

self.constant_forcings_inputs = self.checkpoint.constant_forcings_inputs(self, input_state)
self.dynamic_forcings_inputs = self.checkpoint.dynamic_forcings_inputs(self, input_state)
self.boundary_forcings_inputs = self.checkpoint.boundary_forcings_inputs(self, input_state)
Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/inference/runners/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
class PluginRunner(Runner):
"""A runner implementing the ai-models plugin API."""

def __init__(self, checkpoint: str, *, device: str):
super().__init__(checkpoint, device=device)
def __init__(self, checkpoint: str, *, device: str, pre_processors=None, post_processors=None):
super().__init__(checkpoint, device=device, pre_processors=pre_processors, post_processors=post_processors)

# Compatibility with the ai_models API

Expand Down

0 comments on commit df4eee9

Please sign in to comment.