diff --git a/rba/__main__.py b/rba/__main__.py index 4709db8..3fef683 100644 --- a/rba/__main__.py +++ b/rba/__main__.py @@ -56,7 +56,11 @@ ensemble_parser.add_argument("--graph_file", type=str, default=os.path.join(package_dir, "data/2010/new_hampshire_geodata_merged.json")) ensemble_parser.add_argument("--community_file", type=str, default=os.path.join(package_dir, "data/2010/new_hampshire_communities.json")) ensemble_parser.add_argument("--vra_config_file", type=str, default=os.path.join(package_dir, "data/2010/vra_nh.json")) + ensemble_parser.add_argument("--num_steps", type=int, default=100) + ensemble_parser.add_argument("--num_districts", type=int, default=2) + ensemble_parser.add_argument("--initial_plan_file", type=str, default=None) ensemble_parser.add_argument("--district_file", type=str, default=os.path.join(package_dir, "data/2010/new_hampshire_districts.json")) + ensemble_parser.add_argument("-o", "--output_dir", type=str) ensemble_parser.set_defaults(func=rba.ensemble.ensemble_analysis) optimize_parser = subparsers.add_parser("optimize") diff --git a/rba/data/2010/vra_ma.json b/rba/data/2010/vra_ma.json deleted file mode 100644 index 28d7db5..0000000 --- a/rba/data/2010/vra_ma.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "black": 0, - "hispanic": 0, - "asian": 0, - "native": 0, - "islander": 0, - "combined": 4, - "opportunity_threshold": 0.51 -} \ No newline at end of file diff --git a/rba/ensemble.py b/rba/ensemble.py index c63d314..54bbaa9 100644 --- a/rba/ensemble.py +++ b/rba/ensemble.py @@ -3,33 +3,63 @@ for a state. Enforces population equality, VRA compliance, and biases towards following county lines. """ +from dataclasses import dataclass from functools import partial import json +import math +import os +import pickle import random -from collections import defaultdict +import sys import statistics +import warnings -import matplotlib.pyplot as plt -import networkx as nx from gerrychain import Partition, Graph, MarkovChain, updaters, constraints, accept +from gerrychain.constraints import Validator from gerrychain.proposals import recom from gerrychain.tree import recursive_tree_part, bipartition_tree -from gerrychain.random import random +from tqdm import tqdm +from welford import Welford +import gerrychain.random +import matplotlib.pyplot as plt +import networkx as nx +import numpy as np import pandas as pd from rba import constants from rba.district_quantification import quantify_gerrymandering, quantify_districts -from rba.util import get_num_vra_districts, get_county_weighted_random_spanning_tree -from rba.visualization import visualize_partition_geopandas, visualize_metric +from rba.util import (get_num_vra_districts, get_county_weighted_random_spanning_tree, + get_county_spanning_forest, choose_cut, create_folder, load_districts) +from rba.visualization import (visualize_gradient_geopandas, visualize_partition_geopandas, + visualize_metric) -# CONSTANTS -# random.seed(2023) -# GEODATA_FILE = "../rba/data/2010/new_hampshire_geodata_merged.json" -# COMMUNITY_OUTPUT_FILE = "../rba/data/2010/new_hampshire_communities.json" -# VRA_CONFIG_FILE = "../rba/data/2010/vra_nh.json" -# NUM_DISTRICTS = 2 +class RBAMarkovChain(MarkovChain): + """Markov Chain for The Rebalancing Act. Changes to gerrychain.MarkovChain: + - Ignores validity of initial plan. + """ + def __init__(self, proposal, constraints, accept, initial_state, total_steps): + try: + super().__init__(proposal, constraints, accept, initial_state, total_steps) + except ValueError as e: + if "initial_state" in repr(e): + warnings.warn("GerryChain error was ignored: " + repr(e)) + # Initialize with no constraints. + super().__init__(proposal, [], accept, initial_state, total_steps) + if callable(constraints): + self.is_valid = constraints + else: + self.is_valid = Validator(constraints) + else: + raise e + +@dataclass +class SimplePartition: + """Only stores parts and assignment for easy pickling. + """ + parts: dict + assignment: dict # UPDATERS @@ -54,11 +84,6 @@ def create_updaters(edge_lifetimes, vra_config, vra_threshold): def create_constraints(initial_partition, vra_config): - # CONSTRAINTS - - # NOTE: we said we wouldn't have a compactness constraint but GerryChain uses one in their example - # showing that maybe it's necessary even for ReCom. This keeps the proposals within 2x the number of - # cut edges in the starting one. # compactness_bound = constraints.UpperBound( # lambda p: len(p["cut_edges"]), # 2 * len(initial_partition["cut_edges"]) @@ -79,139 +104,326 @@ def create_constraints(initial_partition, vra_config): return all_constraints -def markov_chain(graph_file, community_file, vra_config_file, initial_assignment, num_districts, verbose=False): +def generate_ensemble(graph, edge_lifetimes, num_vra_districts, vra_threshold, + pop_equality_threshold, num_steps, num_districts, initial_assignment=None, + output_dir=None, verbose=False): + """Conduct the ensemble analysis for a state. Data is returned, but all partitions are saved + to output_dir (not wasting memory). + + Parameters + ---------- + graph : gerrychain.Graph + The state graph of precincts. + edge_lifetimes : dict + Maps edges (tuples of precinct IDs) + num_vra_districts : dict + Maps the name of each minority to the minimum number of VRA districts required for it. + vra_threshold : float + Between 0 and 1. The minimum percentage required to consider a district + "minority opportunity." + pop_equality_threshold : float + Between 0 and 1. The allowed percent deviation allowed between the population of any two + districts. + num_steps : int + The number of iterations to run the markov chain for. + num_districts : int + The number of districts to partition the state into. + initial_assignment : dict, default=None + Maps nodes to districts for the initial plan in the Markov chain. If set to None, a random + partition is used. + output_dir : str + Path to directory where every partition produced will be saved (as pickle files). In each + pickle, a tuple is stored containing the gerrychain.Partition object as well as the order + of the districts based on gerrymandering score (for matching up with the dataframe). + verbose : boolean, default=False + Controls verbosity. + + Returns + ------- + df : pandas.DataFrame + Contains gerrymandering scores of the state and all the districts for each step in the + Markov Chain. """ - Conduct the ensemble analysis for a state. - """ - # LOADING DATA - - with open(graph_file, "r") as f: - data = json.load(f) - nx_graph = nx.readwrite.json_graph.adjacency_graph(data) - graph = Graph.from_networkx(nx_graph) - del nx_graph - - with open(community_file, "r") as f: - community_data = json.load(f) - - edge_lifetimes = {} - for edge, lifetime in community_data["edge_lifetimes"].items(): - u = edge.split(",")[0][2:-1] - v = edge.split(",")[1][2:-2] - edge_lifetimes[(u, v)] = lifetime - - with open(vra_config_file, "r") as f: - vra_config = json.load(f) - - vra_threshold = vra_config["opportunity_threshold"] - del vra_config["opportunity_threshold"] - - rba_updaters = create_updaters(edge_lifetimes, vra_config, vra_threshold) - - # INITIAL STATE + rba_updaters = create_updaters(edge_lifetimes, num_vra_districts, vra_threshold) state_population = 0 for node in graph: state_population += graph.nodes[node]["total_pop"] ideal_population = state_population / num_districts - # initial_assignment = recursive_tree_part( - # graph, range(num_districts), - # pop_target=ideal_population, - # pop_col="total_pop", - # epsilon=constants.POP_EQUALITY_THRESHOLD) + if initial_assignment is None: + if verbose: + print("Creating random initial partition...", end="") + sys.stdout.flush() + initial_assignment = recursive_tree_part( + graph, range(num_districts), + pop_target=ideal_population, + pop_col="total_pop", + epsilon=constants.POP_EQUALITY_THRESHOLD) + if verbose: + print("done!") initial_partition = Partition(graph, initial_assignment, rba_updaters) - for part in initial_partition.parts: - pop_sum = 0 - for node in initial_partition.parts[part]: - pop_sum += graph.nodes[node]["total_pop"] - print(part, pop_sum) - visualize_partition_geopandas(initial_partition) - # PROPOSAL METHOD + visualize_partition_geopandas(initial_partition) weighted_recom_proposal = partial( recom, pop_col="total_pop", pop_target=ideal_population, - epsilon=constants.POP_EQUALITY_THRESHOLD, + epsilon=pop_equality_threshold, node_repeats=2, + # method=partial( + # bipartition_tree, + # spanning_tree_fn=get_county_weighted_random_spanning_tree) method=partial( bipartition_tree, - spanning_tree_fn=get_county_weighted_random_spanning_tree) + spanning_tree_fn=get_county_spanning_forest, + choice=partial(choose_cut, graph=graph)) ) # recom_proposal = partial(recom, # pop_col="total_pop", # pop_target=ideal_population, - # epsilon=constants.POP_EQUALITY_THRESHOLD, + # epsilon=pop_equality_threshold, # node_repeats=2 # ) - all_constraints = create_constraints(initial_partition, vra_config) + all_constraints = create_constraints(initial_partition, num_vra_districts) - chain = MarkovChain( + chain = RBAMarkovChain( # proposal=recom_proposal, proposal=weighted_recom_proposal, constraints=all_constraints, # accept=lambda p: random.random() < get_county_border_proportion(p), accept=accept.always_accept, initial_state=initial_partition, - total_steps=15 + total_steps=num_steps ) - df = pd.DataFrame(columns=[f"district{i}" for i in range(1, num_districts + 1)] + ["state_gerry_score"], dtype=float) + scores_df = pd.DataFrame(columns=[f"district {i}" for i in range(1, num_districts + 1)] + ["state_gerry_score"], dtype=float) - saved_partitions = [] - precinct_scores = {precinct : [] for precinct in initial_partition.graph.nodes} - for i, partition in enumerate(chain.with_progress_bar()): - district_scores, state_score = partition["gerry_scores"] - assignment = partition.assignment - # print(district_scores) - # print(assignment) - for j, node in enumerate(partition.graph.nodes): - # if j < 3: - # print(node, assignment[node], district_scores[assignment[node]]) - precinct_scores[node].append(district_scores[assignment[node]]) - if i == 0: - graph.nodes[node]["precinct_scores"] = [district_scores[assignment[node]]] - else: - graph.nodes[node]["precinct_scores"].append(district_scores[assignment[node]]) - df.loc[len(df.index)] = sorted(list(district_scores.keys())) + [state_score] - if i % 5 == 0: - saved_partitions.append(partition) + if output_dir is not None: + create_folder(output_dir) + create_folder(os.path.join(output_dir, "plans")) - for i, partition in enumerate(saved_partitions): - visualize_partition_geopandas(partition, i=i) # TODO: add titles for what index in the chain each image is from. - # print(precinct_scores) - plt.hist(df["state_gerry_score"], bins=10) - plt.show() - plt.savefig("ensemble_analysis_results.png") + if verbose: + print("Running Markov chain...") + sys.stdout.flush() + chain_iter = chain.with_progress_bar() + else: + chain_iter = chain + for i, partition in enumerate(chain_iter, start=1): + district_scores, state_score = partition["gerry_scores"] + districts_order = sorted(list(district_scores.keys()), key=lambda d: district_scores[d]) + scores_df.loc[len(scores_df.index)] = [district_scores[d] for d in districts_order] + [state_score] + if output_dir is not None: + with open(os.path.join(output_dir, "plans", f"{i}.pickle"), "wb+") as f: + pickle.dump((SimplePartition(partition.parts, partition.assignment), districts_order), f) - for i, partition in enumerate(saved_partitions): - visualize_partition_geopandas(partition) # TODO: add titles for what index in the chain each image is from. - return graph + return scores_df -def ensemble_analysis(graph_file, community_file, vra_config_file, district_file, verbose=False): +def ensemble_analysis(graph_file, community_file, vra_config_file, num_steps, num_districts, + initial_plan_file, district_file, output_dir, verbose=False): """Conducts a geographic ensemble analysis of a state's gerrymandering. """ - districts, district_scores, state_score = quantify_districts(graph_file, district_file, community_file, verbose) - assignment = {} - for district, node_list in districts.items(): - for node in node_list: - assignment[node] = district - graph = markov_chain(graph_file, community_file, vra_config_file, assignment, len(districts), verbose) - for precinct in graph.nodes: - scores = graph.nodes[precinct]["precinct_scores"] - mean_score = sum(scores)/len(scores) - stdev = statistics.stdev(scores) - real_score = district_scores[assignment[precinct]] - z_score = (real_score-mean_score)/stdev - print(scores, mean_score, stdev, real_score, z_score) - graph.nodes[precinct]["z_score"] = z_score - graph.nodes[precinct]["distribution_score"] = mean_score - graph.nodes[precinct]["real_score"] = real_score - # print(graph.nodes[graph.nodes[0]]) - visualize_metric("geographic_ensemble_analysis.png", graph, "z_score") + # NOTE: does not create reproducibility. + gerrychain.random.random.seed(2023) + random.seed(2023) + + if verbose: + print("Loading precinct graph...", end="") + sys.stdout.flush() + + with open(graph_file, "r") as f: + data = json.load(f) + nx_graph = nx.readwrite.json_graph.adjacency_graph(data) + graph = Graph.from_networkx(nx_graph) + del nx_graph + + if verbose: + print("done!") + print("Loading community algorithm output...", end="") + sys.stdout.flush() + + with open(community_file, "r") as f: + community_data = json.load(f) + + edge_lifetimes = {} + for edge, lifetime in community_data["edge_lifetimes"].items(): + u = edge.split(",")[0][2:-1] + v = edge.split(",")[1][2:-2] + edge_lifetimes[(u, v)] = lifetime + + if verbose: + print("done!") + print("Loading VRA requirements...", end="") + sys.stdout.flush() + + with open(vra_config_file, "r") as f: + vra_config = json.load(f) + vra_threshold = vra_config["opportunity_threshold"] + del vra_config["opportunity_threshold"] + + if verbose: + print("done!") + + if initial_plan_file is not None: + if verbose: + print("Loading starting map...", end="") + sys.stdout.flush() + + initial_plan_node_sets = load_districts(graph, initial_plan_file, verbose) + initial_assignment = {} + for district, nodes in initial_plan_node_sets.items(): + for node in nodes: + initial_assignment[node] = district + + if verbose: + print("done!") + else: + if verbose: + print("No starting map provided. Will generate a random one later.") + initial_assignment = None + + scores_df = generate_ensemble(graph, edge_lifetimes, vra_config, vra_threshold, + constants.POP_EQUALITY_THRESHOLD, num_steps, num_districts, + initial_assignment, output_dir, verbose) + + scores_df.to_csv(os.path.join(output_dir, "scores.csv")) + + # Save a histogram of statewide scores. + plt.hist(scores_df["state_gerry_score"], bins=10) + plt.savefig(os.path.join(output_dir, "score_distribution.png")) + + create_folder(os.path.join(output_dir, "visuals")) + + if verbose: + print("Calculating precinct-level statistics and visualizing partitions...") + + sorted_node_names = sorted(list(graph.nodes)) + # "score" is referring to the "goodness" score and "homogeneity" is referring to the homogeneity + # metric (in this case standard deviation of republican vote share). This uses Welford's + # algorithm to calculate mean and variance one at a time instead of saving all the values to + # memory (there will be num_steps * len(graph.nodes) values. that is a lot.) + score_accumulator = Welford() + homogeneity_accumulator = Welford() + if verbose: + step_iter = tqdm(range(num_steps)) + else: + step_iter = range(num_steps) + for i in step_iter: + with open(os.path.join(output_dir, "plans", f"{i + 1}.pickle"), "rb") as f: + partition, district_order = pickle.load(f) + + part_values = {} # part: (score, homogeneity) + for part in partition.parts: + score = scores_df.loc[i, f"district {district_order.index(part) + 1}"] + homogeneity = statistics.stdev( + [graph.nodes[node]["total_rep"] / graph.nodes[node]["total_votes"] + for node in partition.parts[part]] + ) + part_values[part] = (score, homogeneity) + score_sample = np.zeros((len(sorted_node_names),)) + homogeneity_sample = np.zeros((len(sorted_node_names),)) + for j, precinct in enumerate(sorted_node_names): + score_sample[j] = part_values[partition.assignment[precinct]][0] + homogeneity_sample[j] = part_values[partition.assignment[precinct]][1] + score_accumulator.add(score_sample) + homogeneity_accumulator.add(homogeneity_sample) + + # Visualize 100 partitions, or however many there are if there are less than 100. + if num_steps >= 100: + visualize = i % (num_steps // 100) == 0 + else: + visualize = True + if visualize: + visualize_partition_geopandas( + partition, graph=graph, img_path=os.path.join(output_dir, "visuals", f"{i + 1}.png")) + + if verbose: + print("Evaluating inputted district map...", end="") + sys.stdout.flush() + + precinct_df = pd.DataFrame(columns=["avg_score", "stdev_score", "avg_homogeneity", + "stdev_homogeneity"], + index=sorted_node_names) + for i, precinct in enumerate(sorted_node_names): + precinct_df.loc[precinct] = [ + score_accumulator.mean[i], + math.sqrt(score_accumulator.var_s[i]), + homogeneity_accumulator.mean[i], + math.sqrt(homogeneity_accumulator.var_s[i]) + ] + + districts_precinct_df = pd.DataFrame(columns=["score", "homogeneity"], index=sorted_node_names) + district_node_sets = load_districts(graph, district_file, verbose) + district_scores, _ = quantify_gerrymandering(graph, district_node_sets, edge_lifetimes, verbose) + for district, precincts in district_node_sets.items(): + homogeneity = statistics.stdev( + [graph.nodes[node]["total_rep"] / graph.nodes[node]["total_votes"] + for node in precincts] + ) + for precinct in precincts: + districts_precinct_df.loc[precinct] = [district_scores[district], homogeneity] + + # Create gerrymandering and packing/cracking heatmaps for the inputted districting plan. + + def get_z_score(precinct, metric): + mean = precinct_df.loc[precinct, f"avg_{metric}"] + stdev = precinct_df.loc[precinct, f"stdev_{metric}"] + flag = districts_precinct_df.loc[precinct, metric] + return (flag - mean) / stdev + + # Needed for drawing district boundaries + districts_assignment = {} + for district, nodes in district_node_sets.items(): + for node in nodes: + districts_assignment[node] = district + districts_partition = Partition(graph, assignment=districts_assignment) + + # TODO: this doesn't work with Maryland for some reason + + _, ax = plt.subplots(figsize=(12.8, 9.6)) + visualize_gradient_geopandas( + sorted_node_names, + get_value=partial(get_z_score, metric="score"), + get_geometry=lambda p: graph.nodes[p]["geometry"], + clear=False, + ax=ax, + legend=True, + ) + visualize_partition_geopandas( + districts_partition, + union=True, + img_path=os.path.join(output_dir, "gerry_scores.png"), + clear=True, + ax=ax, + facecolor="none", + edgecolor="black", + linewidth=0.5 + ) + + _, ax = plt.subplots(figsize=(12.8, 9.6)) + visualize_gradient_geopandas( + sorted_node_names, + get_value=partial(get_z_score, metric="homogeneity"), + get_geometry=lambda p: graph.nodes[p]["geometry"], + clear=False, + ax=ax, + legend=True + ) + visualize_partition_geopandas( + districts_partition, + union=True, + img_path=os.path.join(output_dir, "packing_cracking.png"), + clear=True, + ax=ax, + facecolor="none", + edgecolor="black", + linewidth=0.5 + ) + + if verbose: + print("done!") \ No newline at end of file diff --git a/rba/optimization.py b/rba/optimization.py index 55e6f08..b23e838 100644 --- a/rba/optimization.py +++ b/rba/optimization.py @@ -12,8 +12,7 @@ import sys import warnings -from gerrychain import Partition, Graph, MarkovChain, updaters, constraints, accept -from gerrychain.constraints import Validator +from gerrychain import Partition, Graph, updaters, constraints, accept from gerrychain.proposals import recom from gerrychain.tree import recursive_tree_part, bipartition_tree, uniform_spanning_tree import gerrychain.random @@ -22,12 +21,13 @@ from . import constants from .district_quantification import quantify_gerrymandering +from .ensemble import RBAMarkovChain from .util import (get_num_vra_districts, load_districts, get_county_spanning_forest, save_assignment, choose_cut) from .visualization import visualize_partition_geopandas -class SimulatedAnnealingChain(MarkovChain): +class SimulatedAnnealingChain(RBAMarkovChain): """Simulated annealing Markov Chain. Major changes to gerrychain.MarkovChain: - `get_temperature` is now the first positional argument, and it must take the current iteration and return the current temperature. @@ -44,21 +44,8 @@ def get_temperature_linear(i, num_steps): "linear": get_temperature_linear } - def __init__(self, get_temperature, proposal, constraints, accept, initial_state, total_steps): - try: - super().__init__(proposal, constraints, accept, initial_state, total_steps) - except ValueError as e: - if "initial_state" in repr(e): - warnings.warn("GerryChain error was ignored: " + repr(e)) - # Initialize with no constraints. - super().__init__(proposal, [], accept, initial_state, total_steps) - if callable(constraints): - self.is_valid = constraints - else: - self.is_valid = Validator(constraints) - else: - raise e - + def __init__(self, get_temperature, *args, **kwargs): + super().__init__(*args, **kwargs) self.get_temperature = get_temperature def __iter__(self): @@ -188,10 +175,10 @@ def generate_districts_simulated_annealing(graph, edge_lifetimes, num_vra_distri # pop_target=ideal_population, # epsilon=constants.POP_EQUALITY_THRESHOLD, # node_repeats=2, - # method=partial( - # bipartition_tree, - # spanning_tree_fn=get_county_spanning_forest, - # choice=partial(choose_cut, graph=graph)) + # method=partial( + # bipartition_tree, + # spanning_tree_fn=get_county_spanning_forest, + # choice=partial(choose_cut, graph=graph)) # ) recom_proposal = partial(recom, @@ -281,6 +268,7 @@ def optimize(graph_file, communitygen_out_file, vra_config_file, num_steps, num_ initial_plan_file, output_dir, verbose): """Wrapper function for command-line usage. """ + # NOTE: does not create reproducibility. gerrychain.random.random.seed(2023) random.seed(2023) @@ -309,7 +297,7 @@ def optimize(graph_file, communitygen_out_file, vra_config_file, num_steps, num_ if verbose: print("done!") - print("VRA requirements...", end="") + print("Loading VRA requirements...", end="") sys.stdout.flush() with open(vra_config_file, "r") as f: diff --git a/rba/util.py b/rba/util.py index b1eb80c..66bbb82 100644 --- a/rba/util.py +++ b/rba/util.py @@ -3,6 +3,7 @@ import json import random +import os import time from collections import defaultdict @@ -21,6 +22,15 @@ # from .district_quantification import quantify_gerrymandering +def create_folder(path): + """Creates a folder but does not throw an exception if it already exists. + """ + try: + os.mkdir(path) + except FileExistsError: + pass + + def copy_adjacency(graph): """Copies adjacency information from a graph but not attribute data. """ diff --git a/rba/visualization.py b/rba/visualization.py index 3e4b0b9..a850a5b 100644 --- a/rba/visualization.py +++ b/rba/visualization.py @@ -107,45 +107,95 @@ def modify_coords(coords, bounds): return new_coords -def visualize_partition_geopandas(partition, *args, union=False, i=None, **kwargs): +def visualize_gradient_geopandas(precincts, get_value, get_geometry, *args, img_path=None, + show=False, clear=True, **kwargs): + """Visualizes a variable on a gradient using geopandas. + + Parameters + ---------- + precincts : container of precincts + Contains a list of precinct names. + get_value : callable + Returns a metric value for a given precinct + get_geometry : callable + Returns the geometry (GeoJSON list form, not shapely) for a given precinct. + img_path : str, default=None + Optional path to save the image. + show : bool, default=False + Whether or not to call plt.show() + clear : bool, default=True + Whether or not to call plt.clf() + Also takes any parameters taken by geopandas.GeoDataFrame.plot() + """ + gdf = geopandas.GeoDataFrame(columns=["val", "geometry"]) + for precinct in precincts: + # gdf.loc[len(gdf.index)] = [get_value(precinct), shapely.geometry.shape(get_geometry(precinct))] + gdf.loc[len(gdf.index)] = [get_value(precinct), get_geometry(precinct)] + gdf.plot( + column="val", + *args, + **{key: arg for key, arg in kwargs.items() if key not in ["img_path", "show", "clear"]} + ) + if img_path is not None: + plt.savefig(img_path) + if show: + plt.show() + if clear: + plt.clf() + + +def visualize_partition_geopandas(partition, *args, graph=None, union=False, img_path=None, + show=False, clear=True, **kwargs): """Visualizes a gerrychain.Partition object using geopandas. Parameters ---------- partition : gerrychain.Partition Partition to visualize. - union : boolean, default=False + graph : gerrychain.Graph, default=None + State precinct graph. Only needs to be provided if partition is a SimplePartition. + union : bool, default=False Whether or not to visualize the partitions as a single polygon, as opposed to just showing their assignment by coloring. + img_path : str, default=None + Optional path to save the image. + show : bool, default=False + Whether or not to call plt.show() + clear : bool, default=True + Whether or not to call plt.clf() Also takes any parameters taken by geopandas.GeoDataFrame.plot() """ + if graph is None: + graph = partition.graph + if union: data = {"assignment": [], "geometry": []} for part in partition.parts: data["assignment"].append(part) - geoms = [shapely.geometry.shape(partition.graph.nodes[node]["geometry"]) + geoms = [shapely.geometry.shape(graph.nodes[node]["geometry"]) for node in partition.parts[part]] data["geometry"].append(shapely.ops.unary_union(geoms)) else: data = {"assignment": [], "geometry": []} - for node in partition.graph: + for node in graph: data["assignment"].append(partition.assignment[node]) - data["geometry"].append(shapely.geometry.shape(partition.graph.nodes[node]["geometry"])) + data["geometry"].append(shapely.geometry.shape(graph.nodes[node]["geometry"])) gdf = geopandas.GeoDataFrame(data) del data gdf.plot( column="assignment", *args, - **{key: arg for key, arg in kwargs.items() if key != "union"} + **{key: arg for key, arg in kwargs.items() if key not in ["union", "img_path", "show", "clear"]} ) - if i: - plt.savefig(f"partition_plot_{i}.png") - else: - plt.savefig(f"partition_plot_initial.png") - plt.clf() + if img_path is not None: + plt.savefig(img_path) + if show: + plt.show() + if clear: + plt.clf() def visualize_map(graph, output_fpath, node_coords, edge_coords, node_colors=None, edge_colors=None, diff --git a/setup.py b/setup.py index 9c09937..8cb0018 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,7 @@ "matplotlib==3.5.1", "pandas==1.4.1", "gerrychain==0.2.20", - "maup==1.0.8" + "maup==1.0.8", + "welford==0.2.5" ] ) \ No newline at end of file