Skip to content

Commit

Permalink
Merge pull request #18 from c-hydro/tools-v1.3.4
Browse files Browse the repository at this point in the history
Tools v1.3.4
  • Loading branch information
ltrotter authored Oct 14, 2024
2 parents e976c0c + 4ba55d1 commit e61e8ce
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 14 deletions.
98 changes: 97 additions & 1 deletion dam/processing/warp.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import xarray as xr
import rioxarray as rxr
import rasterio
import dask

from typing import Optional
import numpy as np
import tempfile
import os

from ..utils.register_process import as_DAM_process

Expand Down Expand Up @@ -31,8 +34,22 @@ def match_grid(input: xr.DataArray,
resampling_method: str|int = 'NearestNeighbour',
nodata_threshold: float = 1,
nodata_value: Optional[float] = None,
engine = 'xarray'
) -> xr.DataArray:

if engine == 'xarray':
return _match_grid_xarray(input, grid, resampling_method, nodata_threshold, nodata_value)
elif engine == 'gdal': # this is for compatibility with a previous version. It is not recommended to use gdal
return _match_grid_gdal(input, grid, resampling_method, nodata_threshold, nodata_value)
else:
raise ValueError('engine must be one of [xarray, gdal]')

def _match_grid_xarray(input: xr.DataArray,
grid: xr.DataArray,
resampling_method: str|int,
nodata_threshold: float,
nodata_value: Optional[float]
) -> xr.DataArray:
input_da = input
mask_da = grid

Expand Down Expand Up @@ -73,4 +90,83 @@ def process_chunk(chunk, nodata_value):
nan_mask = nan_mask * 100
regridded_mask = nan_mask.rio.reproject_match(mask_da, resampling=rasterio.enums.Resampling(5))

return input_reprojected.where(regridded_mask < nodata_threshold*100, nodata_value)
return input_reprojected.where(regridded_mask < nodata_threshold*100, nodata_value)

def _match_grid_gdal(input: xr.DataArray,
grid: xr.DataArray,
resampling_method: str|int,
nodata_threshold: float,
nodata_value: Optional[float]) -> xr.DataArray:

from osgeo import gdal, gdalconst

_resampling_methods_gdal = ['NearestNeighbour', 'Bilinear',
'Cubic', 'CubicSpline',
'Lanczos',
'Average', 'Mode',
'Max', 'Min',
'Med', 'Q1', 'Q3']

if isinstance(resampling_method, int):
for method in _resampling_methods:
if _resampling_methods[method] == resampling_method:
resampling_method = method
break

for method in _resampling_methods_gdal:
if method.lower() == resampling_method.lower():
resampling_method = method
break
else:
raise ValueError(f'resampling_method must be one of {_resampling_methods_gdal}')

# save input and grid to temporary files
with tempfile.TemporaryDirectory() as tmpdir:
input_path = f'{tmpdir}/input.tif'
grid_path = f'{tmpdir}/grid.tif'
input.rio.to_raster(input_path)
grid.rio.to_raster(grid_path)

output_path = input_path.replace('.tif', '_regridded.tif')

# Open the input and reference raster files
input_ds = gdal.Open(input_path, gdalconst.GA_ReadOnly)
input_transform = input_ds.GetGeoTransform()
input_projection = input_ds.GetProjection()

if nodata_value is not None:
input_ds.GetRasterBand(1).SetNoDataValue(nodata_value)

in_type = input_ds.GetRasterBand(1).DataType
input_ds = None

# Open the reference raster file
input_grid = gdal.Open(grid_path, gdalconst.GA_ReadOnly)
grid_transform = input_grid.GetGeoTransform()
grid_projection = input_grid.GetProjection()

# Get the resampling method
resampling = getattr(gdalconst, f'GRA_{resampling_method}')

# get the output bounds = the grid bounds
# input_bounds = [input_transform[0], input_transform[3], input_transform[0] + input_transform[1] * input_ds.RasterXSize,
# input_transform[3] + input_transform[5] * input_ds.RasterYSize]
output_bounds = [grid_transform[0], grid_transform[3], grid_transform[0] + grid_transform[1] * input_grid.RasterXSize,
grid_transform[3] + grid_transform[5] * input_grid.RasterYSize]

# set the type of the output to the type of the input if resampling is nearest neighbour, otherwise to float32
if resampling == gdalconst.GRA_NearestNeighbour:
output_type = in_type
else:
output_type = gdalconst.GDT_Float32

os.makedirs(os.path.dirname(output_path), exist_ok=True)
gdal.Warp(output_path, input_path, outputBounds=output_bounds, #outputBoundsSRS = input_projection,
srcSRS=input_projection, dstSRS=grid_projection,
xRes=grid_transform[1], yRes=grid_transform[5], resampleAlg=resampling,
outputType=output_type,
format='GTiff', creationOptions=['COMPRESS=LZW'], multithread=True)

output = rxr.open_rasterio(output_path)

return output
2 changes: 1 addition & 1 deletion dam/tools
32 changes: 21 additions & 11 deletions dam/workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ def __init__(self,
self.tmp_dir = tempfile.mkdtemp(dir = tmp_dir)

@classmethod
def from_options(cls, options: Options) -> 'DAMWorkflow':
def from_options(cls, options: Options|dict) -> 'DAMWorkflow':
if isinstance(options, dict): options = Options(options)
input = options.get('input',ignore_case=True)
if isinstance(input, Options):
if isinstance(input, dict):
input = Dataset.from_options(input)
output = options.get('output', None, ignore_case=True)
if isinstance(output, Options):
if isinstance(output, dict):
output = Dataset.from_options(output)
wf_options = options.get('options', None, ignore_case=True)

Expand All @@ -70,23 +71,33 @@ def clean_up(self):
except Exception as e:
print(f'Error cleaning up temporary directory: {e}')

def make_output(self, input: Dataset, output: Optional[Dataset|dict] = None) -> Dataset:
def make_output(self, input: Dataset, output: Optional[Dataset|dict] = None, name = None) -> Dataset:
if isinstance(output, Dataset):
return output

input_pattern = input.key_pattern
input_name = input.name
if name is not None:
extention = os.path.splitext(input_pattern)[1]
output_pattern = input_pattern.replace(f'{extention}', f'_{name}{extention}')
output_name = f'{input_name}_{name}'

if output is None:
key_pattern = input.key_pattern
output_pattern = output_pattern
elif isinstance(output, dict):
key_pattern = output.get('key_pattern', input.key_pattern)
output_pattern = output.get('key_pattern', output_pattern)
else:
raise ValueError('Output must be a Dataset or a dictionary.')

output_type = self.options['intermediate_output']
if output_type == 'Mem':
return MemoryDataset(key_pattern = input.key_pattern)
output_ds = MemoryDataset(key_pattern = output_pattern)
elif output_type == 'Tmp':
filename = os.path.basename(key_pattern)
return LocalDataset(path = self.tmp_dir, filename = filename)
filename = os.path.basename(output_pattern)
output_ds = LocalDataset(path = self.tmp_dir, filename = filename)

output_ds.name = output_name
return output_ds

def add_process(self, function, output: Optional[Dataset|dict] = None, **kwargs) -> None:
if len(self.processes) == 0:
Expand All @@ -96,8 +107,7 @@ def add_process(self, function, output: Optional[Dataset|dict] = None, **kwargs)
previous = self.processes[-1]
this_input = previous.output

this_output = self.make_output(this_input, output)
this_output.name = f'{this_output.name}_{function.__name__}'
this_output = self.make_output(this_input, output, function.__name__)
this_process = DAMProcessor(function = function,
input = this_input,
args = kwargs,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name='dam',
version='1.1.1',
version='1.1.0-alpha',
packages=find_packages(),
description='A package for raster data processing developed at the CIMA Research Foundation',
author='Luca Trotter',
Expand Down

0 comments on commit e61e8ce

Please sign in to comment.