-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
265 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import numpy as np | ||
import pytest | ||
import torch | ||
from torch_geometric.data import HeteroData | ||
|
||
lats = [-0.15, 0, 0.15] | ||
lons = [0, 0.25, 0.5, 0.75] | ||
|
||
|
||
class MockZarrDataset: | ||
"""Mock Zarr dataset with latitudes and longitudes attributes.""" | ||
|
||
def __init__(self, latitudes, longitudes): | ||
self.latitudes = latitudes | ||
self.longitudes = longitudes | ||
self.num_nodes = len(latitudes) | ||
|
||
|
||
@pytest.fixture | ||
def mock_zarr_dataset() -> MockZarrDataset: | ||
"""Mock zarr dataset with nodes.""" | ||
coords = 2 * torch.pi * np.array([[lat, lon] for lat in lats for lon in lons]) | ||
return MockZarrDataset(latitudes=coords[:, 0], longitudes=coords[:, 1]) | ||
|
||
|
||
@pytest.fixture | ||
def mock_grids_path(tmp_path) -> tuple[str, int]: | ||
"""Mock grid_definition_path with files for 3 resolutions.""" | ||
num_nodes = len(lats) * len(lons) | ||
for resolution in ["o16", "o48", "5km5"]: | ||
file_path = tmp_path / f"grid-{resolution}.npz" | ||
np.savez(file_path, latitudes=np.random.rand(num_nodes), longitudes=np.random.rand(num_nodes)) | ||
return str(tmp_path), num_nodes | ||
|
||
|
||
@pytest.fixture | ||
def graph_with_nodes() -> HeteroData: | ||
"""Graph with 12 nodes over the globe, stored in \"test_nodes\".""" | ||
coords = np.array([[lat, lon] for lat in lats for lon in lons]) | ||
graph = HeteroData() | ||
graph["test_nodes"].x = 2 * torch.pi * torch.tensor(coords) | ||
return graph | ||
|
||
|
||
@pytest.fixture | ||
def graph_nodes_and_edges() -> HeteroData: | ||
"""Graph with 1 set of nodes and edges.""" | ||
coords = np.array([[lat, lon] for lat in lats for lon in lons]) | ||
graph = HeteroData() | ||
graph["test_nodes"].x = 2 * torch.pi * torch.tensor(coords) | ||
graph[("test_nodes", "to", "test_nodes")].edge_index = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]]) | ||
return graph |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import pytest | ||
import torch | ||
|
||
from anemoi.graphs.edges.attributes import DirectionalFeatures | ||
|
||
|
||
@pytest.mark.parametrize("norm", ["l1", "l2", "unit-max", "unit-sum", "unit-std"]) | ||
@pytest.mark.parametrize("luse_rotated_features", [True, False]) | ||
def test_directional_features(graph_nodes_and_edges, norm, luse_rotated_features: bool): | ||
"""Test DirectionalFeatures compute method.""" | ||
edge_attr_builder = DirectionalFeatures(norm=norm, luse_rotated_features=luse_rotated_features) | ||
edge_attr = edge_attr_builder(graph_nodes_and_edges, "test_nodes", "test_nodes") | ||
assert isinstance(edge_attr, torch.Tensor) | ||
|
||
|
||
def test_fail_directional_features(graph_nodes_and_edges): | ||
"""Test DirectionalFeatures compute method.""" | ||
edge_attr_builder = DirectionalFeatures() | ||
with pytest.raises(AttributeError): | ||
edge_attr_builder(graph_nodes_and_edges, "test_nodes", "unknown_nodes") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import pytest | ||
|
||
from anemoi.graphs.edges import CutOffEdgeBuilder | ||
|
||
|
||
def test_init(): | ||
"""Test CutOffEdgeBuilder initialization.""" | ||
CutOffEdgeBuilder("test_nodes1", "test_nodes2", 0.5) | ||
|
||
|
||
@pytest.mark.parametrize("cutoff_factor", [-0.5, "hello", None]) | ||
def test_fail_init(cutoff_factor: str): | ||
"""Test CutOffEdgeBuilder initialization with invalid cutoff.""" | ||
with pytest.raises(AssertionError): | ||
CutOffEdgeBuilder("test_nodes1", "test_nodes2", cutoff_factor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import pytest | ||
|
||
from anemoi.graphs.edges import KNNEdgeBuilder | ||
|
||
|
||
def test_init(): | ||
"""Test CutOffEdgeBuilder initialization.""" | ||
KNNEdgeBuilder("test_nodes1", "test_nodes2", 3) | ||
|
||
|
||
@pytest.mark.parametrize("num_nearest_neighbours", [-1, 2.6, "hello", None]) | ||
def test_fail_init(num_nearest_neighbours: str): | ||
"""Test KNNEdgeBuilder initialization with invalid number of nearest neighbours.""" | ||
with pytest.raises(AssertionError): | ||
KNNEdgeBuilder("test_nodes1", "test_nodes2", num_nearest_neighbours) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import pytest | ||
import torch | ||
from torch_geometric.data import HeteroData | ||
|
||
from anemoi.graphs.nodes.nodes import NPZNodes | ||
from anemoi.graphs.nodes.weights import AreaWeights | ||
from anemoi.graphs.nodes.weights import UniformWeights | ||
|
||
|
||
@pytest.mark.parametrize("resolution", ["o16", "o48", "5km5"]) | ||
def test_init(mock_grids_path: tuple[str, int], resolution: str): | ||
"""Test NPZNodes initialization.""" | ||
grid_definition_path, _ = mock_grids_path | ||
node_builder = NPZNodes(resolution, grid_definition_path=grid_definition_path) | ||
assert isinstance(node_builder, NPZNodes) | ||
|
||
|
||
@pytest.mark.parametrize("resolution", ["o17", 13, "ajsnb", None]) | ||
def test_fail_init_wrong_resolution(mock_grids_path: tuple[str, int], resolution: str): | ||
"""Test NPZNodes initialization with invalid resolution.""" | ||
grid_definition_path, _ = mock_grids_path | ||
with pytest.raises(FileNotFoundError): | ||
NPZNodes(resolution, grid_definition_path=grid_definition_path) | ||
|
||
|
||
def test_fail_init_wrong_path(): | ||
"""Test NPZNodes initialization with invalid path.""" | ||
with pytest.raises(FileNotFoundError): | ||
NPZNodes("o16", "invalid_path") | ||
|
||
|
||
@pytest.mark.parametrize("resolution", ["o16", "o48", "5km5"]) | ||
def test_register_nodes(mock_grids_path: str, resolution: str): | ||
"""Test NPZNodes register correctly the nodes.""" | ||
graph = HeteroData() | ||
grid_definition_path, num_nodes = mock_grids_path | ||
node_builder = NPZNodes(resolution, grid_definition_path=grid_definition_path) | ||
|
||
graph = node_builder.register_nodes(graph, "test_nodes") | ||
|
||
assert graph["test_nodes"].x is not None | ||
assert isinstance(graph["test_nodes"].x, torch.Tensor) | ||
assert graph["test_nodes"].x.shape == (num_nodes, 2) | ||
assert graph["test_nodes"].node_type == "NPZNodes" | ||
|
||
|
||
@pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) | ||
def test_register_weights(graph_with_nodes: HeteroData, mock_grids_path: tuple[str, int], attr_class): | ||
"""Test NPZNodes register correctly the weights.""" | ||
grid_definition_path, _ = mock_grids_path | ||
node_builder = NPZNodes("o16", grid_definition_path=grid_definition_path) | ||
config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.weights.{attr_class.__name__}"}} | ||
|
||
graph = node_builder.register_attributes(graph_with_nodes, "test_nodes", config) | ||
|
||
assert graph["test_nodes"]["test_attr"] is not None | ||
assert isinstance(graph["test_nodes"]["test_attr"], torch.Tensor) | ||
assert graph["test_nodes"]["test_attr"].shape[0] == graph["test_nodes"].x.shape[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import numpy as np | ||
import pytest | ||
import torch | ||
from hydra.utils import instantiate | ||
from torch_geometric.data import HeteroData | ||
|
||
|
||
@pytest.mark.parametrize("norm", [None, "l1", "l2", "unit-max", "unit-sum", "unit-std"]) | ||
def test_uniform_weights(graph_with_nodes: HeteroData, norm: str): | ||
"""Test NPZNodes register correctly the weights.""" | ||
config = {"_target_": "anemoi.graphs.nodes.weights.UniformWeights", "norm": norm} | ||
|
||
weights = instantiate(config).get_weights(graph_with_nodes["test_nodes"]) | ||
|
||
assert weights is not None | ||
assert isinstance(weights, torch.Tensor) | ||
assert weights.shape[0] == graph_with_nodes["test_nodes"].x.shape[0] | ||
|
||
|
||
@pytest.mark.parametrize("norm", ["l3", "invalide"]) | ||
def test_uniform_weights_fail(graph_with_nodes: HeteroData, norm: str): | ||
"""Test NPZNodes register correctly the weights.""" | ||
config = {"_target_": "anemoi.graphs.nodes.weights.UniformWeights", "norm": norm} | ||
|
||
with pytest.raises(ValueError): | ||
instantiate(config).get_weights(graph_with_nodes["test_nodes"]) | ||
|
||
|
||
def test_area_weights(graph_with_nodes: HeteroData): | ||
"""Test NPZNodes register correctly the weights.""" | ||
config = { | ||
"_target_": "anemoi.graphs.nodes.weights.AreaWeights", | ||
"radius": 1.0, | ||
"centre": np.array([0, 0, 0]), | ||
} | ||
|
||
weights = instantiate(config).get_weights(graph_with_nodes["test_nodes"]) | ||
|
||
assert weights is not None | ||
assert isinstance(weights, torch.Tensor) | ||
assert weights.shape[0] == graph_with_nodes["test_nodes"].x.shape[0] | ||
|
||
|
||
@pytest.mark.parametrize("radius", [-1.0, "hello", None]) | ||
def test_area_weights_fail(graph_with_nodes: HeteroData, radius: float): | ||
config = { | ||
"_target_": "anemoi.graphs.nodes.weights.AreaWeights", | ||
"radius": radius, | ||
"centre": np.array([0, 0, 0]), | ||
} | ||
|
||
with pytest.raises(ValueError): | ||
instantiate(config).get_weights(graph_with_nodes["test_nodes"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import pytest | ||
import torch | ||
import zarr | ||
from torch_geometric.data import HeteroData | ||
|
||
from anemoi.graphs.nodes import nodes | ||
from anemoi.graphs.nodes.weights import AreaWeights | ||
from anemoi.graphs.nodes.weights import UniformWeights | ||
|
||
|
||
def test_init(mocker, mock_zarr_dataset): | ||
"""Test ZarrNodes initialization.""" | ||
mocker.patch.object(nodes, "open_dataset", return_value=mock_zarr_dataset) | ||
node_builder = nodes.ZarrNodes("dataset.zarr") | ||
assert isinstance(node_builder, nodes.BaseNodeBuilder) | ||
assert isinstance(node_builder, nodes.ZarrNodes) | ||
|
||
|
||
def test_fail_init(): | ||
"""Test ZarrNodes initialization with invalid resolution.""" | ||
with pytest.raises(zarr.errors.PathNotFoundError): | ||
nodes.ZarrNodes("invalid_path.zarr") | ||
|
||
|
||
def test_register_nodes(mocker, mock_zarr_dataset): | ||
"""Test ZarrNodes register correctly the nodes.""" | ||
mocker.patch.object(nodes, "open_dataset", return_value=mock_zarr_dataset) | ||
node_builder = nodes.ZarrNodes("dataset.zarr") | ||
graph = HeteroData() | ||
|
||
graph = node_builder.register_nodes(graph, "test_nodes") | ||
|
||
assert graph["test_nodes"].x is not None | ||
assert isinstance(graph["test_nodes"].x, torch.Tensor) | ||
assert graph["test_nodes"].x.shape == (node_builder.ds.num_nodes, 2) | ||
assert graph["test_nodes"].node_type == "ZarrNodes" | ||
|
||
|
||
@pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) | ||
def test_register_weights(mocker, graph_with_nodes: HeteroData, attr_class): | ||
"""Test ZarrNodes register correctly the weights.""" | ||
mocker.patch.object(nodes, "open_dataset", return_value=None) | ||
node_builder = nodes.ZarrNodes("dataset.zarr") | ||
config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.weights.{attr_class.__name__}"}} | ||
|
||
graph = node_builder.register_attributes(graph_with_nodes, "test_nodes", config) | ||
|
||
assert graph["test_nodes"]["test_attr"] is not None | ||
assert isinstance(graph["test_nodes"]["test_attr"], torch.Tensor) | ||
assert graph["test_nodes"]["test_attr"].shape[0] == graph["test_nodes"].x.shape[0] |