Giorgio Patrini, Alessandro Rozza, Aditya Menon, Richard Nock, Lizhen Qu
We present a theoretically grounded approach to train deep neural networks, including recurrent networks, subject to class-dependent label noise. We propose two procedures for loss correction that are agnostic to both application domain and network architecture. They simply amount to at most a matrix inversion and multiplication, provided that we know the probability of each class being corrupted into another. We further show how one can estimate these probabilities, adapting a recent technique for noise estimation to the multi-class setting, and thus providing an end-to-end framework. Extensive experiments on MNIST, IMDB, CIFAR-10, CIFAR-100 and a large scale dataset of clothing images employing a diversity of architectures --- stacking dense, convolutional, pooling, dropout, batch normalization, word embedding, LSTM and residual layers --- demonstrate the noise robustness of our proposals. Incidentally, we also prove that, when ReLU is the only non-linearity, the loss curvature is immune to class-dependent label noise.
| Task | Dataset | Metric | Value | Model |
|---|---|---|---|---|
| Image Classification | Clothing1M (using clean data) | Accuracy | 80.27 | Forward |
| Image Classification | mini WebVision 1.0 | ImageNet Top-1 Accuracy | 57.36 | F-Correction (Inception-ResNet-v2) |
| Image Classification | mini WebVision 1.0 | ImageNet Top-5 Accuracy | 82.36 | F-Correction (Inception-ResNet-v2) |
| Image Classification | mini WebVision 1.0 | Top-1 Accuracy | 61.12 | F-Correction (Inception-ResNet-v2) |
| Image Classification | mini WebVision 1.0 | Top-5 Accuracy | 82.68 | F-Correction (Inception-ResNet-v2) |
| Image Classification | CIFAR-10N-Random2 | Accuracy (mean) | 86.28 | Backward-T |
| Image Classification | CIFAR-10N-Random2 | Accuracy (mean) | 86.14 | Forward-T |
| Image Classification | CIFAR-10N-Random3 | Accuracy (mean) | 87.04 | Forward-T |
| Image Classification | CIFAR-10N-Random3 | Accuracy (mean) | 86.86 | Backward-T |
| Image Classification | CIFAR-10N-Aggregate | Accuracy (mean) | 88.24 | Forward-T |
| Image Classification | CIFAR-10N-Aggregate | Accuracy (mean) | 88.13 | Backward-T |
| Image Classification | CIFAR-10N-Random1 | Accuracy (mean) | 87.14 | Backward-T |
| Image Classification | CIFAR-10N-Random1 | Accuracy (mean) | 86.88 | Forward-T |
| Image Classification | CIFAR-100N | Accuracy (mean) | 57.14 | Backward-T |
| Image Classification | CIFAR-100N | Accuracy (mean) | 57.01 | Forward-T |
| Image Classification | CIFAR-10N-Worst | Accuracy (mean) | 79.79 | Forward-T |
| Image Classification | CIFAR-10N-Worst | Accuracy (mean) | 77.61 | Backward-T |
| Document Text Classification | CIFAR-10N-Random2 | Accuracy (mean) | 86.28 | Backward-T |
| Document Text Classification | CIFAR-10N-Random2 | Accuracy (mean) | 86.14 | Forward-T |
| Document Text Classification | CIFAR-10N-Random3 | Accuracy (mean) | 87.04 | Forward-T |
| Document Text Classification | CIFAR-10N-Random3 | Accuracy (mean) | 86.86 | Backward-T |
| Document Text Classification | CIFAR-10N-Aggregate | Accuracy (mean) | 88.24 | Forward-T |
| Document Text Classification | CIFAR-10N-Aggregate | Accuracy (mean) | 88.13 | Backward-T |
| Document Text Classification | CIFAR-10N-Random1 | Accuracy (mean) | 87.14 | Backward-T |
| Document Text Classification | CIFAR-10N-Random1 | Accuracy (mean) | 86.88 | Forward-T |
| Document Text Classification | CIFAR-100N | Accuracy (mean) | 57.14 | Backward-T |
| Document Text Classification | CIFAR-100N | Accuracy (mean) | 57.01 | Forward-T |
| Document Text Classification | CIFAR-10N-Worst | Accuracy (mean) | 79.79 | Forward-T |
| Document Text Classification | CIFAR-10N-Worst | Accuracy (mean) | 77.61 | Backward-T |