diff --git a/netsalt/quantum_graph.py b/netsalt/quantum_graph.py index e002b75..b20be0b 100644 --- a/netsalt/quantum_graph.py +++ b/netsalt/quantum_graph.py @@ -208,46 +208,45 @@ def oversample_graph(graph, edge_size): # pylint: disable=too-many-locals oversampled_graph = graph.copy() for ei, (u, v) in enumerate(graph.edges): last_node = len(oversampled_graph) - if graph[u][v]["inner"]: - n_nodes = int(graph[u][v]["length"] / edge_size) - if n_nodes > 1: - dielectric_constant = graph[u][v]["dielectric_constant"] - pump = graph[u][v]["pump"] - oversampled_graph.remove_edge(u, v) - - for node_index in range(n_nodes - 1): - node_position_x = graph.nodes[u]["position"][0] + (node_index + 1) / n_nodes * ( - graph.nodes[v]["position"][0] - graph.nodes[u]["position"][0] - ) - node_position_y = graph.nodes[u]["position"][1] + (node_index + 1) / n_nodes * ( - graph.nodes[v]["position"][1] - graph.nodes[u]["position"][1] - ) - node_position = np.array([node_position_x, node_position_y]) - - if node_index == 0: - first, last = u, last_node - else: - first, last = last_node + node_index - 1, last_node + node_index - - oversampled_graph.add_node(last, position=node_position) - oversampled_graph.add_edge( - first, - last, - inner=True, - dielectric_constant=dielectric_constant, - pump=pump, - edgelabel=ei, - ) + n_nodes = int(graph[u][v]["length"] / edge_size) + if n_nodes > 1: + dielectric_constant = graph[u][v]["dielectric_constant"] + pump = graph[u][v]["pump"] + oversampled_graph.remove_edge(u, v) + + for node_index in range(n_nodes - 1): + node_position_x = graph.nodes[u]["position"][0] + (node_index + 1) / n_nodes * ( + graph.nodes[v]["position"][0] - graph.nodes[u]["position"][0] + ) + node_position_y = graph.nodes[u]["position"][1] + (node_index + 1) / n_nodes * ( + graph.nodes[v]["position"][1] - graph.nodes[u]["position"][1] + ) + node_position = np.array([node_position_x, node_position_y]) + if node_index == 0: + first, last = u, last_node + else: + first, last = last_node + node_index - 1, last_node + node_index + + oversampled_graph.add_node(last, position=node_position) oversampled_graph.add_edge( - last_node + node_index, - v, + first, + last, inner=True, dielectric_constant=dielectric_constant, pump=pump, edgelabel=ei, ) + oversampled_graph.add_edge( + last_node + node_index, + v, + inner=True, + dielectric_constant=dielectric_constant, + pump=pump, + edgelabel=ei, + ) + oversampled_graph = nx.convert_node_labels_to_integers(oversampled_graph) _set_edge_lengths(oversampled_graph) params = {"inner": [oversampled_graph[u][v]["inner"] for u, v in oversampled_graph.edges]}