add back stop-grad
parent
6b1cc4cf87
commit
6a8c371a94
|
@ -66,7 +66,7 @@ Note that the smaller batch size: 1) facilitates stable training, as discussed i
|
|||
</details>
|
||||
|
||||
<details>
|
||||
<summary>ViT-Base, 300-Epoch, 2-Nodes.</summary>
|
||||
<summary>ViT-Base, 300-Epoch, 2-Node.</summary>
|
||||
|
||||
With a batch size of 1024, ViT-Base can be trained on 2 nodes:
|
||||
|
||||
|
|
|
@ -115,6 +115,10 @@ parser.add_argument('--moco-m', default=0.99, type=float,
|
|||
parser.add_argument('--moco-t', default=1.0, type=float,
|
||||
help='softmax temperature (default: 1.0)')
|
||||
|
||||
# vit specific configs:
|
||||
parser.add_argument('--stop-grad-conv1', action='store_true',
|
||||
help='stop-grad after first conv, or patch embedding')
|
||||
|
||||
# other upgrades
|
||||
parser.add_argument('--optimizer', default='lars', type=str,
|
||||
choices=['lars', 'adamw'],
|
||||
|
@ -193,7 +197,7 @@ def main_worker(gpu, ngpus_per_node, args):
|
|||
print("=> creating model '{}'".format(args.arch))
|
||||
if args.arch.startswith('vit'):
|
||||
model = moco.builder.MoCo(
|
||||
vits.__dict__[args.arch],
|
||||
partial(vits.__dict__[args.arch], stop_grad_conv1=args.stop_grad_conv1),
|
||||
True, # with vit setup
|
||||
args.moco_dim, args.moco_mlp_dim, args.moco_t)
|
||||
else:
|
||||
|
|
7
vits.py
7
vits.py
|
@ -19,11 +19,15 @@ __all__ = [
|
|||
|
||||
|
||||
class VisionTransformerMoCo(VisionTransformer):
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, stop_grad_conv1=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
# Use 2D sin-cos position embedding
|
||||
self.build_2d_sincos_position_embedding()
|
||||
|
||||
if stop_grad_conv1:
|
||||
self.patch_embed.proj.weight.requires_grad = False
|
||||
self.patch_embed.proj.bias.requires_grad = False
|
||||
|
||||
def build_2d_sincos_position_embedding(self, temperature=10000.):
|
||||
h, w = self.patch_embed.grid_size
|
||||
grid_w = torch.arange(w, dtype=torch.float32)
|
||||
|
@ -37,6 +41,7 @@ class VisionTransformerMoCo(VisionTransformer):
|
|||
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
|
||||
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
|
||||
|
||||
assert self.num_tokens == 1, 'Assuming one and only one token, [cls]'
|
||||
pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32)
|
||||
self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
|
||||
self.pos_embed.requires_grad = False
|
||||
|
|
Loading…
Reference in New Issue