Skip to content

Commit

Permalink
fix: tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Oct 16, 2024
1 parent 24891b9 commit 820ca2f
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions tests/nodes/test_cutout_nodes.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,34 @@
import pytest
import torch
from omegaconf import OmegaConf
from torch_geometric.data import HeteroData

from anemoi.graphs.nodes.attributes import AreaWeights
from anemoi.graphs.nodes.attributes import UniformWeights
from anemoi.graphs.nodes.builders import from_file

dataset_cfg = OmegaConf.create({"cutout": ["lam.zarr", "global.zarr"]})


def test_init(mocker, mock_zarr_dataset_cutout):
"""Test CutOutZarrDatasetNodes initialization."""
mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset_cutout)
node_builder = from_file.CutOutZarrDatasetNodes(
forcing_dataset="global.zarr", lam_dataset="lam.zarr", name="test_nodes"
)
node_builder = from_file.CutOutZarrDatasetNodes(dataset_cfg, name="test_nodes")

assert isinstance(node_builder, from_file.BaseNodeBuilder)
assert isinstance(node_builder, from_file.CutOutZarrDatasetNodes)


def test_fail_init():
"""Test CutOutZarrDatasetNodes initialization with invalid resolution."""
with pytest.raises(TypeError):
with pytest.raises(AssertionError):
from_file.CutOutZarrDatasetNodes("global_dataset.zarr", name="test_nodes")


def test_register_nodes(mocker, mock_zarr_dataset_cutout):
"""Test CutOutZarrDatasetNodes register correctly the nodes."""
mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset_cutout)
node_builder = from_file.CutOutZarrDatasetNodes(
forcing_dataset="global.zarr", lam_dataset="lam.zarr", name="test_nodes"
)
node_builder = from_file.CutOutZarrDatasetNodes(dataset_cfg, name="test_nodes")
graph = HeteroData()

graph = node_builder.register_nodes(graph)
Expand All @@ -44,9 +43,7 @@ def test_register_nodes(mocker, mock_zarr_dataset_cutout):
def test_register_attributes(mocker, mock_zarr_dataset_cutout, graph_with_nodes: HeteroData, attr_class):
"""Test CutOutZarrDatasetNodes register correctly the weights."""
mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset_cutout)
node_builder = from_file.CutOutZarrDatasetNodes(
forcing_dataset="global.zarr", lam_dataset="lam.zarr", name="test_nodes"
)
node_builder = from_file.CutOutZarrDatasetNodes(dataset_cfg, name="test_nodes")
config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}}

graph = node_builder.register_attributes(graph_with_nodes, config)
Expand Down

0 comments on commit 820ca2f

Please sign in to comment.