remove ViT BN and simplify pipeline

pull/3/head
Xinlei Chen 2021-07-08 21:00:51 -07:00
parent f8d0325008
commit 6b1cc4cf87
4 changed files with 26 additions and 97 deletions

View File

@ -65,6 +65,24 @@ Note that the smaller batch size: 1) facilitates stable training, as discussed i
</details>
<details>
<summary>ViT-Base, 300-Epoch, 2-Nodes.</summary>
With a batch size of 1024, ViT-Base can be trained on 2 nodes:
```
python main_moco.py \
-a vit_small -b 1024 \
--optimizer=adamw --lr=1e-4 --weight-decay=.1 \
--epochs=300 --warmup-epochs=40 \
--moco-t=.2 \
--dist-url 'tcp://[your node 1 address]:[specified port]'' \
--multiprocessing-distributed --world-size 2 --rank 0 \
[your imagenet-folder with train and val folders]
```
On the second node, run the same command as above, with `--rank 1`.
</details>
### Linear Classification
By default, we use SGD+Momentum optimizer and a batch size of 1024 for linear classification on frozen features/weights. This fits on an 8-GPU node.
@ -84,8 +102,6 @@ python main_lincls.py \
### Reference Setups
#### ResNet-50
For longer pre-trainings with ResNet-50, we find the following hyper-parameters work well (expected performance in the last column, will update logs/pre-trained models soon):
<table><tbody>
@ -102,21 +118,21 @@ For longer pre-trainings with ResNet-50, we find the following hyper-parameters
<td align="center">0.45</td>
<td align="center">1e-6</td>
<td align="center">0.99</td>
<td align="center">~67.5</td>
<td align="center">[TODO]67.5</td>
</tr>
<tr>
<td align="center">300</td>
<td align="center">0.3</td>
<td align="center">1e-6</td>
<td align="center">0.99</td>
<td align="center">~72.8</td>
<td align="center">[TODO]72.8</td>
</tr>
<tr>
<td align="center">1000</td>
<td align="center">0.3</td>
<td align="center">1.5e-6</td>
<td align="center">0.996</td>
<td align="center">~74.8</td>
<td align="center">[TODO]74.8</td>
</tr>
</tbody></table>
@ -136,28 +152,6 @@ python main_moco.py \
On the second node, run the same command as above, with `--rank 1`.
</details>
#### ViT
For Vision Transformers, we also provide the BatchNorm based backbone, where the LayerNorm in each MLP block (and the last one) is replaced with BatchNorm. We recommend the following hyper-parameters as a starting point:
<details>
<summary>MoCo v3 with ViT-Small, BatchNorm backbone.</summary>
```
python main_moco.py \
-a vit_small -b 1024 \
--vit-bn --vit-no-cls-token \
--optimizer=adamw --lr=3e-4 --weight-decay=.05 \
--epochs=300 --warmup-epochs=40 \
--moco-t=.2 \
--dist-url 'tcp://localhost:10001' \
--multiprocessing-distributed --world-size 1 --rank 0 \
[your imagenet-folder with train and val folders]
```
Note the changes in learning rate, weight decay, and removal of class token.
</details>
### License
This project is under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for details.

View File

@ -85,14 +85,6 @@ parser.add_argument('--multiprocessing-distributed', action='store_true',
'fastest way to use PyTorch for either single node or '
'multi node data parallel training')
# vit specific configs:
parser.add_argument('--vit-bn', action='store_true',
help='use batch normalization instead of layer normalization '
'in ViT MLP blocks and in the end')
parser.add_argument('--vit-no-cls-token', action='store_true',
help='remove class token in ViT, and use average pooled '
'features as embedding')
# additional configs:
parser.add_argument('--pretrained', default='', type=str,
help='path to moco pretrained checkpoint')
@ -161,7 +153,7 @@ def main_worker(gpu, ngpus_per_node, args):
# create model
print("=> creating model '{}'".format(args.arch))
if args.arch.startswith('vit'):
model = vits.__dict__[args.arch](use_bn=args.vit_bn, no_cls_token=args.vit_no_cls_token)
model = vits.__dict__[args.arch]()
linear_keyword = 'head'
else:
model = torchvision_models.__dict__[args.arch]()

View File

@ -115,14 +115,6 @@ parser.add_argument('--moco-m', default=0.99, type=float,
parser.add_argument('--moco-t', default=1.0, type=float,
help='softmax temperature (default: 1.0)')
# vit specific configs:
parser.add_argument('--vit-bn', action='store_true',
help='use batch normalization instead of layer normalization '
'in ViT MLP blocks and in the end')
parser.add_argument('--vit-no-cls-token', action='store_true',
help='remove class token in ViT, and use average pooled '
'features as embedding')
# other upgrades
parser.add_argument('--optimizer', default='lars', type=str,
choices=['lars', 'adamw'],
@ -201,7 +193,7 @@ def main_worker(gpu, ngpus_per_node, args):
print("=> creating model '{}'".format(args.arch))
if args.arch.startswith('vit'):
model = moco.builder.MoCo(
partial(vits.__dict__[args.arch], use_bn=args.vit_bn, no_cls_token=args.vit_no_cls_token),
vits.__dict__[args.arch],
True, # with vit setup
args.moco_dim, args.moco_mlp_dim, args.moco_t)
else:

55
vits.py
View File

@ -19,21 +19,11 @@ __all__ = [
class VisionTransformerMoCo(VisionTransformer):
def __init__(self, use_bn=False, no_cls_token=False, **kwargs):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.no_cls_token = no_cls_token
# Use 2D sin-cos position embedding
del self.pos_embed
self.build_2d_sincos_position_embedding()
if use_bn:
self.replace_lns_with_bns()
if no_cls_token:
del self.cls_token
self.num_tokens -= 1
def build_2d_sincos_position_embedding(self, temperature=10000.):
h, w = self.patch_embed.grid_size
grid_w = torch.arange(w, dtype=torch.float32)
@ -47,49 +37,10 @@ class VisionTransformerMoCo(VisionTransformer):
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
if not self.no_cls_token:
pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32)
self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
else:
self.pos_embed = nn.Parameter(pos_emb)
pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32)
self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
self.pos_embed.requires_grad = False
def replace_lns_with_bns(self):
# replace LNs with BNs in the MLP blocks
for blk in self.blocks:
del blk.norm2
blk.norm2 = nn.BatchNorm1d(self.embed_dim, eps=1e-6)
# replace last LN with BN
del self.norm
self.norm = nn.BatchNorm1d(self.embed_dim, eps=1e-6)
def forward_features(self, x):
x = self.patch_embed(x)
x_list = []
if not self.no_cls_token:
x_list.append(self.cls_token.expand(x.shape[0], -1, -1))
if self.dist_token is not None:
x_list.append(self.dist_token.expand(x.shape[0], -1, -1))
x_list.append(x)
x = torch.cat(x_list, dim=1)
x = self.pos_drop(x + self.pos_embed)
x = self.blocks(x)
x = self.norm(x)
if self.no_cls_token:
x_feat = x[:, self.num_tokens:].mean(dim=1) # take the mean over all tokens
else:
x_feat = x[:, 0]
if self.dist_token is None:
return self.pre_logits(x_feat)
else:
return x_feat, x[:, self.num_tokens-1]
def vit_small(**kwargs):
model = VisionTransformerMoCo(