From 1d76a72e35b7ec7e2432c6903204b118f18b8295 Mon Sep 17 00:00:00 2001 From: Xinlei Chen Date: Thu, 8 Jul 2021 04:34:54 -0700 Subject: [PATCH] start readme --- README.md | 45 ++++++++++++++++++++++++++++++++++++++++++++- main_moco.py | 22 +++++++--------------- moco/builder.py | 4 ++-- vits.py | 1 - 4 files changed, 53 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 871bcf0..09dab66 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,47 @@ -# moco-v3 +# MoCo v3 + +This is a PyTorch implementation of [MoCo v3](https://arxiv.org/abs/2104.02057): +``` +@Article{chen2021mocov3, + author = {Xinlei Chen* and Saining Xie* and Kaiming He}, + title = {An Empirical Study of Training Self-Supervised Vision Transformers}, + journal = {arXiv preprint arXiv:2104.02057}, + year = {2021}, +} +``` + +### Preparation + +Install PyTorch and download the ImageNet dataset following the [official PyTorch ImageNet training code](https://github.com/pytorch/examples/tree/master/imagenet). Similar to [MoCo](https://github.com/facebookresearch/moco), the code release contains minimal modifications for both unsupervised pre-training and linear classification to that code. + +In addition, install [timm](https://github.com/rwightman/pytorch-image-models) for the Vision Transformer [(ViT)](https://arxiv.org/abs/2010.11929) models. + +### Unsupervised Pre-Training + +Similar to MoCo, only **multi-gpu**, **DistributedDataParallel** training is supported; single-gpu or DataParallel training is not supported. In addition, the code is tested with **multi-node** setting, and by default uses automatic **mixed-precision** for pre-training. + +Below we exemplify several pre-training commands covering different model architectures, training epochs, single-/multi-node, etc. + +
+ +MoCo v3 with ResNet-50, 100-Epoch, 2-Node. + +This is the default setting for most hyper-parameters. With a batch size of 4096, the training fits into 2 nodes with a total of 16 Volta 32G GPUs. +On the first node, run: +``` +python main_moco.py \ + --dist-url "tcp://[your node 1 address]:[specified port]" \ + --multiprocessing-distributed --world-size 2 --rank 0 \ + [your imagenet-folder with train and val folders] +``` +On the second node, run: +``` +python main_moco.py \ + --dist-url "tcp://[your node 1 address]:[specified port]" \ + --multiprocessing-distributed --world-size 2 --rank 1 \ + [your imagenet-folder with train and val folders] +``` +
### License diff --git a/main_moco.py b/main_moco.py index 2febe2a..86215cb 100755 --- a/main_moco.py +++ b/main_moco.py @@ -32,7 +32,6 @@ import torchvision.datasets as datasets import torchvision.models as torchvision_models from functools import partial -import apex import vits import moco.builder @@ -126,8 +125,6 @@ parser.add_argument('--optimizer', default='lars', type=str, help='optimizer used (default: lars)') parser.add_argument('--warmup-epochs', default=10, type=int, metavar='N', help='number of warmup epochs') -parser.add_argument('--mixed-precision', action='store_true', - help='Use mixed precision') # ===== OTHERS, to delete ===== parser.add_argument('--checkpoint-folder', default='.', type=str, metavar='PATH', @@ -255,7 +252,7 @@ def main_worker(gpu, ngpus_per_node, args): optimizer = moco.optimizer.AdamW(model.parameters(), init_lr, weight_decay=args.weight_decay) - scaler = torch.cuda.amp.GradScaler() if args.mixed_precision else None + scaler = torch.cuda.amp.GradScaler() # optionally resume from a checkpoint if args.resume: @@ -270,8 +267,7 @@ def main_worker(gpu, ngpus_per_node, args): args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) - if scaler is not None: - scaler.load_state_dict(checkpoint['scaler']) + scaler.load_state_dict(checkpoint['scaler']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: @@ -340,7 +336,7 @@ def main_worker(gpu, ngpus_per_node, args): 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict(), - 'scaler': None if scaler is None else scaler.state_dict(), + 'scaler': scaler.state_dict(), }, is_best=False, filename='%s/checkpoint_%04d.pth.tar' % (args.checkpoint_folder, epoch)) @@ -369,7 +365,7 @@ def train(train_loader, model, criterion, optimizer, scaler, epoch, args): images[1] = images[1].cuda(args.gpu, non_blocking=True) # compute output - with torch.cuda.amp.autocast(scaler is not None): + with torch.cuda.amp.autocast(True): output1, output2, target = model(im1=images[0], im2=images[1], m=moco_m) loss = (criterion(output1, target) + criterion(output2, target)) * (args.moco_t * 2.) @@ -382,13 +378,9 @@ def train(train_loader, model, criterion, optimizer, scaler, epoch, args): # compute gradient and do SGD step optimizer.zero_grad() - if scaler is None: - loss.backward() - optimizer.step() - else: - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() # measure elapsed time batch_time.update(time.time() - end) diff --git a/moco/builder.py b/moco/builder.py index d856997..b7dc436 100644 --- a/moco/builder.py +++ b/moco/builder.py @@ -40,7 +40,7 @@ class MoCo(nn.Module): self.momentum_encoder = base_encoder(num_classes=mlp_dim) hidden_dim = self.base_encoder.fc.weight.shape[1] - del self.base_encoder.fc + del self.base_encoder.fc # remove original fc layer self.base_encoder.fc = nn.Sequential(nn.Linear(hidden_dim, mlp_dim, bias=False), nn.BatchNorm1d(mlp_dim), nn.ReLU(inplace=True), # first layer @@ -67,7 +67,7 @@ class MoCo(nn.Module): self.momentum_encoder = base_encoder(num_classes=mlp_dim) hidden_dim = self.base_encoder.head.weight.shape[1] - del self.base_encoder.head + del self.base_encoder.head # remove original fc layer self.base_encoder.head = nn.Sequential(nn.Linear(hidden_dim, mlp_dim, bias=False), nn.BatchNorm1d(mlp_dim), nn.GELU(), # first layer diff --git a/vits.py b/vits.py index 660a623..090a586 100644 --- a/vits.py +++ b/vits.py @@ -26,7 +26,6 @@ class VisionTransformerMoCo(VisionTransformer): self.patch_embed.proj.weight.requires_grad = False self.patch_embed.proj.bias.requires_grad = False - def build_2d_sincos_position_embedding(self, temperature=10000.): h, w = self.patch_embed.grid_size grid_w = torch.arange(w, dtype=torch.float32)