From 97316678c55e3e88eab8a6dd04fa37b9ed637920 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff <35577657+nikhilwoodruff@users.noreply.github.com> Date: Wed, 4 Oct 2023 14:24:24 +0100 Subject: [PATCH] Add automatic period adjustment --- policyengine_core/holders/holder.py | 29 +++----------------- policyengine_core/simulations/simulation.py | 30 ++++++++++++++++----- policyengine_core/variables/variable.py | 29 ++++++++++++++------ 3 files changed, 49 insertions(+), 39 deletions(-) diff --git a/policyengine_core/holders/holder.py b/policyengine_core/holders/holder.py index 534f673f4..084d86943 100644 --- a/policyengine_core/holders/holder.py +++ b/policyengine_core/holders/holder.py @@ -235,7 +235,10 @@ def set_input( return warnings.warn(warning_message, Warning) if self.variable.value_type in (float, int) and isinstance(array, str): array = tools.eval_expression(array) - if self.variable.set_input: + if ( + self.variable.set_input + and period.unit != self.variable.definition_period + ): return self.variable.set_input(self, period, array) return self._set(period, array, branch_name) @@ -285,30 +288,6 @@ def _set( raise ValueError( "A period must be specified to set values, except for variables with periods.ETERNITY as as period_definition." ) - if ( - self.variable.definition_period != period.unit - or period.size > 1 - ): - name = self.variable.name - period_size_adj = ( - f"{period.unit}" - if (period.size == 1) - else f"{period.size}-{period.unit}s" - ) - error_message = os.linesep.join( - [ - f'Unable to set a value for variable "{name}" for {period_size_adj}-long period "{period}".', - f'"{name}" can only be set for one {self.variable.definition_period} at a time. Please adapt your input.', - f'If you are the maintainer of "{name}", you can consider adding it a set_input attribute to enable automatic period casting.', - ] - ) - - raise PeriodMismatchError( - self.variable.name, - period, - self.variable.definition_period, - error_message, - ) should_store_on_disk = ( self._on_disk_storable diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index 65d607be3..36d368f6d 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -13,7 +13,7 @@ from policyengine_core.errors import CycleError, SpiralError from policyengine_core.holders.holder import Holder from policyengine_core.periods import Period -from policyengine_core.periods.config import ETERNITY +from policyengine_core.periods.config import ETERNITY, MONTH, YEAR from policyengine_core.periods.helpers import period from policyengine_core.tracers import ( FullTracer, @@ -27,7 +27,7 @@ from policyengine_core.experimental import MemoryConfig from policyengine_core.populations import Population from policyengine_core.tracers import SimpleTracer -from policyengine_core.variables import Variable +from policyengine_core.variables import Variable, QuantityType from policyengine_core.reforms.reform import Reform from policyengine_core.parameters import get_parameter @@ -454,8 +454,6 @@ def _calculate( variable_name, check_existence=True ) - self._check_period_consistency(period, variable) - # Check if we've neutralized via parameters. try: if self.tax_benefit_system.parameters(period).gov.abolitions[ @@ -470,6 +468,20 @@ def _calculate( if cached_array is not None: return cached_array + if variable.definition_period == MONTH and period.unit == YEAR: + if variable.quantity_type == QuantityType.STOCK: + contained_months = period.get_subperiods(MONTH) + return self.calculate(variable_name, contained_months[-1]) + else: + return self.calculate_add(variable_name, period) + elif variable.definition_period == YEAR and period.unit == MONTH: + if variable.quantity_type == QuantityType.STOCK: + return self.calculate(variable_name, period.this_year) + else: + return self.calculate_divide(variable_name, period) + + self._check_period_consistency(period, variable) + if variable.defined_for is not None: mask = ( self.calculate( @@ -607,10 +619,13 @@ def calculate_add( ) ) - return sum( + result = sum( self.calculate(variable_name, sub_period) for sub_period in period.get_subperiods(variable.definition_period) ) + holder = self.get_holder(variable.name) + holder.put_in_cache(result, period, self.branch_name) + return result def calculate_divide( self, @@ -640,9 +655,12 @@ def calculate_divide( if period.unit == periods.MONTH: computation_period = period.this_year - return ( + result = ( self.calculate(variable_name, period=computation_period) / 12.0 ) + holder = self.get_holder(variable.name) + holder.put_in_cache(result, period, self.branch_name) + return result elif period.unit == periods.YEAR: return self.calculate(variable_name, period) diff --git a/policyengine_core/variables/variable.py b/policyengine_core/variables/variable.py index 0dcb6d451..bbdca35d9 100644 --- a/policyengine_core/variables/variable.py +++ b/policyengine_core/variables/variable.py @@ -11,6 +11,10 @@ from policyengine_core.entities import Entity from policyengine_core.enums import Enum, EnumArray from policyengine_core.periods import Period +from policyengine_core.holders import ( + set_input_dispatch_by_period, + set_input_divide_by_period, +) from . import config, helpers @@ -176,13 +180,6 @@ def __init__(self, baseline_variable=None): periods.ETERNITY, ), ) - self.quantity_type = self.set( - attr, - "quantity_type", - required=False, - allowed_values=(QuantityType.STOCK, QuantityType.FLOW), - default=QuantityType.FLOW, - ) self.label = self.set( attr, "label", allowed_type=str, setter=self.set_label ) @@ -192,13 +189,29 @@ def __init__(self, baseline_variable=None): attr, "cerfa_field", allowed_type=(str, dict) ) self.unit = self.set(attr, "unit", allowed_type=str) + self.quantity_type = self.set( + attr, + "quantity_type", + required=False, + allowed_values=(QuantityType.STOCK, QuantityType.FLOW), + default=QuantityType.STOCK + if (self.value_type in (bool, int, Enum, str) or self.unit == "/1") + else QuantityType.FLOW, + ) self.documentation = self.set( attr, "documentation", allowed_type=str, setter=self.set_documentation, ) - self.set_input = self.set_set_input(attr.pop("set_input", None)) + self.set_input = self.set_set_input( + attr.pop( + "set_input", + set_input_dispatch_by_period + if self.quantity_type == QuantityType.STOCK + else set_input_divide_by_period, + ) + ) self.calculate_output = self.set_calculate_output( attr.pop("calculate_output", None) )