Skip to content

Commit

Permalink
Modify read_geotif to include crop option
Browse files Browse the repository at this point in the history
  • Loading branch information
k034b363 committed Aug 16, 2024
1 parent d61265d commit 7421618
Showing 1 changed file with 41 additions and 13 deletions.
54 changes: 41 additions & 13 deletions plantcv/geospatial/read_geotif.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import cv2
import rasterio
import numpy as np
import fiona
from rasterio.mask import mask
from plantcv.plantcv import params
from plantcv.plantcv import fatal_error
from plantcv.plantcv._debug import _debug
from plantcv.plantcv.classes import Spectral_data
from shapely.geometry import shape, MultiPoint, mapping


def _find_closest_unsorted(array, target):
Expand All @@ -25,24 +28,46 @@ def _find_closest_unsorted(array, target):
return min(range(len(array)), key=lambda i: abs(array[i]-target))


def read_geotif(filename, bands="R,G,B"):
def read_geotif(filename, bands="R,G,B", cropto=None):
"""Read Georeferenced TIF image from file.
Inputs:
filename: Path of the TIF image file.
bands: Comma separated string representing the order of image bands (default bands="R,G,B"),
or a list of wavelengths (e.g. bands=[650,560,480])
bands: Comma separated string representing the order of image bands
(default bands="R,G,B"), or a list of wavelengths (e.g. bands=[650,560,480])
cropto: Path to a geoJSON-type shape file for cropping input image.
Returns:
spectral_array: PlantCV format Spectral data object instance
:param filename: str
:param bands: str, list
:return spectral_array: __main__.Spectral_data
"""
img = rasterio.open(filename)
img_data = img.read()

if cropto:
with fiona.open(cropto, 'r') as shapefile:
# polygon-type shapefile
if len(shapefile) == 1:
shapes = [feature['geometry'] for feature in shapefile]
# points-type shapefile
else:
points = [shape(feature["geometry"]) for feature in shapefile]
multi_point = MultiPoint(points)
convex_hull = multi_point.convex_hull
shapes = [mapping(convex_hull)]
# rasterio does the cropping within open
with rasterio.open(filename, 'r') as src:
img_data, geo_transform = mask(src, shapes, crop=True)
d_type = src.dtypes[0]
geo_crs = src.crs.wkt

else:
img = rasterio.open(filename)
img_data = img.read()
d_type = img.dtypes[0]
geo_transform = img.transform
geo_crs = img.crs.wkt

img_data = img_data.transpose(1, 2, 0) # reshape such that z-dimension is last
height = img.height
width = img.width
geo_transform = img.transform
height, width, _ = img_data.shape
wavelengths = {}

if isinstance(bands, str):
Expand All @@ -54,7 +79,8 @@ def read_geotif(filename, bands="R,G,B"):
for i, band in enumerate(list_bands):

if band.upper() not in wavelength_keys:
fatal_error(f"Currently {band} is not supported, instead provide list of wavelengths in order.")
fatal_error(f"Currently {band} is not supported, instead
provide list of wavelengths in order.")
else:
wavelength = default_wavelengths[band.upper()]
wavelengths[wavelength] = i
Expand All @@ -74,12 +100,13 @@ def read_geotif(filename, bands="R,G,B"):
max_wavelength=None,
min_wavelength=None,
max_value=np.max(pseudo_rgb), min_value=np.min(pseudo_rgb),
d_type=img.dtypes[0],
d_type=d_type,
wavelength_dict=wavelengths, samples=int(width),
lines=int(height), interleave=None,
wavelength_units="nm", array_type="datacube",
pseudo_rgb=pseudo_rgb, filename=filename, default_bands=None,
geo_transform=geo_transform)
geo_transform=geo_transform,
geo_crs=geo_crs)
_debug(visual=pseudo_rgb,
filename=os.path.join(params.debug_outdir, str(params.device) + "pseudo_rgb.png"))

Expand All @@ -106,12 +133,13 @@ def read_geotif(filename, bands="R,G,B"):
max_wavelength=None,
min_wavelength=None,
max_value=np.max(img_data), min_value=np.min(img_data),
d_type=img.dtypes[0],
d_type=d_type,
wavelength_dict=wavelengths, samples=int(width),
lines=int(height), interleave=None,
wavelength_units="nm", array_type="datacube",
pseudo_rgb=pseudo_rgb, filename=filename, default_bands=None,
geo_transform=geo_transform)
geo_transform=geo_transform,
geo_crs=geo_crs)

_debug(visual=pseudo_rgb, filename=os.path.join(params.debug_outdir,
str(params.device) + "pseudo_rgb.png"))
Expand Down

0 comments on commit 7421618

Please sign in to comment.