Skip to content

Commit

Permalink
fix transfer
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudon committed Aug 18, 2023
1 parent e5d8cc2 commit 6d875d5
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
14 changes: 11 additions & 3 deletions netsalt/modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,10 @@ def flux_on_edges(mode, graph):
def mean_mode_on_edges(mode, graph):
r"""Compute the average :math:`Real(E^2)` on each edge."""
edge_flux = flux_on_edges(mode, graph)
return mean_on_edges(edge_flux, graph)


def mean_on_edges(edge_flux, graph):
mean_edge_solution = np.zeros(len(graph.edges))
for ei in range(len(graph.edges)):
k = 1.0j * graph.graph["ks"][ei]
Expand Down Expand Up @@ -898,14 +901,19 @@ def lasing_threshold_linear(mode, graph, D0):

def get_node_transfer(k, graph, input_flow):
"""Compute node transfer from a given input flow."""
return sc.sparse.linalg.spsolve(construct_laplacian(k, graph), graph.graph["ks"] * input_flow)
BT, _ = construct_incidence_matrix(graph)
bt = np.clip(np.ceil(np.real(BT.toarray())), -1, 0)
K = bt.dot(np.repeat(graph.graph["ks"], 2))
return sc.sparse.linalg.spsolve(construct_laplacian(k, graph), K * input_flow)


def get_edge_transfer(k, graph, input_flow):
"""Compute edge transfer from a given input flow."""
set_wavenumber(graph, k)
BT, B = construct_incidence_matrix(graph)
_r = get_node_transfer(k, graph, BT.dot(input_flow))
_r = sc.sparse.linalg.spsolve(
construct_laplacian(k, graph), BT.dot(np.repeat(graph.graph["ks"], 2) * input_flow)
)
Winv = construct_weight_matrix(graph, with_k=False)
return Winv.dot(B).dot(_r)

Expand All @@ -926,7 +934,7 @@ def estimate_boundary_flow(graph, input_flow, k_frac=1e-2):
)

e_deg = np.array([len(graph[v]) for u, v in graph.edges])
output_ids = list(np.argwhere(e_deg == 1).flatten())
output_ids = list(2 * np.argwhere(e_deg == 1).flatten())
output_ids += list(2 * np.argwhere(e_deg == 1).flatten() + 1)

# get the flows on all nodes
Expand Down
24 changes: 20 additions & 4 deletions netsalt/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,24 +441,40 @@ def plot_single_mode(
def _plot_single_mode(
graph, mode, ax=None, colorbar=True, edge_vmin=None, edge_vmax=None, cmap="coolwarm"
):
positions = [graph.nodes[u]["position"] for u in graph]
edge_solution = mean_mode_on_edges(mode, graph)

return plot_on_graph(
graph,
edge_solution,
edge_vmin=edge_vmin,
edge_vmax=edge_vmax,
ax=ax,
cmap=cmap,
colorbar=colorbar,
)


def plot_on_graph(
graph, edge_data, edge_vmin=None, edge_vmax=None, ax=None, cmap="coolwarm", colorbar=True
):
"""Plot edge data on graph."""
if ax is None:
plt.figure(figsize=(5, 4)) # 14,3
ax = plt.gca()

positions = [graph.nodes[u]["position"] for u in graph]
nx.draw(graph, pos=positions, node_size=0, width=0, ax=ax)

cmap = plt.get_cmap(cmap)
if edge_vmax is None:
edge_vmax = max(abs(edge_solution))
edge_vmax = max(abs(edge_data))
if edge_vmin is None:
edge_vmin = -max(abs(edge_solution))
edge_vmin = -max(abs(edge_data))
print(len(edge_data))
nx.draw_networkx_edges(
graph,
pos=positions,
edge_color=edge_solution,
edge_color=edge_data,
width=2,
edge_cmap=cmap,
edge_vmin=edge_vmin,
Expand Down

0 comments on commit 6d875d5

Please sign in to comment.