Skip to content

Commit

Permalink
Fig 1 updates (#31)
Browse files Browse the repository at this point in the history
* Fig 1 updates

* Fixed bugs found by nicholas + adjust fonts and sizing a bit more

* simplify code

---------

Co-authored-by: Nicholas Landry <[email protected]>
  • Loading branch information
jg-you and nwlandry authored Apr 16, 2024
1 parent cd8483b commit 253d5af
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 36 deletions.
Binary file modified Figures/Fig1/fig1.pdf
Binary file not shown.
Binary file modified Figures/Fig1/fig1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 7 additions & 7 deletions fig_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ def set_colors(n_colors=2):

def set_fonts(extra_params={}):
params = {
"font.family": "Sans-Serif",
"font.sans-serif": ["Tahoma", "DejaVu Sans", "Lucida Grande", "Verdana"],
"font.family": "Serif",
# "font.sans-serif": ["Tahoma", "DejaVu Sans", "Lucida Grande", "Verdana"],
"mathtext.fontset": "cm",
"legend.fontsize": 12,
"axes.labelsize": 15,
"axes.titlesize": 15,
"xtick.labelsize": 15,
"ytick.labelsize": 15,
"figure.titlesize": 15,
"axes.labelsize": 12,
"axes.titlesize": 12,
"xtick.labelsize": 12,
"ytick.labelsize": 12,
"figure.titlesize": 12,
}
for key, value in extra_params.items():
params[key] = value
Expand Down
79 changes: 50 additions & 29 deletions plot_fig1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
import fig_settings as fs
from lcs import *

axislabel_fontsize = 20
tick_fontsize = 18
axislabel_fontsize = 12
tick_fontsize = 12
fs.set_fonts(
{
"font.family": "sans-serif",
"font.family": "serif",
"axes.labelsize": axislabel_fontsize,
"xtick.labelsize": tick_fontsize,
"ytick.labelsize": tick_fontsize,
Expand All @@ -23,16 +23,18 @@
fs.set_colors()
cmap = fs.cmap

fig = plt.figure(figsize=(8, 6))
plt.subplots_adjust(left=0.1, right=0.86, bottom=0.15, top=0.95, wspace=0.4, hspace=0.4)
fig = plt.figure(figsize=(5.5, 5))
plt.subplots_adjust(left=0.12, right=0.89, bottom=0.15, top=0.95)

gs = GridSpec(2, 2, hspace=0.4, wspace=0.4)
gs = GridSpec(2, 2, hspace=0.6, wspace=0.5)

"""
Panel 1: Network Viz
"""
ax1 = fig.add_subplot(gs[0])

ax1.text(-0.36, 1.05, "(a)", transform=ax1.transAxes, fontsize=13, fontweight='bold', va="top")

el = zkc(format="edgelist")
H = xgi.Hypergraph(el)
A = zkc()
Expand All @@ -53,15 +55,17 @@
infected_color = "C1"
susceptible_color = "white"
subgraph_color = "black"
graph_color = (0.1, 0.1, 0.1, 0.1)
graph_color = (0.7, 0.7, 0.7, 0.5)
subgraph_node_lc = "black"
graph_node_lc = (0.3, 0.3, 0.3)

sg = H.nodes.memberships(i)
nbrs = H.nodes.neighbors(i)
nbrs.add(i)

pos = xgi.pca_transform(xgi.pairwise_spring_layout(H, seed=5, k=0.3))
# rotate pos by 30 degrees
pos = xgi.pca_transform(xgi.pairwise_spring_layout(H, seed=5, k=0.25), theta=30)

node_fc = [infected_color if x[t, i] else susceptible_color for i in H.nodes]
node_ec = [subgraph_node_lc if n in nbrs else graph_node_lc for n in H.nodes]
node_fc[12] = "C0"
Expand All @@ -75,17 +79,22 @@
node_size=6.5,
node_fc=node_fc,
dyad_color=dyad_color,
dyad_lw=0.5,
dyad_lw=0.8,
node_ec=node_ec,
node_lw=0.5,
ax=ax1,
node_lw=0.8,
ax=ax1
)
plt.scatter(pos[13][0], pos[13][1], s=50, c='C0', edgecolors='black', linewidths=0.8, zorder=10, marker='s')

ax1.set_xlim([1.1 * min([pos[i][0] for i in pos]), 1.1 * max([pos[i][0] for i in pos])])


"""
Panel 2:
"""
ax2 = fig.add_subplot(gs[1])
ax2.text(-0.39, 1.05, "(b)", transform=ax2.transAxes, fontsize=12, fontweight='bold', va="top")


with open("Data/zkc_infer_contagion_functions.json") as file:
data = json.load(file)
Expand All @@ -110,7 +119,7 @@

# simple contagion
c1_mean = c1_samples.mean(axis=0)
ax2.plot(nus, c1, "-", color="C0", label="Simple contagion")
ax2.plot(nus, c1, "-", color="C0", lw=5, alpha=0.5)

err_c1 = np.zeros((2, n))
c1_mode = np.zeros(n)
Expand All @@ -119,11 +128,15 @@
x, y = interval
err_c1[0, i] = max(c1_mean[i] - x, 0)
err_c1[1, i] = max(y - c1_mean[i], 0)
ax2.errorbar(nus, c1_mean, err_c1, color="C0", fmt="o")


offset_distance = 0.15
ax2.errorbar(nus - offset_distance, c1_mean, err_c1, color="C0", fmt="o",
capsize=3, markersize=5, markeredgecolor="#315b7d", label="Simple")

# threshold contagion, tau=2
c2_mean = c2_samples.mean(axis=0)
ax2.plot(nus, c2, "-", color="C1", label="Complex contagion")
ax2.plot(nus, c2, "-", color="C1", lw=5, alpha=0.5)

err_c2 = np.zeros((2, n))
c2_mode = np.zeros(n)
Expand All @@ -132,24 +145,28 @@
x, y = interval
err_c2[0, i] = max(c2_mean[i] - x, 0)
err_c2[1, i] = max(y - c2_mean[i], 0)
ax2.errorbar(nus, c2_mean, err_c2, color="C1", fmt="o")
ax2.errorbar(nus + offset_distance, c2_mean, err_c2, color="C1", fmt="o",
capsize=3, markersize=5, markeredgecolor="#391c23", label="Complex")


ax2.set_xticks(np.arange(0, n, 5))
ax2.set_xlabel(r"$\nu$")
ax2.set_ylabel(r"$c(\nu)$")
ax2.set_xlabel(r"# of infected neighbors, $\nu$")
ax2.set_ylabel(r"Probability, $c(\nu)$")

ax2.set_xlim([0, kmax + 2.5])
ax2.set_xlim([0, 13.5])
ax2.set_ylim([0, 1])
ax2.set_yticks([0, 0.5, 1], [0, 0.5, 1])

ax2.legend(loc="upper left")
ax2.legend(loc='upper left', bbox_to_anchor=(-0.1, 0.9, 0.2, 0.2),
handletextpad=0.1, frameon=False)

sns.despine()

""""
Panel 3: recovery vs. tmax
"""
ax3 = fig.add_subplot(gs[2])
ax3.text(-0.35, 1.05, "(c)", transform=ax3.transAxes, fontsize=12, fontweight='bold', va="top")


with open("Data/zkc_infer_vs_tmax.json") as file:
Expand All @@ -162,21 +179,23 @@
fce = np.array(data["fce"], dtype=float)


ax3.semilogx(tmax, sps[0].mean(axis=1), color="C0", label="Simple contagion")
ax3.semilogx(tmax, sps[1].mean(axis=1), color="C1", label="Complex contagion")
ax3.semilogx(tmax, sps[0].mean(axis=1), color="C0", label="Simple")
ax3.semilogx(tmax, sps[1].mean(axis=1), color="C1", label="Complex")
ax3.fill_between(
tmax,
sps[0].mean(axis=1) - sps[0].std(axis=1),
sps[0].mean(axis=1) + sps[0].std(axis=1),
alpha=0.3,
color="C0",
edgecolor="none"
)
ax3.fill_between(
tmax,
sps[1].mean(axis=1) - sps[1].std(axis=1),
sps[1].mean(axis=1) + sps[1].std(axis=1),
alpha=0.3,
color="C1",
edgecolor="none"
)
ax3.set_ylabel("Performance")
ax3.set_xlabel(r"$t_{max}$")
Expand All @@ -188,13 +207,15 @@
ax3.set_ylim([0, 1])
ax3.set_yticks([0, 0.5, 1], [0, 0.5, 1])

ax3.legend(loc="upper left")
ax3.legend(loc="lower right", bbox_to_anchor=(0.85, 0, 0.2, 0.2),
markerfirst=False, frameon=False, handlelength=0.8)
sns.despine()

"""
Panel 4: heatmap of recover vs. beta and f
"""
ax4 = fig.add_subplot(gs[3])
ax4.text(-0.38, 1.05, "(d)", transform=ax4.transAxes, fontsize=12, fontweight='bold', va="top")

with open("Data/zkc_frac_vs_beta.json") as file:
data = json.load(file)
Expand All @@ -210,24 +231,24 @@


c = ax4.imshow(
to_imshow_orientation(sps_summary),
extent=(min(frac), max(frac), min(beta), max(beta)),
np.fliplr(to_imshow_orientation(sps_summary)),
extent=(min(frac), max(frac), max(beta), min(beta)),
aspect="auto",
cmap=cmap,
vmin=0,
vmax=1,
)
ax4.set_xlabel(r"$f$")
ax4.set_ylabel(r"$\beta$")
ax4.set_xlabel(r"Complexity, $\lambda$")
ax4.set_ylabel(r"Infectivity, $\beta$")

ax4.set_xticks([0, 0.5, 1], [0, 0.5, 1])
ax4.set_yticks([0, 0.5, 1], [0, 0.5, 1])


cbar_ax = fig.add_axes([0.875, 0.15, 0.015, 0.335]) # x, y, width, height
cbar_ax = fig.add_axes([0.91, 0.15, 0.015, 0.31]) # x, y, width, height
cbar = plt.colorbar(c, cax=cbar_ax)
cbar.set_label(r"Performance", fontsize=axislabel_fontsize, rotation=270, labelpad=25)
cbar_ax.set_yticks([0, 0.5, 1], [0, 0.5, 1], fontsize=tick_fontsize)
cbar.set_label(r"Performance", fontsize=axislabel_fontsize, rotation=270, labelpad=10)
cbar_ax.set_yticks([0, 1], [0, 1], fontsize=tick_fontsize)

plt.savefig("Figures/Fig1/fig1.png", dpi=1000)
plt.savefig("Figures/Fig1/fig1.pdf", dpi=1000)
Expand Down

0 comments on commit 253d5af

Please sign in to comment.