fix issues, add vit position embedding

This commit is contained in:
Xinlei Chen 2021-07-07 18:47:47 -07:00
parent fe5f5a2395
commit d04dcd5246
2 changed files with 26 additions and 41 deletions

View File

@ -228,17 +228,17 @@ def main_worker(gpu, ngpus_per_node, args):
# ourselves based on the total number of GPUs we have
args.batch_size = int(args.batch_size / args.world_size)
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
# Use apex DDP to support stop-grad in networks
model = apex.parallel.DistributedDataParallel(module=model, delay_allreduce=True)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
else:
model.cuda()
# Use apex DDP to support stop-grad in networks
model = apex.parallel.DistributedDataParallel(module=model, delay_allreduce=True)
# DistributedDataParallel will divide and allocate batch_size to all
# available GPUs if device_ids are not set
model = torch.nn.parallel.DistributedDataParallel(model)
elif args.gpu is not None:
torch.cuda.set_device(args.gpu)
model = model.cuda(args.gpu)
# comment out the following line for debugging
raise NotImplementedError("Only DistributedDataParallel is supported.")
# raise NotImplementedError("Only DistributedDataParallel is supported.")
else:
# AllGather/rank implementation in this code only supports DistributedDataParallel.
raise NotImplementedError("Only DistributedDataParallel is supported.")
@ -285,6 +285,7 @@ def main_worker(gpu, ngpus_per_node, args):
std=[0.229, 0.224, 0.225])
# BYOL's augmentation recipe: https://arxiv.org/abs/2006.07733
# except min-scale kept as 0.2
augmentation1 = [
transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
transforms.RandomApply([

56
vits.py
View File

@ -21,45 +21,29 @@ __all__ = [
class VisionTransformerMoCo(VisionTransformer):
def __init__(self, stop_grad_conv1=False, **kwargs):
super().__init__(**kwargs)
self.stop_grad_conv1 = stop_grad_conv1
def forward_features(self, x):
x = self.patch_embed(x)
# Add stop-grad after conv1
if self.stop_grad_conv1:
x = x.detach()
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
if self.dist_token is None:
x = torch.cat((cls_token, x), dim=1)
else:
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.pos_drop(x + self.pos_embed)
x = self.blocks(x)
x = self.norm(x)
if self.dist_token is None:
return self.pre_logits(x[:, 0])
else:
return x[:, 0], x[:, 1]
self.build_2d_sincos_position_embedding()
if stop_grad_conv1:
self.patch_embed.proj.weight.requires_grad = False
self.patch_embed.proj.bias.requires_grad = False
def build_pos_embedding_2d_sincos(grid_size, hidden_dim, temperature):
grid_h = torch.arange(grid_size, dtype=torch.float32)
grid_w = torch.arange(grid_size, dtype=torch.float32)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
def build_2d_sincos_position_embedding(self, temperature=10000.):
h, w = self.patch_embed.grid_size
grid_w = torch.arange(w, dtype=torch.float32)
grid_h = torch.arange(h, dtype=torch.float32)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
assert self.embed_dim % 4 == 0, 'Hidden dimension must be divisible by 4 for 2D sin-cos position embedding.'
pos_dim = self.embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1. / (temperature**omega)
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
assert hidden_dim % 4 == 0, 'Hidden dimension must be an even number for position embedding.'
pos_dim = hidden_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1. / (temperature**omega)
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[:, None, :]
p = torch.zeros([1, 1, hidden_dim], dtype=torch.float32)
pos_emb = torch.cat([p, pos_emb], dim=0)
return pos_emb
pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32)
del self.pos_embed
self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
self.pos_embed.requires_grad = False
def vit_small(**kwargs):