diff --git a/pyrealm/demography/community.py b/pyrealm/demography/community.py index 7c8b1bc3..4f6a05b5 100644 --- a/pyrealm/demography/community.py +++ b/pyrealm/demography/community.py @@ -7,9 +7,8 @@ post processing to align the input formats to the initialisation arguments to the Community class. -Internally, the cohort data in the Community class is represented as a pandas dataframe, -which makes it possible to update cohort attributes in parallel across all cohorts but -also provide a clean interface for adding and removing cohorts to a Community. +Internally, the cohort data in the Community class is represented as a dictionary of +`numpy` arrays. Worked example ============== @@ -90,9 +89,10 @@ ... cohort_pft_names=cohort_pft_names ... ) -Display the community's cohort data with calculated T Model predictions: +Convert the community cohort data to a :class:`pandas.DataFrame` for nicer display and +show some of the calculated T Model predictions: ->>> community.cohort_data[ +>>> pd.DataFrame(community.cohort_data)[ ... ['name', 'dbh', 'n_individuals', 'height', 'crown_area', 'stem_mass'] ... ] name dbh n_individuals height crown_area stem_mass @@ -354,9 +354,7 @@ class Community: # Post init properties number_of_cohorts: int = field(init=False) - - # Dataframe of cohort data - cohort_data: pd.DataFrame = field(init=False) + cohort_data: dict[str, NDArray] = field(init=False) def __post_init__( self, @@ -401,22 +399,37 @@ def __post_init__( f"Plant functional types unknown in flora: {','.join(unknown_pfts)}" ) - # Convert to a dataframe - cohort_data = pd.DataFrame( - { - "name": cohort_pft_names, - "dbh": cohort_dbh_values, - "n_individuals": cohort_n_individuals, - } - ) - # Broadcast the pft trait data to the cohort data by merging with the flora data - # and then store as the cohort data attribute - self.cohort_data = pd.merge(cohort_data, self.flora.data) - self.number_of_cohorts = self.cohort_data.shape[0] + # Store as a dictionary + self.cohort_data: dict[str, NDArray] = { + "name": cohort_pft_names, + "dbh": cohort_dbh_values, + "n_individuals": cohort_n_individuals, + } + + # Duplicate the pft trait data to match the cohort data and add to the cohort + # data dictionary. + self.cohort_data.update(self._unpack_pft_data(cohort_pft_names)) + + self.number_of_cohorts = len(cohort_pft_names) # Populate the T model fields self._calculate_t_model() + def _unpack_pft_data( + self, cohort_pft_names: NDArray[np.str_] + ) -> dict[str, NDArray]: + """Creates a dictionary of PFT data for a set of cohorts. + + Args: + cohort_pft_names: The PFT name for each cohort + """ + # Get the indices for the cohort PFT names in the flora PFT data + pft_index = [self.flora.pft_indices[str(nm)] for nm in cohort_pft_names] + + # Use that index to duplicate the PFT specific data into a per cohort entry for + # each of the PFT traits + return {k: v[pft_index] for k, v in self.flora.data.items()} + def _calculate_t_model(self) -> None: """Calculate T Model predictions across cohort data. diff --git a/pyrealm/demography/flora.py b/pyrealm/demography/flora.py index 57b72100..cc49ad42 100644 --- a/pyrealm/demography/flora.py +++ b/pyrealm/demography/flora.py @@ -28,6 +28,7 @@ import numpy as np import pandas as pd from marshmallow.exceptions import ValidationError +from numpy.typing import NDArray from pyrealm.demography.t_model_functions import ( calculate_canopy_q_m, @@ -237,12 +238,11 @@ def __init__(self, pfts: Sequence[type[PlantFunctionalTypeStrict]]) -> None: [getattr(pft, pft_field) for pft in self.values()] ) - self.data: pd.DataFrame = pd.DataFrame(data) - """A dataframe of trait values as numpy arrays. + self.data: dict[str, NDArray] = data + """A dictionary of trait values as numpy arrays.""" - The 'name' column can be used with cohort names to broadcast plant functional - type data out to cohorts. - """ + self.pft_indices = {v: k for k, v in enumerate(self.data["name"])} + """An dictionary giving the index of each PFT name in the PFT data.""" @classmethod def _from_file_data(cls, file_data: dict) -> Flora: diff --git a/pyrealm/demography/t_model_functions.py b/pyrealm/demography/t_model_functions.py index 1960f08e..1dd8e89b 100644 --- a/pyrealm/demography/t_model_functions.py +++ b/pyrealm/demography/t_model_functions.py @@ -6,10 +6,42 @@ """ # noqa: D205 import numpy as np -from pandas import Series +from numpy.typing import NDArray +from pyrealm.core.utilities import check_input_shapes -def calculate_heights(h_max: Series, a_hd: Series, dbh: Series) -> Series: + +def _validate_t_model_args(pft_args: list[NDArray], size_args: list[NDArray]) -> None: + """Shared validation for T model function inputs. + + Args: + pft_args: A list of row arrays representing trait values + size_args: A list of arrays representing stem sizes at which to evaluate + functions. + """ + + try: + pft_args_shape = check_input_shapes(*pft_args) + except ValueError: + raise ValueError("PFT trait values are not of equal length") + + try: + size_args_shape = check_input_shapes(*size_args) + except ValueError: + raise ValueError("Size arrays are not of equal length") + + if len(pft_args_shape) > 1: + raise ValueError("T model functions only accept 1D arrays of PFT trait values") + + try: + _ = np.broadcast_shapes(pft_args_shape, size_args_shape) + except ValueError: + raise ValueError("PFT and size inputs to T model function are not compatible.") + + +def calculate_heights( + h_max: NDArray[np.float32], a_hd: NDArray[np.float32], dbh: NDArray[np.float32] +) -> NDArray[np.float32]: r"""Calculate tree height under the T Model. The height of trees (:math:`H`) are calculated from individual diameters at breast @@ -31,8 +63,11 @@ def calculate_heights(h_max: Series, a_hd: Series, dbh: Series) -> Series: def calculate_crown_areas( - ca_ratio: Series, a_hd: Series, dbh: Series, height: Series -) -> Series: + ca_ratio: NDArray[np.float32], + a_hd: NDArray[np.float32], + dbh: NDArray[np.float32], + height: NDArray[np.float32], +) -> NDArray[np.float32]: r"""Calculate tree crown area under the T Model. The tree crown area (:math:`A_{c}`)is calculated from individual diameters at breast @@ -55,7 +90,9 @@ def calculate_crown_areas( return ((np.pi * ca_ratio) / (4 * a_hd)) * dbh * height -def calculate_crown_fractions(a_hd: Series, height: Series, dbh: Series) -> Series: +def calculate_crown_fractions( + a_hd: NDArray[np.float32], height: NDArray[np.float32], dbh: NDArray[np.float32] +) -> NDArray[np.float32]: r"""Calculate tree crown fraction under the T Model. The crown fraction (:math:`f_{c}`)is calculated from individual diameters at breast @@ -76,7 +113,9 @@ def calculate_crown_fractions(a_hd: Series, height: Series, dbh: Series) -> Seri return height / (a_hd * dbh) -def calculate_stem_masses(rho_s: Series, height: Series, dbh: Series) -> Series: +def calculate_stem_masses( + rho_s: NDArray[np.float32], height: NDArray[np.float32], dbh: NDArray[np.float32] +) -> NDArray[np.float32]: r"""Calculate stem mass under the T Model. The stem mass (:math:`W_{s}`) is calculated from individual diameters at breast @@ -96,7 +135,9 @@ def calculate_stem_masses(rho_s: Series, height: Series, dbh: Series) -> Series: return (np.pi / 8) * rho_s * (dbh**2) * height -def calculate_foliage_masses(sla: Series, lai: Series, crown_area: Series) -> Series: +def calculate_foliage_masses( + sla: NDArray[np.float32], lai: NDArray[np.float32], crown_area: NDArray[np.float32] +) -> NDArray[np.float32]: r"""Calculate foliage mass under the T Model. The foliage mass (:math:`W_{f}`) is calculated from the crown area (:math:`A_{c}`), @@ -117,12 +158,12 @@ def calculate_foliage_masses(sla: Series, lai: Series, crown_area: Series) -> Se def calculate_sapwood_masses( - rho_s: Series, - ca_ratio: Series, - height: Series, - crown_area: Series, - crown_fraction: Series, -) -> Series: + rho_s: NDArray[np.float32], + ca_ratio: NDArray[np.float32], + height: NDArray[np.float32], + crown_area: NDArray[np.float32], + crown_fraction: NDArray[np.float32], +) -> NDArray[np.float32]: r"""Calculate sapwood mass under the T Model. The sapwood mass (:math:`W_{\cdot s}`) is calculated from the individual crown area @@ -146,8 +187,11 @@ def calculate_sapwood_masses( def calculate_whole_crown_gpp( - potential_gpp: Series, crown_area: Series, par_ext: Series, lai: Series -) -> Series: + potential_gpp: NDArray[np.float32], + crown_area: NDArray[np.float32], + par_ext: NDArray[np.float32], + lai: NDArray[np.float32], +) -> NDArray[np.float32]: r"""Calculate whole crown gross primary productivity. This function calculates individual GPP across the whole crown, given the @@ -170,7 +214,9 @@ def calculate_whole_crown_gpp( return potential_gpp * crown_area * (1 - np.exp(-(par_ext * lai))) -def calculate_sapwood_respiration(resp_s: Series, sapwood_mass: Series) -> Series: +def calculate_sapwood_respiration( + resp_s: NDArray[np.float32], sapwood_mass: NDArray[np.float32] +) -> NDArray[np.float32]: r"""Calculate sapwood respiration. Calculates the total sapwood respiration (:math:`R_{\cdot s}`) given the individual @@ -187,7 +233,9 @@ def calculate_sapwood_respiration(resp_s: Series, sapwood_mass: Series) -> Serie return sapwood_mass * resp_s -def calculate_foliar_respiration(resp_f: Series, whole_crown_gpp: Series) -> Series: +def calculate_foliar_respiration( + resp_f: NDArray[np.float32], whole_crown_gpp: NDArray[np.float32] +) -> NDArray[np.float32]: r"""Calculate foliar respiration. Calculates the total foliar respiration (:math:`R_{f}`) given the individual crown @@ -207,8 +255,11 @@ def calculate_foliar_respiration(resp_f: Series, whole_crown_gpp: Series) -> Ser def calculate_fine_root_respiration( - zeta: Series, sla: Series, resp_r: Series, foliage_mass: Series -) -> Series: + zeta: NDArray[np.float32], + sla: NDArray[np.float32], + resp_r: NDArray[np.float32], + foliage_mass: NDArray[np.float32], +) -> NDArray[np.float32]: r"""Calculate foliar respiration. Calculates the total fine root respiration (:math:`R_{r}`) given the individual @@ -230,12 +281,12 @@ def calculate_fine_root_respiration( def calculate_net_primary_productivity( - yld: Series, - whole_crown_gpp: Series, - foliar_respiration: Series, - fine_root_respiration: Series, - sapwood_respiration: Series, -) -> Series: + yld: NDArray[np.float32], + whole_crown_gpp: NDArray[np.float32], + foliar_respiration: NDArray[np.float32], + fine_root_respiration: NDArray[np.float32], + sapwood_respiration: NDArray[np.float32], +) -> NDArray[np.float32]: r"""Calculate net primary productivity. The net primary productivity (NPP, :math:`P_{net}`) is calculated as a plant @@ -270,12 +321,12 @@ def calculate_net_primary_productivity( def calculate_foliage_and_fine_root_turnover( - sla: Series, - zeta: Series, - tau_f: Series, - tau_r: Series, - foliage_mass: Series, -) -> Series: + sla: NDArray[np.float32], + zeta: NDArray[np.float32], + tau_f: NDArray[np.float32], + tau_r: NDArray[np.float32], + foliage_mass: NDArray[np.float32], +) -> NDArray[np.float32]: r"""Calculate turnover costs. This function calculates the costs associated with the turnover of fine roots and @@ -301,18 +352,18 @@ def calculate_foliage_and_fine_root_turnover( def calculate_growth_increments( - rho_s: Series, - a_hd: Series, - h_max: Series, - lai: Series, - ca_ratio: Series, - sla: Series, - zeta: Series, - npp: Series, - turnover: Series, - dbh: Series, - height: Series, -) -> tuple[Series, Series, Series]: + rho_s: NDArray[np.float32], + a_hd: NDArray[np.float32], + h_max: NDArray[np.float32], + lai: NDArray[np.float32], + ca_ratio: NDArray[np.float32], + sla: NDArray[np.float32], + zeta: NDArray[np.float32], + npp: NDArray[np.float32], + turnover: NDArray[np.float32], + dbh: NDArray[np.float32], + height: NDArray[np.float32], +) -> tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]: r"""Calculate growth increments. Given an estimate of net primary productivity (:math:`P_{net}`), less associated @@ -437,7 +488,9 @@ def calculate_canopy_z_max_proportion(m: float, n: float) -> float: return ((n - 1) / (m * n - 1)) ** (1 / n) -def calculate_canopy_z_max(z_max_prop: Series, height: Series) -> Series: +def calculate_canopy_z_max( + z_max_prop: NDArray[np.float32], height: NDArray[np.float32] +) -> NDArray[np.float32]: r"""Calculate height of maximum crown radius. The height of the maximum crown radius (:math:`z_m`) is derived from the canopy @@ -461,7 +514,9 @@ def calculate_canopy_z_max(z_max_prop: Series, height: Series) -> Series: return height * z_max_prop -def calculate_canopy_r0(q_m: Series, crown_area: Series) -> Series: +def calculate_canopy_r0( + q_m: NDArray[np.float32], crown_area: NDArray[np.float32] +) -> NDArray[np.float32]: r"""Calculate scaling factor for height of maximum crown radius. This scaling factor (:math:`r_0`) is derived from the canopy shape parameters @@ -487,10 +542,10 @@ def calculate_canopy_r0(q_m: Series, crown_area: Series) -> Series: def calculate_relative_canopy_radii( z: float, - height: Series, - m: Series, - n: Series, -) -> Series: + height: NDArray[np.float32], + m: NDArray[np.float32], + n: NDArray[np.float32], +) -> NDArray[np.float32]: r"""Calculate relative canopy radius at a given height. The canopy shape parameters ``m`` and ``n`` define the vertical distribution of diff --git a/tests/unit/demography/test_flora.py b/tests/unit/demography/test_flora.py index ad975788..c03ca0f0 100644 --- a/tests/unit/demography/test_flora.py +++ b/tests/unit/demography/test_flora.py @@ -2,11 +2,9 @@ import sys from contextlib import nullcontext as does_not_raise -from dataclasses import fields from importlib import resources from json import JSONDecodeError -import pandas as pd import pytest from marshmallow.exceptions import ValidationError from pandas.errors import ParserError @@ -177,11 +175,9 @@ def test_Flora__init__(flora_inputs, outcome): assert k == v.name # Check data view is correct - assert isinstance(flora.data, pd.DataFrame) - assert flora.data.shape == ( - len(flora_inputs), - len(fields(next(iter(flora.values())))), - ) + assert isinstance(flora.data, dict) + for trait_array in flora.data.values(): + assert trait_array.shape == (len(flora),) # diff --git a/tests/unit/demography/test_t_model_functions.py b/tests/unit/demography/test_t_model_functions.py index 26c9523c..a442a2a8 100644 --- a/tests/unit/demography/test_t_model_functions.py +++ b/tests/unit/demography/test_t_model_functions.py @@ -1,6 +1,9 @@ """test the functions in t_model_functions.py.""" +from contextlib import nullcontext as does_not_raise + import numpy as np +import pytest from numpy.testing import assert_array_almost_equal from pyrealm.demography.t_model_functions import ( @@ -14,6 +17,78 @@ ) +@pytest.mark.parametrize( + argnames="pft_args, size_args, outcome, excep_message", + argvalues=[ + pytest.param( + [np.ones(4), np.ones(4)], + [np.ones(4), np.ones(4)], + does_not_raise(), + None, + id="all_1d_ok", + ), + pytest.param( + [np.ones(5), np.ones(4)], + [np.ones(4), np.ones(4)], + pytest.raises(ValueError), + "PFT trait values are not of equal length", + id="pfts_unequal", + ), + pytest.param( + [np.ones(4), np.ones(4)], + [np.ones(5), np.ones(4)], + pytest.raises(ValueError), + "Size arrays are not of equal length", + id="shape_unequal", + ), + pytest.param( + [np.ones((4, 2)), np.ones((4, 2))], + [np.ones(4), np.ones(4)], + pytest.raises(ValueError), + "T model functions only accept 1D arrays of PFT trait values", + id="pfts_not_row_arrays", + ), + pytest.param( + [np.ones(4), np.ones(4)], + [np.ones(5), np.ones(5)], + pytest.raises(ValueError), + "PFT and size inputs to T model function are not compatible.", + id="sizes_row_array_of_bad_length", + ), + pytest.param( + [np.ones(4), np.ones(4)], + [np.ones((5, 1)), np.ones((5, 1))], + does_not_raise(), + None, + id="size_2d_columns_ok", + ), + pytest.param( + [np.ones(4), np.ones(4)], + [np.ones((5, 2)), np.ones((5, 2))], + pytest.raises(ValueError), + "PFT and size inputs to T model function are not compatible.", + id="size_2d_not_ok", + ), + pytest.param( + [np.ones(4), np.ones(4)], + [np.ones((5, 4)), np.ones((5, 4))], + does_not_raise(), + None, + id="size_2d_weird_but_ok", + ), + ], +) +def test__validate_t_model_args(pft_args, size_args, outcome, excep_message): + """Test shared input validation function.""" + from pyrealm.demography.t_model_functions import _validate_t_model_args + + with outcome as excep: + _validate_t_model_args(pft_args=pft_args, size_args=size_args) + return + + assert str(excep.value).startswith(excep_message) + + def test_calculate_heights(): """Tests happy path for calculation of heights of tree from diameter.""" pft_h_max_values = np.array([25.33, 15.33])