Skip to content

Commit

Permalink
[FEAT][KAN]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed May 6, 2024
1 parent 78af525 commit 333eaa8
Show file tree
Hide file tree
Showing 4 changed files with 513 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zetascale"
version = "2.4.3"
version = "2.4.5"
description = "Rapidly Build, Optimize, and Deploy SOTA AI Models"
authors = ["Zeta Team <[email protected]>"]
license = "MIT"
Expand Down
2 changes: 2 additions & 0 deletions zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@

from zeta.nn.modules.query_proposal import TextHawkQueryProposal
from zeta.nn.modules.pixel_shuffling import PixelShuffleDownscale
from zeta.nn.modules.kan import KAN

# from zeta.nn.modules.img_reshape import image_reshape
# from zeta.nn.modules.flatten_features import flatten_features
Expand Down Expand Up @@ -424,4 +425,5 @@
"video_patch_linear_flatten",
"TextHawkQueryProposal",
"PixelShuffleDownscale",
"KAN",
]
362 changes: 362 additions & 0 deletions zeta/nn/modules/kan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,362 @@
import torch
import torch.nn.functional as F
import math
from typing import List


class KANLinear(torch.nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
grid_size: int = 5,
spline_order: int = 3,
scale_noise: float = 0.1,
scale_base: float = 1.0,
scale_spline: float = 1.0,
enable_standalone_scale_spline: bool = True,
base_activation: torch.nn.Module = torch.nn.SiLU,
grid_eps: float = 0.02,
grid_range: List[float] = [-1, 1],
):
super(KANLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.grid_size = grid_size
self.spline_order = spline_order

h = (grid_range[1] - grid_range[0]) / grid_size
grid = (
(
torch.arange(-spline_order, grid_size + spline_order + 1) * h
+ grid_range[0]
)
.expand(in_features, -1)
.contiguous()
)
self.register_buffer("grid", grid)

self.base_weight = torch.nn.Parameter(
torch.Tensor(out_features, in_features)
)
self.spline_weight = torch.nn.Parameter(
torch.Tensor(out_features, in_features, grid_size + spline_order)
)
if enable_standalone_scale_spline:
self.spline_scaler = torch.nn.Parameter(
torch.Tensor(out_features, in_features)
)

self.scale_noise = scale_noise
self.scale_base = scale_base
self.scale_spline = scale_spline
self.enable_standalone_scale_spline = enable_standalone_scale_spline
self.base_activation = base_activation()
self.grid_eps = grid_eps

self.reset_parameters()

def reset_parameters(self):
torch.nn.init.kaiming_uniform_(
self.base_weight, a=math.sqrt(5) * self.scale_base
)
with torch.no_grad():
noise = (
(
torch.rand(
self.grid_size + 1, self.in_features, self.out_features
)
- 1 / 2
)
* self.scale_noise
/ self.grid_size
)
self.spline_weight.data.copy_(
(
self.scale_spline
if not self.enable_standalone_scale_spline
else 1.0
)
* self.curve2coeff(
self.grid.T[self.spline_order : -self.spline_order],
noise,
)
)
if self.enable_standalone_scale_spline:
# torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
torch.nn.init.kaiming_uniform_(
self.spline_scaler, a=math.sqrt(5) * self.scale_spline
)

def b_splines(self, x: torch.Tensor):
"""
Compute the B-spline bases for the given input tensor.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
Returns:
torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
"""
assert x.dim() == 2 and x.size(1) == self.in_features

grid: torch.Tensor = (
self.grid
) # (in_features, grid_size + 2 * spline_order + 1)
x = x.unsqueeze(-1)
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
for k in range(1, self.spline_order + 1):
bases = (
(x - grid[:, : -(k + 1)])
/ (grid[:, k:-1] - grid[:, : -(k + 1)])
* bases[:, :, :-1]
) + (
(grid[:, k + 1 :] - x)
/ (grid[:, k + 1 :] - grid[:, 1:(-k)])
* bases[:, :, 1:]
)

assert bases.size() == (
x.size(0),
self.in_features,
self.grid_size + self.spline_order,
)
return bases.contiguous()

def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
"""
Compute the coefficients of the curve that interpolates the given points.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).
Returns:
torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
"""
assert x.dim() == 2 and x.size(1) == self.in_features
assert y.size() == (x.size(0), self.in_features, self.out_features)

A = self.b_splines(x).transpose(
0, 1
) # (in_features, batch_size, grid_size + spline_order)
B = y.transpose(0, 1) # (in_features, batch_size, out_features)
solution = torch.linalg.lstsq(
A, B
).solution # (in_features, grid_size + spline_order, out_features)
result = solution.permute(
2, 0, 1
) # (out_features, in_features, grid_size + spline_order)

assert result.size() == (
self.out_features,
self.in_features,
self.grid_size + self.spline_order,
)
return result.contiguous()

@property
def scaled_spline_weight(self):
return self.spline_weight * (
self.spline_scaler.unsqueeze(-1)
if self.enable_standalone_scale_spline
else 1.0
)

def forward(self, x: torch.Tensor):
assert x.dim() == 2 and x.size(1) == self.in_features

base_output = F.linear(self.base_activation(x), self.base_weight)
spline_output = F.linear(
self.b_splines(x).view(x.size(0), -1),
self.scaled_spline_weight.view(self.out_features, -1),
)
return base_output + spline_output

@torch.no_grad()
def update_grid(self, x: torch.Tensor, margin=0.01):
assert x.dim() == 2 and x.size(1) == self.in_features
batch = x.size(0)

splines = self.b_splines(x) # (batch, in, coeff)
splines = splines.permute(1, 0, 2) # (in, batch, coeff)
orig_coeff = self.scaled_spline_weight # (out, in, coeff)
orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out)
unreduced_spline_output = torch.bmm(
splines, orig_coeff
) # (in, batch, out)
unreduced_spline_output = unreduced_spline_output.permute(
1, 0, 2
) # (batch, in, out)

# sort each channel individually to collect data distribution
x_sorted = torch.sort(x, dim=0)[0]
grid_adaptive = x_sorted[
torch.linspace(
0,
batch - 1,
self.grid_size + 1,
dtype=torch.int64,
device=x.device,
)
]

uniform_step = (
x_sorted[-1] - x_sorted[0] + 2 * margin
) / self.grid_size
grid_uniform = (
torch.arange(
self.grid_size + 1, dtype=torch.float32, device=x.device
).unsqueeze(1)
* uniform_step
+ x_sorted[0]
- margin
)

grid = (
self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
)
grid = torch.concatenate(
[
grid[:1]
- uniform_step
* torch.arange(
self.spline_order, 0, -1, device=x.device
).unsqueeze(1),
grid,
grid[-1:]
+ uniform_step
* torch.arange(
1, self.spline_order + 1, device=x.device
).unsqueeze(1),
],
dim=0,
)

self.grid.copy_(grid.T)
self.spline_weight.data.copy_(
self.curve2coeff(x, unreduced_spline_output)
)

def regularization_loss(
self, regularize_activation=1.0, regularize_entropy=1.0
):
"""
Compute the regularization loss.
This is a dumb simulation of the original L1 regularization as stated in the
paper, since the original one requires computing absolutes and entropy from the
expanded (batch, in_features, out_features) intermediate tensor, which is hidden
behind the F.linear function if we want an memory efficient implementation.
The L1 regularization is now computed as mean absolute value of the spline
weights. The authors implementation also includes this term in addition to the
sample-based regularization.
"""
l1_fake = self.spline_weight.abs().mean(-1)
regularization_loss_activation = l1_fake.sum()
p = l1_fake / regularization_loss_activation
regularization_loss_entropy = -torch.sum(p * p.log())
return (
regularize_activation * regularization_loss_activation
+ regularize_entropy * regularization_loss_entropy
)


class KAN(torch.nn.Module):
"""
KAN (Kernel Activation Network) module.
Args:
layers_hidden (list): List of integers representing the number of hidden units in each layer.
grid_size (int, optional): Size of the grid. Defaults to 5.
spline_order (int, optional): Order of the spline. Defaults to 3.
scale_noise (float, optional): Scale factor for the noise. Defaults to 0.1.
scale_base (float, optional): Scale factor for the base. Defaults to 1.0.
scale_spline (float, optional): Scale factor for the spline. Defaults to 1.0.
base_activation (torch.nn.Module, optional): Activation function for the base. Defaults to torch.nn.SiLU.
grid_eps (float, optional): Epsilon value for the grid. Defaults to 0.02.
grid_range (list, optional): Range of the grid. Defaults to [-1, 1].
Example:
>>> kan = KAN([2, 3, 1])
>>> x = torch.randn(10, 2)
>>> y = kan(x)
"""

def __init__(
self,
layers_hidden: List[int],
grid_size: int = 5,
spline_order: int = 3,
scale_noise: float = 0.1,
scale_base: float = 1.0,
scale_spline: float = 1.0,
base_activation: torch.nn.Module = torch.nn.SiLU,
grid_eps: float = 0.02,
grid_range: List[float] = [-1, 1],
) -> None:
super(KAN, self).__init__()
self.grid_size = grid_size
self.spline_order = spline_order

self.layers = torch.nn.ModuleList()
for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
self.layers.append(
KANLinear(
in_features,
out_features,
grid_size=grid_size,
spline_order=spline_order,
scale_noise=scale_noise,
scale_base=scale_base,
scale_spline=scale_spline,
base_activation=base_activation,
grid_eps=grid_eps,
grid_range=grid_range,
)
)

def forward(self, x: torch.Tensor, update_grid=False):
"""
Forward pass of the KAN module.
Args:
x (torch.Tensor): Input tensor.
update_grid (bool, optional): Whether to update the grid. Defaults to False.
Returns:
torch.Tensor: Output tensor.
"""
for layer in self.layers:
if update_grid:
layer.update_grid(x)
x = layer(x)
return x

def regularization_loss(
self, regularize_activation=1.0, regularize_entropy=1.0
):
"""
Compute the regularization loss of the KAN module.
Args:
regularize_activation (float, optional): Regularization factor for activation. Defaults to 1.0.
regularize_entropy (float, optional): Regularization factor for entropy. Defaults to 1.0.
Returns:
torch.Tensor: Regularization loss.
"""
return sum(
layer.regularization_loss(regularize_activation, regularize_entropy)
for layer in self.layers
)


# x = torch.randn(2, 3, 1)
# kan = KAN(
# layers_hidden=[2, 3, 1],
# )
# y = kan(x)
# print(y)
Loading

0 comments on commit 333eaa8

Please sign in to comment.