mirror of
https://github.com/facebookresearch/moco-v3.git
synced 2025-06-03 14:59:22 +08:00
fix issues, add vit position embedding
This commit is contained in:
parent
fe5f5a2395
commit
d04dcd5246
11
main_moco.py
11
main_moco.py
@ -228,17 +228,17 @@ def main_worker(gpu, ngpus_per_node, args):
|
|||||||
# ourselves based on the total number of GPUs we have
|
# ourselves based on the total number of GPUs we have
|
||||||
args.batch_size = int(args.batch_size / args.world_size)
|
args.batch_size = int(args.batch_size / args.world_size)
|
||||||
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
|
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
|
||||||
# Use apex DDP to support stop-grad in networks
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
||||||
model = apex.parallel.DistributedDataParallel(module=model, delay_allreduce=True)
|
|
||||||
else:
|
else:
|
||||||
model.cuda()
|
model.cuda()
|
||||||
# Use apex DDP to support stop-grad in networks
|
# DistributedDataParallel will divide and allocate batch_size to all
|
||||||
model = apex.parallel.DistributedDataParallel(module=model, delay_allreduce=True)
|
# available GPUs if device_ids are not set
|
||||||
|
model = torch.nn.parallel.DistributedDataParallel(model)
|
||||||
elif args.gpu is not None:
|
elif args.gpu is not None:
|
||||||
torch.cuda.set_device(args.gpu)
|
torch.cuda.set_device(args.gpu)
|
||||||
model = model.cuda(args.gpu)
|
model = model.cuda(args.gpu)
|
||||||
# comment out the following line for debugging
|
# comment out the following line for debugging
|
||||||
raise NotImplementedError("Only DistributedDataParallel is supported.")
|
# raise NotImplementedError("Only DistributedDataParallel is supported.")
|
||||||
else:
|
else:
|
||||||
# AllGather/rank implementation in this code only supports DistributedDataParallel.
|
# AllGather/rank implementation in this code only supports DistributedDataParallel.
|
||||||
raise NotImplementedError("Only DistributedDataParallel is supported.")
|
raise NotImplementedError("Only DistributedDataParallel is supported.")
|
||||||
@ -285,6 +285,7 @@ def main_worker(gpu, ngpus_per_node, args):
|
|||||||
std=[0.229, 0.224, 0.225])
|
std=[0.229, 0.224, 0.225])
|
||||||
|
|
||||||
# BYOL's augmentation recipe: https://arxiv.org/abs/2006.07733
|
# BYOL's augmentation recipe: https://arxiv.org/abs/2006.07733
|
||||||
|
# except min-scale kept as 0.2
|
||||||
augmentation1 = [
|
augmentation1 = [
|
||||||
transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
|
transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
|
||||||
transforms.RandomApply([
|
transforms.RandomApply([
|
||||||
|
46
vits.py
46
vits.py
@ -21,45 +21,29 @@ __all__ = [
|
|||||||
class VisionTransformerMoCo(VisionTransformer):
|
class VisionTransformerMoCo(VisionTransformer):
|
||||||
def __init__(self, stop_grad_conv1=False, **kwargs):
|
def __init__(self, stop_grad_conv1=False, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.stop_grad_conv1 = stop_grad_conv1
|
self.build_2d_sincos_position_embedding()
|
||||||
|
if stop_grad_conv1:
|
||||||
def forward_features(self, x):
|
self.patch_embed.proj.weight.requires_grad = False
|
||||||
x = self.patch_embed(x)
|
self.patch_embed.proj.bias.requires_grad = False
|
||||||
# Add stop-grad after conv1
|
|
||||||
if self.stop_grad_conv1:
|
|
||||||
x = x.detach()
|
|
||||||
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
|
|
||||||
if self.dist_token is None:
|
|
||||||
x = torch.cat((cls_token, x), dim=1)
|
|
||||||
else:
|
|
||||||
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
|
|
||||||
x = self.pos_drop(x + self.pos_embed)
|
|
||||||
x = self.blocks(x)
|
|
||||||
x = self.norm(x)
|
|
||||||
if self.dist_token is None:
|
|
||||||
return self.pre_logits(x[:, 0])
|
|
||||||
else:
|
|
||||||
return x[:, 0], x[:, 1]
|
|
||||||
|
|
||||||
|
|
||||||
def build_pos_embedding_2d_sincos(grid_size, hidden_dim, temperature):
|
def build_2d_sincos_position_embedding(self, temperature=10000.):
|
||||||
grid_h = torch.arange(grid_size, dtype=torch.float32)
|
h, w = self.patch_embed.grid_size
|
||||||
grid_w = torch.arange(grid_size, dtype=torch.float32)
|
grid_w = torch.arange(w, dtype=torch.float32)
|
||||||
|
grid_h = torch.arange(h, dtype=torch.float32)
|
||||||
grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
|
grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
|
||||||
|
assert self.embed_dim % 4 == 0, 'Hidden dimension must be divisible by 4 for 2D sin-cos position embedding.'
|
||||||
assert hidden_dim % 4 == 0, 'Hidden dimension must be an even number for position embedding.'
|
pos_dim = self.embed_dim // 4
|
||||||
pos_dim = hidden_dim // 4
|
|
||||||
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
|
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
|
||||||
omega = 1. / (temperature**omega)
|
omega = 1. / (temperature**omega)
|
||||||
|
|
||||||
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
|
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
|
||||||
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
|
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
|
||||||
|
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
|
||||||
|
|
||||||
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[:, None, :]
|
pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32)
|
||||||
|
del self.pos_embed
|
||||||
p = torch.zeros([1, 1, hidden_dim], dtype=torch.float32)
|
self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
|
||||||
pos_emb = torch.cat([p, pos_emb], dim=0)
|
self.pos_embed.requires_grad = False
|
||||||
return pos_emb
|
|
||||||
|
|
||||||
|
|
||||||
def vit_small(**kwargs):
|
def vit_small(**kwargs):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user