Skip to content

Commit

Permalink
add ability to reuse previously generated ensembles
Browse files Browse the repository at this point in the history
  • Loading branch information
formularin committed Feb 28, 2023
1 parent 7428934 commit 2890db3
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions rba/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def create_constraints(initial_partition, vra_config):
return all_constraints


def generate_ensemble(graph, edge_lifetimes, num_vra_districts, vra_threshold,
def generate_ensemble(graph, node_differences, num_vra_districts, vra_threshold,
pop_equality_threshold, num_steps, num_districts, initial_assignment=None,
output_dir=None, verbose=False):
"""Conduct the ensemble analysis for a state. Data is returned, but all partitions are saved
Expand All @@ -115,7 +115,7 @@ def generate_ensemble(graph, edge_lifetimes, num_vra_districts, vra_threshold,
----------
graph : gerrychain.Graph
The state graph of precincts.
edge_lifetimes : dict
node_differences : dict
Maps edges (tuples of precinct IDs)
num_vra_districts : dict
Maps the name of each minority to the minimum number of VRA districts required for it.
Expand Down Expand Up @@ -145,7 +145,7 @@ def generate_ensemble(graph, edge_lifetimes, num_vra_districts, vra_threshold,
Contains gerrymandering scores of the state and all the districts for each step in the
Markov Chain.
"""
rba_updaters = create_updaters(edge_lifetimes, num_vra_districts, vra_threshold)
rba_updaters = create_updaters(node_differences, num_vra_districts, vra_threshold)

state_population = 0
for node in graph:
Expand Down Expand Up @@ -253,11 +253,11 @@ def ensemble_analysis(graph_file, community_file, vra_config_file, num_steps, nu
with open(community_file, "r") as f:
community_data = json.load(f)

edge_lifetimes = {}
node_differences = {}
for edge, lifetime in community_data["edge_lifetimes"].items():
u = edge.split(",")[0][2:-1]
v = edge.split(",")[1][2:-2]
edge_lifetimes[(u, v)] = lifetime
node_differences[(u, v)] = lifetime

if verbose:
print("done!")
Expand Down Expand Up @@ -290,10 +290,28 @@ def ensemble_analysis(graph_file, community_file, vra_config_file, num_steps, nu
print("No starting map provided. Will generate a random one later.")
initial_assignment = None

scores_df = generate_ensemble(graph, edge_lifetimes, vra_config, vra_threshold,
scores_df = generate_ensemble(graph, node_differences, vra_config, vra_threshold,
constants.POP_EQUALITY_THRESHOLD, num_steps, num_districts,
initial_assignment, output_dir, verbose)

# IN CASE THIS HAS ALREADY BEEN RUN AND WE WANT TO REGENERATE MAPS AND PLOTS, UNCOMMENT THE
# BLOCK OF CODE BELOW AND COMMENT OUT THE BLOCK OF CODE ABOVE.
# scores_df = pd.DataFrame(columns=[f"district {i}" for i in range(1, num_districts + 1)] + ["state_gerry_score"], dtype=float)
# print("Re-calculating scores from existing ensemble")
# for step in tqdm(range(num_steps)):
# with open(os.path.join(output_dir, "plans", f"{step + 1}.pickle"), "rb") as f:
# partition, _ = pickle.load(f)
# subgraphs = {part: graph.subgraph(partition.parts[part]) for part in partition.parts}
# district_scores, state_score = quantify_gerrymandering(
# graph,
# subgraphs,
# node_differences
# )
# districts_order = sorted(list(district_scores.keys()), key=lambda d: district_scores[d])
# scores_df.loc[len(scores_df.index)] = [district_scores[d] for d in districts_order] + [state_score]
# with open(os.path.join(output_dir, "plans", f"{step + 1}.pickle"), "wb") as f:
# partition = pickle.dump((partition, districts_order), f)

scores_df.to_csv(os.path.join(output_dir, "scores.csv"))

create_folder(os.path.join(output_dir, "visuals"))
Expand Down Expand Up @@ -358,7 +376,7 @@ def ensemble_analysis(graph_file, community_file, vra_config_file, num_steps, nu

districts_precinct_df = pd.DataFrame(columns=["score", "homogeneity"], index=sorted_node_names)
district_node_sets = load_districts(graph, district_file, verbose)
district_scores, state_score = quantify_gerrymandering(graph, district_node_sets, edge_lifetimes, verbose)
district_scores, state_score = quantify_gerrymandering(graph, district_node_sets, node_differences, verbose)
for district, precincts in district_node_sets.items():
homogeneity = statistics.stdev(
[graph.nodes[node]["total_rep"] / graph.nodes[node]["total_votes"]
Expand Down

0 comments on commit 2890db3

Please sign in to comment.