diff --git a/docs/photonics/examples/effective_area.py b/docs/photonics/examples/effective_area.py index e35d3596..84a4c1fe 100644 --- a/docs/photonics/examples/effective_area.py +++ b/docs/photonics/examples/effective_area.py @@ -95,7 +95,7 @@ for mode in modes: if mode.tm_fraction > 0.5: - # mode.show(np.real(mode.E)) + # mode.show("E", part="real") print(f"Effective refractive index: {mode.n_eff:.4f}") print(f"Effective mode area: {mode.calculate_effective_area(field='y'):.4f}") print(f"Mode transversality: {mode.transversality}") diff --git a/docs/photonics/examples/selecting_modes.py b/docs/photonics/examples/selecting_modes.py index dd57eb32..635722fd 100644 --- a/docs/photonics/examples/selecting_modes.py +++ b/docs/photonics/examples/selecting_modes.py @@ -124,7 +124,7 @@ modes = compute_modes(basis0, epsilon, wavelength=wavelength, num_modes=4) for mode in modes: - mode.show(mode.E.real, direction="x") + mode.show("E", part="real") print(f"The effective index of the SiN mode is {np.real(modes[2].n_eff)}") @@ -158,7 +158,7 @@ modes = compute_modes(basis0, epsilon, wavelength=wavelength, num_modes=2) for mode in modes: - mode.show(mode.E.real, direction="x") + mode.show("E", part="real") print(f"The effective index of the SiN mode is {np.real(modes[0].n_eff)}") @@ -180,7 +180,7 @@ modes = compute_modes(basis0, epsilon, wavelength=wavelength, num_modes=2, n_guess=1.62) for mode in modes: - mode.show(mode.E.real, direction="x") + mode.show("E", part="real") print(f"The effective index of the SiN mode is {np.real(modes[1].n_eff)}") diff --git a/docs/photonics/examples/waveguide_modes.py b/docs/photonics/examples/waveguide_modes.py index 72a5062e..49c0249c 100644 --- a/docs/photonics/examples/waveguide_modes.py +++ b/docs/photonics/examples/waveguide_modes.py @@ -83,8 +83,8 @@ modes = compute_modes(basis0, epsilon, wavelength=wavelength, num_modes=2, order=2) for mode in modes: print(f"Effective refractive index: {mode.n_eff:.4f}") - mode.show(mode.E.real, colorbar=True, direction="x") - mode.show(mode.E.imag, colorbar=True, direction="x") + mode.show("E", part="real", colorbar=True) + mode.show("E", part="imag", colorbar=True) # %% [markdown] @@ -92,11 +92,7 @@ # + # %% -fig, ax = plt.subplots() -modes[0].plot_intensity(ax=ax) -plt.title("Normalized Intensity") -plt.tight_layout() -plt.show() +modes[0].show("I", colorbar=True) # %% [markdown] # - diff --git a/femwell/examples/coplanar_waveguide.py b/femwell/examples/coplanar_waveguide.py index 47dae27b..a1fb39cf 100644 --- a/femwell/examples/coplanar_waveguide.py +++ b/femwell/examples/coplanar_waveguide.py @@ -127,7 +127,7 @@ def mesh_coax(filename, radius_inner, radius_outer): ) print("propagation constants", 1 / modes.n_effs) - modes[0].show(modes[0].E.real, plot_vectors=True) + modes[0].show("E", part="real", plot_vectors=True, colorbar=True) from skfem import * from skfem.helpers import * @@ -139,7 +139,7 @@ def current_form(w): currents = np.zeros((len(conductors), len(modes))) for mode_i, mode in enumerate(modes): - mode.show(mode.H.real, plot_vectors=True) + modes[0].show("H", part="real", plot_vectors=True, colorbar=True) (ht, ht_basis), (hz, hz_basis) = mode.basis.split(mode.H) for conductors_i, conductor in enumerate(conductors): diff --git a/femwell/maxwell/waveguide.py b/femwell/maxwell/waveguide.py index 204b2d74..0b4b124b 100644 --- a/femwell/maxwell/waveguide.py +++ b/femwell/maxwell/waveguide.py @@ -2,7 +2,8 @@ from dataclasses import dataclass from functools import cached_property -from typing import List, Tuple +from typing import Callable, List, Literal, Optional, Tuple +from warnings import warn import matplotlib.pyplot as plt import numpy as np @@ -281,9 +282,142 @@ def plot(self, field, plot_vectors=False, colorbar=True, direction="y", title="E direction=direction, ) - def show(self, field, **kwargs): - self.plot(field=field, **kwargs) - plt.show() + def plot_component( + self, + field: Literal["E", "H"], + component: Literal["x", "y", "z", "n", "t"], + part: Literal["real", "imag", "abs"] | Callable = "real", + boundaries: bool = True, + colorbar: bool = False, + ax: Axes = None, + ): + from mpl_toolkits.axes_grid1 import make_axes_locatable + + if part == "real": + conv_func = np.real + elif part == "imag": + conv_func = np.imag + elif part == "abs": + conv_func = np.abs + elif isinstance(part, Callable): + conv_func = part + else: + raise ValueError("A valid part is 'real', 'imag' or 'abs'.") + + if ax is None: + fig = plt.figure() + ax = fig.add_subplot(111) + + if field == "E": + mfield = self.E + elif field == "H": + mfield = self.H + else: + raise ValueError("A valid field is 'E' or 'H'.") + + (mfield_t, mfield_t_basis), (mfield_n, mfield_n_basis) = self.basis.split(mfield) + + if component == "x" or component == "y": + plot_basis = mfield_t_basis.with_element(ElementVector(ElementDG(ElementTriP1()))) + mfield_xy = plot_basis.project(mfield_t_basis.interpolate(conv_func(mfield_t))) + (mfield_x, mfield_x_basis), (mfield_y, mfield_y_basis) = plot_basis.split(mfield_xy) + if component == "x": + mfield_x_basis.plot(mfield_x, ax=ax, shading="gouraud") + else: + mfield_y_basis.plot(mfield_y, ax=ax, shading="gouraud") + elif component == "t": + plot_basis = mfield_t_basis + mfield_t_basis.plot(conv_func(mfield_t), ax=ax, shading="gouraud") + elif component == "z" or component == "n": + plot_basis = mfield_n_basis + mfield_z, mfield_z_basis = mfield_n, mfield_n_basis + mfield_z_basis.plot(conv_func(mfield_z), ax=ax, shading="gouraud") + else: + raise ValueError("A valid component is 'x', 'y', 'z', 'n' or 't'.") + + if boundaries: + for subdomain in plot_basis.mesh.subdomains.keys() - {"gmsh:bounding_entities"}: + plot_basis.mesh.restrict(subdomain).draw(ax=ax, boundaries_only=True, color="k") + if colorbar: + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="5%", pad=0.05) + plt.colorbar(ax.collections[-1], cax=cax) + + ax.set_title(f"{field}{component} ({conv_func.__name__}. part)") + + return ax + + # def show( + # self, + # field: Literal["E", "H"] , + # part: Literal["real", "imag", "abs"] = "real", + # boundaries: bool = True, + # colorbar: bool = False, + # ): + # fig, axs = plt.subplots(1, 3, subplot_kw=dict(aspect=1)) + + # for id_ax, comp in enumerate("xyz"): + # self.plot_component(field, comp, part, boundaries, colorbar, axs[id_ax]) + # plt.tight_layout() + # plt.show() + + def show( + self, + field: Literal["E", "H", "I"] | NDArray, + part: Literal["real", "imag", "abs"] | Callable = "real", + plot_vectors: bool = False, + boundaries: bool = True, + colorbar: bool = False, + direction: Literal["x", "y"] = "x", + title: Optional[str] = None, + ): + """Plots the different quantities associated with a field. + + Args: + field ("E", "H", "I"): Field of interest, can be the electric field, the magnetic field or the intensity of the mode. + part ("real", "imag", "abs", callable): Function to use to preprocess the field to be plotted. Defaults to "real". + plot_vectors (bool): If set to True, plot the normal and tangential component + boundaries (bool): Superimpose the mesh boundaries on the plot. Defaults to True. + colorbar (bool): Adds a colorbar to the plot. Defaults to False. + direction ("x", "y"): Orientation of the plots ("x" for horizontal and "y" for vertical) Defaults to "x". + Returns: + Tuple[Figure, Axes]: Figure and axes of the plot. + """ + if type(field) is np.ndarray: + warn( + "The behavior of passing an array directly to `show` " + + "is deprecated and will be removed in the future. " + + "Use `plot` instead.", + DeprecationWarning, + stacklevel=2, + ) + self.plot(field, plot_vectors, colorbar, direction, title) + plt.show() + else: + from mpl_toolkits.axes_grid1 import make_axes_locatable + + if plot_vectors is True: + if field == "I": + return ValueError( + "'plot_vectors' is used to plot the tangential components " + + "of a field. Thus it can be used only with 'E' or 'H'." + ) + rc = (2, 1) if direction != "x" else (1, 2) + fig, axs = plt.subplots(*rc, subplot_kw=dict(aspect=1)) + + self.plot_component(field, "t", part, boundaries, colorbar, axs[0]) + self.plot_component(field, "n", part, boundaries, colorbar, axs[1]) + elif field == "I": + fig, ax = self.plot_intensity(ax=None, colorbar=colorbar) + else: + rc = (3, 1) if direction != "x" else (1, 3) + fig, axs = plt.subplots(*rc, subplot_kw=dict(aspect=1)) + + for id_ax, comp in enumerate("xyz"): + self.plot_component(field, comp, part, boundaries, colorbar, axs[id_ax]) + fig.suptitle(title if title else field) + plt.tight_layout() + plt.show() def plot_intensity( self,