diff --git a/plantcv/geospatial/read_geotif.py b/plantcv/geospatial/read_geotif.py index 154c2dd..d89eebf 100644 --- a/plantcv/geospatial/read_geotif.py +++ b/plantcv/geospatial/read_geotif.py @@ -4,10 +4,13 @@ import cv2 import rasterio import numpy as np +import fiona +from rasterio.mask import mask from plantcv.plantcv import params from plantcv.plantcv import fatal_error from plantcv.plantcv._debug import _debug from plantcv.plantcv.classes import Spectral_data +from shapely.geometry import shape, MultiPoint, mapping def _find_closest_unsorted(array, target): @@ -25,24 +28,46 @@ def _find_closest_unsorted(array, target): return min(range(len(array)), key=lambda i: abs(array[i]-target)) -def read_geotif(filename, bands="R,G,B"): +def read_geotif(filename, bands="R,G,B", cropto=None): """Read Georeferenced TIF image from file. Inputs: filename: Path of the TIF image file. - bands: Comma separated string representing the order of image bands (default bands="R,G,B"), - or a list of wavelengths (e.g. bands=[650,560,480]) + bands: Comma separated string representing the order of image bands + (default bands="R,G,B"), or a list of wavelengths (e.g. bands=[650,560,480]) + cropto: Path to a geoJSON-type shape file for cropping input image. Returns: spectral_array: PlantCV format Spectral data object instance :param filename: str :param bands: str, list :return spectral_array: __main__.Spectral_data """ - img = rasterio.open(filename) - img_data = img.read() + + if cropto: + with fiona.open(cropto, 'r') as shapefile: + # polygon-type shapefile + if len(shapefile) == 1: + shapes = [feature['geometry'] for feature in shapefile] + # points-type shapefile + else: + points = [shape(feature["geometry"]) for feature in shapefile] + multi_point = MultiPoint(points) + convex_hull = multi_point.convex_hull + shapes = [mapping(convex_hull)] + # rasterio does the cropping within open + with rasterio.open(filename, 'r') as src: + img_data, geo_transform = mask(src, shapes, crop=True) + d_type = src.dtypes[0] + geo_crs = src.crs.wkt + + else: + img = rasterio.open(filename) + img_data = img.read() + d_type = img.dtypes[0] + geo_transform = img.transform + geo_crs = img.crs.wkt + img_data = img_data.transpose(1, 2, 0) # reshape such that z-dimension is last - height = img.height - width = img.width - geo_transform = img.transform + height, width, _ = img_data.shape wavelengths = {} if isinstance(bands, str): @@ -54,7 +79,8 @@ def read_geotif(filename, bands="R,G,B"): for i, band in enumerate(list_bands): if band.upper() not in wavelength_keys: - fatal_error(f"Currently {band} is not supported, instead provide list of wavelengths in order.") + fatal_error(f"Currently {band} is not supported, instead + provide list of wavelengths in order.") else: wavelength = default_wavelengths[band.upper()] wavelengths[wavelength] = i @@ -74,12 +100,13 @@ def read_geotif(filename, bands="R,G,B"): max_wavelength=None, min_wavelength=None, max_value=np.max(pseudo_rgb), min_value=np.min(pseudo_rgb), - d_type=img.dtypes[0], + d_type=d_type, wavelength_dict=wavelengths, samples=int(width), lines=int(height), interleave=None, wavelength_units="nm", array_type="datacube", pseudo_rgb=pseudo_rgb, filename=filename, default_bands=None, - geo_transform=geo_transform) + geo_transform=geo_transform, + geo_crs=geo_crs) _debug(visual=pseudo_rgb, filename=os.path.join(params.debug_outdir, str(params.device) + "pseudo_rgb.png")) @@ -106,12 +133,13 @@ def read_geotif(filename, bands="R,G,B"): max_wavelength=None, min_wavelength=None, max_value=np.max(img_data), min_value=np.min(img_data), - d_type=img.dtypes[0], + d_type=d_type, wavelength_dict=wavelengths, samples=int(width), lines=int(height), interleave=None, wavelength_units="nm", array_type="datacube", pseudo_rgb=pseudo_rgb, filename=filename, default_bands=None, - geo_transform=geo_transform) + geo_transform=geo_transform, + geo_crs=geo_crs) _debug(visual=pseudo_rgb, filename=os.path.join(params.debug_outdir, str(params.device) + "pseudo_rgb.png"))