diff --git a/zeta/nn/attention/mgqa.py b/zeta/nn/attention/mgqa.py index 82c3889d..72510c43 100644 --- a/zeta/nn/attention/mgqa.py +++ b/zeta/nn/attention/mgqa.py @@ -4,7 +4,7 @@ from torch import nn from zeta.nn.attention.attend import Attend -from zeta.utils.cache import CacheView +from zeta.nn.modules.cache import CacheView def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int): diff --git a/zeta/nn/modules/__init__.py b/zeta/nn/modules/__init__.py index 76102c22..df28f73f 100644 --- a/zeta/nn/modules/__init__.py +++ b/zeta/nn/modules/__init__.py @@ -1,25 +1,32 @@ # modules -from zeta.nn.modules.lora import Lora -from zeta.nn.modules.token_learner import TokenLearner -from zeta.nn.modules.dynamic_module import DynamicModule +from zeta.nn.modules.cache import ( + CacheView, + RotatingBufferCache, + RotatingCacheInputMetadata, + interleave_list, + unrotate, +) +from zeta.nn.modules.cnn_text import CNNNew +from zeta.nn.modules.combined_linear import CombinedLinear +from zeta.nn.modules.convnet import ConvNet from zeta.nn.modules.droppath import DropPath +from zeta.nn.modules.dynamic_module import DynamicModule +from zeta.nn.modules.exo import Exo +from zeta.nn.modules.fast_text import FastTextNew from zeta.nn.modules.feedforward_network import FeedForwardNetwork from zeta.nn.modules.layernorm import LayerNorm, l2norm -from zeta.nn.modules.residual import Residual -from zeta.nn.modules.mlp import MLP -from zeta.nn.modules.sublayer import subln -from zeta.nn.modules.combined_linear import CombinedLinear -from zeta.nn.modules.rms_norm import RMSNorm +from zeta.nn.modules.lora import Lora from zeta.nn.modules.mbconv import MBConv -from zeta.nn.modules.super_resolution import SuperResolutionNet -from zeta.nn.modules.convnet import ConvNet -from zeta.nn.modules.shufflenet import ShuffleNet +from zeta.nn.modules.mlp import MLP +from zeta.nn.modules.pulsar import Pulsar +from zeta.nn.modules.residual import Residual from zeta.nn.modules.resnet import ResNet +from zeta.nn.modules.rms_norm import RMSNorm from zeta.nn.modules.rnn_nlp import RNNL -from zeta.nn.modules.cnn_text import CNNNew -from zeta.nn.modules.fast_text import FastTextNew +from zeta.nn.modules.shufflenet import ShuffleNet from zeta.nn.modules.simple_attention import simple_attention from zeta.nn.modules.spacial_transformer import SpacialTransformer +from zeta.nn.modules.sublayer import subln +from zeta.nn.modules.super_resolution import SuperResolutionNet +from zeta.nn.modules.token_learner import TokenLearner from zeta.nn.modules.yolo import yolo -from zeta.nn.modules.pulsar import Pulsar -from zeta.nn.modules.exo import Exo diff --git a/zeta/utils/cache.py b/zeta/nn/modules/cache.py similarity index 100% rename from zeta/utils/cache.py rename to zeta/nn/modules/cache.py