remove ViT BN and simplify pipeline
parent
f8d0325008
commit
6b1cc4cf87
48
README.md
48
README.md
|
@ -65,6 +65,24 @@ Note that the smaller batch size: 1) facilitates stable training, as discussed i
|
|||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>ViT-Base, 300-Epoch, 2-Nodes.</summary>
|
||||
|
||||
With a batch size of 1024, ViT-Base can be trained on 2 nodes:
|
||||
|
||||
```
|
||||
python main_moco.py \
|
||||
-a vit_small -b 1024 \
|
||||
--optimizer=adamw --lr=1e-4 --weight-decay=.1 \
|
||||
--epochs=300 --warmup-epochs=40 \
|
||||
--moco-t=.2 \
|
||||
--dist-url 'tcp://[your node 1 address]:[specified port]'' \
|
||||
--multiprocessing-distributed --world-size 2 --rank 0 \
|
||||
[your imagenet-folder with train and val folders]
|
||||
```
|
||||
On the second node, run the same command as above, with `--rank 1`.
|
||||
</details>
|
||||
|
||||
### Linear Classification
|
||||
|
||||
By default, we use SGD+Momentum optimizer and a batch size of 1024 for linear classification on frozen features/weights. This fits on an 8-GPU node.
|
||||
|
@ -84,8 +102,6 @@ python main_lincls.py \
|
|||
|
||||
### Reference Setups
|
||||
|
||||
#### ResNet-50
|
||||
|
||||
For longer pre-trainings with ResNet-50, we find the following hyper-parameters work well (expected performance in the last column, will update logs/pre-trained models soon):
|
||||
|
||||
<table><tbody>
|
||||
|
@ -102,21 +118,21 @@ For longer pre-trainings with ResNet-50, we find the following hyper-parameters
|
|||
<td align="center">0.45</td>
|
||||
<td align="center">1e-6</td>
|
||||
<td align="center">0.99</td>
|
||||
<td align="center">~67.5</td>
|
||||
<td align="center">[TODO]67.5</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">300</td>
|
||||
<td align="center">0.3</td>
|
||||
<td align="center">1e-6</td>
|
||||
<td align="center">0.99</td>
|
||||
<td align="center">~72.8</td>
|
||||
<td align="center">[TODO]72.8</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">1000</td>
|
||||
<td align="center">0.3</td>
|
||||
<td align="center">1.5e-6</td>
|
||||
<td align="center">0.996</td>
|
||||
<td align="center">~74.8</td>
|
||||
<td align="center">[TODO]74.8</td>
|
||||
</tr>
|
||||
</tbody></table>
|
||||
|
||||
|
@ -136,28 +152,6 @@ python main_moco.py \
|
|||
On the second node, run the same command as above, with `--rank 1`.
|
||||
</details>
|
||||
|
||||
#### ViT
|
||||
|
||||
For Vision Transformers, we also provide the BatchNorm based backbone, where the LayerNorm in each MLP block (and the last one) is replaced with BatchNorm. We recommend the following hyper-parameters as a starting point:
|
||||
|
||||
<details>
|
||||
<summary>MoCo v3 with ViT-Small, BatchNorm backbone.</summary>
|
||||
|
||||
```
|
||||
python main_moco.py \
|
||||
-a vit_small -b 1024 \
|
||||
--vit-bn --vit-no-cls-token \
|
||||
--optimizer=adamw --lr=3e-4 --weight-decay=.05 \
|
||||
--epochs=300 --warmup-epochs=40 \
|
||||
--moco-t=.2 \
|
||||
--dist-url 'tcp://localhost:10001' \
|
||||
--multiprocessing-distributed --world-size 1 --rank 0 \
|
||||
[your imagenet-folder with train and val folders]
|
||||
```
|
||||
|
||||
Note the changes in learning rate, weight decay, and removal of class token.
|
||||
</details>
|
||||
|
||||
### License
|
||||
|
||||
This project is under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for details.
|
|
@ -85,14 +85,6 @@ parser.add_argument('--multiprocessing-distributed', action='store_true',
|
|||
'fastest way to use PyTorch for either single node or '
|
||||
'multi node data parallel training')
|
||||
|
||||
# vit specific configs:
|
||||
parser.add_argument('--vit-bn', action='store_true',
|
||||
help='use batch normalization instead of layer normalization '
|
||||
'in ViT MLP blocks and in the end')
|
||||
parser.add_argument('--vit-no-cls-token', action='store_true',
|
||||
help='remove class token in ViT, and use average pooled '
|
||||
'features as embedding')
|
||||
|
||||
# additional configs:
|
||||
parser.add_argument('--pretrained', default='', type=str,
|
||||
help='path to moco pretrained checkpoint')
|
||||
|
@ -161,7 +153,7 @@ def main_worker(gpu, ngpus_per_node, args):
|
|||
# create model
|
||||
print("=> creating model '{}'".format(args.arch))
|
||||
if args.arch.startswith('vit'):
|
||||
model = vits.__dict__[args.arch](use_bn=args.vit_bn, no_cls_token=args.vit_no_cls_token)
|
||||
model = vits.__dict__[args.arch]()
|
||||
linear_keyword = 'head'
|
||||
else:
|
||||
model = torchvision_models.__dict__[args.arch]()
|
||||
|
|
10
main_moco.py
10
main_moco.py
|
@ -115,14 +115,6 @@ parser.add_argument('--moco-m', default=0.99, type=float,
|
|||
parser.add_argument('--moco-t', default=1.0, type=float,
|
||||
help='softmax temperature (default: 1.0)')
|
||||
|
||||
# vit specific configs:
|
||||
parser.add_argument('--vit-bn', action='store_true',
|
||||
help='use batch normalization instead of layer normalization '
|
||||
'in ViT MLP blocks and in the end')
|
||||
parser.add_argument('--vit-no-cls-token', action='store_true',
|
||||
help='remove class token in ViT, and use average pooled '
|
||||
'features as embedding')
|
||||
|
||||
# other upgrades
|
||||
parser.add_argument('--optimizer', default='lars', type=str,
|
||||
choices=['lars', 'adamw'],
|
||||
|
@ -201,7 +193,7 @@ def main_worker(gpu, ngpus_per_node, args):
|
|||
print("=> creating model '{}'".format(args.arch))
|
||||
if args.arch.startswith('vit'):
|
||||
model = moco.builder.MoCo(
|
||||
partial(vits.__dict__[args.arch], use_bn=args.vit_bn, no_cls_token=args.vit_no_cls_token),
|
||||
vits.__dict__[args.arch],
|
||||
True, # with vit setup
|
||||
args.moco_dim, args.moco_mlp_dim, args.moco_t)
|
||||
else:
|
||||
|
|
55
vits.py
55
vits.py
|
@ -19,21 +19,11 @@ __all__ = [
|
|||
|
||||
|
||||
class VisionTransformerMoCo(VisionTransformer):
|
||||
def __init__(self, use_bn=False, no_cls_token=False, **kwargs):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.no_cls_token = no_cls_token
|
||||
|
||||
# Use 2D sin-cos position embedding
|
||||
del self.pos_embed
|
||||
self.build_2d_sincos_position_embedding()
|
||||
|
||||
if use_bn:
|
||||
self.replace_lns_with_bns()
|
||||
|
||||
if no_cls_token:
|
||||
del self.cls_token
|
||||
self.num_tokens -= 1
|
||||
|
||||
def build_2d_sincos_position_embedding(self, temperature=10000.):
|
||||
h, w = self.patch_embed.grid_size
|
||||
grid_w = torch.arange(w, dtype=torch.float32)
|
||||
|
@ -47,49 +37,10 @@ class VisionTransformerMoCo(VisionTransformer):
|
|||
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
|
||||
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
|
||||
|
||||
if not self.no_cls_token:
|
||||
pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32)
|
||||
self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
|
||||
else:
|
||||
self.pos_embed = nn.Parameter(pos_emb)
|
||||
pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32)
|
||||
self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
|
||||
self.pos_embed.requires_grad = False
|
||||
|
||||
def replace_lns_with_bns(self):
|
||||
# replace LNs with BNs in the MLP blocks
|
||||
for blk in self.blocks:
|
||||
del blk.norm2
|
||||
blk.norm2 = nn.BatchNorm1d(self.embed_dim, eps=1e-6)
|
||||
|
||||
# replace last LN with BN
|
||||
del self.norm
|
||||
self.norm = nn.BatchNorm1d(self.embed_dim, eps=1e-6)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
|
||||
x_list = []
|
||||
if not self.no_cls_token:
|
||||
x_list.append(self.cls_token.expand(x.shape[0], -1, -1))
|
||||
if self.dist_token is not None:
|
||||
x_list.append(self.dist_token.expand(x.shape[0], -1, -1))
|
||||
x_list.append(x)
|
||||
|
||||
x = torch.cat(x_list, dim=1)
|
||||
x = self.pos_drop(x + self.pos_embed)
|
||||
x = self.blocks(x)
|
||||
x = self.norm(x)
|
||||
|
||||
if self.no_cls_token:
|
||||
x_feat = x[:, self.num_tokens:].mean(dim=1) # take the mean over all tokens
|
||||
else:
|
||||
x_feat = x[:, 0]
|
||||
|
||||
if self.dist_token is None:
|
||||
return self.pre_logits(x_feat)
|
||||
else:
|
||||
return x_feat, x[:, self.num_tokens-1]
|
||||
|
||||
|
||||
|
||||
def vit_small(**kwargs):
|
||||
model = VisionTransformerMoCo(
|
||||
|
|
Loading…
Reference in New Issue