add back stop-grad

pull/3/head
Xinlei Chen 2021-07-09 02:22:59 -07:00
parent 6b1cc4cf87
commit 6a8c371a94
3 changed files with 12 additions and 3 deletions

View File

@ -66,7 +66,7 @@ Note that the smaller batch size: 1) facilitates stable training, as discussed i
</details>
<details>
<summary>ViT-Base, 300-Epoch, 2-Nodes.</summary>
<summary>ViT-Base, 300-Epoch, 2-Node.</summary>
With a batch size of 1024, ViT-Base can be trained on 2 nodes:

View File

@ -115,6 +115,10 @@ parser.add_argument('--moco-m', default=0.99, type=float,
parser.add_argument('--moco-t', default=1.0, type=float,
help='softmax temperature (default: 1.0)')
# vit specific configs:
parser.add_argument('--stop-grad-conv1', action='store_true',
help='stop-grad after first conv, or patch embedding')
# other upgrades
parser.add_argument('--optimizer', default='lars', type=str,
choices=['lars', 'adamw'],
@ -193,7 +197,7 @@ def main_worker(gpu, ngpus_per_node, args):
print("=> creating model '{}'".format(args.arch))
if args.arch.startswith('vit'):
model = moco.builder.MoCo(
vits.__dict__[args.arch],
partial(vits.__dict__[args.arch], stop_grad_conv1=args.stop_grad_conv1),
True, # with vit setup
args.moco_dim, args.moco_mlp_dim, args.moco_t)
else:

View File

@ -19,11 +19,15 @@ __all__ = [
class VisionTransformerMoCo(VisionTransformer):
def __init__(self, **kwargs):
def __init__(self, stop_grad_conv1=False, **kwargs):
super().__init__(**kwargs)
# Use 2D sin-cos position embedding
self.build_2d_sincos_position_embedding()
if stop_grad_conv1:
self.patch_embed.proj.weight.requires_grad = False
self.patch_embed.proj.bias.requires_grad = False
def build_2d_sincos_position_embedding(self, temperature=10000.):
h, w = self.patch_embed.grid_size
grid_w = torch.arange(w, dtype=torch.float32)
@ -37,6 +41,7 @@ class VisionTransformerMoCo(VisionTransformer):
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
assert self.num_tokens == 1, 'Assuming one and only one token, [cls]'
pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32)
self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
self.pos_embed.requires_grad = False