[CodeCamp2023-584]Support DINO self-supervised learning in project (#1756)
* feat: impelemt DINO * chore: delete debug code * chore: impplement pre-commit * fix: fix imported package * chore: pre-commit checkpull/1782/head
parent
732b0f4c98
commit
d2ccc44a2c
|
@ -0,0 +1,26 @@
|
|||
# Implementation for DINO
|
||||
|
||||
**NOTE**: We only guarantee correctness of the forward pass, not responsible for full reimplementation.
|
||||
|
||||
First, ensure you are in the root directory of MMPretrain, then you have two choices
|
||||
to play with DINO in MMPretrain:
|
||||
|
||||
## Slurm
|
||||
|
||||
If you are using a cluster managed by Slurm, you can use the following command to
|
||||
start your job:
|
||||
|
||||
```shell
|
||||
GPUS_PER_NODE=8 GPUS=8 CPUS_PER_TASK=16 bash projects/dino/tools/slurm_train.sh mm_model dino projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py --amp
|
||||
```
|
||||
|
||||
The above command will pre-train the model on a single node with 8 GPUs.
|
||||
|
||||
## PyTorch
|
||||
|
||||
If you are using a single machine, without any cluster management software, you can use the following command
|
||||
|
||||
```shell
|
||||
NNODES=1 bash projects/dino/tools/dist_train.sh projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py 8
|
||||
--amp
|
||||
```
|
|
@ -0,0 +1,104 @@
|
|||
model = dict(
|
||||
type='DINO',
|
||||
data_preprocessor=dict(
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True),
|
||||
backbone=dict(
|
||||
type='mmpretrain.VisionTransformer', arch='b', patch_size=16),
|
||||
neck=dict(
|
||||
type='DINONeck',
|
||||
in_channels=768,
|
||||
out_channels=65536,
|
||||
hidden_channels=2048,
|
||||
bottleneck_channels=256),
|
||||
head=dict(
|
||||
type='DINOHead',
|
||||
out_channels=65536,
|
||||
num_crops=10,
|
||||
student_temp=0.1,
|
||||
center_momentum=0.9))
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='DINOMultiCrop',
|
||||
global_crops_scale=(0.4, 1.0),
|
||||
local_crops_scale=(0.05, 0.4),
|
||||
local_crops_number=8),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=16,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
collate_fn=dict(type='default_collate'),
|
||||
dataset=dict(
|
||||
type='mmpretrain.ImageNet',
|
||||
data_root='/data/imagenet/',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix=dict(img_path='train/'),
|
||||
pipeline=train_pipeline,
|
||||
))
|
||||
optimizer = dict(type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05)
|
||||
optim_wrapper = dict(
|
||||
type='AmpOptimWrapper',
|
||||
optimizer=dict(
|
||||
type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys=dict(
|
||||
ln=dict(decay_mult=0.0),
|
||||
bias=dict(decay_mult=0.0),
|
||||
pos_embed=dict(decay_mult=0.0),
|
||||
mask_token=dict(decay_mult=0.0),
|
||||
cls_token=dict(decay_mult=0.0))),
|
||||
loss_scale='dynamic')
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-09,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=10,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=90,
|
||||
by_epoch=True,
|
||||
begin=10,
|
||||
end=100,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100)
|
||||
default_scope = 'mmpretrain'
|
||||
default_hooks = dict(
|
||||
runtime_info=dict(type='RuntimeInfoHook'),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
logger=dict(type='LoggerHook', interval=100),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'))
|
||||
env_cfg = dict(
|
||||
cudnn_benchmark=False,
|
||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
||||
dist_cfg=dict(backend='nccl'))
|
||||
log_processor = dict(
|
||||
window_size=10,
|
||||
custom_cfg=[dict(data_src='', method='mean', window_size='global')])
|
||||
vis_backends = [dict(type='LocalVisBackend')]
|
||||
visualizer = dict(
|
||||
type='UniversalVisualizer',
|
||||
vis_backends=[dict(type='LocalVisBackend')],
|
||||
name='visualizer')
|
||||
log_level = 'INFO'
|
||||
load_from = None
|
||||
resume = True
|
||||
randomness = dict(seed=2, diff_rank_seed=True)
|
||||
custom_hooks = [
|
||||
dict(
|
||||
type='DINOTeacherTempWarmupHook',
|
||||
warmup_teacher_temp=0.04,
|
||||
teacher_temp=0.04,
|
||||
teacher_temp_warmup_epochs=0,
|
||||
max_epochs=100)
|
||||
]
|
|
@ -0,0 +1 @@
|
|||
from .transform import * # noqa: F401,F403
|
|
@ -0,0 +1,3 @@
|
|||
from .processing import DINOMultiCrop
|
||||
|
||||
__all__ = ['DINOMultiCrop']
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import random
|
||||
|
||||
from mmcv.transforms import RandomApply # noqa: E501
|
||||
from mmcv.transforms import BaseTransform, Compose, RandomFlip, RandomGrayscale
|
||||
|
||||
from mmpretrain.datasets.transforms import (ColorJitter, GaussianBlur,
|
||||
RandomResizedCrop, Solarize)
|
||||
from mmpretrain.registry import TRANSFORMS
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class DINOMultiCrop(BaseTransform):
|
||||
"""Multi-crop transform for DINO.
|
||||
|
||||
This module applies the multi-crop transform for DINO.
|
||||
|
||||
Args:
|
||||
global_crops_scale (int): Scale of global crops.
|
||||
local_crops_scale (int): Scale of local crops.
|
||||
local_crops_number (int): Number of local crops.
|
||||
"""
|
||||
|
||||
def __init__(self, global_crops_scale: int, local_crops_scale: int,
|
||||
local_crops_number: int) -> None:
|
||||
super().__init__()
|
||||
self.global_crops_scale = global_crops_scale
|
||||
self.local_crops_scale = local_crops_scale
|
||||
|
||||
flip_and_color_jitter = Compose([
|
||||
RandomFlip(prob=0.5, direction='horizontal'),
|
||||
RandomApply([
|
||||
ColorJitter(
|
||||
brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)
|
||||
],
|
||||
prob=0.8),
|
||||
RandomGrayscale(
|
||||
prob=0.2,
|
||||
keep_channels=True,
|
||||
channel_weights=(0.114, 0.587, 0.2989),
|
||||
)
|
||||
])
|
||||
|
||||
self.global_transform_1 = Compose([
|
||||
RandomResizedCrop(
|
||||
224,
|
||||
crop_ratio_range=global_crops_scale,
|
||||
interpolation='bicubic'),
|
||||
flip_and_color_jitter,
|
||||
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)),
|
||||
])
|
||||
|
||||
self.global_transform_2 = Compose([
|
||||
RandomResizedCrop(
|
||||
224,
|
||||
crop_ratio_range=global_crops_scale,
|
||||
interpolation='bicubic'),
|
||||
flip_and_color_jitter,
|
||||
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)),
|
||||
Solarize(thr=128, prob=0.2),
|
||||
])
|
||||
|
||||
self.local_crops_number = local_crops_number
|
||||
self.local_transform = Compose([
|
||||
RandomResizedCrop(
|
||||
96,
|
||||
crop_ratio_range=local_crops_scale,
|
||||
interpolation='bicubic'),
|
||||
flip_and_color_jitter,
|
||||
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)),
|
||||
])
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
ori_img = results['img']
|
||||
crops = []
|
||||
results['img'] = ori_img
|
||||
crops.append(self.global_transform_1(results)['img'])
|
||||
results['img'] = ori_img
|
||||
crops.append(self.global_transform_2(results)['img'])
|
||||
for _ in range(self.local_crops_number):
|
||||
results['img'] = ori_img
|
||||
crops.append(self.local_transform(results)['img'])
|
||||
results['img'] = crops
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(global_crops_scale = {self.global_crops_scale}, '
|
||||
repr_str += f'local_crops_scale = {self.local_crops_scale}, '
|
||||
repr_str += f'local_crop_number = {self.local_crops_number})'
|
||||
return repr_str
|
|
@ -0,0 +1 @@
|
|||
from .hooks import * # noqa
|
|
@ -0,0 +1,3 @@
|
|||
from .dino_teacher_temp_warmup_hook import DINOTeacherTempWarmupHook
|
||||
|
||||
__all__ = ['DINOTeacherTempWarmupHook']
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
from mmengine.hooks import Hook
|
||||
|
||||
from mmpretrain.registry import HOOKS
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class DINOTeacherTempWarmupHook(Hook):
|
||||
"""Warmup teacher temperature for DINO.
|
||||
|
||||
This hook warmups the temperature for teacher to stabilize the training
|
||||
process.
|
||||
|
||||
Args:
|
||||
warmup_teacher_temp (float): Warmup temperature for teacher.
|
||||
teacher_temp (float): Temperature for teacher.
|
||||
teacher_temp_warmup_epochs (int): Warmup epochs for teacher
|
||||
temperature.
|
||||
max_epochs (int): Maximum epochs for training.
|
||||
"""
|
||||
|
||||
def __init__(self, warmup_teacher_temp: float, teacher_temp: float,
|
||||
teacher_temp_warmup_epochs: int, max_epochs: int) -> None:
|
||||
super().__init__()
|
||||
self.teacher_temps = np.concatenate(
|
||||
(np.linspace(warmup_teacher_temp, teacher_temp,
|
||||
teacher_temp_warmup_epochs),
|
||||
np.ones(max_epochs - teacher_temp_warmup_epochs) * teacher_temp))
|
||||
|
||||
def before_train_epoch(self, runner) -> None:
|
||||
runner.model.module.head.teacher_temp = self.teacher_temps[
|
||||
runner.epoch]
|
|
@ -0,0 +1,3 @@
|
|||
from .algorithm import * # noqa
|
||||
from .head import * # noqa
|
||||
from .neck import * # noqa
|
|
@ -0,0 +1,3 @@
|
|||
from .dino import DINO
|
||||
|
||||
__all__ = ['DINO']
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from mmpretrain.models import BaseSelfSupervisor, CosineEMA
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DINO(BaseSelfSupervisor):
|
||||
"""Implementation for DINO.
|
||||
|
||||
This module is proposed in `DINO: Emerging Properties in Self-Supervised
|
||||
Vision Transformers <https://arxiv.org/abs/2104.14294>`_.
|
||||
|
||||
Args:
|
||||
backbone (dict): Config for backbone.
|
||||
neck (dict): Config for neck.
|
||||
head (dict): Config for head.
|
||||
pretrained (str, optional): Path for pretrained model.
|
||||
Defaults to None.
|
||||
base_momentum (float, optional): Base momentum for momentum update.
|
||||
Defaults to 0.99.
|
||||
data_preprocessor (dict, optional): Config for data preprocessor.
|
||||
Defaults to None.
|
||||
init_cfg (list[dict] | dict, optional): Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone: dict,
|
||||
neck: dict,
|
||||
head: dict,
|
||||
pretrained: Optional[str] = None,
|
||||
base_momentum: float = 0.99,
|
||||
data_preprocessor: Optional[dict] = None,
|
||||
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
||||
super().__init__(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
pretrained=pretrained,
|
||||
data_preprocessor=data_preprocessor,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
# create momentum model
|
||||
self.teacher = CosineEMA(
|
||||
nn.Sequential(self.backbone, self.neck), momentum=base_momentum)
|
||||
# weight normalization layer
|
||||
self.neck.last_layer = nn.utils.weight_norm(self.neck.last_layer)
|
||||
self.neck.last_layer.weight_g.data.fill_(1)
|
||||
self.neck.last_layer.weight_g.requires_grad = False
|
||||
self.teacher.module[1].last_layer = nn.utils.weight_norm(
|
||||
self.teacher.module[1].last_layer)
|
||||
self.teacher.module[1].last_layer.weight_g.data.fill_(1)
|
||||
self.teacher.module[1].last_layer.weight_g.requires_grad = False
|
||||
|
||||
def loss(self, inputs: torch.Tensor,
|
||||
data_samples: List[DataSample]) -> dict:
|
||||
global_crops = torch.cat(inputs[:2])
|
||||
local_crops = torch.cat(inputs[2:])
|
||||
# teacher forward
|
||||
teacher_output = self.teacher(global_crops)
|
||||
|
||||
# student forward global
|
||||
student_output_global = self.backbone(global_crops)
|
||||
student_output_global = self.neck(student_output_global)
|
||||
|
||||
# student forward local
|
||||
student_output_local = self.backbone(local_crops)
|
||||
student_output_local = self.neck(student_output_local)
|
||||
|
||||
student_output = torch.cat(
|
||||
(student_output_global, student_output_local))
|
||||
|
||||
# compute loss
|
||||
loss = self.head(student_output, teacher_output)
|
||||
|
||||
return dict(loss=loss)
|
|
@ -0,0 +1,3 @@
|
|||
from .dino_head import DINOHead
|
||||
|
||||
__all__ = ['DINOHead']
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine.dist import all_reduce, get_world_size
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DINOHead(BaseModule):
|
||||
"""Implementation for DINO head.
|
||||
|
||||
This module is proposed in `DINO: Emerging Properties in Self-Supervised
|
||||
Vision Transformers <https://arxiv.org/abs/2104.14294>`_.
|
||||
|
||||
Args:
|
||||
out_channels (int): Output channels of the head.
|
||||
num_crops (int): Number of crops.
|
||||
student_temp (float): Temperature for student output.
|
||||
center_momentum (float): Momentum for center update.
|
||||
"""
|
||||
|
||||
def __init__(self, out_channels: int, num_crops: int, student_temp: float,
|
||||
center_momentum: float) -> None:
|
||||
super().__init__()
|
||||
self.student_temp = student_temp
|
||||
self.teacher_temp = 0
|
||||
self.center_momentum = center_momentum
|
||||
self.num_crops = num_crops
|
||||
self.register_buffer('center', torch.zeros(1, out_channels))
|
||||
|
||||
def forward(self, student_output: torch.Tensor,
|
||||
teacher_output: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
current_teacher_output = teacher_output
|
||||
student_output = student_output / self.student_temp
|
||||
student_output = student_output.chunk(self.num_crops, dim=0)
|
||||
|
||||
# teacher centering and sharpening
|
||||
teacher_output = F.softmax(
|
||||
(teacher_output - self.center) / self.teacher_temp, dim=-1)
|
||||
teacher_output = teacher_output.detach().chunk(2, dim=0)
|
||||
|
||||
total_loss = 0
|
||||
n_loss_terms = 0
|
||||
|
||||
for i in range(len(teacher_output)):
|
||||
for j in range(len(student_output)):
|
||||
if i == j:
|
||||
continue
|
||||
total_loss += (-teacher_output[i] *
|
||||
student_output[j].log_softmax(dim=-1)).sum(
|
||||
dim=-1).mean()
|
||||
n_loss_terms += 1
|
||||
total_loss /= n_loss_terms
|
||||
self.update_center(current_teacher_output)
|
||||
return total_loss
|
||||
|
||||
@torch.no_grad()
|
||||
def update_center(self, teacher_output: torch.Tensor) -> None:
|
||||
|
||||
batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
|
||||
all_reduce(batch_center)
|
||||
batch_center = batch_center / (len(teacher_output) * get_world_size())
|
||||
|
||||
# ema update batch center
|
||||
self.center = self.center * self.center_momentum + batch_center * (
|
||||
1 - self.center_momentum)
|
|
@ -0,0 +1,3 @@
|
|||
from .dino_neck import DINONeck
|
||||
|
||||
__all__ = ['DINONeck']
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmengine.model import BaseModule
|
||||
from torch import nn
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DINONeck(BaseModule):
|
||||
"""Implementation for DINO neck.
|
||||
|
||||
This module is proposed in `DINO: Emerging Properties in Self-Supervised
|
||||
Vision Transformers <https://arxiv.org/abs/2104.14294>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels.
|
||||
hidden_channels (int): Hidden channels.
|
||||
out_channels (int): Output channels.
|
||||
bottleneck_channels (int): Bottleneck channels.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, hidden_channels: int,
|
||||
out_channels: int, bottleneck_channels: int) -> None:
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(*[
|
||||
nn.Linear(in_channels, hidden_channels),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_channels, hidden_channels),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_channels, bottleneck_channels),
|
||||
])
|
||||
|
||||
self.last_layer = nn.Linear(
|
||||
bottleneck_channels, out_channels, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.mlp(x[0])
|
||||
x = nn.functional.normalize(x, dim=-1, p=2)
|
||||
x = self.last_layer(x)
|
||||
return x
|
|
@ -0,0 +1,19 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
CONFIG=$1
|
||||
GPUS=$2
|
||||
NNODES=${NNODES:-1}
|
||||
NODE_RANK=${NODE_RANK:-0}
|
||||
PORT=${PORT:-29500}
|
||||
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
|
||||
|
||||
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
|
||||
python -m torch.distributed.launch \
|
||||
--nnodes=$NNODES \
|
||||
--node_rank=$NODE_RANK \
|
||||
--master_addr=$MASTER_ADDR \
|
||||
--nproc_per_node=$GPUS \
|
||||
--master_port=$PORT \
|
||||
$(dirname "$0")/train.py \
|
||||
$CONFIG \
|
||||
--launcher pytorch ${@:3}
|
|
@ -0,0 +1,23 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -x
|
||||
|
||||
PARTITION=$1
|
||||
JOB_NAME=$2
|
||||
CONFIG=$3
|
||||
GPUS=${GPUS:-8}
|
||||
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
|
||||
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
|
||||
SRUN_ARGS=${SRUN_ARGS:-""}
|
||||
PY_ARGS=${@:4}
|
||||
|
||||
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
|
||||
srun -p ${PARTITION} \
|
||||
--job-name=${JOB_NAME} \
|
||||
--gres=gpu:${GPUS_PER_NODE} \
|
||||
--ntasks=${GPUS} \
|
||||
--ntasks-per-node=${GPUS_PER_NODE} \
|
||||
--cpus-per-task=${CPUS_PER_TASK} \
|
||||
--kill-on-bad-exit=1 \
|
||||
${SRUN_ARGS} \
|
||||
python -u projects/dino/tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS}
|
|
@ -0,0 +1,104 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
from dataset import * # noqa: F401,F403
|
||||
from engine import * # noqa: F401,F403
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.runner import Runner
|
||||
from models.algorithm import * # noqa: F401,F403
|
||||
from models.head import * # noqa: F401,F403
|
||||
from models.neck import * # noqa: F401,F403
|
||||
|
||||
from mmpretrain.utils import register_all_modules
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Train a model')
|
||||
parser.add_argument('config', help='train config file path')
|
||||
parser.add_argument('--work-dir', help='the dir to save logs and models')
|
||||
parser.add_argument(
|
||||
'--resume',
|
||||
nargs='?',
|
||||
type=str,
|
||||
const='auto',
|
||||
help='If specify checkpint path, resume from it, while if not '
|
||||
'specify, try to auto resume from the latest checkpoint '
|
||||
'in the work directory.')
|
||||
parser.add_argument(
|
||||
'--amp',
|
||||
action='store_true',
|
||||
help='enable automatic-mixed-precision training')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
parser.add_argument(
|
||||
'--launcher',
|
||||
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
||||
default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
if 'LOCAL_RANK' not in os.environ:
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# register all modules in mmpretrain into the registries
|
||||
# do not init the default scope here because it will be init in the runner
|
||||
register_all_modules(init_default_scope=False)
|
||||
|
||||
# load config
|
||||
cfg = Config.fromfile(args.config)
|
||||
cfg.launcher = args.launcher
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# work_dir is determined in this priority: CLI > segment in file > filename
|
||||
if args.work_dir is not None:
|
||||
# update configs according to CLI args if args.work_dir is not None
|
||||
cfg.work_dir = args.work_dir
|
||||
elif cfg.get('work_dir', None) is None:
|
||||
# use config filename as default work_dir if cfg.work_dir is None
|
||||
work_type = args.config.split('/')[1]
|
||||
cfg.work_dir = osp.join('./work_dirs', work_type,
|
||||
osp.splitext(osp.basename(args.config))[0])
|
||||
|
||||
# enable automatic-mixed-precision training
|
||||
if args.amp is True:
|
||||
optim_wrapper = cfg.optim_wrapper.get('type', 'OptimWrapper')
|
||||
assert optim_wrapper in ['OptimWrapper', 'AmpOptimWrapper'], \
|
||||
'`--amp` is not supported custom optimizer wrapper type ' \
|
||||
f'`{optim_wrapper}.'
|
||||
cfg.optim_wrapper.type = 'AmpOptimWrapper'
|
||||
cfg.optim_wrapper.setdefault('loss_scale', 'dynamic')
|
||||
|
||||
# resume training
|
||||
if args.resume == 'auto':
|
||||
cfg.resume = True
|
||||
cfg.load_from = None
|
||||
elif args.resume is not None:
|
||||
cfg.resume = True
|
||||
cfg.load_from = args.resume
|
||||
|
||||
# build the runner from config
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
# start training
|
||||
runner.train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue