Chongxuan Li, Kun Xu, Jiashuo Liu, Jun Zhu, Bo Zhang
We propose a unified game-theoretical framework to perform classification and conditional image generation given limited supervision. It is formulated as a three-player minimax game consisting of a generator, a classifier and a discriminator, and therefore is referred to as Triple Generative Adversarial Network (Triple-GAN). The generator and the classifier characterize the conditional distributions between images and labels to perform conditional generation and classification, respectively. The discriminator solely focuses on identifying fake image-label pairs. Under a nonparametric assumption, we prove the unique equilibrium of the game is that the distributions characterized by the generator and the classifier converge to the data distribution. As a byproduct of the three-player mechanism, Triple-GAN is flexible to incorporate different semi-supervised classifiers and GAN architectures. We evaluate Triple-GAN in two challenging settings, namely, semi-supervised learning and the extreme low data regime. In both settings, Triple-GAN can achieve excellent classification results and generate meaningful samples in a specific class simultaneously. In particular, using a commonly adopted 13-layer CNN classifier, Triple-GAN outperforms extensive semi-supervised learning methods substantially on more than 10 benchmarks no matter data augmentation is applied or not.
| Task | Dataset | Metric | Value | Model |
|---|---|---|---|---|
| Image Classification | CIFAR-10, 4000 Labels | Percentage error | 6.54 | Triple-GAN-V2 (ResNet-26) |
| Image Classification | CIFAR-10, 4000 Labels | Percentage error | 10.01 | Triple-GAN-V2 (CNN-13) |
| Image Classification | CIFAR-10, 4000 Labels | Percentage error | 12.41 | Triple-GAN-V2 (CNN-13, no aug) |
| Image Classification | SVHN, 500 Labels | Accuracy | 96.39 | Triple-GAN-V2 (CNN-13) |
| Image Classification | SVHN, 500 Labels | Accuracy | 96.16 | Triple-GAN-V2 (CNN-13, no aug) |
| Image Classification | CIFAR-10, 1000 Labels | Accuracy | 91.59 | Triple-GAN-V2 (ResNet-26) |
| Image Classification | CIFAR-10, 1000 Labels | Accuracy | 85 | Triple-GAN-V2 (CNN-13) |
| Image Classification | CIFAR-10, 1000 Labels | Accuracy | 81.81 | Triple-GAN-V2 (CNN-13, no aug) |
| Image Classification | SVHN, 1000 labels | Accuracy | 96.55 | Triple-GAN-V2 (CNN-13) |
| Image Classification | SVHN, 1000 labels | Accuracy | 96.04 | Triple-GAN-V2 (CNN-13, no aug) |
| Image Classification | SVHN, 250 Labels | Accuracy | 96.52 | Triple-GAN-V2 (CNN-13) |
| Image Classification | SVHN, 250 Labels | Accuracy | 95.81 | Triple-GAN-V2 (CNN-13, no aug) |
| Semi-Supervised Image Classification | CIFAR-10, 4000 Labels | Percentage error | 6.54 | Triple-GAN-V2 (ResNet-26) |
| Semi-Supervised Image Classification | CIFAR-10, 4000 Labels | Percentage error | 10.01 | Triple-GAN-V2 (CNN-13) |
| Semi-Supervised Image Classification | CIFAR-10, 4000 Labels | Percentage error | 12.41 | Triple-GAN-V2 (CNN-13, no aug) |
| Semi-Supervised Image Classification | SVHN, 500 Labels | Accuracy | 96.39 | Triple-GAN-V2 (CNN-13) |
| Semi-Supervised Image Classification | SVHN, 500 Labels | Accuracy | 96.16 | Triple-GAN-V2 (CNN-13, no aug) |
| Semi-Supervised Image Classification | CIFAR-10, 1000 Labels | Accuracy | 91.59 | Triple-GAN-V2 (ResNet-26) |
| Semi-Supervised Image Classification | CIFAR-10, 1000 Labels | Accuracy | 85 | Triple-GAN-V2 (CNN-13) |
| Semi-Supervised Image Classification | CIFAR-10, 1000 Labels | Accuracy | 81.81 | Triple-GAN-V2 (CNN-13, no aug) |
| Semi-Supervised Image Classification | SVHN, 1000 labels | Accuracy | 96.55 | Triple-GAN-V2 (CNN-13) |
| Semi-Supervised Image Classification | SVHN, 1000 labels | Accuracy | 96.04 | Triple-GAN-V2 (CNN-13, no aug) |
| Semi-Supervised Image Classification | SVHN, 250 Labels | Accuracy | 96.52 | Triple-GAN-V2 (CNN-13) |
| Semi-Supervised Image Classification | SVHN, 250 Labels | Accuracy | 95.81 | Triple-GAN-V2 (CNN-13, no aug) |