Shuang Li, Yilun Du, Joshua B. Tenenbaum, Antonio Torralba, Igor Mordatch
Large pre-trained models exhibit distinct and complementary capabilities dependent on the data they are trained on. Language models such as GPT-3 are capable of textual reasoning but cannot understand visual information, while vision models such as DALL-E can generate photorealistic photos but fail to understand complex language descriptions. In this work, we propose a unified framework for composing ensembles of different pre-trained models -- combining the strengths of each individual model to solve various multimodal problems in a zero-shot manner. We use pre-trained models as "generators" or "scorers" and compose them via closed-loop iterative consensus optimization. The generator constructs proposals and the scorers iteratively provide feedback to refine the generated result. Such closed-loop communication enables models to correct errors caused by other models, significantly boosting performance on downstream tasks, e.g. improving accuracy on grade school math problems by 7.5%, without requiring any model finetuning. We demonstrate that consensus achieved by an ensemble of scorers outperforms the feedback of a single scorer, by leveraging the strengths of each expert model. Results show that the proposed method can be used as a general purpose framework for a wide range of zero-shot multimodal tasks, such as image generation, video question answering, mathematical reasoning, and robotic manipulation. Project page: https://energy-based-model.github.io/composing-pretrained-models.
| Task | Dataset | Metric | Value | Model |
|---|---|---|---|---|
| Image Generation | ImageNet 64x64 | FID | 29.184 | GLIDE + CLIP + CLS + CLS-FREE |
| Image Generation | ImageNet 64x64 | Inception Score | 34.952 | GLIDE + CLIP + CLS + CLS-FREE |
| Image Generation | ImageNet 64x64 | KID | 3.766 | GLIDE + CLIP + CLS + CLS-FREE |
| Image Generation | ImageNet 64x64 | FID | 29.219 | GLIDE + CLS-FREE |
| Image Generation | ImageNet 64x64 | Inception Score | 25.926 | GLIDE + CLS-FREE |
| Image Generation | ImageNet 64x64 | KID | 5.325 | GLIDE + CLS-FREE |
| Image Generation | ImageNet 64x64 | FID | 30.462 | GLIDE + CLIP |
| Image Generation | ImageNet 64x64 | Inception Score | 25.017 | GLIDE + CLIP |
| Image Generation | ImageNet 64x64 | KID | 6.174 | GLIDE + CLIP |
| Image Generation | ImageNet 64x64 | FID | 30.871 | GLIDE + CLS |
| Image Generation | ImageNet 64x64 | Inception Score | 22.077 | GLIDE + CLS |
| Image Generation | ImageNet 64x64 | KID | 7.952 | GLIDE +CLS |
| Video Question Answering | ActivityNet-QA | Accuracy | 61.2 | GPT-2 + CLIP-14 + CLIP-multilingual (Zero-Shot) |
| Video Question Answering | ActivityNet-QA | Accuracy | 58.4 | GPT-2 + CLIP-32 (Zero-Shot) |
| Arithmetic Reasoning | GSM8K | Accuracy | 20.8 | GPT-2-Medium 355M + question-solution classifier (BS=5) |
| Arithmetic Reasoning | GSM8K | Parameters (Billion) | 0.355 | GPT-2-Medium 355M + question-solution classifier (BS=5) |
| Arithmetic Reasoning | GSM8K | Accuracy | 18.3 | GPT-2-Medium 355M (fine-tuned, BS=5) |
| Arithmetic Reasoning | GSM8K | Parameters (Billion) | 0.355 | GPT-2-Medium 355M (fine-tuned, BS=5) |
| Arithmetic Reasoning | GSM8K | Accuracy | 16.8 | GPT-2-Medium 355M + question-solution classifier (BS=1) |
| Arithmetic Reasoning | GSM8K | Parameters (Billion) | 0.355 | GPT-2-Medium 355M + question-solution classifier (BS=1) |
| Arithmetic Reasoning | GSM8K | Accuracy | 12.2 | GPT-2-Medium 355M (BS=5) |
| Arithmetic Reasoning | GSM8K | Parameters (Billion) | 0.355 | GPT-2-Medium 355M (BS=5) |