From bd30eee441b540ac82835c11840c8e87437fe754 Mon Sep 17 00:00:00 2001
From: Paolo Davini
Date: Thu, 17 Oct 2024 17:21:52 +0200
Subject: [PATCH] linting
---
smmregrid/cdo_weights.py | 30 +++++++++---------
smmregrid/checker.py | 22 ++++++-------
smmregrid/gridinspector.py | 21 +++++++------
smmregrid/gridtype.py | 33 ++++++++++----------
smmregrid/log.py | 7 +++--
smmregrid/regrid.py | 64 ++++++++++++++++++--------------------
6 files changed, 89 insertions(+), 88 deletions(-)
diff --git a/smmregrid/cdo_weights.py b/smmregrid/cdo_weights.py
index e2cb9c1..f7605dc 100644
--- a/smmregrid/cdo_weights.py
+++ b/smmregrid/cdo_weights.py
@@ -11,10 +11,12 @@
from .weights import compute_weights_matrix3d, compute_weights_matrix, mask_weights, check_mask
from .log import setup_logger
+
def worker(wlist, nnn, *args, **kwargs):
"""Run a worker process"""
wlist[nnn] = cdo_generate_weights2d(*args, **kwargs).compute()
+
def cdo_generate_weights(source_grid, target_grid, method="con", extrapolate=True,
remap_norm="fracarea", remap_area_min=0.0, icongridpath=None,
gridpath=None, extra=None, cdo_extra=None, cdo_options=None, vertical_dim=None,
@@ -79,7 +81,7 @@ def cdo_generate_weights(source_grid, target_grid, method="con", extrapolate=Tru
if not vertical_dim in sgrid:
raise KeyError(f'Cannot find vertical dim {vertical_dim} in {list(sgrid.dims)}')
-
+
nvert = sgrid[vertical_dim].values.size
loggy.info('Vertical dimensions has length: %s', nvert)
@@ -101,17 +103,17 @@ def cdo_generate_weights(source_grid, target_grid, method="con", extrapolate=Tru
ppp = Process(target=worker,
args=(wlist, lev, source_grid, target_grid),
kwargs={
- "method": method,
- "extrapolate": extrapolate,
- "remap_norm": remap_norm,
- "remap_area_min": remap_area_min,
- "icongridpath": icongridpath,
- "gridpath": gridpath,
- "cdo_extra": cdo_extra + cdo_extra_vertical,
- "cdo_options": cdo_options,
- "cdo": cdo,
- "nproc": nproc
- })
+ "method": method,
+ "extrapolate": extrapolate,
+ "remap_norm": remap_norm,
+ "remap_area_min": remap_area_min,
+ "icongridpath": icongridpath,
+ "gridpath": gridpath,
+ "cdo_extra": cdo_extra + cdo_extra_vertical,
+ "cdo_options": cdo_options,
+ "cdo": cdo,
+ "nproc": nproc
+ })
ppp.start()
processes.append(ppp)
@@ -125,7 +127,7 @@ def cdo_generate_weights(source_grid, target_grid, method="con", extrapolate=Tru
weights = mask_weights(weights, weights_matrix, vertical_dim)
masked = check_mask(weights, vertical_dim)
masked = [int(x) for x in masked] # convert to list of int
- masked_xa = xarray.DataArray(masked,
+ masked_xa = xarray.DataArray(masked,
coords={vertical_dim: range(0, len(masked))},
name="dst_grid_masked")
@@ -187,7 +189,7 @@ def cdo_generate_weights2d(source_grid, target_grid, method="con", extrapolate=T
source_grid_file = tempfile.NamedTemporaryFile()
source_grid.to_netcdf(source_grid_file.name)
sgrid = source_grid_file.name
-
+
if isinstance(target_grid, str):
tgrid = target_grid
else:
diff --git a/smmregrid/checker.py b/smmregrid/checker.py
index 76cfbfd..6906d8d 100644
--- a/smmregrid/checker.py
+++ b/smmregrid/checker.py
@@ -47,9 +47,9 @@ def check_cdo_regrid(finput, ftarget, remap_method='con', access='Dataset',
smmvar = find_var(xfield)
cdovar = find_var(cdofield)
- #if len(smmvar) == 1 and access == 'DataArray':
+ # if len(smmvar) == 1 and access == 'DataArray':
# xfield = xfield[smmvar[0]]
- #if len(cdovar) == 1 and access == 'DataArray':
+ # if len(cdovar) == 1 and access == 'DataArray':
# cdofield = cdofield[cdovar[0]]
# interpolation with smmregrid (CDO-based)
@@ -98,26 +98,26 @@ def check_cdo_regrid_levels(finput, ftarget, vertical_dim, levels, remap_method=
smmvar = find_var(xfield)
cdovar = find_var(cdofield)
- #if len(smmvar) == 1 and access == 'DataArray':
+ # if len(smmvar) == 1 and access == 'DataArray':
# xfield = xfield[smmvar[0]]
- #if len(cdovar) == 1 and access == 'DataArray':
+ # if len(cdovar) == 1 and access == 'DataArray':
# cdofield = cdofield[cdovar[0]]
# compute weights
if vertical_dim == 'plev':
wfield = cdo_generate_weights(finput, ftarget,
- method=remap_method)
+ method=remap_method)
else:
wfield = cdo_generate_weights(finput, ftarget,
- method=remap_method,
- vertical_dim=vertical_dim)
-
+ method=remap_method,
+ vertical_dim=vertical_dim)
+
# Pass full 3D weights
interpolator = Regridder(weights=wfield)
# Add a helper idx_3d coordinate (unclear why it was here)
- #idx = list(range(0, len(xfield.coords[vertical_dim])))
- #xfield = xfield.assign_coords(idx_3d=(vertical_dim, idx))
+ # idx = list(range(0, len(xfield.coords[vertical_dim])))
+ # xfield = xfield.assign_coords(idx_3d=(vertical_dim, idx))
# subselect some levels
xfield = xfield.isel(**{vertical_dim: levels})
@@ -131,4 +131,4 @@ def check_cdo_regrid_levels(finput, ftarget, vertical_dim, levels, remap_method=
# check if arrays are equal with numerical tolerance
checker = np.allclose(cdofield, rfield, equal_nan=True)
- return checker
\ No newline at end of file
+ return checker
diff --git a/smmregrid/gridinspector.py b/smmregrid/gridinspector.py
index 01914ce..1165fca 100644
--- a/smmregrid/gridinspector.py
+++ b/smmregrid/gridinspector.py
@@ -4,13 +4,14 @@
from smmregrid.log import setup_logger
from .gridtype import GridType
+
class GridInspector():
def __init__(self, data, cdo_weights=False, extra_dims=None,
clean=True, loglevel='warning'):
"""
GridInspector class to detect information on the data, based on GridType class
-
+
Parameters:
data (xr.Datase or xr.DataArray): The input dataset.
clean (bool): apply the cleaning of grids which are assumed to be not relevant
@@ -25,7 +26,7 @@ def __init__(self, data, cdo_weights=False, extra_dims=None,
self.cdo_weights = cdo_weights
self.clean = clean
self.grids = [] # List to hold all grids info
-
+
def _inspect_grids(self):
"""
Inspects the dataset and identifies different grids.
@@ -41,14 +42,14 @@ def _inspect_grids(self):
for gridtype in self.grids:
gridtype.identify_variables(self.data)
- #gridtype.identify_sizes(self.data)
+ # gridtype.identify_sizes(self.data)
def _inspect_dataarray_grid(self, data_array):
"""
Helper method to inspect a single DataArray and identify its grid type.
"""
grid_key = tuple(data_array.dims)
- gridtype = GridType(dims=grid_key,extra_dims=self.extra_dims)
+ gridtype = GridType(dims=grid_key, extra_dims=self.extra_dims)
if gridtype not in self.grids:
self.grids.append(gridtype)
@@ -56,7 +57,7 @@ def _inspect_weights(self):
"""
Return basic information about CDO weights
"""
-
+
gridtype = GridType(dims=[], weights=self.data)
# get vertical info from the weights coords if available
@@ -82,7 +83,7 @@ def get_grid_info(self):
if self.clean:
self._clean_grids()
- #self.loggy.info('Grids that have been identifed are: %s', self.grids.)
+ # self.loggy.info('Grids that have been identifed are: %s', self.grids.)
for gridtype in self.grids:
self.loggy.debug('More details on gridtype %s:', gridtype.dims)
if gridtype.horizontal_dims:
@@ -92,7 +93,7 @@ def get_grid_info(self):
self.loggy.debug(' Variables are: %s', list(gridtype.variables.keys()))
self.loggy.debug(' Bounds are: %s', gridtype.bounds)
return self.grids
-
+
def _clean_grids(self):
"""
Remove degenerate grids which are used by not relevant variables
@@ -108,11 +109,11 @@ def _clean_grids(self):
if any('bounds' in variable for variable in gridtype.variables):
removed.append(gridtype) # Add to removed list
self.loggy.info('Removing the grid defined by %s with variables containing "bounds"', gridtype.dims)
-
+
for remove in removed:
self.grids.remove(remove)
- #def get_variable_grids(self):
+ # def get_variable_grids(self):
# """
# Return a dictionary with the variable - grids pairs
# """
@@ -120,4 +121,4 @@ def _clean_grids(self):
# for gridtype in self.grids:
# for variable in gridtype.variables:
# all_grids[variable] = grid_dims
- # return all_grids
\ No newline at end of file
+ # return all_grids
diff --git a/smmregrid/gridtype.py b/smmregrid/gridtype.py
index 2a18744..9083b4d 100644
--- a/smmregrid/gridtype.py
+++ b/smmregrid/gridtype.py
@@ -4,12 +4,13 @@
# default spatial dimensions and vertical coordinates
DEFAULT_DIMS = {
'horizontal': ['i', 'j', 'x', 'y', 'lon', 'lat', 'longitude', 'latitude',
- 'cell', 'cells', 'ncells', 'values', 'value', 'nod2', 'pix', 'elem',
- 'nav_lon', 'nav_lat'],
+ 'cell', 'cells', 'ncells', 'values', 'value', 'nod2', 'pix', 'elem',
+ 'nav_lon', 'nav_lat'],
'vertical': ['lev', 'nz1', 'nz', 'depth', 'depth_full', 'depth_half'],
'time': ['time']
}
+
class GridType:
def __init__(self, dims, extra_dims=None, weights=None):
"""
@@ -45,10 +46,10 @@ def _handle_default_dimensions(self, extra_dims):
update_dims = DEFAULT_DIMS
for dim in extra_dims.keys():
- if extra_dims[dim]:
+ if extra_dims[dim]:
update_dims[dim] = update_dims[dim] + extra_dims[dim]
return update_dims
-
+
def __eq__(self, other):
# so far equality based on dims only
if isinstance(other, GridType):
@@ -65,12 +66,12 @@ def _identify_dims(self, axis, default_dims):
"""
identified_dims = list(set(self.dims).intersection(default_dims[axis]))
if axis == 'vertical':
- if len(identified_dims)>1:
+ if len(identified_dims) > 1:
raise ValueError(f'Only one vertical dimension can be processed at the time: check {identified_dims}')
- if len(identified_dims)==1:
- identified_dims=identified_dims[0] #unlist the single vertical dimension
+ if len(identified_dims) == 1:
+ identified_dims = identified_dims[0] # unlist the single vertical dimension
return identified_dims if identified_dims else None
-
+
def _identify_spatial_bounds(self, data):
"""
Find all bounds variables in the dataset by looking for variables
@@ -81,13 +82,13 @@ def _identify_spatial_bounds(self, data):
for var in data.data_vars:
if (var.endswith('_bnds') or var.endswith('_bounds')) and 'time' not in var:
# store all the bounds fro each grid. not fancy, but effective
- #boundvar = var.split('_')[0]
- #if boundvar in self.dims:
+ # boundvar = var.split('_')[0]
+ # if boundvar in self.dims:
bounds_variables.append(var)
return bounds_variables
-
- #def identify_sizes(self, data):
+
+ # def identify_sizes(self, data):
# """
# Idenfity the sizes of the dataset
# """
@@ -115,16 +116,16 @@ def identify_variables(self, data):
for var in data.data_vars:
self._identify_variable(data[var], var)
self.bounds = self._identify_spatial_bounds(data)
-
+
elif isinstance(data, xr.DataArray):
self._identify_variable(data)
-
+
# def _identify_grid_type(self, grid_key):
# """
# Determines the grid type (e.g., structured, unstructured, curvilinear).
# This could be expanded based on more detailed metadata inspection.
# """
- # horizontal_dims = self._identify_horizontal_dims(grid_key)
+ # horizontal_dims = self._identify_horizontal_dims(grid_key)
# if 'mesh' in self.dataset.attrs.get('grid_type', '').lower():
# return 'unstructured'
# elif any('lat' in coord and 'lon' in coord for coord in horizontal_dims):
@@ -132,4 +133,4 @@ def identify_variables(self, data):
# elif 'curvilinear' in self.dataset.attrs.get('grid_type', '').lower():
# return 'curvilinear'
# else:
- # return 'unknown'
\ No newline at end of file
+ # return 'unknown'
diff --git a/smmregrid/log.py b/smmregrid/log.py
index 36610de..f2bb55b 100644
--- a/smmregrid/log.py
+++ b/smmregrid/log.py
@@ -13,7 +13,7 @@ def setup_logger(level=None, name=None):
logger = logging.getLogger(name) # Create a logger specific to your module
if logger.handlers:
- #logger.warning('Logging is already setup with name %s', name)
+ # logger.warning('Logging is already setup with name %s', name)
if level != logging.getLevelName(logger.getEffectiveLevel()):
logger.setLevel(loglev)
logger.info('Updating the log_level to %s', loglev)
@@ -30,12 +30,13 @@ def setup_logger(level=None, name=None):
# Create a handler for the logger
handler = logging.StreamHandler()
- #handler.setLevel(loglev) # Set the desired log level for the handler
+ # handler.setLevel(loglev) # Set the desired log level for the handler
handler.setFormatter(formatter) # Assign the formatter to the handler
logger.addHandler(handler) # Add the handler to the logger
return logger
+
def convert_logger(loglev=None):
"""Convert a string or integer to a valid logging level"""
@@ -67,4 +68,4 @@ def convert_logger(loglev=None):
loglev, loglev_default)
loglev = loglev_default
- return loglev
\ No newline at end of file
+ return loglev
diff --git a/smmregrid/regrid.py b/smmregrid/regrid.py
index 3483f13..63335c1 100644
--- a/smmregrid/regrid.py
+++ b/smmregrid/regrid.py
@@ -81,7 +81,7 @@ def __init__(self, source_grid=None, target_grid=None, weights=None,
raise ValueError(
"Either weights or source_grid/target_grid must be supplied"
)
-
+
# Check for deprecated 'vert_coord' argument
if vert_coord is not None:
warnings.warn(
@@ -91,7 +91,7 @@ def __init__(self, source_grid=None, target_grid=None, weights=None,
# If cdo_extra is not provided, use the value from extra
if vertical_dim is None:
vertical_dim = vert_coord
-
+
# Check for deprecated 'space_dim' argument
if space_dims is not None:
warnings.warn(
@@ -103,11 +103,10 @@ def __init__(self, source_grid=None, target_grid=None, weights=None,
self.loggy = setup_logger(level=loglevel, name='smmregrid.Regrid')
self.loglevel = loglevel
self.transpose = transpose
- self.vertical_dim = [vertical_dim] #need a list
+ self.vertical_dim = [vertical_dim] # need a list
if vertical_dim:
self.loggy.info('Forcing vertical_dim: expecting a single gridtype dataset')
-
# Is there already a weights file?
if weights is not None:
self.grids = self._gridtype_from_weights(weights)
@@ -123,14 +122,14 @@ def __init__(self, source_grid=None, target_grid=None, weights=None,
self.grids = self._gridtype_from_data(source_grid_array)
- len_grids = len(self.grids)
+ len_grids = len(self.grids)
if len_grids == 0:
raise KeyError('Cannot find any gridtype in your data, aborting!')
if len_grids == 1:
self.loggy.info('One gridtype found! Standard procedure')
else:
self.loggy.info('%s gridtypes found! We are in uncharted territory!', len_grids)
-
+
for gridtype in self.grids:
self.loggy.debug('Processing grids %s', gridtype.dims)
self.loggy.debug('Horizontal dimension is %s', gridtype.horizontal_dims)
@@ -139,7 +138,7 @@ def __init__(self, source_grid=None, target_grid=None, weights=None,
# always prefer to pass file (i.e. source_grid) when possible to cdo_generate_weights
# this will limit errors from xarray and speed up CDO itself
# it wil work only for single gridtype dataset
- if isinstance(source_grid, str) and len_grids==1:
+ if isinstance(source_grid, str) and len_grids == 1:
source_grid_array_to_cdo = source_grid
else:
# when feeding from xarray, select the variable and its bounds
@@ -153,7 +152,7 @@ def __init__(self, source_grid=None, target_grid=None, weights=None,
gridtype.weights = cdo_generate_weights(source_grid_array_to_cdo, target_grid, method=method,
vertical_dim=gridtype.vertical_dim,
cdo=cdo, loglevel=loglevel)
-
+
for gridtype in self.grids:
if gridtype.vertical_dim:
gridtype.weights_matrix = compute_weights_matrix3d(gridtype.weights, gridtype.vertical_dim)
@@ -168,31 +167,30 @@ def __init__(self, source_grid=None, target_grid=None, weights=None,
gridtype.weights = mask_weights(gridtype.weights, gridtype.weights_matrix, gridtype.vertical_dim)
gridtype.masked = check_mask(gridtype.weights, gridtype.vertical_dim)
-
def _gridtype_from_weights(self, weights):
"""
Initialize the gridtype reading from weights
"""
-
+
self.loggy.warning('Precomputed weights support so far single-gridtype datasets')
if not isinstance(weights, xarray.Dataset):
weights = xarray.open_mfdataset(weights)
grid_info = GridInspector(weights, cdo_weights=True, extra_dims={'vertical': self.vertical_dim},
- clean=False, loglevel=self.loglevel)
+ clean=False, loglevel=self.loglevel)
gridtype = grid_info.get_grid_info()
- #if not gridtype[0].dims:
+ # if not gridtype[0].dims:
# self.loggy.warning('Missing weights dimension information, support only single-gridtype datasets')
-
+
return gridtype
-
+
def _gridtype_from_data(self, source_grid_array):
"""
Initialize the gridtype reading from source_data
"""
- grid_info = GridInspector(source_grid_array, extra_dims={'vertical': self.vertical_dim},
+ grid_info = GridInspector(source_grid_array, extra_dims={'vertical': self.vertical_dim},
clean=True, loglevel=self.loglevel)
return grid_info.get_grid_info()
@@ -208,7 +206,7 @@ def regrid(self, source_data):
version of the source variable
"""
- # apply the regridder on each DataArray
+ # apply the regridder on each DataArray
if isinstance(source_data, xarray.Dataset):
out = source_data.map(self.regrid_array, keep_attrs=False)
@@ -225,7 +223,7 @@ def regrid(self, source_data):
def regrid_array(self, source_data):
"""Regridding selection through 2d and 3d arrays"""
- grid_inspect = GridInspector(source_data, clean=True,
+ grid_inspect = GridInspector(source_data, clean=True,
extra_dims={'vertical': self.vertical_dim}, loglevel=self.loglevel)
datagrids = grid_inspect.get_grid_info()
@@ -235,7 +233,7 @@ def regrid_array(self, source_data):
return self.regrid3d(source_data, datagridtype)
# 2d case
return self.regrid2d(source_data, datagridtype)
-
+
def _get_gridtype(self, datagridtype):
# special case for CDO weights without any dimensional information
@@ -244,7 +242,7 @@ def _get_gridtype(self, datagridtype):
self.loggy.warning('Assuming gridtype from data to be the same from weights')
self.grids[0].dims = datagridtype.dims
self.grids[0].horizontal_dims = datagridtype.horizontal_dims
-
+
# match the grid
gridtype = next((grid for grid in self.grids if grid == datagridtype), None)
@@ -271,7 +269,7 @@ def regrid3d(self, source_data, datagridtype):
if gridtype is None:
self.loggy.info('%s will be excluded from the output', source_data.name)
return xarray.DataArray(data=None)
-
+
# select the gridtype to be used
vertical_dim = gridtype.vertical_dim
weights = gridtype.weights
@@ -317,7 +315,7 @@ def regrid3d(self, source_data, datagridtype):
# get dimensional info on target grid. TODO: can be moved at the init?
target_gridtypes = GridInspector(data3d, clean=True, loglevel=self.loglevel).get_grid_info()
target_horizontal_dims = target_gridtypes[0].horizontal_dims
-
+
if self.transpose:
dims = list(data3d.dims)
index = min([i for i, s in enumerate(dims) if s in target_horizontal_dims])
@@ -347,16 +345,16 @@ def regrid2d(self, source_data, datagridtype):
if gridtype is None:
self.loggy.info('%s will be excluded from the output', source_data.name)
return xarray.DataArray(data=None)
-
+
return self.apply_weights(
- source_data,
- gridtype.weights,
- weights_matrix= gridtype.weights_matrix,
- masked= gridtype.masked,
- horizontal_dims=gridtype.horizontal_dims)
-
+ source_data,
+ gridtype.weights,
+ weights_matrix=gridtype.weights_matrix,
+ masked=gridtype.masked,
+ horizontal_dims=gridtype.horizontal_dims)
+
def apply_weights(self, source_data, weights, weights_matrix=None,
- masked=True, horizontal_dims=None):
+ masked=True, horizontal_dims=None):
"""
Apply the CDO weights ``weights`` to ``source_data``, performing a regridding operation
@@ -370,7 +368,6 @@ def apply_weights(self, source_data, weights, weights_matrix=None,
xarray.DataArray: Regridded version of the source dataset
"""
-
# Understand immediately if we need to return something or not
# This is done if we have bounds variables
if any(substring in source_data.name for substring in ["bnds", "bounds", "vertices"]):
@@ -416,11 +413,12 @@ def apply_weights(self, source_data, weights, weights_matrix=None,
axis_scale = 180.0 / math.pi # Weight lat/lon in radians
# Dimension on which we can produce the interpolation
- #if horizontal_dims is None:
+ # if horizontal_dims is None:
# horizontal_dims = default_horizontal_dims
if not any(x in source_data.dims for x in horizontal_dims):
- self.loggy.error("None of dimensions on which we can interpolate is found in the DataArray. Does your DataArray include any of these?")
+ self.loggy.error(
+ "None of dimensions on which we can interpolate is found in the DataArray. Does your DataArray include any of these?")
self.loggy.error(horizontal_dims)
self.loggy.error('smmregrid can identify only %s', source_data.dims)
raise KeyError('Dimensions mismatch')
@@ -520,7 +518,6 @@ def apply_weights(self, source_data, weights, weights_matrix=None,
return target_da
-
def regrid(source_data, target_grid=None, weights=None, transpose=True, cdo='cdo'):
"""
A simple regrid. Inefficient if you are regridding more than one dataset
@@ -541,7 +538,6 @@ def regrid(source_data, target_grid=None, weights=None, transpose=True, cdo='cdo
:class:`xarray.DataArray` with a regridded version of the source variable
"""
-
regridder = Regridder(source_data, target_grid=target_grid, weights=weights, cdo=cdo, transpose=transpose)
return regridder.regrid(source_data)