From 56d24a9fe02f3e97b92fa76ec4ba9aeb90702bbf Mon Sep 17 00:00:00 2001 From: Bouwe Andela Date: Mon, 9 Sep 2024 10:49:47 +0200 Subject: [PATCH] Add test and improve documentation --- doc/quickstart/find_data.rst | 2 +- esmvalcore/dataset.py | 91 +++++++++++++---------- tests/integration/dataset/test_dataset.py | 12 ++- 3 files changed, 62 insertions(+), 43 deletions(-) diff --git a/doc/quickstart/find_data.rst b/doc/quickstart/find_data.rst index 50486a2d3f..33135b719b 100644 --- a/doc/quickstart/find_data.rst +++ b/doc/quickstart/find_data.rst @@ -401,7 +401,7 @@ unstructured grids, which is required by many software packages or tools to work correctly and specifically by Iris to interpret the grid as a :ref:`mesh `. An example is the horizontal regridding of native ICON data to a regular grid. -While the :ref:`built-in regridding schemes ` +While the :ref:`built-in regridding schemes ` `linear` and `nearest` can handle unstructured grids (i.e., not UGRID-compliant) and meshes (i.e., UGRID-compliant), the `area_weighted` scheme requires the input data in UGRID format. This automatic UGRIDization is enabled by default, but can be switched off with diff --git a/esmvalcore/dataset.py b/esmvalcore/dataset.py index d3efdb4e96..60a1b5cb29 100644 --- a/esmvalcore/dataset.py +++ b/esmvalcore/dataset.py @@ -6,13 +6,15 @@ import re import textwrap import uuid +from collections.abc import Iterable from copy import deepcopy from fnmatch import fnmatchcase from itertools import groupby from pathlib import Path -from typing import Any, Iterator, Sequence, Union +from typing import Any, Iterator, Sequence, TypeVar, Union import dask +from dask.delayed import Delayed from iris.cube import Cube from esmvalcore import esgf, local @@ -80,8 +82,12 @@ def _ismatch(facet_value: FacetValue, pattern: FacetValue) -> bool: and fnmatchcase(facet_value, pattern)) -def _first(elems): - return elems[0] +T = TypeVar('T') + + +def _first(elems: Iterable[T]) -> T: + """Return the first element.""" + return next(iter(elems)) class Dataset: @@ -669,16 +675,16 @@ def files(self) -> Sequence[File]: def files(self, value): self._files = value - def load(self, compute=True) -> Cube: + def load(self, compute=True) -> Cube | Delayed: """Load dataset. Parameters ---------- compute: - If :obj:`True`, return the cube immediately. If :obj:`False`, - return a :class:`~dask.delayed.Delayed` object that can be used - to load the cube by calling its - :func:`~dask.delayed.Delayed.compute` method. Multiple datasets + If :obj:`True`, return the :class:`~iris.cube.Cube` immediately. + If :obj:`False`, return a :class:`~dask.delayed.Delayed` object + that can be used to load the cube by calling its + :meth:`~dask.delayed.Delayed.compute` method. Multiple datasets can be loaded in parallel by passing a list of such delayeds to :func:`dask.compute`. @@ -731,7 +737,14 @@ def _load(self) -> Cube: msg = "\n".join(lines) raise InputFilesNotFound(msg) + input_files = [ + file.local_file(self.session['download_dir']) if isinstance( + file, esgf.ESGFFile) else file for file in self.files + ] output_file = _get_output_file(self.facets, self.session.preproc_dir) + debug = self.session['save_intermediary_cubes'] + + # Load all input files and concatenate them. fix_dir_prefix = Path( self.session._fixed_file_dir, self._get_joined_summary_facets('_', join_lists=True) + '_', @@ -757,36 +770,6 @@ def _load(self) -> Cube: settings['concatenate'] = { 'check_level': self.session['check_level'] } - settings['cmor_check_metadata'] = { - 'check_level': self.session['check_level'], - 'cmor_table': self.facets['project'], - 'mip': self.facets['mip'], - 'frequency': self.facets['frequency'], - 'short_name': self.facets['short_name'], - } - if 'timerange' in self.facets: - settings['clip_timerange'] = { - 'timerange': self.facets['timerange'], - } - settings['fix_data'] = { - 'check_level': self.session['check_level'], - 'session': self.session, - **self.facets, - } - settings['cmor_check_data'] = { - 'check_level': self.session['check_level'], - 'cmor_table': self.facets['project'], - 'mip': self.facets['mip'], - 'frequency': self.facets['frequency'], - 'short_name': self.facets['short_name'], - } - - input_files = [ - file.local_file(self.session['download_dir']) if isinstance( - file, esgf.ESGFFile) else file for file in self.files - ] - - debug = self.session['save_intermediary_cubes'] result = [] for input_file in input_files: @@ -798,6 +781,7 @@ def _load(self) -> Cube: debug=debug, **settings['fix_file'], ) + # Multiple cubes may be present in a file. cubes = dask.delayed(preprocess)( files, 'load', @@ -806,6 +790,7 @@ def _load(self) -> Cube: debug=debug, **settings['load'], ) + # Combine the cubes into a single cube per file. cubes = dask.delayed(preprocess)( cubes, 'fix_metadata', @@ -817,6 +802,7 @@ def _load(self) -> Cube: cube = dask.delayed(_first)(cubes) result.append(cube) + # Concatenate the cubes from all files. result = dask.delayed(preprocess)( result, 'concatenate', @@ -825,7 +811,34 @@ def _load(self) -> Cube: debug=debug, **settings['concatenate'], ) - for step, kwargs in dict(tuple(settings.items())[4:]).items(): + + # At this point `result` is a list containing a single cube. Apply the + # remaining preprocessor functions to this cube. + settings.clear() + settings['cmor_check_metadata'] = { + 'check_level': self.session['check_level'], + 'cmor_table': self.facets['project'], + 'mip': self.facets['mip'], + 'frequency': self.facets['frequency'], + 'short_name': self.facets['short_name'], + } + if 'timerange' in self.facets: + settings['clip_timerange'] = { + 'timerange': self.facets['timerange'], + } + settings['fix_data'] = { + 'check_level': self.session['check_level'], + 'session': self.session, + **self.facets, + } + settings['cmor_check_data'] = { + 'check_level': self.session['check_level'], + 'cmor_table': self.facets['project'], + 'mip': self.facets['mip'], + 'frequency': self.facets['frequency'], + 'short_name': self.facets['short_name'], + } + for step, kwargs in settings.items(): result = dask.delayed(preprocess)( result, step, diff --git a/tests/integration/dataset/test_dataset.py b/tests/integration/dataset/test_dataset.py index 0c94dc8c48..19d5aa1f4a 100644 --- a/tests/integration/dataset/test_dataset.py +++ b/tests/integration/dataset/test_dataset.py @@ -3,6 +3,7 @@ import iris.coords import iris.cube import pytest +from dask.delayed import Delayed from esmvalcore.config import CFG from esmvalcore.dataset import Dataset @@ -34,7 +35,8 @@ def example_data(tmp_path, monkeypatch): monkeypatch.setitem(CFG, 'output_dir', tmp_path / 'output_dir') -def test_load(example_data): +@pytest.mark.parametrize('lazy', [True, False]) +def test_load(example_data, lazy): tas = Dataset( short_name='tas', mip='Amon', @@ -51,7 +53,11 @@ def test_load(example_data): tas.find_files() print(tas.files) - cube = tas.load() - + if lazy: + result = tas.load(compute=False) + assert isinstance(result, Delayed) + cube = result.compute() + else: + cube = tas.load() assert isinstance(cube, iris.cube.Cube) assert cube.cell_measures()