Skip to content

Commit

Permalink
clean up to run scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
jdebacker committed Aug 2, 2024
1 parent 516df70 commit 63b272c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 36 deletions.
31 changes: 7 additions & 24 deletions examples/run_og_zaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@
from ogcore.utils import safe_read_pickle, param_dump_json

# Use a custom matplotlib style file for plots
style_file_url = (
"https://raw.githubusercontent.com/PSLmodels/OG-Core/"
+ "master/ogcore/OGcorePlots.mplstyle"
)
plt.style.use(style_file_url)
plt.style.use("ogcore.OGcorePlots")


def main():
Expand All @@ -30,8 +26,9 @@ def main():

# Directories to save data
CUR_DIR = os.path.dirname(os.path.realpath(__file__))
base_dir = os.path.join(CUR_DIR, "OG-ZAF-Example", "OUTPUT_BASELINE")
reform_dir = os.path.join(CUR_DIR, "OG-ZAF-Example", "OUTPUT_REFORM")
save_dir = os.path.join(CUR_DIR, "OG-ZAF-Example")
base_dir = os.path.join(save_dir, "OUTPUT_BASELINE")
reform_dir = os.path.join(save_dir, "OUTPUT_REFORM")

"""
---------------------------------------------------------------------------
Expand All @@ -58,21 +55,7 @@ def main():
# Update parameters from calibrate.py Calibration class
c = Calibration(p)
updated_params = c.get_dict()
p.tax_func_type = "linear"
p.age_specific = False
p.update_specifications(updated_params)
# set underlying growth rate to zero, as value from data is negative
p.g_y = 0.0
# set tax rates
p.update_specifications(
{
"cit_rate": [[0.27]],
"etr_params": [[[0.22]]],
"mtrx_params": [[[0.31]]],
"mtry_params": [[[0.25]]],
"tau_c": [[0.15]],
}
)

# Run model
start_time = time.time()
Expand All @@ -90,7 +73,7 @@ def main():
p2.baseline = False
p2.output_base = reform_dir

# additional parameters to change
# Parameter change for the reform run
updated_params_ref = {
"cit_rate": [[0.30]],
}
Expand Down Expand Up @@ -128,12 +111,12 @@ def main():

# create plots of output
op.plot_all(
base_dir, reform_dir, os.path.join(CUR_DIR, "OG-ZAF_example_plots")
base_dir, reform_dir, os.path.join(save_dir, "OG-ZAF_example_plots")
)

print("Percentage changes in aggregates:", ans)
# save percentage change output to csv file
ans.to_csv("ogzaf_example_output.csv")
ans.to_csv(os.path.join(save_dir, "OG-ZAF_example_output.csv"))


if __name__ == "__main__":
Expand Down
28 changes: 16 additions & 12 deletions examples/run_og_zaf_multiple_industry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,29 @@
import json
import time
import copy
import numpy as np

# from taxcalc import Calculator
import matplotlib.pyplot as plt
from ogzaf.calibrate import Calibration
from ogcore.parameters import Specifications
from ogcore import output_tables as ot
from ogcore import output_plots as op
from ogcore.execute import runner
from ogcore.utils import safe_read_pickle, param_dump_json

# Use a custom matplotlib style file for plots
plt.style.use("ogcore.OGcorePlots")


def main():
# Define parameters to use for multiprocessing
client = Client()
num_workers = min(multiprocessing.cpu_count(), 7)
client = Client(n_workers=num_workers, threads_per_worker=1)
print("Number of workers = ", num_workers)

# Directories to save data
CUR_DIR = os.path.dirname(os.path.realpath(__file__))
base_dir = os.path.join(CUR_DIR, "OG-ZAF-CIT_Example", "OUTPUT_BASELINE")
reform_dir = os.path.join(CUR_DIR, "OG-ZAF-CIT_Example", "OUTPUT_REFORM")
save_dir = os.path.join(CUR_DIR, "OG-ZAF-MultipleIndustry-example")
base_dir = os.path.join(save_dir, "OUTPUT_BASELINE")
reform_dir = os.path.join(save_dir, "OUTPUT_REFORM")

"""
---------------------------------------------------------------------------
Expand Down Expand Up @@ -84,12 +86,14 @@ def main():
p2.baseline = False
p2.output_base = reform_dir

# additional parameters to change
# Example reform is a corp tax rate cut (phased in) for all
# industries EXCEPT for secondary ex energy, which has a one point
# increae in the CIT rate
updated_params_ref = {
"cit_rate": [
[0.28, 0.28, 0.28, 0.28],
[0.28, 0.28, 0.28, 0.28],
[0.27, 0.27, 0.27, 0.27],
[0.27, 0.27, 0.27, 0.28],
[0.26, 0.26, 0.26, 0.28],
[0.25, 0.25, 0.25, 0.28],
],
"baseline_spending": True,
}
Expand Down Expand Up @@ -129,12 +133,12 @@ def main():
op.plot_all(
base_dir,
reform_dir,
os.path.join(CUR_DIR, "OG-ZAF_CIT_multi_industry_plots"),
os.path.join(save_dir, "OG-ZAF-MultipleIndustry-example_plots"),
)

print("Percentage changes in aggregates:", ans)
# save percentage change output to csv file
ans.to_csv("ogzaf_CIT_multi_industry_output.csv")
ans.to_csv(os.path.join(save_dir, "OG-ZAF-MultipleIndustry-example_output.csv"))


if __name__ == "__main__":
Expand Down

0 comments on commit 63b272c

Please sign in to comment.