Skip to content

Commit

Permalink
Merge pull request #132 from lucasgrjn/plot_component
Browse files Browse the repository at this point in the history
Add separated component field plotting
  • Loading branch information
HelgeGehring authored Mar 7, 2024
2 parents 7918fa8 + a573be7 commit eaf36ae
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 17 deletions.
2 changes: 1 addition & 1 deletion docs/photonics/examples/effective_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@

for mode in modes:
if mode.tm_fraction > 0.5:
# mode.show(np.real(mode.E))
# mode.show("E", part="real")
print(f"Effective refractive index: {mode.n_eff:.4f}")
print(f"Effective mode area: {mode.calculate_effective_area(field='y'):.4f}")
print(f"Mode transversality: {mode.transversality}")
Expand Down
6 changes: 3 additions & 3 deletions docs/photonics/examples/selecting_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@
modes = compute_modes(basis0, epsilon, wavelength=wavelength, num_modes=4)

for mode in modes:
mode.show(mode.E.real, direction="x")
mode.show("E", part="real")

print(f"The effective index of the SiN mode is {np.real(modes[2].n_eff)}")

Expand Down Expand Up @@ -158,7 +158,7 @@
modes = compute_modes(basis0, epsilon, wavelength=wavelength, num_modes=2)

for mode in modes:
mode.show(mode.E.real, direction="x")
mode.show("E", part="real")

print(f"The effective index of the SiN mode is {np.real(modes[0].n_eff)}")

Expand All @@ -180,7 +180,7 @@
modes = compute_modes(basis0, epsilon, wavelength=wavelength, num_modes=2, n_guess=1.62)

for mode in modes:
mode.show(mode.E.real, direction="x")
mode.show("E", part="real")

print(f"The effective index of the SiN mode is {np.real(modes[1].n_eff)}")

Expand Down
10 changes: 3 additions & 7 deletions docs/photonics/examples/waveguide_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,16 @@
modes = compute_modes(basis0, epsilon, wavelength=wavelength, num_modes=2, order=2)
for mode in modes:
print(f"Effective refractive index: {mode.n_eff:.4f}")
mode.show(mode.E.real, colorbar=True, direction="x")
mode.show(mode.E.imag, colorbar=True, direction="x")
mode.show("E", part="real", colorbar=True)
mode.show("E", part="imag", colorbar=True)


# %% [markdown]
# The intensity can be plotted directly from the mode object
# +

# %%
fig, ax = plt.subplots()
modes[0].plot_intensity(ax=ax)
plt.title("Normalized Intensity")
plt.tight_layout()
plt.show()
modes[0].show("I", colorbar=True)

# %% [markdown]
# -
Expand Down
4 changes: 2 additions & 2 deletions femwell/examples/coplanar_waveguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def mesh_coax(filename, radius_inner, radius_outer):
)
print("propagation constants", 1 / modes.n_effs)

modes[0].show(modes[0].E.real, plot_vectors=True)
modes[0].show("E", part="real", plot_vectors=True, colorbar=True)

from skfem import *
from skfem.helpers import *
Expand All @@ -139,7 +139,7 @@ def current_form(w):
currents = np.zeros((len(conductors), len(modes)))

for mode_i, mode in enumerate(modes):
mode.show(mode.H.real, plot_vectors=True)
modes[0].show("H", part="real", plot_vectors=True, colorbar=True)

(ht, ht_basis), (hz, hz_basis) = mode.basis.split(mode.H)
for conductors_i, conductor in enumerate(conductors):
Expand Down
142 changes: 138 additions & 4 deletions femwell/maxwell/waveguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from dataclasses import dataclass
from functools import cached_property
from typing import List, Tuple
from typing import Callable, List, Literal, Optional, Tuple
from warnings import warn

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -281,9 +282,142 @@ def plot(self, field, plot_vectors=False, colorbar=True, direction="y", title="E
direction=direction,
)

def show(self, field, **kwargs):
self.plot(field=field, **kwargs)
plt.show()
def plot_component(
self,
field: Literal["E", "H"],
component: Literal["x", "y", "z", "n", "t"],
part: Literal["real", "imag", "abs"] | Callable = "real",
boundaries: bool = True,
colorbar: bool = False,
ax: Axes = None,
):
from mpl_toolkits.axes_grid1 import make_axes_locatable

if part == "real":
conv_func = np.real
elif part == "imag":
conv_func = np.imag
elif part == "abs":
conv_func = np.abs
elif isinstance(part, Callable):
conv_func = part
else:
raise ValueError("A valid part is 'real', 'imag' or 'abs'.")

if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111)

if field == "E":
mfield = self.E
elif field == "H":
mfield = self.H
else:
raise ValueError("A valid field is 'E' or 'H'.")

(mfield_t, mfield_t_basis), (mfield_n, mfield_n_basis) = self.basis.split(mfield)

if component == "x" or component == "y":
plot_basis = mfield_t_basis.with_element(ElementVector(ElementDG(ElementTriP1())))
mfield_xy = plot_basis.project(mfield_t_basis.interpolate(conv_func(mfield_t)))
(mfield_x, mfield_x_basis), (mfield_y, mfield_y_basis) = plot_basis.split(mfield_xy)
if component == "x":
mfield_x_basis.plot(mfield_x, ax=ax, shading="gouraud")
else:
mfield_y_basis.plot(mfield_y, ax=ax, shading="gouraud")
elif component == "t":
plot_basis = mfield_t_basis
mfield_t_basis.plot(conv_func(mfield_t), ax=ax, shading="gouraud")
elif component == "z" or component == "n":
plot_basis = mfield_n_basis
mfield_z, mfield_z_basis = mfield_n, mfield_n_basis
mfield_z_basis.plot(conv_func(mfield_z), ax=ax, shading="gouraud")
else:
raise ValueError("A valid component is 'x', 'y', 'z', 'n' or 't'.")

if boundaries:
for subdomain in plot_basis.mesh.subdomains.keys() - {"gmsh:bounding_entities"}:
plot_basis.mesh.restrict(subdomain).draw(ax=ax, boundaries_only=True, color="k")
if colorbar:
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(ax.collections[-1], cax=cax)

ax.set_title(f"{field}{component} ({conv_func.__name__}. part)")

return ax

# def show(
# self,
# field: Literal["E", "H"] ,
# part: Literal["real", "imag", "abs"] = "real",
# boundaries: bool = True,
# colorbar: bool = False,
# ):
# fig, axs = plt.subplots(1, 3, subplot_kw=dict(aspect=1))

# for id_ax, comp in enumerate("xyz"):
# self.plot_component(field, comp, part, boundaries, colorbar, axs[id_ax])
# plt.tight_layout()
# plt.show()

def show(
self,
field: Literal["E", "H", "I"] | NDArray,
part: Literal["real", "imag", "abs"] | Callable = "real",
plot_vectors: bool = False,
boundaries: bool = True,
colorbar: bool = False,
direction: Literal["x", "y"] = "x",
title: Optional[str] = None,
):
"""Plots the different quantities associated with a field.
Args:
field ("E", "H", "I"): Field of interest, can be the electric field, the magnetic field or the intensity of the mode.
part ("real", "imag", "abs", callable): Function to use to preprocess the field to be plotted. Defaults to "real".
plot_vectors (bool): If set to True, plot the normal and tangential component
boundaries (bool): Superimpose the mesh boundaries on the plot. Defaults to True.
colorbar (bool): Adds a colorbar to the plot. Defaults to False.
direction ("x", "y"): Orientation of the plots ("x" for horizontal and "y" for vertical) Defaults to "x".
Returns:
Tuple[Figure, Axes]: Figure and axes of the plot.
"""
if type(field) is np.ndarray:
warn(
"The behavior of passing an array directly to `show` "
+ "is deprecated and will be removed in the future. "
+ "Use `plot` instead.",
DeprecationWarning,
stacklevel=2,
)
self.plot(field, plot_vectors, colorbar, direction, title)
plt.show()
else:
from mpl_toolkits.axes_grid1 import make_axes_locatable

if plot_vectors is True:
if field == "I":
return ValueError(
"'plot_vectors' is used to plot the tangential components "
+ "of a field. Thus it can be used only with 'E' or 'H'."
)
rc = (2, 1) if direction != "x" else (1, 2)
fig, axs = plt.subplots(*rc, subplot_kw=dict(aspect=1))

self.plot_component(field, "t", part, boundaries, colorbar, axs[0])
self.plot_component(field, "n", part, boundaries, colorbar, axs[1])
elif field == "I":
fig, ax = self.plot_intensity(ax=None, colorbar=colorbar)
else:
rc = (3, 1) if direction != "x" else (1, 3)
fig, axs = plt.subplots(*rc, subplot_kw=dict(aspect=1))

for id_ax, comp in enumerate("xyz"):
self.plot_component(field, comp, part, boundaries, colorbar, axs[id_ax])
fig.suptitle(title if title else field)
plt.tight_layout()
plt.show()

def plot_intensity(
self,
Expand Down

0 comments on commit eaf36ae

Please sign in to comment.