Skip to content

Commit

Permalink
Add automatic period adjustment
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhilwoodruff committed Oct 4, 2023
1 parent ffcf4a5 commit 9731667
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 39 deletions.
29 changes: 4 additions & 25 deletions policyengine_core/holders/holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,10 @@ def set_input(
return warnings.warn(warning_message, Warning)
if self.variable.value_type in (float, int) and isinstance(array, str):
array = tools.eval_expression(array)
if self.variable.set_input:
if (
self.variable.set_input
and period.unit != self.variable.definition_period
):
return self.variable.set_input(self, period, array)
return self._set(period, array, branch_name)

Expand Down Expand Up @@ -285,30 +288,6 @@ def _set(
raise ValueError(
"A period must be specified to set values, except for variables with periods.ETERNITY as as period_definition."
)
if (
self.variable.definition_period != period.unit
or period.size > 1
):
name = self.variable.name
period_size_adj = (
f"{period.unit}"
if (period.size == 1)
else f"{period.size}-{period.unit}s"
)
error_message = os.linesep.join(
[
f'Unable to set a value for variable "{name}" for {period_size_adj}-long period "{period}".',
f'"{name}" can only be set for one {self.variable.definition_period} at a time. Please adapt your input.',
f'If you are the maintainer of "{name}", you can consider adding it a set_input attribute to enable automatic period casting.',
]
)

raise PeriodMismatchError(
self.variable.name,
period,
self.variable.definition_period,
error_message,
)

should_store_on_disk = (
self._on_disk_storable
Expand Down
30 changes: 24 additions & 6 deletions policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from policyengine_core.errors import CycleError, SpiralError
from policyengine_core.holders.holder import Holder
from policyengine_core.periods import Period
from policyengine_core.periods.config import ETERNITY
from policyengine_core.periods.config import ETERNITY, MONTH, YEAR
from policyengine_core.periods.helpers import period
from policyengine_core.tracers import (
FullTracer,
Expand All @@ -27,7 +27,7 @@
from policyengine_core.experimental import MemoryConfig
from policyengine_core.populations import Population
from policyengine_core.tracers import SimpleTracer
from policyengine_core.variables import Variable
from policyengine_core.variables import Variable, QuantityType
from policyengine_core.reforms.reform import Reform
from policyengine_core.parameters import get_parameter

Expand Down Expand Up @@ -454,8 +454,6 @@ def _calculate(
variable_name, check_existence=True
)

self._check_period_consistency(period, variable)

# Check if we've neutralized via parameters.
try:
if self.tax_benefit_system.parameters(period).gov.abolitions[
Expand All @@ -470,6 +468,20 @@ def _calculate(
if cached_array is not None:
return cached_array

if variable.definition_period == MONTH and period.unit == YEAR:
if variable.quantity_type == QuantityType.STOCK:
contained_months = period.get_subperiods(MONTH)
return self.calculate(variable_name, contained_months[-1])
else:
return self.calculate_add(variable_name, period)
elif variable.definition_period == YEAR and period.unit == MONTH:
if variable.quantity_type == QuantityType.STOCK:
return self.calculate(variable_name, period.this_year)
else:
return self.calculate_divide(variable_name, period)

self._check_period_consistency(period, variable)

if variable.defined_for is not None:
mask = (
self.calculate(
Expand Down Expand Up @@ -607,10 +619,13 @@ def calculate_add(
)
)

return sum(
result = sum(
self.calculate(variable_name, sub_period)
for sub_period in period.get_subperiods(variable.definition_period)
)
holder = self.get_holder(variable.name)
holder.put_in_cache(result, period, self.branch_name)
return result

def calculate_divide(
self,
Expand Down Expand Up @@ -640,9 +655,12 @@ def calculate_divide(

if period.unit == periods.MONTH:
computation_period = period.this_year
return (
result = (
self.calculate(variable_name, period=computation_period) / 12.0
)
holder = self.get_holder(variable.name)
holder.put_in_cache(result, period, self.branch_name)
return result
elif period.unit == periods.YEAR:
return self.calculate(variable_name, period)

Expand Down
29 changes: 21 additions & 8 deletions policyengine_core/variables/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from policyengine_core.entities import Entity
from policyengine_core.enums import Enum, EnumArray
from policyengine_core.periods import Period
from policyengine_core.holders import (
set_input_dispatch_by_period,
set_input_divide_by_period,
)

from . import config, helpers

Expand Down Expand Up @@ -176,13 +180,6 @@ def __init__(self, baseline_variable=None):
periods.ETERNITY,
),
)
self.quantity_type = self.set(
attr,
"quantity_type",
required=False,
allowed_values=(QuantityType.STOCK, QuantityType.FLOW),
default=QuantityType.FLOW,
)
self.label = self.set(
attr, "label", allowed_type=str, setter=self.set_label
)
Expand All @@ -192,13 +189,29 @@ def __init__(self, baseline_variable=None):
attr, "cerfa_field", allowed_type=(str, dict)
)
self.unit = self.set(attr, "unit", allowed_type=str)
self.quantity_type = self.set(
attr,
"quantity_type",
required=False,
allowed_values=(QuantityType.STOCK, QuantityType.FLOW),
default=QuantityType.STOCK
if (self.value_type in (bool, int, Enum, str) or self.unit == "/1")
else QuantityType.FLOW,
)
self.documentation = self.set(
attr,
"documentation",
allowed_type=str,
setter=self.set_documentation,
)
self.set_input = self.set_set_input(attr.pop("set_input", None))
self.set_input = self.set_set_input(
attr.pop(
"set_input",
set_input_dispatch_by_period
if self.quantity_type == QuantityType.STOCK
else set_input_divide_by_period,
)
)
self.calculate_output = self.set_calculate_output(
attr.pop("calculate_output", None)
)
Expand Down

0 comments on commit 9731667

Please sign in to comment.