From 333eaa836733b36a37dabbae26d5dca8f53e4973 Mon Sep 17 00:00:00 2001 From: Kye Date: Mon, 6 May 2024 17:53:04 -0400 Subject: [PATCH] [FEAT][KAN] --- pyproject.toml | 2 +- zeta/nn/modules/__init__.py | 2 + zeta/nn/modules/kan.py | 362 ++++++++++++++++++++++++++++++++++++ zeta/nn/modules/splines.py | 148 +++++++++++++++ 4 files changed, 513 insertions(+), 1 deletion(-) create mode 100644 zeta/nn/modules/kan.py create mode 100644 zeta/nn/modules/splines.py diff --git a/pyproject.toml b/pyproject.toml index 21b974a6..bff5edf4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "zetascale" -version = "2.4.3" +version = "2.4.5" description = "Rapidly Build, Optimize, and Deploy SOTA AI Models" authors = ["Zeta Team "] license = "MIT" diff --git a/zeta/nn/modules/__init__.py b/zeta/nn/modules/__init__.py index ce9acb4a..d960e0db 100644 --- a/zeta/nn/modules/__init__.py +++ b/zeta/nn/modules/__init__.py @@ -213,6 +213,7 @@ from zeta.nn.modules.query_proposal import TextHawkQueryProposal from zeta.nn.modules.pixel_shuffling import PixelShuffleDownscale +from zeta.nn.modules.kan import KAN # from zeta.nn.modules.img_reshape import image_reshape # from zeta.nn.modules.flatten_features import flatten_features @@ -424,4 +425,5 @@ "video_patch_linear_flatten", "TextHawkQueryProposal", "PixelShuffleDownscale", + "KAN", ] diff --git a/zeta/nn/modules/kan.py b/zeta/nn/modules/kan.py new file mode 100644 index 00000000..03dc13a6 --- /dev/null +++ b/zeta/nn/modules/kan.py @@ -0,0 +1,362 @@ +import torch +import torch.nn.functional as F +import math +from typing import List + + +class KANLinear(torch.nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + grid_size: int = 5, + spline_order: int = 3, + scale_noise: float = 0.1, + scale_base: float = 1.0, + scale_spline: float = 1.0, + enable_standalone_scale_spline: bool = True, + base_activation: torch.nn.Module = torch.nn.SiLU, + grid_eps: float = 0.02, + grid_range: List[float] = [-1, 1], + ): + super(KANLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.grid_size = grid_size + self.spline_order = spline_order + + h = (grid_range[1] - grid_range[0]) / grid_size + grid = ( + ( + torch.arange(-spline_order, grid_size + spline_order + 1) * h + + grid_range[0] + ) + .expand(in_features, -1) + .contiguous() + ) + self.register_buffer("grid", grid) + + self.base_weight = torch.nn.Parameter( + torch.Tensor(out_features, in_features) + ) + self.spline_weight = torch.nn.Parameter( + torch.Tensor(out_features, in_features, grid_size + spline_order) + ) + if enable_standalone_scale_spline: + self.spline_scaler = torch.nn.Parameter( + torch.Tensor(out_features, in_features) + ) + + self.scale_noise = scale_noise + self.scale_base = scale_base + self.scale_spline = scale_spline + self.enable_standalone_scale_spline = enable_standalone_scale_spline + self.base_activation = base_activation() + self.grid_eps = grid_eps + + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.kaiming_uniform_( + self.base_weight, a=math.sqrt(5) * self.scale_base + ) + with torch.no_grad(): + noise = ( + ( + torch.rand( + self.grid_size + 1, self.in_features, self.out_features + ) + - 1 / 2 + ) + * self.scale_noise + / self.grid_size + ) + self.spline_weight.data.copy_( + ( + self.scale_spline + if not self.enable_standalone_scale_spline + else 1.0 + ) + * self.curve2coeff( + self.grid.T[self.spline_order : -self.spline_order], + noise, + ) + ) + if self.enable_standalone_scale_spline: + # torch.nn.init.constant_(self.spline_scaler, self.scale_spline) + torch.nn.init.kaiming_uniform_( + self.spline_scaler, a=math.sqrt(5) * self.scale_spline + ) + + def b_splines(self, x: torch.Tensor): + """ + Compute the B-spline bases for the given input tensor. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, in_features). + + Returns: + torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order). + """ + assert x.dim() == 2 and x.size(1) == self.in_features + + grid: torch.Tensor = ( + self.grid + ) # (in_features, grid_size + 2 * spline_order + 1) + x = x.unsqueeze(-1) + bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype) + for k in range(1, self.spline_order + 1): + bases = ( + (x - grid[:, : -(k + 1)]) + / (grid[:, k:-1] - grid[:, : -(k + 1)]) + * bases[:, :, :-1] + ) + ( + (grid[:, k + 1 :] - x) + / (grid[:, k + 1 :] - grid[:, 1:(-k)]) + * bases[:, :, 1:] + ) + + assert bases.size() == ( + x.size(0), + self.in_features, + self.grid_size + self.spline_order, + ) + return bases.contiguous() + + def curve2coeff(self, x: torch.Tensor, y: torch.Tensor): + """ + Compute the coefficients of the curve that interpolates the given points. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, in_features). + y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features). + + Returns: + torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order). + """ + assert x.dim() == 2 and x.size(1) == self.in_features + assert y.size() == (x.size(0), self.in_features, self.out_features) + + A = self.b_splines(x).transpose( + 0, 1 + ) # (in_features, batch_size, grid_size + spline_order) + B = y.transpose(0, 1) # (in_features, batch_size, out_features) + solution = torch.linalg.lstsq( + A, B + ).solution # (in_features, grid_size + spline_order, out_features) + result = solution.permute( + 2, 0, 1 + ) # (out_features, in_features, grid_size + spline_order) + + assert result.size() == ( + self.out_features, + self.in_features, + self.grid_size + self.spline_order, + ) + return result.contiguous() + + @property + def scaled_spline_weight(self): + return self.spline_weight * ( + self.spline_scaler.unsqueeze(-1) + if self.enable_standalone_scale_spline + else 1.0 + ) + + def forward(self, x: torch.Tensor): + assert x.dim() == 2 and x.size(1) == self.in_features + + base_output = F.linear(self.base_activation(x), self.base_weight) + spline_output = F.linear( + self.b_splines(x).view(x.size(0), -1), + self.scaled_spline_weight.view(self.out_features, -1), + ) + return base_output + spline_output + + @torch.no_grad() + def update_grid(self, x: torch.Tensor, margin=0.01): + assert x.dim() == 2 and x.size(1) == self.in_features + batch = x.size(0) + + splines = self.b_splines(x) # (batch, in, coeff) + splines = splines.permute(1, 0, 2) # (in, batch, coeff) + orig_coeff = self.scaled_spline_weight # (out, in, coeff) + orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out) + unreduced_spline_output = torch.bmm( + splines, orig_coeff + ) # (in, batch, out) + unreduced_spline_output = unreduced_spline_output.permute( + 1, 0, 2 + ) # (batch, in, out) + + # sort each channel individually to collect data distribution + x_sorted = torch.sort(x, dim=0)[0] + grid_adaptive = x_sorted[ + torch.linspace( + 0, + batch - 1, + self.grid_size + 1, + dtype=torch.int64, + device=x.device, + ) + ] + + uniform_step = ( + x_sorted[-1] - x_sorted[0] + 2 * margin + ) / self.grid_size + grid_uniform = ( + torch.arange( + self.grid_size + 1, dtype=torch.float32, device=x.device + ).unsqueeze(1) + * uniform_step + + x_sorted[0] + - margin + ) + + grid = ( + self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive + ) + grid = torch.concatenate( + [ + grid[:1] + - uniform_step + * torch.arange( + self.spline_order, 0, -1, device=x.device + ).unsqueeze(1), + grid, + grid[-1:] + + uniform_step + * torch.arange( + 1, self.spline_order + 1, device=x.device + ).unsqueeze(1), + ], + dim=0, + ) + + self.grid.copy_(grid.T) + self.spline_weight.data.copy_( + self.curve2coeff(x, unreduced_spline_output) + ) + + def regularization_loss( + self, regularize_activation=1.0, regularize_entropy=1.0 + ): + """ + Compute the regularization loss. + + This is a dumb simulation of the original L1 regularization as stated in the + paper, since the original one requires computing absolutes and entropy from the + expanded (batch, in_features, out_features) intermediate tensor, which is hidden + behind the F.linear function if we want an memory efficient implementation. + + The L1 regularization is now computed as mean absolute value of the spline + weights. The authors implementation also includes this term in addition to the + sample-based regularization. + """ + l1_fake = self.spline_weight.abs().mean(-1) + regularization_loss_activation = l1_fake.sum() + p = l1_fake / regularization_loss_activation + regularization_loss_entropy = -torch.sum(p * p.log()) + return ( + regularize_activation * regularization_loss_activation + + regularize_entropy * regularization_loss_entropy + ) + + +class KAN(torch.nn.Module): + """ + KAN (Kernel Activation Network) module. + + Args: + layers_hidden (list): List of integers representing the number of hidden units in each layer. + grid_size (int, optional): Size of the grid. Defaults to 5. + spline_order (int, optional): Order of the spline. Defaults to 3. + scale_noise (float, optional): Scale factor for the noise. Defaults to 0.1. + scale_base (float, optional): Scale factor for the base. Defaults to 1.0. + scale_spline (float, optional): Scale factor for the spline. Defaults to 1.0. + base_activation (torch.nn.Module, optional): Activation function for the base. Defaults to torch.nn.SiLU. + grid_eps (float, optional): Epsilon value for the grid. Defaults to 0.02. + grid_range (list, optional): Range of the grid. Defaults to [-1, 1]. + + Example: + >>> kan = KAN([2, 3, 1]) + >>> x = torch.randn(10, 2) + >>> y = kan(x) + + """ + + def __init__( + self, + layers_hidden: List[int], + grid_size: int = 5, + spline_order: int = 3, + scale_noise: float = 0.1, + scale_base: float = 1.0, + scale_spline: float = 1.0, + base_activation: torch.nn.Module = torch.nn.SiLU, + grid_eps: float = 0.02, + grid_range: List[float] = [-1, 1], + ) -> None: + super(KAN, self).__init__() + self.grid_size = grid_size + self.spline_order = spline_order + + self.layers = torch.nn.ModuleList() + for in_features, out_features in zip(layers_hidden, layers_hidden[1:]): + self.layers.append( + KANLinear( + in_features, + out_features, + grid_size=grid_size, + spline_order=spline_order, + scale_noise=scale_noise, + scale_base=scale_base, + scale_spline=scale_spline, + base_activation=base_activation, + grid_eps=grid_eps, + grid_range=grid_range, + ) + ) + + def forward(self, x: torch.Tensor, update_grid=False): + """ + Forward pass of the KAN module. + + Args: + x (torch.Tensor): Input tensor. + update_grid (bool, optional): Whether to update the grid. Defaults to False. + + Returns: + torch.Tensor: Output tensor. + """ + for layer in self.layers: + if update_grid: + layer.update_grid(x) + x = layer(x) + return x + + def regularization_loss( + self, regularize_activation=1.0, regularize_entropy=1.0 + ): + """ + Compute the regularization loss of the KAN module. + + Args: + regularize_activation (float, optional): Regularization factor for activation. Defaults to 1.0. + regularize_entropy (float, optional): Regularization factor for entropy. Defaults to 1.0. + + Returns: + torch.Tensor: Regularization loss. + """ + return sum( + layer.regularization_loss(regularize_activation, regularize_entropy) + for layer in self.layers + ) + + +# x = torch.randn(2, 3, 1) +# kan = KAN( +# layers_hidden=[2, 3, 1], +# ) +# y = kan(x) +# print(y) diff --git a/zeta/nn/modules/splines.py b/zeta/nn/modules/splines.py new file mode 100644 index 00000000..1446045e --- /dev/null +++ b/zeta/nn/modules/splines.py @@ -0,0 +1,148 @@ +import torch + + +def B_batch(x, grid, k=0, extend=True, device="cpu"): + """ + evaludate x on B-spline bases + + Args: + ----- + x : 2D torch.tensor + inputs, shape (number of splines, number of samples) + grid : 2D torch.tensor + grids, shape (number of splines, number of grid points) + k : int + the piecewise polynomial order of splines. + extend : bool + If True, k points are extended on both ends. If False, no extension (zero boundary condition). Default: True + device : str + devicde + + Returns: + -------- + spline values : 3D torch.tensor + shape (number of splines, number of B-spline bases (coeffcients), number of samples). The numbef of B-spline bases = number of grid points + k - 1. + + Example + ------- + >>> num_spline = 5 + >>> num_sample = 100 + >>> num_grid_interval = 10 + >>> k = 3 + >>> x = torch.normal(0,1,size=(num_spline, num_sample)) + >>> grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1)) + >>> B_batch(x, grids, k=k).shape + torch.Size([5, 13, 100]) + """ + + # x shape: (size, x); grid shape: (size, grid) + def extend_grid(grid, k_extend=0): + # pad k to left and right + # grid shape: (batch, grid) + h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1) + + for i in range(k_extend): + grid = torch.cat([grid[:, [0]] - h, grid], dim=1) + grid = torch.cat([grid, grid[:, [-1]] + h], dim=1) + grid = grid.to(device) + return grid + + if extend is True: + grid = extend_grid(grid, k_extend=k) + + grid = grid.unsqueeze(dim=2).to(device) + x = x.unsqueeze(dim=1).to(device) + + if k == 0: + value = (x >= grid[:, :-1]) * (x < grid[:, 1:]) + else: + B_km1 = B_batch(x[:, 0], grid=grid[:, :, 0], k=k - 1, extend=False) + value = (x - grid[:, : -(k + 1)]) / ( + grid[:, k:-1] - grid[:, : -(k + 1)] + ) * B_km1[:, :-1] + (grid[:, k + 1 :] - x) / ( + grid[:, k + 1 :] - grid[:, 1:(-k)] + ) * B_km1[ + :, 1: + ] + return value + + +def coef2curve(x_eval, grid, coef, k, device="cpu"): + """ + converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up B_batch results over B-spline basis). + + Args: + ----- + x_eval : 2D torch.tensor) + shape (number of splines, number of samples) + grid : 2D torch.tensor) + shape (number of splines, number of grid points) + coef : 2D torch.tensor) + shape (number of splines, number of coef params). number of coef params = number of grid intervals + k + k : int + the piecewise polynomial order of splines. + device : str + devicde + + Returns: + -------- + y_eval : 2D torch.tensor + shape (number of splines, number of samples) + + Example + ------- + >>> num_spline = 5 + >>> num_sample = 100 + >>> num_grid_interval = 10 + >>> k = 3 + >>> x_eval = torch.normal(0,1,size=(num_spline, num_sample)) + >>> grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1)) + >>> coef = torch.normal(0,1,size=(num_spline, num_grid_interval+k)) + >>> coef2curve(x_eval, grids, coef, k=k).shape + torch.Size([5, 100]) + """ + # x_eval: (size, batch), grid: (size, grid), coef: (size, coef) + # coef: (size, coef), B_batch: (size, coef, batch), summer over coef + y_eval = torch.einsum( + "ij,ijk->ik", coef, B_batch(x_eval, grid, k, device=device) + ) + return y_eval + + +def curve2coef(x_eval, y_eval, grid, k, device="cpu"): + """ + converting B-spline curves to B-spline coefficients using least squares. + + Args: + ----- + x_eval : 2D torch.tensor + shape (number of splines, number of samples) + y_eval : 2D torch.tensor + shape (number of splines, number of samples) + grid : 2D torch.tensor + shape (number of splines, number of grid points) + k : int + the piecewise polynomial order of splines. + device : str + devicde + + Example + ------- + >>> num_spline = 5 + >>> num_sample = 100 + >>> num_grid_interval = 10 + >>> k = 3 + >>> x_eval = torch.normal(0,1,size=(num_spline, num_sample)) + >>> y_eval = torch.normal(0,1,size=(num_spline, num_sample)) + >>> grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1)) + >>> curve2coef(x_eval, y_eval, grids, k=k).shape + torch.Size([5, 13]) + """ + # x_eval: (size, batch); y_eval: (size, batch); grid: (size, grid); k: scalar + mat = B_batch(x_eval, grid, k, device=device).permute(0, 2, 1) + coef = torch.linalg.lstsq( + mat.to("cpu"), y_eval.unsqueeze(dim=2).to("cpu") + ).solution[ + :, :, 0 + ] # sometimes 'cuda' version may diverge + return coef.to(device)