diff --git a/nanodl/__init__.py b/nanodl/__init__.py index 09a268a..bf7f8fc 100644 --- a/nanodl/__init__.py +++ b/nanodl/__init__.py @@ -63,6 +63,14 @@ LlamaDataParallelTrainer, RotaryPositionalEncoding, ) +from nanodl.__src.models.kan import( + KANLinear, + ChebyKANLinear, + LegendreKANLinear, + MonomialKANLinear, + FourierKANLinear, + HermiteKANLinear, +) from nanodl.__src.models.mistral import ( GroupedRotaryShiftedWindowMultiHeadAttention, Mistral, @@ -147,7 +155,7 @@ ) __all__ = [ - # Sklearn GPU + # Classical "NaiveBayesClassifier", "PCA", "KMeans", @@ -285,6 +293,12 @@ "geometric", "gamma", "chisquare", + "KANLinear", + "ChebyKANLinear", + "LegendreKANLinear", + "MonomialKANLinear", + "FourierKANLinear", + "HermiteKANLinear", ] import importlib diff --git a/nanodl/__src/experimental/kan.py b/nanodl/__src/experimental/kan.py deleted file mode 100644 index 7ff74d9..0000000 --- a/nanodl/__src/experimental/kan.py +++ /dev/null @@ -1,227 +0,0 @@ -import jax -import jax.numpy as jnp -from flax import linen as nn -from jax.scipy.special import logsumexp -from jax import random - - -class KANLinear(nn.Module): - """ - KANLinear is a class that represents a linear layer in a Kernelized Attention Network (KAN). - It uses B-splines to model the attention mechanism, which allows for more flexibility than traditional attention mechanisms. - - Attributes: - in_features (int): The number of input features. - out_features (int): The number of output features. - grid_size (int): The size of the grid used for the B-splines. Default is 5. - spline_order (int): The order of the B-splines. Default is 3. - scale_noise (float): The scale of the noise added to the B-splines. Default is 0.1. - scale_base (float): The scale of the base weights. Default is 1.0. - scale_spline (float): The scale of the spline weights. Default is 1.0. - enable_standalone_scale_spline (bool): Whether to enable standalone scaling of the spline weights. Default is True. - base_activation (callable): The activation function to use for the base weights. Default is nn.silu. - grid_eps (float): The epsilon value used for the grid. Default is 0.02. - grid_range (list): The range of the grid. Default is [-1, 1]. - """ - - in_features: int - out_features: int - grid_size: int = 5 - spline_order: int = 3 - scale_noise: float = 0.1 - scale_base: float = 1.0 - scale_spline: float = 1.0 - enable_standalone_scale_spline: bool = True - base_activation: callable = nn.silu - grid_eps: float = 0.02 - grid_range: list = [-1, 1] - - def setup(self): - h = (self.grid_range[1] - self.grid_range[0]) / self.grid_size - grid = jnp.tile( - jnp.arange(-self.spline_order, self.grid_size + self.spline_order + 1) * h - + self.grid_range[0], - (self.in_features, 1), - ) - self.grid = self.param("grid", grid.shape, nn.initializers.zeros) - - self.base_weight = self.param( - "base_weight", - (self.out_features, self.in_features), - nn.initializers.kaiming_uniform(), - ) - self.spline_weight = self.param( - "spline_weight", - (self.out_features, self.in_features, self.grid_size + self.spline_order), - nn.initializers.zeros, - ) - if self.enable_standalone_scale_spline: - self.spline_scaler = self.param( - "spline_scaler", - (self.out_features, self.in_features), - nn.initializers.kaiming_uniform(), - ) - - self.reset_parameters() - - def reset_parameters(self): - self.base_weight = ( - nn.initializers.kaiming_uniform()( - self.base_weight.shape, self.base_weight.dtype - ) - * self.scale_base - ) - noise = ( - ( - random.uniform( - jax.random.PRNGKey(0), - (self.grid_size + 1, self.in_features, self.out_features), - ) - - 1 / 2 - ) - * self.scale_noise - / self.grid_size - ) - self.spline_weight = self.curve2coeff( - self.grid.T[self.spline_order : -self.spline_order], noise - ) * (self.scale_spline if not self.enable_standalone_scale_spline else 1.0) - if self.enable_standalone_scale_spline: - self.spline_scaler = ( - nn.initializers.kaiming_uniform()( - self.spline_scaler.shape, self.spline_scaler.dtype - ) - * self.scale_spline - ) - - def b_splines(self, x): - grid = self.grid - x = x[..., None] - bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).astype(x.dtype) - for k in range(1, self.spline_order + 1): - bases = (x - grid[:, : -(k + 1)]) / ( - grid[:, k:-1] - grid[:, : -(k + 1)] - ) * bases[..., :-1] + (grid[:, k + 1 :] - x) / ( - grid[:, k + 1 :] - grid[:, 1:(-k)] - ) * bases[ - ..., 1: - ] - return bases - - def curve2coeff(self, x, y): - A = self.b_splines(x).transpose((1, 0, 2)) - B = y.transpose((1, 0, 2)) - solution = jnp.linalg.lstsq(A, B)[0] - result = solution.transpose((2, 0, 1)) - return result - - @property - def scaled_spline_weight(self): - return self.spline_weight * ( - self.spline_scaler[..., None] - if self.enable_standalone_scale_spline - else 1.0 - ) - - def __call__(self, x): - base_output = jnp.dot(self.base_activation(x), self.base_weight.T) - spline_output = jnp.dot( - self.b_splines(x).reshape(x.shape[0], -1), - self.scaled_spline_weight.reshape(self.out_features, -1).T, - ) - return base_output + spline_output - - def update_grid(self, x, margin=0.01): - batch = x.shape[0] - - splines = self.b_splines(x).transpose((1, 0, 2)) - orig_coeff = self.scaled_spline_weight.transpose((1, 2, 0)) - unreduced_spline_output = jnp.matmul(splines, orig_coeff).transpose((1, 0, 2)) - - x_sorted = jnp.sort(x, axis=0) - grid_adaptive = x_sorted[ - jnp.linspace(0, batch - 1, self.grid_size + 1, dtype=jnp.int64) - ] - - uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size - grid_uniform = ( - jnp.arange(self.grid_size + 1, dtype=jnp.float32)[..., None] * uniform_step - + x_sorted[0] - - margin - ) - - grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive - grid = jnp.concatenate( - [ - grid[:1] - - uniform_step * jnp.arange(self.spline_order, 0, -1)[..., None], - grid, - grid[-1:] - + uniform_step * jnp.arange(1, self.spline_order + 1)[..., None], - ], - axis=0, - ) - - self.grid = grid.T - self.spline_weight = self.curve2coeff(x, unreduced_spline_output) - - def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): - l1_fake = jnp.mean(jnp.abs(self.spline_weight), axis=-1) - regularization_loss_activation = jnp.sum(l1_fake) - p = l1_fake / regularization_loss_activation - regularization_loss_entropy = -jnp.sum(p * jnp.log(p)) - return ( - regularize_activation * regularization_loss_activation - + regularize_entropy * regularization_loss_entropy - ) - - -class KAN(nn.Module): - """ - KAN is a class that represents a Kernelized Attention Network (KAN). - It is a type of neural network that uses a kernelized attention mechanism, which allows for more flexibility than traditional attention mechanisms. - - Attributes: - layers_hidden (list): A list of integers representing the number of hidden units in each layer. - """ - - layers_hidden: list - grid_size: int = 5 - spline_order: int = 3 - scale_noise: float = 0.1 - scale_base: float = 1.0 - scale_spline: float = 1.0 - base_activation: callable = nn.silu - grid_eps: float = 0.02 - grid_range: list = [-1, 1] - - def setup(self): - self.layers = [ - KANLinear( - in_features, - out_features, - grid_size=self.grid_size, - spline_order=self.spline_order, - scale_noise=self.scale_noise, - scale_base=self.scale_base, - scale_spline=self.scale_spline, - base_activation=self.base_activation, - grid_eps=self.grid_eps, - grid_range=self.grid_range, - ) - for in_features, out_features in zip( - self.layers_hidden, self.layers_hidden[1:] - ) - ] - - def __call__(self, x, update_grid=False): - for layer in self.layers: - if update_grid: - layer.update_grid(x) - x = layer(x) - return x - - def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): - return sum( - layer.regularization_loss(regularize_activation, regularize_entropy) - for layer in self.layers - ) diff --git a/nanodl/__src/models/kan.py b/nanodl/__src/models/kan.py new file mode 100644 index 0000000..ed6a1cc --- /dev/null +++ b/nanodl/__src/models/kan.py @@ -0,0 +1,243 @@ +from flax import linen as nn +import jax.numpy as jnp +import jax +from jax import random +from typing import Any + + +class KANLinear(nn.Module): + """ + A Flax module implementing a B-spline Neural Network layer, where the basis functions are B-splines. + + Attributes: + in_features (int): Number of input features. + out_features (int): Number of output features. + degree (int): Degree of the B-splines. + """ + in_features: int + out_features: int + degree: int + + def setup(self) -> None: + assert self.degree > 0, "Degree of the B-splines must be greater than 0" + mean, std = 0.0, 1 / (self.in_features * (self.degree + 1)) + self.coefficients = self.param( + "coefficients", + lambda key, shape: mean + std * random.normal(key, shape), + (self.in_features, self.out_features, self.degree + 1) + ) + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + x = jnp.tanh(x) + + knots = jnp.linspace(-1, 1, self.degree + self.in_features + 1) + + b_spline_values = jnp.array([ + self.bspline_basis(x[:, i], self.degree, knots) for i in range(self.in_features) + ]) + + b_spline_values = b_spline_values.transpose((1, 0, 2)) + + output = jnp.einsum('bid,ijd->bj', b_spline_values, self.coefficients) + return output + + def bspline_basis(self, x: jnp.ndarray, degree: int, knots: jnp.ndarray) -> jnp.ndarray: + + def cox_de_boor(x, k, i, t): + if k == 0: + return jnp.where((t[i] <= x) & (x < t[i + 1]), 1.0, 0.0) + else: + denom1 = t[i + k] - t[i] + denom2 = t[i + k + 1] - t[i + 1] + + term1 = jnp.where(denom1 != 0, (x - t[i]) / denom1, 0.0) * cox_de_boor(x, k - 1, i, t) + term2 = jnp.where(denom2 != 0, (t[i + k + 1] - x) / denom2, 0.0) * cox_de_boor(x, k - 1, i + 1, t) + + return term1 + term2 + + n_basis = len(knots) - degree - 1 + basis = jnp.array([cox_de_boor(x, degree, i, knots) for i in range(n_basis)]) + return jnp.transpose(basis) + + +class ChebyKANLinear(nn.Module): + """ + A Flax module implementing a Chebyshev Neural Network layer, where the basis functions are Chebyshev polynomials. + + Inspired by https://github.com/CG80499/KAN-GPT-2/blob/master/chebykan_layer.py + + Attributes: + in_features (int): Number of input features. + out_features (int): Number of output features. + degree (int): Degree of the Chebyshev polynomials. + """ + in_features: int + out_features: int + degree: int + + def setup(self) -> None: + mean, std = 0.0, 1 / (self.in_features * (self.degree + 1)) + self.coefficients = self.param( + "coefficients", + lambda key, shape: mean + std * random.normal(key, shape), + (self.in_features, self.out_features, self.degree + 1) + ) + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + + x = jnp.tanh(x) + + cheby_values = jnp.ones((x.shape[0], self.in_features, self.degree + 1)) + cheby_values = cheby_values.at[:, :, 1].set(x) + + for i in range(2, self.degree + 1): + next_value = 2 * x * cheby_values[:, :, i - 1] - cheby_values[:, :, i - 2] + cheby_values = cheby_values.at[:, :, i].set(next_value) + + output = jnp.einsum('bid,ijd->bj', cheby_values, self.coefficients) + return output + + +class LegendreKANLinear(nn.Module): + """ + A Flax module implementing a Legendre Neural Network layer, where the basis functions are Legendre polynomials. + + Attributes: + in_features (int): Number of input features. + out_features (int): Number of output features. + degree (int): Degree of the Legendre polynomials. + """ + in_features: int + out_features: int + degree: int + + def setup(self) -> None: + assert self.degree > 0, "Degree of the Legendre polynomials must be greater than 0" + mean, std = 0.0, 1 / (self.in_features * (self.degree + 1)) + self.coefficients = self.param( + "coefficients", + lambda key, shape: mean + std * random.normal(key, shape), + (self.in_features, self.out_features, self.degree + 1) + ) + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + x = jnp.tanh(x) + + legendre_values = jnp.ones((x.shape[0], self.in_features, self.degree + 1)) + legendre_values = legendre_values.at[:, :, 1].set(x) + + for i in range(2, self.degree + 1): + next_value = ((2 * i - 1) * x * legendre_values[:, :, i - 1] - (i - 1) * legendre_values[:, :, i - 2]) / i + legendre_values = legendre_values.at[:, :, i].set(next_value) + + output = jnp.einsum('bid,ijd->bj', legendre_values, self.coefficients) + return output + + +class MonomialKANLinear(nn.Module): + """ + A Flax module implementing a Monomial Neural Network layer, where the basis functions are monomials. + + Attributes: + in_features (int): Number of input features. + out_features (int): Number of output features. + degree (int): Degree of the monomial basis functions. + """ + in_features: int + out_features: int + degree: int + + def setup(self) -> None: + assert self.degree > 0, "Degree of the monomial basis functions must be greater than 0" + mean, std = 0.0, 1 / (self.in_features * (self.degree + 1)) + self.coefficients = self.param( + "coefficients", + lambda key, shape: mean + std * random.normal(key, shape), + (self.in_features, self.out_features, self.degree + 1) + ) + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + x = jnp.tanh(x) + monomial_values = jnp.ones((x.shape[0], self.in_features, self.degree + 1)) + + for i in range(1, self.degree + 1): + monomial_values = monomial_values.at[:, :, i].set(x ** i) + + output = jnp.einsum('bid,ijd->bj', monomial_values, self.coefficients) + return output + + +class FourierKANLinear(nn.Module): + """ + A Flax module implementing a Fourier Neural Network layer, where the basis functions are sine and cosine functions. + + Attributes: + in_features (int): Number of input features. + out_features (int): Number of output features. + degree (int): Degree of the Fourier series (i.e., number of harmonics). + """ + in_features: int + out_features: int + degree: int + + def setup(self) -> None: + assert self.degree > 0, "Degree of the Fourier series must be greater than 0" + mean, std = 0.0, 1 / (self.in_features * (2 * self.degree + 1)) + self.sine_coefficients = self.param( + "sine_coefficients", + lambda key, shape: mean + std * random.normal(key, shape), + (self.in_features, self.out_features, self.degree) + ) + self.cosine_coefficients = self.param( + "cosine_coefficients", + lambda key, shape: mean + std * random.normal(key, shape), + (self.in_features, self.out_features, self.degree) + ) + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + x = jnp.tanh(x) + sine_values = jnp.sin(jnp.pi * jnp.arange(1, self.degree + 1) * x[..., None]) + cosine_values = jnp.cos(jnp.pi * jnp.arange(1, self.degree + 1) * x[..., None]) + + output = (jnp.einsum('bid,ijd->bj', sine_values, self.sine_coefficients) + + jnp.einsum('bid,ijd->bj', cosine_values, self.cosine_coefficients)) + return output + + +class HermiteKANLinear(nn.Module): + """ + A Flax module implementing a Hermite Neural Network layer, where the basis functions are Hermite polynomials. + + Attributes: + in_features (int): Number of input features. + out_features (int): Number of output features. + degree (int): Degree of the Hermite polynomials. + """ + in_features: int + out_features: int + degree: int + + def setup(self) -> None: + assert self.degree > 0, "Degree of the Hermite polynomials must be greater than 0" + mean, std = 0.0, 1 / (self.in_features * (self.degree + 1)) + self.coefficients = self.param( + "coefficients", + lambda key, shape: mean + std * random.normal(key, shape), + (self.in_features, self.out_features, self.degree + 1) + ) + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + x = jnp.tanh(x) + + hermite_values = jnp.ones((x.shape[0], self.in_features, self.degree + 1)) + + if self.degree >= 1: + hermite_values = hermite_values.at[:, :, 1].set(2 * x) + for i in range(2, self.degree + 1): + hermite_values = hermite_values.at[:, :, i].set( + 2 * x * hermite_values[:, :, i - 1] - 2 * (i - 1) * hermite_values[:, :, i - 2] + ) + + output = jnp.einsum('bid,ijd->bj', hermite_values, self.coefficients) + return output + \ No newline at end of file diff --git a/setup.py b/setup.py index d437069..c3f6e77 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="nanodl", - version="1.2.5.dev1", + version="0.0.0", author="Henry Ndubuaku", author_email="ndubuakuhenry@gmail.com", description="A Jax-based library for designing and training transformer models from scratch.", diff --git a/tests/test_sklearn_gpu.py b/tests/test_classic.py similarity index 100% rename from tests/test_sklearn_gpu.py rename to tests/test_classic.py diff --git a/tests/test_kan.py b/tests/test_kan.py new file mode 100644 index 0000000..0b0b337 --- /dev/null +++ b/tests/test_kan.py @@ -0,0 +1,33 @@ +import unittest +import jax.numpy as jnp +from jax import random +from nanodl import * + +class TestKANLinearVariants(unittest.TestCase): + def setUp(self): + self.in_features = 4 + self.out_features = 3 + self.degree = 5 + + self.key = random.PRNGKey(0) + self.x = random.normal(self.key, (10, self.in_features)) + + self.models = { + "ChebyKANLinear": ChebyKANLinear(self.in_features, self.out_features, self.degree), + "LegendreKANLinear": LegendreKANLinear(self.in_features, self.out_features, self.degree), + "MonomialKANLinear": MonomialKANLinear(self.in_features, self.out_features, self.degree), + "FourierKANLinear": FourierKANLinear(self.in_features, self.out_features, self.degree), + "HermiteKANLinear": HermiteKANLinear(self.in_features, self.out_features, self.degree), + } + + def test_model_outputs(self): + for model_name, model in self.models.items(): + with self.subTest(model=model_name): + variables = model.init(self.key, self.x) + output = model.apply(variables, self.x) + self.assertEqual(output.shape, (10, self.out_features)) + self.assertTrue(jnp.all(jnp.isfinite(output))) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file