add linear probing

pull/56/head
Kaiming He 2022-01-11 12:31:55 -08:00
parent be47fef7a7
commit 6a2ba40229
8 changed files with 584 additions and 4 deletions

View File

@ -5,7 +5,7 @@ possible.
## Pull Requests
We actively welcome your pull requests.
1. Fork the repo and create your branch from `master`.
1. Fork the repo and create your branch from `main`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.

View File

@ -124,10 +124,54 @@ OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 main_fin
```
- Here the effective batch size is 32 (`batch_size` per gpu) * 4 (`accum_iter`) * 8 (gpus) = 1024. `--accum_iter 4` simulates 4 nodes.
### Notes
#### Notes
- The [pre-trained models we provide](https://github.com/fairinternal/mae/#pre-trained-checkpoints) are trained with *normalized* pixels `--norm_pix_loss` (1600 epochs, Table 3 in paper). The fine-tuning hyper-parameters are slightly different from the default baseline using *unnormalized* pixels.
- The original MAE implementation was in TensorFlow+TPU with no explicit mixed precision. This re-implementation is in PyTorch+GPU with automatic mixed precision (`torch.cuda.amp`). We have observed different numerical behavior between the two platforms. In this repo, we use `--global_pool` for fine-tuning; using `--cls_token` performs similarly, but there is a chance of producing NaN when fine-tuning ViT-Huge in GPUs. We did not observe this issue in TPUs. Turning off amp could solve this issue, but is slower.
- Here we use RandErase following DeiT: `--reprob 0.25`. Its effect is smaller than random variance.
### Linear Probing
Run the following on 4 nodes with 8 GPUs each:
```
python submitit_linprobe.py \
--job_dir ${JOB_DIR} \
--nodes 4 \
--batch_size 512 \
--model vit_base_patch16 --cls_token \
--finetune ${PRETRAIN_CHKPT} \
--epochs 90 \
--blr 0.1 \
--weight_decay 0.0 \
--dist_eval --data_path ${IMAGENET_DIR}
```
- Here the effective batch size is 512 (`batch_size` per gpu) * 4 (`nodes`) * 8 (gpus per node) = 16384.
- `blr` is the base learning rate. The actual `lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `lr` = `blr` * effective batch size / 256.
- Training time is ~2h20m for 90 epochs in 32 V100 GPUs.
- To run single-node training, follow the instruction in fine-tuning.
To train ViT-Large or ViT-Huge, set `--model vit_large_patch16` or `--model vit_huge_patch14`. It is sufficient to train 50 epochs `--epochs 50`.
This PT/GPU code produces *better* results for ViT-L/H (see the table below). This is likely caused by the system difference between TF and PT.
<table><tbody>
<!-- START TABLE -->
<!-- TABLE HEADER -->
<th valign="bottom"></th>
<th valign="bottom">ViT-Base</th>
<th valign="bottom">ViT-Large</th>
<th valign="bottom">ViT-Huge</th>
<!-- TABLE BODY -->
<tr><td align="left">paper (TF/TPU)</td>
<td align="center">68.0</td>
<td align="center">75.8</td>
<td align="center">76.6</td>
</tr>
<tr><td align="left">this repo (PT/GPU)</td>
<td align="center">67.8</td>
<td align="center">76.0</td>
<td align="center">77.2</td>
</tr>
</tbody></table>

View File

@ -20,4 +20,4 @@ python submitit_pretrain.py \
- The exact same hyper-parameters and configs (initialization, augmentation, etc.) are used as our TF/TPU implementation. In our sanity checks, this PT/GPU re-implementation can reproduce the TF/TPU results within reasonable random variation. We get 85.5% [fine-tuning](FINETUNE.md) accuracy by pre-training ViT-Large for 800 epochs (85.4% in paper Table 1d with TF/TPU).
- Training time is ~42h in 64 V100 GPUs (800 epochs).
To train ViT-Base or ViT-Huge, set `--model vit_base_patch16` or `--model vit_huge_patch14`.
To train ViT-Base or ViT-Huge, set `--model mae_vit_base_patch16` or `--model mae_vit_huge_patch14`.

316
main_linprobe.py 100644
View File

@ -0,0 +1,316 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# MoCo v3: https://github.com/facebookresearch/moco-v3
# --------------------------------------------------------
import argparse
import datetime
import json
import numpy as np
import os
import time
from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import timm
assert timm.__version__ == "0.3.2" # version check
from timm.models.layers import trunc_normal_
import util.misc as misc
from util.pos_embed import interpolate_pos_embed
from util.misc import NativeScalerWithGradNormCount as NativeScaler
from util.lars import LARS
from util.crop import RandomResizedCrop
import models_vit
from engine_finetune import train_one_epoch, evaluate
def get_args_parser():
parser = argparse.ArgumentParser('MAE linear probing for image classification', add_help=False)
parser.add_argument('--batch_size', default=512, type=int,
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
parser.add_argument('--epochs', default=90, type=int)
parser.add_argument('--accum_iter', default=1, type=int,
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
# Model parameters
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
help='Name of model to train')
# Optimizer parameters
parser.add_argument('--weight_decay', type=float, default=0,
help='weight decay (default: 0 for linear probe following MoCo v1)')
parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (absolute lr)')
parser.add_argument('--blr', type=float, default=0.1, metavar='LR',
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
help='epochs to warmup LR')
# * Finetuning params
parser.add_argument('--finetune', default='',
help='finetune from checkpoint')
parser.add_argument('--global_pool', action='store_true')
parser.set_defaults(global_pool=False)
parser.add_argument('--cls_token', action='store_false', dest='global_pool',
help='Use class token instead of global pool for classification')
# Dataset parameters
parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
help='dataset path')
parser.add_argument('--nb_classes', default=1000, type=int,
help='number of the classification types')
parser.add_argument('--output_dir', default='./output_dir',
help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default='./output_dir',
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='',
help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true',
help='Perform evaluation only')
parser.add_argument('--dist_eval', action='store_true', default=False,
help='Enabling distributed evaluation (recommended during training for faster monitor')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
return parser
def main(args):
misc.init_distributed_mode(args)
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
print("{}".format(args).replace(', ', ',\n'))
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
# linear probe: weak augmentation
transform_train = transforms.Compose([
RandomResizedCrop(224, interpolation=3),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
transform_val = transforms.Compose([
transforms.Resize(256, interpolation=3),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train)
dataset_val = datasets.ImageFolder(os.path.join(args.data_path, 'val'), transform=transform_val)
print(dataset_train)
print(dataset_val)
if True: # args.distributed:
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
print("Sampler_train = %s" % str(sampler_train))
if args.dist_eval:
if len(dataset_val) % num_tasks != 0:
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler_val = torch.utils.data.DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias
else:
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
if global_rank == 0 and args.log_dir is not None and not args.eval:
os.makedirs(args.log_dir, exist_ok=True)
log_writer = SummaryWriter(log_dir=args.log_dir)
else:
log_writer = None
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False
)
model = models_vit.__dict__[args.model](
num_classes=args.nb_classes,
global_pool=args.global_pool,
)
if args.finetune and not args.eval:
checkpoint = torch.load(args.finetune, map_location='cpu')
print("Load pre-trained checkpoint from: %s" % args.finetune)
checkpoint_model = checkpoint['model']
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
# interpolate position embedding
interpolate_pos_embed(model, checkpoint_model)
# load pre-trained model
msg = model.load_state_dict(checkpoint_model, strict=False)
print(msg)
if args.global_pool:
assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
else:
assert set(msg.missing_keys) == {'head.weight', 'head.bias'}
# manually initialize fc layer: following MoCo v3
trunc_normal_(model.head.weight, std=0.01)
# for linear prob only
# hack: revise model's head with BN
model.head = torch.nn.Sequential(torch.nn.BatchNorm1d(model.head.in_features, affine=False, eps=1e-6), model.head)
# freeze all but the head
for _, p in model.named_parameters():
p.requires_grad = False
for _, p in model.head.named_parameters():
p.requires_grad = True
model.to(device)
model_without_ddp = model
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Model = %s" % str(model_without_ddp))
print('number of params (M): %.2f' % (n_parameters / 1.e6))
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
if args.lr is None: # only base_lr is specified
args.lr = args.blr * eff_batch_size / 256
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
print("actual lr: %.2e" % args.lr)
print("accumulate grad iterations: %d" % args.accum_iter)
print("effective batch size: %d" % eff_batch_size)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
optimizer = LARS(model_without_ddp.head.parameters(), lr=args.lr, weight_decay=args.weight_decay)
print(optimizer)
loss_scaler = NativeScaler()
criterion = torch.nn.CrossEntropyLoss()
print("criterion = %s" % str(criterion))
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
if args.eval:
test_stats = evaluate(data_loader_val, model, device)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
exit(0)
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
max_accuracy = 0.0
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler,
max_norm=None,
log_writer=log_writer,
args=args
)
if args.output_dir:
misc.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch)
test_stats = evaluate(data_loader_val, model, device)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
max_accuracy = max(max_accuracy, test_stats["acc1"])
print(f'Max accuracy: {max_accuracy:.2f}%')
if log_writer is not None:
log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch)
log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch)
log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
if args.output_dir and misc.is_main_process():
if log_writer is not None:
log_writer.flush()
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
args = get_args_parser()
args = args.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)

View File

@ -55,7 +55,7 @@ class MaskedAutoencoderViT(nn.Module):
for i in range(decoder_depth)])
self.decoder_norm = norm_layer(decoder_embed_dim)
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # encoder to decoder
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
# --------------------------------------------------------------------------
self.norm_pix_loss = norm_pix_loss

View File

@ -0,0 +1,131 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# A script to run multinode training with submitit.
# --------------------------------------------------------
import argparse
import os
import uuid
from pathlib import Path
import main_linprobe as classification
import submitit
def parse_args():
classification_parser = classification.get_args_parser()
parser = argparse.ArgumentParser("Submitit for MAE linear probe", parents=[classification_parser])
parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request")
parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job")
parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.")
parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit")
parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs")
parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler")
return parser.parse_args()
def get_shared_folder() -> Path:
user = os.getenv("USER")
if Path("/checkpoint/").is_dir():
p = Path(f"/checkpoint/{user}/experiments")
p.mkdir(exist_ok=True)
return p
raise RuntimeError("No shared folder available")
def get_init_file():
# Init file must not exist, but it's parent dir must exist.
os.makedirs(str(get_shared_folder()), exist_ok=True)
init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init"
if init_file.exists():
os.remove(str(init_file))
return init_file
class Trainer(object):
def __init__(self, args):
self.args = args
def __call__(self):
import main_linprobe as classification
self._setup_gpu_args()
classification.main(self.args)
def checkpoint(self):
import os
import submitit
self.args.dist_url = get_init_file().as_uri()
checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth")
if os.path.exists(checkpoint_file):
self.args.resume = checkpoint_file
print("Requeuing ", self.args)
empty_trainer = type(self)(self.args)
return submitit.helpers.DelayedSubmission(empty_trainer)
def _setup_gpu_args(self):
import submitit
from pathlib import Path
job_env = submitit.JobEnvironment()
self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id)))
self.args.log_dir = self.args.output_dir
self.args.gpu = job_env.local_rank
self.args.rank = job_env.global_rank
self.args.world_size = job_env.num_tasks
print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
def main():
args = parse_args()
if args.job_dir == "":
args.job_dir = get_shared_folder() / "%j"
# Note that the folder will depend on the job_id, to easily track experiments
executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30)
num_gpus_per_node = args.ngpus
nodes = args.nodes
timeout_min = args.timeout
partition = args.partition
kwargs = {}
if args.use_volta32:
kwargs['slurm_constraint'] = 'volta32gb'
if args.comment:
kwargs['slurm_comment'] = args.comment
executor.update_parameters(
mem_gb=40 * num_gpus_per_node,
gpus_per_node=num_gpus_per_node,
tasks_per_node=num_gpus_per_node, # one task per GPU
cpus_per_task=10,
nodes=nodes,
timeout_min=timeout_min,
# Below are cluster dependent parameters
slurm_partition=partition,
slurm_signal_delay_s=120,
**kwargs
)
executor.update_parameters(name="mae")
args.dist_url = get_init_file().as_uri()
args.output_dir = args.job_dir
trainer = Trainer(args)
job = executor.submit(trainer)
# print("Submitted job_id:", job.job_id)
print(job.job_id)
if __name__ == "__main__":
main()

42
util/crop.py 100644
View File

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from torchvision import transforms
from torchvision.transforms import functional as F
class RandomResizedCrop(transforms.RandomResizedCrop):
"""
RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
This may lead to results different with torchvision's version.
Following BYOL's TF code:
https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
"""
@staticmethod
def get_params(img, scale, ratio):
width, height = F._get_image_size(img)
area = height * width
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
log_ratio = torch.log(torch.tensor(ratio))
aspect_ratio = torch.exp(
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
).item()
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
w = min(w, width)
h = min(h, height)
i = torch.randint(0, height - h + 1, size=(1,)).item()
j = torch.randint(0, width - w + 1, size=(1,)).item()
return i, j, h, w

47
util/lars.py 100644
View File

@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# LARS optimizer, implementation from MoCo v3:
# https://github.com/facebookresearch/moco-v3
# --------------------------------------------------------
import torch
class LARS(torch.optim.Optimizer):
"""
LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
"""
def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
super().__init__(params, defaults)
@torch.no_grad()
def step(self):
for g in self.param_groups:
for p in g['params']:
dp = p.grad
if dp is None:
continue
if p.ndim > 1: # if not normalization gamma/beta or bias
dp = dp.add(p, alpha=g['weight_decay'])
param_norm = torch.norm(p)
update_norm = torch.norm(dp)
one = torch.ones_like(param_norm)
q = torch.where(param_norm > 0.,
torch.where(update_norm > 0,
(g['trust_coefficient'] * param_norm / update_norm), one),
one)
dp = dp.mul(q)
param_state = self.state[p]
if 'mu' not in param_state:
param_state['mu'] = torch.zeros_like(p)
mu = param_state['mu']
mu.mul_(g['momentum']).add_(dp)
p.add_(mu, alpha=-g['lr'])