-
-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Kye
committed
May 6, 2024
1 parent
78af525
commit 333eaa8
Showing
4 changed files
with
513 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[tool.poetry] | ||
name = "zetascale" | ||
version = "2.4.3" | ||
version = "2.4.5" | ||
description = "Rapidly Build, Optimize, and Deploy SOTA AI Models" | ||
authors = ["Zeta Team <[email protected]>"] | ||
license = "MIT" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,362 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
import math | ||
from typing import List | ||
|
||
|
||
class KANLinear(torch.nn.Module): | ||
def __init__( | ||
self, | ||
in_features: int, | ||
out_features: int, | ||
grid_size: int = 5, | ||
spline_order: int = 3, | ||
scale_noise: float = 0.1, | ||
scale_base: float = 1.0, | ||
scale_spline: float = 1.0, | ||
enable_standalone_scale_spline: bool = True, | ||
base_activation: torch.nn.Module = torch.nn.SiLU, | ||
grid_eps: float = 0.02, | ||
grid_range: List[float] = [-1, 1], | ||
): | ||
super(KANLinear, self).__init__() | ||
self.in_features = in_features | ||
self.out_features = out_features | ||
self.grid_size = grid_size | ||
self.spline_order = spline_order | ||
|
||
h = (grid_range[1] - grid_range[0]) / grid_size | ||
grid = ( | ||
( | ||
torch.arange(-spline_order, grid_size + spline_order + 1) * h | ||
+ grid_range[0] | ||
) | ||
.expand(in_features, -1) | ||
.contiguous() | ||
) | ||
self.register_buffer("grid", grid) | ||
|
||
self.base_weight = torch.nn.Parameter( | ||
torch.Tensor(out_features, in_features) | ||
) | ||
self.spline_weight = torch.nn.Parameter( | ||
torch.Tensor(out_features, in_features, grid_size + spline_order) | ||
) | ||
if enable_standalone_scale_spline: | ||
self.spline_scaler = torch.nn.Parameter( | ||
torch.Tensor(out_features, in_features) | ||
) | ||
|
||
self.scale_noise = scale_noise | ||
self.scale_base = scale_base | ||
self.scale_spline = scale_spline | ||
self.enable_standalone_scale_spline = enable_standalone_scale_spline | ||
self.base_activation = base_activation() | ||
self.grid_eps = grid_eps | ||
|
||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
torch.nn.init.kaiming_uniform_( | ||
self.base_weight, a=math.sqrt(5) * self.scale_base | ||
) | ||
with torch.no_grad(): | ||
noise = ( | ||
( | ||
torch.rand( | ||
self.grid_size + 1, self.in_features, self.out_features | ||
) | ||
- 1 / 2 | ||
) | ||
* self.scale_noise | ||
/ self.grid_size | ||
) | ||
self.spline_weight.data.copy_( | ||
( | ||
self.scale_spline | ||
if not self.enable_standalone_scale_spline | ||
else 1.0 | ||
) | ||
* self.curve2coeff( | ||
self.grid.T[self.spline_order : -self.spline_order], | ||
noise, | ||
) | ||
) | ||
if self.enable_standalone_scale_spline: | ||
# torch.nn.init.constant_(self.spline_scaler, self.scale_spline) | ||
torch.nn.init.kaiming_uniform_( | ||
self.spline_scaler, a=math.sqrt(5) * self.scale_spline | ||
) | ||
|
||
def b_splines(self, x: torch.Tensor): | ||
""" | ||
Compute the B-spline bases for the given input tensor. | ||
Args: | ||
x (torch.Tensor): Input tensor of shape (batch_size, in_features). | ||
Returns: | ||
torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order). | ||
""" | ||
assert x.dim() == 2 and x.size(1) == self.in_features | ||
|
||
grid: torch.Tensor = ( | ||
self.grid | ||
) # (in_features, grid_size + 2 * spline_order + 1) | ||
x = x.unsqueeze(-1) | ||
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype) | ||
for k in range(1, self.spline_order + 1): | ||
bases = ( | ||
(x - grid[:, : -(k + 1)]) | ||
/ (grid[:, k:-1] - grid[:, : -(k + 1)]) | ||
* bases[:, :, :-1] | ||
) + ( | ||
(grid[:, k + 1 :] - x) | ||
/ (grid[:, k + 1 :] - grid[:, 1:(-k)]) | ||
* bases[:, :, 1:] | ||
) | ||
|
||
assert bases.size() == ( | ||
x.size(0), | ||
self.in_features, | ||
self.grid_size + self.spline_order, | ||
) | ||
return bases.contiguous() | ||
|
||
def curve2coeff(self, x: torch.Tensor, y: torch.Tensor): | ||
""" | ||
Compute the coefficients of the curve that interpolates the given points. | ||
Args: | ||
x (torch.Tensor): Input tensor of shape (batch_size, in_features). | ||
y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features). | ||
Returns: | ||
torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order). | ||
""" | ||
assert x.dim() == 2 and x.size(1) == self.in_features | ||
assert y.size() == (x.size(0), self.in_features, self.out_features) | ||
|
||
A = self.b_splines(x).transpose( | ||
0, 1 | ||
) # (in_features, batch_size, grid_size + spline_order) | ||
B = y.transpose(0, 1) # (in_features, batch_size, out_features) | ||
solution = torch.linalg.lstsq( | ||
A, B | ||
).solution # (in_features, grid_size + spline_order, out_features) | ||
result = solution.permute( | ||
2, 0, 1 | ||
) # (out_features, in_features, grid_size + spline_order) | ||
|
||
assert result.size() == ( | ||
self.out_features, | ||
self.in_features, | ||
self.grid_size + self.spline_order, | ||
) | ||
return result.contiguous() | ||
|
||
@property | ||
def scaled_spline_weight(self): | ||
return self.spline_weight * ( | ||
self.spline_scaler.unsqueeze(-1) | ||
if self.enable_standalone_scale_spline | ||
else 1.0 | ||
) | ||
|
||
def forward(self, x: torch.Tensor): | ||
assert x.dim() == 2 and x.size(1) == self.in_features | ||
|
||
base_output = F.linear(self.base_activation(x), self.base_weight) | ||
spline_output = F.linear( | ||
self.b_splines(x).view(x.size(0), -1), | ||
self.scaled_spline_weight.view(self.out_features, -1), | ||
) | ||
return base_output + spline_output | ||
|
||
@torch.no_grad() | ||
def update_grid(self, x: torch.Tensor, margin=0.01): | ||
assert x.dim() == 2 and x.size(1) == self.in_features | ||
batch = x.size(0) | ||
|
||
splines = self.b_splines(x) # (batch, in, coeff) | ||
splines = splines.permute(1, 0, 2) # (in, batch, coeff) | ||
orig_coeff = self.scaled_spline_weight # (out, in, coeff) | ||
orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out) | ||
unreduced_spline_output = torch.bmm( | ||
splines, orig_coeff | ||
) # (in, batch, out) | ||
unreduced_spline_output = unreduced_spline_output.permute( | ||
1, 0, 2 | ||
) # (batch, in, out) | ||
|
||
# sort each channel individually to collect data distribution | ||
x_sorted = torch.sort(x, dim=0)[0] | ||
grid_adaptive = x_sorted[ | ||
torch.linspace( | ||
0, | ||
batch - 1, | ||
self.grid_size + 1, | ||
dtype=torch.int64, | ||
device=x.device, | ||
) | ||
] | ||
|
||
uniform_step = ( | ||
x_sorted[-1] - x_sorted[0] + 2 * margin | ||
) / self.grid_size | ||
grid_uniform = ( | ||
torch.arange( | ||
self.grid_size + 1, dtype=torch.float32, device=x.device | ||
).unsqueeze(1) | ||
* uniform_step | ||
+ x_sorted[0] | ||
- margin | ||
) | ||
|
||
grid = ( | ||
self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive | ||
) | ||
grid = torch.concatenate( | ||
[ | ||
grid[:1] | ||
- uniform_step | ||
* torch.arange( | ||
self.spline_order, 0, -1, device=x.device | ||
).unsqueeze(1), | ||
grid, | ||
grid[-1:] | ||
+ uniform_step | ||
* torch.arange( | ||
1, self.spline_order + 1, device=x.device | ||
).unsqueeze(1), | ||
], | ||
dim=0, | ||
) | ||
|
||
self.grid.copy_(grid.T) | ||
self.spline_weight.data.copy_( | ||
self.curve2coeff(x, unreduced_spline_output) | ||
) | ||
|
||
def regularization_loss( | ||
self, regularize_activation=1.0, regularize_entropy=1.0 | ||
): | ||
""" | ||
Compute the regularization loss. | ||
This is a dumb simulation of the original L1 regularization as stated in the | ||
paper, since the original one requires computing absolutes and entropy from the | ||
expanded (batch, in_features, out_features) intermediate tensor, which is hidden | ||
behind the F.linear function if we want an memory efficient implementation. | ||
The L1 regularization is now computed as mean absolute value of the spline | ||
weights. The authors implementation also includes this term in addition to the | ||
sample-based regularization. | ||
""" | ||
l1_fake = self.spline_weight.abs().mean(-1) | ||
regularization_loss_activation = l1_fake.sum() | ||
p = l1_fake / regularization_loss_activation | ||
regularization_loss_entropy = -torch.sum(p * p.log()) | ||
return ( | ||
regularize_activation * regularization_loss_activation | ||
+ regularize_entropy * regularization_loss_entropy | ||
) | ||
|
||
|
||
class KAN(torch.nn.Module): | ||
""" | ||
KAN (Kernel Activation Network) module. | ||
Args: | ||
layers_hidden (list): List of integers representing the number of hidden units in each layer. | ||
grid_size (int, optional): Size of the grid. Defaults to 5. | ||
spline_order (int, optional): Order of the spline. Defaults to 3. | ||
scale_noise (float, optional): Scale factor for the noise. Defaults to 0.1. | ||
scale_base (float, optional): Scale factor for the base. Defaults to 1.0. | ||
scale_spline (float, optional): Scale factor for the spline. Defaults to 1.0. | ||
base_activation (torch.nn.Module, optional): Activation function for the base. Defaults to torch.nn.SiLU. | ||
grid_eps (float, optional): Epsilon value for the grid. Defaults to 0.02. | ||
grid_range (list, optional): Range of the grid. Defaults to [-1, 1]. | ||
Example: | ||
>>> kan = KAN([2, 3, 1]) | ||
>>> x = torch.randn(10, 2) | ||
>>> y = kan(x) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
layers_hidden: List[int], | ||
grid_size: int = 5, | ||
spline_order: int = 3, | ||
scale_noise: float = 0.1, | ||
scale_base: float = 1.0, | ||
scale_spline: float = 1.0, | ||
base_activation: torch.nn.Module = torch.nn.SiLU, | ||
grid_eps: float = 0.02, | ||
grid_range: List[float] = [-1, 1], | ||
) -> None: | ||
super(KAN, self).__init__() | ||
self.grid_size = grid_size | ||
self.spline_order = spline_order | ||
|
||
self.layers = torch.nn.ModuleList() | ||
for in_features, out_features in zip(layers_hidden, layers_hidden[1:]): | ||
self.layers.append( | ||
KANLinear( | ||
in_features, | ||
out_features, | ||
grid_size=grid_size, | ||
spline_order=spline_order, | ||
scale_noise=scale_noise, | ||
scale_base=scale_base, | ||
scale_spline=scale_spline, | ||
base_activation=base_activation, | ||
grid_eps=grid_eps, | ||
grid_range=grid_range, | ||
) | ||
) | ||
|
||
def forward(self, x: torch.Tensor, update_grid=False): | ||
""" | ||
Forward pass of the KAN module. | ||
Args: | ||
x (torch.Tensor): Input tensor. | ||
update_grid (bool, optional): Whether to update the grid. Defaults to False. | ||
Returns: | ||
torch.Tensor: Output tensor. | ||
""" | ||
for layer in self.layers: | ||
if update_grid: | ||
layer.update_grid(x) | ||
x = layer(x) | ||
return x | ||
|
||
def regularization_loss( | ||
self, regularize_activation=1.0, regularize_entropy=1.0 | ||
): | ||
""" | ||
Compute the regularization loss of the KAN module. | ||
Args: | ||
regularize_activation (float, optional): Regularization factor for activation. Defaults to 1.0. | ||
regularize_entropy (float, optional): Regularization factor for entropy. Defaults to 1.0. | ||
Returns: | ||
torch.Tensor: Regularization loss. | ||
""" | ||
return sum( | ||
layer.regularization_loss(regularize_activation, regularize_entropy) | ||
for layer in self.layers | ||
) | ||
|
||
|
||
# x = torch.randn(2, 3, 1) | ||
# kan = KAN( | ||
# layers_hidden=[2, 3, 1], | ||
# ) | ||
# y = kan(x) | ||
# print(y) |
Oops, something went wrong.