Skip to content

Commit

Permalink
update community visualization for new json format and change optimiz…
Browse files Browse the repository at this point in the history
…ation to save the 10 best partitions
  • Loading branch information
formularin committed Mar 1, 2023
1 parent f5a566f commit 1b681b3
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 29 deletions.
6 changes: 3 additions & 3 deletions rba/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,9 @@ def ensemble_analysis(graph_file, difference_file, vra_config_file, num_steps, n
for precinct in precincts:
districts_precinct_df.loc[precinct] = [district_scores[district], homogeneity]

if optimize_vis:
output_dir = vis_dir

# Save a histogram of statewide scores.
plt.hist(scores_df["state_gerry_score"], bins=30)
plt.axvline(scores_df["state_gerry_score"].mean(), color='k', linestyle='dashed', linewidth=1)
Expand All @@ -462,9 +465,6 @@ def get_z_score(precinct, metric):
districts_assignment[node] = district
districts_partition = Partition(graph, assignment=districts_assignment)

# TODO: this doesn't work with Maryland for some reason
if optimize_vis:
output_dir = vis_dir
_, ax = plt.subplots(figsize=(12.8, 9.6))
visualize_gradient_geopandas(
sorted_node_names,
Expand Down
20 changes: 12 additions & 8 deletions rba/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def __next__(self):

@dataclass(order=True)
class ScoredPartition:
"""A comparable class for storing partitions and they gerrymandering scores."""
score: float
"""A comparable class for storing partitions and (1 - difference score)."""
goodness_score: float
partition: Partition=field(compare=False)


Expand Down Expand Up @@ -175,6 +175,8 @@ def generate_districts_simulated_annealing(graph, differences, num_vra_districts
# )
)

# TODO: save every partition

restarted = False
while True:
try:
Expand Down Expand Up @@ -242,7 +244,7 @@ def generate_districts_simulated_annealing(graph, differences, num_vra_districts
if verbose:
print("Running Markov chain...")

good_partitions = [] # min heap based on "goodness" score
good_partitions = [] # min heap based on (1 - difference score)
if verbose:
chain_iter = chain.with_progress_bar()
else:
Expand All @@ -256,14 +258,16 @@ def generate_districts_simulated_annealing(graph, differences, num_vra_districts
if i < 10:
heapq.heappush(
good_partitions,
ScoredPartition(score=state_score, partition=partition)
ScoredPartition(goodness_score=(1 - state_score), partition=partition)
)
elif state_score > good_partitions[0].score: # better than the worst good score.
elif (1 - state_score) > good_partitions[0].goodness_score: # better than the worst good score.
heapq.heapreplace(
good_partitions,
ScoredPartition(score=state_score, partition=partition)
ScoredPartition(goodness_score=(1 - state_score), partition=partition)
)



if verbose:
chain_iter.set_description(f"State score: {round(state_score, 4)}")

Expand Down Expand Up @@ -349,8 +353,8 @@ def optimize(graph_file, communitygen_out_file, vra_config_file, num_steps, num_
except FileExistsError:
pass

# Save districts in order of decreasing goodness.
for i, partition in enumerate(sorted(plans, key=lambda p: p["gerry_scores"][1], reverse=True)):
# Save districts in order of increasing gerrymandering score.
for i, partition in enumerate(sorted(plans, key=lambda p: p["gerry_scores"][1])):
# save_assignment(partition, os.path.join(output_dir, f"Plan_{i + 1}.json"))
with open(os.path.join(output_dir, f"Plan_{i + 1}.json"), "w+") as f:
json.dump({part: list(nodes) for part, nodes in partition.parts.items()}, f)
Expand Down
47 changes: 29 additions & 18 deletions rba/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,22 +404,18 @@ def visualize_community_generation(difference_fpath, output_fpath, graph, num_fr
with open(difference_fpath, "r") as f:
supercommunity_output = json.load(f) # Contains strings as keys.

differences = {}
edge_lifetimes = {}
for edge, lifetime in supercommunity_output.items():
u = edge.split(",")[0][2:-1]
v = edge.split(",")[1][2:-2]
# print(edge, (u, v), type(u), type(v))
if u == '19001001022' and v =='19001001021':
print("DETECTED", type((u,v)), (u,v))
if ((u,v) == ('19001001022', '19001001021')):
print("should be added!")
differences[frozenset((u, v))] = lifetime
if u in graph[v]:
edge_lifetimes[frozenset((u, v))] = lifetime
print("Done!")

max_lt = max(differences.values())
min_lt = min(differences.values())
max_lt = max(edge_lifetimes.values())
min_lt = min(edge_lifetimes.values())
edge_widths = {
edge: int((lt - min_lt) / max_lt * EDGE_WIDTH_FACTOR) + 1 for edge, lt in differences.items()
edge: int((lt - min_lt) / max_lt * EDGE_WIDTH_FACTOR) + 1 for edge, lt in edge_lifetimes.items()
}

# node_colors = {
Expand All @@ -437,12 +433,11 @@ def visualize_community_generation(difference_fpath, output_fpath, graph, num_fr
pass


living_edges = set(frozenset(e) for e in graph.edges)
# unrendered_contractions = [frozenset(supercommunity_output[e]) for e in graph.edges] # Not a set because order must be preserved.
unrendered_contractions = [] # Not a set because order must be preserved.
for edge in graph.edges:
# print(edge[0], edge[1], differences[edge], "edge")
unrendered_contractions.append((edge[0], edge[1], differences[frozenset(edge)]))
unrendered_contractions.append((edge[0], edge[1], edge_lifetimes[frozenset(edge)]))
unrendered_contractions = sorted(unrendered_contractions, key=lambda x: x[2])
community_graph = util.copy_adjacency(graph)
for edge in community_graph.edges:
Expand Down Expand Up @@ -470,8 +465,8 @@ def visualize_community_generation(difference_fpath, output_fpath, graph, num_fr
sys.stdout.flush()
t = (f - 1) / (num_frames - 1)
edge_colors = {}
for u, v in living_edges:
if differences[frozenset((u, v))] < t:
for u, v in graph.edges:
if edge_lifetimes[frozenset((u, v))] < t:
if graph.nodes[u]["partition"] != graph.nodes[v]["partition"]:
edge_colors[frozenset((u, v))] = (156, 156, 255)
else:
Expand All @@ -483,16 +478,33 @@ def visualize_community_generation(difference_fpath, output_fpath, graph, num_fr
edge_colors[frozenset((u, v))] = (0, 0, 0)

this_iter_contractions = set()
for c1, c2, time in unrendered_contractions:
for u, v, time in unrendered_contractions:
# if (c1, c2, time) in this_iter_contractions:
# continue
# if (c2, c1, time) in this_iter_contractions:
# continue
if time < t:
this_iter_contractions.add((u, v, time))

c1 = None
c2 = None
skip = False
for community in community_graph.nodes:
if u in community_graph.nodes[community]["constituent_nodes"]:
if v in community_graph.nodes[community]["constituent_nodes"]:
skip = True
else:
c1 = community
for second_community in community_graph.nodes:
if v in community_graph.nodes[second_community]["constituent_nodes"]:
c2 = second_community
break
if skip:
continue

# for neighbor in community_graph.neighbors(c2):
# this_iter_contractions.add((neighbor, c2, differences[frozenset((neighbor, c2))]))
# Update graph
print(c1, c2, time, this_iter_contractions)
for neighbor in community_graph[c2]:
if neighbor == c1:
continue
Expand All @@ -519,7 +531,6 @@ def visualize_community_generation(difference_fpath, output_fpath, graph, num_fr
# node_colors[node] = get_partisanship_color(total_rep / total_votes)

# Delete c2
this_iter_contractions.add((c1, c2, time))
community_graph.remove_node(c2)
# del node_coords[c2]
# del node_colors[c2]
Expand Down Expand Up @@ -641,7 +652,7 @@ def visualize_graph(graph, output_path, coords, colors=None, edge_colors=None, n
draw.line(
centers,
fill=edge_color,
width=1
width=5
)

print("Edges drawn")
Expand Down

0 comments on commit 1b681b3

Please sign in to comment.