From bd30eee441b540ac82835c11840c8e87437fe754 Mon Sep 17 00:00:00 2001 From: Paolo Davini Date: Thu, 17 Oct 2024 17:21:52 +0200 Subject: [PATCH] linting --- smmregrid/cdo_weights.py | 30 +++++++++--------- smmregrid/checker.py | 22 ++++++------- smmregrid/gridinspector.py | 21 +++++++------ smmregrid/gridtype.py | 33 ++++++++++---------- smmregrid/log.py | 7 +++-- smmregrid/regrid.py | 64 ++++++++++++++++++-------------------- 6 files changed, 89 insertions(+), 88 deletions(-) diff --git a/smmregrid/cdo_weights.py b/smmregrid/cdo_weights.py index e2cb9c1..f7605dc 100644 --- a/smmregrid/cdo_weights.py +++ b/smmregrid/cdo_weights.py @@ -11,10 +11,12 @@ from .weights import compute_weights_matrix3d, compute_weights_matrix, mask_weights, check_mask from .log import setup_logger + def worker(wlist, nnn, *args, **kwargs): """Run a worker process""" wlist[nnn] = cdo_generate_weights2d(*args, **kwargs).compute() + def cdo_generate_weights(source_grid, target_grid, method="con", extrapolate=True, remap_norm="fracarea", remap_area_min=0.0, icongridpath=None, gridpath=None, extra=None, cdo_extra=None, cdo_options=None, vertical_dim=None, @@ -79,7 +81,7 @@ def cdo_generate_weights(source_grid, target_grid, method="con", extrapolate=Tru if not vertical_dim in sgrid: raise KeyError(f'Cannot find vertical dim {vertical_dim} in {list(sgrid.dims)}') - + nvert = sgrid[vertical_dim].values.size loggy.info('Vertical dimensions has length: %s', nvert) @@ -101,17 +103,17 @@ def cdo_generate_weights(source_grid, target_grid, method="con", extrapolate=Tru ppp = Process(target=worker, args=(wlist, lev, source_grid, target_grid), kwargs={ - "method": method, - "extrapolate": extrapolate, - "remap_norm": remap_norm, - "remap_area_min": remap_area_min, - "icongridpath": icongridpath, - "gridpath": gridpath, - "cdo_extra": cdo_extra + cdo_extra_vertical, - "cdo_options": cdo_options, - "cdo": cdo, - "nproc": nproc - }) + "method": method, + "extrapolate": extrapolate, + "remap_norm": remap_norm, + "remap_area_min": remap_area_min, + "icongridpath": icongridpath, + "gridpath": gridpath, + "cdo_extra": cdo_extra + cdo_extra_vertical, + "cdo_options": cdo_options, + "cdo": cdo, + "nproc": nproc + }) ppp.start() processes.append(ppp) @@ -125,7 +127,7 @@ def cdo_generate_weights(source_grid, target_grid, method="con", extrapolate=Tru weights = mask_weights(weights, weights_matrix, vertical_dim) masked = check_mask(weights, vertical_dim) masked = [int(x) for x in masked] # convert to list of int - masked_xa = xarray.DataArray(masked, + masked_xa = xarray.DataArray(masked, coords={vertical_dim: range(0, len(masked))}, name="dst_grid_masked") @@ -187,7 +189,7 @@ def cdo_generate_weights2d(source_grid, target_grid, method="con", extrapolate=T source_grid_file = tempfile.NamedTemporaryFile() source_grid.to_netcdf(source_grid_file.name) sgrid = source_grid_file.name - + if isinstance(target_grid, str): tgrid = target_grid else: diff --git a/smmregrid/checker.py b/smmregrid/checker.py index 76cfbfd..6906d8d 100644 --- a/smmregrid/checker.py +++ b/smmregrid/checker.py @@ -47,9 +47,9 @@ def check_cdo_regrid(finput, ftarget, remap_method='con', access='Dataset', smmvar = find_var(xfield) cdovar = find_var(cdofield) - #if len(smmvar) == 1 and access == 'DataArray': + # if len(smmvar) == 1 and access == 'DataArray': # xfield = xfield[smmvar[0]] - #if len(cdovar) == 1 and access == 'DataArray': + # if len(cdovar) == 1 and access == 'DataArray': # cdofield = cdofield[cdovar[0]] # interpolation with smmregrid (CDO-based) @@ -98,26 +98,26 @@ def check_cdo_regrid_levels(finput, ftarget, vertical_dim, levels, remap_method= smmvar = find_var(xfield) cdovar = find_var(cdofield) - #if len(smmvar) == 1 and access == 'DataArray': + # if len(smmvar) == 1 and access == 'DataArray': # xfield = xfield[smmvar[0]] - #if len(cdovar) == 1 and access == 'DataArray': + # if len(cdovar) == 1 and access == 'DataArray': # cdofield = cdofield[cdovar[0]] # compute weights if vertical_dim == 'plev': wfield = cdo_generate_weights(finput, ftarget, - method=remap_method) + method=remap_method) else: wfield = cdo_generate_weights(finput, ftarget, - method=remap_method, - vertical_dim=vertical_dim) - + method=remap_method, + vertical_dim=vertical_dim) + # Pass full 3D weights interpolator = Regridder(weights=wfield) # Add a helper idx_3d coordinate (unclear why it was here) - #idx = list(range(0, len(xfield.coords[vertical_dim]))) - #xfield = xfield.assign_coords(idx_3d=(vertical_dim, idx)) + # idx = list(range(0, len(xfield.coords[vertical_dim]))) + # xfield = xfield.assign_coords(idx_3d=(vertical_dim, idx)) # subselect some levels xfield = xfield.isel(**{vertical_dim: levels}) @@ -131,4 +131,4 @@ def check_cdo_regrid_levels(finput, ftarget, vertical_dim, levels, remap_method= # check if arrays are equal with numerical tolerance checker = np.allclose(cdofield, rfield, equal_nan=True) - return checker \ No newline at end of file + return checker diff --git a/smmregrid/gridinspector.py b/smmregrid/gridinspector.py index 01914ce..1165fca 100644 --- a/smmregrid/gridinspector.py +++ b/smmregrid/gridinspector.py @@ -4,13 +4,14 @@ from smmregrid.log import setup_logger from .gridtype import GridType + class GridInspector(): def __init__(self, data, cdo_weights=False, extra_dims=None, clean=True, loglevel='warning'): """ GridInspector class to detect information on the data, based on GridType class - + Parameters: data (xr.Datase or xr.DataArray): The input dataset. clean (bool): apply the cleaning of grids which are assumed to be not relevant @@ -25,7 +26,7 @@ def __init__(self, data, cdo_weights=False, extra_dims=None, self.cdo_weights = cdo_weights self.clean = clean self.grids = [] # List to hold all grids info - + def _inspect_grids(self): """ Inspects the dataset and identifies different grids. @@ -41,14 +42,14 @@ def _inspect_grids(self): for gridtype in self.grids: gridtype.identify_variables(self.data) - #gridtype.identify_sizes(self.data) + # gridtype.identify_sizes(self.data) def _inspect_dataarray_grid(self, data_array): """ Helper method to inspect a single DataArray and identify its grid type. """ grid_key = tuple(data_array.dims) - gridtype = GridType(dims=grid_key,extra_dims=self.extra_dims) + gridtype = GridType(dims=grid_key, extra_dims=self.extra_dims) if gridtype not in self.grids: self.grids.append(gridtype) @@ -56,7 +57,7 @@ def _inspect_weights(self): """ Return basic information about CDO weights """ - + gridtype = GridType(dims=[], weights=self.data) # get vertical info from the weights coords if available @@ -82,7 +83,7 @@ def get_grid_info(self): if self.clean: self._clean_grids() - #self.loggy.info('Grids that have been identifed are: %s', self.grids.) + # self.loggy.info('Grids that have been identifed are: %s', self.grids.) for gridtype in self.grids: self.loggy.debug('More details on gridtype %s:', gridtype.dims) if gridtype.horizontal_dims: @@ -92,7 +93,7 @@ def get_grid_info(self): self.loggy.debug(' Variables are: %s', list(gridtype.variables.keys())) self.loggy.debug(' Bounds are: %s', gridtype.bounds) return self.grids - + def _clean_grids(self): """ Remove degenerate grids which are used by not relevant variables @@ -108,11 +109,11 @@ def _clean_grids(self): if any('bounds' in variable for variable in gridtype.variables): removed.append(gridtype) # Add to removed list self.loggy.info('Removing the grid defined by %s with variables containing "bounds"', gridtype.dims) - + for remove in removed: self.grids.remove(remove) - #def get_variable_grids(self): + # def get_variable_grids(self): # """ # Return a dictionary with the variable - grids pairs # """ @@ -120,4 +121,4 @@ def _clean_grids(self): # for gridtype in self.grids: # for variable in gridtype.variables: # all_grids[variable] = grid_dims - # return all_grids \ No newline at end of file + # return all_grids diff --git a/smmregrid/gridtype.py b/smmregrid/gridtype.py index 2a18744..9083b4d 100644 --- a/smmregrid/gridtype.py +++ b/smmregrid/gridtype.py @@ -4,12 +4,13 @@ # default spatial dimensions and vertical coordinates DEFAULT_DIMS = { 'horizontal': ['i', 'j', 'x', 'y', 'lon', 'lat', 'longitude', 'latitude', - 'cell', 'cells', 'ncells', 'values', 'value', 'nod2', 'pix', 'elem', - 'nav_lon', 'nav_lat'], + 'cell', 'cells', 'ncells', 'values', 'value', 'nod2', 'pix', 'elem', + 'nav_lon', 'nav_lat'], 'vertical': ['lev', 'nz1', 'nz', 'depth', 'depth_full', 'depth_half'], 'time': ['time'] } + class GridType: def __init__(self, dims, extra_dims=None, weights=None): """ @@ -45,10 +46,10 @@ def _handle_default_dimensions(self, extra_dims): update_dims = DEFAULT_DIMS for dim in extra_dims.keys(): - if extra_dims[dim]: + if extra_dims[dim]: update_dims[dim] = update_dims[dim] + extra_dims[dim] return update_dims - + def __eq__(self, other): # so far equality based on dims only if isinstance(other, GridType): @@ -65,12 +66,12 @@ def _identify_dims(self, axis, default_dims): """ identified_dims = list(set(self.dims).intersection(default_dims[axis])) if axis == 'vertical': - if len(identified_dims)>1: + if len(identified_dims) > 1: raise ValueError(f'Only one vertical dimension can be processed at the time: check {identified_dims}') - if len(identified_dims)==1: - identified_dims=identified_dims[0] #unlist the single vertical dimension + if len(identified_dims) == 1: + identified_dims = identified_dims[0] # unlist the single vertical dimension return identified_dims if identified_dims else None - + def _identify_spatial_bounds(self, data): """ Find all bounds variables in the dataset by looking for variables @@ -81,13 +82,13 @@ def _identify_spatial_bounds(self, data): for var in data.data_vars: if (var.endswith('_bnds') or var.endswith('_bounds')) and 'time' not in var: # store all the bounds fro each grid. not fancy, but effective - #boundvar = var.split('_')[0] - #if boundvar in self.dims: + # boundvar = var.split('_')[0] + # if boundvar in self.dims: bounds_variables.append(var) return bounds_variables - - #def identify_sizes(self, data): + + # def identify_sizes(self, data): # """ # Idenfity the sizes of the dataset # """ @@ -115,16 +116,16 @@ def identify_variables(self, data): for var in data.data_vars: self._identify_variable(data[var], var) self.bounds = self._identify_spatial_bounds(data) - + elif isinstance(data, xr.DataArray): self._identify_variable(data) - + # def _identify_grid_type(self, grid_key): # """ # Determines the grid type (e.g., structured, unstructured, curvilinear). # This could be expanded based on more detailed metadata inspection. # """ - # horizontal_dims = self._identify_horizontal_dims(grid_key) + # horizontal_dims = self._identify_horizontal_dims(grid_key) # if 'mesh' in self.dataset.attrs.get('grid_type', '').lower(): # return 'unstructured' # elif any('lat' in coord and 'lon' in coord for coord in horizontal_dims): @@ -132,4 +133,4 @@ def identify_variables(self, data): # elif 'curvilinear' in self.dataset.attrs.get('grid_type', '').lower(): # return 'curvilinear' # else: - # return 'unknown' \ No newline at end of file + # return 'unknown' diff --git a/smmregrid/log.py b/smmregrid/log.py index 36610de..f2bb55b 100644 --- a/smmregrid/log.py +++ b/smmregrid/log.py @@ -13,7 +13,7 @@ def setup_logger(level=None, name=None): logger = logging.getLogger(name) # Create a logger specific to your module if logger.handlers: - #logger.warning('Logging is already setup with name %s', name) + # logger.warning('Logging is already setup with name %s', name) if level != logging.getLevelName(logger.getEffectiveLevel()): logger.setLevel(loglev) logger.info('Updating the log_level to %s', loglev) @@ -30,12 +30,13 @@ def setup_logger(level=None, name=None): # Create a handler for the logger handler = logging.StreamHandler() - #handler.setLevel(loglev) # Set the desired log level for the handler + # handler.setLevel(loglev) # Set the desired log level for the handler handler.setFormatter(formatter) # Assign the formatter to the handler logger.addHandler(handler) # Add the handler to the logger return logger + def convert_logger(loglev=None): """Convert a string or integer to a valid logging level""" @@ -67,4 +68,4 @@ def convert_logger(loglev=None): loglev, loglev_default) loglev = loglev_default - return loglev \ No newline at end of file + return loglev diff --git a/smmregrid/regrid.py b/smmregrid/regrid.py index 3483f13..63335c1 100644 --- a/smmregrid/regrid.py +++ b/smmregrid/regrid.py @@ -81,7 +81,7 @@ def __init__(self, source_grid=None, target_grid=None, weights=None, raise ValueError( "Either weights or source_grid/target_grid must be supplied" ) - + # Check for deprecated 'vert_coord' argument if vert_coord is not None: warnings.warn( @@ -91,7 +91,7 @@ def __init__(self, source_grid=None, target_grid=None, weights=None, # If cdo_extra is not provided, use the value from extra if vertical_dim is None: vertical_dim = vert_coord - + # Check for deprecated 'space_dim' argument if space_dims is not None: warnings.warn( @@ -103,11 +103,10 @@ def __init__(self, source_grid=None, target_grid=None, weights=None, self.loggy = setup_logger(level=loglevel, name='smmregrid.Regrid') self.loglevel = loglevel self.transpose = transpose - self.vertical_dim = [vertical_dim] #need a list + self.vertical_dim = [vertical_dim] # need a list if vertical_dim: self.loggy.info('Forcing vertical_dim: expecting a single gridtype dataset') - # Is there already a weights file? if weights is not None: self.grids = self._gridtype_from_weights(weights) @@ -123,14 +122,14 @@ def __init__(self, source_grid=None, target_grid=None, weights=None, self.grids = self._gridtype_from_data(source_grid_array) - len_grids = len(self.grids) + len_grids = len(self.grids) if len_grids == 0: raise KeyError('Cannot find any gridtype in your data, aborting!') if len_grids == 1: self.loggy.info('One gridtype found! Standard procedure') else: self.loggy.info('%s gridtypes found! We are in uncharted territory!', len_grids) - + for gridtype in self.grids: self.loggy.debug('Processing grids %s', gridtype.dims) self.loggy.debug('Horizontal dimension is %s', gridtype.horizontal_dims) @@ -139,7 +138,7 @@ def __init__(self, source_grid=None, target_grid=None, weights=None, # always prefer to pass file (i.e. source_grid) when possible to cdo_generate_weights # this will limit errors from xarray and speed up CDO itself # it wil work only for single gridtype dataset - if isinstance(source_grid, str) and len_grids==1: + if isinstance(source_grid, str) and len_grids == 1: source_grid_array_to_cdo = source_grid else: # when feeding from xarray, select the variable and its bounds @@ -153,7 +152,7 @@ def __init__(self, source_grid=None, target_grid=None, weights=None, gridtype.weights = cdo_generate_weights(source_grid_array_to_cdo, target_grid, method=method, vertical_dim=gridtype.vertical_dim, cdo=cdo, loglevel=loglevel) - + for gridtype in self.grids: if gridtype.vertical_dim: gridtype.weights_matrix = compute_weights_matrix3d(gridtype.weights, gridtype.vertical_dim) @@ -168,31 +167,30 @@ def __init__(self, source_grid=None, target_grid=None, weights=None, gridtype.weights = mask_weights(gridtype.weights, gridtype.weights_matrix, gridtype.vertical_dim) gridtype.masked = check_mask(gridtype.weights, gridtype.vertical_dim) - def _gridtype_from_weights(self, weights): """ Initialize the gridtype reading from weights """ - + self.loggy.warning('Precomputed weights support so far single-gridtype datasets') if not isinstance(weights, xarray.Dataset): weights = xarray.open_mfdataset(weights) grid_info = GridInspector(weights, cdo_weights=True, extra_dims={'vertical': self.vertical_dim}, - clean=False, loglevel=self.loglevel) + clean=False, loglevel=self.loglevel) gridtype = grid_info.get_grid_info() - #if not gridtype[0].dims: + # if not gridtype[0].dims: # self.loggy.warning('Missing weights dimension information, support only single-gridtype datasets') - + return gridtype - + def _gridtype_from_data(self, source_grid_array): """ Initialize the gridtype reading from source_data """ - grid_info = GridInspector(source_grid_array, extra_dims={'vertical': self.vertical_dim}, + grid_info = GridInspector(source_grid_array, extra_dims={'vertical': self.vertical_dim}, clean=True, loglevel=self.loglevel) return grid_info.get_grid_info() @@ -208,7 +206,7 @@ def regrid(self, source_data): version of the source variable """ - # apply the regridder on each DataArray + # apply the regridder on each DataArray if isinstance(source_data, xarray.Dataset): out = source_data.map(self.regrid_array, keep_attrs=False) @@ -225,7 +223,7 @@ def regrid(self, source_data): def regrid_array(self, source_data): """Regridding selection through 2d and 3d arrays""" - grid_inspect = GridInspector(source_data, clean=True, + grid_inspect = GridInspector(source_data, clean=True, extra_dims={'vertical': self.vertical_dim}, loglevel=self.loglevel) datagrids = grid_inspect.get_grid_info() @@ -235,7 +233,7 @@ def regrid_array(self, source_data): return self.regrid3d(source_data, datagridtype) # 2d case return self.regrid2d(source_data, datagridtype) - + def _get_gridtype(self, datagridtype): # special case for CDO weights without any dimensional information @@ -244,7 +242,7 @@ def _get_gridtype(self, datagridtype): self.loggy.warning('Assuming gridtype from data to be the same from weights') self.grids[0].dims = datagridtype.dims self.grids[0].horizontal_dims = datagridtype.horizontal_dims - + # match the grid gridtype = next((grid for grid in self.grids if grid == datagridtype), None) @@ -271,7 +269,7 @@ def regrid3d(self, source_data, datagridtype): if gridtype is None: self.loggy.info('%s will be excluded from the output', source_data.name) return xarray.DataArray(data=None) - + # select the gridtype to be used vertical_dim = gridtype.vertical_dim weights = gridtype.weights @@ -317,7 +315,7 @@ def regrid3d(self, source_data, datagridtype): # get dimensional info on target grid. TODO: can be moved at the init? target_gridtypes = GridInspector(data3d, clean=True, loglevel=self.loglevel).get_grid_info() target_horizontal_dims = target_gridtypes[0].horizontal_dims - + if self.transpose: dims = list(data3d.dims) index = min([i for i, s in enumerate(dims) if s in target_horizontal_dims]) @@ -347,16 +345,16 @@ def regrid2d(self, source_data, datagridtype): if gridtype is None: self.loggy.info('%s will be excluded from the output', source_data.name) return xarray.DataArray(data=None) - + return self.apply_weights( - source_data, - gridtype.weights, - weights_matrix= gridtype.weights_matrix, - masked= gridtype.masked, - horizontal_dims=gridtype.horizontal_dims) - + source_data, + gridtype.weights, + weights_matrix=gridtype.weights_matrix, + masked=gridtype.masked, + horizontal_dims=gridtype.horizontal_dims) + def apply_weights(self, source_data, weights, weights_matrix=None, - masked=True, horizontal_dims=None): + masked=True, horizontal_dims=None): """ Apply the CDO weights ``weights`` to ``source_data``, performing a regridding operation @@ -370,7 +368,6 @@ def apply_weights(self, source_data, weights, weights_matrix=None, xarray.DataArray: Regridded version of the source dataset """ - # Understand immediately if we need to return something or not # This is done if we have bounds variables if any(substring in source_data.name for substring in ["bnds", "bounds", "vertices"]): @@ -416,11 +413,12 @@ def apply_weights(self, source_data, weights, weights_matrix=None, axis_scale = 180.0 / math.pi # Weight lat/lon in radians # Dimension on which we can produce the interpolation - #if horizontal_dims is None: + # if horizontal_dims is None: # horizontal_dims = default_horizontal_dims if not any(x in source_data.dims for x in horizontal_dims): - self.loggy.error("None of dimensions on which we can interpolate is found in the DataArray. Does your DataArray include any of these?") + self.loggy.error( + "None of dimensions on which we can interpolate is found in the DataArray. Does your DataArray include any of these?") self.loggy.error(horizontal_dims) self.loggy.error('smmregrid can identify only %s', source_data.dims) raise KeyError('Dimensions mismatch') @@ -520,7 +518,6 @@ def apply_weights(self, source_data, weights, weights_matrix=None, return target_da - def regrid(source_data, target_grid=None, weights=None, transpose=True, cdo='cdo'): """ A simple regrid. Inefficient if you are regridding more than one dataset @@ -541,7 +538,6 @@ def regrid(source_data, target_grid=None, weights=None, transpose=True, cdo='cdo :class:`xarray.DataArray` with a regridded version of the source variable """ - regridder = Regridder(source_data, target_grid=target_grid, weights=weights, cdo=cdo, transpose=transpose) return regridder.regrid(source_data)