diff --git a/tests/nodes/test_cutout_nodes.py b/tests/nodes/test_cutout_nodes.py index a18424da..e6484ae9 100644 --- a/tests/nodes/test_cutout_nodes.py +++ b/tests/nodes/test_cutout_nodes.py @@ -1,18 +1,19 @@ import pytest import torch +from omegaconf import OmegaConf from torch_geometric.data import HeteroData from anemoi.graphs.nodes.attributes import AreaWeights from anemoi.graphs.nodes.attributes import UniformWeights from anemoi.graphs.nodes.builders import from_file +dataset_cfg = OmegaConf.create({"cutout": ["lam.zarr", "global.zarr"]}) + def test_init(mocker, mock_zarr_dataset_cutout): """Test CutOutZarrDatasetNodes initialization.""" mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset_cutout) - node_builder = from_file.CutOutZarrDatasetNodes( - forcing_dataset="global.zarr", lam_dataset="lam.zarr", name="test_nodes" - ) + node_builder = from_file.CutOutZarrDatasetNodes(dataset_cfg, name="test_nodes") assert isinstance(node_builder, from_file.BaseNodeBuilder) assert isinstance(node_builder, from_file.CutOutZarrDatasetNodes) @@ -20,16 +21,14 @@ def test_init(mocker, mock_zarr_dataset_cutout): def test_fail_init(): """Test CutOutZarrDatasetNodes initialization with invalid resolution.""" - with pytest.raises(TypeError): + with pytest.raises(AssertionError): from_file.CutOutZarrDatasetNodes("global_dataset.zarr", name="test_nodes") def test_register_nodes(mocker, mock_zarr_dataset_cutout): """Test CutOutZarrDatasetNodes register correctly the nodes.""" mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset_cutout) - node_builder = from_file.CutOutZarrDatasetNodes( - forcing_dataset="global.zarr", lam_dataset="lam.zarr", name="test_nodes" - ) + node_builder = from_file.CutOutZarrDatasetNodes(dataset_cfg, name="test_nodes") graph = HeteroData() graph = node_builder.register_nodes(graph) @@ -44,9 +43,7 @@ def test_register_nodes(mocker, mock_zarr_dataset_cutout): def test_register_attributes(mocker, mock_zarr_dataset_cutout, graph_with_nodes: HeteroData, attr_class): """Test CutOutZarrDatasetNodes register correctly the weights.""" mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset_cutout) - node_builder = from_file.CutOutZarrDatasetNodes( - forcing_dataset="global.zarr", lam_dataset="lam.zarr", name="test_nodes" - ) + node_builder = from_file.CutOutZarrDatasetNodes(dataset_cfg, name="test_nodes") config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}} graph = node_builder.register_attributes(graph_with_nodes, config)