Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
oloapinivad committed Oct 17, 2024
1 parent 81606cc commit bd30eee
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 88 deletions.
30 changes: 16 additions & 14 deletions smmregrid/cdo_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from .weights import compute_weights_matrix3d, compute_weights_matrix, mask_weights, check_mask
from .log import setup_logger


def worker(wlist, nnn, *args, **kwargs):
"""Run a worker process"""
wlist[nnn] = cdo_generate_weights2d(*args, **kwargs).compute()


def cdo_generate_weights(source_grid, target_grid, method="con", extrapolate=True,
remap_norm="fracarea", remap_area_min=0.0, icongridpath=None,
gridpath=None, extra=None, cdo_extra=None, cdo_options=None, vertical_dim=None,
Expand Down Expand Up @@ -79,7 +81,7 @@ def cdo_generate_weights(source_grid, target_grid, method="con", extrapolate=Tru

if not vertical_dim in sgrid:
raise KeyError(f'Cannot find vertical dim {vertical_dim} in {list(sgrid.dims)}')

nvert = sgrid[vertical_dim].values.size
loggy.info('Vertical dimensions has length: %s', nvert)

Expand All @@ -101,17 +103,17 @@ def cdo_generate_weights(source_grid, target_grid, method="con", extrapolate=Tru
ppp = Process(target=worker,
args=(wlist, lev, source_grid, target_grid),
kwargs={
"method": method,
"extrapolate": extrapolate,
"remap_norm": remap_norm,
"remap_area_min": remap_area_min,
"icongridpath": icongridpath,
"gridpath": gridpath,
"cdo_extra": cdo_extra + cdo_extra_vertical,
"cdo_options": cdo_options,
"cdo": cdo,
"nproc": nproc
})
"method": method,
"extrapolate": extrapolate,
"remap_norm": remap_norm,
"remap_area_min": remap_area_min,
"icongridpath": icongridpath,
"gridpath": gridpath,
"cdo_extra": cdo_extra + cdo_extra_vertical,
"cdo_options": cdo_options,
"cdo": cdo,
"nproc": nproc
})
ppp.start()
processes.append(ppp)

Expand All @@ -125,7 +127,7 @@ def cdo_generate_weights(source_grid, target_grid, method="con", extrapolate=Tru
weights = mask_weights(weights, weights_matrix, vertical_dim)
masked = check_mask(weights, vertical_dim)
masked = [int(x) for x in masked] # convert to list of int
masked_xa = xarray.DataArray(masked,
masked_xa = xarray.DataArray(masked,
coords={vertical_dim: range(0, len(masked))},
name="dst_grid_masked")

Expand Down Expand Up @@ -187,7 +189,7 @@ def cdo_generate_weights2d(source_grid, target_grid, method="con", extrapolate=T
source_grid_file = tempfile.NamedTemporaryFile()
source_grid.to_netcdf(source_grid_file.name)
sgrid = source_grid_file.name

if isinstance(target_grid, str):
tgrid = target_grid
else:
Expand Down
22 changes: 11 additions & 11 deletions smmregrid/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def check_cdo_regrid(finput, ftarget, remap_method='con', access='Dataset',
smmvar = find_var(xfield)
cdovar = find_var(cdofield)

#if len(smmvar) == 1 and access == 'DataArray':
# if len(smmvar) == 1 and access == 'DataArray':
# xfield = xfield[smmvar[0]]
#if len(cdovar) == 1 and access == 'DataArray':
# if len(cdovar) == 1 and access == 'DataArray':
# cdofield = cdofield[cdovar[0]]

# interpolation with smmregrid (CDO-based)
Expand Down Expand Up @@ -98,26 +98,26 @@ def check_cdo_regrid_levels(finput, ftarget, vertical_dim, levels, remap_method=
smmvar = find_var(xfield)
cdovar = find_var(cdofield)

#if len(smmvar) == 1 and access == 'DataArray':
# if len(smmvar) == 1 and access == 'DataArray':
# xfield = xfield[smmvar[0]]
#if len(cdovar) == 1 and access == 'DataArray':
# if len(cdovar) == 1 and access == 'DataArray':
# cdofield = cdofield[cdovar[0]]

# compute weights
if vertical_dim == 'plev':
wfield = cdo_generate_weights(finput, ftarget,
method=remap_method)
method=remap_method)
else:
wfield = cdo_generate_weights(finput, ftarget,
method=remap_method,
vertical_dim=vertical_dim)
method=remap_method,
vertical_dim=vertical_dim)

# Pass full 3D weights
interpolator = Regridder(weights=wfield)

# Add a helper idx_3d coordinate (unclear why it was here)
#idx = list(range(0, len(xfield.coords[vertical_dim])))
#xfield = xfield.assign_coords(idx_3d=(vertical_dim, idx))
# idx = list(range(0, len(xfield.coords[vertical_dim])))
# xfield = xfield.assign_coords(idx_3d=(vertical_dim, idx))

# subselect some levels
xfield = xfield.isel(**{vertical_dim: levels})
Expand All @@ -131,4 +131,4 @@ def check_cdo_regrid_levels(finput, ftarget, vertical_dim, levels, remap_method=

# check if arrays are equal with numerical tolerance
checker = np.allclose(cdofield, rfield, equal_nan=True)
return checker
return checker
21 changes: 11 additions & 10 deletions smmregrid/gridinspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from smmregrid.log import setup_logger
from .gridtype import GridType


class GridInspector():

def __init__(self, data, cdo_weights=False, extra_dims=None,
clean=True, loglevel='warning'):
"""
GridInspector class to detect information on the data, based on GridType class
Parameters:
data (xr.Datase or xr.DataArray): The input dataset.
clean (bool): apply the cleaning of grids which are assumed to be not relevant
Expand All @@ -25,7 +26,7 @@ def __init__(self, data, cdo_weights=False, extra_dims=None,
self.cdo_weights = cdo_weights
self.clean = clean
self.grids = [] # List to hold all grids info

def _inspect_grids(self):
"""
Inspects the dataset and identifies different grids.
Expand All @@ -41,22 +42,22 @@ def _inspect_grids(self):

for gridtype in self.grids:
gridtype.identify_variables(self.data)
#gridtype.identify_sizes(self.data)
# gridtype.identify_sizes(self.data)

def _inspect_dataarray_grid(self, data_array):
"""
Helper method to inspect a single DataArray and identify its grid type.
"""
grid_key = tuple(data_array.dims)
gridtype = GridType(dims=grid_key,extra_dims=self.extra_dims)
gridtype = GridType(dims=grid_key, extra_dims=self.extra_dims)
if gridtype not in self.grids:
self.grids.append(gridtype)

def _inspect_weights(self):
"""
Return basic information about CDO weights
"""

gridtype = GridType(dims=[], weights=self.data)

# get vertical info from the weights coords if available
Expand All @@ -82,7 +83,7 @@ def get_grid_info(self):
if self.clean:
self._clean_grids()

#self.loggy.info('Grids that have been identifed are: %s', self.grids.)
# self.loggy.info('Grids that have been identifed are: %s', self.grids.)
for gridtype in self.grids:
self.loggy.debug('More details on gridtype %s:', gridtype.dims)
if gridtype.horizontal_dims:
Expand All @@ -92,7 +93,7 @@ def get_grid_info(self):
self.loggy.debug(' Variables are: %s', list(gridtype.variables.keys()))
self.loggy.debug(' Bounds are: %s', gridtype.bounds)
return self.grids

def _clean_grids(self):
"""
Remove degenerate grids which are used by not relevant variables
Expand All @@ -108,16 +109,16 @@ def _clean_grids(self):
if any('bounds' in variable for variable in gridtype.variables):
removed.append(gridtype) # Add to removed list
self.loggy.info('Removing the grid defined by %s with variables containing "bounds"', gridtype.dims)

for remove in removed:
self.grids.remove(remove)

#def get_variable_grids(self):
# def get_variable_grids(self):
# """
# Return a dictionary with the variable - grids pairs
# """
# all_grids = {}
# for gridtype in self.grids:
# for variable in gridtype.variables:
# all_grids[variable] = grid_dims
# return all_grids
# return all_grids
33 changes: 17 additions & 16 deletions smmregrid/gridtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
# default spatial dimensions and vertical coordinates
DEFAULT_DIMS = {
'horizontal': ['i', 'j', 'x', 'y', 'lon', 'lat', 'longitude', 'latitude',
'cell', 'cells', 'ncells', 'values', 'value', 'nod2', 'pix', 'elem',
'nav_lon', 'nav_lat'],
'cell', 'cells', 'ncells', 'values', 'value', 'nod2', 'pix', 'elem',
'nav_lon', 'nav_lat'],
'vertical': ['lev', 'nz1', 'nz', 'depth', 'depth_full', 'depth_half'],
'time': ['time']
}


class GridType:
def __init__(self, dims, extra_dims=None, weights=None):
"""
Expand Down Expand Up @@ -45,10 +46,10 @@ def _handle_default_dimensions(self, extra_dims):

update_dims = DEFAULT_DIMS
for dim in extra_dims.keys():
if extra_dims[dim]:
if extra_dims[dim]:
update_dims[dim] = update_dims[dim] + extra_dims[dim]
return update_dims

def __eq__(self, other):
# so far equality based on dims only
if isinstance(other, GridType):
Expand All @@ -65,12 +66,12 @@ def _identify_dims(self, axis, default_dims):
"""
identified_dims = list(set(self.dims).intersection(default_dims[axis]))
if axis == 'vertical':
if len(identified_dims)>1:
if len(identified_dims) > 1:
raise ValueError(f'Only one vertical dimension can be processed at the time: check {identified_dims}')
if len(identified_dims)==1:
identified_dims=identified_dims[0] #unlist the single vertical dimension
if len(identified_dims) == 1:
identified_dims = identified_dims[0] # unlist the single vertical dimension
return identified_dims if identified_dims else None

def _identify_spatial_bounds(self, data):
"""
Find all bounds variables in the dataset by looking for variables
Expand All @@ -81,13 +82,13 @@ def _identify_spatial_bounds(self, data):
for var in data.data_vars:
if (var.endswith('_bnds') or var.endswith('_bounds')) and 'time' not in var:
# store all the bounds fro each grid. not fancy, but effective
#boundvar = var.split('_')[0]
#if boundvar in self.dims:
# boundvar = var.split('_')[0]
# if boundvar in self.dims:
bounds_variables.append(var)

return bounds_variables
#def identify_sizes(self, data):

# def identify_sizes(self, data):
# """
# Idenfity the sizes of the dataset
# """
Expand Down Expand Up @@ -115,21 +116,21 @@ def identify_variables(self, data):
for var in data.data_vars:
self._identify_variable(data[var], var)
self.bounds = self._identify_spatial_bounds(data)

elif isinstance(data, xr.DataArray):
self._identify_variable(data)

# def _identify_grid_type(self, grid_key):
# """
# Determines the grid type (e.g., structured, unstructured, curvilinear).
# This could be expanded based on more detailed metadata inspection.
# """
# horizontal_dims = self._identify_horizontal_dims(grid_key)
# horizontal_dims = self._identify_horizontal_dims(grid_key)
# if 'mesh' in self.dataset.attrs.get('grid_type', '').lower():
# return 'unstructured'
# elif any('lat' in coord and 'lon' in coord for coord in horizontal_dims):
# return 'regular'
# elif 'curvilinear' in self.dataset.attrs.get('grid_type', '').lower():
# return 'curvilinear'
# else:
# return 'unknown'
# return 'unknown'
7 changes: 4 additions & 3 deletions smmregrid/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def setup_logger(level=None, name=None):
logger = logging.getLogger(name) # Create a logger specific to your module

if logger.handlers:
#logger.warning('Logging is already setup with name %s', name)
# logger.warning('Logging is already setup with name %s', name)
if level != logging.getLevelName(logger.getEffectiveLevel()):
logger.setLevel(loglev)
logger.info('Updating the log_level to %s', loglev)
Expand All @@ -30,12 +30,13 @@ def setup_logger(level=None, name=None):

# Create a handler for the logger
handler = logging.StreamHandler()
#handler.setLevel(loglev) # Set the desired log level for the handler
# handler.setLevel(loglev) # Set the desired log level for the handler
handler.setFormatter(formatter) # Assign the formatter to the handler
logger.addHandler(handler) # Add the handler to the logger

return logger


def convert_logger(loglev=None):
"""Convert a string or integer to a valid logging level"""

Expand Down Expand Up @@ -67,4 +68,4 @@ def convert_logger(loglev=None):
loglev, loglev_default)
loglev = loglev_default

return loglev
return loglev
Loading

0 comments on commit bd30eee

Please sign in to comment.