From d04dcd52464a8c702abbbb2e91e905ce2372819c Mon Sep 17 00:00:00 2001 From: Xinlei Chen Date: Wed, 7 Jul 2021 18:47:47 -0700 Subject: [PATCH] fix issues, add vit position embedding --- main_moco.py | 11 ++++++----- vits.py | 56 +++++++++++++++++++--------------------------------- 2 files changed, 26 insertions(+), 41 deletions(-) diff --git a/main_moco.py b/main_moco.py index dd83491..2febe2a 100755 --- a/main_moco.py +++ b/main_moco.py @@ -228,17 +228,17 @@ def main_worker(gpu, ngpus_per_node, args): # ourselves based on the total number of GPUs we have args.batch_size = int(args.batch_size / args.world_size) args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) - # Use apex DDP to support stop-grad in networks - model = apex.parallel.DistributedDataParallel(module=model, delay_allreduce=True) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) else: model.cuda() - # Use apex DDP to support stop-grad in networks - model = apex.parallel.DistributedDataParallel(module=model, delay_allreduce=True) + # DistributedDataParallel will divide and allocate batch_size to all + # available GPUs if device_ids are not set + model = torch.nn.parallel.DistributedDataParallel(model) elif args.gpu is not None: torch.cuda.set_device(args.gpu) model = model.cuda(args.gpu) # comment out the following line for debugging - raise NotImplementedError("Only DistributedDataParallel is supported.") + # raise NotImplementedError("Only DistributedDataParallel is supported.") else: # AllGather/rank implementation in this code only supports DistributedDataParallel. raise NotImplementedError("Only DistributedDataParallel is supported.") @@ -285,6 +285,7 @@ def main_worker(gpu, ngpus_per_node, args): std=[0.229, 0.224, 0.225]) # BYOL's augmentation recipe: https://arxiv.org/abs/2006.07733 + # except min-scale kept as 0.2 augmentation1 = [ transforms.RandomResizedCrop(224, scale=(0.2, 1.)), transforms.RandomApply([ diff --git a/vits.py b/vits.py index 9b7c09b..660a623 100644 --- a/vits.py +++ b/vits.py @@ -21,45 +21,29 @@ __all__ = [ class VisionTransformerMoCo(VisionTransformer): def __init__(self, stop_grad_conv1=False, **kwargs): super().__init__(**kwargs) - self.stop_grad_conv1 = stop_grad_conv1 - - def forward_features(self, x): - x = self.patch_embed(x) - # Add stop-grad after conv1 - if self.stop_grad_conv1: - x = x.detach() - cls_token = self.cls_token.expand(x.shape[0], -1, -1) - if self.dist_token is None: - x = torch.cat((cls_token, x), dim=1) - else: - x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) - x = self.pos_drop(x + self.pos_embed) - x = self.blocks(x) - x = self.norm(x) - if self.dist_token is None: - return self.pre_logits(x[:, 0]) - else: - return x[:, 0], x[:, 1] + self.build_2d_sincos_position_embedding() + if stop_grad_conv1: + self.patch_embed.proj.weight.requires_grad = False + self.patch_embed.proj.bias.requires_grad = False -def build_pos_embedding_2d_sincos(grid_size, hidden_dim, temperature): - grid_h = torch.arange(grid_size, dtype=torch.float32) - grid_w = torch.arange(grid_size, dtype=torch.float32) - grid_w, grid_h = torch.meshgrid(grid_w, grid_h) + def build_2d_sincos_position_embedding(self, temperature=10000.): + h, w = self.patch_embed.grid_size + grid_w = torch.arange(w, dtype=torch.float32) + grid_h = torch.arange(h, dtype=torch.float32) + grid_w, grid_h = torch.meshgrid(grid_w, grid_h) + assert self.embed_dim % 4 == 0, 'Hidden dimension must be divisible by 4 for 2D sin-cos position embedding.' + pos_dim = self.embed_dim // 4 + omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim + omega = 1. / (temperature**omega) + out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) + out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) + pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] - assert hidden_dim % 4 == 0, 'Hidden dimension must be an even number for position embedding.' - pos_dim = hidden_dim // 4 - omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim - omega = 1. / (temperature**omega) - - out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) - out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) - - pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[:, None, :] - - p = torch.zeros([1, 1, hidden_dim], dtype=torch.float32) - pos_emb = torch.cat([p, pos_emb], dim=0) - return pos_emb + pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32) + del self.pos_embed + self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1)) + self.pos_embed.requires_grad = False def vit_small(**kwargs):