Skip to content

Unofficial implementation of Titans, SOTA memory for transformers, in Pytorch

License

Notifications You must be signed in to change notification settings

lucidrains/titans-pytorch

Repository files navigation

Titans - Pytorch (wip)

Unofficial implementation of Titans in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.

Install

$ pip install titans-pytorch

Usage

import torch
from titans_pytorch import NeuralMemory

mem = NeuralMemory(
    dim = 384,
    chunk_size = 64,
    pre_rmsnorm = True
).cuda()

seq = torch.randn(2, 1024, 384).cuda()
retrieved = mem(seq)

assert seq.shape == retrieved.shape

A transformer with the MAC configuration can be used as

import torch
from titans_pytorch import MemoryAsContextTransformer

transformer = MemoryAsContextTransformer(
    num_tokens = 256,
    dim = 256,
    depth = 2,
    segment_len = 128,              # local attention window size
    num_persist_mem_tokens = 4,
    num_longterm_mem_tokens = 16,
)

token_ids = torch.randint(0, 256, (1, 1023))

logits = transformer(token_ids) # (1, 1023, 256)

Experiments

$ pip install .[examples]

Then modify train_mac.py and run it to query nature

$ python train_mac.py

Citations

@inproceedings{Behrouz2024TitansLT,
    title   = {Titans: Learning to Memorize at Test Time},
    author  = {Ali Behrouz and Peilin Zhong and Vahab S. Mirrokni},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:275212078}
}
@software{Kyrylov_Accelerated_Scan_2024,
    author  = {Kyrylov, Volodymyr},
    doi     = {10.5281/zenodo.10600962},
    title   = {Accelerated Scan},
    version = {0.1.2},
    year    = {2024}
}
@inproceedings{Yang2024GatedDN,
    title   = {Gated Delta Networks: Improving Mamba2 with Delta Rule},
    author  = {Songlin Yang and Jan Kautz and Ali Hatamizadeh},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:274598177}
}

About

Unofficial implementation of Titans, SOTA memory for transformers, in Pytorch

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages