add linear probing
parent
be47fef7a7
commit
6a2ba40229
|
@ -5,7 +5,7 @@ possible.
|
|||
## Pull Requests
|
||||
We actively welcome your pull requests.
|
||||
|
||||
1. Fork the repo and create your branch from `master`.
|
||||
1. Fork the repo and create your branch from `main`.
|
||||
2. If you've added code that should be tested, add tests.
|
||||
3. If you've changed APIs, update the documentation.
|
||||
4. Ensure the test suite passes.
|
||||
|
|
46
FINETUNE.md
46
FINETUNE.md
|
@ -124,10 +124,54 @@ OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 main_fin
|
|||
```
|
||||
- Here the effective batch size is 32 (`batch_size` per gpu) * 4 (`accum_iter`) * 8 (gpus) = 1024. `--accum_iter 4` simulates 4 nodes.
|
||||
|
||||
### Notes
|
||||
#### Notes
|
||||
|
||||
- The [pre-trained models we provide](https://github.com/fairinternal/mae/#pre-trained-checkpoints) are trained with *normalized* pixels `--norm_pix_loss` (1600 epochs, Table 3 in paper). The fine-tuning hyper-parameters are slightly different from the default baseline using *unnormalized* pixels.
|
||||
|
||||
- The original MAE implementation was in TensorFlow+TPU with no explicit mixed precision. This re-implementation is in PyTorch+GPU with automatic mixed precision (`torch.cuda.amp`). We have observed different numerical behavior between the two platforms. In this repo, we use `--global_pool` for fine-tuning; using `--cls_token` performs similarly, but there is a chance of producing NaN when fine-tuning ViT-Huge in GPUs. We did not observe this issue in TPUs. Turning off amp could solve this issue, but is slower.
|
||||
|
||||
- Here we use RandErase following DeiT: `--reprob 0.25`. Its effect is smaller than random variance.
|
||||
|
||||
### Linear Probing
|
||||
|
||||
Run the following on 4 nodes with 8 GPUs each:
|
||||
```
|
||||
python submitit_linprobe.py \
|
||||
--job_dir ${JOB_DIR} \
|
||||
--nodes 4 \
|
||||
--batch_size 512 \
|
||||
--model vit_base_patch16 --cls_token \
|
||||
--finetune ${PRETRAIN_CHKPT} \
|
||||
--epochs 90 \
|
||||
--blr 0.1 \
|
||||
--weight_decay 0.0 \
|
||||
--dist_eval --data_path ${IMAGENET_DIR}
|
||||
```
|
||||
- Here the effective batch size is 512 (`batch_size` per gpu) * 4 (`nodes`) * 8 (gpus per node) = 16384.
|
||||
- `blr` is the base learning rate. The actual `lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `lr` = `blr` * effective batch size / 256.
|
||||
- Training time is ~2h20m for 90 epochs in 32 V100 GPUs.
|
||||
- To run single-node training, follow the instruction in fine-tuning.
|
||||
|
||||
To train ViT-Large or ViT-Huge, set `--model vit_large_patch16` or `--model vit_huge_patch14`. It is sufficient to train 50 epochs `--epochs 50`.
|
||||
|
||||
This PT/GPU code produces *better* results for ViT-L/H (see the table below). This is likely caused by the system difference between TF and PT.
|
||||
|
||||
<table><tbody>
|
||||
<!-- START TABLE -->
|
||||
<!-- TABLE HEADER -->
|
||||
<th valign="bottom"></th>
|
||||
<th valign="bottom">ViT-Base</th>
|
||||
<th valign="bottom">ViT-Large</th>
|
||||
<th valign="bottom">ViT-Huge</th>
|
||||
<!-- TABLE BODY -->
|
||||
<tr><td align="left">paper (TF/TPU)</td>
|
||||
<td align="center">68.0</td>
|
||||
<td align="center">75.8</td>
|
||||
<td align="center">76.6</td>
|
||||
</tr>
|
||||
<tr><td align="left">this repo (PT/GPU)</td>
|
||||
<td align="center">67.8</td>
|
||||
<td align="center">76.0</td>
|
||||
<td align="center">77.2</td>
|
||||
</tr>
|
||||
</tbody></table>
|
||||
|
|
|
@ -20,4 +20,4 @@ python submitit_pretrain.py \
|
|||
- The exact same hyper-parameters and configs (initialization, augmentation, etc.) are used as our TF/TPU implementation. In our sanity checks, this PT/GPU re-implementation can reproduce the TF/TPU results within reasonable random variation. We get 85.5% [fine-tuning](FINETUNE.md) accuracy by pre-training ViT-Large for 800 epochs (85.4% in paper Table 1d with TF/TPU).
|
||||
- Training time is ~42h in 64 V100 GPUs (800 epochs).
|
||||
|
||||
To train ViT-Base or ViT-Huge, set `--model vit_base_patch16` or `--model vit_huge_patch14`.
|
||||
To train ViT-Base or ViT-Huge, set `--model mae_vit_base_patch16` or `--model mae_vit_huge_patch14`.
|
||||
|
|
|
@ -0,0 +1,316 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# DeiT: https://github.com/facebookresearch/deit
|
||||
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
||||
# --------------------------------------------------------
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
|
||||
import timm
|
||||
|
||||
assert timm.__version__ == "0.3.2" # version check
|
||||
from timm.models.layers import trunc_normal_
|
||||
|
||||
import util.misc as misc
|
||||
from util.pos_embed import interpolate_pos_embed
|
||||
from util.misc import NativeScalerWithGradNormCount as NativeScaler
|
||||
from util.lars import LARS
|
||||
from util.crop import RandomResizedCrop
|
||||
|
||||
import models_vit
|
||||
|
||||
from engine_finetune import train_one_epoch, evaluate
|
||||
|
||||
|
||||
def get_args_parser():
|
||||
parser = argparse.ArgumentParser('MAE linear probing for image classification', add_help=False)
|
||||
parser.add_argument('--batch_size', default=512, type=int,
|
||||
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
|
||||
parser.add_argument('--epochs', default=90, type=int)
|
||||
parser.add_argument('--accum_iter', default=1, type=int,
|
||||
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
|
||||
|
||||
# Model parameters
|
||||
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
|
||||
help='Name of model to train')
|
||||
|
||||
# Optimizer parameters
|
||||
parser.add_argument('--weight_decay', type=float, default=0,
|
||||
help='weight decay (default: 0 for linear probe following MoCo v1)')
|
||||
|
||||
parser.add_argument('--lr', type=float, default=None, metavar='LR',
|
||||
help='learning rate (absolute lr)')
|
||||
parser.add_argument('--blr', type=float, default=0.1, metavar='LR',
|
||||
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
|
||||
|
||||
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
|
||||
help='lower lr bound for cyclic schedulers that hit 0')
|
||||
|
||||
parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
|
||||
help='epochs to warmup LR')
|
||||
|
||||
# * Finetuning params
|
||||
parser.add_argument('--finetune', default='',
|
||||
help='finetune from checkpoint')
|
||||
parser.add_argument('--global_pool', action='store_true')
|
||||
parser.set_defaults(global_pool=False)
|
||||
parser.add_argument('--cls_token', action='store_false', dest='global_pool',
|
||||
help='Use class token instead of global pool for classification')
|
||||
|
||||
# Dataset parameters
|
||||
parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
|
||||
help='dataset path')
|
||||
parser.add_argument('--nb_classes', default=1000, type=int,
|
||||
help='number of the classification types')
|
||||
|
||||
parser.add_argument('--output_dir', default='./output_dir',
|
||||
help='path where to save, empty for no saving')
|
||||
parser.add_argument('--log_dir', default='./output_dir',
|
||||
help='path where to tensorboard log')
|
||||
parser.add_argument('--device', default='cuda',
|
||||
help='device to use for training / testing')
|
||||
parser.add_argument('--seed', default=0, type=int)
|
||||
parser.add_argument('--resume', default='',
|
||||
help='resume from checkpoint')
|
||||
|
||||
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
||||
help='start epoch')
|
||||
parser.add_argument('--eval', action='store_true',
|
||||
help='Perform evaluation only')
|
||||
parser.add_argument('--dist_eval', action='store_true', default=False,
|
||||
help='Enabling distributed evaluation (recommended during training for faster monitor')
|
||||
parser.add_argument('--num_workers', default=10, type=int)
|
||||
parser.add_argument('--pin_mem', action='store_true',
|
||||
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
||||
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
|
||||
parser.set_defaults(pin_mem=True)
|
||||
|
||||
# distributed training parameters
|
||||
parser.add_argument('--world_size', default=1, type=int,
|
||||
help='number of distributed processes')
|
||||
parser.add_argument('--local_rank', default=-1, type=int)
|
||||
parser.add_argument('--dist_on_itp', action='store_true')
|
||||
parser.add_argument('--dist_url', default='env://',
|
||||
help='url used to set up distributed training')
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(args):
|
||||
misc.init_distributed_mode(args)
|
||||
|
||||
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
|
||||
print("{}".format(args).replace(', ', ',\n'))
|
||||
|
||||
device = torch.device(args.device)
|
||||
|
||||
# fix the seed for reproducibility
|
||||
seed = args.seed + misc.get_rank()
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
# linear probe: weak augmentation
|
||||
transform_train = transforms.Compose([
|
||||
RandomResizedCrop(224, interpolation=3),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
|
||||
transform_val = transforms.Compose([
|
||||
transforms.Resize(256, interpolation=3),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
|
||||
dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train)
|
||||
dataset_val = datasets.ImageFolder(os.path.join(args.data_path, 'val'), transform=transform_val)
|
||||
print(dataset_train)
|
||||
print(dataset_val)
|
||||
|
||||
if True: # args.distributed:
|
||||
num_tasks = misc.get_world_size()
|
||||
global_rank = misc.get_rank()
|
||||
sampler_train = torch.utils.data.DistributedSampler(
|
||||
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
||||
)
|
||||
print("Sampler_train = %s" % str(sampler_train))
|
||||
if args.dist_eval:
|
||||
if len(dataset_val) % num_tasks != 0:
|
||||
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
|
||||
'This will slightly alter validation results as extra duplicate entries are added to achieve '
|
||||
'equal num of samples per-process.')
|
||||
sampler_val = torch.utils.data.DistributedSampler(
|
||||
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias
|
||||
else:
|
||||
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
||||
else:
|
||||
sampler_train = torch.utils.data.RandomSampler(dataset_train)
|
||||
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
||||
|
||||
if global_rank == 0 and args.log_dir is not None and not args.eval:
|
||||
os.makedirs(args.log_dir, exist_ok=True)
|
||||
log_writer = SummaryWriter(log_dir=args.log_dir)
|
||||
else:
|
||||
log_writer = None
|
||||
|
||||
data_loader_train = torch.utils.data.DataLoader(
|
||||
dataset_train, sampler=sampler_train,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=args.pin_mem,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
data_loader_val = torch.utils.data.DataLoader(
|
||||
dataset_val, sampler=sampler_val,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=args.pin_mem,
|
||||
drop_last=False
|
||||
)
|
||||
|
||||
model = models_vit.__dict__[args.model](
|
||||
num_classes=args.nb_classes,
|
||||
global_pool=args.global_pool,
|
||||
)
|
||||
|
||||
if args.finetune and not args.eval:
|
||||
checkpoint = torch.load(args.finetune, map_location='cpu')
|
||||
|
||||
print("Load pre-trained checkpoint from: %s" % args.finetune)
|
||||
checkpoint_model = checkpoint['model']
|
||||
state_dict = model.state_dict()
|
||||
for k in ['head.weight', 'head.bias']:
|
||||
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
|
||||
print(f"Removing key {k} from pretrained checkpoint")
|
||||
del checkpoint_model[k]
|
||||
|
||||
# interpolate position embedding
|
||||
interpolate_pos_embed(model, checkpoint_model)
|
||||
|
||||
# load pre-trained model
|
||||
msg = model.load_state_dict(checkpoint_model, strict=False)
|
||||
print(msg)
|
||||
|
||||
if args.global_pool:
|
||||
assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
|
||||
else:
|
||||
assert set(msg.missing_keys) == {'head.weight', 'head.bias'}
|
||||
|
||||
# manually initialize fc layer: following MoCo v3
|
||||
trunc_normal_(model.head.weight, std=0.01)
|
||||
|
||||
# for linear prob only
|
||||
# hack: revise model's head with BN
|
||||
model.head = torch.nn.Sequential(torch.nn.BatchNorm1d(model.head.in_features, affine=False, eps=1e-6), model.head)
|
||||
# freeze all but the head
|
||||
for _, p in model.named_parameters():
|
||||
p.requires_grad = False
|
||||
for _, p in model.head.named_parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
model.to(device)
|
||||
|
||||
model_without_ddp = model
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
print("Model = %s" % str(model_without_ddp))
|
||||
print('number of params (M): %.2f' % (n_parameters / 1.e6))
|
||||
|
||||
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
|
||||
|
||||
if args.lr is None: # only base_lr is specified
|
||||
args.lr = args.blr * eff_batch_size / 256
|
||||
|
||||
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
|
||||
print("actual lr: %.2e" % args.lr)
|
||||
|
||||
print("accumulate grad iterations: %d" % args.accum_iter)
|
||||
print("effective batch size: %d" % eff_batch_size)
|
||||
|
||||
if args.distributed:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
||||
model_without_ddp = model.module
|
||||
|
||||
optimizer = LARS(model_without_ddp.head.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
print(optimizer)
|
||||
loss_scaler = NativeScaler()
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
print("criterion = %s" % str(criterion))
|
||||
|
||||
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
|
||||
|
||||
if args.eval:
|
||||
test_stats = evaluate(data_loader_val, model, device)
|
||||
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
|
||||
exit(0)
|
||||
|
||||
print(f"Start training for {args.epochs} epochs")
|
||||
start_time = time.time()
|
||||
max_accuracy = 0.0
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
if args.distributed:
|
||||
data_loader_train.sampler.set_epoch(epoch)
|
||||
train_stats = train_one_epoch(
|
||||
model, criterion, data_loader_train,
|
||||
optimizer, device, epoch, loss_scaler,
|
||||
max_norm=None,
|
||||
log_writer=log_writer,
|
||||
args=args
|
||||
)
|
||||
if args.output_dir:
|
||||
misc.save_model(
|
||||
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
||||
loss_scaler=loss_scaler, epoch=epoch)
|
||||
|
||||
test_stats = evaluate(data_loader_val, model, device)
|
||||
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
|
||||
max_accuracy = max(max_accuracy, test_stats["acc1"])
|
||||
print(f'Max accuracy: {max_accuracy:.2f}%')
|
||||
|
||||
if log_writer is not None:
|
||||
log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch)
|
||||
log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch)
|
||||
log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch)
|
||||
|
||||
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
||||
**{f'test_{k}': v for k, v in test_stats.items()},
|
||||
'epoch': epoch,
|
||||
'n_parameters': n_parameters}
|
||||
|
||||
if args.output_dir and misc.is_main_process():
|
||||
if log_writer is not None:
|
||||
log_writer.flush()
|
||||
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(log_stats) + "\n")
|
||||
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('Training time {}'.format(total_time_str))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args_parser()
|
||||
args = args.parse_args()
|
||||
if args.output_dir:
|
||||
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
main(args)
|
|
@ -55,7 +55,7 @@ class MaskedAutoencoderViT(nn.Module):
|
|||
for i in range(decoder_depth)])
|
||||
|
||||
self.decoder_norm = norm_layer(decoder_embed_dim)
|
||||
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # encoder to decoder
|
||||
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
self.norm_pix_loss = norm_pix_loss
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# A script to run multinode training with submitit.
|
||||
# --------------------------------------------------------
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import main_linprobe as classification
|
||||
import submitit
|
||||
|
||||
|
||||
def parse_args():
|
||||
classification_parser = classification.get_args_parser()
|
||||
parser = argparse.ArgumentParser("Submitit for MAE linear probe", parents=[classification_parser])
|
||||
parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
|
||||
parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request")
|
||||
parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job")
|
||||
parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.")
|
||||
|
||||
parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit")
|
||||
parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs")
|
||||
parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_shared_folder() -> Path:
|
||||
user = os.getenv("USER")
|
||||
if Path("/checkpoint/").is_dir():
|
||||
p = Path(f"/checkpoint/{user}/experiments")
|
||||
p.mkdir(exist_ok=True)
|
||||
return p
|
||||
raise RuntimeError("No shared folder available")
|
||||
|
||||
|
||||
def get_init_file():
|
||||
# Init file must not exist, but it's parent dir must exist.
|
||||
os.makedirs(str(get_shared_folder()), exist_ok=True)
|
||||
init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init"
|
||||
if init_file.exists():
|
||||
os.remove(str(init_file))
|
||||
return init_file
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
def __call__(self):
|
||||
import main_linprobe as classification
|
||||
|
||||
self._setup_gpu_args()
|
||||
classification.main(self.args)
|
||||
|
||||
def checkpoint(self):
|
||||
import os
|
||||
import submitit
|
||||
|
||||
self.args.dist_url = get_init_file().as_uri()
|
||||
checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth")
|
||||
if os.path.exists(checkpoint_file):
|
||||
self.args.resume = checkpoint_file
|
||||
print("Requeuing ", self.args)
|
||||
empty_trainer = type(self)(self.args)
|
||||
return submitit.helpers.DelayedSubmission(empty_trainer)
|
||||
|
||||
def _setup_gpu_args(self):
|
||||
import submitit
|
||||
from pathlib import Path
|
||||
|
||||
job_env = submitit.JobEnvironment()
|
||||
self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id)))
|
||||
self.args.log_dir = self.args.output_dir
|
||||
self.args.gpu = job_env.local_rank
|
||||
self.args.rank = job_env.global_rank
|
||||
self.args.world_size = job_env.num_tasks
|
||||
print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
if args.job_dir == "":
|
||||
args.job_dir = get_shared_folder() / "%j"
|
||||
|
||||
# Note that the folder will depend on the job_id, to easily track experiments
|
||||
executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30)
|
||||
|
||||
num_gpus_per_node = args.ngpus
|
||||
nodes = args.nodes
|
||||
timeout_min = args.timeout
|
||||
|
||||
partition = args.partition
|
||||
kwargs = {}
|
||||
if args.use_volta32:
|
||||
kwargs['slurm_constraint'] = 'volta32gb'
|
||||
if args.comment:
|
||||
kwargs['slurm_comment'] = args.comment
|
||||
|
||||
executor.update_parameters(
|
||||
mem_gb=40 * num_gpus_per_node,
|
||||
gpus_per_node=num_gpus_per_node,
|
||||
tasks_per_node=num_gpus_per_node, # one task per GPU
|
||||
cpus_per_task=10,
|
||||
nodes=nodes,
|
||||
timeout_min=timeout_min,
|
||||
# Below are cluster dependent parameters
|
||||
slurm_partition=partition,
|
||||
slurm_signal_delay_s=120,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
executor.update_parameters(name="mae")
|
||||
|
||||
args.dist_url = get_init_file().as_uri()
|
||||
args.output_dir = args.job_dir
|
||||
|
||||
trainer = Trainer(args)
|
||||
job = executor.submit(trainer)
|
||||
|
||||
# print("Submitted job_id:", job.job_id)
|
||||
print(job.job_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
class RandomResizedCrop(transforms.RandomResizedCrop):
|
||||
"""
|
||||
RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
|
||||
This may lead to results different with torchvision's version.
|
||||
Following BYOL's TF code:
|
||||
https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
|
||||
"""
|
||||
@staticmethod
|
||||
def get_params(img, scale, ratio):
|
||||
width, height = F._get_image_size(img)
|
||||
area = height * width
|
||||
|
||||
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
|
||||
log_ratio = torch.log(torch.tensor(ratio))
|
||||
aspect_ratio = torch.exp(
|
||||
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
|
||||
).item()
|
||||
|
||||
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
w = min(w, width)
|
||||
h = min(h, height)
|
||||
|
||||
i = torch.randint(0, height - h + 1, size=(1,)).item()
|
||||
j = torch.randint(0, width - w + 1, size=(1,)).item()
|
||||
|
||||
return i, j, h, w
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# LARS optimizer, implementation from MoCo v3:
|
||||
# https://github.com/facebookresearch/moco-v3
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class LARS(torch.optim.Optimizer):
|
||||
"""
|
||||
LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
|
||||
"""
|
||||
def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
|
||||
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
for g in self.param_groups:
|
||||
for p in g['params']:
|
||||
dp = p.grad
|
||||
|
||||
if dp is None:
|
||||
continue
|
||||
|
||||
if p.ndim > 1: # if not normalization gamma/beta or bias
|
||||
dp = dp.add(p, alpha=g['weight_decay'])
|
||||
param_norm = torch.norm(p)
|
||||
update_norm = torch.norm(dp)
|
||||
one = torch.ones_like(param_norm)
|
||||
q = torch.where(param_norm > 0.,
|
||||
torch.where(update_norm > 0,
|
||||
(g['trust_coefficient'] * param_norm / update_norm), one),
|
||||
one)
|
||||
dp = dp.mul(q)
|
||||
|
||||
param_state = self.state[p]
|
||||
if 'mu' not in param_state:
|
||||
param_state['mu'] = torch.zeros_like(p)
|
||||
mu = param_state['mu']
|
||||
mu.mul_(g['momentum']).add_(dp)
|
||||
p.add_(mu, alpha=-g['lr'])
|
Loading…
Reference in New Issue