diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..a19ade0 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +CHANGELOG.md merge=union diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..74bdac0 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,6 @@ +# CODEOWNERS file + +# Protect workflow files +/.github/ @theissenhelen @jesperdramsch @gmertes @b8raoult @floriankrb @anaprietonem @HCookie @JPXKQX @mchantry +/.pre-commit-config.yaml @theissenhelen @jesperdramsch @gmertes @b8raoult @floriankrb @anaprietonem @HCookie @JPXKQX @mchantry +/pyproject.toml @theissenhelen @jesperdramsch @gmertes @b8raoult @floriankrb @anaprietonem @HCookie @JPXKQX @mchantry diff --git a/.github/ci-config.yml b/.github/ci-config.yml new file mode 100644 index 0000000..6138e63 --- /dev/null +++ b/.github/ci-config.yml @@ -0,0 +1,9 @@ +dependencies: | + ecmwf/ecbuild + MathisRosenhauer/libaec@master + ecmwf/eccodes + ecmwf/eckit + ecmwf/odc +dependency_branch: develop +parallelism_factor: 8 +self_build: false # Only for python packages diff --git a/.github/ci-hpc-config.yml b/.github/ci-hpc-config.yml new file mode 100644 index 0000000..bbe6ef6 --- /dev/null +++ b/.github/ci-hpc-config.yml @@ -0,0 +1,16 @@ +build: + modules: + - ninja + dependencies: + - ecmwf/ecbuild@develop + - ecmwf/eccodes@develop + - ecmwf/eckit@develop + - ecmwf/odc@develop + python_dependencies: + - ecmwf/anemoi-utils@develop + - ecmwf/anemoi-datasets@develop + parallel: 64 + + pytest_cmd: | + python -m pytest -vv -m 'not notebook and not no_cache_init' --cov=. --cov-report=xml + python -m coverage report diff --git a/.github/workflows/changelog-pr-update.yml b/.github/workflows/changelog-pr-update.yml new file mode 100644 index 0000000..73cb1eb --- /dev/null +++ b/.github/workflows/changelog-pr-update.yml @@ -0,0 +1,18 @@ +name: Check Changelog Update on PR +on: + pull_request: + types: [assigned, opened, synchronize, reopened, labeled, unlabeled] + branches: + - main + - develop + paths-ignore: + - .pre-commit-config.yaml + - .readthedocs.yaml +jobs: + Check-Changelog: + name: Check Changelog Action + runs-on: ubuntu-20.04 + steps: + - uses: tarides/changelog-check-action@v2 + with: + changelog: CHANGELOG.md diff --git a/.github/workflows/changelog-release-update.yml b/.github/workflows/changelog-release-update.yml new file mode 100644 index 0000000..17d9525 --- /dev/null +++ b/.github/workflows/changelog-release-update.yml @@ -0,0 +1,35 @@ +# .github/workflows/update-changelog.yaml +name: "Update Changelog" + +on: + release: + types: [released] + workflow_dispatch: ~ + +permissions: + pull-requests: write + contents: write + +jobs: + update: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ github.event.release.target_commitish }} + + - name: Update Changelog + uses: stefanzweifel/changelog-updater-action@v1 + with: + latest-version: ${{ github.event.release.tag_name }} + heading-text: ${{ github.event.release.name }} + + - name: Create Pull Request + uses: peter-evans/create-pull-request@v6 + with: + branch: docs/changelog-update-${{ github.event.release.tag_name }} + title: '[Changelog] Update to ${{ github.event.release.tag_name }}' + add-paths: | + CHANGELOG.md diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..1844abc --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,52 @@ +name: ci + +on: + # Trigger the workflow on push to master or develop, except tag creation + push: + branches: + - 'main' + - 'develop' + tags-ignore: + - '**' + paths-ignore: + - "docs/**" + - "CHANGELOG.md" + - "README.md" + + # Trigger the workflow on pull request + pull_request: + paths-ignore: + - "docs/**" + - "CHANGELOG.md" + - "README.md" + + # Trigger the workflow manually + workflow_dispatch: ~ + + # Trigger after public PR approved for CI + pull_request_target: + types: [labeled] + paths-ignore: + - "docs/**" + - "CHANGELOG.md" + - "README.md" + +jobs: + # Run CI including downstream packages on self-hosted runners + downstream-ci: + name: downstream-ci + if: ${{ !github.event.pull_request.head.repo.fork && github.event.action != 'labeled' || github.event.label.name == 'approved-for-ci' }} + uses: ecmwf-actions/downstream-ci/.github/workflows/downstream-ci.yml@main + with: + anemoi-graphs: ecmwf/anemoi-graphs@${{ github.event.pull_request.head.sha || github.sha }} + codecov_upload: true + secrets: inherit + + # Build downstream packages on HPC + downstream-ci-hpc: + name: downstream-ci-hpc + if: ${{ !github.event.pull_request.head.repo.fork && github.event.action != 'labeled' || github.event.label.name == 'approved-for-ci' }} + uses: ecmwf-actions/downstream-ci/.github/workflows/downstream-ci-hpc.yml@main + with: + anemoi-graphs: ecmwf/anemoi-graphs@${{ github.event.pull_request.head.sha || github.sha }} + secrets: inherit diff --git a/.github/workflows/label-public-pr.yml b/.github/workflows/label-public-pr.yml new file mode 100644 index 0000000..59b2bfa --- /dev/null +++ b/.github/workflows/label-public-pr.yml @@ -0,0 +1,10 @@ +# Manage labels of pull requests that originate from forks +name: label-public-pr + +on: + pull_request_target: + types: [opened, synchronize] + +jobs: + label: + uses: ecmwf-actions/reusable-workflows/.github/workflows/label-pr.yml@v2 diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 666f65d..9046e94 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -4,71 +4,24 @@ name: Upload Python Package on: - - push: {} - release: types: [created] jobs: quality: - name: Code QA - runs-on: ubuntu-latest - steps: - - run: sudo apt-get install -y pandoc # Needed by sphinx for notebooks - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 - with: - python-version: 3.x - - uses: pre-commit/action@v3.0.1 + uses: ecmwf-actions/reusable-workflows/.github/workflows/qa-precommit-run.yml@v2 + with: + skip-hooks: "no-commit-to-branch" checks: strategy: - fail-fast: false matrix: - platform: ["ubuntu-latest", "macos-latest"] - python-version: ["3.10"] - - name: Python ${{ matrix.python-version }} on ${{ matrix.platform }} - runs-on: ${{ matrix.platform }} - - steps: - - uses: actions/checkout@v4 - - - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - - name: Install - run: | - pip install -e .[all,tests] - pip freeze - - - name: Tests - run: pytest + python-version: ["3.9", "3.10", "3.11", "3.12"] + uses: ecmwf-actions/reusable-workflows/.github/workflows/qa-pytest-pyproject.yml@v2 + with: + python-version: ${{ matrix.python-version }} deploy: - - if: ${{ github.event_name == 'release' }} - runs-on: ubuntu-latest needs: [checks, quality] - - steps: - - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: 3.x - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install build wheel twine - - name: Build and publish - env: - TWINE_USERNAME: __token__ - TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} - run: | - python -m build - twine upload dist/* + uses: ecmwf-actions/reusable-workflows/.github/workflows/cd-pypi.yml@v2 + secrets: inherit diff --git a/.github/workflows/python-pull-request.yml b/.github/workflows/python-pull-request.yml new file mode 100644 index 0000000..c2be6a4 --- /dev/null +++ b/.github/workflows/python-pull-request.yml @@ -0,0 +1,23 @@ +# This workflow will upload a Python Package using Twine when a release is created +# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries + +name: Code Quality checks for PRs + +on: + push: + pull_request: + types: [opened, synchronize, reopened] + +jobs: + quality: + uses: ecmwf-actions/reusable-workflows/.github/workflows/qa-precommit-run.yml@v2 + with: + skip-hooks: "no-commit-to-branch" + + checks: + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] + uses: ecmwf-actions/reusable-workflows/.github/workflows/qa-pytest-pyproject.yml@v2 + with: + python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/readthedocs-pr-update.yml b/.github/workflows/readthedocs-pr-update.yml new file mode 100644 index 0000000..ebf07e7 --- /dev/null +++ b/.github/workflows/readthedocs-pr-update.yml @@ -0,0 +1,22 @@ +name: Read the Docs PR Preview +on: + pull_request_target: + types: + - opened + - synchronize + - reopened + # Execute this action only on PRs that touch + # documentation files. + paths: + - "docs/**" + +permissions: + pull-requests: write + +jobs: + documentation-links: + runs-on: ubuntu-latest + steps: + - uses: readthedocs/actions/preview@v1 + with: + project-slug: "anemoi-graphs" diff --git a/.gitignore b/.gitignore index 1b49006..c610ac1 100644 --- a/.gitignore +++ b/.gitignore @@ -120,6 +120,7 @@ celerybeat.pid *.sage.py # Environments +.envrc .env .venv env/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f4b6367..4de5932 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,12 +5,12 @@ repos: - id: clear-notebooks-output name: clear-notebooks-output files: tools/.*\.ipynb$ - stages: [commit] + stages: [pre-commit] language: python entry: jupyter nbconvert --ClearOutputPreprocessor.enabled=True --inplace additional_dependencies: [jupyter] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: check-yaml # Check YAML files for syntax errors only args: [--unsafe, --allow-multiple-documents] @@ -20,8 +20,14 @@ repos: - id: no-commit-to-branch # Prevent committing to main / master - id: check-added-large-files # Check for large files added to git - id: check-merge-conflict # Check for files that contain merge conflict +- repo: https://github.com/pre-commit/pygrep-hooks + rev: v1.10.0 # Use the ref you want to point at + hooks: + - id: python-use-type-annotations # Check for missing type annotations + - id: python-check-blanket-noqa # Check for # noqa: all + - id: python-no-log-warn # Check for log.warn - repo: https://github.com/psf/black-pre-commit-mirror - rev: 24.4.2 + rev: 24.8.0 hooks: - id: black args: [--line-length=120] @@ -34,18 +40,17 @@ repos: - --force-single-line-imports - --profile black - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.6 + rev: v0.6.9 hooks: - id: ruff - # Next line if for documenation cod snippets - exclude: '^[^_].*_\.py$' args: - --line-length=120 - --fix - --exit-non-zero-on-fix - --preview + - --exclude=docs/**/*_.py - repo: https://github.com/sphinx-contrib/sphinx-lint - rev: v0.9.1 + rev: v1.0.0 hooks: - id: sphinx-lint # For now, we use it. But it does not support a lot of sphinx features @@ -59,12 +64,21 @@ repos: hooks: - id: docconvert args: ["numpy"] -- repo: https://github.com/b8raoult/optional-dependencies-all - rev: "0.0.6" - hooks: - - id: optional-dependencies-all - args: ["--inplace", "--exclude-keys=dev,docs,tests", "--group=dev=all,docs,tests"] - repo: https://github.com/tox-dev/pyproject-fmt - rev: "2.1.3" + rev: "2.2.4" hooks: - id: pyproject-fmt +- repo: https://github.com/jshwi/docsig # Check docstrings against function sig + rev: v0.64.0 + hooks: + - id: docsig + args: + - --ignore-no-params # Allow docstrings without parameters + - --check-dunders # Check dunder methods + - --check-overridden # Check overridden methods + - --check-protected # Check protected methods + - --check-class # Check class docstrings + - --disable=E113 # Disable empty docstrings + - --summary # Print a summary +ci: + autoupdate_schedule: monthly diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..f474881 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,109 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +Please add your functional changes to the appropriate section in the PR. +Keep it human-readable, your future self will thank you! + +## [Unreleased](https://github.com/ecmwf/anemoi-graphs/compare/0.3.0...HEAD) + +### Added + +- ci: hpc-config, CODEOWNERS (#49) +- feat: New node builder class, CutOutZarrDatasetNodes, to create nodes from 2 datasets. (#30) +- feat: New class, KNNAreaMaskBuilder, to specify Area of Interest (AOI) based on a set of nodes. (#30) +- feat: New node builder classes, LimitedAreaXXXXXNodes, to create nodes within an Area of Interest (AOI). (#30) +- feat: Expanded MultiScaleEdges to support multi-scale connections in limited area graphs. (#30) +- feat: New method update_graph(graph) in the GraphCreator class. (#60) +- feat: New class StretchedTriNodes to create a stretched mesh. (#51) +- feat: Expanded MultiScaleEdges to support multi-scale connections in stretched graphs. (#51) +- fix: bug in color when plotting isolated nodes (#63) +- Add anemoi-transform link to documentation (#59) +- Added `CutOutMask` class to create a mask for a cutout. (#68) +- Added `MissingZarrVariable` and `NotMissingZarrVariable` classes to create a mask for missing zarr variables. (#68) +- feat: Add CONTRIBUTORS.md file. (#72) + +### Changed +- ci: small fixes and updates pre-commit, downsteam-ci (#49) +- Update CODEOWNERS (#61) +- ci: extened python versions to include 3.11 and 3.12 (#66) +- Update copyright notice (#67) + +### Removed +- Remove `CutOutZarrDatasetNodes` class. (#68) +- Update CODEOWNERS +- Fix pre-commit regex +- ci: extened python versions to include 3.11 and 3.12 +- Update copyright notice +- Fix `__version__` import in init + +## [0.3.0 Anemoi-graphs, minor release](https://github.com/ecmwf/anemoi-graphs/compare/0.2.1...0.3.0) - 2024-09-03 + +### Added + +- HEALPixNodes - nodebuilder based on Hierarchical Equal Area isoLatitude Pixelation of a sphere + +- Inspection tools: interactive plots, and distribution plots of edge & node attributes. + +- Graph description print in the console. + +- CLI entry point: 'anemoi-graphs inspect ...'. + +- added downstream-ci pipeline and cd-pypi reusable workflow + +- Changelog release updater + +- Create package documentation. + + +### Changed + +- fix: added support for Python3.9. +- fix: bug in graph cleaning method +- fix: `anemoi-graphs create` CLI argument is casted to a Path. +- ci: fix missing binary dependency in ci-config.yaml +- fix: Updated `get_raw_values` method in `AreaWeights` to ensure compatibility with `scipy.spatial.SphericalVoronoi` by converting `latitudes` and `longitudes` to NumPy arrays before passing them to the `latlon_rad_to_cartesian` function. This resolves an issue where the function would fail if passed Torch Tensors directly. +- ci: Reusable workflows for push, PR, and releases +- ci: ignore docs for downstream ci +- ci: changed Changelog action to create PR +- ci: fixes and permissions on changelog updater + +### Removed + +## [0.2.1](https://github.com/ecmwf/anemoi-graphs/compare/0.2.0...0.2.1) - Anemoi-graph Release, bug fix release + +### Added + +### Changed + +- Fix The 'save_path' argument of the GraphCreator class is optional, allowing users to create graphs without saving them. + +### Removed + +## [0.2.0](https://github.com/ecmwf/anemoi-graphs/compare/0.1.0...0.2.0) - Anemoi-graph Release, Icosahedral graph building + +### Added + +- New node builders by iteratively refining an icosahedron: TriNodes, HexNodes. +- New edge builders for building multi-scale connections. +- Added Changelog + +### Changed + +### Removed + +## [0.1.0](https://github.com/ecmwf/anemoi-graphs/releases/tag/0.1.0) - Initial Release, Global graph building + +### Added + +- Documentation +- Initial implementation for global graph building on the fly from Zarr and NPZ datasets + +### Changed + +### Removed + + diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md new file mode 100644 index 0000000..b226907 --- /dev/null +++ b/CONTRIBUTORS.md @@ -0,0 +1,13 @@ +## How to Contribute + +Please see the [read the docs](https://anemoi-training.readthedocs.io/en/latest/dev/contributing.html). + + +## Contributors + +Thank you to all the wonderful people who have contributed to Anemoi. Contributions can come in many forms, including code, documentation, bug reports, feature suggestions, design, and more. A list of code-based contributors can be found [here](https://github.com/ecmwf/anemoi-graphs/graphs/contributors). + + +## Contributing Organisations + +Significant contributions have been made by the following organisations: [DWD](https://www.dwd.de/), [MET Norway](https://www.met.no/), [MeteoSwiss](https://www.meteoswiss.admin.ch/), [RMI](https://www.meteo.be/) & [ECMWF](https://www.ecmwf.int/) diff --git a/README.md b/README.md index 607c2d3..21b3248 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ $ anemoi-graphs create recipe.yaml my_graph.pt ## License ``` -Copyright 2022, European Centre for Medium Range Weather Forecasts. +Copyright 2024, Anemoi contributors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/docs/_static/cutoff.jpg b/docs/_static/cutoff.jpg new file mode 100644 index 0000000..90f20dd Binary files /dev/null and b/docs/_static/cutoff.jpg differ diff --git a/docs/_static/hetero_data_graph.txt b/docs/_static/hetero_data_graph.txt new file mode 100644 index 0000000..b5b02f1 --- /dev/null +++ b/docs/_static/hetero_data_graph.txt @@ -0,0 +1,30 @@ +HeteroData( + data={ + x=[40320, 2], # coordinates in radians (lat in [-pi/2, pi/2], lon in [0, 2pi]) + node_type='ZarrDatasetNodes', + area_weight=[40320, 1], + }, + hidden={ + x=[10242, 2], # coordinates in radians (lat in [-pi/2, pi/2], lon in [0, 2pi]) + node_type='TriNodes', + area_weight=[10242, 1], + }, + (data, to, hidden)={ + edge_index=[2, 62980], + edge_type='CutOffEdges', + edge_length=[62980, 1], + edge_dirs=[62980, 2], + }, + (hidden, to, hidden)={ + edge_index=[2, 81900], + edge_type='MultiScaleEdges', + edge_length=[81900, 1], + edge_dirs=[81900, 2], + }, + (hidden, to, data)={ + edge_index=[2, 120960], + edge_type='KNNEdges', + edge_length=[120960, 1], + edge_dirs=[120960, 2], + } +) diff --git a/docs/cli/create.rst b/docs/cli/create.rst new file mode 100644 index 0000000..4f02101 --- /dev/null +++ b/docs/cli/create.rst @@ -0,0 +1,15 @@ +.. _cli-create: + +====== +create +====== + +Use this command to create a graph from a recipe file. + +The syntax of the recipe file is described in :doc:`building graphs <../graphs/introduction>`. + +.. argparse:: + :module: anemoi.graphs.__main__ + :func: create_parser + :prog: anemoi-graphs + :path: create diff --git a/docs/cli/describe.rst b/docs/cli/describe.rst new file mode 100644 index 0000000..18bc73b --- /dev/null +++ b/docs/cli/describe.rst @@ -0,0 +1,15 @@ +.. _cli-describe: + +======== +describe +======== + +Use this command to describe a graph stored in your filesystem. It will print graph information to the console. + +The syntax of the recipe file is described in :doc:`building graphs <../graphs/introduction>`. + +.. argparse:: + :module: anemoi.graphs.__main__ + :func: create_parser + :prog: anemoi-graphs + :path: describe diff --git a/docs/cli/inspect.rst b/docs/cli/inspect.rst new file mode 100644 index 0000000..394d529 --- /dev/null +++ b/docs/cli/inspect.rst @@ -0,0 +1,15 @@ +.. _cli-inspect: + +======== +inspect +======== + +Use this command to inspect a graph stored in your filesystem. + +The syntax of the recipe file is described in :doc:`building graphs <../graphs/introduction>`. + +.. argparse:: + :module: anemoi.graphs.__main__ + :func: create_parser + :prog: anemoi-graphs + :path: inspect diff --git a/docs/cli/introduction.rst b/docs/cli/introduction.rst new file mode 100644 index 0000000..4fa01db --- /dev/null +++ b/docs/cli/introduction.rst @@ -0,0 +1,29 @@ +.. _cli-introduction: + +============= +Introduction +============= + +When you install the `anemoi-graphs` package, this will also install command line tool +called ``anemoi-graphs`` which can be used to design and inspect weather graphs. + +The tool can provide help with the ``--help`` options: + +.. code-block:: bash + + % anemoi-graphs --help + +The commands are: + +.. toctree:: + :maxdepth: 1 + + create + describe + inspect + +.. argparse:: + :module: anemoi.graphs.__main__ + :func: create_parser + :prog: anemoi-graphs + :nosubcommands: diff --git a/docs/conf.py b/docs/conf.py index aff04af..8e507b7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -101,6 +101,10 @@ "https://anemoi-registry.readthedocs.io/en/latest/", ("../../anemoi-registry/docs/_build/html/objects.inv", None), ), + "anemoi-transform": ( + "https://anemoi-transform.readthedocs.io/en/latest/", + ("../../anemoi-transform/docs/_build/html/objects.inv", None), + ), } diff --git a/docs/dev/code_structure.rst b/docs/dev/code_structure.rst new file mode 100644 index 0000000..3eed014 --- /dev/null +++ b/docs/dev/code_structure.rst @@ -0,0 +1,81 @@ +.. _dev-code_structure: + +################ + Code Structure +################ + +Understanding and maintaining the code structure is crucial for +sustainable development of Anemoi Graphs. This guide outlines best +practices for contributing to the codebase. + +****************************** + Subclassing for New Features +****************************** + +When creating a new feature, the recommended practice is to subclass +existing base classes rather than modifying them directly. This approach +preserves functionality for other users while allowing for +customization. + +Example: +======== + +In `anemoi/graphs/nodes/builder.py`, the `BaseNodeBuilder` class serves +as a foundation to define new sets of nodes. New node builders should +subclass this base class. + +******************* + File Organization +******************* + +When developing multiple new functions for a feature: + +#. Create a new file in the folder (e.g., + `edges/builder/.py`) to avoid confusion with base + functions. + +#. Group related functionality together for better organization and + maintainability. + +******************************** + Version Control Best Practices +******************************** + +#. Always use pre-commit hooks to ensure code quality and consistency. +#. Never commit directly to the `develop` branch. +#. Create a new branch for your feature or bug fix, e.g., + `feature/` or `bugfix/`. +#. Submit a Pull Request from your branch to `develop` for peer review + and testing. + +****************************** + Code Style and Documentation +****************************** + +#. Follow PEP 8 guidelines for Python code style, the pre-commit hooks + will help enforce this. +#. Write clear, concise docstrings for all classes and functions using + the Numpy style. +#. Use type hints to improve code readability and catch potential + errors. +#. Add inline comments for complex logic or algorithms. + +********* + Testing +********* + +#. Write unit tests for new features using pytest. +#. Ensure all existing tests pass before submitting a Pull Request. +#. Aim for high test coverage, especially for critical functionality. + +**************************** + Performance Considerations +**************************** + +#. Profile your code to identify performance bottlenecks. +#. Optimize critical paths and frequently called functions. +#. Consider using vectorized operations when working with large + datasets. + +By following these guidelines, you'll contribute to a maintainable and +robust codebase for Anemoi Graphs. diff --git a/docs/dev/contributing.rst b/docs/dev/contributing.rst new file mode 100644 index 0000000..d04b5fd --- /dev/null +++ b/docs/dev/contributing.rst @@ -0,0 +1,151 @@ +.. _dev-contributing: + +############## + Contributing +############## + +Thank you for your interest in contributing to Anemoi Graphs! This guide +will help you get started with the development process. + +**************************************** + Setting Up the Development Environment +**************************************** + +#. Clone the repository: + + .. code:: bash + + git clone https://github.com/ecmwf/anemoi-graphs/ + cd anemoi-graphs + +#. Install dependencies: + + .. code:: bash + + # For all dependencies + pip install -e . + + # For development dependencies + pip install -e '.[dev]' + +#. (macOS only) Install pandoc for documentation building: + + .. code:: bash + + brew install pandoc + +****************** + Pre-Commit Hooks +****************** + +We use pre-commit hooks to ensure code quality and consistency. To set +them up: + +#. Install pre-commit hooks: + + .. code:: bash + + pre-commit install + +#. Run hooks on all files to verify installation: + + .. code:: bash + + pre-commit run --all-files + +******************* + Commit Guidelines +******************* + +Ideally, open an issue for the feature or bug fix you're working on +before starting development, to discuss the approach with maintainers. + +When committing code changes: + +#. Make small, focused commits with clear and concise messages. + +#. Follow the `Conventional Commits guidelines + `_, e.g., "feat:", "fix:", + "docs:", etc. + +#. Use present tense and imperative mood in commit messages (e.g., "Add + feature" not "Added feature"). + +#. Reference relevant issue numbers in commit messages when applicable. + +#. Update the ``CHANGELOG.md`` file with a human-friendly summary of + your changes. + +********************** + Pull Request Process +********************** + +#. Create a new branch for your feature or bug fix. +#. Make your changes and commit them using the guidelines above. +#. Push your branch to your fork on GitHub. +#. Open a Pull Request against the `develop` branch of the main + repository. +#. Ensure all tests pass and the code adheres to the project's style + guidelines. +#. Request a review from maintainers or other contributors. + +*************** + Running Tests +*************** + +We use pytest for our test suite. To run tests: + +.. code:: bash + + # Run all tests + pytest + + # Run tests in a specific file + pytest tests/test_.py + +Note: Some tests, like `test_gnn.py`, may run slower on CPU and are +better suited for GPU execution. + +************************ + Building Documentation +************************ + +You can build the documentation locally to preview changes before +submitting a Pull Request. We use Sphinx for documentation. + +You can install the dependencies for building the documentation with: + +.. code:: bash + + pip install '.[docs]' + +To build the documentation locally: + +.. code:: bash + + cd docs + make html + +The generated documentation will be in `docs/_build/html/index.html`. + +********************* + Code Review Process +********************* + +#. All code changes must be reviewed before merging. +#. Address any feedback or comments from reviewers promptly. +#. Once approved, a maintainer will merge your Pull Request. + +****************** + Reporting Issues +****************** + +If you encounter a bug or have a feature request: + +#. Check the existing issues to avoid duplicates. +#. If it's a new issue, create a detailed bug report or feature request. +#. Use clear, descriptive titles and provide as much relevant + information as possible. + +Thank you for contributing to Anemoi Graphs! Your efforts help improve +the project for everyone. diff --git a/docs/dev/testing.rst b/docs/dev/testing.rst new file mode 100644 index 0000000..68129b1 --- /dev/null +++ b/docs/dev/testing.rst @@ -0,0 +1,192 @@ +.. _dev-testing: + +######### + Testing +######### + +Comprehensive testing is crucial for maintaining the reliability and +stability of Anemoi Graphs. This guide outlines our testing strategy and +best practices for contributing tests. + +******************* + Testing Framework +******************* + +We use pytest as our primary testing framework. Pytest offers a simple +and powerful way to write and run tests. + +*************** + Writing Tests +*************** + +General Guidelines +================== + +#. Write tests for all new features and bug fixes. +#. Aim for high test coverage, especially for critical components. +#. Keep tests simple, focused, and independent of each other. +#. Use descriptive names for test functions, following the pattern + `test__`. + +Example Test Structure +====================== + +.. code:: python + + import pytest + from anemoi.graphs import SomeFeature + + + def test_some_feature_normal_input(): + feature = SomeFeature() + result = feature.process(normal_input) + assert result == expected_output + + + def test_some_feature_edge_case(): + feature = SomeFeature() + with pytest.raises(ValueError): + feature.process(invalid_input) + +**************** + Types of Tests +**************** + +1. Unit Tests +============= + +Test individual components in isolation. These should be the majority of +your tests. + +2. Integration Tests +==================== + +Test how different components work together. These are particularly +important for graph creation workflows. + +3. Functional Tests +=================== + +Test entire features or workflows from start to finish. These ensure +that the system works as expected from a user's perspective. + +4. Parametrized Tests +===================== + +Use pytest's parametrize decorator to run the same test with different +inputs: + +.. code:: python + + @pytest.mark.parametrize( + "input,expected", + [ + (2, 4), + (3, 9), + (4, 16), + ], + ) + def test_square(input, expected): + assert square(input) == expected + +You can also consider ``hypothesis`` for property-based testing. + +5. Fixtures +=========== + +Use fixtures to set up common test data or objects: + +.. code:: python + + @pytest.fixture + def sample_dataset(): + # Create and return a sample dataset + pass + + + def test_data_loading(sample_dataset): + # Use the sample_dataset fixture in your test + pass + +*************** + Running Tests +*************** + +To run all tests: + +.. code:: bash + + pytest + +To run tests in a specific file: + +.. code:: bash + + pytest tests/test_specific_feature.py + +To run tests with a specific mark: + +.. code:: bash + + pytest -m slow + +*************** + Test Coverage +*************** + +We use pytest-cov to measure test coverage. To run tests with coverage: + +.. code:: bash + + pytest --cov=anemoi_graphs + +Aim for at least 80% coverage for new features, and strive to maintain +or improve overall project coverage. + +************************ + Continuous Integration +************************ + +All tests are run automatically on our CI/CD pipeline for every pull +request. Ensure all tests pass before submitting your PR. + +********************* + Performance Testing +********************* + +For performance-critical components: + +#. Write benchmarks. +#. Compare performance before and after changes. +#. Set up performance regression tests in CI. + +********************** + Mocking and Patching +********************** + +Use unittest.mock or pytest-mock for mocking external dependencies or +complex objects: + +.. code:: python + + def test_api_call(mocker): + mock_response = mocker.Mock() + mock_response.json.return_value = {"data": "mocked"} + mocker.patch("requests.get", return_value=mock_response) + + result = my_api_function() + assert result == "mocked" + +**************** + Best Practices +**************** + +#. Keep tests fast: Optimize slow tests or mark them for separate + execution. +#. Use appropriate assertions: pytest provides a rich set of assertions. +#. Test edge cases and error conditions, not just the happy path. +#. Regularly review and update tests as the codebase evolves. +#. Document complex test setups or scenarios. + +By following these guidelines and continuously improving our test suite, +we can ensure the reliability and maintainability of Anemoi Graphs. diff --git a/docs/graphs/edge_attributes.rst b/docs/graphs/edge_attributes.rst new file mode 100644 index 0000000..c7d713b --- /dev/null +++ b/docs/graphs/edge_attributes.rst @@ -0,0 +1,44 @@ +.. _edge-attributes: + +#################### + Edges - Attributes +#################### + +There are 2 main edge attributes implemented in the `anemoi-graphs` +package: + +************* + Edge length +************* + +The `edge length` is a scalar value representing the distance between +the source and target nodes. This attribute is calculated using the +Haversine formula, which is a method of calculating the distance between +two points on the Earth's surface given their latitude and longitude +coordinates. + +.. code:: yaml + + edges: + - ... + edge_builder: ... + attributes: + edge_length: + _target_: anemoi.graphs.edges.attributes.EdgeLength + +**************** + Edge direction +**************** + +The `edge direction` is a 2D vector representing the direction of the +edge. This attribute is calculated from the difference between the +latitude and longitude coordinates of the source and target nodes. + +.. code:: yaml + + edges: + - ... + edge_builder: ... + attributes: + edge_length: + _target_: anemoi.graphs.edges.attributes.EdgeDirection diff --git a/docs/graphs/edges.rst b/docs/graphs/edges.rst new file mode 100644 index 0000000..57aebd2 --- /dev/null +++ b/docs/graphs/edges.rst @@ -0,0 +1,28 @@ +.. _graphs-edges: + +##################### + Edges - Connections +##################### + +Once the `nodes`, :math:`V`, are defined, you can create the `edges`, +:math:`E`, that will connect them. These connections are listed in the +``edges`` section of the recipe file, and they are created independently +for each (`source name`, `target name`) pair specified. + +.. code:: yaml + + edges: + - source_name: data + target_name: hidden + edge_builder: + _target_: anemoi.graphs.edges.CutOff + cutoff_factor: 0.7 + +Below are the available methods for defining the edges: + +.. toctree:: + :maxdepth: 1 + + edges/cutoff + edges/knn + edges/multi_scale diff --git a/docs/graphs/edges/cutoff.rst b/docs/graphs/edges/cutoff.rst new file mode 100644 index 0000000..ba999e6 --- /dev/null +++ b/docs/graphs/edges/cutoff.rst @@ -0,0 +1,52 @@ +################ + Cut-off radius +################ + +The cut-off method is a method for establishing connections between two +sets of nodes. Given two sets of nodes, (`source`, `target`), the +cut-off method connects all source nodes, :math:`V_{source}`, in a +neighbourhood of the target nodes, :math:`V_{target}`. + +.. image:: ../../_static/cutoff.jpg + :alt: Cut-off radius image + :align: center + +The neighbourhood is defined by a `cut-off radius`, which is computed +as, + +.. math:: + + cutoff\_radius = cuttoff\_factor \times nodes\_reference\_dist + +where :math:`nodes\_reference\_dist` is the maximum distance between a +target node and its nearest source node. + +.. math:: + + nodes\_reference\_dist = \max_{x \in V_{target}} \left\{ \min_{y \in V_{source}, y \neq x} \left\{ d(x, y) \right\} \right\} + +where :math:`d(x, y)` is the `Haversine distance +`_ between nodes +:math:`x` and :math:`y`. The ``cutoff_factor`` is a parameter that can +be adjusted to increase or decrease the size of the neighbourhood, and +consequently the number of connections in the graph. + +To use this method to create your connections, you can use the following +YAML configuration: + +.. code:: yaml + + edges: + - source_name: source + target_name: destination + edge_builder: + _target_: anemoi.graphs.edges.CutOffEdges + cutoff_factor: 0.6 + +.. note:: + + The cut-off method is recommended for the encoder edges, to connect + all data nodes to hidden nodes. The optimal ``cutoff_factor`` value + will be the lowest value without orphan nodes. This optimal value + depends on the node distribution, so it is recommended to tune it for + each case. diff --git a/docs/graphs/edges/knn.rst b/docs/graphs/edges/knn.rst new file mode 100644 index 0000000..b90c048 --- /dev/null +++ b/docs/graphs/edges/knn.rst @@ -0,0 +1,25 @@ +##################### + K-Nearest Neighbors +##################### + +The knn method is a method for establishing connections between two sets +of nodes. Given two sets of nodes, (`source`, `target`), the knn method +connects all destination nodes, to their ``num_nearest_neighbours`` +nearest source nodes. + +To use this method to build your connections, you can use the following +YAML configuration: + +.. code:: yaml + + edges: + - source_name: source + target_name: destination + edge_builder: + _target_: anemoi.graphs.edges.KNNEdges + num_nearest_neighbours: 3 + +.. note:: + + The knn method is recommended for the decoder edges, to connect all + data nodes with the surrounding hidden nodes. diff --git a/docs/graphs/edges/multi_scale.rst b/docs/graphs/edges/multi_scale.rst new file mode 100644 index 0000000..7f5878a --- /dev/null +++ b/docs/graphs/edges/multi_scale.rst @@ -0,0 +1,35 @@ +################################################ + Multi-scale connections at refined icosahedron +################################################ + +The multi-scale connections can only be defined with the same source and +target nodes. Edges of different scales are defined based on the +refinement level of an icosahedron. The higher the refinement level, the +shorter the length of the edges. By default, all possible refinements +levels are considered. + +To use this method to build your connections, you can use the following +YAML configuration: + +.. code:: yaml + + edges: + - source_name: source + target_name: source + edge_builder: + _target_: anemoi.graphs.edges.MultiScaleEdges + x_hops: 1 + +where `x_hops` is the number of hops between two nodes of the same +refinement level to be considered neighbours, and then connected. + +.. note:: + + This method is used by data-driven weather models like GraphCast to + process the latent/hidden state. + +.. warning:: + + This connection method is only supported for building the connections + within a set of nodes defined with the ``TriNodes`` or ``HexNodes`` + classes. diff --git a/docs/graphs/introduction.rst b/docs/graphs/introduction.rst new file mode 100644 index 0000000..b0d6383 --- /dev/null +++ b/docs/graphs/introduction.rst @@ -0,0 +1,78 @@ +.. _graphs-introduction: + +############## + Introduction +############## + +The `anemoi-graphs` package allows you to design custom graphs for +training data-driven weather models. The graphs are built using a +`recipe`, which is a YAML file that specifies the nodes and edges of the +graph. + +********** + Concepts +********** + +nodes + A `node` represents a location (2D) on the earth's surface which may + contain additional `attributes`. + +data nodes + A set of nodes representing one or multiple datasets. The `data + nodes` may correspond to the input/output of our data-driven model. + They can be defined from Zarr datasets and this method supports all + :ref:`anemoi-datasets ` operations such + as `cutout` or `thinning`. + +hidden nodes + The `hidden nodes` capture intermediate representations of the model, + which are used to learn the dynamics of the system considered + (atmosphere, ocean, etc, ...). These nodes can be generated from + existing locations (Zarr datasets or NPZ files) or algorithmically + from iterative refinements of polygons over the globe. + +isolated nodes + A set of nodes that are not connected to any other nodes in the + graph. These nodes can be used to store additional information that + is not directly used in the training process. + +edges + An `edge` represents a connection between two nodes. The `edges` can + be used to define the flow of information between the nodes. Edges + may also contain `attributes` related to their length, direction or + other properties. + +***************** + Data structures +***************** + +The nodes :math:`V` correspond to locations on the earth's surface, and +they can be classified into 2 categories: + +- **Data nodes**: The `data nodes` represent the input/output of the + data-driven model, i.e. they are linked to existing datasets. +- **Hidden nodes**: These `hidden nodes` represent the latent space, + where the internal dynamics are learned. + +Several methods are currently supported to create your nodes. You can +use indistinctly any of these to create your `data` or `hidden` nodes. + +The `nodes` are defined in the ``nodes`` section of the recipe file. The +keys are the names of the sets of `nodes` that will later be used to +build the connections. Each `nodes` configuration must include a +``node_builder`` section describing how to define the `nodes`. The +following classes define different behaviour: + +- :doc:`node_coordinates/zarr_dataset` +- :doc:`node_coordinates/npz_file` +- :doc:`node_coordinates/tri_refined_icosahedron` +- :doc:`node_coordinates/hex_refined_icosahedron` +- :doc:`node_coordinates/healpix` + +In addition to the ``node_builder`` section, the `nodes` configuration +can contain an optional ``attributes`` section to define additional node +attributes (weights, mask, ...). For example, the weights can be used to +define the importance of each node in the loss function, or the masks +can be used to build connections only between subsets of nodes. + +- :doc:`node_attributes/weights` diff --git a/docs/graphs/node_attributes.rst b/docs/graphs/node_attributes.rst new file mode 100644 index 0000000..9856625 --- /dev/null +++ b/docs/graphs/node_attributes.rst @@ -0,0 +1,23 @@ +.. _graphs-node_attributes: + +#################### + Nodes - Attributes +#################### + +.. warning:: + + This is still a work in progress. More classes will be added in the + future. + +The nodes :math:`V` correspond to locations on the earth's surface. As +well as defining their locations, the `nodes` can contain additional +attributes, which should be defined in the ``attributes`` section of the +`nodes` configuration. For example, a `weights` attribute can be used to +define the importance of each node in the loss function, or a `masks` +attribute can be used to build connections only between subsets of +nodes. + +.. toctree:: + :maxdepth: 1 + + node_attributes/weights diff --git a/docs/graphs/node_attributes/weights.rst b/docs/graphs/node_attributes/weights.rst new file mode 100644 index 0000000..b3cfccd --- /dev/null +++ b/docs/graphs/node_attributes/weights.rst @@ -0,0 +1,10 @@ +######### + Weights +######### + +The `weights` are a node attribute useful for defining the importance of +a node in the loss function. You can set the weights to follow an +uniform distribution or to match the area associated with that node. + +.. literalinclude:: ../yaml/attributes_weights.yaml + :language: yaml diff --git a/docs/graphs/node_coordinates.rst b/docs/graphs/node_coordinates.rst new file mode 100644 index 0000000..aff5b70 --- /dev/null +++ b/docs/graphs/node_coordinates.rst @@ -0,0 +1,37 @@ +.. _graphs-node_coordinates: + +##################### + Nodes - Coordinates +##################### + +.. warning:: + + This is still a work in progress. More classes will be added in the + future. + +The `nodes` :math:`V` correspond to locations on the earth's surface. + +The `nodes` are defined in the ``nodes`` section of the recipe file. The +keys are the names of the sets of `nodes` that will later be used to +build the connections. Each `nodes` configuration must include a +``node_builder`` section describing how to define the `nodes`. + +The `nodes` can be defined based on the coordinates already available in +a file: + +.. toctree:: + :maxdepth: 1 + + node_coordinates/zarr_dataset + node_coordinates/npz_file + +or based on other algorithms. A commonn approach is to use an +icosahedron to project the earth's surface, and refine it iteratively to +reach the desired resolution. + +.. toctree:: + :maxdepth: 1 + + node_coordinates/tri_refined_icosahedron + node_coordinates/hex_refined_icosahedron + node_coordinates/healpix diff --git a/docs/graphs/node_coordinates/healpix.csv b/docs/graphs/node_coordinates/healpix.csv new file mode 100644 index 0000000..1783471 --- /dev/null +++ b/docs/graphs/node_coordinates/healpix.csv @@ -0,0 +1,11 @@ +Refinement level,Number of nodes,Resolution (km),Resolution (degrees) +0,12,6371,57.296 +1,48,3185.5,28.648 +2,192,1592.75,14.324 +3,768,796.375,7.162 +4,3072,398.187,3.581 +5,12288,199.094,1.790 +6,49152,99.547,0.895 +7,196608,49.773,0.448 +8,786432,24.887,0.224 +9,3145728,12.443,0.112 diff --git a/docs/graphs/node_coordinates/healpix.rst b/docs/graphs/node_coordinates/healpix.rst new file mode 100644 index 0000000..2f5fc17 --- /dev/null +++ b/docs/graphs/node_coordinates/healpix.rst @@ -0,0 +1,30 @@ +############### + HEALPix Nodes +############### + +This method allows us to define nodes based on the Hierarchical Equal +Area isoLatitude Pixelation of a sphere (HEALPix). The resolution of the +HEALPix grid is defined by the `resolution` parameter, which corresponds +to the number of refinements of the sphere. + +.. code:: yaml + + nodes: + data: + node_builder: + _target_: anemoi.graphs.nodes.HEALPixNodes + resolution: 3 + attributes: ... + +For reference, the following table shows the number of nodes and +resolution for each resolution: + +.. csv-table:: HEALPix refinements specifications + :file: ./healpix.csv + :header-rows: 1 + +.. warning:: + + This class will require the `healpy + `_ package to be installed. You can + install it with `pip install healpy`. diff --git a/docs/graphs/node_coordinates/hex_refined.csv b/docs/graphs/node_coordinates/hex_refined.csv new file mode 100644 index 0000000..dc78f58 --- /dev/null +++ b/docs/graphs/node_coordinates/hex_refined.csv @@ -0,0 +1,10 @@ +Refinement Level,Number of nodes,Avg. Hexagon Area (sq km) +0,122,4.250.546 +1,842,607.220 +2,5.882,86.745 +3,41.162,12.392 +4,288.122,1.770 +5,2.016.842,252 +6,14.117.882,36 +7,98.825.162,5.1 +8,691.776.122,0.7 diff --git a/docs/graphs/node_coordinates/hex_refined_icosahedron.rst b/docs/graphs/node_coordinates/hex_refined_icosahedron.rst new file mode 100644 index 0000000..7f665cd --- /dev/null +++ b/docs/graphs/node_coordinates/hex_refined_icosahedron.rst @@ -0,0 +1,36 @@ +############################### + Hexagonal refined Icosahedron +############################### + +This method allows us to define the nodes based on the Hexagonal +Hierarchical Geospatial Indexing System, which uses hexagons to divide +the sphere. Each refinement level divides each hexagon into seven +smaller hexagons. + +To define the `node coordinates` based on the hexagonal refinements of +an icosahedron, you can use the following YAML configuration: + +.. code:: yaml + + nodes: + data: + node_builder: + _target_: anemoi.graphs.nodes.HexNodes + resolution: 4 + attributes: ... + +where resolution is the number of refinements to be applied. + +.. csv-table:: Hexagonal Hierarchical refinements specifications + :file: ./hex_refined.csv + :header-rows: 1 + +Note that the refinement level is the parameter used to control the +resolution of the nodes, but the resolution also depends on the +refinement method. Then, for the same refinement level, ``HexNodes`` +will have a higher resolution than ``TriNodes``. + +.. warning:: + + This class will require the `h3 `_ package to be + installed. You can install it with `pip install h3`. diff --git a/docs/graphs/node_coordinates/npz_file.rst b/docs/graphs/node_coordinates/npz_file.rst new file mode 100644 index 0000000..266687d --- /dev/null +++ b/docs/graphs/node_coordinates/npz_file.rst @@ -0,0 +1,30 @@ +############### + From NPZ file +############### + +To define the `node coordinates` based on a NPZ file, you can use the +following YAML configuration: + +.. code:: yaml + + nodes: + data: + node_builder: + _target_: anemoi.graphs.nodes.NPZFileNodes + grids_definition_path: /path/to/folder/with/grids/ + resolution: o48 + +where `grids_definition_path` is the path to the folder containing the +grid definition files and `resolution` is the resolution of the grid to +be used. + +By default, the grid files are supposed to be in the `grids` folder in +the same directory as the recipe file. The grid definition files are +expected to be name `"grid_{resolution}.npz"`. + +.. note:: + + The NPZ file should contain the following keys: + + - `longitudes`: The longitudes of the grid. + - `latitudes`: The latitudes of the grid. diff --git a/docs/graphs/node_coordinates/tri_refined_icosahedron.rst b/docs/graphs/node_coordinates/tri_refined_icosahedron.rst new file mode 100644 index 0000000..44b3e44 --- /dev/null +++ b/docs/graphs/node_coordinates/tri_refined_icosahedron.rst @@ -0,0 +1,31 @@ +################################ + Triangular refined Icosahedron +################################ + +This class allows us to define nodes based on iterative refinements of +an icoshaedron with triangles. + +To define the `node coordinates` based on icosahedral refinements of an +icosahedron, you can use the following YAML configuration: + +.. code:: yaml + + nodes: + data: + node_builder: + _target_: anemoi.graphs.nodes.TriNodes + resolution: 4 + attributes: ... + +where resolution is the number of refinements to be applied to the +icosahedron. + +Note that the refinement level is the parameter used to control the +resolution of the nodes, but the resolution also depends on the +refinement method. Then, for the same refinement level, ``HexNodes`` +will have a higher resolution than ``TriNodes``. + +.. warning:: + + This class will require the `trimesh `_ package + to be installed. You can install it with `pip install trimesh`. diff --git a/docs/graphs/node_coordinates/zarr_dataset.rst b/docs/graphs/node_coordinates/zarr_dataset.rst new file mode 100644 index 0000000..3723c0e --- /dev/null +++ b/docs/graphs/node_coordinates/zarr_dataset.rst @@ -0,0 +1,39 @@ +################### + From Zarr dataset +################### + +This class builds a set of nodes from a Zarr dataset. The nodes are +defined by the coordinates of the dataset. The ZarrDataset class +supports operations compatible with :ref:`anemoi-datasets +`. + +To define the `node coordinates` based on a Zarr dataset, you can use +the following YAML configuration: + +.. code:: yaml + + nodes: + data: + node_builder: + _target_: anemoi.graphs.nodes.ZarrDatasetNodes + dataset: /path/to/dataset.zarr + attributes: ... + +where `dataset` is the path to the Zarr dataset. The +``ZarrDatasetNodes`` class supports operations compatible with +:ref:`anemoi-datasets `, such as "cutout". +Below, an example of how to use the "cutout" operation directly within +:ref:`anemoi-graphs `. + +.. code:: yaml + + nodes: + data: + node_builder: + _target_: anemoi.graphs.nodes.ZarrDatasetNodes + dataset: + cutout: + dataset: /path/to/lam_dataset.zarr + dataset: /path/to/boundary_forcing.zarr + adjust: "all" + attributes: ... diff --git a/docs/graphs/yaml/attributes_weights.yaml b/docs/graphs/yaml/attributes_weights.yaml new file mode 100644 index 0000000..f889ef7 --- /dev/null +++ b/docs/graphs/yaml/attributes_weights.yaml @@ -0,0 +1,10 @@ +nodes: + data: + node_builder: + _target_: anemoi.graphs.nodes.nodes.ZarrDatasetNodeBuilder + dataset: /path/to/dataset.zarr + attributes: + weights: + _target_: anemoi.graphs.nodes.weights.Area + norm: unit-max + hidden: ... diff --git a/docs/index.rst b/docs/index.rst index 46aaf9a..3baa098 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -22,29 +22,113 @@ models from existing recipes but with their own data. This package provides a series of utility functions for used by the rest of the *Anemoi* packages. -- :doc:`installing` +- :doc:`overview` .. toctree:: :maxdepth: 1 :hidden: - installing + overview + +***************** + Building graphs +***************** + +- :doc:`graphs/introduction` +- :doc:`graphs/node_coordinates` +- :doc:`graphs/node_attributes` +- :doc:`graphs/edges` +- :doc:`graphs/edge_attributes` + +.. toctree:: + :maxdepth: 1 + :hidden: + :caption: Building graphs + + graphs/introduction + graphs/node_coordinates + graphs/node_attributes + graphs/edges + graphs/edge_attributes ********* Modules ********* +- :doc:`modules/node_builder` +- :doc:`modules/edge_builder` +- :doc:`modules/node_attributes` +- :doc:`modules/edge_attributes` +- :doc:`modules/graph_creator` +- :doc:`modules/graph_inspector` + +.. toctree:: + :maxdepth: 1 + :hidden: + :caption: Modules + + modules/node_builder + modules/edge_builder + modules/node_attributes + modules/edge_attributes + modules/graph_creator + modules/graph_inspector + +******************* + Command line tool +******************* + +- :doc:`cli/introduction` +- :doc:`cli/create` +- :doc:`cli/describe` +- :doc:`cli/inspect` + +.. toctree:: + :maxdepth: 1 + :hidden: + :caption: Command line tool + + cli/introduction + cli/create + cli/describe + cli/inspect + +************************** + Developing Anemoi Graphs +************************** + +- :doc:`dev/contributing` +- :doc:`dev/code_structure` +- :doc:`dev/testing` + .. toctree:: :maxdepth: 1 - :glob: + :hidden: + :caption: Developing Anemoi Graphs + + dev/contributing + dev/code_structure + dev/testing + +*********** + Tutorials +*********** + +- :doc:`usage/getting_started` + +.. toctree:: + :maxdepth: 1 + :hidden: + :caption: Usage - modules/* + usage/getting_started ***************** Anemoi packages ***************** - :ref:`anemoi-utils ` +- :ref:`anemoi-transform ` - :ref:`anemoi-datasets ` - :ref:`anemoi-models ` - :ref:`anemoi-graphs ` diff --git a/docs/installing.rst b/docs/installing.rst deleted file mode 100644 index e452a31..0000000 --- a/docs/installing.rst +++ /dev/null @@ -1,31 +0,0 @@ -############ - Installing -############ - -To install the package, you can use the following command: - -.. code:: bash - - pip install anemoi-graphs[...options...] - -The options are: - -- ``dev``: install the development dependencies -- ``all``: install all the dependencies - -************** - Contributing -************** - -.. code:: bash - - git clone ... - cd anemoi-graphs - pip install .[dev] - pip install -r docs/requirements.txt - -You may also have to install pandoc on MacOS: - -.. code:: bash - - brew install pandoc diff --git a/docs/modules/dates.rst b/docs/modules/dates.rst deleted file mode 100644 index 94713af..0000000 --- a/docs/modules/dates.rst +++ /dev/null @@ -1,8 +0,0 @@ -####### - dates -####### - -.. automodule:: anemoi.graphs.dates - :members: - :no-undoc-members: - :show-inheritance: diff --git a/docs/modules/edge_attributes.rst b/docs/modules/edge_attributes.rst new file mode 100644 index 0000000..9abb5e7 --- /dev/null +++ b/docs/modules/edge_attributes.rst @@ -0,0 +1,11 @@ +.. _modules-edge_attributes: + +################# + Edge attributes +################# + +.. automodule:: anemoi.graphs.edges.attributes + :members: + :exclude-members: BaseEdgeAttribute + :no-undoc-members: + :show-inheritance: diff --git a/docs/modules/edge_builder.rst b/docs/modules/edge_builder.rst new file mode 100644 index 0000000..1fa555b --- /dev/null +++ b/docs/modules/edge_builder.rst @@ -0,0 +1,11 @@ +.. _modules-edge_builder: + +############## + Edge builder +############## + +.. automodule:: anemoi.graphs.edges.builder + :members: + :exclude-members: BaseEdgeBuilder + :no-undoc-members: + :show-inheritance: diff --git a/docs/modules/graph_creator.rst b/docs/modules/graph_creator.rst new file mode 100644 index 0000000..7221c05 --- /dev/null +++ b/docs/modules/graph_creator.rst @@ -0,0 +1,14 @@ +.. _modules-graph_creator: + +############### + Graph Creator +############### + +This module is used to create custom graphs for data-driven weather +models. The graphs are built using a `recipe` that defines the structure +of the graph. + +.. automodule:: anemoi.graphs.create + :members: + :no-undoc-members: + :show-inheritance: diff --git a/docs/modules/graph_inspector.rst b/docs/modules/graph_inspector.rst new file mode 100644 index 0000000..e7f1ede --- /dev/null +++ b/docs/modules/graph_inspector.rst @@ -0,0 +1,17 @@ +.. _modules-graph_inspector: + +################# + Graph Inspector +################# + +This module is used to inspect graphs. This inspection includes: + +- Distribution plots of node & edge attributes. +- Interactive plot of each subgraph. +- Interactive plot of isolated nodes. +- Description of the graph in the console. + +.. automodule:: anemoi.graphs.inspector + :members: + :no-undoc-members: + :show-inheritance: diff --git a/docs/modules/node_attributes.rst b/docs/modules/node_attributes.rst new file mode 100644 index 0000000..3193409 --- /dev/null +++ b/docs/modules/node_attributes.rst @@ -0,0 +1,11 @@ +.. _modules-node_attributes: + +################# + Node attributes +################# + +.. automodule:: anemoi.graphs.nodes.attributes + :members: + :exclude-members: BaseWeights + :no-undoc-members: + :show-inheritance: diff --git a/docs/modules/node_builder.rst b/docs/modules/node_builder.rst new file mode 100644 index 0000000..9a83de5 --- /dev/null +++ b/docs/modules/node_builder.rst @@ -0,0 +1,11 @@ +.. _modules-node_builder: + +############## + Node builder +############## + +.. automodule:: anemoi.graphs.nodes.builder + :members: + :exclude-members: BaseNodeBuilder,IcosahedralNodes + :no-undoc-members: + :show-inheritance: diff --git a/docs/overview.rst b/docs/overview.rst new file mode 100644 index 0000000..603709b --- /dev/null +++ b/docs/overview.rst @@ -0,0 +1,127 @@ +.. _overview: + +########## + Overview +########## + +A graph :math:`G = (V, E)` is a collection of nodes/vertices :math:`V` +and edges :math:`E` that connect the nodes. The nodes can represent +locations in the globe. + +In weather models, the nodes :math:`V` can generally be classified into +2 categories: + +- **Data nodes**: The `data nodes` represent the input/output of the + data-driven model, so they are linked to existing datasets. +- **Hidden nodes**: These `hidden nodes` represent the latent space, + where the internal dynamics are learned. + +Similarly, the edges :math:`V` can be classified into 3 categories: + +- **Encoder edges**: These `encoder edges` connect the `data` nodes + with the `hidden` nodes to encode the input data into the latent + space. + +- **Processor edges**: These `processor edges` connect the `hidden` + nodes with the `hidden` nodes to process the latent space. + +- **Decoder edges**: These `decoder edges` connect the `hidden` nodes + with the `data` nodes to decode the latent space into the output + data. + +When building the graph with `anemoi-graphs`, there is no difference +between these categories. However, it is important to keep this +distinction in mind when designing a weather graph to be used in a +data-driven model with :ref:`anemoi-training +`. + +******************* + Design principles +******************* + +In particular, when designing a graph for a weather model, the following +guidelines should be followed: + +- Use a coarser resolution for the `hidden nodes`. This will reduce the + computational cost of training and inference. +- All input nodes should be connected to the `hidden nodes`. This will + ensure that all available information can be used. +- In the encoder edges, minimise the number of connections to the + `hidden nodes`. This will reduce the computational cost. +- All output nodes should have incoming connections from a few + surrounding `hidden nodes`. +- The number of incoming connections in each set of nodes should be be + similar to make the training more stable. +- Think whether or not your use case requires long-range connections + between the `hidden nodes` or not. + +**************** + Data structure +**************** + +The graphs generated by :ref:`anemoi-utils ` +are represented as a `pytorch_geometric.data.HeteroData +`_ +object. They include all the attributes specified in the recipe file and +the node/edge type. The node/edge type represents the node/edge builder +used to create the set of nodes/edges. + +.. literalinclude:: _static/hetero_data_graph.txt + :language: console + +The `HeteroData` object contains some useful attributes such as +`node_types` and `edge_types` which output the nodes and edges defined +in the respective graph. + +.. code:: console + + >>> graph.node_types + ['data', 'hidden'] + + >>> graph.edge_types + [("data", "to", "hidden"), ("hidden", "to", "hidden"), ("hidden", "to", "data")] + +In addition, you can inspect the attributes of the nodes and edges using +the `node_attrs` and `edge_attrs` methods. + +.. code:: console + + >>> graph["data"].node_attrs() + ["x", "area_weight"] + + >>> graph[("data", "to", "hidden")].edge_attrs() + ['edge_index', 'edge_length', 'edge_dirs'] + +************ + Installing +************ + +To install the package, you can use the following command: + +.. code:: bash + + pip install anemoi-graphs[...options...] + +The options are: + +- ``dev``: install the development dependencies +- ``docs``: install the dependencies for the documentation +- ``test``: install the dependencies for testing +- ``all``: install all the dependencies + +************** + Contributing +************** + +.. code:: bash + + git clone ... + cd anemoi-graphs + pip install .[dev] + pip install -r docs/requirements.txt + +You may also have to install pandoc on MacOS: + +.. code:: bash + + brew install pandoc diff --git a/docs/usage/getting_started.rst b/docs/usage/getting_started.rst new file mode 100644 index 0000000..4c322a2 --- /dev/null +++ b/docs/usage/getting_started.rst @@ -0,0 +1,117 @@ +.. _usage-getting-started: + +################# + Getting started +################# + +************** + First recipe +************** + +The simplest use case is to build an encoder-processor-decoder graph for +a global weather model. In this case, the recipe must contain a +``nodes`` section where the keys will be the names of the sets of +`nodes`, that will later be used to build the connections. Each `nodes` +configuration must include a ``node_builder`` section describing how to +generate the `nodes`, and it may include an optional ``attributes`` +section to define additional attributes (weights, mask, ...). + +.. literalinclude:: yaml/nodes.yaml + :language: yaml + +Once the `nodes` have been defined, you need to create the edges between +them through which information will flow. To this aim, the recipe file +must contain a ``edges`` section. These connections are defined between +pairs of `nodes` (source and target, specified by `source_name` and +`target_name`). + +There are several methods to build these edges such as cutoff +(`CutOffEdges`) or nearest neighbours (`KNNEdges`). For an +encoder-processor-decoder graph you will need to build two sets of +`edges`. The first set of edges will connect the `data` nodes with the +`hidden` nodes to encode the input data into the latent space, normally +referred to as the `encoder edges` and represented here by the first +element of the ``edges`` section. The second set of `edges` will connect +the `hidden` nodes with the `data` nodes to decode the latent space into +the output data, normally referred to as `decoder edges` and represented +here by the second element of the ``edges`` section. + +.. literalinclude:: yaml/global_wo-proc.yaml + :language: yaml + +.. figure:: schemas/global_wo-proc.png + :alt: Schema of global graph (without processor connections) + :align: center + +To create the graph, run the following command: + +.. code:: console + + $ anemoi-graphs create recipe.yaml graph.pt + +Once the build is complete, you can inspect the dataset using the +following command: + +.. code:: console + + $ anemoi-graphs inspect graph.pt output_plots + +This will generate the following graph: + +.. literalinclude:: yaml/global_wo-proc.txt + :language: console + +.. note:: + + Note that that the resulting graph will only work with a Transformer + processor because there are no connections between the `hidden + nodes`. + +****************************** + Adding processor connections +****************************** + +To add connections within the ``hidden`` nodes, to be used in the +processor, you need to add a new set of `edges` to the recipe file. +These connections are normally referred to as `processor edges` and are +represented here by the third element of the ``edges`` section. + +.. literalinclude:: yaml/global.yaml + :language: yaml + +.. figure:: schemas/global.png + :alt: Schema of global graph + :align: center + +This will generate the following graph: + +.. literalinclude:: yaml/global.txt + :language: console + +******************* + Adding attributes +******************* + +When training a data-driven weather model, it is common to add +attributes to the nodes or edges. For example, you may want to add node +attributes to weight the loss function, or add edge attributes to +represent the direction of the edges. + +To add attributes to the `nodes`, you must include the `attributes` +section in the `nodes` configuration. The attributes can be defined as a +list of dictionaries, where each dictionary contains the name of the +attribute and the type of the attribute. + +.. literalinclude:: yaml/nodes_with-attrs.yaml + :language: yaml + +To add the extra features to the edges of the graph, you need to set +them in the ``attributes`` section. + +.. literalinclude:: yaml/global_with-attrs.yaml + :language: yaml + +This will generate the following graph: + +.. literalinclude:: yaml/global_with-attrs.txt + :language: console diff --git a/docs/usage/schemas/global.excalidraw b/docs/usage/schemas/global.excalidraw new file mode 100644 index 0000000..eca62b2 --- /dev/null +++ b/docs/usage/schemas/global.excalidraw @@ -0,0 +1,479 @@ +{ + "type": "excalidraw", + "version": 2, + "source": "https://excalidraw.com", + "elements": [ + { + "type": "diamond", + "version": 1574, + "versionNonce": 1952312215, + "index": "akG", + "isDeleted": false, + "id": "Znt9M8pxS0HpC9GiKp7f2", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": -686.275390625, + "y": 8.697265625, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "width": 189.40234375, + "height": 168.08203124999994, + "seed": 216165446, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 2 + }, + "boundElements": [ + { + "id": "9zgm6lmwdP5Lhkwjvggxr", + "type": "text" + }, + { + "id": "X87zkD0RgTBG4qEdx6-6a", + "type": "arrow" + } + ], + "updated": 1718866666194, + "link": null, + "locked": false + }, + { + "type": "text", + "version": 1514, + "versionNonce": 994799639, + "index": "akV", + "isDeleted": false, + "id": "9zgm6lmwdP5Lhkwjvggxr", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": -620.644775390625, + "y": 67.7177734375, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "width": 58.43994140625, + "height": 50, + "seed": 395899782, + "groupIds": [], + "frameId": null, + "roundness": null, + "boundElements": [], + "updated": 1718864334184, + "link": null, + "locked": false, + "fontSize": 20, + "fontFamily": 1, + "text": "Hidden\nnodes", + "textAlign": "center", + "verticalAlign": "middle", + "containerId": "Znt9M8pxS0HpC9GiKp7f2", + "originalText": "Hidden\nnodes", + "autoResize": true, + "lineHeight": 1.25 + }, + { + "type": "diamond", + "version": 1473, + "versionNonce": 1197881228, + "index": "al", + "isDeleted": false, + "id": "Qo-9em1mLQnX3Epd04wAU", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": -680.701171875, + "y": 298.513671875, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "width": 189.40234375, + "height": 168.08203124999994, + "seed": 500381062, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 2 + }, + "boundElements": [ + { + "type": "text", + "id": "bjYlghPRGa9M149mWesMW" + }, + { + "id": "X87zkD0RgTBG4qEdx6-6a", + "type": "arrow" + } + ], + "updated": 1718794669385, + "link": null, + "locked": false + }, + { + "type": "text", + "version": 1430, + "versionNonce": 91521241, + "index": "am", + "isDeleted": false, + "id": "bjYlghPRGa9M149mWesMW", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": -617.6405715942383, + "y": 357.5341796875, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "width": 63.57997131347656, + "height": 50, + "seed": 2048824518, + "groupIds": [], + "frameId": null, + "roundness": null, + "boundElements": [], + "updated": 1718864334184, + "link": null, + "locked": false, + "fontSize": 20, + "fontFamily": 1, + "text": "Data \nnodes", + "textAlign": "center", + "verticalAlign": "middle", + "containerId": "Qo-9em1mLQnX3Epd04wAU", + "originalText": "Data nodes", + "autoResize": true, + "lineHeight": 1.25 + }, + { + "type": "arrow", + "version": 265, + "versionNonce": 1943685428, + "index": "b0M", + "isDeleted": false, + "id": "X87zkD0RgTBG4qEdx6-6a", + "fillStyle": "solid", + "strokeWidth": 4, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0.014815383044963326, + "x": -634.25390625, + "y": 337.84375, + "strokeColor": "#1e1e1e", + "backgroundColor": "#ffec99", + "width": 64.01953125, + "height": 191.93359375, + "seed": 980493452, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 2 + }, + "boundElements": [], + "updated": 1718794678935, + "link": null, + "locked": false, + "startBinding": { + "elementId": "Qo-9em1mLQnX3Epd04wAU", + "focus": -0.22730420345522181, + "gap": 3.2174693078751844 + }, + "endBinding": { + "elementId": "Znt9M8pxS0HpC9GiKp7f2", + "focus": 0.06334851014220298, + "gap": 5.576512041887561 + }, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": "arrow", + "points": [ + [ + 0, + 0 + ], + [ + -64.01953125, + -97.4453125 + ], + [ + -1.421875, + -191.93359375 + ] + ] + }, + { + "type": "arrow", + "version": 547, + "versionNonce": 1330253748, + "index": "b0N", + "isDeleted": false, + "id": "o1-27eSnLA7FWSCeI8_hf", + "fillStyle": "solid", + "strokeWidth": 4, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 3.1473660224451887, + "x": -480.7864185018202, + "y": 335.0726012960115, + "strokeColor": "#1e1e1e", + "backgroundColor": "#ffec99", + "width": 64.01953125, + "height": 191.93359375, + "seed": 69264820, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 2 + }, + "boundElements": [], + "updated": 1718794699487, + "link": null, + "locked": false, + "startBinding": null, + "endBinding": null, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": "arrow", + "points": [ + [ + 0, + 0 + ], + [ + -64.01953125, + -97.4453125 + ], + [ + -1.421875, + -191.93359375 + ] + ] + }, + { + "type": "text", + "version": 104, + "versionNonce": 447960375, + "index": "b0O", + "isDeleted": false, + "id": "w4_mrYxHlHgpbd01zu-1J", + "fillStyle": "solid", + "strokeWidth": 4, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": -747.6875, + "y": 229.17578125, + "strokeColor": "#1e1e1e", + "backgroundColor": "#ffec99", + "width": 36.3828125, + "height": 25, + "seed": 778583092, + "groupIds": [], + "frameId": null, + "roundness": null, + "boundElements": [], + "updated": 1718864334184, + "link": null, + "locked": false, + "fontSize": 20, + "fontFamily": 1, + "text": "A)", + "textAlign": "left", + "verticalAlign": "top", + "containerId": null, + "originalText": "A)", + "autoResize": false, + "lineHeight": 1.25 + }, + { + "type": "text", + "version": 34, + "versionNonce": 482409913, + "index": "b0P", + "isDeleted": false, + "id": "8jI1ivvTgDU6LdGAtBLst", + "fillStyle": "solid", + "strokeWidth": 4, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": -463.83203125, + "y": 219.03125, + "strokeColor": "#1e1e1e", + "backgroundColor": "#ffec99", + "width": 21.29998779296875, + "height": 25, + "seed": 335072396, + "groupIds": [], + "frameId": null, + "roundness": null, + "boundElements": [], + "updated": 1718864334184, + "link": null, + "locked": false, + "fontSize": 20, + "fontFamily": 1, + "text": "B)", + "textAlign": "left", + "verticalAlign": "top", + "containerId": null, + "originalText": "B)", + "autoResize": true, + "lineHeight": 1.25 + }, + { + "id": "EGBSbXlhXBORmnsVFvEts", + "type": "arrow", + "x": -678.5792230220354, + "y": 13.811080021389312, + "width": 87.03339052302621, + "height": 48.768129202436036, + "angle": 5.148276390964529, + "strokeColor": "#1e1e1e", + "backgroundColor": "#ffec99", + "fillStyle": "solid", + "strokeWidth": 4, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "index": "b0Q", + "roundness": { + "type": 2 + }, + "seed": 1974783065, + "version": 1072, + "versionNonce": 1886357911, + "isDeleted": false, + "boundElements": null, + "updated": 1718866748897, + "link": null, + "locked": false, + "points": [ + [ + 0, + 0 + ], + [ + 39.15328310546727, + -48.768129202436036 + ], + [ + 87.03339052302621, + -1.8211131319121323 + ] + ], + "lastCommittedPoint": null, + "startBinding": null, + "endBinding": null, + "startArrowhead": null, + "endArrowhead": null + }, + { + "type": "arrow", + "version": 1384, + "versionNonce": 918275127, + "index": "b0R", + "isDeleted": false, + "id": "uEVsPEcX9f8SeEA9Q79-X", + "fillStyle": "solid", + "strokeWidth": 4, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 1.194283386269774, + "x": -597.5325644796139, + "y": 13.995674352769438, + "strokeColor": "#1e1e1e", + "backgroundColor": "#ffec99", + "width": 87.03339052302621, + "height": 48.768129202436036, + "seed": 1101268537, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 2 + }, + "boundElements": [], + "updated": 1718866755747, + "link": null, + "locked": false, + "startBinding": null, + "endBinding": null, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": "arrow", + "points": [ + [ + 0, + 0 + ], + [ + 39.15328310546727, + -48.768129202436036 + ], + [ + 87.03339052302621, + -1.8211131319121323 + ] + ] + }, + { + "type": "text", + "version": 113, + "versionNonce": 1667036473, + "index": "b0S", + "isDeleted": false, + "id": "VQ-NFIzfci1S3-cTTm0_G", + "fillStyle": "solid", + "strokeWidth": 4, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": -606.0249938964844, + "y": -82.9609375, + "strokeColor": "#1e1e1e", + "backgroundColor": "#ffec99", + "width": 19.639984130859375, + "height": 25, + "seed": 1604276025, + "groupIds": [], + "frameId": null, + "roundness": null, + "boundElements": [], + "updated": 1718866732266, + "link": null, + "locked": false, + "fontSize": 20, + "fontFamily": 1, + "text": "C)", + "textAlign": "left", + "verticalAlign": "top", + "containerId": null, + "originalText": "C)", + "autoResize": true, + "lineHeight": 1.25 + } + ], + "appState": { + "gridSize": null, + "viewBackgroundColor": "#ffffff" + }, + "files": {} +} diff --git a/docs/usage/schemas/global.png b/docs/usage/schemas/global.png new file mode 100644 index 0000000..2c2a429 Binary files /dev/null and b/docs/usage/schemas/global.png differ diff --git a/docs/usage/schemas/global_wo-proc.excalidraw b/docs/usage/schemas/global_wo-proc.excalidraw new file mode 100644 index 0000000..7b28c79 --- /dev/null +++ b/docs/usage/schemas/global_wo-proc.excalidraw @@ -0,0 +1,344 @@ +{ + "type": "excalidraw", + "version": 2, + "source": "https://excalidraw.com", + "elements": [ + { + "type": "diamond", + "version": 1572, + "versionNonce": 1863238796, + "index": "akG", + "isDeleted": false, + "id": "Znt9M8pxS0HpC9GiKp7f2", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": -686.275390625, + "y": 8.697265625, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "width": 189.40234375, + "height": 168.08203124999994, + "seed": 216165446, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 2 + }, + "boundElements": [ + { + "id": "9zgm6lmwdP5Lhkwjvggxr", + "type": "text" + }, + { + "id": "X87zkD0RgTBG4qEdx6-6a", + "type": "arrow" + } + ], + "updated": 1718794669385, + "link": null, + "locked": false + }, + { + "type": "text", + "version": 1508, + "versionNonce": 358396172, + "index": "akV", + "isDeleted": false, + "id": "9zgm6lmwdP5Lhkwjvggxr", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": -620.644775390625, + "y": 67.7177734375, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "width": 58.43994140625, + "height": 50, + "seed": 395899782, + "groupIds": [], + "frameId": null, + "roundness": null, + "boundElements": [], + "updated": 1718794797984, + "link": null, + "locked": false, + "fontSize": 20, + "fontFamily": 1, + "text": "Hidden\nnodes", + "textAlign": "center", + "verticalAlign": "middle", + "containerId": "Znt9M8pxS0HpC9GiKp7f2", + "originalText": "Hidden\nnodes", + "autoResize": true, + "lineHeight": 1.25 + }, + { + "type": "diamond", + "version": 1473, + "versionNonce": 1197881228, + "index": "al", + "isDeleted": false, + "id": "Qo-9em1mLQnX3Epd04wAU", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": -680.701171875, + "y": 298.513671875, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "width": 189.40234375, + "height": 168.08203124999994, + "seed": 500381062, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 2 + }, + "boundElements": [ + { + "type": "text", + "id": "bjYlghPRGa9M149mWesMW" + }, + { + "id": "X87zkD0RgTBG4qEdx6-6a", + "type": "arrow" + } + ], + "updated": 1718794669385, + "link": null, + "locked": false + }, + { + "type": "text", + "version": 1424, + "versionNonce": 13350324, + "index": "am", + "isDeleted": false, + "id": "bjYlghPRGa9M149mWesMW", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": -617.6405715942383, + "y": 357.5341796875, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "width": 63.57997131347656, + "height": 50, + "seed": 2048824518, + "groupIds": [], + "frameId": null, + "roundness": null, + "boundElements": [], + "updated": 1718794797984, + "link": null, + "locked": false, + "fontSize": 20, + "fontFamily": 1, + "text": "Data \nnodes", + "textAlign": "center", + "verticalAlign": "middle", + "containerId": "Qo-9em1mLQnX3Epd04wAU", + "originalText": "Data nodes", + "autoResize": true, + "lineHeight": 1.25 + }, + { + "id": "X87zkD0RgTBG4qEdx6-6a", + "type": "arrow", + "x": -634.25390625, + "y": 337.84375, + "width": 64.01953125, + "height": 191.93359375, + "angle": 0.014815383044963326, + "strokeColor": "#1e1e1e", + "backgroundColor": "#ffec99", + "fillStyle": "solid", + "strokeWidth": 4, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "index": "b0M", + "roundness": { + "type": 2 + }, + "seed": 980493452, + "version": 265, + "versionNonce": 1943685428, + "isDeleted": false, + "boundElements": null, + "updated": 1718794678935, + "link": null, + "locked": false, + "points": [ + [ + 0, + 0 + ], + [ + -64.01953125, + -97.4453125 + ], + [ + -1.421875, + -191.93359375 + ] + ], + "lastCommittedPoint": null, + "startBinding": { + "elementId": "Qo-9em1mLQnX3Epd04wAU", + "focus": -0.22730420345522181, + "gap": 3.2174693078751844 + }, + "endBinding": { + "elementId": "Znt9M8pxS0HpC9GiKp7f2", + "focus": 0.06334851014220298, + "gap": 5.576512041887561 + }, + "startArrowhead": null, + "endArrowhead": "arrow" + }, + { + "type": "arrow", + "version": 547, + "versionNonce": 1330253748, + "index": "b0N", + "isDeleted": false, + "id": "o1-27eSnLA7FWSCeI8_hf", + "fillStyle": "solid", + "strokeWidth": 4, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 3.1473660224451887, + "x": -480.7864185018202, + "y": 335.0726012960115, + "strokeColor": "#1e1e1e", + "backgroundColor": "#ffec99", + "width": 64.01953125, + "height": 191.93359375, + "seed": 69264820, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 2 + }, + "boundElements": [], + "updated": 1718794699487, + "link": null, + "locked": false, + "startBinding": null, + "endBinding": null, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": "arrow", + "points": [ + [ + 0, + 0 + ], + [ + -64.01953125, + -97.4453125 + ], + [ + -1.421875, + -191.93359375 + ] + ] + }, + { + "id": "w4_mrYxHlHgpbd01zu-1J", + "type": "text", + "x": -747.6875, + "y": 229.17578125, + "width": 36.3828125, + "height": 25, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "#ffec99", + "fillStyle": "solid", + "strokeWidth": 4, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "index": "b0O", + "roundness": null, + "seed": 778583092, + "version": 98, + "versionNonce": 841301132, + "isDeleted": false, + "boundElements": null, + "updated": 1718794797984, + "link": null, + "locked": false, + "text": "A)", + "fontSize": 20, + "fontFamily": 1, + "textAlign": "left", + "verticalAlign": "top", + "containerId": null, + "originalText": "A)", + "autoResize": false, + "lineHeight": 1.25 + }, + { + "id": "8jI1ivvTgDU6LdGAtBLst", + "type": "text", + "x": -463.83203125, + "y": 219.03125, + "width": 21.29998779296875, + "height": 25, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "#ffec99", + "fillStyle": "solid", + "strokeWidth": 4, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "index": "b0P", + "roundness": null, + "seed": 335072396, + "version": 28, + "versionNonce": 1531159092, + "isDeleted": false, + "boundElements": null, + "updated": 1718794797984, + "link": null, + "locked": false, + "text": "B)", + "fontSize": 20, + "fontFamily": 1, + "textAlign": "left", + "verticalAlign": "top", + "containerId": null, + "originalText": "B)", + "autoResize": true, + "lineHeight": 1.25 + } + ], + "appState": { + "gridSize": null, + "viewBackgroundColor": "#ffffff" + }, + "files": {} +} diff --git a/docs/usage/schemas/global_wo-proc.png b/docs/usage/schemas/global_wo-proc.png new file mode 100644 index 0000000..71ffdde Binary files /dev/null and b/docs/usage/schemas/global_wo-proc.png differ diff --git a/docs/usage/yaml/global.txt b/docs/usage/yaml/global.txt new file mode 100644 index 0000000..96da1e1 --- /dev/null +++ b/docs/usage/yaml/global.txt @@ -0,0 +1,21 @@ +┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈ +📦 Path : graph.pt +🔢 Format version: 0.0.1 + +💽 Size : 3.1 MiB (3,283,650) + + Nodes name │ Num. nodes | Attribute dim | Min. latitude | Max. latitude | Min. longitude | Max. longitude + ─────────────┼────────────┼───────────────┼───────────────┼───────────────┼────────────────┼──────────────── + data | 10,840 | 0 | -3.135 | 3.140 | 0.02 | 6.13 + hidden | 6,200 | 0 | -3.141 | 3.137 | 0.01 | 6.14 + ─────────────┴────────────┴───────────────┴───────────────┴───────────────┴────────────────┴──────────────── + + + Source │ Destination │ Num. edges │ Attribute dim | Min. length │ Max. length │ Mean length │ Std dev + ─────────────┼──────────────┼─────────────┼───────────────┼─────────────┼─────────────┼─────────────┼───────── + data │ hidden │ 13508 │ 1 | 0.3116 │ 25.79 │ 11.059531 │ 5.5856 + hidden │ data │ 40910 │ 1 | 0.2397 │ 21.851 │ 12.270924 │ 4.2347 + hidden │ hidden │ 32010 │ 1 | 0.2397 │ 21.851 │ 10.270924 │ 4.2347 + ─────────────┴──────────────┴─────────────┴───────────────|─────────────┴─────────────┴─────────────┴───────── +🔋 Graph ready, last update 7 seconds ago. +📊 Statistics ready. diff --git a/docs/usage/yaml/global.yaml b/docs/usage/yaml/global.yaml new file mode 100644 index 0000000..a0bc3bb --- /dev/null +++ b/docs/usage/yaml/global.yaml @@ -0,0 +1,23 @@ +nodes: + data: ... + hidden: ... + +edges: + # A) Encoder connections + - source_name: data + target_name: hidden + edge_builder: + _target_: anemoi.graphs.edges.CutOffEdges + cutoff_factor: 0.7 + # B) Decoder connections + - source_name: hidden + target_name: hidden + edge_builder: + _target_: anemoi.graphs.edges.KNNEdges + nearest_neighbours: 3 + # C) Processor connections + - source_name: hidden + target_name: data + edge_builder: + _target_: anemoi.graphs.edges.KNNEdges + nearest_neighbours: 3 diff --git a/docs/usage/yaml/global_with-attrs.txt b/docs/usage/yaml/global_with-attrs.txt new file mode 100644 index 0000000..57d15a7 --- /dev/null +++ b/docs/usage/yaml/global_with-attrs.txt @@ -0,0 +1,21 @@ +┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈ +📦 Path : graph.pt +🔢 Format version: 0.0.1 + +💽 Size : 3.1 MiB (3,283,650) + + Nodes name │ Num. nodes | Attribute dim | Min. latitude | Max. latitude | Min. longitude | Max. longitude + ─────────────┼────────────┼───────────────┼───────────────┼───────────────┼────────────────┼──────────────── + data | 10,840 | 1 | -3.135 | 3.140 | 0.02 | 6.13 + hidden | 6,200 | 1 | -3.141 | 3.137 | 0.01 | 6.14 + ─────────────┴────────────┴───────────────┴───────────────┴───────────────┴────────────────┴──────────────── + + + Source │ Destination │ Num. edges │ Attribute dim | Min. length │ Max. length │ Mean length │ Std dev + ─────────────┼──────────────┼─────────────┼───────────────┼─────────────┼─────────────┼─────────────┼───────── + data │ hidden │ 13508 │ 3 | 0.3116 │ 25.79 │ 11.059531 │ 5.5856 + hidden │ data │ 40910 │ 3 | 0.2397 │ 21.851 │ 12.270924 │ 4.2347 + hidden │ hidden │ 32010 │ 3 | 0.2397 │ 21.851 │ 10.270924 │ 4.2347 + ─────────────┴──────────────┴─────────────┴───────────────|─────────────┴─────────────┴─────────────┴───────── +🔋 Graph ready, last update 7 seconds ago. +📊 Statistics ready. diff --git a/docs/usage/yaml/global_with-attrs.yaml b/docs/usage/yaml/global_with-attrs.yaml new file mode 100644 index 0000000..fc292d0 --- /dev/null +++ b/docs/usage/yaml/global_with-attrs.yaml @@ -0,0 +1,32 @@ +nodes: + data: ... + hidden: ... + +edges: + # A) Encoder connections + - source_name: data + target_name: hidden + edge_builder: + _target_: anemoi.graphs.edges.CutOffEdges + cutoff_factor: 0.7 + attributes: + edge_length: + _target_: anemoi.graphs.edges.attributes.EdgeLength + # B) Decoder connections + - source_name: hidden + target_name: hidden + edge_builder: + _target_: anemoi.graphs.edges.KNNEdges + nearest_neighbours: 3 + attributes: + edge_length: + _target_: anemoi.graphs.edges.attributes.EdgeLength + # C) Processor connections + - source_name: hidden + target_name: data + edge_builder: + _target_: anemoi.graphs.edges.KNNEdges + nearest_neighbours: 3 + attributes: + edge_length: + _target_: anemoi.graphs.edges.attributes.EdgeLength diff --git a/docs/usage/yaml/global_wo-proc.png b/docs/usage/yaml/global_wo-proc.png new file mode 100644 index 0000000..71ffdde Binary files /dev/null and b/docs/usage/yaml/global_wo-proc.png differ diff --git a/docs/usage/yaml/global_wo-proc.txt b/docs/usage/yaml/global_wo-proc.txt new file mode 100644 index 0000000..c27dafc --- /dev/null +++ b/docs/usage/yaml/global_wo-proc.txt @@ -0,0 +1,20 @@ +┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈ +📦 Path : graph.pt +🔢 Format version: 0.0.1 + +💽 Size : 3.1 MiB (3,283,650) + + Nodes name │ Num. nodes | Attribute dim | Min. latitude | Max. latitude | Min. longitude | Max. longitude + ─────────────┼────────────┼───────────────┼───────────────┼───────────────┼────────────────┼──────────────── + data | 10,840 | 4 | -3.135 | 3.140 | 0.02 | 6.13 + hidden | 6,200 | 4 | -3.141 | 3.137 | 0.01 | 6.14 + ─────────────┴────────────┴───────────────┴───────────────┴───────────────┴────────────────┴──────────────── + + + Source │ Destination │ Num. edges │ Attribute dim | Min. length │ Max. length │ Mean length │ Std dev + ─────────────┼──────────────┼─────────────┼───────────────┼─────────────┼─────────────┼─────────────┼───────── + data │ hidden │ 13508 │ 1 | 0.3116 │ 25.79 │ 11.059531 │ 5.5856 + hidden │ data │ 40910 │ 1 | 0.2397 │ 21.851 │ 12.270924 │ 4.2347 + ─────────────┴──────────────┴─────────────┴───────────────|─────────────┴─────────────┴─────────────┴───────── +🔋 Graph ready, last update 17 seconds ago. +📊 Statistics ready. diff --git a/docs/usage/yaml/global_wo-proc.yaml b/docs/usage/yaml/global_wo-proc.yaml new file mode 100644 index 0000000..c1e3ad3 --- /dev/null +++ b/docs/usage/yaml/global_wo-proc.yaml @@ -0,0 +1,17 @@ +nodes: + data: ... + hidden: ... + +edges: + # A) Encoder connections + - source_name: data + target_name: hidden + edge_builder: + _target_: anemoi.graphs.edges.CutOffEdges + cutoff_factor: 0.7 + # B) Decoder connections + - source_name: hidden + target_name: hidden + edge_builder: + _target_: anemoi.graphs.edges.KNNEdges + nearest_neighbours: 3 diff --git a/docs/usage/yaml/nodes.yaml b/docs/usage/yaml/nodes.yaml new file mode 100644 index 0000000..7c97851 --- /dev/null +++ b/docs/usage/yaml/nodes.yaml @@ -0,0 +1,10 @@ +nodes: + data: + node_builder: + _target_: anemoi.graphs.nodes.ZarrDatasetNodes + dataset: /path/to/dataset.zarr + hidden: + node_builder: + _target_: anemoi.graphs.nodes.NPZFileNodes + grid_definition_path: /path/to/grids/ + resolution: o48 diff --git a/docs/usage/yaml/nodes_with-attrs.yaml b/docs/usage/yaml/nodes_with-attrs.yaml new file mode 100644 index 0000000..44ded27 --- /dev/null +++ b/docs/usage/yaml/nodes_with-attrs.yaml @@ -0,0 +1,10 @@ +nodes: + data: + node_builder: + _target_: anemoi.graphs.nodes.ZarrDatasetNodes + dataset: /path/to/dataset.zarr + attributes: + weights: + _target_: anemoi.graphs.nodes.attributes.AreaWeights + norm: unit-max + hidden: ... diff --git a/pyproject.toml b/pyproject.toml index cb5bb7f..a6148ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,20 +10,13 @@ # https://packaging.python.org/en/latest/guides/writing-pyproject-toml/ [build-system] -requires = [ - "setuptools>=60", - "setuptools-scm>=8", -] +requires = [ "setuptools>=60", "setuptools-scm>=8" ] [project] name = "anemoi-graphs" description = "A package to build graphs for data-driven forecasts." -keywords = [ - "ai", - "graphs", - "tools", -] +keywords = [ "ai", "graphs", "tools" ] license = { file = "LICENSE" } authors = [ @@ -46,31 +39,23 @@ classifiers = [ "Programming Language :: Python :: Implementation :: PyPy", ] -dynamic = [ - "version", -] +dynamic = [ "version" ] dependencies = [ "anemoi-datasets[data]>=0.3.3", "anemoi-utils>=0.3.6", + "h3>=3.7.6,<4", + "healpy>=1.17", "hydra-core>=1.3", + "matplotlib>=3.4", + "networkx>=3.1", + "plotly>=5.19", "torch>=2.2", "torch-geometric>=2.3.1,<2.5", + "trimesh>=4.1", ] -optional-dependencies.all = [ -] -optional-dependencies.dev = [ - "nbsphinx", - "pandoc", - "pytest", - "pytest-mock", - "requests", - "sphinx", - "sphinx-argparse", - "sphinx-rtd-theme", - "termcolor", - "tomli", -] +optional-dependencies.all = [ ] +optional-dependencies.dev = [ "anemoi-graphs[docs,tests]" ] optional-dependencies.docs = [ "nbsphinx", @@ -83,10 +68,7 @@ optional-dependencies.docs = [ "tomli", ] -optional-dependencies.tests = [ - "pytest", - "pytest-mock", -] +optional-dependencies.tests = [ "pytest", "pytest-mock" ] urls.Documentation = "https://anemoi-graphs.readthedocs.io/" urls.Homepage = "https://github.com/ecmwf/anemoi-graphs/" diff --git a/src/anemoi/graphs/__init__.py b/src/anemoi/graphs/__init__.py index 715b8a4..55bad43 100644 --- a/src/anemoi/graphs/__init__.py +++ b/src/anemoi/graphs/__init__.py @@ -5,6 +5,11 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -from ._version import __version__ - EARTH_RADIUS = 6371.0 # km +try: + # NOTE: the `_version.py` file must not be present in the git repository + # as it is generated by setuptools at install time + from ._version import __version__ # type: ignore +except ImportError: # pragma: no cover + # Local copy or not installed with setuptools + __version__ = "999" diff --git a/src/anemoi/graphs/commands/create.py b/src/anemoi/graphs/commands/create.py index 18b3127..81ac397 100644 --- a/src/anemoi/graphs/commands/create.py +++ b/src/anemoi/graphs/commands/create.py @@ -1,7 +1,23 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import argparse +import logging +from pathlib import Path + from anemoi.graphs.create import GraphCreator +from anemoi.graphs.describe import GraphDescriptor from . import Command +LOGGER = logging.getLogger(__name__) + class Create(Command): """Create a graph.""" @@ -15,14 +31,26 @@ def add_arguments(self, command_parser): action="store_true", help="Overwrite existing files. This will delete the target graph if it already exists.", ) - command_parser.add_argument("config", help="Configuration yaml file defining the recipe to create the graph.") - command_parser.add_argument("path", help="Path to store the created graph.") + command_parser.add_argument( + "--description", + action=argparse.BooleanOptionalAction, + default=True, + help="Show the description of the graph.", + ) + command_parser.add_argument( + "config", type=Path, help="Configuration yaml file path defining the recipe to create the graph." + ) + command_parser.add_argument("save_path", type=Path, help="Path to store the created graph.") def run(self, args): - kwargs = vars(args) - - c = GraphCreator(**kwargs) - c.create() + graph_creator = GraphCreator(config=args.config) + graph_creator.create(save_path=args.save_path, overwrite=args.overwrite) + + if args.description: + if args.save_path.exists(): + GraphDescriptor(args.save_path).describe() + else: + print("Graph description is not shown if the graph is not saved.") command = Create diff --git a/src/anemoi/graphs/commands/describe.py b/src/anemoi/graphs/commands/describe.py new file mode 100644 index 0000000..36ee137 --- /dev/null +++ b/src/anemoi/graphs/commands/describe.py @@ -0,0 +1,30 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from anemoi.graphs.describe import GraphDescriptor + +from . import Command + + +class Describe(Command): + """Describe a graph.""" + + internal = True + timestamp = True + + def add_arguments(self, command_parser): + command_parser.add_argument("graph_file", help="Path to the graph (a .PT file).") + + def run(self, args): + kwargs = vars(args) + + GraphDescriptor(kwargs["graph_file"]).describe() + + +command = Describe diff --git a/src/anemoi/graphs/commands/inspect.py b/src/anemoi/graphs/commands/inspect.py new file mode 100644 index 0000000..34e91e2 --- /dev/null +++ b/src/anemoi/graphs/commands/inspect.py @@ -0,0 +1,56 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import argparse + +from anemoi.graphs.describe import GraphDescriptor +from anemoi.graphs.inspect import GraphInspector + +from . import Command + + +class Inspect(Command): + """Inspect a graph.""" + + internal = True + timestamp = True + + def add_arguments(self, command_parser): + command_parser.add_argument( + "--show_attribute_distributions", + action=argparse.BooleanOptionalAction, + default=True, + help="Hide distribution plots of edge/node attributes.", + ) + command_parser.add_argument( + "--show_nodes", + action=argparse.BooleanOptionalAction, + default=False, + help="Show the nodes of the graph.", + ) + command_parser.add_argument( + "--description", + action=argparse.BooleanOptionalAction, + default=True, + help="Hide the description of the graph.", + ) + command_parser.add_argument("path", help="Path to the graph (a .PT file).") + command_parser.add_argument("output_path", help="Path to store the inspection results.") + + def run(self, args): + kwargs = vars(args) + + if kwargs.get("description", False): + GraphDescriptor(kwargs["path"]).describe() + + inspector = GraphInspector(**kwargs) + inspector.inspect() + + +command = Inspect diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index bf3cb0c..c111cc3 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -1,5 +1,17 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + import logging -import os +from itertools import chain +from pathlib import Path import torch from anemoi.utils.config import DotDict @@ -14,70 +26,110 @@ class GraphCreator: def __init__( self, - path, - config=None, - cache=None, - print=print, - overwrite=False, - **kwargs, + config: str | Path | DotDict, ): - if isinstance(config, str) or isinstance(config, os.PathLike): + if isinstance(config, Path) or isinstance(config, str): self.config = DotDict.from_file(config) else: self.config = config - self.path = path # Output path - self.cache = cache - self.print = print - self.overwrite = overwrite - - def init(self): - if self._path_readable() and not self.overwrite: - raise Exception(f"{self.path} already exists. Use overwrite=True to overwrite.") - - def generate_graph(self) -> HeteroData: - """Generate the graph. + def update_graph(self, graph: HeteroData) -> HeteroData: + """Update the graph. It instantiates the node builders and edge builders defined in the configuration file and applies them to the graph. + Parameters + ---------- + graph : HeteroData + The input graph to be updated. + Returns ------- - HeteroData: The generated graph. + HeteroData + The updated graph with new nodes and edges added based on the configuration. """ - graph = HeteroData() - - for nodes_cfg in self.config.nodes: - graph = instantiate(nodes_cfg.node_builder, name=nodes_cfg.name).update_graph( + for nodes_name, nodes_cfg in self.config.get("nodes", {}).items(): + graph = instantiate(nodes_cfg.node_builder, name=nodes_name).update_graph( graph, nodes_cfg.get("attributes", {}) ) - for edges_cfg in self.config.edges: - graph = instantiate(edges_cfg.edge_builder, edges_cfg.source_name, edges_cfg.target_name).update_graph( - graph, edges_cfg.get("attributes", {}) - ) + for edges_cfg in self.config.get("edges", {}): + graph = instantiate( + edges_cfg.edge_builder, + edges_cfg.source_name, + edges_cfg.target_name, + source_mask_attr_name=edges_cfg.get("source_mask_attr_name", None), + target_mask_attr_name=edges_cfg.get("target_mask_attr_name", None), + ).update_graph(graph, edges_cfg.get("attributes", {})) return graph - def save(self, graph: HeteroData) -> None: - """Save the graph to the output path.""" - if not os.path.exists(self.path) or self.overwrite: - torch.save(graph, self.path) - self.print(f"Graph saved at {self.path}.") - - def create(self) -> HeteroData: - """Create the graph and save it to the output path.""" - self.init() - graph = self.generate_graph() - self.save(graph) + def clean(self, graph: HeteroData) -> HeteroData: + """Remove private attributes used during creation from the graph. + + Parameters + ---------- + graph : HeteroData + Generated graph + + Returns + ------- + HeteroData + Cleaned graph + """ + LOGGER.info("Cleaning graph.") + for type_name in chain(graph.node_types, graph.edge_types): + attr_names_to_remove = [attr_name for attr_name in graph[type_name] if attr_name.startswith("_")] + for attr_name in attr_names_to_remove: + del graph[type_name][attr_name] + LOGGER.info(f"{attr_name} deleted from graph.") + return graph - def _path_readable(self) -> bool: - """Check if the output path is readable.""" - import torch + def save(self, graph: HeteroData, save_path: Path, overwrite: bool = False) -> None: + """Save the generated graph to the output path. + + Parameters + ---------- + graph : HeteroData + generated graph + save_path : Path + location to save the graph + overwrite : bool, optional + whether to overwrite existing graph file, by default False + """ + save_path = Path(save_path) + + if not save_path.exists() or overwrite: + save_path.parent.mkdir(parents=True, exist_ok=True) + torch.save(graph, save_path) + LOGGER.info(f"Graph saved at {save_path}.") + else: + LOGGER.info("Graph already exists. Use overwrite=True to overwrite.") + + def create(self, save_path: Path | None = None, overwrite: bool = False) -> HeteroData: + """Create the graph and save it to the output path. - try: - torch.load(self.path) - return True - except FileNotFoundError: - return False + Parameters + ---------- + save_path : Path, optional + location to save the graph, by default None + overwrite : bool, optional + whether to overwrite existing graph file, by default False + + Returns + ------- + HeteroData + created graph object + """ + graph = HeteroData() + graph = self.update_graph(graph) + graph = self.clean(graph) + + if save_path is None: + LOGGER.warning("No output path specified. The graph will not be saved.") + else: + self.save(graph, save_path, overwrite) + + return graph diff --git a/src/anemoi/graphs/describe.py b/src/anemoi/graphs/describe.py new file mode 100644 index 0000000..95eba5a --- /dev/null +++ b/src/anemoi/graphs/describe.py @@ -0,0 +1,225 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import math +from itertools import chain +from pathlib import Path +from typing import Optional +from typing import Union + +import torch +from anemoi.utils.humanize import bytes +from anemoi.utils.text import table + + +class GraphDescriptor: + """Class for descripting the graph.""" + + def __init__(self, path: Union[str, Path], **kwargs): + self.path = path + self.graph = torch.load(self.path) + + @property + def total_size(self): + """Total size of the tensors in the graph (in bytes).""" + total_size = 0 + + for store in chain(self.graph.node_stores, self.graph.edge_stores): + for value in store.values(): + if isinstance(value, torch.Tensor): + total_size += value.numel() * value.element_size() + + return total_size + + def get_node_summary(self) -> list[list]: + """Summary of the nodes in the graph. + + Returns + ------- + list[list] + Returns a list for each subgraph with the following information: + - Node name. + - Number of nodes. + - List of attribute names. + - Total dimension of the attributes. + - Min. latitude. + - Max. latitude. + - Min. longitude. + - Max. longitude. + """ + node_summary = [] + for name, nodes in self.graph.node_items(): + attributes = nodes.node_attrs() + attributes.remove("x") + + node_summary.append( + [ + name, + nodes.num_nodes, + ", ".join(attributes), + sum(nodes[attr].shape[1] for attr in attributes if isinstance(nodes[attr], torch.Tensor)), + nodes.x[:, 0].min().item() / 2 / math.pi * 360, + nodes.x[:, 0].max().item() / 2 / math.pi * 360, + nodes.x[:, 1].min().item() / 2 / math.pi * 360, + nodes.x[:, 1].max().item() / 2 / math.pi * 360, + ] + ) + return node_summary + + def get_edge_summary(self) -> list[list]: + """Summary of the edges in the graph. + + Returns + ------- + list[list] + Returns a list for each subgraph with the following information: + - Source node name. + - Destination node name. + - Number of edges. + - Number of isolated source nodes. + - Number of isolated target nodes. + - Total dimension of the attributes. + - List of attribute names. + """ + edge_summary = [] + for (src_name, _, dst_name), edges in self.graph.edge_items(): + attributes = edges.edge_attrs() + attributes.remove("edge_index") + + edge_summary.append( + [ + src_name, + dst_name, + edges.num_edges, + self.graph[src_name].num_nodes - len(torch.unique(edges.edge_index[0])), + self.graph[dst_name].num_nodes - len(torch.unique(edges.edge_index[1])), + sum(edges[attr].shape[1] for attr in attributes), + ", ".join([f"{attr}({edges[attr].shape[1]}D)" for attr in attributes]), + ] + ) + return edge_summary + + def get_node_attribute_table(self) -> list[list]: + node_attributes = [] + for node_name, node_store in self.graph.node_items(): + node_attr_names = node_store.node_attrs() + node_attr_names.remove("x") # Remove the coordinates from statistics table + for node_attr_name in node_attr_names: + node_attributes.append( + [ + "Node", + node_name, + node_attr_name, + node_store[node_attr_name].dtype, + node_store[node_attr_name].float().min().item(), + node_store[node_attr_name].float().mean().item(), + node_store[node_attr_name].float().max().item(), + node_store[node_attr_name].float().std().item(), + ] + ) + return node_attributes + + def get_edge_attribute_table(self) -> list[list]: + edge_attributes = [] + for (source_name, _, target_name), edge_store in self.graph.edge_items(): + edge_attr_names = edge_store.edge_attrs() + edge_attr_names.remove("edge_index") # Remove the edge index from statistics table + for edge_attr_name in edge_attr_names: + edge_attributes.append( + [ + "Edge", + f"{source_name}-->{target_name}", + edge_attr_name, + edge_store[edge_attr_name].dtype, + edge_store[edge_attr_name].float().min().item(), + edge_store[edge_attr_name].float().mean().item(), + edge_store[edge_attr_name].float().max().item(), + edge_store[edge_attr_name].float().std().item(), + ] + ) + + return edge_attributes + + def get_attribute_table(self) -> list[list]: + """Get a table with the attributes of the graph.""" + attribute_table = [] + attribute_table.extend(self.get_node_attribute_table()) + attribute_table.extend(self.get_edge_attribute_table()) + return attribute_table + + def describe(self, show_attribute_distributions: Optional[bool] = True) -> None: + """Describe the graph.""" + print() + print(f"📦 Path : {self.path}") + print(f"💽 Size : {bytes(self.total_size)} ({self.total_size})") + print() + print("🪩 Nodes summary") + print() + print( + table( + self.get_node_summary(), + header=[ + "Nodes name", + "Num. nodes", + "Attributes", + "Attribute dim", + "Min. latitude", + "Max. latitude", + "Min. longitude", + "Max. longitude", + ], + align=["<", ">", ">", ">", ">", ">", ">", ">"], + margin=3, + ) + ) + print() + print() + print("🌐 Edges summary") + print() + print( + table( + self.get_edge_summary(), + header=[ + "Source", + "Target", + "Num. edges", + "Isolated Source", + "Isolated Target", + "Attribute dim", + "Attributes", + ], + align=["<", "<", ">", ">", ">", ">", ">"], + margin=3, + ) + ) + print() + if show_attribute_distributions: + print() + print("📊 Attribute distributions") + print() + print( + table( + self.get_attribute_table(), + header=[ + "Type", + "Source", + "Name", + "Dtype", + "Min.", + "Mean", + "Max.", + "Std. dev.", + ], + align=["<", "<", ">", ">", ">", ">", ">", ">"], + margin=3, + ) + ) + print() + print("🔋 Graph ready.") + print() diff --git a/src/anemoi/graphs/edges/__init__.py b/src/anemoi/graphs/edges/__init__.py index 53b9c74..281860d 100644 --- a/src/anemoi/graphs/edges/__init__.py +++ b/src/anemoi/graphs/edges/__init__.py @@ -1,4 +1,5 @@ from .builder import CutOffEdges from .builder import KNNEdges +from .builder import MultiScaleEdges -__all__ = ["KNNEdges", "CutOffEdges"] +__all__ = ["KNNEdges", "CutOffEdges", "MultiScaleEdges"] diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index 9a8d6d8..3560fb2 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -1,23 +1,30 @@ -import logging +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + from abc import ABC from abc import abstractmethod -from typing import Optional import numpy as np import torch from torch_geometric.data import HeteroData from anemoi.graphs.edges.directional import directional_edge_features -from anemoi.graphs.normalizer import NormalizerMixin +from anemoi.graphs.normalise import NormaliserMixin from anemoi.graphs.utils import haversine_distance -LOGGER = logging.getLogger(__name__) - -class BaseEdgeAttribute(ABC, NormalizerMixin): +class BaseEdgeAttribute(ABC, NormaliserMixin): """Base class for edge attributes.""" - def __init__(self, norm: Optional[str] = None) -> None: + def __init__(self, norm: str | None = None) -> None: self.norm = norm @abstractmethod @@ -28,7 +35,7 @@ def post_process(self, values: np.ndarray) -> torch.Tensor: if values.ndim == 1: values = values[:, np.newaxis] - normed_values = self.normalize(values) + normed_values = self.normalise(values) return torch.tensor(normed_values, dtype=torch.float32) @@ -49,27 +56,27 @@ def compute(self, graph: HeteroData, edges_name: tuple[str, str, str], *args, ** class EdgeDirection(BaseEdgeAttribute): """Edge direction feature. - If using the rotated features, the direction of the edge is computed - rotating the target nodes to the north pole. If not, it is computed - as the diference in latitude and longitude between the source and - target nodes. + This class calculates the direction of an edge using either: + 1. Rotated features: The target nodes are rotated to the north pole to compute the edge direction. + 2. Non-rotated features: The direction is computed as the difference in latitude and longitude between the source + and target nodes. + + The resulting direction is represented as a unit vector starting at (0, 0), with X and Y components. Attributes ---------- norm : Optional[str] - Normalization method. + Normalisation method. luse_rotated_features : bool Whether to use rotated features. Methods ------- - get_raw_values(graph, source_name, target_name) - Compute directions between nodes connected by edges. compute(graph, source_name, target_name) - Compute directional attributes. + Compute direction of all edges. """ - def __init__(self, norm: Optional[str] = None, luse_rotated_features: bool = True) -> None: + def __init__(self, norm: str | None = None, luse_rotated_features: bool = True) -> None: super().__init__(norm) self.luse_rotated_features = luse_rotated_features @@ -103,19 +110,17 @@ class EdgeLength(BaseEdgeAttribute): Attributes ---------- norm : str - Normalization method. + Normalisation method. invert : bool Whether to invert the edge lengths, i.e. 1 - edge_length. Methods ------- - get_raw_values(graph, source_name, target_name) - Compute haversine distance between nodes connected by edges. compute(graph, source_name, target_name) Compute edge lengths attributes. """ - def __init__(self, norm: Optional[str] = None, invert: bool = False) -> None: + def __init__(self, norm: str | None = None, invert: bool = False) -> None: super().__init__(norm) self.invert = invert @@ -144,6 +149,9 @@ def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) def post_process(self, values: np.ndarray) -> torch.Tensor: """Post-process edge lengths.""" + values = super().post_process(values) + if self.invert: values = 1 - values - return super().post_process(values) + + return values diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 17ba4fc..b5a5df8 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -1,17 +1,37 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + import logging from abc import ABC from abc import abstractmethod -from typing import Optional +import networkx as nx import numpy as np import torch from anemoi.utils.config import DotDict from hydra.utils import instantiate +from scipy.sparse import coo_matrix from sklearn.neighbors import NearestNeighbors from torch_geometric.data import HeteroData from torch_geometric.data.storage import NodeStorage from anemoi.graphs import EARTH_RADIUS +from anemoi.graphs.generate import hex_icosahedron +from anemoi.graphs.generate import tri_icosahedron +from anemoi.graphs.generate.masks import KNNAreaMaskBuilder +from anemoi.graphs.nodes.builders.from_refined_icosahedron import HexNodes +from anemoi.graphs.nodes.builders.from_refined_icosahedron import LimitedAreaHexNodes +from anemoi.graphs.nodes.builders.from_refined_icosahedron import LimitedAreaTriNodes +from anemoi.graphs.nodes.builders.from_refined_icosahedron import StretchedTriNodes +from anemoi.graphs.nodes.builders.from_refined_icosahedron import TriNodes from anemoi.graphs.utils import get_grid_reference_distance LOGGER = logging.getLogger(__name__) @@ -20,9 +40,17 @@ class BaseEdgeBuilder(ABC): """Base class for edge builders.""" - def __init__(self, source_name: str, target_name: str): + def __init__( + self, + source_name: str, + target_name: str, + source_mask_attr_name: str | None = None, + target_mask_attr_name: str | None = None, + ): self.source_name = source_name self.target_name = target_name + self.source_mask_attr_name = source_mask_attr_name + self.target_mask_attr_name = target_mask_attr_name @property def name(self) -> tuple[str, str, str]: @@ -33,7 +61,7 @@ def name(self) -> tuple[str, str, str]: def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): ... def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage]: - """Prepare nodes information.""" + """Prepare node information and get source and target nodes.""" return graph[self.source_name], graph[self.target_name] def get_edge_index(self, graph: HeteroData) -> torch.Tensor: @@ -94,7 +122,7 @@ def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: graph[self.name][attr_name] = instantiate(attr_config).compute(graph, self.name) return graph - def update_graph(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: + def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -> HeteroData: """Update the graph with the edges. Parameters @@ -111,15 +139,48 @@ def update_graph(self, graph: HeteroData, attrs_config: Optional[DotDict] = None """ graph = self.register_edges(graph) - if attrs_config is None: - return graph - - graph = self.register_attributes(graph, attrs_config) + if attrs_config is not None: + graph = self.register_attributes(graph, attrs_config) return graph -class KNNEdges(BaseEdgeBuilder): +class NodeMaskingMixin: + """Mixin class for masking source/target nodes when building edges.""" + + def get_node_coordinates( + self, source_nodes: NodeStorage, target_nodes: NodeStorage + ) -> tuple[np.ndarray, np.ndarray]: + """Get the node coordinates.""" + source_coords, target_coords = source_nodes.x.numpy(), target_nodes.x.numpy() + + if self.source_mask_attr_name is not None: + source_coords = source_coords[source_nodes[self.source_mask_attr_name].squeeze()] + + if self.target_mask_attr_name is not None: + target_coords = target_coords[target_nodes[self.target_mask_attr_name].squeeze()] + + return source_coords, target_coords + + def undo_masking(self, adj_matrix, source_nodes: NodeStorage, target_nodes: NodeStorage): + if self.target_mask_attr_name is not None: + target_mask = target_nodes[self.target_mask_attr_name].squeeze() + target_mapper = dict(zip(list(range(len(adj_matrix.row))), np.where(target_mask)[0])) + adj_matrix.row = np.vectorize(target_mapper.get)(adj_matrix.row) + + if self.source_mask_attr_name is not None: + source_mask = source_nodes[self.source_mask_attr_name].squeeze() + source_mapper = dict(zip(list(range(len(adj_matrix.col))), np.where(source_mask)[0])) + adj_matrix.col = np.vectorize(source_mapper.get)(adj_matrix.col) + + if self.source_mask_attr_name is not None or self.target_mask_attr_name is not None: + true_shape = target_nodes.x.shape[0], source_nodes.x.shape[0] + adj_matrix = coo_matrix((adj_matrix.data, (adj_matrix.row, adj_matrix.col)), shape=true_shape) + + return adj_matrix + + +class KNNEdges(BaseEdgeBuilder, NodeMaskingMixin): """Computes KNN based edges and adds them to the graph. Attributes @@ -130,11 +191,13 @@ class KNNEdges(BaseEdgeBuilder): The name of the target nodes. num_nearest_neighbours : int Number of nearest neighbours. + source_mask_attr_name : str | None + The name of the source mask attribute to filter edge connections. + target_mask_attr_name : str | None + The name of the target mask attribute to filter edge connections. Methods ------- - get_adjacency_matrix(source_nodes, target_nodes) - Compute the adjacency matrix for the KNN method. register_edges(graph) Register the edges in the graph. register_attributes(graph, config) @@ -143,22 +206,35 @@ class KNNEdges(BaseEdgeBuilder): Update the graph with the edges. """ - def __init__(self, source_name: str, target_name: str, num_nearest_neighbours: int): - super().__init__(source_name, target_name) + def __init__( + self, + source_name: str, + target_name: str, + num_nearest_neighbours: int, + source_mask_attr_name: str | None = None, + target_mask_attr_name: str | None = None, + ) -> None: + super().__init__(source_name, target_name, source_mask_attr_name, target_mask_attr_name) assert isinstance(num_nearest_neighbours, int), "Number of nearest neighbours must be an integer" assert num_nearest_neighbours > 0, "Number of nearest neighbours must be positive" self.num_nearest_neighbours = num_nearest_neighbours - def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarray): + def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage) -> np.ndarray: """Compute the adjacency matrix for the KNN method. Parameters ---------- - source_nodes : np.ndarray + source_nodes : NodeStorage The source nodes. - target_nodes : np.ndarray + target_nodes : NodeStorage The target nodes. + + Returns + ------- + np.ndarray + The adjacency matrix. """ + source_coords, target_coords = self.get_node_coordinates(source_nodes, target_nodes) assert self.num_nearest_neighbours is not None, "number of neighbors required for knn encoder" LOGGER.info( "Using KNN-Edges (with %d nearest neighbours) between %s and %s.", @@ -168,16 +244,20 @@ def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarra ) nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) - nearest_neighbour.fit(source_nodes.x.numpy()) + nearest_neighbour.fit(source_coords) adj_matrix = nearest_neighbour.kneighbors_graph( - target_nodes.x.numpy(), + target_coords, n_neighbors=self.num_nearest_neighbours, mode="distance", ).tocoo() + + # Post-process the adjacency matrix. Add masked nodes. + adj_matrix = self.undo_masking(adj_matrix, source_nodes, target_nodes) + return adj_matrix -class CutOffEdges(BaseEdgeBuilder): +class CutOffEdges(BaseEdgeBuilder, NodeMaskingMixin): """Computes cut-off based edges and adds them to the graph. Attributes @@ -188,15 +268,13 @@ class CutOffEdges(BaseEdgeBuilder): The name of the target nodes. cutoff_factor : float Factor to multiply the grid reference distance to get the cut-off radius. - radius : float - Cut-off radius. + source_mask_attr_name : str | None + The name of the source mask attribute to filter edge connections. + target_mask_attr_name : str | None + The name of the target mask attribute to filter edge connections. Methods ------- - get_cutoff_radius(graph, mask_attr) - Compute the cut-off radius. - get_adjacency_matrix(source_nodes, target_nodes) - Get the adjacency matrix for the cut-off method. register_edges(graph) Register the edges in the graph. register_attributes(graph, config) @@ -205,16 +283,24 @@ class CutOffEdges(BaseEdgeBuilder): Update the graph with the edges. """ - def __init__(self, source_name: str, target_name: str, cutoff_factor: float): - super().__init__(source_name, target_name) + def __init__( + self, + source_name: str, + target_name: str, + cutoff_factor: float, + source_mask_attr_name: str | None = None, + target_mask_attr_name: str | None = None, + ): + super().__init__(source_name, target_name, source_mask_attr_name, target_mask_attr_name) assert isinstance(cutoff_factor, (int, float)), "Cutoff factor must be a float" assert cutoff_factor > 0, "Cutoff factor must be positive" self.cutoff_factor = cutoff_factor - def get_cutoff_radius(self, graph: HeteroData, mask_attr: Optional[torch.Tensor] = None): + def get_cutoff_radius(self, graph: HeteroData, mask_attr: torch.Tensor | None = None) -> float: """Compute the cut-off radius. - The cut-off radius is computed as the product of the target nodes reference distance and the cut-off factor. + The cut-off radius is computed as the product of the target nodes + reference distance and the cut-off factor. Parameters ---------- @@ -235,11 +321,11 @@ def get_cutoff_radius(self, graph: HeteroData, mask_attr: Optional[torch.Tensor] return radius def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage]: - """Prepare nodes information.""" + """Prepare node information and get source and target nodes.""" self.radius = self.get_cutoff_radius(graph) return super().prepare_node_data(graph) - def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): + def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage) -> np.ndarray: """Get the adjacency matrix for the cut-off method. Parameters @@ -248,7 +334,13 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor The source nodes. target_nodes : NodeStorage The target nodes. + + Returns + ------- + np.ndarray + The adjacency matrix. """ + source_coords, target_coords = self.get_node_coordinates(source_nodes, target_nodes) LOGGER.info( "Using CutOff-Edges (with radius = %.1f km) between %s and %s.", self.radius * EARTH_RADIUS, @@ -257,6 +349,98 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor ) nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) - nearest_neighbour.fit(source_nodes.x) - adj_matrix = nearest_neighbour.radius_neighbors_graph(target_nodes.x, radius=self.radius).tocoo() + nearest_neighbour.fit(source_coords) + adj_matrix = nearest_neighbour.radius_neighbors_graph(target_coords, radius=self.radius).tocoo() + + # Post-process the adjacency matrix. Add masked nodes. + adj_matrix = self.undo_masking(adj_matrix, source_nodes, target_nodes) + return adj_matrix + + +class MultiScaleEdges(BaseEdgeBuilder): + """Base class for multi-scale edges in the nodes of a graph. + + Attributes + ---------- + source_name : str + The name of the source nodes. + target_name : str + The name of the target nodes. + x_hops : int + Number of hops (in the refined icosahedron) between two nodes to connect + them with an edge. + + Methods + ------- + register_edges(graph) + Register the edges in the graph. + register_attributes(graph, config) + Register attributes in the edges of the graph. + update_graph(graph, attrs_config) + Update the graph with the edges. + """ + + VALID_NODES = [TriNodes, HexNodes, LimitedAreaTriNodes, LimitedAreaHexNodes, StretchedTriNodes] + + def __init__(self, source_name: str, target_name: str, x_hops: int, **kwargs): + super().__init__(source_name, target_name) + assert source_name == target_name, f"{self.__class__.__name__} requires source and target nodes to be the same." + assert isinstance(x_hops, int), "Number of x_hops must be an integer" + assert x_hops > 0, "Number of x_hops must be positive" + self.x_hops = x_hops + self.node_type = None + + def add_edges_from_tri_nodes(self, nodes: NodeStorage) -> NodeStorage: + nodes["_nx_graph"] = tri_icosahedron.add_edges_to_nx_graph( + nodes["_nx_graph"], + resolutions=nodes["_resolutions"], + x_hops=self.x_hops, + area_mask_builder=nodes.get("_area_mask_builder", None), + ) + + return nodes + + def add_edges_from_stretched_tri_nodes(self, nodes: NodeStorage) -> NodeStorage: + all_points_mask_builder = KNNAreaMaskBuilder("all_nodes", 1.0) + all_points_mask_builder.fit_coords(nodes.x.numpy()) + + nodes["_nx_graph"] = tri_icosahedron.add_edges_to_nx_graph( + nodes["_nx_graph"], + resolutions=nodes["_resolutions"], + x_hops=self.x_hops, + area_mask_builder=all_points_mask_builder, + ) + return nodes + + def add_edges_from_hex_nodes(self, nodes: NodeStorage) -> NodeStorage: + nodes["_nx_graph"] = hex_icosahedron.add_edges_to_nx_graph( + nodes["_nx_graph"], + resolutions=nodes["_resolutions"], + x_hops=self.x_hops, + ) + + return nodes + + def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): + if self.node_type in [TriNodes.__name__, LimitedAreaTriNodes.__name__]: + source_nodes = self.add_edges_from_tri_nodes(source_nodes) + elif self.node_type in [HexNodes.__name__, LimitedAreaHexNodes.__name__]: + source_nodes = self.add_edges_from_hex_nodes(source_nodes) + elif self.node_type == StretchedTriNodes.__name__: + source_nodes = self.add_edges_from_stretched_tri_nodes(source_nodes) + else: + raise ValueError(f"Invalid node type {self.node_type}") + + adjmat = nx.to_scipy_sparse_array(source_nodes["_nx_graph"], format="coo") + + return adjmat + + def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -> HeteroData: + self.node_type = graph[self.source_name].node_type + valid_node_names = [n.__name__ for n in self.VALID_NODES] + assert ( + self.node_type in valid_node_names + ), f"{self.__class__.__name__} requires {','.join(valid_node_names)} nodes." + + return super().update_graph(graph, attrs_config) diff --git a/src/anemoi/graphs/edges/directional.py b/src/anemoi/graphs/edges/directional.py index 9c7cdea..8a23a51 100644 --- a/src/anemoi/graphs/edges/directional.py +++ b/src/anemoi/graphs/edges/directional.py @@ -1,4 +1,13 @@ -from typing import Optional +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations import numpy as np from scipy.spatial.transform import Rotation @@ -28,7 +37,7 @@ def get_rotation_from_unit_vecs(points: np.ndarray, reference: np.ndarray) -> Ro return Rotation.from_rotvec(np.transpose(v_unit * theta)) -def compute_directions(loc1: np.ndarray, loc2: np.ndarray, pole_vec: Optional[np.ndarray] = None) -> np.ndarray: +def compute_directions(loc1: np.ndarray, loc2: np.ndarray, pole_vec: np.ndarray | None = None) -> np.ndarray: """Compute the direction of the edge joining the nodes considered. Parameters diff --git a/src/anemoi/graphs/generate/hex_icosahedron.py b/src/anemoi/graphs/generate/hex_icosahedron.py new file mode 100644 index 0000000..1e50f35 --- /dev/null +++ b/src/anemoi/graphs/generate/hex_icosahedron.py @@ -0,0 +1,236 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import h3 +import networkx as nx +import numpy as np + +from anemoi.graphs.generate.masks import KNNAreaMaskBuilder +from anemoi.graphs.generate.utils import get_coordinates_ordering + + +def create_hex_nodes( + resolution: int, + area_mask_builder: KNNAreaMaskBuilder | None = None, +) -> tuple[nx.Graph, np.ndarray, list[int]]: + """Creates a global mesh from a refined icosahedron. + + This method relies on the H3 python library, which covers the earth with hexagons (and 5 pentagons). At each + refinement level, a hexagon cell (nodes) has 7 child cells (aperture 7). + + Parameters + ---------- + resolution : int + Level of mesh resolution to consider. + area_mask_builder : KNNAreaMaskBuilder, optional + KNNAreaMaskBuilder with the cloud of points to limit the mesh area, by default None. + + Returns + ------- + graph : networkx.Graph + The specified graph (only nodes) sorted by latitude and longitude. + coords_rad : np.ndarray + The node coordinates (not ordered) in radians. + node_ordering : list[int] + Order of the node coordinates to be sorted by latitude and longitude. + """ + nodes = get_nodes_at_resolution(resolution) + + coords_rad = np.deg2rad(np.array([h3.h3_to_geo(node) for node in nodes])) + + node_ordering = get_coordinates_ordering(coords_rad) + + if area_mask_builder is not None: + area_mask = area_mask_builder.get_mask(coords_rad) + node_ordering = node_ordering[area_mask[node_ordering]] + + graph = create_nx_graph_from_hex_coords(nodes, node_ordering) + + return graph, coords_rad, list(node_ordering) + + +def create_nx_graph_from_hex_coords(nodes: list[str], node_ordering: np.ndarray) -> nx.Graph: + """Add all nodes at a specified refinement level to a graph. + + Parameters + ---------- + nodes : list[str] + The set of H3 indexes (nodes). + node_ordering: np.ndarray + Order of the node coordinates to be sorted by latitude and longitude. + + Returns + ------- + graph : networkx.Graph + The graph with the added nodes. + """ + graph = nx.Graph() + + for node_pos in node_ordering: + h3_idx = nodes[node_pos] + graph.add_node(h3_idx, hcoords_rad=np.deg2rad(h3.h3_to_geo(h3_idx))) + + return graph + + +def get_nodes_at_resolution( + resolution: int, +) -> list[str]: + """Get nodes at a specified refinement level over the entire globe. + + Parameters + ---------- + resolution : int + The H3 refinement level. It can be an integer from 0 to 15. + + Returns + ------- + nodes : list[str] + The list of H3 indexes at the specified resolution level. + """ + return list(h3.uncompact(h3.get_res0_indexes(), resolution)) + + +def add_edges_to_nx_graph( + graph: nx.Graph, + resolutions: list[int], + x_hops: int = 1, + depth_children: int = 0, +) -> nx.Graph: + """Adds the edges to the graph. + + This method includes multi-scale connections to the existing graph. The different scales + are defined by the resolutions (or refinement levels) specified. + + Parameters + ---------- + graph : networkx.Graph + The graph to add the edges. + resolutions : list[int] + Levels of mesh resolution to consider. + x_hops: int + The number of hops to consider for the neighbours. + depth_children : int + The number of resolution levels to consider for the connections of children. Defaults to 1, which includes + connections up to the next resolution level. + + Returns + ------- + graph : networkx.Graph + The graph with the added edges. + """ + + graph = add_neighbour_edges(graph, resolutions, x_hops) + graph = add_edges_to_children(graph, resolutions, depth_children) + return graph + + +def add_neighbour_edges( + graph: nx.Graph, + refinement_levels: tuple[int], + x_hops: int = 1, +) -> nx.Graph: + """Adds edges between neighbours at the specified refinement levels.""" + for resolution in refinement_levels: + nodes = select_nodes_from_graph_at_resolution(graph, resolution) + + for idx in nodes: + # neighbours + for idx_neighbour in h3.k_ring(idx, k=x_hops) & set(nodes): + graph = add_edge( + graph, + h3.h3_to_center_child(idx, max(refinement_levels)), + h3.h3_to_center_child(idx_neighbour, max(refinement_levels)), + ) + + return graph + + +def add_edges_to_children( + graph: nx.Graph, + refinement_levels: tuple[int], + depth_children: int | None = None, +) -> nx.Graph: + """Adds edges to the children of the nodes at the specified resolution levels. + + Parameters + ---------- + graph : nx.Graph + graph to which the edges will be added + refinement_levels : tuple[int] + set of refinement levels + depth_children : Optional[int], optional + The number of resolution levels to consider for the connections of children. Defaults to 1, which includes + connections up to the next resolution level, by default None. + + Returns + ------- + nx.Graph + Graph with the added edges. + """ + if depth_children is None: + depth_children = len(refinement_levels) + elif depth_children == 0: + return graph + + for i_level, resolution_parent in enumerate(list(sorted(refinement_levels))[0:-1]): + parent_nodes = select_nodes_from_graph_at_resolution(graph, resolution_parent) + + for parent_idx in parent_nodes: + # add own children + for resolution_child in refinement_levels[i_level + 1 : i_level + depth_children + 1]: + for child_idx in h3.h3_to_children(parent_idx, res=resolution_child): + graph = add_edge( + graph, + h3.h3_to_center_child(parent_idx, refinement_levels[-1]), + h3.h3_to_center_child(child_idx, refinement_levels[-1]), + ) + + return graph + + +def select_nodes_from_graph_at_resolution(graph: nx.Graph, resolution: int) -> set[str]: + """Select nodes from a graph at a specified resolution level.""" + nodes_at_lower_resolution = [n for n in h3.compact(graph.nodes) if h3.h3_get_resolution(n) <= resolution] + nodes_at_resolution = h3.uncompact(nodes_at_lower_resolution, resolution) + return nodes_at_resolution + + +def add_edge( + graph: nx.Graph, + source_node_h3_idx: str, + target_node_h3_idx: str, +) -> nx.Graph: + """Add edge between two nodes to a graph. + + The edge will only be added in case both target and source nodes are included in the graph. + + Parameters + ---------- + graph : networkx.Graph + The graph to add the nodes. + source_node_h3_idx : str + The H3 index of the tail of the edge. + target_node_h3_idx : str + The H3 index of the head of the edge. + + Returns + ------- + graph : networkx.Graph + The graph with the added edge. + """ + if not graph.has_node(source_node_h3_idx) or not graph.has_node(target_node_h3_idx): + return graph + + if source_node_h3_idx != target_node_h3_idx: + graph.add_edge(source_node_h3_idx, target_node_h3_idx) + + return graph diff --git a/src/anemoi/graphs/generate/masks.py b/src/anemoi/graphs/generate/masks.py new file mode 100644 index 0000000..4f9c8c1 --- /dev/null +++ b/src/anemoi/graphs/generate/masks.py @@ -0,0 +1,99 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging + +import numpy as np +from sklearn.neighbors import NearestNeighbors +from torch_geometric.data import HeteroData + +from anemoi.graphs import EARTH_RADIUS + +LOGGER = logging.getLogger(__name__) + + +class KNNAreaMaskBuilder: + """Class to build a mask based on distance to masked reference nodes using KNN. + + Attributes + ---------- + nearest_neighbour : NearestNeighbors + Nearest neighbour object to compute the KNN. + margin_radius_km : float + Maximum distance to the reference nodes to consider a node as valid, in kilometers. Defaults to 100 km. + reference_node_name : str + Name of the reference nodes in the graph to consider for the Area Mask. + mask_attr_name : str + Name of a node to attribute to mask the reference nodes, if desired. Defaults to consider all reference nodes. + + Methods + ------- + fit_coords(coords_rad: np.ndarray) + Fit the KNN model to the coordinates in radians. + fit(graph: HeteroData) + Fit the KNN model to the reference nodes. + get_mask(coords_rad: np.ndarray) -> np.ndarray + Get the mask for the nodes based on the distance to the reference nodes. + """ + + def __init__(self, reference_node_name: str, margin_radius_km: float = 100, mask_attr_name: str | None = None): + assert isinstance(margin_radius_km, (int, float)), "The margin radius must be a number." + assert margin_radius_km > 0, "The margin radius must be positive." + + self.nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) + self.margin_radius_km = margin_radius_km + self.reference_node_name = reference_node_name + self.mask_attr_name = mask_attr_name + + def get_reference_coords(self, graph: HeteroData) -> np.ndarray: + """Retrive coordinates from the reference nodes.""" + assert ( + self.reference_node_name in graph.node_types + ), f'Reference node "{self.reference_node_name}" not found in the graph.' + + coords_rad = graph[self.reference_node_name].x.numpy() + if self.mask_attr_name is not None: + assert ( + self.mask_attr_name in graph[self.reference_node_name].node_attrs() + ), f'Mask attribute "{self.mask_attr_name}" not found in the reference nodes.' + mask = graph[self.reference_node_name][self.mask_attr_name].squeeze() + coords_rad = coords_rad[mask] + + return coords_rad + + def fit_coords(self, coords_rad: np.ndarray): + """Fit the KNN model to the coordinates in radians.""" + self.nearest_neighbour.fit(coords_rad) + + def fit(self, graph: HeteroData): + """Fit the KNN model to the nodes of interest.""" + # Prepare string for logging + reference_mask_str = self.reference_node_name + if self.mask_attr_name is not None: + reference_mask_str += f" ({self.mask_attr_name})" + + # Fit to the reference nodes + coords_rad = self.get_reference_coords(graph) + self.fit_coords(coords_rad) + + LOGGER.info( + 'Fitting %s with %d reference nodes from "%s".', + self.__class__.__name__, + len(coords_rad), + reference_mask_str, + ) + + def get_mask(self, coords_rad: np.ndarray) -> np.ndarray: + """Compute a mask based on the distance to the reference nodes.""" + + neigh_dists, _ = self.nearest_neighbour.kneighbors(coords_rad, n_neighbors=1) + mask = neigh_dists[:, 0] * EARTH_RADIUS <= self.margin_radius_km + return mask diff --git a/src/anemoi/graphs/generate/transforms.py b/src/anemoi/graphs/generate/transforms.py index 99e838e..9392241 100644 --- a/src/anemoi/graphs/generate/transforms.py +++ b/src/anemoi/graphs/generate/transforms.py @@ -1,3 +1,12 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + import numpy as np diff --git a/src/anemoi/graphs/generate/tri_icosahedron.py b/src/anemoi/graphs/generate/tri_icosahedron.py new file mode 100644 index 0000000..72cd5cb --- /dev/null +++ b/src/anemoi/graphs/generate/tri_icosahedron.py @@ -0,0 +1,272 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +from collections.abc import Iterable + +import networkx as nx +import numpy as np +import trimesh +from sklearn.neighbors import BallTree + +from anemoi.graphs.generate.masks import KNNAreaMaskBuilder +from anemoi.graphs.generate.transforms import cartesian_to_latlon_rad +from anemoi.graphs.generate.utils import get_coordinates_ordering + + +def create_tri_nodes( + resolution: int, area_mask_builder: KNNAreaMaskBuilder | None = None +) -> tuple[nx.DiGraph, np.ndarray, list[int]]: + """Creates a global mesh from a refined icosahedron. + + This method relies on the trimesh python library. + + Parameters + ---------- + resolution : int + Level of mesh resolution to consider. + area_mask_builder : KNNAreaMaskBuilder + KNNAreaMaskBuilder with the cloud of points to limit the mesh area, by default None. + + Returns + ------- + graph : networkx.Graph + The specified graph (only nodes) sorted by latitude and longitude. + coords_rad : np.ndarray + The node coordinates (not ordered) in radians. + node_ordering : list[int] + Order of the node coordinates to be sorted by latitude and longitude. + """ + coords_rad = get_latlon_coords_icosphere(resolution) + + node_ordering = get_coordinates_ordering(coords_rad) + + if area_mask_builder is not None: + area_mask = area_mask_builder.get_mask(coords_rad) + node_ordering = node_ordering[area_mask[node_ordering]] + + # Creates the graph, with the nodes sorted by latitude and longitude. + nx_graph = create_nx_graph_from_tri_coords(coords_rad, node_ordering) + + return nx_graph, coords_rad, list(node_ordering) + + +def create_stretched_tri_nodes( + base_resolution: int, + lam_resolution: int, + area_mask_builder: KNNAreaMaskBuilder | None = None, +) -> tuple[nx.DiGraph, np.ndarray, list[int]]: + """Creates a global mesh with 2 levels of resolution. + + The base resolution is used to define the nodes outside the Area Of Interest (AOI), + while the lam_resolution is used to define the nodes inside the AOI. + + Parameters + --------- + base_resolution : int + Global resolution level. + lam_resolution : int + Local resolution level. + area_mask_builder : KNNAreaMaskBuilder + NearestNeighbors with the cloud of points to limit the mesh area. + + Returns + ------- + nx_graph : nx.DiGraph + The graph with the added nodes. + coords_rad : np.ndarray + The node coordinates (not ordered) in radians. + node_ordering : list[int] + Order of the node coordinates to be sorted by latitude and longitude. + """ + assert area_mask_builder is not None, "AOI mask builder must be provided to build refined grid." + # Get the low resolution nodes outside the AOI + base_coords_rad = get_latlon_coords_icosphere(base_resolution) + base_area_mask = ~area_mask_builder.get_mask(base_coords_rad) + + # Get the high resolution nodes inside the AOI + lam_coords_rad = get_latlon_coords_icosphere(lam_resolution) + lam_area_mask = area_mask_builder.get_mask(lam_coords_rad) + + coords_rad = np.concatenate([base_coords_rad[base_area_mask], lam_coords_rad[lam_area_mask]]) + + node_ordering = get_coordinates_ordering(coords_rad) + + # Creates the graph, with the nodes sorted by latitude and longitude. + nx_graph = create_nx_graph_from_tri_coords(coords_rad, node_ordering) + + return nx_graph, coords_rad, list(node_ordering) + + +def get_latlon_coords_icosphere(resolution: int) -> np.ndarray: + """Get the latitude and longitude coordinates (in radians) of the icosphere. + + Parameters + ---------- + resolution : int + The resolution level of the icosphere. + + Returns + ------- + np.ndarray + The latitude and longitude coordinates, in radians, of the icosphere. + """ + sphere = trimesh.creation.icosphere(subdivisions=resolution, radius=1.0) + coords_rad = cartesian_to_latlon_rad(sphere.vertices) + return coords_rad + + +def create_nx_graph_from_tri_coords(coords_rad: np.ndarray, node_ordering: np.ndarray) -> nx.DiGraph: + """Creates the networkx graph from the coordinates and the node ordering.""" + graph = nx.DiGraph() + for i, coords in enumerate(coords_rad[node_ordering]): + node_id = node_ordering[i] + graph.add_node(node_id, hcoords_rad=coords) + + assert list(graph.nodes.keys()) == list(node_ordering), "Nodes are not correctly added to the graph." + assert graph.number_of_nodes() == len(node_ordering), "The number of nodes must be the same." + return graph + + +def add_edges_to_nx_graph( + graph: nx.DiGraph, + resolutions: list[int], + x_hops: int = 1, + area_mask_builder: KNNAreaMaskBuilder | None = None, +) -> nx.DiGraph: + """Adds the edges to the graph. + + This method adds multi-scale connections to the existing graph. The corresponfing nodes or vertices + are defined by an isophere at the different esolutions (or refinement levels) specified. + + Parameters + ---------- + graph : nx.DiGraph + The graph to add the edges. It should correspond to the mesh nodes, without edges. + resolutions : list[int] + Levels of mesh refinement levels to consider. + x_hops : int, optional + Number of hops between 2 nodes to consider them neighbours, by default 1. + area_mask_builder : KNNAreaMaskBuilder + NearestNeighbors with the cloud of points to limit the mesh area, by default None. + + Returns + ------- + graph : nx.DiGraph + The graph with the added edges. + """ + assert x_hops > 0, "x_hops == 0, graph would have no edges ..." + + graph_vertices = np.array([graph.nodes[i]["hcoords_rad"] for i in sorted(graph.nodes)]) + tree = BallTree(graph_vertices, metric="haversine") + + # Build the multi-scale connections + for resolution in resolutions: + # Define the coordinates of the isophere vertices at specified 'resolution' level + r_sphere = trimesh.creation.icosphere(subdivisions=resolution, radius=1.0) + r_vertices_rad = cartesian_to_latlon_rad(r_sphere.vertices) + + # Limit area of mesh points. + if area_mask_builder is not None: + area_mask = area_mask_builder.get_mask(r_vertices_rad) + valid_nodes = np.where(area_mask)[0] + else: + valid_nodes = None + + node_neighbours = get_neighbours_within_hops(r_sphere, x_hops, valid_nodes=valid_nodes) + + _, vertex_mapping_index = tree.query(r_vertices_rad, k=1) + for idx_node, idx_neighbours in node_neighbours.items(): + graph = add_neigbours_edges(graph, idx_node, idx_neighbours, vertex_mapping_index=vertex_mapping_index) + + return graph + + +def get_neighbours_within_hops( + tri_mesh: trimesh.Trimesh, x_hops: int, valid_nodes: list[int] | None = None +) -> dict[int, set[int]]: + """Get the neigbour connections in the graph. + + Parameters + ---------- + tri_mesh : trimesh.Trimesh + The mesh to consider. + x_hops : int + Number of hops between 2 nodes to consider them neighbours. + valid_nodes : list[int], optional + List of valid nodes to consider, by default None. It is useful to consider only a subset of the nodes to save + computation time. + + Returns + ------- + neighbours : dict[int, set[int]] + A list with the neighbours for each vertex. The element at position 'i' correspond to the neighbours to the + i-th vertex of the mesh. + """ + edges = tri_mesh.edges_unique + + if valid_nodes is not None: + edges = edges[np.isin(tri_mesh.edges_unique, valid_nodes).all(axis=1)] + else: + valid_nodes = list(range(len(tri_mesh.vertices))) + graph = nx.from_edgelist(edges) + + neighbours = { + i: set(nx.ego_graph(graph, i, radius=x_hops, center=False) if i in graph else []) for i in valid_nodes + } + + return neighbours + + +def add_neigbours_edges( + graph: nx.Graph, + node_idx: int, + neighbour_indices: Iterable[int], + self_loops: bool = False, + vertex_mapping_index: np.ndarray | None = None, +) -> nx.Graph: + """Adds the edges of one node to its neighbours. + + Parameters + ---------- + graph : nx.Graph + The graph. + node_idx : int + The node considered. + neighbour_indices : list[int] + The neighbours of the node. + self_loops : bool, optional + Whether is supported to add self-loops, by default False. + vertex_mapping_index : np.ndarray, optional + Index to map the vertices from the refined sphere to the original one, by default None. + + Returns + ------- + nx.Graph + The graph with the added edges. + """ + graph_nodes_idx = list(sorted(graph.nodes)) + for neighbour_idx in neighbour_indices: + if not self_loops and node_idx == neighbour_idx: # no self-loops + continue + + if vertex_mapping_index is not None: + # Use the same method to add edge in all spheres + node_neighbour = graph_nodes_idx[vertex_mapping_index[neighbour_idx][0]] + node = graph_nodes_idx[vertex_mapping_index[node_idx][0]] + else: + node_neighbour = graph_nodes_idx[neighbour_idx] + node = graph_nodes_idx[node_idx] + + # add edge to the graph + if node in graph and node_neighbour in graph: + graph.add_edge(node_neighbour, node) + + return graph diff --git a/src/anemoi/graphs/generate/utils.py b/src/anemoi/graphs/generate/utils.py new file mode 100644 index 0000000..2dfc320 --- /dev/null +++ b/src/anemoi/graphs/generate/utils.py @@ -0,0 +1,31 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import numpy as np + + +def get_coordinates_ordering(coords: np.ndarray) -> np.ndarray: + """Sort node coordinates by latitude and longitude. + + Parameters + ---------- + coords : np.ndarray of shape (N, 2) + The node coordinates, with the latitude in the first column and the + longitude in the second column. + + Returns + ------- + np.ndarray + The order of the node coordinates to be sorted by latitude and longitude. + """ + # Get indices to sort points by lon & lat in radians. + index_latitude = np.argsort(coords[:, 1]) + index_longitude = np.argsort(coords[index_latitude][:, 0])[::-1] + node_ordering = np.arange(coords.shape[0])[index_latitude][index_longitude] + return node_ordering diff --git a/src/anemoi/graphs/inspect.py b/src/anemoi/graphs/inspect.py new file mode 100644 index 0000000..fa59018 --- /dev/null +++ b/src/anemoi/graphs/inspect.py @@ -0,0 +1,72 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +import os +from pathlib import Path +from typing import Optional +from typing import Union + +import torch + +from anemoi.graphs.plotting.displots import plot_distribution_edge_attributes +from anemoi.graphs.plotting.displots import plot_distribution_node_attributes +from anemoi.graphs.plotting.displots import plot_distribution_node_derived_attributes +from anemoi.graphs.plotting.interactive_html import plot_interactive_nodes +from anemoi.graphs.plotting.interactive_html import plot_interactive_subgraph +from anemoi.graphs.plotting.interactive_html import plot_isolated_nodes + +LOGGER = logging.getLogger(__name__) + + +class GraphInspector: + """Inspect the graph.""" + + def __init__( + self, + path: Union[str, Path], + output_path: Path, + show_attribute_distributions: Optional[bool] = True, + show_nodes: Optional[bool] = False, + **kwargs, + ): + self.path = path + self.graph = torch.load(self.path) + self.output_path = output_path + self.show_attribute_distributions = show_attribute_distributions + self.show_nodes = show_nodes + + if isinstance(self.output_path, str): + self.output_path = Path(self.output_path) + + os.makedirs(self.output_path, exist_ok=True) + + assert self.output_path.is_dir(), f"Path {self.output_path} is not a directory." + assert os.access(self.output_path, os.W_OK), f"Path {self.output_path} is not writable." + + def inspect(self): + """Run all the inspector methods.""" + LOGGER.info("Saving interactive plots of isolated nodes ...") + plot_isolated_nodes(self.graph, self.output_path / "isolated_nodes.html") + + LOGGER.info("Saving interactive plots of subgraphs ...") + for edges_subgraph in self.graph.edge_types: + ofile = self.output_path / f"{edges_subgraph[0]}_to_{edges_subgraph[2]}.html" + plot_interactive_subgraph(self.graph, edges_subgraph, out_file=ofile) + + if self.show_attribute_distributions: + LOGGER.info("Saving distribution plots of node ande edge attributes ...") + plot_distribution_node_derived_attributes(self.graph, self.output_path / "distribution_node_adjancency.png") + plot_distribution_edge_attributes(self.graph, self.output_path / "distribution_edge_attributes.png") + plot_distribution_node_attributes(self.graph, self.output_path / "distribution_node_attributes.png") + + if self.show_nodes: + LOGGER.info("Saving interactive plots of nodes ...") + for nodes_name in self.graph.node_types: + plot_interactive_nodes(self.graph, nodes_name, out_file=self.output_path / f"{nodes_name}_nodes.html") diff --git a/src/anemoi/graphs/nodes/__init__.py b/src/anemoi/graphs/nodes/__init__.py index 737f27f..ef48d45 100644 --- a/src/anemoi/graphs/nodes/__init__.py +++ b/src/anemoi/graphs/nodes/__init__.py @@ -1,4 +1,23 @@ -from .builder import NPZFileNodes -from .builder import ZarrDatasetNodes +from .builders.from_file import LimitedAreaNPZFileNodes +from .builders.from_file import NPZFileNodes +from .builders.from_file import ZarrDatasetNodes +from .builders.from_healpix import HEALPixNodes +from .builders.from_healpix import LimitedAreaHEALPixNodes +from .builders.from_refined_icosahedron import HexNodes +from .builders.from_refined_icosahedron import LimitedAreaHexNodes +from .builders.from_refined_icosahedron import LimitedAreaTriNodes +from .builders.from_refined_icosahedron import StretchedTriNodes +from .builders.from_refined_icosahedron import TriNodes -__all__ = ["ZarrDatasetNodes", "NPZFileNodes"] +__all__ = [ + "ZarrDatasetNodes", + "NPZFileNodes", + "TriNodes", + "HexNodes", + "HEALPixNodes", + "LimitedAreaHEALPixNodes", + "LimitedAreaNPZFileNodes", + "LimitedAreaTriNodes", + "LimitedAreaHexNodes", + "StretchedTriNodes", +] diff --git a/src/anemoi/graphs/nodes/attributes.py b/src/anemoi/graphs/nodes/attributes.py index 911ce99..1498e6d 100644 --- a/src/anemoi/graphs/nodes/attributes.py +++ b/src/anemoi/graphs/nodes/attributes.py @@ -1,40 +1,53 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + import logging from abc import ABC from abc import abstractmethod -from typing import Optional +from typing import Type import numpy as np import torch +from anemoi.datasets import open_dataset from scipy.spatial import SphericalVoronoi from torch_geometric.data import HeteroData from torch_geometric.data.storage import NodeStorage from anemoi.graphs.generate.transforms import latlon_rad_to_cartesian -from anemoi.graphs.normalizer import NormalizerMixin +from anemoi.graphs.normalise import NormaliserMixin LOGGER = logging.getLogger(__name__) -class BaseWeights(ABC, NormalizerMixin): +class BaseNodeAttribute(ABC, NormaliserMixin): """Base class for the weights of the nodes.""" - def __init__(self, norm: Optional[str] = None) -> None: + def __init__(self, norm: str | None = None, dtype: str = "float32") -> None: self.norm = norm + self.dtype = dtype @abstractmethod - def get_raw_values(self, nodes: NodeStorage, *args, **kwargs): ... + def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: ... def post_process(self, values: np.ndarray) -> torch.Tensor: """Post-process the values.""" if values.ndim == 1: values = values[:, np.newaxis] - norm_values = self.normalize(values) + norm_values = self.normalise(values) - return torch.tensor(norm_values, dtype=torch.float32) + return torch.tensor(norm_values.astype(self.dtype)) - def compute(self, graph: HeteroData, nodes_name: str, *args, **kwargs) -> torch.Tensor: - """Get the node weights. + def compute(self, graph: HeteroData, nodes_name: str, **kwargs) -> torch.Tensor: + """Get the nodes attribute. Parameters ---------- @@ -42,43 +55,53 @@ def compute(self, graph: HeteroData, nodes_name: str, *args, **kwargs) -> torch. Graph. nodes_name : str Name of the nodes. + kwargs : dict + Additional keyword arguments. Returns ------- torch.Tensor - Weights associated to the nodes. + Attributes associated to the nodes. """ nodes = graph[nodes_name] - weights = self.get_raw_values(nodes, *args, **kwargs) - return self.post_process(weights) + attributes = self.get_raw_values(nodes, **kwargs) + return self.post_process(attributes) + +class UniformWeights(BaseNodeAttribute): + """Implements a uniform weight for the nodes. -class UniformWeights(BaseWeights): - """Implements a uniform weight for the nodes.""" + Methods + ------- + compute(self, graph, nodes_name) + Compute the area attributes for each node. + """ - def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray: + def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: """Compute the weights. Parameters ---------- nodes : NodeStorage Nodes of the graph. + kwargs : dict + Additional keyword arguments. Returns ------- np.ndarray - Weights. + Attributes. """ return np.ones(nodes.num_nodes) -class AreaWeights(BaseWeights): +class AreaWeights(BaseNodeAttribute): """Implements the area of the nodes as the weights. Attributes ---------- norm : str - Normalization of the weights. + Normalisation of the weights. radius : float Radius of the sphere. centre : np.ndarray @@ -86,20 +109,22 @@ class AreaWeights(BaseWeights): Methods ------- - get_raw_values(nodes, *args, **kwargs) - Compute the area associated to each node. - compute(nodes, *args, **kwargs) + compute(self, graph, nodes_name) Compute the area attributes for each node. """ def __init__( - self, norm: Optional[str] = None, radius: float = 1.0, centre: np.ndarray = np.array([0, 0, 0]) + self, + norm: str | None = None, + radius: float = 1.0, + centre: np.ndarray = np.array([0, 0, 0]), + dtype: str = "float32", ) -> None: - super().__init__(norm) + super().__init__(norm, dtype) self.radius = radius self.centre = centre - def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray: + def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: """Compute the area associated to each node. It uses Voronoi diagrams to compute the area of each node. @@ -108,14 +133,16 @@ def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray: ---------- nodes : NodeStorage Nodes of the graph. + kwargs : dict + Additional keyword arguments. Returns ------- np.ndarray - Weights. + Attributes. """ latitudes, longitudes = nodes.x[:, 0], nodes.x[:, 1] - points = latlon_rad_to_cartesian((latitudes, longitudes)) + points = latlon_rad_to_cartesian((np.asarray(latitudes), np.asarray(longitudes))) sv = SphericalVoronoi(points, self.radius, self.centre) area_weights = sv.calculate_areas() LOGGER.debug( @@ -124,3 +151,83 @@ def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray: np.array(area_weights).sum(), ) return area_weights + + +class BooleanBaseNodeAttribute(BaseNodeAttribute, ABC): + """Base class for boolean node attributes.""" + + def __init__(self) -> None: + super().__init__(norm=None, dtype="bool") + + +class NonmissingZarrVariable(BooleanBaseNodeAttribute): + """Mask of valid (not missing) values of a Zarr dataset variable. + + It reads a variable from a Zarr dataset and returns a boolean mask of nonmissing values in the first timestep. + + Attributes + ---------- + variable : str + Variable to read from the Zarr dataset. + norm : str + Normalization of the weights. + + Methods + ------- + compute(self, graph, nodes_name) + Compute the attribute for each node. + """ + + def __init__(self, variable: str) -> None: + super().__init__() + self.variable = variable + + def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: + assert ( + nodes["node_type"] == "ZarrDatasetNodes" + ), f"{self.__class__.__name__} can only be used with ZarrDatasetNodes." + ds = open_dataset(nodes["_dataset"], select=self.variable)[0].squeeze() + return ~np.isnan(ds) + + +class CutOutMask(BooleanBaseNodeAttribute): + """Cut out mask.""" + + def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: + assert isinstance(nodes["_dataset"], dict), "The 'dataset' attribute must be a dictionary." + assert "cutout" in nodes["_dataset"], "The 'dataset' attribute must contain a 'cutout' key." + num_lam, num_other = open_dataset(nodes["_dataset"]).grids + return np.array([True] * num_lam + [False] * num_other, dtype=bool) + + +class BooleanOperation(BooleanBaseNodeAttribute, ABC): + """Base class for boolean operations.""" + + def __init__(self, masks: list[str | Type[BooleanBaseNodeAttribute]]) -> None: + super().__init__() + self.masks = masks + + @staticmethod + def get_mask_values(mask: str | Type[BaseNodeAttribute], nodes: NodeStorage, **kwargs) -> np.array: + if isinstance(mask, str): + attributes = nodes[mask] + assert ( + attributes.dtype == "bool" + ), f"The mask attribute '{mask}' must be a boolean but is {attributes.dtype}." + return attributes + + return mask.get_raw_values(nodes, **kwargs) + + +class BooleanAndMask(BooleanOperation): + """Boolean AND mask.""" + + def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: + return np.logical_and.reduce([BooleanOperation.get_mask_values(mask, nodes, **kwargs) for mask in self.masks]) + + +class BooleanOrMask(BooleanOperation): + """Boolean OR mask.""" + + def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: + return np.logical_or.reduce([BooleanOperation.get_mask_values(mask, nodes, **kwargs) for mask in self.masks]) diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py deleted file mode 100644 index 6ff37a1..0000000 --- a/src/anemoi/graphs/nodes/builder.py +++ /dev/null @@ -1,190 +0,0 @@ -import logging -from abc import ABC -from abc import abstractmethod -from pathlib import Path -from typing import Optional - -import numpy as np -import torch -from anemoi.datasets import open_dataset -from anemoi.utils.config import DotDict -from hydra.utils import instantiate -from torch_geometric.data import HeteroData - -LOGGER = logging.getLogger(__name__) - - -class BaseNodeBuilder(ABC): - """Base class for node builders. - - The node coordinates are stored in the `x` attribute of the nodes and they are stored in radians. - """ - - def __init__(self, name: str) -> None: - self.name = name - - def register_nodes(self, graph: HeteroData) -> None: - """Register nodes in the graph. - - Parameters - ---------- - graph : HeteroData - The graph to register the nodes. - """ - graph[self.name].x = self.get_coordinates() - graph[self.name].node_type = type(self).__name__ - return graph - - def register_attributes(self, graph: HeteroData, config: Optional[DotDict] = None) -> HeteroData: - """Register attributes in the nodes of the graph specified. - - Parameters - ---------- - graph : HeteroData - The graph to register the attributes. - config : DotDict - The configuration of the attributes. - - Returns - ------- - HeteroData - The graph with the registered attributes. - """ - for attr_name, attr_config in config.items(): - graph[self.name][attr_name] = instantiate(attr_config).compute(graph, self.name) - return graph - - @abstractmethod - def get_coordinates(self) -> torch.Tensor: ... - - def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> torch.Tensor: - """Reshape latitude and longitude coordinates. - - Parameters - ---------- - latitudes : np.ndarray of shape (N, ) - Latitude coordinates, in degrees. - longitudes : np.ndarray of shape (N, ) - Longitude coordinates, in degrees. - - Returns - ------- - torch.Tensor of shape (N, 2) - A 2D tensor with the coordinates, in radians. - """ - coords = np.stack([latitudes, longitudes], axis=-1).reshape((-1, 2)) - coords = np.deg2rad(coords) - return torch.tensor(coords, dtype=torch.float32) - - def update_graph(self, graph: HeteroData, attr_config: Optional[DotDict] = None) -> HeteroData: - """Update the graph with new nodes. - - Parameters - ---------- - graph : HeteroData - Input graph. - attr_config : DotDict - The configuration of the attributes. - - Returns - ------- - HeteroData - The graph with new nodes included. - """ - graph = self.register_nodes(graph) - - if attr_config is None: - return graph - - graph = self.register_attributes(graph, attr_config) - - return graph - - -class ZarrDatasetNodes(BaseNodeBuilder): - """Nodes from Zarr dataset. - - Attributes - ---------- - ds : zarr.core.Array - The dataset. - - Methods - ------- - get_coordinates() - Get the lat-lon coordinates of the nodes. - register_nodes(graph, name) - Register the nodes in the graph. - register_attributes(graph, name, config) - Register the attributes in the nodes of the graph specified. - update_graph(graph, name, attr_config) - Update the graph with new nodes and attributes. - """ - - def __init__(self, dataset: DotDict, name: str) -> None: - LOGGER.info("Reading the dataset from %s.", dataset) - self.ds = open_dataset(dataset) - super().__init__(name) - - def get_coordinates(self) -> torch.Tensor: - """Get the coordinates of the nodes. - - Returns - ------- - torch.Tensor of shape (N, 2) - Coordinates of the nodes. - """ - return self.reshape_coords(self.ds.latitudes, self.ds.longitudes) - - -class NPZFileNodes(BaseNodeBuilder): - """Nodes from NPZ defined grids. - - Attributes - ---------- - resolution : str - The resolution of the grid. - grid_definition_path : str - Path to the folder containing the grid definition files. - grid_definition : dict[str, np.ndarray] - The grid definition. - - Methods - ------- - get_coordinates() - Get the lat-lon coordinates of the nodes. - register_nodes(graph, name) - Register the nodes in the graph. - register_attributes(graph, name, config) - Register the attributes in the nodes of the graph specified. - update_graph(graph, name, attr_config) - Update the graph with new nodes and attributes. - """ - - def __init__(self, resolution: str, grid_definition_path: str, name: str) -> None: - """Initialize the NPZFileNodes builder. - - The builder suppose the grids are stored in files with the name `grid-{resolution}.npz`. - - Parameters - ---------- - resolution : str - The resolution of the grid. - grid_definition_path : str - Path to the folder containing the grid definition files. - """ - self.resolution = resolution - self.grid_definition_path = grid_definition_path - self.grid_definition = np.load(Path(self.grid_definition_path) / f"grid-{self.resolution}.npz") - super().__init__(name) - - def get_coordinates(self) -> torch.Tensor: - """Get the coordinates of the nodes. - - Returns - ------- - torch.Tensor of shape (N, 2) - Coordinates of the nodes. - """ - coords = self.reshape_coords(self.grid_definition["latitudes"], self.grid_definition["longitudes"]) - return coords diff --git a/src/anemoi/graphs/nodes/builders/__init__.py b/src/anemoi/graphs/nodes/builders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/anemoi/graphs/nodes/builders/base.py b/src/anemoi/graphs/nodes/builders/base.py new file mode 100644 index 0000000..a141ebf --- /dev/null +++ b/src/anemoi/graphs/nodes/builders/base.py @@ -0,0 +1,125 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod + +import numpy as np +import torch +from anemoi.utils.config import DotDict +from hydra.utils import instantiate +from torch_geometric.data import HeteroData + + +class BaseNodeBuilder(ABC): + """Base class for node builders. + + The node coordinates are stored in the `x` attribute of the nodes and they are stored in radians. + + Attributes + ---------- + name : str + name of the nodes, key for the nodes in the HeteroData graph object. + area_mask_builder : KNNAreaMaskBuilder + The area of interest mask builder, if any. Defaults to None. + """ + + hidden_attributes: set[str] = set() + + def __init__(self, name: str) -> None: + self.name = name + self.area_mask_builder = None + + def register_nodes(self, graph: HeteroData) -> HeteroData: + """Register nodes in the graph. + + Parameters + ---------- + graph : HeteroData + The graph to register the nodes. + + Returns + ------- + HeteroData + The graph with the registered nodes. + """ + graph[self.name].x = self.get_coordinates() + graph[self.name].node_type = type(self).__name__ + return graph + + def register_attributes(self, graph: HeteroData, config: DotDict | None = None) -> HeteroData: + """Register attributes in the nodes of the graph specified. + + Parameters + ---------- + graph : HeteroData + The graph to register the attributes. + config : DotDict + The configuration of the attributes. + + Returns + ------- + HeteroData + The graph with the registered attributes. + """ + for hidden_attr in self.hidden_attributes: + graph[self.name][f"_{hidden_attr}"] = getattr(self, hidden_attr) + + for attr_name, attr_config in config.items(): + graph[self.name][attr_name] = instantiate(attr_config).compute(graph, self.name) + + return graph + + @abstractmethod + def get_coordinates(self) -> torch.Tensor: ... + + def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> torch.Tensor: + """Reshape latitude and longitude coordinates. + + Parameters + ---------- + latitudes : np.ndarray of shape (num_nodes, ) + Latitude coordinates, in degrees. + longitudes : np.ndarray of shape (num_nodes, ) + Longitude coordinates, in degrees. + + Returns + ------- + torch.Tensor of shape (num_nodes, 2) + A 2D tensor with the coordinates, in radians. + """ + coords = np.stack([latitudes, longitudes], axis=-1).reshape((-1, 2)) + coords = np.deg2rad(coords) + return torch.tensor(coords, dtype=torch.float32) + + def update_graph(self, graph: HeteroData, attr_config: DotDict | None = None) -> HeteroData: + """Update the graph with new nodes. + + Parameters + ---------- + graph : HeteroData + Input graph. + attr_config : DotDict + The configuration of the attributes. + + Returns + ------- + HeteroData + The graph with new nodes included. + """ + graph = self.register_nodes(graph) + + if attr_config is None: + return graph + + graph = self.register_attributes(graph, attr_config) + + return graph diff --git a/src/anemoi/graphs/nodes/builders/from_file.py b/src/anemoi/graphs/nodes/builders/from_file.py new file mode 100644 index 0000000..786ae76 --- /dev/null +++ b/src/anemoi/graphs/nodes/builders/from_file.py @@ -0,0 +1,152 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging +from pathlib import Path + +import numpy as np +import torch +from anemoi.datasets import open_dataset +from omegaconf import DictConfig +from omegaconf import OmegaConf +from torch_geometric.data import HeteroData + +from anemoi.graphs.generate.masks import KNNAreaMaskBuilder +from anemoi.graphs.nodes.builders.base import BaseNodeBuilder + +LOGGER = logging.getLogger(__name__) + + +class ZarrDatasetNodes(BaseNodeBuilder): + """Nodes from Zarr dataset. + + Attributes + ---------- + dataset : str | DictConfig + The dataset. + + Methods + ------- + get_coordinates() + Get the lat-lon coordinates of the nodes. + register_nodes(graph, name) + Register the nodes in the graph. + register_attributes(graph, name, config) + Register the attributes in the nodes of the graph specified. + update_graph(graph, name, attr_config) + Update the graph with new nodes and attributes. + """ + + def __init__(self, dataset: DictConfig, name: str) -> None: + LOGGER.info("Reading the dataset from %s.", dataset) + self.dataset = dataset if isinstance(dataset, str) else OmegaConf.to_container(dataset) + super().__init__(name) + self.hidden_attributes = BaseNodeBuilder.hidden_attributes | {"dataset"} + + def get_coordinates(self) -> torch.Tensor: + """Get the coordinates of the nodes. + + Returns + ------- + torch.Tensor of shape (num_nodes, 2) + A 2D tensor with the coordinates, in radians. + """ + dataset = open_dataset(self.dataset) + return self.reshape_coords(dataset.latitudes, dataset.longitudes) + + +class NPZFileNodes(BaseNodeBuilder): + """Nodes from NPZ defined grids. + + Attributes + ---------- + resolution : str + The resolution of the grid. + grid_definition_path : str + Path to the folder containing the grid definition files. + grid_definition : dict[str, np.ndarray] + The grid definition. + + Methods + ------- + get_coordinates() + Get the lat-lon coordinates of the nodes. + register_nodes(graph, name) + Register the nodes in the graph. + register_attributes(graph, name, config) + Register the attributes in the nodes of the graph specified. + update_graph(graph, name, attr_config) + Update the graph with new nodes and attributes. + """ + + def __init__(self, resolution: str, grid_definition_path: str, name: str) -> None: + """Initialize the NPZFileNodes builder. + + The builder suppose the grids are stored in files with the name `grid-{resolution}.npz`. + + Parameters + ---------- + resolution : str + The resolution of the grid. + grid_definition_path : str + Path to the folder containing the grid definition files. + """ + self.resolution = resolution + self.grid_definition_path = grid_definition_path + self.grid_definition = np.load(Path(self.grid_definition_path) / f"grid-{self.resolution}.npz") + super().__init__(name) + + def get_coordinates(self) -> torch.Tensor: + """Get the coordinates of the nodes. + + Returns + ------- + torch.Tensor of shape (num_nodes, 2) + A 2D tensor with the coordinates, in radians. + """ + coords = self.reshape_coords(self.grid_definition["latitudes"], self.grid_definition["longitudes"]) + return coords + + +class LimitedAreaNPZFileNodes(NPZFileNodes): + """Nodes from NPZ defined grids using an area of interest.""" + + def __init__( + self, + resolution: str, + grid_definition_path: str, + reference_node_name: str, + name: str, + mask_attr_name: str | None = None, + margin_radius_km: float = 100.0, + ) -> None: + + self.area_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name) + + super().__init__(resolution, grid_definition_path, name) + + def register_nodes(self, graph: HeteroData) -> None: + self.area_mask_builder.fit(graph) + return super().register_nodes(graph) + + def get_coordinates(self) -> np.ndarray: + coords = super().get_coordinates() + + LOGGER.info( + "Limiting the processor mesh to a radius of %.2f km from the output mesh.", + self.area_mask_builder.margin_radius_km, + ) + area_mask = self.area_mask_builder.get_mask(coords) + + LOGGER.info("Dropping %d nodes from the processor mesh.", len(area_mask) - area_mask.sum()) + coords = coords[area_mask] + + return coords diff --git a/src/anemoi/graphs/nodes/builders/from_healpix.py b/src/anemoi/graphs/nodes/builders/from_healpix.py new file mode 100644 index 0000000..f36d91e --- /dev/null +++ b/src/anemoi/graphs/nodes/builders/from_healpix.py @@ -0,0 +1,106 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging + +import numpy as np +import torch +from torch_geometric.data import HeteroData + +from anemoi.graphs.generate.masks import KNNAreaMaskBuilder +from anemoi.graphs.nodes.builders.base import BaseNodeBuilder + +LOGGER = logging.getLogger(__name__) + + +class HEALPixNodes(BaseNodeBuilder): + """Nodes from HEALPix grid. + + HEALPix is an acronym for Hierarchical Equal Area isoLatitude Pixelization of a sphere. + + Attributes + ---------- + resolution : int + The resolution of the grid. + + Methods + ------- + get_coordinates() + Get the lat-lon coordinates of the nodes. + register_nodes(graph, name) + Register the nodes in the graph. + register_attributes(graph, name, config) + Register the attributes in the nodes of the graph specified. + update_graph(graph, name, attr_config) + Update the graph with new nodes and attributes. + """ + + def __init__(self, resolution: int, name: str) -> None: + """Initialize the HEALPixNodes builder.""" + self.resolution = resolution + super().__init__(name) + + assert isinstance(resolution, int), "Resolution must be an integer." + assert resolution > 0, "Resolution must be positive." + + def get_coordinates(self) -> torch.Tensor: + """Get the coordinates of the nodes. + + Returns + ------- + torch.Tensor of shape (num_nodes, 2) + Coordinates of the nodes, in radians. + """ + import healpy as hp + + spatial_res_degrees = hp.nside2resol(2**self.resolution, arcmin=True) / 60 + LOGGER.info(f"Creating HEALPix nodes with resolution {spatial_res_degrees:.2} deg.") + + npix = hp.nside2npix(2**self.resolution) + hpxlon, hpxlat = hp.pix2ang(2**self.resolution, range(npix), nest=True, lonlat=True) + + return self.reshape_coords(hpxlat, hpxlon) + + +class LimitedAreaHEALPixNodes(HEALPixNodes): + """Nodes from HEALPix grid using an area of interest.""" + + def __init__( + self, + resolution: str, + reference_node_name: str, + name: str, + mask_attr_name: str | None = None, + margin_radius_km: float = 100.0, + ) -> None: + + self.area_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name) + + super().__init__(resolution, name) + + def register_nodes(self, graph: HeteroData) -> None: + self.area_mask_builder.fit(graph) + return super().register_nodes(graph) + + def get_coordinates(self) -> np.ndarray: + coords = super().get_coordinates() + + LOGGER.info( + 'Limiting the "%s" nodes to a radius of %.2f km from the nodes of interest.', + self.name, + self.area_mask_builder.margin_radius_km, + ) + area_mask = self.area_mask_builder.get_mask(coords) + + LOGGER.info('Masking out %d nodes from "%s".', len(area_mask) - area_mask.sum(), self.name) + coords = coords[area_mask] + + return coords diff --git a/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py b/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py new file mode 100644 index 0000000..1f961e0 --- /dev/null +++ b/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py @@ -0,0 +1,191 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging +from abc import ABC +from abc import abstractmethod + +import networkx as nx +import numpy as np +import torch +from torch_geometric.data import HeteroData + +from anemoi.graphs.generate.hex_icosahedron import create_hex_nodes +from anemoi.graphs.generate.masks import KNNAreaMaskBuilder +from anemoi.graphs.generate.tri_icosahedron import create_stretched_tri_nodes +from anemoi.graphs.generate.tri_icosahedron import create_tri_nodes +from anemoi.graphs.nodes.builders.base import BaseNodeBuilder + +LOGGER = logging.getLogger(__name__) + + +class IcosahedralNodes(BaseNodeBuilder, ABC): + """Nodes based on iterative refinements of an icosahedron. + + Attributes + ---------- + resolution : list[int] | int + Refinement level of the mesh. + """ + + def __init__( + self, + resolution: int | list[int], + name: str, + ) -> None: + if isinstance(resolution, int): + self.resolutions = list(range(resolution + 1)) + else: + self.resolutions = resolution + + super().__init__(name) + self.hidden_attributes = BaseNodeBuilder.hidden_attributes | { + "resolutions", + "nx_graph", + "node_ordering", + "area_mask_builder", + } + + def get_coordinates(self) -> torch.Tensor: + """Get the coordinates of the nodes. + + Returns + ------- + torch.Tensor of shape (num_nodes, 2) + A 2D tensor with the coordinates, in radians. + """ + self.nx_graph, coords_rad, self.node_ordering = self.create_nodes() + return torch.tensor(coords_rad[self.node_ordering], dtype=torch.float32) + + @abstractmethod + def create_nodes(self) -> tuple[nx.DiGraph, np.ndarray, list[int]]: ... + + +class LimitedAreaIcosahedralNodes(IcosahedralNodes): + """Nodes based on iterative refinements of an icosahedron using an area of interest. + + Attributes + ---------- + area_mask_builder : KNNAreaMaskBuilder + The area of interest mask builder. + """ + + def __init__( + self, + resolution: int | list[int], + reference_node_name: str, + name: str, + mask_attr_name: str | None = None, + margin_radius_km: float = 100.0, + ) -> None: + + super().__init__(resolution, name) + + self.area_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name) + + def register_nodes(self, graph: HeteroData) -> None: + self.area_mask_builder.fit(graph) + return super().register_nodes(graph) + + +class TriNodes(IcosahedralNodes): + """Nodes based on iterative refinements of an icosahedron. + + It depends on the trimesh Python library. + """ + + def create_nodes(self) -> tuple[nx.Graph, np.ndarray, list[int]]: + return create_tri_nodes(resolution=max(self.resolutions)) + + +class HexNodes(IcosahedralNodes): + """Nodes based on iterative refinements of an icosahedron. + + It depends on the h3 Python library. + """ + + def create_nodes(self) -> tuple[nx.Graph, np.ndarray, list[int]]: + return create_hex_nodes(resolution=max(self.resolutions)) + + +class LimitedAreaTriNodes(LimitedAreaIcosahedralNodes): + """Nodes based on iterative refinements of an icosahedron using an area of interest. + + It depends on the trimesh Python library. + + Parameters + ---------- + area_mask_builder: KNNAreaMaskBuilder + The area of interest mask builder. + """ + + def create_nodes(self) -> tuple[nx.Graph, np.ndarray, list[int]]: + return create_tri_nodes(resolution=max(self.resolutions), area_mask_builder=self.area_mask_builder) + + +class LimitedAreaHexNodes(LimitedAreaIcosahedralNodes): + """Nodes based on iterative refinements of an icosahedron using an area of interest. + + It depends on the h3 Python library. + + Parameters + ---------- + area_mask_builder: KNNAreaMaskBuilder + The area of interest mask builder. + """ + + def create_nodes(self) -> tuple[nx.Graph, np.ndarray, list[int]]: + return create_hex_nodes(resolution=max(self.resolutions), area_mask_builder=self.area_mask_builder) + + +class StretchedIcosahedronNodes(IcosahedralNodes): + """Nodes based on iterative refinements of an icosahedron with 2 + different resolutions. + + Attributes + ---------- + area_mask_builder : KNNAreaMaskBuilder + The area of interest mask builder. + """ + + def __init__( + self, + global_resolution: int, + lam_resolution: int, + name: str, + reference_node_name: str, + mask_attr_name: str, + margin_radius_km: float = 100.0, + ) -> None: + + super().__init__(lam_resolution, name) + self.global_resolution = global_resolution + + self.area_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name) + + def register_nodes(self, graph: HeteroData) -> None: + self.area_mask_builder.fit(graph) + return super().register_nodes(graph) + + +class StretchedTriNodes(StretchedIcosahedronNodes): + """Nodes based on iterative refinements of an icosahedron with 2 + different resolutions. + + It depends on the trimesh Python library. + """ + + def create_nodes(self) -> tuple[nx.Graph, np.ndarray, list[int]]: + return create_stretched_tri_nodes( + base_resolution=self.global_resolution, + lam_resolution=max(self.resolutions), + area_mask_builder=self.area_mask_builder, + ) diff --git a/src/anemoi/graphs/normalise.py b/src/anemoi/graphs/normalise.py new file mode 100644 index 0000000..4e50118 --- /dev/null +++ b/src/anemoi/graphs/normalise.py @@ -0,0 +1,55 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import numpy as np + +LOGGER = logging.getLogger(__name__) + + +class NormaliserMixin: + """Mixin class for normalising attributes.""" + + def normalise(self, values: np.ndarray) -> np.ndarray: + """Normalise the given values. + + It supports different normalisation methods: None, 'l1', + 'l2', 'unit-max', 'unit-range' and 'unit-std'. + + Parameters + ---------- + values : np.ndarray of shape (N, M) + Values to normalise. + + Returns + ------- + np.ndarray + Normalised values. + """ + if self.norm is None: + LOGGER.debug(f"{self.__class__.__name__} values are not normalised.") + return values + if self.norm == "l1": + return values / np.sum(values) + if self.norm == "l2": + return values / np.linalg.norm(values) + if self.norm == "unit-max": + return values / np.amax(values) + if self.norm == "unit-range": + return (values - np.amin(values)) / (np.amax(values) - np.amin(values)) + if self.norm == "unit-std": + std = np.std(values) + if std == 0: + LOGGER.warning(f"Std. dev. of the {self.__class__.__name__} values is 0. Normalisation is skipped.") + return values + return values / std + raise ValueError( + f"Attribute normalisation \"{self.norm}\" is not valid. Options are: 'l1', 'l2', 'unit-max' or 'unit-std'." + ) diff --git a/src/anemoi/graphs/normalizer.py b/src/anemoi/graphs/normalizer.py deleted file mode 100644 index c625417..0000000 --- a/src/anemoi/graphs/normalizer.py +++ /dev/null @@ -1,44 +0,0 @@ -import logging - -import numpy as np - -LOGGER = logging.getLogger(__name__) - - -class NormalizerMixin: - """Mixin class for normalizing attributes.""" - - def normalize(self, values: np.ndarray) -> np.ndarray: - """Normalize the given values. - - It supports different normalization methods: None, 'l1', - 'l2', 'unit-max' and 'unit-std'. - - Parameters - ---------- - values : np.ndarray of shape (N, M) - Values to normalize. - - Returns - ------- - np.ndarray - Normalized values. - """ - if self.norm is None: - LOGGER.debug(f"{self.__class__.__name__} values are not normalized.") - return values - if self.norm == "l1": - return values / np.sum(values) - if self.norm == "l2": - return values / np.linalg.norm(values) - if self.norm == "unit-max": - return values / np.amax(values) - if self.norm == "unit-std": - std = np.std(values) - if std == 0: - LOGGER.warning(f"Std. dev. of the {self.__class__.__name__} values is 0. Normalization is skipped.") - return values - return values / std - raise ValueError( - f"Attribute normalization \"{self.norm}\" is not valid. Options are: 'l1', 'l2', 'unit-max' or 'unit-std'." - ) diff --git a/src/anemoi/graphs/plotting/__init__.py b/src/anemoi/graphs/plotting/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/anemoi/graphs/plotting/displots.py b/src/anemoi/graphs/plotting/displots.py new file mode 100644 index 0000000..1830352 --- /dev/null +++ b/src/anemoi/graphs/plotting/displots.py @@ -0,0 +1,115 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from pathlib import Path +from typing import Literal +from typing import Optional +from typing import Union + +import matplotlib.pyplot as plt +import numpy as np +import torch +from torch_geometric.data import HeteroData +from torch_geometric.data.storage import EdgeStorage +from torch_geometric.data.storage import NodeStorage + +from anemoi.graphs.plotting.prepare import compute_node_adjacencies +from anemoi.graphs.plotting.prepare import get_edge_attribute_dims +from anemoi.graphs.plotting.prepare import get_node_attribute_dims + +LOGGER = logging.getLogger(__name__) + + +def plot_distribution_node_attributes(graph: HeteroData, out_file: Optional[Union[str, Path]] = None) -> None: + """Figure with the distribution of the node attributes. + + Each row represents a node type and each column an attribute dimension. + """ + num_nodes = len(graph.node_types) + attr_dims = get_node_attribute_dims(graph) + plot_distribution_attributes(graph.node_items(), num_nodes, attr_dims, "Node", out_file) + + +def plot_distribution_edge_attributes(graph: HeteroData, out_file: Optional[Union[str, Path]] = None) -> None: + """Figure with the distribution of the edge attributes. + + Each row represents a edge type and each column an attribute dimension. + """ + num_edges = len(graph.edge_types) + attr_dims = get_edge_attribute_dims(graph) + plot_distribution_attributes(graph.edge_items(), num_edges, attr_dims, "Edge", out_file) + + +def plot_distribution_node_derived_attributes(graph, outfile: Optional[Union[str, Path]] = None): + """Figure with the distribution of the node derived attributes. + + Each row represents a node type and each column an attribute dimension. + """ + node_adjs = {} + node_attr_dims = {} + for source_name, _, target_name in graph.edge_types: + node_adj_tensor = compute_node_adjacencies(graph, source_name, target_name) + node_adj_tensor = torch.from_numpy(node_adj_tensor.reshape((node_adj_tensor.shape[0], -1))) + node_adj_key = f"# edges from {source_name}" + + # Store node adjacencies + if target_name in node_adjs: + node_adjs[target_name] = node_adjs[target_name] | {node_adj_key: node_adj_tensor} + else: + node_adjs[target_name] = {node_adj_key: node_adj_tensor} + + # Store attribute dimension + if node_adj_key not in node_attr_dims: + node_attr_dims[node_adj_key] = node_adj_tensor.shape[1] + + node_adj_list = [(k, v) for k, v in node_adjs.items()] + + plot_distribution_attributes(node_adj_list, len(node_adjs), node_attr_dims, "Node", outfile) + + +def plot_distribution_attributes( + graph_items: Union[NodeStorage, EdgeStorage], + num_items: int, + attr_dims: dict, + item_type: Literal["Edge", "Node"], + out_file: Optional[Union[str, Path]] = None, +) -> None: + """Figure with the distribution of the node and edge attributes. + + Each row represents a node or edge type and each column an attribute dimension. + """ + dim_attrs = sum(attr_dims.values()) + + if dim_attrs == 0: + LOGGER.warning("No edge attributes found in the graph.") + return None + + # Define the layout + _, axs = plt.subplots(num_items, dim_attrs, figsize=(10 * num_items, 10)) + if num_items == dim_attrs == 1: + axs = np.array([[axs]]) + elif axs.ndim == 1: + axs = axs.reshape(num_items, dim_attrs) + + for i, (item_name, item_store) in enumerate(graph_items): + for j, (attr_name, attr_values) in enumerate(attr_dims.items()): + for dim in range(attr_values): + if attr_name in item_store: + axs[i, j + dim].hist(item_store[attr_name][:, dim].float(), bins=50) + + axs[i, j + dim].set_ylabel("".join(item_name).replace("to", " --> ")) + axs[i, j + dim].set_title(attr_name if attr_values == 1 else f"{attr_name}_{dim}") + if i == num_items - 1: + axs[i, j + dim].set_xlabel(attr_name if attr_values == 1 else f"{attr_name}_{dim}") + else: + axs[i, j + dim].set_axis_off() + + plt.suptitle(f"{item_type} Attributes distribution", fontsize=14) + plt.savefig(out_file) diff --git a/src/anemoi/graphs/plotting/interactive_html.py b/src/anemoi/graphs/plotting/interactive_html.py new file mode 100644 index 0000000..7021bf1 --- /dev/null +++ b/src/anemoi/graphs/plotting/interactive_html.py @@ -0,0 +1,252 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from pathlib import Path +from typing import Optional +from typing import Union + +import matplotlib.pyplot as plt +import numpy as np +import plotly.graph_objects as go +from matplotlib.colors import rgb2hex +from torch_geometric.data import HeteroData + +from anemoi.graphs.plotting.prepare import compute_isolated_nodes +from anemoi.graphs.plotting.prepare import compute_node_adjacencies +from anemoi.graphs.plotting.prepare import edge_list +from anemoi.graphs.plotting.prepare import node_list + +annotations_style = {"text": "", "showarrow": False, "xref": "paper", "yref": "paper", "x": 0.005, "y": -0.002} +plotly_axis_config = {"showgrid": False, "zeroline": False, "showticklabels": False} + +LOGGER = logging.getLogger(__name__) + + +def plot_interactive_subgraph( + graph: HeteroData, + edges_to_plot: tuple[str, str, str], + out_file: Optional[Union[str, Path]] = None, +) -> None: + """Plots a bipartite graph (bi-graph). + + This methods plots the bipartite graph passed in an interactive window (using Ploty). + + Parameters + ---------- + graph : dict + The graph to plot. + edges_to_plot : tuple[str, str] + Names of the edges to plot. + out_file : str | Path, optional + Name of the file to save the plot. Default is None. + """ + source_name, _, target_name = edges_to_plot + edge_x, edge_y = edge_list(graph, source_nodes_name=source_name, target_nodes_name=target_name) + assert source_name in graph.node_types, f"edges_to_plot ({source_name}) should be in the graph" + assert target_name in graph.node_types, f"edges_to_plot ({target_name}) should be in the graph" + lats_source_nodes, lons_source_nodes = node_list(graph, source_name) + lats_target_nodes, lons_target_nodes = node_list(graph, target_name) + + # Compute node adjacencies + node_adjacencies = compute_node_adjacencies(graph, source_name, target_name) + node_text = [f"# of connections: {x}" for x in node_adjacencies] + + edge_trace = go.Scattergeo( + lat=edge_x, + lon=edge_y, + line={"width": 0.5, "color": "#888"}, + hoverinfo="none", + mode="lines", + name="Connections", + ) + + source_node_trace = go.Scattergeo( + lat=lats_source_nodes, + lon=lons_source_nodes, + mode="markers", + hoverinfo="text", + name=source_name, + marker={ + "showscale": False, + "color": "red", + "size": 2, + "line_width": 2, + }, + ) + + target_node_trace = go.Scattergeo( + lat=lats_target_nodes, + lon=lons_target_nodes, + mode="markers", + hoverinfo="text", + name=target_name, + text=node_text, + marker={ + "showscale": True, + "colorscale": "YlGnBu", + "reversescale": True, + "color": list(node_adjacencies), + "size": 10, + "colorbar": {"thickness": 15, "title": "Node Connections", "xanchor": "left", "titleside": "right"}, + "line_width": 2, + }, + ) + layout = go.Layout( + title="
" + f"Graph {source_name} --> {target_name}", + titlefont_size=16, + showlegend=True, + hovermode="closest", + margin={"b": 20, "l": 5, "r": 5, "t": 40}, + annotations=[annotations_style], + legend={"x": 0, "y": 1}, + xaxis=plotly_axis_config, + yaxis=plotly_axis_config, + ) + fig = go.Figure(data=[edge_trace, source_node_trace, target_node_trace], layout=layout) + fig.update_geos(fitbounds="locations") + + if out_file is not None: + fig.write_html(out_file) + else: + fig.show() + + +def plot_isolated_nodes(graph: HeteroData, out_file: Optional[Union[str, Path]] = None) -> None: + """Plot isolated nodes. + + This method creates an interactive visualization of the isolated nodes in the graph. + + Parameters + ---------- + graph : AnemoiGraph + The graph to plot. + out_file : str | Path, optional + Name of the file to save the plot. Default is None. + """ + isolated_nodes = compute_isolated_nodes(graph) + + if len(isolated_nodes) == 0: + LOGGER.warning("No isolated nodes found.") + return + + colorbar = plt.cm.rainbow(np.linspace(0, 1, len(isolated_nodes))) + nodes = [] + for name, (lat, lon) in isolated_nodes.items(): + nodes.append( + go.Scattergeo( + lat=lat, + lon=lon, + mode="markers", + hoverinfo="text", + name=name, + marker={"showscale": False, "color": rgb2hex(colorbar[len(nodes)]), "size": 10}, + ), + ) + + layout = go.Layout( + title="
Orphan nodes", + titlefont_size=16, + showlegend=True, + hovermode="closest", + margin={"b": 20, "l": 5, "r": 5, "t": 40}, + annotations=[annotations_style], + legend={"x": 0, "y": 1}, + xaxis=plotly_axis_config, + yaxis=plotly_axis_config, + ) + fig = go.Figure(data=nodes, layout=layout) + fig.update_geos(fitbounds="locations") + + if out_file is not None: + fig.write_html(out_file) + else: + fig.show() + + +def plot_interactive_nodes(graph: HeteroData, nodes_name: str, out_file: Optional[str] = None) -> None: + """Plot nodes. + + This method creates an interactive visualization of a set of nodes. + + Parameters + ---------- + graph : HeteroData + Graph. + nodes_name : str + Name of the nodes to plot. + out_file : str, optional + Name of the file to save the plot. Default is None. + """ + node_latitudes, node_longitudes = node_list(graph, nodes_name) + node_attrs = graph[nodes_name].node_attrs() + # Remove x to avoid plotting the coordinates as an attribute + node_attrs.remove("x") + + if len(node_attrs) == 0: + LOGGER.warning(f"No node attributes found for {nodes_name} nodes.") + return + + node_traces = {} + for node_attr in node_attrs: + node_attr_values = graph[nodes_name][node_attr].float().numpy() + + # Skip multi-dimensional attributes. Supported only: (N, 1) or (N,) tensors + if node_attr_values.ndim > 1 and node_attr_values.shape[1] > 1: + continue + + node_traces[node_attr] = go.Scattergeo( + lat=node_latitudes, + lon=node_longitudes, + name=" ".join(node_attr.split("_")).capitalize(), + mode="markers", + hoverinfo="text", + marker={ + "color": node_attr_values.squeeze().tolist(), + "showscale": True, + "colorscale": "RdBu", + "colorbar": {"thickness": 15, "title": node_attr, "xanchor": "left"}, + "size": 5, + }, + visible=False, + ) + + # Create and add slider + slider_steps = [] + for i, node_attr in enumerate(node_traces.keys()): + step = dict( + label=f"Node attribute: {node_attr}", + method="update", + args=[{"visible": [False] * len(node_traces)}], + ) + step["args"][0]["visible"][i] = True # Toggle i'th trace to "visible" + slider_steps.append(step) + + fig = go.Figure( + data=list(node_traces.values()), + layout=go.Layout( + title=f"
Map of {nodes_name} nodes", + sliders=[ + dict(active=0, currentvalue={"visible": False}, len=0.4, x=0.5, xanchor="center", steps=slider_steps) + ], + titlefont_size=16, + showlegend=False, + hovermode="closest", + margin={"b": 20, "l": 5, "r": 5, "t": 40}, + annotations=[annotations_style], + xaxis=plotly_axis_config, + yaxis=plotly_axis_config, + ), + ) + fig.data[0].visible = True + + if out_file is not None: + fig.write_html(out_file) + else: + fig.show() diff --git a/src/anemoi/graphs/plotting/prepare.py b/src/anemoi/graphs/plotting/prepare.py new file mode 100644 index 0000000..f6c2fe3 --- /dev/null +++ b/src/anemoi/graphs/plotting/prepare.py @@ -0,0 +1,197 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from typing import Optional + +import numpy as np +import torch +from torch_geometric.data import HeteroData + + +def node_list(graph: HeteroData, nodes_name: str, mask: Optional[list[bool]] = None) -> tuple[list[float], list[float]]: + """Get the latitude and longitude of the nodes. + + Parameters + ---------- + graph : dict[str, torch.Tensor] + Graph to plot. + nodes_name : str + Name of the nodes. + mask : list[bool], optional + Mask to filter the nodes. Default is None. + + Returns + ------- + latitudes : list[float] + Latitude coordinates of the nodes. + longitudes : list[float] + Longitude coordinates of the nodes. + """ + coords = np.rad2deg(graph[nodes_name].x.numpy()) + latitudes = coords[:, 0] + longitudes = coords[:, 1] + if mask is not None: + latitudes = latitudes[mask] + longitudes = longitudes[mask] + return latitudes.tolist(), longitudes.tolist() + + +def edge_list(graph: HeteroData, source_nodes_name: str, target_nodes_name: str) -> tuple[np.ndarray, np.ndarray]: + """Get the edge list. + + This method returns the edge list to be represented in a graph. It computes the coordinates of the points connected + and include NaNs to separate the edges. + + Parameters + ---------- + graph : HeteroData + Graph to plot. + source_nodes_name : str + Name of the source nodes. + target_nodes_name : str + Name of the target nodes. + + Returns + ------- + latitudes : np.ndarray + Latitude coordinates of the points connected. + longitudes : np.ndarray + Longitude coordinates of the points connected. + """ + sub_graph = graph[(source_nodes_name, "to", target_nodes_name)].edge_index + x0 = np.rad2deg(graph[source_nodes_name].x[sub_graph[0]]) + y0 = np.rad2deg(graph[target_nodes_name].x[sub_graph[1]]) + nans = np.full_like(x0[:, :1], np.nan) + latitudes = np.concatenate([x0[:, :1], y0[:, :1], nans], axis=1).flatten() + longitudes = np.concatenate([x0[:, 1:2], y0[:, 1:2], nans], axis=1).flatten() + return latitudes, longitudes + + +def compute_node_adjacencies( + graph: HeteroData, source_nodes_name: str, target_nodes_name: str +) -> tuple[list[int], list[str]]: + """Compute the number of adjacencies of each target node in a bipartite graph. + + Parameters + ---------- + graph : HeteroData + Graph to plot. + source_nodes_name : str + Name of the dimension of the coordinates for the head nodes. + target_nodes_name : str + Name of the dimension of the coordinates for the tail nodes. + + Returns + ------- + num_adjacencies : np.ndarray + Number of adjacencies of each node. + """ + node_adjacencies = np.zeros(graph[target_nodes_name].num_nodes, dtype=int) + vals, counts = np.unique(graph[(source_nodes_name, "to", target_nodes_name)].edge_index[1], return_counts=True) + node_adjacencies[vals] = counts + return node_adjacencies + + +def get_node_adjancency_attributes(graph: HeteroData) -> dict[str, tuple[str, np.ndarray]]: + """Get the node adjacencies for each subgraph.""" + node_adj_attr = {} + for (source_nodes_name, _, target_nodes_name), _ in graph.edge_items(): + attr_name = f"# connections from {source_nodes_name}" + node_adj_vector = compute_node_adjacencies(graph, source_nodes_name, target_nodes_name) + if target_nodes_name not in node_adj_attr: + node_adj_attr[target_nodes_name] = {attr_name: node_adj_vector} + else: + node_adj_attr[target_nodes_name][attr_name] = node_adj_vector + + return node_adj_attr + + +def compute_isolated_nodes(graph: HeteroData) -> dict[str, tuple[list, list]]: + """Compute the isolated nodes. + + Parameters + ---------- + graph : HeteroData + Graph. + + Returns + ------- + dict[str, list[int]] + Dictionary with the isolated nodes for each subgraph. + """ + isolated_nodes = {} + for (source_name, _, target_name), sub_graph in graph.edge_items(): + head_isolated = np.ones(graph[source_name].num_nodes, dtype=bool) + tail_isolated = np.ones(graph[target_name].num_nodes, dtype=bool) + head_isolated[sub_graph.edge_index[0]] = False + tail_isolated[sub_graph.edge_index[1]] = False + if np.any(head_isolated): + isolated_nodes[f"{source_name} isolated (--> {target_name})"] = node_list( + graph, source_name, mask=list(head_isolated) + ) + if np.any(tail_isolated): + isolated_nodes[f"{target_name} isolated ({source_name} -->)"] = node_list( + graph, target_name, mask=list(tail_isolated) + ) + + return isolated_nodes + + +def get_node_attribute_dims(graph: HeteroData) -> dict[str, int]: + """Get dimensions of the node attributes. + + Parameters + ---------- + graph : HeteroData + The graph to inspect. + + Returns + ------- + dict[str, int] + A dictionary with the attribute names as keys and the number of dimensions as values. + """ + attr_dims = {} + for nodes in graph.node_stores: + for attr in nodes.node_attrs(): + if attr == "x" or not isinstance(nodes[attr], torch.Tensor): + continue + elif attr not in attr_dims: + attr_dims[attr] = nodes[attr].shape[1] + else: + assert ( + nodes[attr].shape[1] == attr_dims[attr] + ), f"Attribute {attr} has different dimensions in different nodes." + return attr_dims + + +def get_edge_attribute_dims(graph: HeteroData) -> dict[str, int]: + """Get dimensions of the node attributes. + + Parameters + ---------- + graph : HeteroData + The graph to inspect. + + Returns + ------- + dict[str, int] + A dictionary with the attribute names as keys and the number of dimensions as values. + """ + attr_dims = {} + for edges in graph.edge_stores: + for attr in edges.edge_attrs(): + if attr == "edge_index": + continue + elif attr not in attr_dims: + attr_dims[attr] = edges[attr].shape[1] + else: + assert ( + edges[attr].shape[1] == attr_dims[attr] + ), f"Attribute {attr} has different dimensions in different edges." + return attr_dims diff --git a/src/anemoi/graphs/utils.py b/src/anemoi/graphs/utils.py index 8999bc6..a68d6e7 100644 --- a/src/anemoi/graphs/utils.py +++ b/src/anemoi/graphs/utils.py @@ -1,11 +1,20 @@ -from typing import Optional +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations import numpy as np import torch from sklearn.neighbors import NearestNeighbors -def get_nearest_neighbour(coords_rad: torch.Tensor, mask: Optional[torch.Tensor] = None) -> NearestNeighbors: +def get_nearest_neighbour(coords_rad: torch.Tensor, mask: torch.Tensor | None = None) -> NearestNeighbors: """Get NearestNeighbour object fitted to coordinates. Parameters @@ -32,7 +41,7 @@ def get_nearest_neighbour(coords_rad: torch.Tensor, mask: Optional[torch.Tensor] return nearest_neighbour -def get_grid_reference_distance(coords_rad: torch.Tensor, mask: Optional[torch.Tensor] = None) -> float: +def get_grid_reference_distance(coords_rad: torch.Tensor, mask: torch.Tensor | None = None) -> float: """Get the reference distance of the grid. It is the maximum distance of a node in the mesh with respect to its nearest neighbour. diff --git a/tests/conftest.py b/tests/conftest.py index 1dc76de..59cf21d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,12 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + import numpy as np import pytest import torch @@ -11,10 +20,12 @@ class MockZarrDataset: """Mock Zarr dataset with latitudes and longitudes attributes.""" - def __init__(self, latitudes, longitudes): + def __init__(self, latitudes, longitudes, grids=None): self.latitudes = latitudes self.longitudes = longitudes self.num_nodes = len(latitudes) + if grids is not None: + self.grids = grids @pytest.fixture @@ -24,6 +35,14 @@ def mock_zarr_dataset() -> MockZarrDataset: return MockZarrDataset(latitudes=coords[:, 0], longitudes=coords[:, 1]) +@pytest.fixture +def mock_zarr_dataset_cutout() -> MockZarrDataset: + """Mock zarr dataset with nodes.""" + coords = 2 * torch.pi * np.array([[lat, lon] for lat in lats for lon in lons]) + grids = int(0.3 * len(coords)), int(0.7 * len(coords)) + return MockZarrDataset(latitudes=coords[:, 0], longitudes=coords[:, 1], grids=grids) + + @pytest.fixture def mock_grids_path(tmp_path) -> tuple[str, int]: """Mock grid_definition_path with files for 3 resolutions.""" @@ -40,6 +59,7 @@ def graph_with_nodes() -> HeteroData: coords = np.array([[lat, lon] for lat in lats for lon in lons]) graph = HeteroData() graph["test_nodes"].x = 2 * torch.pi * torch.tensor(coords) + graph["test_nodes"].mask = torch.tensor([True] * len(coords)) return graph @@ -49,6 +69,7 @@ def graph_nodes_and_edges() -> HeteroData: coords = np.array([[lat, lon] for lat in lats for lon in lons]) graph = HeteroData() graph["test_nodes"].x = 2 * torch.pi * torch.tensor(coords) + graph["test_nodes"].mask = torch.tensor([True] * len(coords)) graph[("test_nodes", "to", "test_nodes")].edge_index = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]]) return graph @@ -57,16 +78,15 @@ def graph_nodes_and_edges() -> HeteroData: def config_file(tmp_path) -> tuple[str, str]: """Mock grid_definition_path with files for 3 resolutions.""" cfg = { - "nodes": [ - { - "name": "test_nodes", + "nodes": { + "test_nodes": { "node_builder": { "_target_": "anemoi.graphs.nodes.NPZFileNodes", "grid_definition_path": str(tmp_path), "resolution": "o16", }, - } - ], + }, + }, "edges": [ { "source_name": "test_nodes", diff --git a/tests/edges/test_cutoff.py b/tests/edges/test_cutoff.py index 838134c..623e335 100644 --- a/tests/edges/test_cutoff.py +++ b/tests/edges/test_cutoff.py @@ -1,3 +1,12 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + import pytest from anemoi.graphs.edges import CutOffEdges diff --git a/tests/edges/test_edge_attributes.py b/tests/edges/test_edge_attributes.py index 40cba1c..74ff198 100644 --- a/tests/edges/test_edge_attributes.py +++ b/tests/edges/test_edge_attributes.py @@ -1,3 +1,12 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + import pytest import torch diff --git a/tests/edges/test_knn.py b/tests/edges/test_knn.py index 9f6cae9..d0688b1 100644 --- a/tests/edges/test_knn.py +++ b/tests/edges/test_knn.py @@ -1,3 +1,12 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + import pytest from anemoi.graphs.edges import KNNEdges diff --git a/tests/edges/test_multiscale_edges.py b/tests/edges/test_multiscale_edges.py new file mode 100644 index 0000000..b6d424a --- /dev/null +++ b/tests/edges/test_multiscale_edges.py @@ -0,0 +1,101 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import numpy as np +import pytest +from torch_geometric.data import HeteroData + +from anemoi.graphs.edges import MultiScaleEdges +from anemoi.graphs.generate.masks import KNNAreaMaskBuilder +from anemoi.graphs.nodes import HexNodes +from anemoi.graphs.nodes import StretchedTriNodes +from anemoi.graphs.nodes import TriNodes + + +class TestMultiScaleEdgesInit: + def test_init(self): + """Test MultiScaleEdges initialization.""" + assert isinstance(MultiScaleEdges("test_nodes", "test_nodes", 1), MultiScaleEdges) + + @pytest.mark.parametrize("x_hops", [-0.5, "hello", None, -4]) + def test_fail_init(self, x_hops: str): + """Test MultiScaleEdges initialization with invalid x_hops.""" + with pytest.raises(AssertionError): + MultiScaleEdges("test_nodes", "test_nodes", x_hops) + + def test_fail_init_diff_nodes(self): + """Test MultiScaleEdges initialization with invalid nodes.""" + with pytest.raises(AssertionError): + MultiScaleEdges("test_nodes", "test_nodes2", 0) + + +class TestMultiScaleEdgesTransform: + + @pytest.fixture() + def tri_ico_graph(self) -> HeteroData: + """Return a HeteroData object with MultiScaleEdges.""" + graph = HeteroData() + graph = TriNodes(1, "test_tri_nodes").update_graph(graph, {}) + graph["fail_nodes"].x = [1, 2, 3] + graph["fail_nodes"].node_type = "FailNodes" + return graph + + @pytest.fixture() + def hex_ico_graph(self) -> HeteroData: + """Return a HeteroData object with TriNodes.""" + graph = HeteroData() + graph = HexNodes(1, "test_hex_nodes").update_graph(graph, {}) + graph["fail_nodes"].x = [1, 2, 3] + graph["fail_nodes"].node_type = "FailNodes" + return graph + + def test_transform_same_src_dst_tri_nodes(self, tri_ico_graph: HeteroData): + """Test MultiScaleEdges update method.""" + + edges = MultiScaleEdges("test_tri_nodes", "test_tri_nodes", 1) + graph = edges.update_graph(tri_ico_graph) + assert ("test_tri_nodes", "to", "test_tri_nodes") in graph.edge_types + + def test_transform_same_src_dst_hex_nodes(self, hex_ico_graph: HeteroData): + """Test MultiScaleEdges update method.""" + + edges = MultiScaleEdges("test_hex_nodes", "test_hex_nodes", 1) + graph = edges.update_graph(hex_ico_graph) + assert ("test_hex_nodes", "to", "test_hex_nodes") in graph.edge_types + + def test_transform_fail_nodes(self, tri_ico_graph: HeteroData): + """Test MultiScaleEdges update method with wrong node type.""" + edges = MultiScaleEdges("fail_nodes", "fail_nodes", 1) + with pytest.raises(AssertionError): + edges.update_graph(tri_ico_graph) + + +class TestMultiScaleEdgesStretched: + + @pytest.fixture() + def tri_graph(self, mocker) -> HeteroData: + """Return a HeteroData object with stretched Tri nodes.""" + graph = HeteroData() + node_builder = StretchedTriNodes(5, 7, "hidden", None, None) + node_builder.area_mask_builder = KNNAreaMaskBuilder("hidden", 400) + node_builder.area_mask_builder.fit_coords(np.array([[0, 0]])) + # We are considering a 400km radius circle centered at (0, 0) as the area of + # interest for the stretched graph. + + mocker.patch.object(node_builder.area_mask_builder, "fit", return_value=None) + + graph = node_builder.update_graph(graph, {}) + return graph + + def test_edges(self, tri_graph: HeteroData): + """Test MultiScaleEdges update method.""" + edges = MultiScaleEdges("hidden", "hidden", x_hops=1) + graph = edges.update_graph(tri_graph) + assert ("hidden", "to", "hidden") in graph.edge_types + assert len(graph[("hidden", "to", "hidden")].edge_index) > 0 diff --git a/tests/generate/test_masks.py b/tests/generate/test_masks.py new file mode 100644 index 0000000..d43c5f5 --- /dev/null +++ b/tests/generate/test_masks.py @@ -0,0 +1,57 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import pytest +from sklearn.neighbors import NearestNeighbors +from torch_geometric.data import HeteroData + +from anemoi.graphs.generate.masks import KNNAreaMaskBuilder + + +def test_init(): + """Test KNNAreaMaskBuilder initialization.""" + mask_builder1 = KNNAreaMaskBuilder("nodes") + mask_builder2 = KNNAreaMaskBuilder("nodes", margin_radius_km=120) + mask_builder3 = KNNAreaMaskBuilder("nodes", mask_attr_name="mask") + mask_builder4 = KNNAreaMaskBuilder("nodes", margin_radius_km=120, mask_attr_name="mask") + + assert isinstance(mask_builder1, KNNAreaMaskBuilder) + assert isinstance(mask_builder2, KNNAreaMaskBuilder) + assert isinstance(mask_builder3, KNNAreaMaskBuilder) + assert isinstance(mask_builder4, KNNAreaMaskBuilder) + + assert isinstance(mask_builder1.nearest_neighbour, NearestNeighbors) + assert isinstance(mask_builder2.nearest_neighbour, NearestNeighbors) + assert isinstance(mask_builder3.nearest_neighbour, NearestNeighbors) + assert isinstance(mask_builder4.nearest_neighbour, NearestNeighbors) + + +@pytest.mark.parametrize("margin", [-1, "120", None]) +def test_fail_init_wrong_margin(margin: int): + """Test KNNAreaMaskBuilder initialization with invalid margin.""" + with pytest.raises(AssertionError): + KNNAreaMaskBuilder("nodes", margin_radius_km=margin) + + +@pytest.mark.parametrize("mask", [None, "mask"]) +def test_fit(graph_with_nodes: HeteroData, mask: str): + """Test KNNAreaMaskBuilder fit.""" + mask_builder = KNNAreaMaskBuilder("test_nodes", mask_attr_name=mask) + assert not hasattr(mask_builder.nearest_neighbour, "n_samples_fit_") + + mask_builder.fit(graph_with_nodes) + + assert mask_builder.nearest_neighbour.n_samples_fit_ == graph_with_nodes["test_nodes"].num_nodes + + +def test_fit_fail(graph_with_nodes): + """Test KNNAreaMaskBuilder fit with wrong graph.""" + mask_builder = KNNAreaMaskBuilder("wrong_nodes") + with pytest.raises(AssertionError): + mask_builder.fit(graph_with_nodes) diff --git a/tests/nodes/test_cutout_nodes.py b/tests/nodes/test_cutout_nodes.py new file mode 100644 index 0000000..8ce6613 --- /dev/null +++ b/tests/nodes/test_cutout_nodes.py @@ -0,0 +1,60 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import pytest +import torch +from omegaconf import OmegaConf +from torch_geometric.data import HeteroData + +from anemoi.graphs.nodes.attributes import AreaWeights +from anemoi.graphs.nodes.attributes import UniformWeights +from anemoi.graphs.nodes.builders import from_file + + +def test_init(mocker, mock_zarr_dataset_cutout): + """Test ZarrDatasetNodes initialization with cutout.""" + mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset_cutout) + node_builder = from_file.ZarrDatasetNodes( + OmegaConf.create({"cutout": ["lam.zarr", "global.zarr"]}), name="test_nodes" + ) + + assert isinstance(node_builder, from_file.BaseNodeBuilder) + assert isinstance(node_builder, from_file.ZarrDatasetNodes) + + +def test_register_nodes(mocker, mock_zarr_dataset_cutout): + """Test ZarrDatasetNodes register correctly the nodes with cutout operation.""" + mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset_cutout) + node_builder = from_file.ZarrDatasetNodes( + OmegaConf.create({"cutout": ["lam.zarr", "global.zarr"]}), name="test_nodes" + ) + graph = HeteroData() + + graph = node_builder.register_nodes(graph) + + assert graph["test_nodes"].x is not None + assert isinstance(graph["test_nodes"].x, torch.Tensor) + assert graph["test_nodes"].x.shape == (mock_zarr_dataset_cutout.num_nodes, 2) + assert graph["test_nodes"].node_type == "ZarrDatasetNodes" + + +@pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) +def test_register_attributes(mocker, mock_zarr_dataset_cutout, graph_with_nodes: HeteroData, attr_class): + """Test ZarrDatasetNodes register correctly the weights with cutout operation.""" + mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset_cutout) + node_builder = from_file.ZarrDatasetNodes( + OmegaConf.create({"cutout": ["lam.zarr", "global.zarr"]}), name="test_nodes" + ) + config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}} + + graph = node_builder.register_attributes(graph_with_nodes, config) + + assert graph["test_nodes"]["test_attr"] is not None + assert isinstance(graph["test_nodes"]["test_attr"], torch.Tensor) + assert graph["test_nodes"]["test_attr"].shape[0] == graph["test_nodes"].x.shape[0] diff --git a/tests/nodes/test_healpix.py b/tests/nodes/test_healpix.py new file mode 100644 index 0000000..384d13a --- /dev/null +++ b/tests/nodes/test_healpix.py @@ -0,0 +1,60 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import pytest +import torch +from torch_geometric.data import HeteroData + +from anemoi.graphs.nodes.attributes import AreaWeights +from anemoi.graphs.nodes.attributes import UniformWeights +from anemoi.graphs.nodes.builders.base import BaseNodeBuilder +from anemoi.graphs.nodes.builders.from_healpix import HEALPixNodes + + +@pytest.mark.parametrize("resolution", [2, 5, 7]) +def test_init(resolution: int): + """Test HEALPixNodes initialization.""" + node_builder = HEALPixNodes(resolution, "test_nodes") + assert isinstance(node_builder, BaseNodeBuilder) + assert isinstance(node_builder, HEALPixNodes) + + +@pytest.mark.parametrize("resolution", ["2", 4.3, -7]) +def test_fail_init(resolution: int): + """Test HEALPixNodes initialization with invalid resolution.""" + with pytest.raises(AssertionError): + HEALPixNodes(resolution, "test_nodes") + + +@pytest.mark.parametrize("resolution", [2, 5, 7]) +def test_register_nodes(resolution: int): + """Test HEALPixNodes register correctly the nodes.""" + node_builder = HEALPixNodes(resolution, "test_nodes") + graph = HeteroData() + + graph = node_builder.register_nodes(graph) + + assert graph["test_nodes"].x is not None + assert isinstance(graph["test_nodes"].x, torch.Tensor) + assert graph["test_nodes"].x.shape[1] == 2 + assert graph["test_nodes"].node_type == "HEALPixNodes" + + +@pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) +@pytest.mark.parametrize("resolution", [2, 5, 7]) +def test_register_attributes(graph_with_nodes: HeteroData, attr_class, resolution: int): + """Test HEALPixNodes register correctly the weights.""" + node_builder = HEALPixNodes(resolution, "test_nodes") + config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}} + + graph = node_builder.register_attributes(graph_with_nodes, config) + + assert graph["test_nodes"]["test_attr"] is not None + assert isinstance(graph["test_nodes"]["test_attr"], torch.Tensor) + assert graph["test_nodes"]["test_attr"].shape[0] == graph["test_nodes"].x.shape[0] diff --git a/tests/nodes/test_hex_nodes.py b/tests/nodes/test_hex_nodes.py new file mode 100644 index 0000000..81b909c --- /dev/null +++ b/tests/nodes/test_hex_nodes.py @@ -0,0 +1,43 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import pytest +import torch +from torch_geometric.data import HeteroData + +from anemoi.graphs.nodes import HexNodes +from anemoi.graphs.nodes.builders.base import BaseNodeBuilder + + +@pytest.mark.parametrize("resolution", [0, 2]) +def test_init(resolution: list[int]): + """Test TrirefinedIcosahedralNodes initialization.""" + + node_builder = HexNodes(resolution, "test_nodes") + assert isinstance(node_builder, BaseNodeBuilder) + assert isinstance(node_builder, HexNodes) + + +def test_get_coordinates(): + """Test get_coordinates method.""" + node_builder = HexNodes(0, "test_nodes") + coords = node_builder.get_coordinates() + assert isinstance(coords, torch.Tensor) + assert coords.shape == (122, 2) + + +def test_update_graph(): + """Test update_graph method.""" + node_builder = HexNodes(0, "test_nodes") + graph = HeteroData() + graph = node_builder.update_graph(graph, {}) + assert "_resolutions" in graph["test_nodes"] + assert "_nx_graph" in graph["test_nodes"] + assert "_node_ordering" in graph["test_nodes"] + assert len(graph["test_nodes"]["_node_ordering"]) == graph["test_nodes"].num_nodes diff --git a/tests/nodes/test_node_attributes.py b/tests/nodes/test_node_attributes.py index 7347d88..ba824a4 100644 --- a/tests/nodes/test_node_attributes.py +++ b/tests/nodes/test_node_attributes.py @@ -1,3 +1,12 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + import pytest import torch from torch_geometric.data import HeteroData diff --git a/tests/nodes/test_npz.py b/tests/nodes/test_npz.py index 95d09c0..9f0c85c 100644 --- a/tests/nodes/test_npz.py +++ b/tests/nodes/test_npz.py @@ -1,10 +1,19 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + import pytest import torch from torch_geometric.data import HeteroData from anemoi.graphs.nodes.attributes import AreaWeights from anemoi.graphs.nodes.attributes import UniformWeights -from anemoi.graphs.nodes.builder import NPZFileNodes +from anemoi.graphs.nodes.builders.from_file import NPZFileNodes @pytest.mark.parametrize("resolution", ["o16", "o48", "5km5"]) diff --git a/tests/nodes/test_tri_nodes.py b/tests/nodes/test_tri_nodes.py new file mode 100644 index 0000000..8114ab7 --- /dev/null +++ b/tests/nodes/test_tri_nodes.py @@ -0,0 +1,43 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import pytest +import torch +from torch_geometric.data import HeteroData + +from anemoi.graphs.nodes import TriNodes +from anemoi.graphs.nodes.builders.base import BaseNodeBuilder + + +@pytest.mark.parametrize("resolution", [0, 2]) +def test_init(resolution: list[int]): + """Test TrirefinedIcosahedralNodes initialization.""" + + node_builder = TriNodes(resolution, "test_nodes") + assert isinstance(node_builder, BaseNodeBuilder) + assert isinstance(node_builder, TriNodes) + + +def test_get_coordinates(): + """Test get_coordinates method.""" + node_builder = TriNodes(2, "test_nodes") + coords = node_builder.get_coordinates() + assert isinstance(coords, torch.Tensor) + assert coords.shape == (162, 2) + + +def test_update_graph(): + """Test update_graph method.""" + node_builder = TriNodes(1, "test_nodes") + graph = HeteroData() + graph = node_builder.update_graph(graph, {}) + assert "_resolutions" in graph["test_nodes"] + assert "_nx_graph" in graph["test_nodes"] + assert "_node_ordering" in graph["test_nodes"] + assert len(graph["test_nodes"]["_node_ordering"]) == graph["test_nodes"].num_nodes diff --git a/tests/nodes/test_zarr.py b/tests/nodes/test_zarr.py index e3a2687..b8656b7 100644 --- a/tests/nodes/test_zarr.py +++ b/tests/nodes/test_zarr.py @@ -1,47 +1,57 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + import pytest import torch import zarr from torch_geometric.data import HeteroData -from anemoi.graphs.nodes import builder from anemoi.graphs.nodes.attributes import AreaWeights from anemoi.graphs.nodes.attributes import UniformWeights +from anemoi.graphs.nodes.builders import from_file def test_init(mocker, mock_zarr_dataset): """Test ZarrDatasetNodes initialization.""" - mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset) - node_builder = builder.ZarrDatasetNodes("dataset.zarr", name="test_nodes") + mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset) + node_builder = from_file.ZarrDatasetNodes("dataset.zarr", name="test_nodes") - assert isinstance(node_builder, builder.BaseNodeBuilder) - assert isinstance(node_builder, builder.ZarrDatasetNodes) + assert isinstance(node_builder, from_file.BaseNodeBuilder) + assert isinstance(node_builder, from_file.ZarrDatasetNodes) -def test_fail_init(): - """Test ZarrDatasetNodes initialization with invalid resolution.""" +def test_fail(): + """Test ZarrDatasetNodes with invalid dataset.""" + node_builder = from_file.ZarrDatasetNodes("invalid_path.zarr", name="test_nodes") with pytest.raises(zarr.errors.PathNotFoundError): - builder.ZarrDatasetNodes("invalid_path.zarr", name="test_nodes") + node_builder.update_graph(HeteroData()) def test_register_nodes(mocker, mock_zarr_dataset): """Test ZarrDatasetNodes register correctly the nodes.""" - mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset) - node_builder = builder.ZarrDatasetNodes("dataset.zarr", name="test_nodes") + mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset) + node_builder = from_file.ZarrDatasetNodes("dataset.zarr", name="test_nodes") graph = HeteroData() graph = node_builder.register_nodes(graph) assert graph["test_nodes"].x is not None assert isinstance(graph["test_nodes"].x, torch.Tensor) - assert graph["test_nodes"].x.shape == (node_builder.ds.num_nodes, 2) + assert graph["test_nodes"].x.shape == (mock_zarr_dataset.num_nodes, 2) assert graph["test_nodes"].node_type == "ZarrDatasetNodes" @pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) def test_register_attributes(mocker, graph_with_nodes: HeteroData, attr_class): """Test ZarrDatasetNodes register correctly the weights.""" - mocker.patch.object(builder, "open_dataset", return_value=None) - node_builder = builder.ZarrDatasetNodes("dataset.zarr", name="test_nodes") + mocker.patch.object(from_file, "open_dataset", return_value=None) + node_builder = from_file.ZarrDatasetNodes("dataset.zarr", name="test_nodes") config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}} graph = node_builder.register_attributes(graph_with_nodes, config) diff --git a/tests/test_create.py b/tests/test_create.py new file mode 100644 index 0000000..2c5f700 --- /dev/null +++ b/tests/test_create.py @@ -0,0 +1,56 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from pathlib import Path + +import pytest +import torch +from torch_geometric.data import HeteroData + +from anemoi.graphs.create import GraphCreator + + +class TestGraphCreator: + + @pytest.mark.parametrize("name", ["graph.pt", None]) + def test_generate_graph(self, config_file: tuple[Path, str], mock_grids_path: tuple[str, int], name: str): + """Test GraphCreator workflow.""" + tmp_path, config_name = config_file + graph_path = tmp_path / name if isinstance(name, str) else None + config_path = tmp_path / config_name + + graph = GraphCreator(config=config_path).create(save_path=graph_path) + + assert isinstance(graph, HeteroData) + assert "test_nodes" in graph.node_types + assert ("test_nodes", "to", "test_nodes") in graph.edge_types + + for nodes in graph.node_stores: + for node_attr in nodes.node_attrs(): + assert isinstance(nodes[node_attr], torch.Tensor) + assert nodes[node_attr].dtype in [torch.int32, torch.float32] + + for edges in graph.edge_stores: + for edge_attr in edges.edge_attrs(): + assert isinstance(edges[edge_attr], torch.Tensor) + assert edges[edge_attr].dtype in [torch.int32, torch.float32] + + for nodes in graph.node_stores: + for node_attr in nodes.node_attrs(): + assert not node_attr.startswith("_") + for edges in graph.edge_stores: + for edge_attr in edges.edge_attrs(): + assert not edge_attr.startswith("_") + + if graph_path is not None: + assert graph_path.exists() + graph_saved = torch.load(graph_path) + assert graph.node_types == graph_saved.node_types + assert graph.edge_types == graph_saved.edge_types diff --git a/tests/test_graphs.py b/tests/test_graphs.py deleted file mode 100644 index ba2704f..0000000 --- a/tests/test_graphs.py +++ /dev/null @@ -1,38 +0,0 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -from pathlib import Path - -import torch -from torch_geometric.data import HeteroData - -from anemoi.graphs import create - - -def test_graphs(config_file: tuple[Path, str], mock_grids_path: tuple[str, int]): - """Test GraphCreator workflow.""" - tmp_path, config_name = config_file - graph_path = tmp_path / "graph.pt" - config_path = tmp_path / config_name - - create.GraphCreator(graph_path, config_path).create() - - graph = torch.load(graph_path) - assert isinstance(graph, HeteroData) - assert "test_nodes" in graph.node_types - assert ("test_nodes", "to", "test_nodes") in graph.edge_types - - for nodes in graph.node_stores: - for node_attr in nodes.node_attrs(): - assert isinstance(nodes[node_attr], torch.Tensor) - assert nodes[node_attr].dtype in [torch.int32, torch.float32] - - for edges in graph.edge_stores: - for edge_attr in edges.edge_attrs(): - assert isinstance(edges[edge_attr], torch.Tensor) - assert edges[edge_attr].dtype in [torch.int32, torch.float32] diff --git a/tests/test_normaliser.py b/tests/test_normaliser.py new file mode 100644 index 0000000..83ca761 --- /dev/null +++ b/tests/test_normaliser.py @@ -0,0 +1,64 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import numpy as np +import pytest + +from anemoi.graphs.normalise import NormaliserMixin + + +@pytest.mark.parametrize("norm", ["l1", "l2", "unit-max", "unit-range", "unit-std"]) +def test_normaliser(norm: str): + """Test NormaliserMixin normalise method.""" + + class Normaliser(NormaliserMixin): + def __init__(self, norm): + self.norm = norm + + def __call__(self, data): + return self.normalise(data) + + normaliser = Normaliser(norm=norm) + data = np.random.rand(10, 5) + normalised_data = normaliser(data) + assert isinstance(normalised_data, np.ndarray) + assert normalised_data.shape == data.shape + + +@pytest.mark.parametrize("norm", ["l3", "invalid"]) +def test_normaliser_wrong_norm(norm: str): + """Test NormaliserMixin normalise method.""" + + class Normaliser(NormaliserMixin): + def __init__(self, norm: str): + self.norm = norm + + def __call__(self, data): + return self.normalise(data) + + with pytest.raises(ValueError): + normaliser = Normaliser(norm=norm) + data = np.random.rand(10, 5) + normaliser(data) + + +def test_normaliser_wrong_inheritance(): + """Test NormaliserMixin normalise method.""" + + class Normaliser(NormaliserMixin): + def __init__(self, attr): + self.attr = attr + + def __call__(self, data): + return self.normalise(data) + + with pytest.raises(AttributeError): + normaliser = Normaliser(attr="attr_name") + data = np.random.rand(10, 5) + normaliser(data) diff --git a/tests/test_normalizer.py b/tests/test_normalizer.py deleted file mode 100644 index c63acce..0000000 --- a/tests/test_normalizer.py +++ /dev/null @@ -1,55 +0,0 @@ -import numpy as np -import pytest - -from anemoi.graphs.normalizer import NormalizerMixin - - -@pytest.mark.parametrize("norm", ["l1", "l2", "unit-max", "unit-std"]) -def test_normalizer(norm: str): - """Test NormalizerMixin normalize method.""" - - class Normalizer(NormalizerMixin): - def __init__(self, norm): - self.norm = norm - - def __call__(self, data): - return self.normalize(data) - - normalizer = Normalizer(norm=norm) - data = np.random.rand(10, 5) - normalized_data = normalizer(data) - assert isinstance(normalized_data, np.ndarray) - assert normalized_data.shape == data.shape - - -@pytest.mark.parametrize("norm", ["l3", "invalid"]) -def test_normalizer_wrong_norm(norm: str): - """Test NormalizerMixin normalize method.""" - - class Normalizer(NormalizerMixin): - def __init__(self, norm: str): - self.norm = norm - - def __call__(self, data): - return self.normalize(data) - - with pytest.raises(ValueError): - normalizer = Normalizer(norm=norm) - data = np.random.rand(10, 5) - normalizer(data) - - -def test_normalizer_wrong_inheritance(): - """Test NormalizerMixin normalize method.""" - - class Normalizer(NormalizerMixin): - def __init__(self, attr): - self.attr = attr - - def __call__(self, data): - return self.normalize(data) - - with pytest.raises(AttributeError): - normalizer = Normalizer(attr="attr_name") - data = np.random.rand(10, 5) - normalizer(data)