add DeiT III (#171)

1.Add a backbone: deitiii.
2.Add an optimizer: lamb. 
3.Add a sampler: RASampler. 
4.Add a lr update hook: CosineAnnealingWarmupByEpochLrUpdaterHook.
5.In easycv/models/classification/classification.py, I remove the default mixup_cfg to keep the classification.py clean.
This commit is contained in:
zzoneee 2022-09-14 15:24:54 +08:00 committed by GitHub
parent 29f0e42427
commit 0cb91de0cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 982 additions and 221 deletions

View File

@ -0,0 +1,143 @@
# from PIL import Image
_base_ = 'configs/base.py'
log_config = dict(
interval=10,
hooks=[dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')])
# model settings
model = dict(
type='Classification',
train_preprocess=['mixUp'],
pretrained=False,
mixup_cfg=dict(
mixup_alpha=0.8,
cutmix_alpha=1.0,
cutmix_minmax=None,
prob=1.0,
switch_prob=0.5,
mode='batch',
label_smoothing=0.0,
num_classes=1000),
backbone=dict(
type='VisionTransformer',
img_size=[192],
num_classes=1000,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
drop_rate=0.,
drop_path_rate=0.2,
use_layer_scale=True),
head=dict(
type='ClsHead',
loss_config=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0,
label_ceil=True),
with_fc=False,
use_num_classes=False))
data_train_list = 'data/imagenet1k/train.txt'
data_train_root = 'data/imagenet1k/train/'
data_test_list = 'data/imagenet1k/val.txt'
data_test_root = 'data/imagenet1k/val/'
dataset_type = 'ClsDataset'
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
three_augment_policies = [[
dict(type='PILGaussianBlur', prob=1.0, radius_min=0.1, radius_max=2.0),
], [
dict(type='Solarization', threshold=128),
], [
dict(type='Grayscale', num_output_channels=3),
]]
train_pipeline = [
dict(
type='RandomResizedCrop', size=192, scale=(0.08, 1.0),
interpolation=3), # interpolation='bicubic'
dict(type='RandomHorizontalFlip'),
dict(type='MMAutoAugment', policies=three_augment_policies),
dict(type='ColorJitter', brightness=0.3, contrast=0.3, saturation=0.3),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Collect', keys=['img', 'gt_labels'])
]
size = int((256 / 224) * 192)
test_pipeline = [
dict(type='Resize', size=size, interpolation=3),
dict(type='CenterCrop', size=192),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Collect', keys=['img', 'gt_labels'])
]
data = dict(
imgs_per_gpu=256,
workers_per_gpu=8,
use_repeated_augment_sampler=True,
train=dict(
type=dataset_type,
data_source=dict(
list_file=data_train_list,
root=data_train_root,
type='ClsSourceImageList'),
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_source=dict(
list_file=data_test_list,
root=data_test_root,
type='ClsSourceImageList'),
pipeline=test_pipeline))
eval_config = dict(initial=True, interval=1, gpu_collect=True)
eval_pipelines = [
dict(
mode='test',
data=data['val'],
dist_eval=True,
evaluators=[dict(type='ClsEvaluator', topk=(1, 5))],
)
]
# additional hooks
custom_hooks = []
# optimizer
optimizer = dict(
type='Lamb',
lr=0.003,
weight_decay=0.05,
eps=1e-8,
paramwise_options={
'cls_token': dict(weight_decay=0.),
'pos_embed': dict(weight_decay=0.),
'bias': dict(weight_decay=0.),
'norm': dict(weight_decay=0.),
'gamma_1': dict(weight_decay=0.),
'gamma_2': dict(weight_decay=0.),
})
optimizer_config = dict(grad_clip=None, update_interval=1)
lr_config = dict(
policy='CosineAnnealingWarmupByEpoch',
by_epoch=True,
min_lr_ratio=0.00001 / 0.003,
warmup='linear',
warmup_by_epoch=True,
warmup_iters=5,
warmup_ratio=0.000001 / 0.003,
)
checkpoint_config = dict(interval=10)
# runtime settings
total_epochs = 800
ema = dict(decay=0.99996)

View File

@ -0,0 +1,17 @@
_base_ = './deitiii_base_patch16_192.py'
# model settings
model = dict(
type='Classification',
backbone=dict(
type='VisionTransformer',
img_size=[192],
num_classes=1000,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
drop_rate=0.,
drop_path_rate=0.2,
use_layer_scale=True))

View File

@ -0,0 +1,17 @@
_base_ = './deitiii_base_patch16_192.py'
# model settings
model = dict(
type='Classification',
backbone=dict(
type='VisionTransformer',
img_size=[192],
num_classes=1000,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
drop_rate=0.,
drop_path_rate=0.45,
use_layer_scale=True))

View File

@ -0,0 +1,86 @@
_base_ = './deitiii_base_patch16_192.py'
# model settings
model = dict(
type='Classification',
backbone=dict(
type='VisionTransformer',
img_size=[224],
num_classes=1000,
patch_size=16,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
qkv_bias=True,
drop_rate=0.,
drop_path_rate=0.05,
use_layer_scale=True))
data_train_list = 'data/imagenet1k/train.txt'
data_train_root = 'data/imagenet1k/train/'
data_test_list = 'data/imagenet1k/val.txt'
data_test_root = 'data/imagenet1k/val/'
dataset_type = 'ClsDataset'
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
three_augment_policies = [[
dict(type='PILGaussianBlur', prob=1.0, radius_min=0.1, radius_max=2.0),
], [
dict(type='Solarization', threshold=128),
], [
dict(type='Grayscale', num_output_channels=3),
]]
train_pipeline = [
dict(
type='RandomResizedCrop', size=224, scale=(0.08, 1.0),
interpolation=3), # interpolation='bicubic'
dict(type='RandomHorizontalFlip'),
dict(type='MMAutoAugment', policies=three_augment_policies),
dict(type='ColorJitter', brightness=0.3, contrast=0.3, saturation=0.3),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Collect', keys=['img', 'gt_labels'])
]
test_pipeline = [
dict(type='Resize', size=256, interpolation=3),
dict(type='CenterCrop', size=224),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Collect', keys=['img', 'gt_labels'])
]
data = dict(
imgs_per_gpu=256,
workers_per_gpu=8,
use_repeated_augment_sampler=True,
train=dict(
type=dataset_type,
data_source=dict(
list_file=data_train_list,
root=data_train_root,
type='ClsSourceImageList'),
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_source=dict(
list_file=data_test_list,
root=data_test_root,
type='ClsSourceImageList'),
pipeline=test_pipeline))
eval_pipelines = [
dict(
mode='test',
data=data['val'],
dist_eval=True,
evaluators=[dict(type='ClsEvaluator', topk=(1, 5))],
)
]
# optimizer
optimizer = dict(lr=0.004)
lr_config = dict(
min_lr_ratio=0.00001 / 0.004,
warmup_ratio=0.000001 / 0.004,
)

View File

@ -21,6 +21,9 @@
| hrnetw64 | [hrnetw64](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw64_jpg.py) | 79.884 | 95.04 | 5120 | 54.74 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/hrnetw64/epoch_100.pth) |
| vit-base-patch16 | [vit-base-patch16](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/vit/imagenet_vit_base_patch16_224_jpg.py) | 76.082 | 92.026 | 346 | 8.03 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/vit/vit-base-patch16/epoch_300.pth) |
| swin-tiny-patch4-window7 | [swin-tiny-patch4-window7](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/swint/imagenet_swin_tiny_patch4_window7_224_jpg.py) | 80.528 | 94.822 | 132 | 12.94 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/swint/swin-tiny-patch4-window7/epoch_300.pth) |
| deitiii-small-patch16-224 | [deitiii-small-patch16-224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/vit/imagenet_deitiii_small_patch16_224_jpg.py) | 81.408 | 95.388 | 89 | 4.53 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/deitiii/imagenet_deitiii_small_patch16_224/deitiii_small.pth) |
| deitiii-base-patch16-192 | [deitiii-base-patch16-192](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/vit/imagenet_deitiii_base_patch16_192_jpg.py) | 82.982 | 95.95 | 337 | 4.63 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/deitiii/imagenet_deitiii_base_patch16_192/deitiii_base.pth) |
| deitiii-large-patch16-192 | [deitiii-large-patch16-192](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/vit/imagenet_deitiii_large_patch16_192_jpg.py) | 83.902 | 96.296 | 1170 | 10.17 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/deitiii/imagenet_deitiii_large_patch16_192/deitiii_large.pth) |
(ps: 通过EasyCV训练得到模型结果推理的输入尺寸默认为224机器默认为V100 16G其中gpu memory记录的是gpu peak memory)

View File

@ -4,6 +4,7 @@ import torch
from torch.optim import *
from .builder import build_optimizer_constructor
from .lamb import Lamb
from .lars import LARS
from .layer_decay_optimizer_constructor import LayerDecayOptimizerConstructor
from .ranger import Ranger

View File

@ -0,0 +1,166 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import torch
from mmcv.runner import OPTIMIZERS
from torch.optim import Optimizer
@OPTIMIZERS.register_module()
class Lamb(Optimizer):
"""A pure pytorch variant of FuseLAMB (NvLamb variant) optimizer.
This class is copied from `timm`_. The LAMB was proposed in `Large Batch
Optimization for Deep Learning - Training BERT in 76 minutes`_.
.. _timm:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lamb.py
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
calculating running averages of gradient. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm
(default: 1.0)
trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
""" # noqa: E501
def __init__(self,
params,
lr=1e-3,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-6,
weight_decay=0.01,
grad_averaging=True,
max_grad_norm=1.0,
trust_clip=False,
always_adapt=False):
defaults = dict(
lr=lr,
bias_correction=bias_correction,
betas=betas,
eps=eps,
weight_decay=weight_decay,
grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm,
trust_clip=trust_clip,
always_adapt=always_adapt)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
device = self.param_groups[0]['params'][0].device
one_tensor = torch.tensor(
1.0, device=device
) # because torch.where doesn't handle scalars correctly
global_grad_norm = torch.zeros(1, device=device)
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
'Lamb does not support sparse gradients, consider '
'SparseAdam instead.')
global_grad_norm.add_(grad.pow(2).sum())
global_grad_norm = torch.sqrt(global_grad_norm)
# FIXME it'd be nice to remove explicit tensor conversion of scalars
# when torch.where promotes
# scalar types properly https://github.com/pytorch/pytorch/issues/9190
max_grad_norm = torch.tensor(
self.defaults['max_grad_norm'], device=device)
clip_global_grad_norm = torch.where(global_grad_norm > max_grad_norm,
global_grad_norm / max_grad_norm,
one_tensor)
for group in self.param_groups:
bias_correction = 1 if group['bias_correction'] else 0
beta1, beta2 = group['betas']
grad_averaging = 1 if group['grad_averaging'] else 0
beta3 = 1 - beta1 if grad_averaging else 1.0
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or
# pass list into kernel
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1
if bias_correction:
bias_correction1 = 1 - beta1**group['step']
bias_correction2 = 1 - beta2**group['step']
else:
bias_correction1, bias_correction2 = 1.0, 1.0
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.div_(clip_global_grad_norm)
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient valuesa
state['exp_avg'] = torch.zeros_like(p)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t
exp_avg_sq.mul_(beta2).addcmul_(
grad, grad, value=1 - beta2) # v_t
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
group['eps'])
update = (exp_avg / bias_correction1).div_(denom)
weight_decay = group['weight_decay']
if weight_decay != 0:
update.add_(p, alpha=weight_decay)
if weight_decay != 0 or group['always_adapt']:
# Layer-wise LR adaptation. By default, skip adaptation on
# parameters that are
# excluded from weight decay, unless always_adapt == True,
# then always enabled.
w_norm = p.norm(2.0)
g_norm = update.norm(2.0)
# FIXME nested where required since logical and/or not
# working in PT XLA
trust_ratio = torch.where(
w_norm > 0,
torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
one_tensor,
)
if group['trust_clip']:
# LAMBC trust clipping, upper bound fixed at one
trust_ratio = torch.minimum(trust_ratio, one_tensor)
update.mul_(trust_ratio)
p.add_(update, alpha=-group['lr'])
return loss

View File

@ -8,7 +8,7 @@ from typing import Sequence
import mmcv
import numpy as np
from PIL import Image
from PIL import Image, ImageFilter
from easycv.datasets.registry import PIPELINES
from easycv.datasets.shared.pipelines import Compose
@ -1043,3 +1043,37 @@ class Cutout(object):
repr_str += f'pad_val={self.pad_val}, '
repr_str += f'prob={self.prob})'
return repr_str
@PIPELINES.register_module()
class PILGaussianBlur(object):
def __init__(self, prob=0.1, radius_min=0.1, radius_max=2.):
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
f'got {prob} instead.'
assert isinstance(radius_min, (int, float)), 'The radius_min type must '\
f'be int or float, but got {type(radius_min)} instead.'
assert isinstance(radius_max, (int, float)), 'The radius_max type must '\
f'be int or float, but got {type(radius_max)} instead.'
self.prob = prob
self.radius_min = radius_min
self.radius_max = radius_max
def __call__(self, results):
if np.random.rand() > self.prob:
return results
for key in results.get('img_fields', ['img']):
img = results[key].filter(
ImageFilter.GaussianBlur(
radius=random.uniform(self.radius_min, self.radius_max)))
results[key] = img
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, '
repr_str += f'radius_min={self.radius_min}, '
repr_str += f'radius_max={self.radius_max})'
return repr_str

View File

@ -14,7 +14,7 @@ from easycv.datasets.shared.odps_reader import set_dataloader_workid
from easycv.utils.dist_utils import sync_random_seed
from easycv.utils.torchacc_util import is_torchacc_enabled
from .collate import CollateWrapper
from .sampler import DistributedMPSampler, DistributedSampler
from .sampler import DistributedMPSampler, DistributedSampler, RASampler
if platform.system() != 'Windows':
# https://github.com/pytorch/pytorch/issues/973
@ -35,6 +35,7 @@ def build_dataloader(dataset,
odps_config=None,
persistent_workers=False,
collate_hooks=None,
use_repeated_augment_sampler=False,
**kwargs):
"""Build PyTorch DataLoader.
In distributed training, each GPU/process has a dataloader.
@ -56,6 +57,8 @@ def build_dataloader(dataset,
data in worker process can be reused.
persistent_workers (bool) : After pytorch1.7, could use persistent_workers=True to
avoid reconstruct dataworker before each epoch, speed up before epoch
use_repeated_augment_sampler (bool) : If set true, it will use RASampler.
Default: False.
kwargs: any keyword argument to be used to initialize DataLoader
Returns:
DataLoader: A PyTorch dataloader.
@ -68,7 +71,9 @@ def build_dataloader(dataset,
'split_huge_listfile_byrank',
False)
if hasattr(dataset, 'm_per_class') and dataset.m_per_class > 1:
if use_repeated_augment_sampler:
sampler = RASampler(dataset, world_size, rank, shuffle=shuffle)
elif hasattr(dataset, 'm_per_class') and dataset.m_per_class > 1:
sampler = DistributedMPSampler(
dataset,
world_size,
@ -88,7 +93,10 @@ def build_dataloader(dataset,
else:
if replace:
raise NotImplementedError
if hasattr(dataset, 'm_per_class') and dataset.m_per_class > 1:
if use_repeated_augment_sampler:
sampler = RASampler(dataset, 1, 0, shuffle=shuffle)
elif hasattr(dataset, 'm_per_class') and dataset.m_per_class > 1:
sampler = DistributedMPSampler(
dataset, 1, 0, shuffle=shuffle, replace=replace)
else:

View File

@ -6,6 +6,7 @@ import random
import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import get_dist_info
from torch.utils.data import DistributedSampler as _DistributedSampler
from torch.utils.data import Sampler
@ -469,3 +470,73 @@ class DistributedGivenIterationSampler(Sampler):
def set_epoch(self, epoch):
pass
class RASampler(torch.utils.data.Sampler):
"""Sampler that restricts data loading to a subset of the dataset for distributed,
with repeated augmentation.
It ensures that different each augmented version of a sample will be visible to a
different process (GPU)
Heavily based on torch.utils.data.DistributedSampler
"""
def __init__(self,
dataset,
num_replicas=None,
rank=None,
shuffle=True,
num_repeats: int = 3):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError(
'Requires distributed package to be available')
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError(
'Requires distributed package to be available')
rank = dist.get_rank()
if num_repeats < 1:
raise ValueError('num_repeats should be greater than 0')
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.num_repeats = num_repeats
self.epoch = 0
self.num_samples = int(
math.ceil(
len(self.dataset) * self.num_repeats / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
# self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
self.num_selected_samples = int(
math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
self.shuffle = shuffle
def __iter__(self):
if self.shuffle:
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(self.dataset), generator=g)
else:
indices = torch.arange(start=0, end=len(self.dataset))
# add extra samples to make it evenly divisible
indices = torch.repeat_interleave(
indices, repeats=self.num_repeats, dim=0).tolist()
padding_size: int = self.total_size - len(indices)
if padding_size > 0:
indices += indices[:padding_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices[:self.num_selected_samples])
def __len__(self):
return self.num_selected_samples
def set_epoch(self, epoch):
self.epoch = epoch

View File

@ -13,7 +13,8 @@ from .eval_hook import DistEvalHook, EvalHook
from .export_hook import ExportHook
from .extractor import Extractor
from .logger import PreLoggerHook
from .lr_update_hook import StepFixCosineAnnealingLrUpdaterHook
from .lr_update_hook import (CosineAnnealingWarmupByEpochLrUpdaterHook,
StepFixCosineAnnealingLrUpdaterHook)
from .optimizer_hook import OptimizerHook
from .oss_sync_hook import OSSSyncHook
from .registry import HOOKS
@ -33,7 +34,8 @@ __all__ = [
'OSSSyncHook', 'HOOKS', 'TIMEHook', 'SWAVHook', 'SyncNormHook',
'SyncRandomSizeHook', 'TensorboardLoggerHookV2', 'WandbLoggerHookV2',
'YOLOXLrUpdaterHook', 'YOLOXModeSwitchHook', 'MixupCollateHook',
'PreLoggerHook', 'StepFixCosineAnnealingLrUpdaterHook', 'ThroughputHook'
'PreLoggerHook', 'StepFixCosineAnnealingLrUpdaterHook',
'CosineAnnealingWarmupByEpochLrUpdaterHook', 'ThroughputHook'
]
if LooseVersion(torch.__version__) >= LooseVersion('1.6.0'):

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv import runner
from mmcv.runner import HOOKS
from mmcv.runner.hooks.lr_updater import (CosineAnnealingLrUpdaterHook,
annealing_cos)
@ -54,3 +55,29 @@ class StepFixCosineAnnealingLrUpdaterHook(CosineAnnealingLrUpdaterHook):
target_lr = self.min_lr
return annealing_cos(base_lr, target_lr, progress / max_progress)
@HOOKS.register_module()
class CosineAnnealingWarmupByEpochLrUpdaterHook(CosineAnnealingLrUpdaterHook):
def before_train_iter(self, runner: 'runner.BaseRunner'):
cur_iter = runner.iter
epoch_len = len(runner.data_loader)
assert isinstance(self.warmup_iters, int)
if not self.by_epoch:
self.regular_lr = self.get_regular_lr(runner)
if self.warmup is None or cur_iter >= self.warmup_iters:
self._set_lr(runner, self.regular_lr)
else:
if cur_iter % epoch_len == 0:
warmup_lr = self.get_warmup_lr(cur_iter)
self._set_lr(runner, warmup_lr)
elif self.by_epoch:
if self.warmup is None or cur_iter > self.warmup_iters:
return
elif cur_iter == self.warmup_iters:
self._set_lr(runner, self.regular_lr)
else:
if cur_iter % epoch_len == 0:
warmup_lr = self.get_warmup_lr(cur_iter)
self._set_lr(runner, warmup_lr)

View File

@ -19,4 +19,5 @@ from .resnet_jit import ResNetJIT
from .resnext import ResNeXt
from .shuffle_transformer import ShuffleTransformer
from .swin_transformer import SwinTransformer
from .vision_transformer import VisionTransformer
from .vitdet import ViTDet

View File

@ -10,7 +10,7 @@ from timm.models.layers import trunc_normal_
from easycv.models.registry import BACKBONES
from easycv.models.utils import DropPath
from easycv.models.utils.pos_embed import get_2d_sincos_pos_embed
from .vit_transfomer_dynamic import Block
from .vision_transformer import Block
class PatchEmbed(nn.Module):

View File

@ -0,0 +1,287 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
"""
Mostly copy-paste from timm library.
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
import math
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_
from easycv.models.utils import DropPath, Mlp
from ..registry import BACKBONES
class Attention(nn.Module):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn
class Block(nn.Module):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
use_layer_scale=False,
init_values=1e-4):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
self.use_layer_scale = use_layer_scale
if self.use_layer_scale:
self.gamma_1 = nn.Parameter(
init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2 = nn.Parameter(
init_values * torch.ones((dim)), requires_grad=True)
def forward(self, x, return_attention=False, rel_pos_bias=None):
y, attn = self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
if return_attention:
return attn
if self.use_layer_scale:
x = x + self.drop_path(self.gamma_1 * y)
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(y)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def forward_fea_and_attn(self, x):
y, attn = self.attn(self.norm1(x))
if self.use_layer_scale:
x = x + self.drop_path(self.gamma_1 * y)
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(y)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x, attn
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
num_patches = (img_size // patch_size) * (img_size // patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2)
return x
@BACKBONES.register_module
class VisionTransformer(nn.Module):
""" DeiT III is based on ViT. It uses some strategies to make the vit model
better, just like layer scale, stochastic depth, 3-Augment.
Paper link: https://arxiv.org/pdf/2204.07118.pdf (DeiT III: Revenge of the ViT)
Args:
img_size (list): Input image size. img_size=[224] means the image size is
224*224. img_size=[192, 224] means the image size is 192*224.
patch_size (int): The patch size. Default: 16
in_chans (int): The num of input channels. Default: 3
num_classes (int): The num of picture classes. Default: 1000
embed_dim (int): The dimensions of embedding. Default: 768
depth (int): The num of blocks. Default: 12
num_heads (int): Parallel attention heads. Default: 12
mlp_ratio (float): Mlp expansion ratio. Default: 4.0
qkv_bias (bool): Does kqv use bias. Default: False
qk_scale (float | None): In the step of self-attention, if qk_scale is not
None, it will use qk_scale to scale the q @ k. Otherwise it will use
head_dim**-0.5 instead of qk_scale. Default: None
drop_rate (float): Probability of an element to be zeroed after the feed
forward layer. Default: 0.0
drop_path_rate (float): Stochastic depth rate. Default: 0
norm_layer (nn.Module): normalization layer
use_dense_prediction (bool): If use_dense_prediction is True, the global
pool and norm will before head will be removed.(if any) Default: False
global_pool (bool): Global pool before head. Default: False
use_layer_scale (bool): If use_layer_scale is True, it will use layer
scale. Default: False
init_scale (float): It is used for layer scale in Block to scale the
gamma_1 and gamma_2.
"""
def __init__(self,
img_size=[224],
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
use_dense_prediction=False,
global_pool=False,
use_layer_scale=False,
init_scale=1e-4,
**kwargs):
super().__init__()
self.num_features = self.embed_dim = embed_dim
self.patch_embed = PatchEmbed(
img_size=img_size[0],
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
self.drop_path_rate = drop_path_rate
self.depth = depth
dpr = [drop_path_rate for i in range(depth)]
self.blocks = nn.ModuleList([
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
use_layer_scale=use_layer_scale,
init_values=init_scale) for i in range(depth)
])
self.norm = norm_layer(embed_dim)
# Classifier head
self.head = nn.Linear(
embed_dim, num_classes) if num_classes > 0 else nn.Identity()
# Dense prediction head
self.use_dense_prediction = use_dense_prediction
if self.use_dense_prediction:
self.head_dense = None
# Use global average pooling
self.global_pool = global_pool
if self.global_pool:
self.fc_norm = norm_layer(embed_dim)
self.norm = None
def init_weights(self):
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
x = self.forward_features(x)
x = self.pos_drop(x)
x = self.head(x)
return [x]
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = x + self.pos_embed
x = torch.cat((cls_tokens, x), dim=1)
for blk in self.blocks:
x = blk(x)
if self.norm is not None:
x = self.norm(x)
if self.use_dense_prediction:
return x[:, 0], x[:, 1:]
else:
if self.global_pool:
x = x[:, 1:, :].mean(dim=1)
return self.fc_norm(x)
else:
return x[:, 0]

View File

@ -12,197 +12,25 @@ from functools import partial
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_
from easycv.models.utils import DropPath, Mlp
from easycv.models.backbones.vision_transformer import VisionTransformer
class Attention(nn.Module):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn
class Block(nn.Module):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
def forward(self, x, return_attention=False, rel_pos_bias=None):
y, attn = self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
if return_attention:
return attn
x = x + self.drop_path(y)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def forward_fea_and_attn(self, x):
y, attn = self.attn(self.norm1(x))
x = x + self.drop_path(y)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x, attn
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
num_patches = (img_size // patch_size) * (img_size // patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class DynamicVisionTransformer(nn.Module):
class DynamicVisionTransformer(VisionTransformer):
"""Dynamic Vision Transformer """
def __init__(self,
img_size=[224],
patch_size=16,
in_chans=3,
num_classes=0,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
use_dense_prediction=False,
global_pool=False,
**kwargs):
super().__init__()
self.num_features = self.embed_dim = embed_dim
def __init__(self, **kwargs):
super(DynamicVisionTransformer, self).__init__(**kwargs)
self.patch_embed = PatchEmbed(
img_size=img_size[0],
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
torch.zeros(1, num_patches + 1, self.embed_dim))
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer) for i in range(depth)
])
self.norm = norm_layer(embed_dim)
# Classifier head
self.head = nn.Linear(
embed_dim, num_classes) if num_classes > 0 else nn.Identity()
# Dense prediction head
self.use_dense_prediction = use_dense_prediction
if self.use_dense_prediction:
self.head_dense = None
# Use global average pooling
self.global_pool = global_pool
if self.global_pool:
self.fc_norm = norm_layer(embed_dim)
self.norm = None
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
dpr = [
x.item()
for x in torch.linspace(0, self.drop_path_rate, self.depth)
]
def forward(self, x):
# convert to list

View File

@ -53,22 +53,15 @@ class Classification(BaseModel):
if 'mixUp' in train_preprocess:
rank, _ = get_dist_info()
np.random.seed(rank + 12)
if not mixup_cfg:
num_classes = head.get(
'num_classes',
1000) if 'num_classes' in head else backbone.get(
'num_classes', 1000)
mixup_cfg = dict(
mixup_alpha=0.8,
cutmix_alpha=1.0,
cutmix_minmax=None,
prob=1.0,
switch_prob=0.5,
mode='batch',
label_smoothing=0.1,
num_classes=num_classes)
self.mixup = Mixup(**mixup_cfg)
head.loss_config = {'type': 'SoftTargetCrossEntropy'}
if mixup_cfg is not None:
if 'num_classes' in mixup_cfg:
self.mixup = Mixup(**mixup_cfg)
elif 'num_classes' in head or 'num_classes' in backbone:
num_classes = head.get(
'num_classes'
) if 'num_classes' in head else backbone.get('num_classes')
mixup_cfg['num_classes'] = num_classes
self.mixup = Mixup(**mixup_cfg)
train_preprocess.remove('mixUp')
self.train_preprocess = [
self.preprocess_key_map[i] for i in train_preprocess
@ -173,7 +166,10 @@ class Classification(BaseModel):
for preprocess in self.train_preprocess:
img = preprocess(img)
if hasattr(self, 'mixup'):
# When the number of samples in the dataset is odd, the last batch size of each epoch will be odd,
# which will cause mixup to report an error. To avoid this situation, mixup is applied only when
# the batch size is even.
if hasattr(self, 'mixup') and len(img) % 2 == 0:
img, gt_labels = self.mixup(img, gt_labels)
x = self.forward_backbone(img)

View File

@ -28,7 +28,8 @@ class ClsHead(nn.Module):
},
input_feature_index=[0],
init_cfg=dict(
type='Normal', layer='Linear', std=0.01, bias=0.)):
type='Normal', layer='Linear', std=0.01, bias=0.),
use_num_classes=True):
super(ClsHead, self).__init__()
self.with_avg_pool = with_avg_pool
@ -46,7 +47,8 @@ class ClsHead(nn.Module):
'label_smooth must be given as a float number in [0,1]'
logger.info(f'=> Augment: using label smooth={self.label_smooth}')
loss_config['label_smooth'] = label_smooth
loss_config['num_classes'] = num_classes
if use_num_classes:
loss_config['num_classes'] = num_classes
self.criterion = build_from_cfg(loss_config, LOSSES)

View File

@ -115,6 +115,7 @@ def binary_cross_entropy(pred,
class_weight=None,
ignore_index=-100,
avg_non_ignore=False,
label_ceil=False,
**kwargs):
"""Calculate the binary CrossEntropy loss.
@ -132,11 +133,14 @@ def binary_cross_entropy(pred,
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
`New in version 0.23.0.`
label_ceil (bool): When use bce and set label_ceil=True,
it will make elements belong to (0, 1] in label change to 1.
Default: False.
Returns:
torch.Tensor: The calculated loss
"""
if len(pred.shape) > 1 and pred.shape(1) == 1:
if len(pred.shape) > 1 and pred.shape[1] == 1:
# For binary class segmentation, the shape of pred is
# [N, 1, H, W] and that of label is [N, H, W].
# As the ignore_index often set as 255, so the
@ -162,6 +166,8 @@ def binary_cross_entropy(pred,
weight = weight * valid_mask
else:
weight = valid_mask
if label_ceil:
label = label.gt(0.0).type(label.dtype)
# average loss over non-ignored and valid elements
if reduction == 'mean' and avg_factor is None and avg_non_ignore:
avg_factor = valid_mask.sum().item()
@ -234,6 +240,9 @@ class CrossEntropyLoss(nn.Module):
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
`New in version 0.23.0.`
label_ceil (bool): When use bce and set label_ceil=True,
it will make elements belong to (0, 1] in label change to 1.
Default: False.
"""
def __init__(self,
@ -243,10 +252,16 @@ class CrossEntropyLoss(nn.Module):
class_weight=None,
loss_weight=1.0,
loss_name='loss_ce',
avg_non_ignore=False):
avg_non_ignore=False,
label_ceil=False):
super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid
if label_ceil:
if not use_sigmoid:
raise ValueError(
'label_ceil is supported only when use_sigmoid is true. If not use bce, please set label_ceil=False'
)
self.use_mask = use_mask
self.reduction = reduction
self.loss_weight = loss_weight
@ -266,6 +281,7 @@ class CrossEntropyLoss(nn.Module):
else:
self.cls_criterion = cross_entropy
self._loss_name = loss_name
self.label_ceil = label_ceil
def extra_repr(self):
"""Extra repr."""
@ -289,16 +305,29 @@ class CrossEntropyLoss(nn.Module):
else:
class_weight = None
# Note: for BCE loss, label < 0 is invalid.
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
weight,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
avg_non_ignore=self.avg_non_ignore,
ignore_index=ignore_index,
**kwargs)
if self.use_sigmoid:
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
weight,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
avg_non_ignore=self.avg_non_ignore,
ignore_index=ignore_index,
label_ceil=self.label_ceil,
**kwargs)
else:
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
weight,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
avg_non_ignore=self.avg_non_ignore,
ignore_index=ignore_index,
**kwargs)
return loss_cls
@property

View File

@ -0,0 +1,42 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
import numpy as np
import torch
from numpy.testing import assert_array_almost_equal
class DeiTIIITest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
@unittest.skip('skip DeiT III unittest')
def test_deitiii(self):
model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/deitiii/epoch_800.pth'
config_path = 'configs/classification/imagenet/vit/imagenet_deitiii_large_patch16_192_jpg.py'
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/deitiii_demo.JPEG'
# deitiii = ClsPredictor(model_path, config_path)
deitiii = []
output = deitiii.predict(img)
self.assertIn('prob', output)
self.assertIn('class', output)
self.assertEqual(len(output['prob'][0]), 1000)
assert_array_almost_equal(
output['prob'][0][:10],
torch.Tensor([
2.04629918698628899e-06, 5.27398606209317222e-06,
5.52915162188583054e-06, 3.60625563189387321e-06,
3.29447357216849923e-06, 5.61309570912271738e-06,
8.93703327164985240e-06, 4.89157764604897238e-06,
4.39371024185675196e-06, 5.21611764270346612e-06
]),
decimal=8)
self.assertEqual(int(output['class']), 948)
if __name__ == '__main__':
unittest.main()

View File

@ -273,8 +273,9 @@ def main():
drop_last=getattr(cfg.data, 'drop_last', False),
reuse_worker_cache=cfg.data.get('reuse_worker_cache', False),
persistent_workers=cfg.data.get('persistent_workers', False),
collate_hooks=cfg.data.get('train_collate_hooks', []))
for ds in datasets
collate_hooks=cfg.data.get('train_collate_hooks', []),
use_repeated_augment_sampler=cfg.data.get(
'use_repeated_augment_sampler', False)) for ds in datasets
]
else:
default_args = dict(