From 6d875d5b16e2bb20abd00618e56c62818d148e86 Mon Sep 17 00:00:00 2001 From: arnaudon Date: Thu, 10 Aug 2023 15:25:41 +0200 Subject: [PATCH] fix transfer --- netsalt/modes.py | 14 +++++++++++--- netsalt/plotting.py | 24 ++++++++++++++++++++---- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/netsalt/modes.py b/netsalt/modes.py index 968fff9..67a5fa8 100644 --- a/netsalt/modes.py +++ b/netsalt/modes.py @@ -296,7 +296,10 @@ def flux_on_edges(mode, graph): def mean_mode_on_edges(mode, graph): r"""Compute the average :math:`Real(E^2)` on each edge.""" edge_flux = flux_on_edges(mode, graph) + return mean_on_edges(edge_flux, graph) + +def mean_on_edges(edge_flux, graph): mean_edge_solution = np.zeros(len(graph.edges)) for ei in range(len(graph.edges)): k = 1.0j * graph.graph["ks"][ei] @@ -898,14 +901,19 @@ def lasing_threshold_linear(mode, graph, D0): def get_node_transfer(k, graph, input_flow): """Compute node transfer from a given input flow.""" - return sc.sparse.linalg.spsolve(construct_laplacian(k, graph), graph.graph["ks"] * input_flow) + BT, _ = construct_incidence_matrix(graph) + bt = np.clip(np.ceil(np.real(BT.toarray())), -1, 0) + K = bt.dot(np.repeat(graph.graph["ks"], 2)) + return sc.sparse.linalg.spsolve(construct_laplacian(k, graph), K * input_flow) def get_edge_transfer(k, graph, input_flow): """Compute edge transfer from a given input flow.""" set_wavenumber(graph, k) BT, B = construct_incidence_matrix(graph) - _r = get_node_transfer(k, graph, BT.dot(input_flow)) + _r = sc.sparse.linalg.spsolve( + construct_laplacian(k, graph), BT.dot(np.repeat(graph.graph["ks"], 2) * input_flow) + ) Winv = construct_weight_matrix(graph, with_k=False) return Winv.dot(B).dot(_r) @@ -926,7 +934,7 @@ def estimate_boundary_flow(graph, input_flow, k_frac=1e-2): ) e_deg = np.array([len(graph[v]) for u, v in graph.edges]) - output_ids = list(np.argwhere(e_deg == 1).flatten()) + output_ids = list(2 * np.argwhere(e_deg == 1).flatten()) output_ids += list(2 * np.argwhere(e_deg == 1).flatten() + 1) # get the flows on all nodes diff --git a/netsalt/plotting.py b/netsalt/plotting.py index 2e5cda0..38cce10 100644 --- a/netsalt/plotting.py +++ b/netsalt/plotting.py @@ -441,24 +441,40 @@ def plot_single_mode( def _plot_single_mode( graph, mode, ax=None, colorbar=True, edge_vmin=None, edge_vmax=None, cmap="coolwarm" ): - positions = [graph.nodes[u]["position"] for u in graph] edge_solution = mean_mode_on_edges(mode, graph) + return plot_on_graph( + graph, + edge_solution, + edge_vmin=edge_vmin, + edge_vmax=edge_vmax, + ax=ax, + cmap=cmap, + colorbar=colorbar, + ) + + +def plot_on_graph( + graph, edge_data, edge_vmin=None, edge_vmax=None, ax=None, cmap="coolwarm", colorbar=True +): + """Plot edge data on graph.""" if ax is None: plt.figure(figsize=(5, 4)) # 14,3 ax = plt.gca() + positions = [graph.nodes[u]["position"] for u in graph] nx.draw(graph, pos=positions, node_size=0, width=0, ax=ax) cmap = plt.get_cmap(cmap) if edge_vmax is None: - edge_vmax = max(abs(edge_solution)) + edge_vmax = max(abs(edge_data)) if edge_vmin is None: - edge_vmin = -max(abs(edge_solution)) + edge_vmin = -max(abs(edge_data)) + print(len(edge_data)) nx.draw_networkx_edges( graph, pos=positions, - edge_color=edge_solution, + edge_color=edge_data, width=2, edge_cmap=cmap, edge_vmin=edge_vmin,