Ping Yu, Mikel Artetxe, Myle Ott, Sam Shleifer, Hongyu Gong, Ves Stoyanov, Xian Li
All-MLP architectures have attracted increasing interest as an alternative to attention-based models. In NLP, recent work like gMLP shows that all-MLPs can match Transformers in language modeling, but still lag behind in downstream tasks. In this work, we analyze the limitations of MLPs in expressiveness, and propose sparsely activated MLPs with mixture-of-experts (MoEs) in both feature and input (token) dimensions. Such sparse all-MLPs significantly increase model capacity and expressiveness while keeping the compute constant. We address critical challenges in incorporating conditional computation with two routing strategies. The proposed sparse all-MLP improves language modeling perplexity and obtains up to 2$\times$ improvement in training efficiency compared to both Transformer-based MoEs (GShard, Switch Transformer, Base Layers and HASH Layers) as well as dense Transformers and all-MLPs. Finally, we evaluate its zero-shot in-context learning performance on six downstream tasks, and find that it surpasses Transformer-based MoEs and dense Transformers.
| Task | Dataset | Metric | Value | Model |
|---|---|---|---|---|
| Question Answering | COPA | Accuracy | 79 | sMLP – deterministic 9.4B (0-shot) |
| Question Answering | COPA | Accuracy | 76 | Gshard 9B |
| Question Answering | COPA | Accuracy | 75 | Switch Transformer 9B |
| Question Answering | COPA | Accuracy | 64 | HASH Layers 10B (0-shot) |
| Question Answering | COPA | Accuracy | 63 | Base Layers 10B (0-shot) |
| Question Answering | PIQA | Accuracy | 73 | sMLP - deterministic 9.4B (0-shot) |
| Question Answering | PIQA | Accuracy | 68.1 | Gshard 9B |
| Question Answering | PIQA | Accuracy | 63.8 | Base Layers 10B (0-shot) |
| Question Answering | PIQA | Accuracy | 63.8 | HASH Layers 10B (0-shot) |
| Question Answering | StoryCloze | Accuracy | 74.7 | sMLP – deterministic 9.4B (0-shot) |
| Question Answering | StoryCloze | Accuracy | 73.3 | Switch Transformer 9B |
| Question Answering | StoryCloze | Accuracy | 67.9 | Gshard 9B |
| Question Answering | StoryCloze | Accuracy | 64.7 | HASH Layers 10B (0-shot) |
| Question Answering | StoryCloze | Accuracy | 61.4 | Base Layers 10B (0-shot) |
| Common Sense Reasoning | WinoGrande | Accuracy | 54.3 | sMLP – deterministic 9.4B (0-shot) |
| Common Sense Reasoning | WinoGrande | Accuracy | 53.4 | Switch Transformer 9B (0-shot) |
| Common Sense Reasoning | WinoGrande | Accuracy | 51.7 | HASH Layers 10B (0-shot) |
| Common Sense Reasoning | WinoGrande | Accuracy | 51.1 | Gshard 9B (0-shot) |
| Common Sense Reasoning | WinoGrande | Accuracy | 51 | Base Layers 10B (0-shot) |
| Common Sense Reasoning | ReCoRD | EM | 79.9 | Switch Transformer 9B |
| Common Sense Reasoning | ReCoRD | EM | 73.4 | sMLP – deterministic 9.4B (0-shot) |
| Common Sense Reasoning | ReCoRD | EM | 72.4 | Gshard 9B |
| Common Sense Reasoning | ReCoRD | EM | 67.2 | HASH Layers 10B (0-shot) |
| Common Sense Reasoning | ReCoRD | EM | 60.7 | Base Layers 10B (0-shot) |
| Sentence Completion | HellaSwag | Accuracy | 54.5 | sMLP – deterministic 9.4B (0-shot) |
| Sentence Completion | HellaSwag | Accuracy | 52.5 | Switch Transformer 9B |
| Sentence Completion | HellaSwag | Accuracy | 38 | Gshard 9B |
| Sentence Completion | HellaSwag | Accuracy | 33 | HASH Layers 10B (0-shot) |
| Sentence Completion | HellaSwag | Accuracy | 30.2 | Base Layers 10B (0-shot) |