Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualization: more plotting options and handling of missing .shx file #85

Merged
merged 2 commits into from
Oct 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 33 additions & 23 deletions utilities/plotMovie/generate_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@
ffmpeg -framerate 1 -r 30 -i frames/frame%05d.png -pix_fmt yuv420p movie.mp4
Can adjust inputs directly at bottom of file or as command-line input:

python generate_frames.py [sim results dir] [.shx dir] [output dir (optional)]
python generate_frames.py [sim results dir] [.shp dir] [output dir] [plotted element]

The function being plotted can be altered in the get_raw() functions defined
in the get_raw_data() and get_raw_data_hdf5() functions, such as plotting
cumulative infections, proportions of infections, or raw counts.

Endpoints of color range can be set in or passed into generate_plot,
or in main function, as necessary.
The plotted element can be any of the data fields in the county-level data:
infected, never_infected, susceptible, immune,
or deaths to plot number of deaths.

Depending on choice, adjusting color scale may be advisable, which can be
done by changing vmin and vmax values on line 198 in the call to generate_plot()

Since US census tracts include territories around the globe, if
we use this data, some sort of cropping is necessary (i.e. 48-state mainland)
Expand Down Expand Up @@ -87,26 +91,27 @@ def get_raw(county):
ds.close()
return raw_df

def get_raw_data_hdf5(name: str):
def get_raw_data_hdf5(name: str, plot_option: str = "infected"):
f = h5py.File(name, 'r')
found = 0
i = 0
while found < 2:
if f.attrs['component_' + str(i)] == b'FIPS':
fips_idx = i
found += 1
if f.attrs['component_' + str(i)] == b'infected':
inf_idx = i
found += 1
i += 1

fips = f['level_0']['data:datatype=' + str(fips_idx)][()]
infs = f['level_0']['data:datatype=' + str(inf_idx)][()]
comm_indices = {}
for i in range(f.attrs['num_components'][0]):
comm_indices[f.attrs['component_' + str(i)]] = str(i)

fips = f['level_0']['data:datatype=' + comm_indices[b'FIPS']][()]

if plot_option == "deaths":
plts = (f['level_0']['data:datatype=' + comm_indices[b'total']][()]
- f['level_0']['data:datatype=' + comm_indices[b'infected']][()]
- f['level_0']['data:datatype=' + comm_indices[b'never_infected']][()]
- f['level_0']['data:datatype=' + comm_indices[b'immune']][()]
- f['level_0']['data:datatype=' + comm_indices[b'susceptible']][()])
else:
plts = f['level_0']['data:datatype=' + comm_indices[bytes(plot_option, "utf-8")]][()]
unique_fips = np.unique(fips).astype(int)

def get_raw(county):
mask = fips == county
return np.log(1 + infs[mask].sum())
return np.log(1 + plts[mask].sum())

raw_df = pd.DataFrame()
raw_df["FIPS"] = unique_fips
Expand All @@ -124,10 +129,12 @@ def get_gdf(prefix: str):
# can be done by running code inside a
# with fiona.Env(SHAPE_RESTORE_SHX = "YES"):
# block (remember to import fiona explicitly)
gdf = gpd.read_file(prefix + ".shp", driver="esri")

# with fiona.Env(SHAPE_RESTORE_SHX = "YES"):
# gdf = gpd.read_file(prefix + ".shp", driver="esri")
if os.path.isfile(prefix + ".shx"):
gdf = gpd.read_file(prefix + ".shp", driver="esri")
else:
import fiona
with fiona.Env(SHAPE_RESTORE_SHX = "YES"):
gdf = gpd.read_file(prefix + ".shp", driver="esri")

cols = list(gdf.columns)

Expand Down Expand Up @@ -179,9 +186,12 @@ def generate_plot(per_df, gdf, vmin = None, vmax = None, crop_usa = False):
crop_usa = "_us_" in prefix

output_dir = sys.argv[3] if argc > 3 else "./frames_usa/"

plot_option = sys.argv[4] if argc > 4 else "infected"

for i in range(len(data_names)):
# vmin and vmax are endpoints for color range; 16 > log(population of LA) is a safe upper bound
# for per-capita, endpoints should be set to much less
fig = generate_plot(get_raw_data_hdf5(data_names[i]), gdf, vmin=0, vmax=16, crop_usa = crop_usa)
fig = generate_plot(get_raw_data_hdf5(data_names[i], plot_option), gdf, vmin=0, vmax=16, crop_usa = crop_usa)
fig.savefig(output_dir + "frame{:05d}".format(i))
plt.close(fig)
Loading