pull/3/head
Saining Xie 2021-08-08 15:06:06 -07:00
parent 4aa07d2415
commit 2e4fb50a88
4 changed files with 308 additions and 0 deletions

121
transfer/README.md 100644
View File

@ -0,0 +1,121 @@
## MoCo v3 for Self-supervised ResNet and ViT
This folder includes the transfer learning experiments on CIFAR-10, CIFAR-100, Flowers and Pets datasets. We provide finetuning recipes for the ViT-Base model.
### Transfer Results
The following results are based on ImageNet-1k self-supervised pre-training, followed by end-to-end fine-tuning on downstream datasets. All results are based on a batch size of 128 and 100 training epochs.
#### ViT, transfer learning
<table><tbody>
<!-- START TABLE -->
<!-- TABLE HEADER -->
<th valign="center">dataset</th>
<th valign="center">pretrain<br/>epochs</th>
<th valign="center">pretrain<br/>crops</th>
<th valign="center">e2e<br/>acc</th>
<!-- TABLE BODY -->
<tr>
<td align="left">CIFAR-10</td>
<td align="right">300</td>
<td align="center">2x224</td>
<td align="center">98.9%</td>
</tr>
<tr>
<td align="left">CIFAR-100</td>
<td align="right">300</td>
<td align="center">2x224</td>
<td align="center">90.5%</td>
</tr>
<tr>
<td align="left">Flowers</td>
<td align="right">300</td>
<td align="center">2x224</td>
<td align="center">97.7%</td>
</tr>
<tr>
<td align="left">Pets</td>
<td align="right">300</td>
<td align="center">2x224</td>
<td align="center">93.2%</td>
</tr>
</tbody></table>
Similar to the end-to-end fine-tuning on ImageNet, the transfer learning results are also obtained using the [DeiT](https://github.com/facebookresearch/deit) repo.
### Usage: Transfer learning with ViT
To perform transfer learning for ViT, use our script to convert the pre-trained ViT checkpoint to [DEiT](https://github.com/facebookresearch/deit) format:
```
python convert_to_deit.py \
--input [your checkpoint path]/[your checkpoint file].pth.tar \
--output [target checkpoint file].pth
```
Then run the training (in the DeiT repo) with the converted checkpoint:
1. copy (or replace) the following files to the DeiT folder:
```
datasets.py
oxford_flowers_dataset.py
oxford_pets_dataset.py
```
1. download and prepare the datasets
Pets [Homepage](https://www.robots.ox.ac.uk/~vgg/data/pets/)
```
.
└── ./pets/
├── ./pets/annotations/ # split and label files
└── ./pets/images/ # data images
```
Flowers [Homepage](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/)
```
.
└── ./flowers/
├── ./flowers/jpg/ # jpg images
├── ./flowers/setid.mat # dataset split
└── ./flowers/imagelabels.mat # labels
```
CIFAR-10/CIFAR-100 datasets will be downloaded automatically.
### Transfer learning scripts (with a 8-GPU machine):
#### CIFAR-10
```
python -u -m torch.distributed.launch --nproc_per_node=8 --use_env main.py \
--batch-size 128 --output_dir [your output dir path] --epochs 100 --lr 3e-4 --weight-decay 0.1 --eval-freq 10 \
--no-pin-mem --warmup-epochs 3 --data-set cifar10 --data-path [cifar-10 data path] --no-repeated-aug \
--reprob 0.0 --drop-path 0.1 --mixup 0.8 --cutmix 1
```
#### CIFAR-100
```
python -u -m torch.distributed.launch --nproc_per_node=8 --use_env main.py \
--batch-size 128 --output_dir [your output dir path] --epochs 100 --lr 3e-4 --weight-decay 0.1 --eval-freq 10 \
--no-pin-mem --warmup-epochs 3 --data-set cifar10 --data-path [cifar-100 data path] --no-repeated-aug \
--reprob 0.0 --drop-path 0.1 --mixup 0.5 --cutmix 1
```
#### Flowers
```
python -u -m torch.distributed.launch --nproc_per_node=8 --use_env main.py \
--batch-size 128 --output_dir [your output dir path] --epochs 100 --lr 3e-4 --weight-decay 0.3 --eval-freq 10 \
--no-pin-mem --warmup-epochs 3 --data-set cifar10 --data-path [oxford-flowers data path] --no-repeated-aug \
--reprob 0.25 --drop-path 0.1 --mixup 0 --cutmix 0
```
#### Pets
```
python -u -m torch.distributed.launch --nproc_per_node=8 --use_env main.py \
--batch-size 128 --output_dir [your output dir path] --epochs 100 --lr 3e-4 --weight-decay 0.1 --eval-freq 10 \
--no-pin-mem --warmup-epochs 3 --data-set cifar10 --data-path [oxford-pets data path] --no-repeated-aug \
--reprob 0 --drop-path 0 --mixup 0.8 --cutmix 0
```
**Note**:
Similar to ImageNet end-to-end finetuning experiment, We use `--resume` rather than `--finetune` in the DeiT repo, as its `--finetune` option trains under eval mode. When loading the pre-trained model, revise `model_without_ddp.load_state_dict(checkpoint['model'])` with `strict=False`.

View File

@ -0,0 +1,69 @@
import os
import json
from torchvision import datasets, transforms
from torchvision.datasets.folder import ImageFolder, default_loader
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import create_transform
import oxford_flowers_dataset, oxford_pets_dataset
def build_transform(is_train, args):
transform_train = transforms.Compose([
transforms.RandomResizedCrop((args.input_size, args.input_size), scale=(0.05, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
transform_test = transforms.Compose([
transforms.Resize(int((256 / 224) * args.input_size)),
transforms.CenterCrop(args.input_size),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
return transform_train if is_train else transform_test
def build_dataset(is_train, args):
transform = build_transform(is_train, args)
if args.data_set == 'imagenet':
raise NotImplementedError("Only [cifar10, cifar100, flowers, pets] are supported; \
for imagenet end-to-end finetuning, please refer to the instructions in the main README.")
if args.data_set == 'imagenet':
root = os.path.join(args.data_path, 'train' if is_train else 'val')
dataset = datasets.ImageFolder(root, transform=transform)
nb_classes = 1000
elif args.data_set == 'cifar10':
dataset = datasets.CIFAR10(root=args.data_path,
train=is_train,
download=True,
transform=transform)
nb_classes = 10
elif args.data_set == "cifar100":
dataset = datasets.CIFAR100(root=args.data_path,
train=is_train,
download=True,
transform=transform)
nb_classes = 100
elif args.data_set == "flowers":
dataset = oxford_flowers_dataset.Flowers(root=args.data_path,
train=is_train,
download=False,
transform=transform)
nb_classes = 102
elif args.data_set == "pets":
dataset = oxford_pets_dataset.Pets(root=args.data_path,
train=is_train,
download=False,
transform=transform)
nb_classes = 37
else:
raise NotImplementedError("Only [cifar10, cifar100, flowers, pets] are supported; \
for imagenet end-to-end finetuning, please refer to the instructions in the main README.")
return dataset, nb_classes

View File

@ -0,0 +1,59 @@
from __future__ import print_function
from PIL import Image
import os
import os.path
import numpy as np
import pickle
import scipy.io
from typing import Any, Callable, Optional, Tuple
from torchvision.datasets.vision import VisionDataset
class Flowers(VisionDataset):
def __init__(
self,
root,
train=True,
transform=None,
target_transform=None,
download=False,
):
super(Flowers, self).__init__(root, transform=transform,
target_transform=target_transform)
base_folder = root
self.image_folder = os.path.join(base_folder, "jpg")
label_file = os.path.join(base_folder, "imagelabels.mat")
setid_file = os.path.join(base_folder, "setid.mat")
self.train = train
self.labels = scipy.io.loadmat(label_file)["labels"][0]
train_list = scipy.io.loadmat(setid_file)["trnid"][0]
val_list = scipy.io.loadmat(setid_file)["valid"][0]
test_list = scipy.io.loadmat(setid_file)["tstid"][0]
trainval_list = np.concatenate([train_list, val_list])
if self.train:
self.img_files = trainval_list
else:
self.img_files = test_list
def __getitem__(self, index):
img_name = "image_%05d.jpg" % self.img_files[index]
target = self.labels[self.img_files[index] - 1] - 1
img = Image.open(os.path.join(self.image_folder, img_name))
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.img_files)

View File

@ -0,0 +1,59 @@
from PIL import Image
import os
import os.path
import numpy as np
import pickle
import scipy.io
from typing import Any, Callable, Optional, Tuple
from torchvision.datasets.vision import VisionDataset
class Pets(VisionDataset):
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(Pets, self).__init__(root, transform=transform,
target_transform=target_transform)
base_folder = root
self.train = train
annotations_path_dir = os.path.join(base_folder, "annotations")
self.image_path_dir = os.path.join(base_folder, "images")
if self.train:
split_file = os.path.join(annotations_path_dir, "trainval.txt")
with open(split_file) as f:
self.images_list = f.readlines()
else:
split_file = os.path.join(annotations_path_dir, "test.txt")
with open(split_file) as f:
self.images_list = f.readlines()
def __getitem__(self, index: int) -> Tuple[Any, Any]:
img_name, label, species, _ = self.images_list[index].strip().split(" ")
img_name += ".jpg"
target = int(label) - 1
img = Image.open(os.path.join(self.image_path_dir, img_name))
img = img.convert('RGB')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self) -> int:
return len(self.images_list)