# tab-transformer-pytorch
**Repository Path**: frontxiang/tab-transformer-pytorch
## Basic Information
- **Project Name**: tab-transformer-pytorch
- **Description**: No description available
- **Primary Language**: Python
- **License**: MIT
- **Default Branch**: main
- **Homepage**: None
- **GVP Project**: No
## Statistics
- **Stars**: 0
- **Forks**: 0
- **Created**: 2021-09-16
- **Last Updated**: 2021-09-16
## Categories & Tags
**Categories**: Uncategorized
**Tags**: None
## README
## Tab Transformer
Implementation of Tab Transformer, attention network for tabular data, in Pytorch. This simple architecture came within a hair's breadth of GBDT's performance.
## Install
```bash
$ pip install tab-transformer-pytorch
```
## Usage
```python
import torch
import torch.nn as nn
from tab_transformer_pytorch import TabTransformer
cont_mean_std = torch.randn(10, 2)
model = TabTransformer(
categories = (10, 5, 6, 5, 8), # tuple containing the number of unique values within each category
num_continuous = 10, # number of continuous values
dim = 32, # dimension, paper set at 32
dim_out = 1, # binary prediction, but could be anything
depth = 6, # depth, paper recommended 6
heads = 8, # heads, paper recommends 8
attn_dropout = 0.1, # post-attention dropout
ff_dropout = 0.1, # feed forward dropout
mlp_hidden_mults = (4, 2), # relative multiples of each hidden dimension of the last mlp to logits
mlp_act = nn.ReLU(), # activation for final mlp, defaults to relu, but could be anything else (selu etc)
continuous_mean_std = cont_mean_std # (optional) - normalize the continuous values before layer norm
)
x_categ = torch.randint(0, 5, (1, 5)) # category values, from 0 - max number of categories, in the order as passed into the constructor above
x_cont = torch.randn(1, 10) # assume continuous values are already normalized individually
pred = model(x_categ, x_cont)
```
## Unsupervised Training
To undergo the type of unsupervised training described in the paper, you can first convert your categories tokens to the appropriate unique ids, and then use Electra on `model.transformer`.
## Citations
```bibtex
@misc{huang2020tabtransformer,
title={TabTransformer: Tabular Data Modeling Using Contextual Embeddings},
author={Xin Huang and Ashish Khetan and Milan Cvitkovic and Zohar Karnin},
year={2020},
eprint={2012.06678},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
```