Skip to content

Commit

Permalink
style the plots
Browse files Browse the repository at this point in the history
  • Loading branch information
jdebacker committed May 16, 2024
1 parent 6383224 commit 78b5e75
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 96 deletions.
200 changes: 104 additions & 96 deletions examples/run_og_usa.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,21 @@
import json
import time
from taxcalc import Calculator
import matplotlib.pyplot as plt
from ogusa.calibrate import Calibration
from ogcore.parameters import Specifications
from ogcore import output_tables as ot
from ogcore import output_plots as op
from ogcore.execute import runner
from ogcore.utils import safe_read_pickle

# Use a custom matplotlib style file for plots
style_file_url = (
"https://raw.githubusercontent.com/PSLmodels/OG-Core/"
+ "master/ogcore/OGcorePlots.mplstyle"
)
plt.style.use(style_file_url)


def main():
# Define parameters to use for multiprocessing
Expand All @@ -29,104 +37,104 @@ def main():
------------------------------------------------------------------------
"""
# Set up baseline parameterization
p = Specifications(
baseline=True,
num_workers=num_workers,
baseline_dir=base_dir,
output_base=base_dir,
)
# Update parameters for baseline from default json file
p.update_specifications(
json.load(
open(
os.path.join(
CUR_DIR, "..", "ogusa", "ogusa_default_parameters.json"
)
)
)
)
p.tax_func_type = "GS"
p.age_specific = False
c = Calibration(p, estimate_tax_functions=True, client=client)
# close and delete client bc cache is too large
client.close()
del client
client = Client(n_workers=num_workers, threads_per_worker=1)
d = c.get_dict()
# # additional parameters to change
updated_params = {
"etr_params": d["etr_params"],
"mtrx_params": d["mtrx_params"],
"mtry_params": d["mtry_params"],
"mean_income_data": d["mean_income_data"],
"frac_tax_payroll": d["frac_tax_payroll"],
}
p.update_specifications(updated_params)
# Run model
start_time = time.time()
runner(p, time_path=True, client=client)
print("run time = ", time.time() - start_time)
# p = Specifications(
# baseline=True,
# num_workers=num_workers,
# baseline_dir=base_dir,
# output_base=base_dir,
# )
# # Update parameters for baseline from default json file
# p.update_specifications(
# json.load(
# open(
# os.path.join(
# CUR_DIR, "..", "ogusa", "ogusa_default_parameters.json"
# )
# )
# )
# )
# p.tax_func_type = "GS"
# p.age_specific = False
# c = Calibration(p, estimate_tax_functions=True, client=client)
# # close and delete client bc cache is too large
# client.close()
# del client
# client = Client(n_workers=num_workers, threads_per_worker=1)
# d = c.get_dict()
# # # additional parameters to change
# updated_params = {
# "etr_params": d["etr_params"],
# "mtrx_params": d["mtrx_params"],
# "mtry_params": d["mtry_params"],
# "mean_income_data": d["mean_income_data"],
# "frac_tax_payroll": d["frac_tax_payroll"],
# }
# p.update_specifications(updated_params)
# # Run model
# start_time = time.time()
# runner(p, time_path=True, client=client)
# print("run time = ", time.time() - start_time)

"""
------------------------------------------------------------------------
Run reform policy
------------------------------------------------------------------------
"""
# Grab a reform JSON file already in Tax-Calculator
# In this example the 'reform' is a change to 2017 law (the
# baseline policy is tax law in 2018)
reform_url = (
"github://PSLmodels:examples@main/psl_examples/"
+ "taxcalc/2017_law.json"
)
ref = Calculator.read_json_param_objects(reform_url, None)
iit_reform = ref["policy"]
# """
# ------------------------------------------------------------------------
# Run reform policy
# ------------------------------------------------------------------------
# """
# # Grab a reform JSON file already in Tax-Calculator
# # In this example the 'reform' is a change to 2017 law (the
# # baseline policy is tax law in 2018)
# reform_url = (
# "github://PSLmodels:examples@main/psl_examples/"
# + "taxcalc/2017_law.json"
# )
# ref = Calculator.read_json_param_objects(reform_url, None)
# iit_reform = ref["policy"]

# create new Specifications object for reform simulation
p2 = Specifications(
baseline=False,
num_workers=num_workers,
baseline_dir=base_dir,
output_base=reform_dir,
)
# Update parameters for baseline from default json file
p2.update_specifications(
json.load(
open(
os.path.join(
CUR_DIR, "..", "ogusa", "ogusa_default_parameters.json"
)
)
)
)
p2.tax_func_type = "GS"
p2.age_specific = False
# Use calibration class to estimate reform tax functions from
# Tax-Calculator, specifying reform for Tax-Calculator in iit_reform
c2 = Calibration(
p2, iit_reform=iit_reform, estimate_tax_functions=True, client=client
)
# close and delete client bc cache is too large
client.close()
del client
client = Client(n_workers=num_workers, threads_per_worker=1)
# update tax function parameters in Specifications Object
d = c2.get_dict()
# # additional parameters to change
updated_params = {
"cit_rate": [[0.35]],
"etr_params": d["etr_params"],
"mtrx_params": d["mtrx_params"],
"mtry_params": d["mtry_params"],
"mean_income_data": d["mean_income_data"],
"frac_tax_payroll": d["frac_tax_payroll"],
}
p2.update_specifications(updated_params)
# Run model
start_time = time.time()
runner(p2, time_path=True, client=client)
print("run time = ", time.time() - start_time)
client.close()
# # create new Specifications object for reform simulation
# p2 = Specifications(
# baseline=False,
# num_workers=num_workers,
# baseline_dir=base_dir,
# output_base=reform_dir,
# )
# # Update parameters for baseline from default json file
# p2.update_specifications(
# json.load(
# open(
# os.path.join(
# CUR_DIR, "..", "ogusa", "ogusa_default_parameters.json"
# )
# )
# )
# )
# p2.tax_func_type = "GS"
# p2.age_specific = False
# # Use calibration class to estimate reform tax functions from
# # Tax-Calculator, specifying reform for Tax-Calculator in iit_reform
# c2 = Calibration(
# p2, iit_reform=iit_reform, estimate_tax_functions=True, client=client
# )
# # close and delete client bc cache is too large
# client.close()
# del client
# client = Client(n_workers=num_workers, threads_per_worker=1)
# # update tax function parameters in Specifications Object
# d = c2.get_dict()
# # # additional parameters to change
# updated_params = {
# "cit_rate": [[0.35]],
# "etr_params": d["etr_params"],
# "mtrx_params": d["mtrx_params"],
# "mtry_params": d["mtry_params"],
# "mean_income_data": d["mean_income_data"],
# "frac_tax_payroll": d["frac_tax_payroll"],
# }
# p2.update_specifications(updated_params)
# # Run model
# start_time = time.time()
# runner(p2, time_path=True, client=client)
# print("run time = ", time.time() - start_time)
# client.close()

"""
------------------------------------------------------------------------
Expand Down
9 changes: 9 additions & 0 deletions examples/run_og_usa_current_policy_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import time
from taxcalc import Calculator
import matplotlib.pyplot as plt
from ogusa.calibrate import Calibration
from ogcore.parameters import Specifications
from ogcore import output_tables as ot
Expand All @@ -13,6 +14,14 @@
from ogcore.utils import safe_read_pickle


# Use a custom matplotlib style file for plots
style_file_url = (
"https://raw.githubusercontent.com/PSLmodels/OG-Core/"
+ "master/ogcore/OGcorePlots.mplstyle"
)
plt.style.use(style_file_url)


def main():
# Define parameters to use for multiprocessing
num_workers = min(multiprocessing.cpu_count(), 7)
Expand Down

0 comments on commit 78b5e75

Please sign in to comment.