Kaidi Cao, Maria Brbic, Jure Leskovec
A fundamental limitation of applying semi-supervised learning in real-world settings is the assumption that unlabeled test data contains only classes previously encountered in the labeled training data. However, this assumption rarely holds for data in-the-wild, where instances belonging to novel classes may appear at testing time. Here, we introduce a novel open-world semi-supervised learning setting that formalizes the notion that novel classes may appear in the unlabeled test data. In this novel setting, the goal is to solve the class distribution mismatch between labeled and unlabeled data, where at the test time every input instance either needs to be classified into one of the existing classes or a new unseen class needs to be initialized. To tackle this challenging problem, we propose ORCA, an end-to-end deep learning approach that introduces uncertainty adaptive margin mechanism to circumvent the bias towards seen classes caused by learning discriminative features for seen classes faster than for the novel classes. In this way, ORCA reduces the gap between intra-class variance of seen with respect to novel classes. Experiments on image classification datasets and a single-cell annotation dataset demonstrate that ORCA consistently outperforms alternative baselines, achieving 25% improvement on seen and 96% improvement on novel classes of the ImageNet dataset.
| Task | Dataset | Metric | Value | Model |
|---|---|---|---|---|
| Image Classification | CIFAR-100 | All accuracy (10% Labeled) | 38.6 | ORCA (ResNet-18) |
| Image Classification | CIFAR-100 | All accuracy (50% Labeled) | 48.1 | ORCA (ResNet-18) |
| Image Classification | CIFAR-100 | Novel accuracy (10% Labeled) | 31.8 | ORCA (ResNet-18) |
| Image Classification | CIFAR-100 | Novel accuracy (50% Labeled) | 43 | ORCA (ResNet-18) |
| Image Classification | CIFAR-100 | Seen accuracy (10% Labeled) | 52.5 | ORCA (ResNet-18) |
| Image Classification | CIFAR-100 | Seen accuracy (50% Labeled) | 66.9 | ORCA (ResNet-18) |
| Image Classification | ImageNet-100 (TEMI Split) | All accuracy (10% Labeled) | 69.7 | ORCA (ResNet-50) |
| Image Classification | ImageNet-100 (TEMI Split) | All accuracy (50% Labeled) | 77.8 | ORCA (ResNet-50) |
| Image Classification | ImageNet-100 (TEMI Split) | Novel accuracy (10% Labeled) | 60.5 | ORCA (ResNet-50) |
| Image Classification | ImageNet-100 (TEMI Split) | Novel accuracy (50% Labeled) | 72.1 | ORCA (ResNet-50) |
| Image Classification | ImageNet-100 (TEMI Split) | Seen accuracy (10% Labeled) | 83.9 | ORCA (ResNet-50) |
| Image Classification | ImageNet-100 (TEMI Split) | Seen accuracy (50% Labeled) | 89.1 | ORCA (ResNet-50) |
| Image Classification | CIFAR-10 | All accuracy (10% Labeled) | 84.1 | ORCA (ResNet-18) |
| Image Classification | CIFAR-10 | All accuracy (50% Labeled) | 89.7 | ORCA (ResNet-18) |
| Image Classification | CIFAR-10 | Novel accuracy (10% Labeled) | 85.5 | ORCA (ResNet-18) |
| Image Classification | CIFAR-10 | Novel accuracy (50% Labeled) | 90.4 | ORCA (ResNet-18) |
| Image Classification | CIFAR-10 | Seen accuracy (10% Labeled) | 82.8 | ORCA (ResNet-18) |
| Image Classification | CIFAR-10 | Seen accuracy (50% Labeled) | 88.2 | ORCA (ResNet-18) |
| 2D Object Detection | LVIS v1.0 val | All mAP | 2.03 | ORCA Cao et al. (2022) |
| 2D Object Detection | LVIS v1.0 val | Known mAP | 20.57 | ORCA Cao et al. (2022) |
| 2D Object Detection | LVIS v1.0 val | Novel mAP | 0.49 | ORCA Cao et al. (2022) |
| Semi-Supervised Image Classification | CIFAR-100 | All accuracy (10% Labeled) | 38.6 | ORCA (ResNet-18) |
| Semi-Supervised Image Classification | CIFAR-100 | All accuracy (50% Labeled) | 48.1 | ORCA (ResNet-18) |
| Semi-Supervised Image Classification | CIFAR-100 | Novel accuracy (10% Labeled) | 31.8 | ORCA (ResNet-18) |
| Semi-Supervised Image Classification | CIFAR-100 | Novel accuracy (50% Labeled) | 43 | ORCA (ResNet-18) |
| Semi-Supervised Image Classification | CIFAR-100 | Seen accuracy (10% Labeled) | 52.5 | ORCA (ResNet-18) |
| Semi-Supervised Image Classification | CIFAR-100 | Seen accuracy (50% Labeled) | 66.9 | ORCA (ResNet-18) |
| Semi-Supervised Image Classification | ImageNet-100 (TEMI Split) | All accuracy (10% Labeled) | 69.7 | ORCA (ResNet-50) |
| Semi-Supervised Image Classification | ImageNet-100 (TEMI Split) | All accuracy (50% Labeled) | 77.8 | ORCA (ResNet-50) |
| Semi-Supervised Image Classification | ImageNet-100 (TEMI Split) | Novel accuracy (10% Labeled) | 60.5 | ORCA (ResNet-50) |
| Semi-Supervised Image Classification | ImageNet-100 (TEMI Split) | Novel accuracy (50% Labeled) | 72.1 | ORCA (ResNet-50) |
| Semi-Supervised Image Classification | ImageNet-100 (TEMI Split) | Seen accuracy (10% Labeled) | 83.9 | ORCA (ResNet-50) |
| Semi-Supervised Image Classification | ImageNet-100 (TEMI Split) | Seen accuracy (50% Labeled) | 89.1 | ORCA (ResNet-50) |
| Semi-Supervised Image Classification | CIFAR-10 | All accuracy (10% Labeled) | 84.1 | ORCA (ResNet-18) |
| Semi-Supervised Image Classification | CIFAR-10 | All accuracy (50% Labeled) | 89.7 | ORCA (ResNet-18) |
| Semi-Supervised Image Classification | CIFAR-10 | Novel accuracy (10% Labeled) | 85.5 | ORCA (ResNet-18) |
| Semi-Supervised Image Classification | CIFAR-10 | Novel accuracy (50% Labeled) | 90.4 | ORCA (ResNet-18) |
| Semi-Supervised Image Classification | CIFAR-10 | Seen accuracy (10% Labeled) | 82.8 | ORCA (ResNet-18) |
| Semi-Supervised Image Classification | CIFAR-10 | Seen accuracy (50% Labeled) | 88.2 | ORCA (ResNet-18) |