Skip to content

Commit

Permalink
do not use warpedVRT to overwrite nodata value (#669)
Browse files Browse the repository at this point in the history
* do not use warpedVRT to overwrite nodata value

* update changelog
  • Loading branch information
vincentsarago authored Jan 12, 2024
1 parent 838ee7e commit 93b502b
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 9 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@

# unreleased

* do not use `warpedVRT` when overwriting the dataset nodata value

# 6.2.10 (2024-01-08)

* remove default Endpoint URL in AWS S3 Client for STAC Reader
Expand Down
22 changes: 13 additions & 9 deletions rio_tiler/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,17 +129,18 @@ def read(
io_resampling = Resampling[resampling_method]
warp_resampling = Resampling[reproject_method]

nodata = nodata if nodata is not None else src_dst.nodata

dst_crs = dst_crs or src_dst.crs
with contextlib.ExitStack() as ctx:
# Use WarpedVRT when Re-projection or Nodata or User VRT Option (cutline)
if (dst_crs != src_dst.crs) or nodata is not None or vrt_options:
# Use WarpedVRT when Re-projection or User VRT Option (cutline)
if (dst_crs != src_dst.crs) or vrt_options:
vrt_params = {
"crs": dst_crs,
"add_alpha": True,
"resampling": warp_resampling,
}

nodata = nodata if nodata is not None else src_dst.nodata
if nodata is not None:
vrt_params.update(
{"nodata": nodata, "add_alpha": False, "src_nodata": nodata}
Expand Down Expand Up @@ -178,7 +179,7 @@ def read(
):
boundless = True

if ColorInterp.alpha in dataset.colorinterp:
if ColorInterp.alpha in dataset.colorinterp and nodata is None:
# If dataset has an alpha band we need to get the mask using the alpha band index
# and then split the data and mask values
alpha_idx = dataset.colorinterp.index(ColorInterp.alpha) + 1
Expand Down Expand Up @@ -226,11 +227,12 @@ def read(
resampling=io_resampling,
boundless=boundless,
masked=True,
fill_value=nodata,
)

# if data has Nodata then we simply make sure the mask == the nodata
if dataset.nodata is not None:
data.mask |= data == dataset.nodata
if nodata is not None:
data.mask |= data == nodata

stats = []
for ix in indexes:
Expand Down Expand Up @@ -332,7 +334,7 @@ def part(

padding = padding or 0
dst_crs = dst_crs or src_dst.crs
if bounds_crs:
if bounds_crs and bounds_crs != dst_crs:
bounds = transform_bounds(bounds_crs, dst_crs, *bounds, densify_pts=21)

if minimum_overlap:
Expand All @@ -354,8 +356,8 @@ def part(
"Dataset covers less than {:.0f}% of tile".format(cover_ratio * 100)
)

# Use WarpedVRT when Re-projection or Nodata or User VRT Option (cutline)
if (dst_crs != src_dst.crs) or nodata is not None or vrt_options:
# Use WarpedVRT when Re-projection or User VRT Option (cutline)
if (dst_crs != src_dst.crs) or vrt_options:
window = None
vrt_transform, vrt_width, vrt_height = get_vrt_transform(
src_dst,
Expand Down Expand Up @@ -436,6 +438,7 @@ def part(
width=width,
height=height,
window=window,
nodata=nodata,
resampling_method=resampling_method,
reproject_method=reproject_method,
force_binary_mask=force_binary_mask,
Expand All @@ -458,6 +461,7 @@ def part(
width=width,
height=height,
window=window,
nodata=nodata,
resampling_method=resampling_method,
reproject_method=reproject_method,
force_binary_mask=force_binary_mask,
Expand Down
28 changes: 28 additions & 0 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def test_tile_read_not_covering_the_whole_tile():
# See https://github.com/cogeotiff/rio-tiler/issues/105#issuecomment-492268836
def test_tile_read_validMask():
"""Dataset mask should be the same as the actual mask."""
# bounds fully outside dataset
bounds = [
-6887893.4928338025,
12210356.646387195,
Expand All @@ -318,6 +319,33 @@ def test_tile_read_validMask():
numpy.testing.assert_array_equal(mask, masknodata)


def test_read_nodata():
"""Dataset mask should be the same as the actual mask."""
bounds = [
316470,
8094354,
415375,
8148789,
]
with rasterio.open(COG) as src_dst:
arr, mask = reader.part(src_dst, bounds, nodata=1)

masknodata = (arr[0] != 1).astype(numpy.uint8) * 255
numpy.testing.assert_array_equal(mask, masknodata)

with rasterio.open(COG) as src_dst:
arr, mask = reader.read(src_dst, nodata=1)

masknodata = (arr[0] != 1).astype(numpy.uint8) * 255
numpy.testing.assert_array_equal(mask, masknodata)

with rasterio.open(COG) as src_dst:
arr, mask = reader.read(src_dst, dst_crs="epsg:3857", nodata=1)

masknodata = (arr[0] != 1).astype(numpy.uint8) * 255
numpy.testing.assert_array_equal(mask, masknodata)


def test_tile_read_crs():
"""Read tile using different target CRS and bounds CRS."""
bounds = (
Expand Down

0 comments on commit 93b502b

Please sign in to comment.