[CodeCamp2023-584]Support DINO self-supervised learning in project (#1756)

* feat: impelemt DINO

* chore: delete debug code

* chore: impplement pre-commit

* fix: fix imported package

* chore: pre-commit check
pull/1782/head
LALBJ 2023-08-23 10:45:18 +08:00 committed by GitHub
parent 732b0f4c98
commit d2ccc44a2c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 612 additions and 0 deletions

View File

@ -0,0 +1,26 @@
# Implementation for DINO
**NOTE**: We only guarantee correctness of the forward pass, not responsible for full reimplementation.
First, ensure you are in the root directory of MMPretrain, then you have two choices
to play with DINO in MMPretrain:
## Slurm
If you are using a cluster managed by Slurm, you can use the following command to
start your job:
```shell
GPUS_PER_NODE=8 GPUS=8 CPUS_PER_TASK=16 bash projects/dino/tools/slurm_train.sh mm_model dino projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py --amp
```
The above command will pre-train the model on a single node with 8 GPUs.
## PyTorch
If you are using a single machine, without any cluster management software, you can use the following command
```shell
NNODES=1 bash projects/dino/tools/dist_train.sh projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py 8
--amp
```

View File

@ -0,0 +1,104 @@
model = dict(
type='DINO',
data_preprocessor=dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
type='mmpretrain.VisionTransformer', arch='b', patch_size=16),
neck=dict(
type='DINONeck',
in_channels=768,
out_channels=65536,
hidden_channels=2048,
bottleneck_channels=256),
head=dict(
type='DINOHead',
out_channels=65536,
num_crops=10,
student_temp=0.1,
center_momentum=0.9))
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='DINOMultiCrop',
global_crops_scale=(0.4, 1.0),
local_crops_scale=(0.05, 0.4),
local_crops_number=8),
dict(type='PackInputs')
]
train_dataloader = dict(
batch_size=32,
num_workers=16,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type='mmpretrain.ImageNet',
data_root='/data/imagenet/',
ann_file='meta/train.txt',
data_prefix=dict(img_path='train/'),
pipeline=train_pipeline,
))
optimizer = dict(type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05)
optim_wrapper = dict(
type='AmpOptimWrapper',
optimizer=dict(
type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05),
paramwise_cfg=dict(
custom_keys=dict(
ln=dict(decay_mult=0.0),
bias=dict(decay_mult=0.0),
pos_embed=dict(decay_mult=0.0),
mask_token=dict(decay_mult=0.0),
cls_token=dict(decay_mult=0.0))),
loss_scale='dynamic')
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-09,
by_epoch=True,
begin=0,
end=10,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=90,
by_epoch=True,
begin=10,
end=100,
convert_to_iter_based=True)
]
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100)
default_scope = 'mmpretrain'
default_hooks = dict(
runtime_info=dict(type='RuntimeInfoHook'),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=100),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1),
sampler_seed=dict(type='DistSamplerSeedHook'))
env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'))
log_processor = dict(
window_size=10,
custom_cfg=[dict(data_src='', method='mean', window_size='global')])
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='UniversalVisualizer',
vis_backends=[dict(type='LocalVisBackend')],
name='visualizer')
log_level = 'INFO'
load_from = None
resume = True
randomness = dict(seed=2, diff_rank_seed=True)
custom_hooks = [
dict(
type='DINOTeacherTempWarmupHook',
warmup_teacher_temp=0.04,
teacher_temp=0.04,
teacher_temp_warmup_epochs=0,
max_epochs=100)
]

View File

@ -0,0 +1 @@
from .transform import * # noqa: F401,F403

View File

@ -0,0 +1,3 @@
from .processing import DINOMultiCrop
__all__ = ['DINOMultiCrop']

View File

@ -0,0 +1,91 @@
# Copyright (c) OpenMMLab. All rights reserved.
import random
from mmcv.transforms import RandomApply # noqa: E501
from mmcv.transforms import BaseTransform, Compose, RandomFlip, RandomGrayscale
from mmpretrain.datasets.transforms import (ColorJitter, GaussianBlur,
RandomResizedCrop, Solarize)
from mmpretrain.registry import TRANSFORMS
@TRANSFORMS.register_module()
class DINOMultiCrop(BaseTransform):
"""Multi-crop transform for DINO.
This module applies the multi-crop transform for DINO.
Args:
global_crops_scale (int): Scale of global crops.
local_crops_scale (int): Scale of local crops.
local_crops_number (int): Number of local crops.
"""
def __init__(self, global_crops_scale: int, local_crops_scale: int,
local_crops_number: int) -> None:
super().__init__()
self.global_crops_scale = global_crops_scale
self.local_crops_scale = local_crops_scale
flip_and_color_jitter = Compose([
RandomFlip(prob=0.5, direction='horizontal'),
RandomApply([
ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)
],
prob=0.8),
RandomGrayscale(
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.2989),
)
])
self.global_transform_1 = Compose([
RandomResizedCrop(
224,
crop_ratio_range=global_crops_scale,
interpolation='bicubic'),
flip_and_color_jitter,
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)),
])
self.global_transform_2 = Compose([
RandomResizedCrop(
224,
crop_ratio_range=global_crops_scale,
interpolation='bicubic'),
flip_and_color_jitter,
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)),
Solarize(thr=128, prob=0.2),
])
self.local_crops_number = local_crops_number
self.local_transform = Compose([
RandomResizedCrop(
96,
crop_ratio_range=local_crops_scale,
interpolation='bicubic'),
flip_and_color_jitter,
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)),
])
def transform(self, results: dict) -> dict:
ori_img = results['img']
crops = []
results['img'] = ori_img
crops.append(self.global_transform_1(results)['img'])
results['img'] = ori_img
crops.append(self.global_transform_2(results)['img'])
for _ in range(self.local_crops_number):
results['img'] = ori_img
crops.append(self.local_transform(results)['img'])
results['img'] = crops
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(global_crops_scale = {self.global_crops_scale}, '
repr_str += f'local_crops_scale = {self.local_crops_scale}, '
repr_str += f'local_crop_number = {self.local_crops_number})'
return repr_str

View File

@ -0,0 +1 @@
from .hooks import * # noqa

View File

@ -0,0 +1,3 @@
from .dino_teacher_temp_warmup_hook import DINOTeacherTempWarmupHook
__all__ = ['DINOTeacherTempWarmupHook']

View File

@ -0,0 +1,33 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmengine.hooks import Hook
from mmpretrain.registry import HOOKS
@HOOKS.register_module()
class DINOTeacherTempWarmupHook(Hook):
"""Warmup teacher temperature for DINO.
This hook warmups the temperature for teacher to stabilize the training
process.
Args:
warmup_teacher_temp (float): Warmup temperature for teacher.
teacher_temp (float): Temperature for teacher.
teacher_temp_warmup_epochs (int): Warmup epochs for teacher
temperature.
max_epochs (int): Maximum epochs for training.
"""
def __init__(self, warmup_teacher_temp: float, teacher_temp: float,
teacher_temp_warmup_epochs: int, max_epochs: int) -> None:
super().__init__()
self.teacher_temps = np.concatenate(
(np.linspace(warmup_teacher_temp, teacher_temp,
teacher_temp_warmup_epochs),
np.ones(max_epochs - teacher_temp_warmup_epochs) * teacher_temp))
def before_train_epoch(self, runner) -> None:
runner.model.module.head.teacher_temp = self.teacher_temps[
runner.epoch]

View File

@ -0,0 +1,3 @@
from .algorithm import * # noqa
from .head import * # noqa
from .neck import * # noqa

View File

@ -0,0 +1,3 @@
from .dino import DINO
__all__ = ['DINO']

View File

@ -0,0 +1,82 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch
from torch import nn
from mmpretrain.models import BaseSelfSupervisor, CosineEMA
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
@MODELS.register_module()
class DINO(BaseSelfSupervisor):
"""Implementation for DINO.
This module is proposed in `DINO: Emerging Properties in Self-Supervised
Vision Transformers <https://arxiv.org/abs/2104.14294>`_.
Args:
backbone (dict): Config for backbone.
neck (dict): Config for neck.
head (dict): Config for head.
pretrained (str, optional): Path for pretrained model.
Defaults to None.
base_momentum (float, optional): Base momentum for momentum update.
Defaults to 0.99.
data_preprocessor (dict, optional): Config for data preprocessor.
Defaults to None.
init_cfg (list[dict] | dict, optional): Config for initialization.
Defaults to None.
"""
def __init__(self,
backbone: dict,
neck: dict,
head: dict,
pretrained: Optional[str] = None,
base_momentum: float = 0.99,
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
backbone=backbone,
neck=neck,
head=head,
pretrained=pretrained,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg)
# create momentum model
self.teacher = CosineEMA(
nn.Sequential(self.backbone, self.neck), momentum=base_momentum)
# weight normalization layer
self.neck.last_layer = nn.utils.weight_norm(self.neck.last_layer)
self.neck.last_layer.weight_g.data.fill_(1)
self.neck.last_layer.weight_g.requires_grad = False
self.teacher.module[1].last_layer = nn.utils.weight_norm(
self.teacher.module[1].last_layer)
self.teacher.module[1].last_layer.weight_g.data.fill_(1)
self.teacher.module[1].last_layer.weight_g.requires_grad = False
def loss(self, inputs: torch.Tensor,
data_samples: List[DataSample]) -> dict:
global_crops = torch.cat(inputs[:2])
local_crops = torch.cat(inputs[2:])
# teacher forward
teacher_output = self.teacher(global_crops)
# student forward global
student_output_global = self.backbone(global_crops)
student_output_global = self.neck(student_output_global)
# student forward local
student_output_local = self.backbone(local_crops)
student_output_local = self.neck(student_output_local)
student_output = torch.cat(
(student_output_global, student_output_local))
# compute loss
loss = self.head(student_output, teacher_output)
return dict(loss=loss)

View File

@ -0,0 +1,3 @@
from .dino_head import DINOHead
__all__ = ['DINOHead']

View File

@ -0,0 +1,69 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
from mmengine.dist import all_reduce, get_world_size
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
@MODELS.register_module()
class DINOHead(BaseModule):
"""Implementation for DINO head.
This module is proposed in `DINO: Emerging Properties in Self-Supervised
Vision Transformers <https://arxiv.org/abs/2104.14294>`_.
Args:
out_channels (int): Output channels of the head.
num_crops (int): Number of crops.
student_temp (float): Temperature for student output.
center_momentum (float): Momentum for center update.
"""
def __init__(self, out_channels: int, num_crops: int, student_temp: float,
center_momentum: float) -> None:
super().__init__()
self.student_temp = student_temp
self.teacher_temp = 0
self.center_momentum = center_momentum
self.num_crops = num_crops
self.register_buffer('center', torch.zeros(1, out_channels))
def forward(self, student_output: torch.Tensor,
teacher_output: torch.Tensor) -> torch.Tensor:
current_teacher_output = teacher_output
student_output = student_output / self.student_temp
student_output = student_output.chunk(self.num_crops, dim=0)
# teacher centering and sharpening
teacher_output = F.softmax(
(teacher_output - self.center) / self.teacher_temp, dim=-1)
teacher_output = teacher_output.detach().chunk(2, dim=0)
total_loss = 0
n_loss_terms = 0
for i in range(len(teacher_output)):
for j in range(len(student_output)):
if i == j:
continue
total_loss += (-teacher_output[i] *
student_output[j].log_softmax(dim=-1)).sum(
dim=-1).mean()
n_loss_terms += 1
total_loss /= n_loss_terms
self.update_center(current_teacher_output)
return total_loss
@torch.no_grad()
def update_center(self, teacher_output: torch.Tensor) -> None:
batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
all_reduce(batch_center)
batch_center = batch_center / (len(teacher_output) * get_world_size())
# ema update batch center
self.center = self.center * self.center_momentum + batch_center * (
1 - self.center_momentum)

View File

@ -0,0 +1,3 @@
from .dino_neck import DINONeck
__all__ = ['DINONeck']

View File

@ -0,0 +1,41 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.model import BaseModule
from torch import nn
from mmpretrain.registry import MODELS
@MODELS.register_module()
class DINONeck(BaseModule):
"""Implementation for DINO neck.
This module is proposed in `DINO: Emerging Properties in Self-Supervised
Vision Transformers <https://arxiv.org/abs/2104.14294>`_.
Args:
in_channels (int): Input channels.
hidden_channels (int): Hidden channels.
out_channels (int): Output channels.
bottleneck_channels (int): Bottleneck channels.
"""
def __init__(self, in_channels: int, hidden_channels: int,
out_channels: int, bottleneck_channels: int) -> None:
super().__init__()
self.mlp = nn.Sequential(*[
nn.Linear(in_channels, hidden_channels),
nn.GELU(),
nn.Linear(hidden_channels, hidden_channels),
nn.GELU(),
nn.Linear(hidden_channels, bottleneck_channels),
])
self.last_layer = nn.Linear(
bottleneck_channels, out_channels, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.mlp(x[0])
x = nn.functional.normalize(x, dim=-1, p=2)
x = self.last_layer(x)
return x

View File

@ -0,0 +1,19 @@
#!/usr/bin/env bash
CONFIG=$1
GPUS=$2
NNODES=${NNODES:-1}
NODE_RANK=${NODE_RANK:-0}
PORT=${PORT:-29500}
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch \
--nnodes=$NNODES \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--nproc_per_node=$GPUS \
--master_port=$PORT \
$(dirname "$0")/train.py \
$CONFIG \
--launcher pytorch ${@:3}

View File

@ -0,0 +1,23 @@
#!/usr/bin/env bash
set -x
PARTITION=$1
JOB_NAME=$2
CONFIG=$3
GPUS=${GPUS:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
SRUN_ARGS=${SRUN_ARGS:-""}
PY_ARGS=${@:4}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS_PER_NODE} \
--ntasks=${GPUS} \
--ntasks-per-node=${GPUS_PER_NODE} \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
${SRUN_ARGS} \
python -u projects/dino/tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS}

View File

@ -0,0 +1,104 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
from dataset import * # noqa: F401,F403
from engine import * # noqa: F401,F403
from mmengine.config import Config, DictAction
from mmengine.runner import Runner
from models.algorithm import * # noqa: F401,F403
from models.head import * # noqa: F401,F403
from models.neck import * # noqa: F401,F403
from mmpretrain.utils import register_all_modules
def parse_args():
parser = argparse.ArgumentParser(description='Train a model')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume',
nargs='?',
type=str,
const='auto',
help='If specify checkpint path, resume from it, while if not '
'specify, try to auto resume from the latest checkpoint '
'in the work directory.')
parser.add_argument(
'--amp',
action='store_true',
help='enable automatic-mixed-precision training')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
# register all modules in mmpretrain into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)
# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
work_type = args.config.split('/')[1]
cfg.work_dir = osp.join('./work_dirs', work_type,
osp.splitext(osp.basename(args.config))[0])
# enable automatic-mixed-precision training
if args.amp is True:
optim_wrapper = cfg.optim_wrapper.get('type', 'OptimWrapper')
assert optim_wrapper in ['OptimWrapper', 'AmpOptimWrapper'], \
'`--amp` is not supported custom optimizer wrapper type ' \
f'`{optim_wrapper}.'
cfg.optim_wrapper.type = 'AmpOptimWrapper'
cfg.optim_wrapper.setdefault('loss_scale', 'dynamic')
# resume training
if args.resume == 'auto':
cfg.resume = True
cfg.load_from = None
elif args.resume is not None:
cfg.resume = True
cfg.load_from = args.resume
# build the runner from config
runner = Runner.from_cfg(cfg)
# start training
runner.train()
if __name__ == '__main__':
main()