From d2ccc44a2c8e5d49bb26187aff42f2abc90aee28 Mon Sep 17 00:00:00 2001 From: LALBJ <40877073+LALBJ@users.noreply.github.com> Date: Wed, 23 Aug 2023 10:45:18 +0800 Subject: [PATCH] [CodeCamp2023-584]Support DINO self-supervised learning in project (#1756) * feat: impelemt DINO * chore: delete debug code * chore: impplement pre-commit * fix: fix imported package * chore: pre-commit check --- projects/dino/README.md | 26 +++++ ..._vit-base-p16_8xb64-amp-coslr-100e_in1k.py | 104 ++++++++++++++++++ projects/dino/dataset/__init__.py | 1 + projects/dino/dataset/transform/__init__.py | 3 + projects/dino/dataset/transform/processing.py | 91 +++++++++++++++ projects/dino/engine/__init__.py | 1 + projects/dino/engine/hooks/__init__.py | 3 + .../hooks/dino_teacher_temp_warmup_hook.py | 33 ++++++ projects/dino/models/__init__.py | 3 + projects/dino/models/algorithm/__init__.py | 3 + projects/dino/models/algorithm/dino.py | 82 ++++++++++++++ projects/dino/models/head/__init__.py | 3 + projects/dino/models/head/dino_head.py | 69 ++++++++++++ projects/dino/models/neck/__init__.py | 3 + projects/dino/models/neck/dino_neck.py | 41 +++++++ projects/dino/tools/dist_train.sh | 19 ++++ projects/dino/tools/slurm_train.sh | 23 ++++ projects/dino/tools/train.py | 104 ++++++++++++++++++ 18 files changed, 612 insertions(+) create mode 100644 projects/dino/README.md create mode 100644 projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py create mode 100644 projects/dino/dataset/__init__.py create mode 100644 projects/dino/dataset/transform/__init__.py create mode 100644 projects/dino/dataset/transform/processing.py create mode 100644 projects/dino/engine/__init__.py create mode 100644 projects/dino/engine/hooks/__init__.py create mode 100644 projects/dino/engine/hooks/dino_teacher_temp_warmup_hook.py create mode 100644 projects/dino/models/__init__.py create mode 100644 projects/dino/models/algorithm/__init__.py create mode 100644 projects/dino/models/algorithm/dino.py create mode 100644 projects/dino/models/head/__init__.py create mode 100644 projects/dino/models/head/dino_head.py create mode 100644 projects/dino/models/neck/__init__.py create mode 100644 projects/dino/models/neck/dino_neck.py create mode 100644 projects/dino/tools/dist_train.sh create mode 100644 projects/dino/tools/slurm_train.sh create mode 100644 projects/dino/tools/train.py diff --git a/projects/dino/README.md b/projects/dino/README.md new file mode 100644 index 00000000..3458fa4c --- /dev/null +++ b/projects/dino/README.md @@ -0,0 +1,26 @@ +# Implementation for DINO + +**NOTE**: We only guarantee correctness of the forward pass, not responsible for full reimplementation. + +First, ensure you are in the root directory of MMPretrain, then you have two choices +to play with DINO in MMPretrain: + +## Slurm + +If you are using a cluster managed by Slurm, you can use the following command to +start your job: + +```shell +GPUS_PER_NODE=8 GPUS=8 CPUS_PER_TASK=16 bash projects/dino/tools/slurm_train.sh mm_model dino projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py --amp +``` + +The above command will pre-train the model on a single node with 8 GPUs. + +## PyTorch + +If you are using a single machine, without any cluster management software, you can use the following command + +```shell +NNODES=1 bash projects/dino/tools/dist_train.sh projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py 8 +--amp +``` diff --git a/projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py b/projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py new file mode 100644 index 00000000..d4a1c240 --- /dev/null +++ b/projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py @@ -0,0 +1,104 @@ +model = dict( + type='DINO', + data_preprocessor=dict( + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + type='mmpretrain.VisionTransformer', arch='b', patch_size=16), + neck=dict( + type='DINONeck', + in_channels=768, + out_channels=65536, + hidden_channels=2048, + bottleneck_channels=256), + head=dict( + type='DINOHead', + out_channels=65536, + num_crops=10, + student_temp=0.1, + center_momentum=0.9)) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='DINOMultiCrop', + global_crops_scale=(0.4, 1.0), + local_crops_scale=(0.05, 0.4), + local_crops_number=8), + dict(type='PackInputs') +] +train_dataloader = dict( + batch_size=32, + num_workers=16, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate'), + dataset=dict( + type='mmpretrain.ImageNet', + data_root='/data/imagenet/', + ann_file='meta/train.txt', + data_prefix=dict(img_path='train/'), + pipeline=train_pipeline, + )) +optimizer = dict(type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05) +optim_wrapper = dict( + type='AmpOptimWrapper', + optimizer=dict( + type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05), + paramwise_cfg=dict( + custom_keys=dict( + ln=dict(decay_mult=0.0), + bias=dict(decay_mult=0.0), + pos_embed=dict(decay_mult=0.0), + mask_token=dict(decay_mult=0.0), + cls_token=dict(decay_mult=0.0))), + loss_scale='dynamic') +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-09, + by_epoch=True, + begin=0, + end=10, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=90, + by_epoch=True, + begin=10, + end=100, + convert_to_iter_based=True) +] +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100) +default_scope = 'mmpretrain' +default_hooks = dict( + runtime_info=dict(type='RuntimeInfoHook'), + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=100), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1), + sampler_seed=dict(type='DistSamplerSeedHook')) +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl')) +log_processor = dict( + window_size=10, + custom_cfg=[dict(data_src='', method='mean', window_size='global')]) +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='UniversalVisualizer', + vis_backends=[dict(type='LocalVisBackend')], + name='visualizer') +log_level = 'INFO' +load_from = None +resume = True +randomness = dict(seed=2, diff_rank_seed=True) +custom_hooks = [ + dict( + type='DINOTeacherTempWarmupHook', + warmup_teacher_temp=0.04, + teacher_temp=0.04, + teacher_temp_warmup_epochs=0, + max_epochs=100) +] diff --git a/projects/dino/dataset/__init__.py b/projects/dino/dataset/__init__.py new file mode 100644 index 00000000..da65f285 --- /dev/null +++ b/projects/dino/dataset/__init__.py @@ -0,0 +1 @@ +from .transform import * # noqa: F401,F403 diff --git a/projects/dino/dataset/transform/__init__.py b/projects/dino/dataset/transform/__init__.py new file mode 100644 index 00000000..00dacb3f --- /dev/null +++ b/projects/dino/dataset/transform/__init__.py @@ -0,0 +1,3 @@ +from .processing import DINOMultiCrop + +__all__ = ['DINOMultiCrop'] diff --git a/projects/dino/dataset/transform/processing.py b/projects/dino/dataset/transform/processing.py new file mode 100644 index 00000000..df4bf0be --- /dev/null +++ b/projects/dino/dataset/transform/processing.py @@ -0,0 +1,91 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random + +from mmcv.transforms import RandomApply # noqa: E501 +from mmcv.transforms import BaseTransform, Compose, RandomFlip, RandomGrayscale + +from mmpretrain.datasets.transforms import (ColorJitter, GaussianBlur, + RandomResizedCrop, Solarize) +from mmpretrain.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class DINOMultiCrop(BaseTransform): + """Multi-crop transform for DINO. + + This module applies the multi-crop transform for DINO. + + Args: + global_crops_scale (int): Scale of global crops. + local_crops_scale (int): Scale of local crops. + local_crops_number (int): Number of local crops. + """ + + def __init__(self, global_crops_scale: int, local_crops_scale: int, + local_crops_number: int) -> None: + super().__init__() + self.global_crops_scale = global_crops_scale + self.local_crops_scale = local_crops_scale + + flip_and_color_jitter = Compose([ + RandomFlip(prob=0.5, direction='horizontal'), + RandomApply([ + ColorJitter( + brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1) + ], + prob=0.8), + RandomGrayscale( + prob=0.2, + keep_channels=True, + channel_weights=(0.114, 0.587, 0.2989), + ) + ]) + + self.global_transform_1 = Compose([ + RandomResizedCrop( + 224, + crop_ratio_range=global_crops_scale, + interpolation='bicubic'), + flip_and_color_jitter, + GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)), + ]) + + self.global_transform_2 = Compose([ + RandomResizedCrop( + 224, + crop_ratio_range=global_crops_scale, + interpolation='bicubic'), + flip_and_color_jitter, + GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)), + Solarize(thr=128, prob=0.2), + ]) + + self.local_crops_number = local_crops_number + self.local_transform = Compose([ + RandomResizedCrop( + 96, + crop_ratio_range=local_crops_scale, + interpolation='bicubic'), + flip_and_color_jitter, + GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)), + ]) + + def transform(self, results: dict) -> dict: + ori_img = results['img'] + crops = [] + results['img'] = ori_img + crops.append(self.global_transform_1(results)['img']) + results['img'] = ori_img + crops.append(self.global_transform_2(results)['img']) + for _ in range(self.local_crops_number): + results['img'] = ori_img + crops.append(self.local_transform(results)['img']) + results['img'] = crops + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(global_crops_scale = {self.global_crops_scale}, ' + repr_str += f'local_crops_scale = {self.local_crops_scale}, ' + repr_str += f'local_crop_number = {self.local_crops_number})' + return repr_str diff --git a/projects/dino/engine/__init__.py b/projects/dino/engine/__init__.py new file mode 100644 index 00000000..41422545 --- /dev/null +++ b/projects/dino/engine/__init__.py @@ -0,0 +1 @@ +from .hooks import * # noqa diff --git a/projects/dino/engine/hooks/__init__.py b/projects/dino/engine/hooks/__init__.py new file mode 100644 index 00000000..df43c492 --- /dev/null +++ b/projects/dino/engine/hooks/__init__.py @@ -0,0 +1,3 @@ +from .dino_teacher_temp_warmup_hook import DINOTeacherTempWarmupHook + +__all__ = ['DINOTeacherTempWarmupHook'] diff --git a/projects/dino/engine/hooks/dino_teacher_temp_warmup_hook.py b/projects/dino/engine/hooks/dino_teacher_temp_warmup_hook.py new file mode 100644 index 00000000..d66b0250 --- /dev/null +++ b/projects/dino/engine/hooks/dino_teacher_temp_warmup_hook.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from mmengine.hooks import Hook + +from mmpretrain.registry import HOOKS + + +@HOOKS.register_module() +class DINOTeacherTempWarmupHook(Hook): + """Warmup teacher temperature for DINO. + + This hook warmups the temperature for teacher to stabilize the training + process. + + Args: + warmup_teacher_temp (float): Warmup temperature for teacher. + teacher_temp (float): Temperature for teacher. + teacher_temp_warmup_epochs (int): Warmup epochs for teacher + temperature. + max_epochs (int): Maximum epochs for training. + """ + + def __init__(self, warmup_teacher_temp: float, teacher_temp: float, + teacher_temp_warmup_epochs: int, max_epochs: int) -> None: + super().__init__() + self.teacher_temps = np.concatenate( + (np.linspace(warmup_teacher_temp, teacher_temp, + teacher_temp_warmup_epochs), + np.ones(max_epochs - teacher_temp_warmup_epochs) * teacher_temp)) + + def before_train_epoch(self, runner) -> None: + runner.model.module.head.teacher_temp = self.teacher_temps[ + runner.epoch] diff --git a/projects/dino/models/__init__.py b/projects/dino/models/__init__.py new file mode 100644 index 00000000..49d01487 --- /dev/null +++ b/projects/dino/models/__init__.py @@ -0,0 +1,3 @@ +from .algorithm import * # noqa +from .head import * # noqa +from .neck import * # noqa diff --git a/projects/dino/models/algorithm/__init__.py b/projects/dino/models/algorithm/__init__.py new file mode 100644 index 00000000..1125b63f --- /dev/null +++ b/projects/dino/models/algorithm/__init__.py @@ -0,0 +1,3 @@ +from .dino import DINO + +__all__ = ['DINO'] diff --git a/projects/dino/models/algorithm/dino.py b/projects/dino/models/algorithm/dino.py new file mode 100644 index 00000000..2d78922f --- /dev/null +++ b/projects/dino/models/algorithm/dino.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +from torch import nn + +from mmpretrain.models import BaseSelfSupervisor, CosineEMA +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class DINO(BaseSelfSupervisor): + """Implementation for DINO. + + This module is proposed in `DINO: Emerging Properties in Self-Supervised + Vision Transformers `_. + + Args: + backbone (dict): Config for backbone. + neck (dict): Config for neck. + head (dict): Config for head. + pretrained (str, optional): Path for pretrained model. + Defaults to None. + base_momentum (float, optional): Base momentum for momentum update. + Defaults to 0.99. + data_preprocessor (dict, optional): Config for data preprocessor. + Defaults to None. + init_cfg (list[dict] | dict, optional): Config for initialization. + Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: dict, + head: dict, + pretrained: Optional[str] = None, + base_momentum: float = 0.99, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + # create momentum model + self.teacher = CosineEMA( + nn.Sequential(self.backbone, self.neck), momentum=base_momentum) + # weight normalization layer + self.neck.last_layer = nn.utils.weight_norm(self.neck.last_layer) + self.neck.last_layer.weight_g.data.fill_(1) + self.neck.last_layer.weight_g.requires_grad = False + self.teacher.module[1].last_layer = nn.utils.weight_norm( + self.teacher.module[1].last_layer) + self.teacher.module[1].last_layer.weight_g.data.fill_(1) + self.teacher.module[1].last_layer.weight_g.requires_grad = False + + def loss(self, inputs: torch.Tensor, + data_samples: List[DataSample]) -> dict: + global_crops = torch.cat(inputs[:2]) + local_crops = torch.cat(inputs[2:]) + # teacher forward + teacher_output = self.teacher(global_crops) + + # student forward global + student_output_global = self.backbone(global_crops) + student_output_global = self.neck(student_output_global) + + # student forward local + student_output_local = self.backbone(local_crops) + student_output_local = self.neck(student_output_local) + + student_output = torch.cat( + (student_output_global, student_output_local)) + + # compute loss + loss = self.head(student_output, teacher_output) + + return dict(loss=loss) diff --git a/projects/dino/models/head/__init__.py b/projects/dino/models/head/__init__.py new file mode 100644 index 00000000..fe31e084 --- /dev/null +++ b/projects/dino/models/head/__init__.py @@ -0,0 +1,3 @@ +from .dino_head import DINOHead + +__all__ = ['DINOHead'] diff --git a/projects/dino/models/head/dino_head.py b/projects/dino/models/head/dino_head.py new file mode 100644 index 00000000..e817bfad --- /dev/null +++ b/projects/dino/models/head/dino_head.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F +from mmengine.dist import all_reduce, get_world_size +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class DINOHead(BaseModule): + """Implementation for DINO head. + + This module is proposed in `DINO: Emerging Properties in Self-Supervised + Vision Transformers `_. + + Args: + out_channels (int): Output channels of the head. + num_crops (int): Number of crops. + student_temp (float): Temperature for student output. + center_momentum (float): Momentum for center update. + """ + + def __init__(self, out_channels: int, num_crops: int, student_temp: float, + center_momentum: float) -> None: + super().__init__() + self.student_temp = student_temp + self.teacher_temp = 0 + self.center_momentum = center_momentum + self.num_crops = num_crops + self.register_buffer('center', torch.zeros(1, out_channels)) + + def forward(self, student_output: torch.Tensor, + teacher_output: torch.Tensor) -> torch.Tensor: + + current_teacher_output = teacher_output + student_output = student_output / self.student_temp + student_output = student_output.chunk(self.num_crops, dim=0) + + # teacher centering and sharpening + teacher_output = F.softmax( + (teacher_output - self.center) / self.teacher_temp, dim=-1) + teacher_output = teacher_output.detach().chunk(2, dim=0) + + total_loss = 0 + n_loss_terms = 0 + + for i in range(len(teacher_output)): + for j in range(len(student_output)): + if i == j: + continue + total_loss += (-teacher_output[i] * + student_output[j].log_softmax(dim=-1)).sum( + dim=-1).mean() + n_loss_terms += 1 + total_loss /= n_loss_terms + self.update_center(current_teacher_output) + return total_loss + + @torch.no_grad() + def update_center(self, teacher_output: torch.Tensor) -> None: + + batch_center = torch.sum(teacher_output, dim=0, keepdim=True) + all_reduce(batch_center) + batch_center = batch_center / (len(teacher_output) * get_world_size()) + + # ema update batch center + self.center = self.center * self.center_momentum + batch_center * ( + 1 - self.center_momentum) diff --git a/projects/dino/models/neck/__init__.py b/projects/dino/models/neck/__init__.py new file mode 100644 index 00000000..e5f4aadb --- /dev/null +++ b/projects/dino/models/neck/__init__.py @@ -0,0 +1,3 @@ +from .dino_neck import DINONeck + +__all__ = ['DINONeck'] diff --git a/projects/dino/models/neck/dino_neck.py b/projects/dino/models/neck/dino_neck.py new file mode 100644 index 00000000..8d8881ea --- /dev/null +++ b/projects/dino/models/neck/dino_neck.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.model import BaseModule +from torch import nn + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class DINONeck(BaseModule): + """Implementation for DINO neck. + + This module is proposed in `DINO: Emerging Properties in Self-Supervised + Vision Transformers `_. + + Args: + in_channels (int): Input channels. + hidden_channels (int): Hidden channels. + out_channels (int): Output channels. + bottleneck_channels (int): Bottleneck channels. + """ + + def __init__(self, in_channels: int, hidden_channels: int, + out_channels: int, bottleneck_channels: int) -> None: + super().__init__() + self.mlp = nn.Sequential(*[ + nn.Linear(in_channels, hidden_channels), + nn.GELU(), + nn.Linear(hidden_channels, hidden_channels), + nn.GELU(), + nn.Linear(hidden_channels, bottleneck_channels), + ]) + + self.last_layer = nn.Linear( + bottleneck_channels, out_channels, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(x[0]) + x = nn.functional.normalize(x, dim=-1, p=2) + x = self.last_layer(x) + return x diff --git a/projects/dino/tools/dist_train.sh b/projects/dino/tools/dist_train.sh new file mode 100644 index 00000000..3fca7641 --- /dev/null +++ b/projects/dino/tools/dist_train.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +CONFIG=$1 +GPUS=$2 +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +PORT=${PORT:-29500} +MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +python -m torch.distributed.launch \ + --nnodes=$NNODES \ + --node_rank=$NODE_RANK \ + --master_addr=$MASTER_ADDR \ + --nproc_per_node=$GPUS \ + --master_port=$PORT \ + $(dirname "$0")/train.py \ + $CONFIG \ + --launcher pytorch ${@:3} diff --git a/projects/dino/tools/slurm_train.sh b/projects/dino/tools/slurm_train.sh new file mode 100644 index 00000000..7e2ad297 --- /dev/null +++ b/projects/dino/tools/slurm_train.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +set -x + +PARTITION=$1 +JOB_NAME=$2 +CONFIG=$3 +GPUS=${GPUS:-8} +GPUS_PER_NODE=${GPUS_PER_NODE:-8} +CPUS_PER_TASK=${CPUS_PER_TASK:-5} +SRUN_ARGS=${SRUN_ARGS:-""} +PY_ARGS=${@:4} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +srun -p ${PARTITION} \ + --job-name=${JOB_NAME} \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ + --cpus-per-task=${CPUS_PER_TASK} \ + --kill-on-bad-exit=1 \ + ${SRUN_ARGS} \ + python -u projects/dino/tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS} diff --git a/projects/dino/tools/train.py b/projects/dino/tools/train.py new file mode 100644 index 00000000..b9482c3b --- /dev/null +++ b/projects/dino/tools/train.py @@ -0,0 +1,104 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp + +from dataset import * # noqa: F401,F403 +from engine import * # noqa: F401,F403 +from mmengine.config import Config, DictAction +from mmengine.runner import Runner +from models.algorithm import * # noqa: F401,F403 +from models.head import * # noqa: F401,F403 +from models.neck import * # noqa: F401,F403 + +from mmpretrain.utils import register_all_modules + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a model') + parser.add_argument('config', help='train config file path') + parser.add_argument('--work-dir', help='the dir to save logs and models') + parser.add_argument( + '--resume', + nargs='?', + type=str, + const='auto', + help='If specify checkpint path, resume from it, while if not ' + 'specify, try to auto resume from the latest checkpoint ' + 'in the work directory.') + parser.add_argument( + '--amp', + action='store_true', + help='enable automatic-mixed-precision training') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + return args + + +def main(): + args = parse_args() + + # register all modules in mmpretrain into the registries + # do not init the default scope here because it will be init in the runner + register_all_modules(init_default_scope=False) + + # load config + cfg = Config.fromfile(args.config) + cfg.launcher = args.launcher + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + work_type = args.config.split('/')[1] + cfg.work_dir = osp.join('./work_dirs', work_type, + osp.splitext(osp.basename(args.config))[0]) + + # enable automatic-mixed-precision training + if args.amp is True: + optim_wrapper = cfg.optim_wrapper.get('type', 'OptimWrapper') + assert optim_wrapper in ['OptimWrapper', 'AmpOptimWrapper'], \ + '`--amp` is not supported custom optimizer wrapper type ' \ + f'`{optim_wrapper}.' + cfg.optim_wrapper.type = 'AmpOptimWrapper' + cfg.optim_wrapper.setdefault('loss_scale', 'dynamic') + + # resume training + if args.resume == 'auto': + cfg.resume = True + cfg.load_from = None + elif args.resume is not None: + cfg.resume = True + cfg.load_from = args.resume + + # build the runner from config + runner = Runner.from_cfg(cfg) + + # start training + runner.train() + + +if __name__ == '__main__': + main()