Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for alternative policy baselines #106

Merged
merged 9 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions examples/run_og_usa_current_policy_baseline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import multiprocessing
from distributed import Client
import os
import requests
import json
import time
from taxcalc import Calculator
from ogusa.calibrate import Calibration
from ogcore.parameters import Specifications
from ogcore import output_tables as ot
from ogcore import output_plots as op
from ogcore.execute import runner
from ogcore.utils import safe_read_pickle


def main():
# Define parameters to use for multiprocessing
num_workers = min(multiprocessing.cpu_count(), 7)
client = Client(n_workers=num_workers, threads_per_worker=1)
print("Number of workers = ", num_workers)

# Directories to save data
CUR_DIR = os.path.dirname(os.path.realpath(__file__))
base_dir = os.path.join(CUR_DIR, "OG-USA-CP-Example", "OUTPUT_BASELINE")
reform_dir = os.path.join(CUR_DIR, "OG-USA-CP-Example", "OUTPUT_REFORM")

"""
------------------------------------------------------------------------
Run baseline policy
------------------------------------------------------------------------
"""
# Set up baseline parameterization
p = Specifications(
baseline=True,
num_workers=num_workers,
baseline_dir=base_dir,
output_base=base_dir,
)
# Update parameters for baseline from default json file
p.update_specifications(
json.load(
open(
os.path.join(
CUR_DIR, "..", "ogusa", "ogusa_default_parameters.json"
)
)
)
)
p.tax_func_type = "GS"
# get current policy JSON file
base_url = (
"github://PSLmodels:Tax-Calculator@master/taxcalc/"
+ "reforms/ext.json"
)
ref = Calculator.read_json_param_objects(base_url, None)
iit_baseline = ref["policy"]
c = Calibration(
p,
estimate_tax_functions=True,
iit_baseline=iit_baseline,
client=client,
)
# close and delete client bc cache is too large
client.close()
del client
client = Client(n_workers=num_workers, threads_per_worker=1)
d = c.get_dict()
# # additional parameters to change
updated_params = {
"etr_params": d["etr_params"],
"mtrx_params": d["mtrx_params"],
"mtry_params": d["mtry_params"],
"mean_income_data": d["mean_income_data"],
"frac_tax_payroll": d["frac_tax_payroll"],
}
p.update_specifications(updated_params)
# Run model
start_time = time.time()
runner(p, time_path=True, client=client)
print("run time = ", time.time() - start_time)

"""
------------------------------------------------------------------------
Run reform policy
------------------------------------------------------------------------
"""
# Grab a reform JSON file already in Tax-Calculator
# In this example the 'reform' is a change to 2017 law
reform_url = (
"github://PSLmodels:Tax-Calculator@master/taxcalc/"
+ "reforms/2017_law.json"
)
ref = Calculator.read_json_param_objects(reform_url, None)
iit_reform = ref["policy"]

# create new Specifications object for reform simulation
p2 = Specifications(
baseline=False,
num_workers=num_workers,
baseline_dir=base_dir,
output_base=reform_dir,
)
# Update parameters for baseline from default json file
p2.update_specifications(
json.load(
open(
os.path.join(
CUR_DIR, "..", "ogusa", "ogusa_default_parameters.json"
)
)
)
)
p2.tax_func_type = "GS"
# Use calibration class to estimate reform tax functions from
# Tax-Calculator, specifying reform for Tax-Calculator in iit_reform
c2 = Calibration(
p2,
iit_baseline=iit_baseline,
iit_reform=iit_reform,
estimate_tax_functions=True,
client=client,
)
# close and delete client bc cache is too large
client.close()
del client
client = Client(n_workers=num_workers, threads_per_worker=1)
# update tax function parameters in Specifications Object
d = c2.get_dict()
# # additional parameters to change
updated_params = {
"cit_rate": [[0.35]],
"etr_params": d["etr_params"],
"mtrx_params": d["mtrx_params"],
"mtry_params": d["mtry_params"],
"mean_income_data": d["mean_income_data"],
"frac_tax_payroll": d["frac_tax_payroll"],
}
p2.update_specifications(updated_params)
# Run model
start_time = time.time()
runner(p2, time_path=True, client=client)
print("run time = ", time.time() - start_time)
client.close()

"""
------------------------------------------------------------------------
Save some results of simulations
------------------------------------------------------------------------
"""
base_tpi = safe_read_pickle(os.path.join(base_dir, "TPI", "TPI_vars.pkl"))
base_params = safe_read_pickle(os.path.join(base_dir, "model_params.pkl"))
reform_tpi = safe_read_pickle(
os.path.join(reform_dir, "TPI", "TPI_vars.pkl")
)
reform_params = safe_read_pickle(
os.path.join(reform_dir, "model_params.pkl")
)
ans = ot.macro_table(
base_tpi,
base_params,
reform_tpi=reform_tpi,
reform_params=reform_params,
var_list=["Y", "C", "K", "L", "r", "w"],
output_type="pct_diff",
num_years=10,
start_year=base_params.start_year,
)

# create plots of output
op.plot_all(
base_dir, reform_dir, os.path.join(CUR_DIR, "OG-USA_example_plots")
)

print("Percentage changes in aggregates:", ans)
# save percentage change output to csv file
ans.to_csv("ogusa_example_output.csv")


if __name__ == "__main__":
# execute only if run as a script
main()
36 changes: 32 additions & 4 deletions ogusa/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,34 @@ def __init__(
estimate_chi_n=False,
estimate_pop=False,
tax_func_path=None,
iit_baseline=None,
iit_reform={},
guid="",
data="cps",
client=None,
num_workers=1,
):
"""
Constructor for the Calibration class. This class is used to find
parameter values for the OG-USA model.

Args:
p (OGUSA Parameters object): parameters object
estimate_tax_functions (bool): whether to estimate tax functions
estimate_beta (bool): whether to estimate beta
estimate_chi_n (bool): whether to estimate chi_n
estimate_pop (bool): whether to estimate population
tax_func_path (str): path to tax function parameters
iit_baseline (dict): baseline policy to use
iit_reform (dict): reform tax parameters
guid (str): id for tax function parameters
data (str): data source for microsimulation model
client (Dask client object): client
num_workers (int): number of workers for Dask client

Returns:
Calibration class object instance
"""
self.estimate_tax_functions = estimate_tax_functions
self.estimate_beta = estimate_beta
self.estimate_chi_n = estimate_chi_n
Expand All @@ -36,6 +58,7 @@ def __init__(
run_micro = True
self.tax_function_params = self.get_tax_function_parameters(
p,
iit_baseline,
iit_reform,
guid,
data,
Expand Down Expand Up @@ -104,6 +127,7 @@ def __init__(
def get_tax_function_parameters(
self,
p,
iit_baseline=None,
iit_reform={},
guid="",
data="",
Expand All @@ -117,7 +141,13 @@ def get_tax_function_parameters(
parameters from microsimulation model output.

Args:
p (OG-Core Parameters object): parameters object
iit_baseline (dict): baseline policy to use
iit_reform (dict): reform tax parameters
guid (string): id for tax function parameters
data (string): data source for microsimulation model
client (Dask client object): client
num_workers (int): number of workers for Dask client
run_micro (bool): whether to estimate parameters from
microsimulation model
tax_func_path (string): path where find or save tax
Expand Down Expand Up @@ -152,7 +182,8 @@ def get_tax_function_parameters(
micro_data, taxcalc_version = get_micro_data.get_data(
baseline=p.baseline,
start_year=p.start_year,
reform=iit_reform,
iit_baseline=iit_baseline,
iit_reform=iit_reform,
data=data,
path=p.output_base,
client=client,
Expand All @@ -166,12 +197,9 @@ def get_tax_function_parameters(
p.starting_age,
p.ending_age,
start_year=p.start_year,
baseline=p.baseline,
analytical_mtrs=p.analytical_mtrs,
tax_func_type=p.tax_func_type,
age_specific=p.age_specific,
reform=iit_reform,
data=data,
client=client,
num_workers=num_workers,
tax_func_path=tax_func_path,
Expand Down
45 changes: 31 additions & 14 deletions ogusa/get_micro_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
def get_calculator(
baseline,
calculator_start_year,
reform=None,
iit_baseline=None,
iit_reform=None,
data=None,
gfactors=None,
weights=None,
Expand Down Expand Up @@ -72,17 +73,21 @@ def get_calculator(
records1 = Records() # pragma: no cover

if baseline:
if not reform:
if iit_baseline is None:
print("Running current law policy baseline")
else:
print("Baseline policy is: ", reform)
print("Baseline policy is: ", iit_baseline)
policy1.implement_reform(iit_baseline)
else:
if not reform:
if not iit_reform:
print("Running with current law as reform")
else:
print("Reform policy is: ", reform)
print("TYPE", type(reform))
policy1.implement_reform(reform)
print("Reform policy is: ", iit_reform)
if (
iit_baseline is not None
): # if alt baseline, stack reform on that
policy1.implement_reform(iit_baseline)
policy1.implement_reform(iit_reform)

# the default set up increments year to 2013
calc1 = Calculator(records=records1, policy=policy1)
Expand All @@ -97,7 +102,8 @@ def get_calculator(
def get_data(
baseline=False,
start_year=DEFAULT_START_YEAR,
reform={},
iit_baseline=None,
iit_reform={},
data=None,
path=CUR_PATH,
client=None,
Expand All @@ -112,7 +118,8 @@ def get_data(
Args:
baseline (boolean): True if baseline tax policy
calculator_start_year (int): first year of budget window
reform (dictionary): IIT policy reform parameters, None if
iit_baseline (dictionary): IIT policy parameters for baseline
iit_reform (dictionary): IIT policy reform parameters, None if
baseline
data (DataFrame or str): DataFrame or path to datafile for
Records object
Expand All @@ -132,7 +139,9 @@ def get_data(
lazy_values = []
for year in range(start_year, TC_LAST_YEAR + 1):
lazy_values.append(
delayed(taxcalc_advance)(baseline, start_year, reform, data, year)
delayed(taxcalc_advance)(
baseline, start_year, iit_baseline, iit_reform, data, year
)
)
if client: # pragma: no cover
futures = client.compute(lazy_values, num_workers=num_workers)
Expand Down Expand Up @@ -167,14 +176,21 @@ def get_data(
return micro_data_dict, taxcalc_version


def taxcalc_advance(baseline, start_year, reform, data, year):
def taxcalc_advance(
baseline, start_year, iit_baseline, iit_reform, data, year
):
"""
This function advances the year used in Tax-Calculator, compute
taxes and rates, and save the results to a dictionary.

Args:
calc1 (Tax-Calculator Calculator object): TC calculator
year (int): year to begin advancing from
baseline (boolean): True if baseline tax policy
start_year (int): first year of budget window
iit_baseline (dict): IIT policy parameters for baseline
iit_reform (dict): IIT policy reform parameters for reform
data (DataFrame or str): DataFrame or path to datafile for
Records object
year (int): year to advance to in Tax-Calculator

Returns:
tax_dict (dict): a dictionary of microdata with marginal tax
Expand All @@ -183,7 +199,8 @@ def taxcalc_advance(baseline, start_year, reform, data, year):
calc1 = get_calculator(
baseline=baseline,
calculator_start_year=start_year,
reform=reform,
iit_baseline=iit_baseline,
iit_reform=iit_reform,
data=data,
)
calc1.advance_to_year(year)
Expand Down
Loading
Loading