diff --git a/eis_toolkit/cli.py b/eis_toolkit/cli.py index 86dafe56..369b6aca 100644 --- a/eis_toolkit/cli.py +++ b/eis_toolkit/cli.py @@ -786,15 +786,15 @@ def compute_pca_vector_cli( # DESCRIPTIVE STATISTICS (RASTER) @app.command() -def descriptive_statistics_raster_cli(input_file: INPUT_FILE_OPTION): +def descriptive_statistics_raster_cli(input_raster: INPUT_FILE_OPTION, band: int = 1): """Generate descriptive statistics from raster data.""" from eis_toolkit.exploratory_analyses.descriptive_statistics import descriptive_statistics_raster typer.echo("Progress: 10%") - with rasterio.open(input_file) as raster: + with rasterio.open(input_raster) as raster: typer.echo("Progress: 25%") - results_dict = descriptive_statistics_raster(raster) + results_dict = descriptive_statistics_raster(raster, band) typer.echo("Progress: 75%") typer.echo("Progress: 100% \n") diff --git a/eis_toolkit/exploratory_analyses/descriptive_statistics.py b/eis_toolkit/exploratory_analyses/descriptive_statistics.py index 3bc064aa..e0da06f5 100644 --- a/eis_toolkit/exploratory_analyses/descriptive_statistics.py +++ b/eis_toolkit/exploratory_analyses/descriptive_statistics.py @@ -3,14 +3,14 @@ import pandas as pd import rasterio from beartype import beartype -from beartype.typing import Union +from beartype.typing import Dict, Union from statsmodels.stats import stattools from statsmodels.stats.weightstats import DescrStatsW -from eis_toolkit.exceptions import InvalidColumnException +from eis_toolkit.exceptions import InvalidColumnException, InvalidRasterBandException -def _descriptive_statistics(data: Union[rasterio.io.DatasetReader, pd.DataFrame, gpd.GeoDataFrame]) -> dict: +def _descriptive_statistics(data: Union[rasterio.io.DatasetReader, pd.DataFrame, gpd.GeoDataFrame]) -> Dict[str, float]: statistics = DescrStatsW(data) min = np.min(data) max = np.max(data) @@ -38,14 +38,25 @@ def _descriptive_statistics(data: Union[rasterio.io.DatasetReader, pd.DataFrame, @beartype -def descriptive_statistics_dataframe(input_data: Union[pd.DataFrame, gpd.GeoDataFrame], column: str) -> dict: - """Generate descriptive statistics from vector data. +def descriptive_statistics_dataframe( + input_data: Union[pd.DataFrame, gpd.GeoDataFrame], column: str +) -> Dict[str, float]: + """Compute descriptive statistics from vector data. - Generates min, max, mean, quantiles(25%, 50% and 75%), standard deviation, relative standard deviation and skewness. + Computes the following statistics: + - min + - max + - mean + - quantiles 25% + - quantile 50% (median) + - quantile 75% + - standard deviation + - relative standard deviation + - skewness Args: - input_data: Data to generate descriptive statistics from. - column: Specify the column to generate descriptive statistics from. + input_data: Input vector data. + column: Column in vector data to compute descriptive statistics from. Returns: The descriptive statistics in previously described order. @@ -58,19 +69,33 @@ def descriptive_statistics_dataframe(input_data: Union[pd.DataFrame, gpd.GeoData @beartype -def descriptive_statistics_raster(input_data: rasterio.io.DatasetReader) -> dict: - """Generate descriptive statistics from raster data. +def descriptive_statistics_raster(input_data: rasterio.io.DatasetReader, band: int = 1) -> Dict[str, float]: + """Compute descriptive statistics from raster data. + + Computes the following statistics: + - min + - max + - mean + - quantiles 25% + - quantile 50% (median) + - quantile 75% + - standard deviation + - relative standard deviation + - skewness - Generates min, max, mean, quantiles(25%, 50% and 75%), standard deviation, relative standard deviation and skewness. Nodata values are removed from the data before the statistics are computed. Args: - input_data: Data to generate descriptive statistics from. + input_data: Input raster data. + band: Raster band to compute descriptive statistics from. Returns: The descriptive statistics in previously described order. """ - data = input_data.read().flatten() + if band not in range(1, input_data.count + 1): + raise InvalidRasterBandException(f"Input raster does not contain the selected band: {band}.") + + data = input_data.read(band) nodata_value = input_data.nodata data = data[data != nodata_value] statistics = _descriptive_statistics(data) diff --git a/tests/exploratory_analyses/descriptive_statistics_test.py b/tests/exploratory_analyses/descriptive_statistics_test.py index ee8f4a69..b7aa634c 100644 --- a/tests/exploratory_analyses/descriptive_statistics_test.py +++ b/tests/exploratory_analyses/descriptive_statistics_test.py @@ -61,7 +61,7 @@ def test_descriptive_statistics_geodataframe(): def test_descriptive_statistics_raster(): """Checks that returned statistics are correct when using numpy.ndarray.""" - test = descriptive_statistics_raster(src_raster) + test = descriptive_statistics_raster(src_raster, 1) np.testing.assert_almost_equal(test["min"], 2.503) np.testing.assert_almost_equal(test["max"], 9.67) np.testing.assert_almost_equal(test["mean"], 5.1865644)