Skip to content

Commit

Permalink
New Colorwheel class
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Jun 27, 2024
1 parent a44147a commit 9c05dae
Showing 1 changed file with 183 additions and 0 deletions.
183 changes: 183 additions & 0 deletions bnpm/plotting_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import torch
import cv2

from . import spectral
from . import math_functions


###############
#### PLOTS ####
Expand Down Expand Up @@ -774,6 +777,186 @@ def complex_colormap(
return rgb


class Colorwheel:
"""
Generates a 2D colorwheel colormap (magnitude and angle). Useful for
visualizing complex/polar values, optical flow, and other cyclic data.
RH 2024
Args:
rotation (float):
Rotation of the colorwheel in degrees.
saturation (float):
Saturation of the colors.
center (int):
Center of the colorwheel.
radius (int):
Radius of the colorwheel.
dtype (np.dtype):
Data type of the output colormap.
bit_depth (int):
Bit depth of the colorwheel.
exponent (float):
Exponent used to adjust the color intensity.
normalize (bool):
Whether to normalize the colorwheel.
"""
def __init__(
self,
rotation: float = 0.0,
saturation: float = 1.0,
center: int = 0,
radius: int = 255,
dtype: np.dtype = np.uint8,
bit_depth: int = 16,
exponent: float = 1.2,
normalize: bool = True,
colors: List[Union[List, Tuple]] = [
[1 , 0 , 0 ],
[1 , 1 , 0 ],
[0 , 1 , 0 ],
[0 , 1 , 1 ],
[0 , 0 , 1 ],
[1 , 0 , 1 ],
],
):
"""
Initializes the ColorwheelColormap with given parameters.
"""
import scipy.interpolate
import scipy.special

self.rotation = rotation
self.saturation = saturation
self.center = center
self.radius = radius
self.dtype = dtype
self.bit_depth = bit_depth
self.exponent = exponent
self.normalize = normalize
self.colors = np.array(colors)

# Make a rainbow colorwheel
# Create 3 single cosine waves centered at 0, 120, and 240 degrees spanning 120 degrees each
import scipy.signal
waves, x = spectral.generate_multiphasic_sinewave(
n_samples=int(2**bit_depth),
n_periods=1,
n_waves=len(colors),
return_x=True,
)
waves = ((waves + 1).astype(np.float64) / 2) ** exponent

waves = (waves - waves.min()) / (waves.max() - waves.min())

if normalize:
# waves = waves / np.linalg.norm(waves, axis=0, keepdims=True)
waves = waves / np.sum(waves, axis=0, keepdims=True)

waves = (waves * (radius - (1-saturation) * radius) + (1-saturation) * radius)
waves = np.roll(waves, int(rotation * 2**bit_depth / (2*np.pi)), axis=1)

# Create interpolation function
self.fn_interp = scipy.interpolate.interp1d(
x=x,
y=waves,
kind='linear',
axis=1,
bounds_error=False,
fill_value='extrapolate',
)

def __call__(
self,
angles: np.ndarray,
magnitudes: np.ndarray = None,
normalize_magnitudes: bool = True,
) -> np.ndarray:
"""
Outputs colors for a given set of angles and magnitudes.
RH 2024
Args:
angles (np.ndarray):
Array of angles in radians. *Shape: (n_samples,)*
magnitudes (np.ndarray, optional):
Array of magnitudes. *Shape: (n_samples,)*
normalize_magnitudes (bool):
If True, applies min-max normalization to the magnitudes. (Default is ``True``)
Returns:
np.ndarray:
Array with RGB values. *Shape: (n_samples, 3)*
"""
# Normalize the magnitudes
if magnitudes is not None:
if normalize_magnitudes:
magnitudes = (magnitudes - np.min(magnitudes)) / (np.max(magnitudes) - np.min(magnitudes))
magnitudes = np.clip(magnitudes, 0, 1)
else:
magnitudes = np.ones_like(angles)

# Get the saturated color by interpolating the colorwheel
sample_colors = self.fn_interp(angles % (2*np.pi))

# Clip the colors
sample_colors = np.clip(sample_colors, 0, self.radius)

# Project to RGB
rgb = self.colors.T @ sample_colors


# Apply the saturation
rgb = rgb * magnitudes[None, :] + (1 - magnitudes)[None, :] * self.center

# Convert to dtype
rgb = rgb.astype(self.dtype)

return rgb.T

def plot_colorwheel(self, n_samples: int = 100000):
"""
Plots the colorwheel colormap.
RH 2024
Args:
n_samples (int):
Number of samples to plot. (Default is ``100000``)
"""
import matplotlib.pyplot as plt
l = int(np.ceil(n_samples**0.5))
grid = np.meshgrid(np.linspace(-1, 1, l), np.linspace(-1, 1, l), indexing='xy')
grid = grid[0] + 1j*grid[1]
grid = grid.reshape(-1)
angles = np.angle(grid)
magnitudes = np.abs(grid)

mask = magnitudes > 1
magnitudes = np.clip(magnitudes, 0, 1)
colors = self(angles, magnitudes)
colors = np.clip(colors, 0, 255)

im = np.zeros((l, l, 3), dtype=self.dtype)
im[*np.meshgrid(range(l), range(l), indexing='ij')] = colors.reshape(im.shape[:2] + (3,))

fig, axs = plt.subplots(2, 1, figsize=(5, 10))
axs[0].imshow(im)
x = np.linspace(0, 2*np.pi, l)
[axs[1].plot(x, v, color=c) for v, c in zip(self(x).T, self.colors)]
axs[1].set_ylabel('Channel magnitude')
axs[1].set_xlabel('Phase (rads)')

def __repr__(self) -> str:
"""
Returns a string representation of the ColorwheelColormap object.
"""
return (f"ColorwheelColormap(rotation={self.rotation}, "
f"saturation={self.saturation}, center={self.center}, "
f"radius={self.radius}, dtype={self.dtype}, "
f"bit_depth={self.bit_depth}, exponent={self.exponent}, "
f"normalize={self.normalize})")


class Figure_Saver:
"""
Class for saving figures
Expand Down

0 comments on commit 9c05dae

Please sign in to comment.