Skip to content

Commit

Permalink
Merge branch 'feature/new-checkpoints' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Nov 8, 2024
2 parents b9db2e0 + a0320a9 commit 1bd19fc
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 6 deletions.
3 changes: 0 additions & 3 deletions src/anemoi/transform/sources/mars.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,3 @@ def forward(self, data):
return this.forward(self.data)

return Input(data)


source_registry.register("mars", Mars)
7 changes: 7 additions & 0 deletions src/anemoi/transform/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,10 @@ def is_computed_forcing(self):
@property
def is_from_input(self):
pass

def similarity(self, other):
"""Compute the similarity between two variables. This is used when
encoding a variables in GRIB and we do not have a template for it.
We can then try to find the most similar variable for which we have a template.
"""
return 0
15 changes: 15 additions & 0 deletions src/anemoi/transform/variables/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,21 @@ def is_instantanous(self):
def grib_keys(self):
return self.data.get("mars", {}).copy()

def similarity(self, other):
if not isinstance(other, VariableFromMarsVocabulary):
return 0

def __similarity(a, b):
if isinstance(a, dict) and isinstance(b, dict):
return sum(__similarity(a[k], b[k]) for k in set(a.keys()) & set(b.keys()))

if isinstance(a, list) and isinstance(b, list):
return sum(__similarity(a[i], b[i]) for i in range(min(len(a), len(b))))

return 1 if a == b else 0

return __similarity(self.data, other.data)


class VariableFromDict(VariableFromMarsVocabulary):
"""A variable that is defined by a user provided dictionary."""
Expand Down
4 changes: 1 addition & 3 deletions src/anemoi/transform/workflows/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from . import workflow_registry


@workflow_registry.register("pipeline")
class Pipeline(Workflow):
"""A simple pipeline of filters"""

Expand All @@ -27,6 +28,3 @@ def backward(self, data):
for filter in reversed(self.filters):
data = filter.backward(data)
return data


workflow_registry.register("pipeline", Pipeline)

0 comments on commit 1bd19fc

Please sign in to comment.