From 9c05daee5759722337d40491be0ca2fa30b02b2a Mon Sep 17 00:00:00 2001 From: RichieHakim Date: Thu, 27 Jun 2024 16:30:25 -0400 Subject: [PATCH] New Colorwheel class --- bnpm/plotting_helpers.py | 183 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 183 insertions(+) diff --git a/bnpm/plotting_helpers.py b/bnpm/plotting_helpers.py index 99e6706..128be49 100644 --- a/bnpm/plotting_helpers.py +++ b/bnpm/plotting_helpers.py @@ -12,6 +12,9 @@ import torch import cv2 +from . import spectral +from . import math_functions + ############### #### PLOTS #### @@ -774,6 +777,186 @@ def complex_colormap( return rgb +class Colorwheel: + """ + Generates a 2D colorwheel colormap (magnitude and angle). Useful for + visualizing complex/polar values, optical flow, and other cyclic data. + RH 2024 + + Args: + rotation (float): + Rotation of the colorwheel in degrees. + saturation (float): + Saturation of the colors. + center (int): + Center of the colorwheel. + radius (int): + Radius of the colorwheel. + dtype (np.dtype): + Data type of the output colormap. + bit_depth (int): + Bit depth of the colorwheel. + exponent (float): + Exponent used to adjust the color intensity. + normalize (bool): + Whether to normalize the colorwheel. + """ + def __init__( + self, + rotation: float = 0.0, + saturation: float = 1.0, + center: int = 0, + radius: int = 255, + dtype: np.dtype = np.uint8, + bit_depth: int = 16, + exponent: float = 1.2, + normalize: bool = True, + colors: List[Union[List, Tuple]] = [ + [1 , 0 , 0 ], + [1 , 1 , 0 ], + [0 , 1 , 0 ], + [0 , 1 , 1 ], + [0 , 0 , 1 ], + [1 , 0 , 1 ], + ], + ): + """ + Initializes the ColorwheelColormap with given parameters. + """ + import scipy.interpolate + import scipy.special + + self.rotation = rotation + self.saturation = saturation + self.center = center + self.radius = radius + self.dtype = dtype + self.bit_depth = bit_depth + self.exponent = exponent + self.normalize = normalize + self.colors = np.array(colors) + + # Make a rainbow colorwheel + # Create 3 single cosine waves centered at 0, 120, and 240 degrees spanning 120 degrees each + import scipy.signal + waves, x = spectral.generate_multiphasic_sinewave( + n_samples=int(2**bit_depth), + n_periods=1, + n_waves=len(colors), + return_x=True, + ) + waves = ((waves + 1).astype(np.float64) / 2) ** exponent + + waves = (waves - waves.min()) / (waves.max() - waves.min()) + + if normalize: + # waves = waves / np.linalg.norm(waves, axis=0, keepdims=True) + waves = waves / np.sum(waves, axis=0, keepdims=True) + + waves = (waves * (radius - (1-saturation) * radius) + (1-saturation) * radius) + waves = np.roll(waves, int(rotation * 2**bit_depth / (2*np.pi)), axis=1) + + # Create interpolation function + self.fn_interp = scipy.interpolate.interp1d( + x=x, + y=waves, + kind='linear', + axis=1, + bounds_error=False, + fill_value='extrapolate', + ) + + def __call__( + self, + angles: np.ndarray, + magnitudes: np.ndarray = None, + normalize_magnitudes: bool = True, + ) -> np.ndarray: + """ + Outputs colors for a given set of angles and magnitudes. + RH 2024 + + Args: + angles (np.ndarray): + Array of angles in radians. *Shape: (n_samples,)* + magnitudes (np.ndarray, optional): + Array of magnitudes. *Shape: (n_samples,)* + normalize_magnitudes (bool): + If True, applies min-max normalization to the magnitudes. (Default is ``True``) + + Returns: + np.ndarray: + Array with RGB values. *Shape: (n_samples, 3)* + """ + # Normalize the magnitudes + if magnitudes is not None: + if normalize_magnitudes: + magnitudes = (magnitudes - np.min(magnitudes)) / (np.max(magnitudes) - np.min(magnitudes)) + magnitudes = np.clip(magnitudes, 0, 1) + else: + magnitudes = np.ones_like(angles) + + # Get the saturated color by interpolating the colorwheel + sample_colors = self.fn_interp(angles % (2*np.pi)) + + # Clip the colors + sample_colors = np.clip(sample_colors, 0, self.radius) + + # Project to RGB + rgb = self.colors.T @ sample_colors + + + # Apply the saturation + rgb = rgb * magnitudes[None, :] + (1 - magnitudes)[None, :] * self.center + + # Convert to dtype + rgb = rgb.astype(self.dtype) + + return rgb.T + + def plot_colorwheel(self, n_samples: int = 100000): + """ + Plots the colorwheel colormap. + RH 2024 + + Args: + n_samples (int): + Number of samples to plot. (Default is ``100000``) + """ + import matplotlib.pyplot as plt + l = int(np.ceil(n_samples**0.5)) + grid = np.meshgrid(np.linspace(-1, 1, l), np.linspace(-1, 1, l), indexing='xy') + grid = grid[0] + 1j*grid[1] + grid = grid.reshape(-1) + angles = np.angle(grid) + magnitudes = np.abs(grid) + + mask = magnitudes > 1 + magnitudes = np.clip(magnitudes, 0, 1) + colors = self(angles, magnitudes) + colors = np.clip(colors, 0, 255) + + im = np.zeros((l, l, 3), dtype=self.dtype) + im[*np.meshgrid(range(l), range(l), indexing='ij')] = colors.reshape(im.shape[:2] + (3,)) + + fig, axs = plt.subplots(2, 1, figsize=(5, 10)) + axs[0].imshow(im) + x = np.linspace(0, 2*np.pi, l) + [axs[1].plot(x, v, color=c) for v, c in zip(self(x).T, self.colors)] + axs[1].set_ylabel('Channel magnitude') + axs[1].set_xlabel('Phase (rads)') + + def __repr__(self) -> str: + """ + Returns a string representation of the ColorwheelColormap object. + """ + return (f"ColorwheelColormap(rotation={self.rotation}, " + f"saturation={self.saturation}, center={self.center}, " + f"radius={self.radius}, dtype={self.dtype}, " + f"bit_depth={self.bit_depth}, exponent={self.exponent}, " + f"normalize={self.normalize})") + + class Figure_Saver: """ Class for saving figures