Skip to content

axrwl/fttransformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Implementation of Revisiting Deep Learning Models for Tabular Data in JAX based on lucidrains/tab-transformer-pytorch.

Installation

pip install fttjax

Usage

from fttjax import FTTransformer
from jax import random

model = FTTransformer(
    categories = (10, 5, 6, 5, 8),
    num_continuous = 10,
    dim = 32,
    dim_out = 1,
    depth = 6,
    heads = 8,
    attn_dropout = 0.1,
    ff_dropout = 0.1
)
x_categ =
x_numer =
rng = random.PRNGKey(0)
p_rng, d_rng = random.split(rng)
pred = model.init({'params': p_rng, 'dropout': d_rng}, x_categ, x_numer)

Citation

@article{Gorishniy2021RevisitingDL,
    title   = {Revisiting Deep Learning Models for Tabular Data},
    author  = {Yu. V. Gorishniy and Ivan Rubachev and Valentin Khrulkov and Artem Babenko},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2106.11959}
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages