Skip to content

Commit

Permalink
initial tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Jun 26, 2024
1 parent d5f67fd commit b12272d
Show file tree
Hide file tree
Showing 8 changed files with 265 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ optional-dependencies.dev = [
"nbsphinx",
"pandoc",
"pytest",
"pytest-mock",
"requests",
"sphinx",
"sphinx-argparse",
Expand All @@ -83,6 +84,7 @@ optional-dependencies.docs = [

optional-dependencies.tests = [
"pytest",
"pytest-mock",
]

urls.Documentation = "https://anemoi-graphs.readthedocs.io/"
Expand Down
52 changes: 52 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import numpy as np
import pytest
import torch
from torch_geometric.data import HeteroData

lats = [-0.15, 0, 0.15]
lons = [0, 0.25, 0.5, 0.75]


class MockZarrDataset:
"""Mock Zarr dataset with latitudes and longitudes attributes."""

def __init__(self, latitudes, longitudes):
self.latitudes = latitudes
self.longitudes = longitudes
self.num_nodes = len(latitudes)


@pytest.fixture
def mock_zarr_dataset() -> MockZarrDataset:
"""Mock zarr dataset with nodes."""
coords = 2 * torch.pi * np.array([[lat, lon] for lat in lats for lon in lons])
return MockZarrDataset(latitudes=coords[:, 0], longitudes=coords[:, 1])


@pytest.fixture
def mock_grids_path(tmp_path) -> tuple[str, int]:
"""Mock grid_definition_path with files for 3 resolutions."""
num_nodes = len(lats) * len(lons)
for resolution in ["o16", "o48", "5km5"]:
file_path = tmp_path / f"grid-{resolution}.npz"
np.savez(file_path, latitudes=np.random.rand(num_nodes), longitudes=np.random.rand(num_nodes))
return str(tmp_path), num_nodes


@pytest.fixture
def graph_with_nodes() -> HeteroData:
"""Graph with 12 nodes over the globe, stored in \"test_nodes\"."""
coords = np.array([[lat, lon] for lat in lats for lon in lons])
graph = HeteroData()
graph["test_nodes"].x = 2 * torch.pi * torch.tensor(coords)
return graph


@pytest.fixture
def graph_nodes_and_edges() -> HeteroData:
"""Graph with 1 set of nodes and edges."""
coords = np.array([[lat, lon] for lat in lats for lon in lons])
graph = HeteroData()
graph["test_nodes"].x = 2 * torch.pi * torch.tensor(coords)
graph[("test_nodes", "to", "test_nodes")].edge_index = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]])
return graph
20 changes: 20 additions & 0 deletions tests/edges/test_attributes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest
import torch

from anemoi.graphs.edges.attributes import DirectionalFeatures


@pytest.mark.parametrize("norm", ["l1", "l2", "unit-max", "unit-sum", "unit-std"])
@pytest.mark.parametrize("luse_rotated_features", [True, False])
def test_directional_features(graph_nodes_and_edges, norm, luse_rotated_features: bool):
"""Test DirectionalFeatures compute method."""
edge_attr_builder = DirectionalFeatures(norm=norm, luse_rotated_features=luse_rotated_features)
edge_attr = edge_attr_builder(graph_nodes_and_edges, "test_nodes", "test_nodes")
assert isinstance(edge_attr, torch.Tensor)


def test_fail_directional_features(graph_nodes_and_edges):
"""Test DirectionalFeatures compute method."""
edge_attr_builder = DirectionalFeatures()
with pytest.raises(AttributeError):
edge_attr_builder(graph_nodes_and_edges, "test_nodes", "unknown_nodes")
15 changes: 15 additions & 0 deletions tests/edges/test_cutoff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest

from anemoi.graphs.edges import CutOffEdgeBuilder


def test_init():
"""Test CutOffEdgeBuilder initialization."""
CutOffEdgeBuilder("test_nodes1", "test_nodes2", 0.5)


@pytest.mark.parametrize("cutoff_factor", [-0.5, "hello", None])
def test_fail_init(cutoff_factor: str):
"""Test CutOffEdgeBuilder initialization with invalid cutoff."""
with pytest.raises(AssertionError):
CutOffEdgeBuilder("test_nodes1", "test_nodes2", cutoff_factor)
15 changes: 15 additions & 0 deletions tests/edges/test_knn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest

from anemoi.graphs.edges import KNNEdgeBuilder


def test_init():
"""Test CutOffEdgeBuilder initialization."""
KNNEdgeBuilder("test_nodes1", "test_nodes2", 3)


@pytest.mark.parametrize("num_nearest_neighbours", [-1, 2.6, "hello", None])
def test_fail_init(num_nearest_neighbours: str):
"""Test KNNEdgeBuilder initialization with invalid number of nearest neighbours."""
with pytest.raises(AssertionError):
KNNEdgeBuilder("test_nodes1", "test_nodes2", num_nearest_neighbours)
58 changes: 58 additions & 0 deletions tests/nodes/test_npz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest
import torch
from torch_geometric.data import HeteroData

from anemoi.graphs.nodes.nodes import NPZNodes
from anemoi.graphs.nodes.weights import AreaWeights
from anemoi.graphs.nodes.weights import UniformWeights


@pytest.mark.parametrize("resolution", ["o16", "o48", "5km5"])
def test_init(mock_grids_path: tuple[str, int], resolution: str):
"""Test NPZNodes initialization."""
grid_definition_path, _ = mock_grids_path
node_builder = NPZNodes(resolution, grid_definition_path=grid_definition_path)
assert isinstance(node_builder, NPZNodes)


@pytest.mark.parametrize("resolution", ["o17", 13, "ajsnb", None])
def test_fail_init_wrong_resolution(mock_grids_path: tuple[str, int], resolution: str):
"""Test NPZNodes initialization with invalid resolution."""
grid_definition_path, _ = mock_grids_path
with pytest.raises(FileNotFoundError):
NPZNodes(resolution, grid_definition_path=grid_definition_path)


def test_fail_init_wrong_path():
"""Test NPZNodes initialization with invalid path."""
with pytest.raises(FileNotFoundError):
NPZNodes("o16", "invalid_path")


@pytest.mark.parametrize("resolution", ["o16", "o48", "5km5"])
def test_register_nodes(mock_grids_path: str, resolution: str):
"""Test NPZNodes register correctly the nodes."""
graph = HeteroData()
grid_definition_path, num_nodes = mock_grids_path
node_builder = NPZNodes(resolution, grid_definition_path=grid_definition_path)

graph = node_builder.register_nodes(graph, "test_nodes")

assert graph["test_nodes"].x is not None
assert isinstance(graph["test_nodes"].x, torch.Tensor)
assert graph["test_nodes"].x.shape == (num_nodes, 2)
assert graph["test_nodes"].node_type == "NPZNodes"


@pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights])
def test_register_weights(graph_with_nodes: HeteroData, mock_grids_path: tuple[str, int], attr_class):
"""Test NPZNodes register correctly the weights."""
grid_definition_path, _ = mock_grids_path
node_builder = NPZNodes("o16", grid_definition_path=grid_definition_path)
config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.weights.{attr_class.__name__}"}}

graph = node_builder.register_attributes(graph_with_nodes, "test_nodes", config)

assert graph["test_nodes"]["test_attr"] is not None
assert isinstance(graph["test_nodes"]["test_attr"], torch.Tensor)
assert graph["test_nodes"]["test_attr"].shape[0] == graph["test_nodes"].x.shape[0]
53 changes: 53 additions & 0 deletions tests/nodes/test_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import numpy as np
import pytest
import torch
from hydra.utils import instantiate
from torch_geometric.data import HeteroData


@pytest.mark.parametrize("norm", [None, "l1", "l2", "unit-max", "unit-sum", "unit-std"])
def test_uniform_weights(graph_with_nodes: HeteroData, norm: str):
"""Test NPZNodes register correctly the weights."""
config = {"_target_": "anemoi.graphs.nodes.weights.UniformWeights", "norm": norm}

weights = instantiate(config).get_weights(graph_with_nodes["test_nodes"])

assert weights is not None
assert isinstance(weights, torch.Tensor)
assert weights.shape[0] == graph_with_nodes["test_nodes"].x.shape[0]


@pytest.mark.parametrize("norm", ["l3", "invalide"])
def test_uniform_weights_fail(graph_with_nodes: HeteroData, norm: str):
"""Test NPZNodes register correctly the weights."""
config = {"_target_": "anemoi.graphs.nodes.weights.UniformWeights", "norm": norm}

with pytest.raises(ValueError):
instantiate(config).get_weights(graph_with_nodes["test_nodes"])


def test_area_weights(graph_with_nodes: HeteroData):
"""Test NPZNodes register correctly the weights."""
config = {
"_target_": "anemoi.graphs.nodes.weights.AreaWeights",
"radius": 1.0,
"centre": np.array([0, 0, 0]),
}

weights = instantiate(config).get_weights(graph_with_nodes["test_nodes"])

assert weights is not None
assert isinstance(weights, torch.Tensor)
assert weights.shape[0] == graph_with_nodes["test_nodes"].x.shape[0]


@pytest.mark.parametrize("radius", [-1.0, "hello", None])
def test_area_weights_fail(graph_with_nodes: HeteroData, radius: float):
config = {
"_target_": "anemoi.graphs.nodes.weights.AreaWeights",
"radius": radius,
"centre": np.array([0, 0, 0]),
}

with pytest.raises(ValueError):
instantiate(config).get_weights(graph_with_nodes["test_nodes"])
50 changes: 50 additions & 0 deletions tests/nodes/test_zarr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest
import torch
import zarr
from torch_geometric.data import HeteroData

from anemoi.graphs.nodes import nodes
from anemoi.graphs.nodes.weights import AreaWeights
from anemoi.graphs.nodes.weights import UniformWeights


def test_init(mocker, mock_zarr_dataset):
"""Test ZarrNodes initialization."""
mocker.patch.object(nodes, "open_dataset", return_value=mock_zarr_dataset)
node_builder = nodes.ZarrNodes("dataset.zarr")
assert isinstance(node_builder, nodes.BaseNodeBuilder)
assert isinstance(node_builder, nodes.ZarrNodes)


def test_fail_init():
"""Test ZarrNodes initialization with invalid resolution."""
with pytest.raises(zarr.errors.PathNotFoundError):
nodes.ZarrNodes("invalid_path.zarr")


def test_register_nodes(mocker, mock_zarr_dataset):
"""Test ZarrNodes register correctly the nodes."""
mocker.patch.object(nodes, "open_dataset", return_value=mock_zarr_dataset)
node_builder = nodes.ZarrNodes("dataset.zarr")
graph = HeteroData()

graph = node_builder.register_nodes(graph, "test_nodes")

assert graph["test_nodes"].x is not None
assert isinstance(graph["test_nodes"].x, torch.Tensor)
assert graph["test_nodes"].x.shape == (node_builder.ds.num_nodes, 2)
assert graph["test_nodes"].node_type == "ZarrNodes"


@pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights])
def test_register_weights(mocker, graph_with_nodes: HeteroData, attr_class):
"""Test ZarrNodes register correctly the weights."""
mocker.patch.object(nodes, "open_dataset", return_value=None)
node_builder = nodes.ZarrNodes("dataset.zarr")
config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.weights.{attr_class.__name__}"}}

graph = node_builder.register_attributes(graph_with_nodes, "test_nodes", config)

assert graph["test_nodes"]["test_attr"] is not None
assert isinstance(graph["test_nodes"]["test_attr"], torch.Tensor)
assert graph["test_nodes"]["test_attr"].shape[0] == graph["test_nodes"].x.shape[0]

0 comments on commit b12272d

Please sign in to comment.