start readme

pull/3/head
Xinlei Chen 2021-07-08 04:34:54 -07:00
parent d04dcd5246
commit 1d76a72e35
4 changed files with 53 additions and 19 deletions

View File

@ -1,4 +1,47 @@
# moco-v3
# MoCo v3
This is a PyTorch implementation of [MoCo v3](https://arxiv.org/abs/2104.02057):
```
@Article{chen2021mocov3,
author = {Xinlei Chen* and Saining Xie* and Kaiming He},
title = {An Empirical Study of Training Self-Supervised Vision Transformers},
journal = {arXiv preprint arXiv:2104.02057},
year = {2021},
}
```
### Preparation
Install PyTorch and download the ImageNet dataset following the [official PyTorch ImageNet training code](https://github.com/pytorch/examples/tree/master/imagenet). Similar to [MoCo](https://github.com/facebookresearch/moco), the code release contains minimal modifications for both unsupervised pre-training and linear classification to that code.
In addition, install [timm](https://github.com/rwightman/pytorch-image-models) for the Vision Transformer [(ViT)](https://arxiv.org/abs/2010.11929) models.
### Unsupervised Pre-Training
Similar to MoCo, only **multi-gpu**, **DistributedDataParallel** training is supported; single-gpu or DataParallel training is not supported. In addition, the code is tested with **multi-node** setting, and by default uses automatic **mixed-precision** for pre-training.
Below we exemplify several pre-training commands covering different model architectures, training epochs, single-/multi-node, etc.
<details>
<summary>
MoCo v3 with ResNet-50, 100-Epoch, 2-Node.
</summary>
This is the default setting for most hyper-parameters. With a batch size of 4096, the training fits into 2 nodes with a total of 16 Volta 32G GPUs.
On the first node, run:
```
python main_moco.py \
--dist-url "tcp://[your node 1 address]:[specified port]" \
--multiprocessing-distributed --world-size 2 --rank 0 \
[your imagenet-folder with train and val folders]
```
On the second node, run:
```
python main_moco.py \
--dist-url "tcp://[your node 1 address]:[specified port]" \
--multiprocessing-distributed --world-size 2 --rank 1 \
[your imagenet-folder with train and val folders]
```
</details>
### License

View File

@ -32,7 +32,6 @@ import torchvision.datasets as datasets
import torchvision.models as torchvision_models
from functools import partial
import apex
import vits
import moco.builder
@ -126,8 +125,6 @@ parser.add_argument('--optimizer', default='lars', type=str,
help='optimizer used (default: lars)')
parser.add_argument('--warmup-epochs', default=10, type=int, metavar='N',
help='number of warmup epochs')
parser.add_argument('--mixed-precision', action='store_true',
help='Use mixed precision')
# ===== OTHERS, to delete =====
parser.add_argument('--checkpoint-folder', default='.', type=str, metavar='PATH',
@ -255,7 +252,7 @@ def main_worker(gpu, ngpus_per_node, args):
optimizer = moco.optimizer.AdamW(model.parameters(), init_lr,
weight_decay=args.weight_decay)
scaler = torch.cuda.amp.GradScaler() if args.mixed_precision else None
scaler = torch.cuda.amp.GradScaler()
# optionally resume from a checkpoint
if args.resume:
@ -270,8 +267,7 @@ def main_worker(gpu, ngpus_per_node, args):
args.start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
if scaler is not None:
scaler.load_state_dict(checkpoint['scaler'])
scaler.load_state_dict(checkpoint['scaler'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
@ -340,7 +336,7 @@ def main_worker(gpu, ngpus_per_node, args):
'arch': args.arch,
'state_dict': model.state_dict(),
'optimizer' : optimizer.state_dict(),
'scaler': None if scaler is None else scaler.state_dict(),
'scaler': scaler.state_dict(),
}, is_best=False, filename='%s/checkpoint_%04d.pth.tar' % (args.checkpoint_folder, epoch))
@ -369,7 +365,7 @@ def train(train_loader, model, criterion, optimizer, scaler, epoch, args):
images[1] = images[1].cuda(args.gpu, non_blocking=True)
# compute output
with torch.cuda.amp.autocast(scaler is not None):
with torch.cuda.amp.autocast(True):
output1, output2, target = model(im1=images[0], im2=images[1], m=moco_m)
loss = (criterion(output1, target) + criterion(output2, target)) * (args.moco_t * 2.)
@ -382,13 +378,9 @@ def train(train_loader, model, criterion, optimizer, scaler, epoch, args):
# compute gradient and do SGD step
optimizer.zero_grad()
if scaler is None:
loss.backward()
optimizer.step()
else:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# measure elapsed time
batch_time.update(time.time() - end)

View File

@ -40,7 +40,7 @@ class MoCo(nn.Module):
self.momentum_encoder = base_encoder(num_classes=mlp_dim)
hidden_dim = self.base_encoder.fc.weight.shape[1]
del self.base_encoder.fc
del self.base_encoder.fc # remove original fc layer
self.base_encoder.fc = nn.Sequential(nn.Linear(hidden_dim, mlp_dim, bias=False),
nn.BatchNorm1d(mlp_dim),
nn.ReLU(inplace=True), # first layer
@ -67,7 +67,7 @@ class MoCo(nn.Module):
self.momentum_encoder = base_encoder(num_classes=mlp_dim)
hidden_dim = self.base_encoder.head.weight.shape[1]
del self.base_encoder.head
del self.base_encoder.head # remove original fc layer
self.base_encoder.head = nn.Sequential(nn.Linear(hidden_dim, mlp_dim, bias=False),
nn.BatchNorm1d(mlp_dim),
nn.GELU(), # first layer

View File

@ -26,7 +26,6 @@ class VisionTransformerMoCo(VisionTransformer):
self.patch_embed.proj.weight.requires_grad = False
self.patch_embed.proj.bias.requires_grad = False
def build_2d_sincos_position_embedding(self, temperature=10000.):
h, w = self.patch_embed.grid_size
grid_w = torch.arange(w, dtype=torch.float32)