Sanghyeok Lee, Joonmyung Choi, Hyunwoo J. Kim
Vision Transformer (ViT) has emerged as a prominent backbone for computer vision. For more efficient ViTs, recent works lessen the quadratic cost of the self-attention layer by pruning or fusing the redundant tokens. However, these works faced the speed-accuracy trade-off caused by the loss of information. Here, we argue that token fusion needs to consider diverse relations between tokens to minimize information loss. In this paper, we propose a Multi-criteria Token Fusion (MCTF), that gradually fuses the tokens based on multi-criteria (e.g., similarity, informativeness, and size of fused tokens). Further, we utilize the one-step-ahead attention, which is the improved approach to capture the informativeness of the tokens. By training the model equipped with MCTF using a token reduction consistency, we achieve the best speed-accuracy trade-off in the image classification (ImageNet1K). Experimental results prove that MCTF consistently surpasses the previous reduction methods with and without training. Specifically, DeiT-T and DeiT-S with MCTF reduce FLOPs by about 44% while improving the performance (+0.5%, and +0.3%) over the base model, respectively. We also demonstrate the applicability of MCTF in various Vision Transformers (e.g., T2T-ViT, LV-ViT), achieving at least 31% speedup without performance degradation. Code is available at https://github.com/mlvlab/MCTF.
| Task | Dataset | Metric | Value | Model |
|---|---|---|---|---|
| Image Classification | ImageNet-1K (With LV-ViT-S) | GFLOPs | 4.9 | MCTF ($r=8$) |
| Image Classification | ImageNet-1K (With LV-ViT-S) | Top 1 Accuracy | 83.5 | MCTF ($r=8$) |
| Image Classification | ImageNet-1K (With LV-ViT-S) | GFLOPs | 4.2 | MCTF ($r=12$) |
| Image Classification | ImageNet-1K (With LV-ViT-S) | Top 1 Accuracy | 83.4 | MCTF ($r=12$) |
| Image Classification | ImageNet-1K (With LV-ViT-S) | GFLOPs | 3.6 | MCTF ($r=16$) |
| Image Classification | ImageNet-1K (With LV-ViT-S) | Top 1 Accuracy | 82.3 | MCTF ($r=16$) |
| Image Classification | ImageNet-1K (with DeiT-S) | GFLOPs | 2.6 | MCTF ($r=16$) |
| Image Classification | ImageNet-1K (with DeiT-S) | Top 1 Accuracy | 80.1 | MCTF ($r=16$) |
| Image Classification | ImageNet-1K (with DeiT-S) | GFLOPs | 2.4 | MCTF ($r=18$) |
| Image Classification | ImageNet-1K (with DeiT-S) | Top 1 Accuracy | 79.9 | MCTF ($r=18$) |
| Image Classification | ImageNet-1K (with DeiT-S) | GFLOPs | 2.2 | MCTF ($r=20$) |
| Image Classification | ImageNet-1K (with DeiT-S) | Top 1 Accuracy | 79.5 | MCTF ($r=20$) |
| Image Classification | ImageNet-1K (with DeiT-T) | GFLOPs | 1 | MCTF ($r=8$) |
| Image Classification | ImageNet-1K (with DeiT-T) | Top 1 Accuracy | 72.9 | MCTF ($r=8$) |
| Image Classification | ImageNet-1K (with DeiT-T) | GFLOPs | 0.7 | MCTF ($r=16$) |
| Image Classification | ImageNet-1K (with DeiT-T) | Top 1 Accuracy | 72.7 | MCTF ($r=16$) |
| Image Classification | ImageNet-1K (with DeiT-T) | GFLOPs | 0.6 | MCTF ($r=20$) |
| Image Classification | ImageNet-1K (with DeiT-T) | Top 1 Accuracy | 71.4 | MCTF ($r=20$) |