# FLatten-Transformer **Repository Path**: xxuffei/FLatten-Transformer ## Basic Information - **Project Name**: FLatten-Transformer - **Description**: No description available - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2024-07-31 - **Last Updated**: 2024-07-31 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # FLatten Transformer This repo contains the official **PyTorch** code and pre-trained models for FLatten Transformer (ICCV 2023). + [FLatten Transformer: Vision Transformer with Focused Linear Attention](https://arxiv.org/abs/2308.00442) ## Updates - May 28 2024: **Fix numerical instability problem.** Now FLatten Transformers can be trained with auto mixed precision (amp) or float16. ## Introduction ### Motivation

The quadratic computation complexity of self-attention $\mathcal{O}(N^2)$ has been a long-standing problem when applying Transformer models to vision tasks. Apart from reducing attention regions, linear attention is also considered as an effective solution to avoid excessive computation costs. By approximating Softmax with carefully designed mapping functions, linear attention can switch the computation order in the self-attention operation and achieve linear complexity $\mathcal{O}(N)$. Nevertheless, current linear attention approaches either suffer from severe performance drop or involve additional computation overhead from the mapping function. In this paper, we propose a novel **Focused Linear Attention** module to achieve both high efficiency and expressiveness. ### Method

In this paper, we first perform a detailed analysis of the inferior performances of linear attention from two perspectives: focus ability and feature diversity. Then, we introduce a simple yet effective mapping function and an efficient rank restoration module and propose our **Focused Linear Attention (FLatten)** which adequately addresses these concerns and achieves high efficiency and expressive capability. ### Results - Comparison of different models on ImageNet-1K.

- Accuracy-Runtime curve on ImageNet.

## Dependencies - Python 3.9 - PyTorch == 1.11.0 - torchvision == 0.12.0 - numpy - timm == 0.4.12 - einops - yacs ## Data preparation The ImageNet dataset should be prepared as follows: ``` $ tree data imagenet ├── train │ ├── class1 │ │ ├── img1.jpeg │ │ ├── img2.jpeg │ │ └── ... │ ├── class2 │ │ ├── img3.jpeg │ │ └── ... │ └── ... └── val ├── class1 │ ├── img4.jpeg │ ├── img5.jpeg │ └── ... ├── class2 │ ├── img6.jpeg │ └── ... └── ... ``` ## Pretrained Models Based on different model architectures, we provide several pretrained models, as listed below. | model | Reso | acc@1 | config | pretrained weights | | :---: | :---: | :---: | :---: | :---: | | FLatten-PVT-T | $224^2$ | 77.8 (+2.7) | [config](cfgs/flatten_pvt_t.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/3ab1d773f19d45648690/?dl=1) | | FLatten-PVTv2-B0 | $224^2$ | 71.1 (+0.6) | [config](cfgs/flatten_pvt_v2_b0.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/5d1f01532b104da28e7b/?dl=1) | | FLatten-Swin-T | $224^2$ | 82.1 (+0.8) | [config](cfgs/flatten_swin_t.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/e1518e76703e4e57a7f2/?dl=1) | | FLatten-Swin-S | $224^2$ | 83.5 (+0.5) | [config](cfgs/flatten_swin_s.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/94188e52af354bf4a88b/?dl=1) | | FLatten-Swin-B | $224^2$ | 83.8 (+0.3) | [config](cfgs/flatten_swin_b.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/7a9e5186bad04e7fb3a9/?dl=1) | | FLatten-Swin-B | $384^2$ | 85.0 (+0.5) | [config](cfgs/flatten_swin_b_384.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/0d0330cf2e5249f1abb6/?dl=1) | | FLatten-CSwin-T | $224^2$ | 83.1 (+0.4) | [config](cfgs/flatten_cswin_t.yaml) | [TsinghuaCloud](https://cloud.tsinghua.edu.cn/f/48ba765ba8b0451d9d5a/?dl=1) | Evaluate one model on ImageNet: ```shell python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg --data-path --output --eval --resume ``` Outputs of the four T/B0 pretrained models are: ``` [2023-07-21 07:50:09 flatten_pvt_tiny] (main.py 294): INFO * Acc@1 77.758 Acc@5 93.910 [2023-07-21 07:50:09 flatten_pvt_tiny] (main.py 149): INFO Accuracy of the network on the 50000 test images: 77.8% [2023-07-21 07:51:36 flatten_pvt_v2_b0] (main.py 294): INFO * Acc@1 71.098 Acc@5 90.596 [2023-07-21 07:51:36 flatten_pvt_v2_b0] (main.py 149): INFO Accuracy of the network on the 50000 test images: 71.1% [2023-07-21 07:46:13 flatten_swin_tiny_patch4_224] (main.py 294): INFO * Acc@1 82.106 Acc@5 95.900 [2023-07-21 07:46:13 flatten_swin_tiny_patch4_224] (main.py 149): INFO Accuracy of the network on the 50000 test images: 82.1% [2023-07-21 07:52:46 FLatten_CSWin_tiny](main.py 294): INFO * Acc@1 83.130 Acc@5 96.376 [2023-07-21 07:52:46 FLatten_CSWin_tiny](main.py 149): INFO Accuracy of the network on the 50000 test images: 83.1% ``` ## Train Models from Scratch - **To train `FLatten-PVT-T/S/M/B` on ImageNet from scratch, run:** ```shell python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_t.yaml --data-path --output --find-unused-params ``` ```shell python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_s.yaml --data-path --output --find-unused-params ``` ```shell python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_m.yaml --data-path --output --find-unused-params ``` ```shell python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_b.yaml --data-path --output --find-unused-params ``` - **To train `FLatten-PVT-v2-b0/1/2/3/4` on ImageNet from scratch, run:** ```shell python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b0.yaml --data-path --output ``` ```shell python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b1.yaml --data-path --output ``` ```shell python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b2.yaml --data-path --output ``` ```shell python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b3.yaml --data-path --output ``` ```shell python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b4.yaml --data-path --output ``` - **To train `FLatten-Swin-T/S/B` on ImageNet from scratch, run:** ```shell python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_t.yaml --data-path --output ``` ```shell python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_s.yaml --data-path --output ``` ```shell python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_b.yaml --data-path --output ``` - **To train `FLatten-CSwin-T/S/B` on ImageNet from scratch, run:** ```shell python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_t.yaml --data-path --output --model-ema --model-ema-decay 0.99984 ``` ```shell python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_s.yaml --data-path --output --model-ema --model-ema-decay 0.99984 ``` ```shell python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_b.yaml --data-path --output --model-ema --model-ema-decay 0.99982 ``` ## Fine-tuning on higher resolution Fine-tune a `FLatten-Swin-B` model pre-trained on 224x224 resolution to 384x384 resolution: ```shell python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_b_384.yaml --data-path --output --pretrained ``` Fine-tune a `FLatten-CSwin-B` model pre-trained on 224x224 resolution to 384x384 resolution: ```shell python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_b_384.yaml --data-path --output --pretrained --model-ema --model-ema-decay 0.99982 ``` ## Visualization We provide code for visualizing flatten attention. For example, to visualize flatten attention in FLatten-Swin-T, add the following to [this line](https://github.com/LeapLabTHU/FLatten-Transformer/blob/96b7dac65e9688d947a3afa01a0c70b92d9654c8/models/flatten_swin.py#L229). ```python from visualize import AttnVisualizer visualizer = AttnVisualizer(qk=[q, k], kernel=self.dwc.weight, name='flatten_swin_t') visualizer.visualize_all_attn(max_num=196, image='./visualize/img_ori_00809.png') ``` Then run: ```shell python visualize.py ``` **Note:** Don't forget to modify the path of FLatten-Swin-T pretrained weight in `visualize.py`. ## Acknowledgements This code is developed on the top of [Swin Transformer](https://github.com/microsoft/Swin-Transformer). The computational resources supporting this work are provided by [Hangzhou High-Flyer AI Fundamental Research Co.,Ltd](https://www.high-flyer.cn/) ## Citation If you find this repo helpful, please consider citing us. ```latex @InProceedings{han2023flatten, title={FLatten Transformer: Vision Transformer using Focused Linear Attention}, author={Han, Dongchen and Pan, Xuran and Han, Yizeng and Song, Shiji and Huang, Gao}, booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, year={2023} } ``` ## Contact If you have any questions, please feel free to contact the authors. Dongchen Han: [hdc23@mails.tsinghua.edu.cn](mailto:hdc23@mails.tsinghua.edu.cn) Xuran Pan: [pxr18@mails.tsinghua.edu.cn](mailto:pxr18@mails.tsinghua.edu.cn)