EasyCV/easycv/models/classification/classification.py

325 lines
12 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Dict, List, Optional
import numpy as np
import torch
import torch.nn as nn
from mmcv.runner import get_dist_info
from timm.data.mixup import Mixup
from easycv.framework.errors import KeyError, NotImplementedError, ValueError
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger, print_log
from easycv.utils.preprocess_function import (bninceptionPre, gaussianBlur,
mixUpCls, randomErasing)
from .. import builder
from ..base import BaseModel
from ..registry import MODELS
from ..utils import Sobel
@MODELS.register_module
class Classification(BaseModel):
"""
Args:
pretrained: Select one {str or True or False/None}.
if pretrained == str, load model from specified path;
if pretrained == True, load model from default path(currently only supports timm);
if pretrained == False or None, load from init weights.
"""
def __init__(self,
backbone,
train_preprocess=[],
with_sobel=False,
head=None,
neck=None,
pretrained=True,
mixup_cfg=None):
super(Classification, self).__init__()
self.with_sobel = with_sobel
self.pretrained = pretrained
if with_sobel:
self.sobel_layer = Sobel()
else:
self.sobel_layer = None
self.preprocess_key_map = {
'bninceptionPre': bninceptionPre,
'gaussianBlur': gaussianBlur,
'mixUpCls': mixUpCls,
'randomErasing': randomErasing
}
if 'mixUp' in train_preprocess:
rank, _ = get_dist_info()
np.random.seed(rank + 12)
if mixup_cfg is not None:
if 'num_classes' in mixup_cfg:
self.mixup = Mixup(**mixup_cfg)
elif 'num_classes' in head or 'num_classes' in backbone:
num_classes = head.get(
'num_classes'
) if 'num_classes' in head else backbone.get('num_classes')
mixup_cfg['num_classes'] = num_classes
self.mixup = Mixup(**mixup_cfg)
train_preprocess.remove('mixUp')
self.train_preprocess = [
self.preprocess_key_map[i] for i in train_preprocess
]
self.backbone = builder.build_backbone(backbone)
assert head is not None, 'Classification head should be configed'
if type(head) == list:
self.head_num = len(head)
tmp_head_list = [builder.build_head(h) for h in head]
else:
self.head_num = 1
tmp_head_list = [builder.build_head(head)]
# do this setattr to make sure nn.Module to be attr of nn.Module
for idx, h in enumerate(tmp_head_list):
setattr(self, 'head_%d' % idx, h)
if type(neck) == list:
self.neck_num = len(neck)
tmp_neck_list = [builder.build_neck(n) for n in neck]
elif neck is not None:
self.neck_num = 1
tmp_neck_list = [builder.build_neck(neck)]
else:
self.neck_num = 0
tmp_neck_list = []
# do this setattr to make sure nn.Module to be attr of nn.Module
for idx, n in enumerate(tmp_neck_list):
setattr(self, 'neck_%d' % idx, n)
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.activate_fn = nn.Softmax(dim=1)
self.extract_list = ['neck']
self.init_weights()
def init_weights(self):
logger = get_root_logger()
if isinstance(self.pretrained, str):
load_checkpoint(
self.backbone, self.pretrained, strict=False, logger=logger)
elif self.pretrained:
if self.backbone.__class__.__name__ == 'PytorchImageModelWrapper':
self.backbone.init_weights(pretrained=self.pretrained)
elif hasattr(self.backbone, 'default_pretrained_model_path'
) and self.backbone.default_pretrained_model_path:
print_log(
'load model from default path: {}'.format(
self.backbone.default_pretrained_model_path), logger)
load_checkpoint(
self.backbone,
self.backbone.default_pretrained_model_path,
strict=False,
logger=logger,
revise_keys=[
(r'^backbone\.', '')
]) # revise_keys is used to avoid load mismatches
else:
raise ValueError(
'default_pretrained_model_path for {} not found'.format(
self.backbone.__class__.__name__))
else:
print_log('load model from init weights')
self.backbone.init_weights()
for idx in range(self.head_num):
h = getattr(self, 'head_%d' % idx)
h.init_weights()
for idx in range(self.neck_num):
n = getattr(self, 'neck_%d' % idx)
n.init_weights()
def forward_backbone(self, img: torch.Tensor) -> List[torch.Tensor]:
"""Forward backbone
Returns:
x (tuple): backbone outputs
"""
if self.sobel_layer is not None:
img = self.sobel_layer(img)
x = self.backbone(img)
return x
def forward_onnx(self, img: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
forward_onnx means generate prob from image only support one neck + one head
"""
x = self.forward_backbone(img) # tuple
# if self.neck_num > 0:
if hasattr(self, 'neck_0'):
x = self.neck_0([i for i in x])
out = self.head_0(x)[0].cpu()
out = self.activate_fn(out)
return out
@torch.jit.unused
def forward_train(self, img, gt_labels) -> Dict[str, torch.Tensor]:
"""
In forward train, model will forward backbone + neck / multi-neck, get alist of output tensor,
and put this list to head / multi-head, to compute each loss
"""
# for mxk sampler, use datasource type = ClsSourceImageListByClass will sample k img in 1 class,
# input data will be m x k x c x h x w, should be reshape to (m x k) x c x h x w
if img.dim() == 5:
new_shape = [
img.shape[0] * img.shape[1], img.shape[2], img.shape[3],
img.shape[4]
]
img = img.view(new_shape)
gt_labels = gt_labels.view([-1])
for preprocess in self.train_preprocess:
img = preprocess(img)
# When the number of samples in the dataset is odd, the last batch size of each epoch will be odd,
# which will cause mixup to report an error. To avoid this situation, mixup is applied only when
# the batch size is even.
if hasattr(self, 'mixup') and len(img) % 2 == 0:
img, gt_labels = self.mixup(img, gt_labels)
x = self.forward_backbone(img)
if self.neck_num > 0:
tmp = []
for idx in range(self.neck_num):
h = getattr(self, 'neck_%d' % idx)
tmp += h(x)
x = tmp
else:
x = x
losses = {}
for idx in range(self.head_num):
h = getattr(self, 'head_%d' % idx)
outs = h(x)
loss_inputs = (outs, gt_labels)
hlosses = h.loss(*loss_inputs)
if 'loss' in losses.keys():
losses['loss'] += hlosses['loss']
else:
losses['loss'] = hlosses['loss']
return losses
# @torch.jit.unused
def forward_test(self, img: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
forward_test means generate prob/class from image only support one neck + one head
"""
x = self.forward_backbone(img) # tuple
# if self.neck_num > 0:
if hasattr(self, 'neck_0'):
x = self.neck_0([i for i in x])
out = self.head_0(x)[0].cpu()
result = {}
result['prob'] = self.activate_fn(out)
result['class'] = torch.argmax(result['prob'])
return result
@torch.jit.unused
def forward_test_label(self, img, gt_labels) -> Dict[str, torch.Tensor]:
"""
forward_test_label means generate prob/class from image only support one neck + one head
ps : head init need set the input feature idx
"""
x = self.forward_backbone(img) # tuple
if hasattr(self, 'neck_0'):
x = self.neck_0([i for i in x])
out = [self.head_0(x)[0].cpu()]
keys = ['neck']
keys.append('gt_labels')
out.append(gt_labels.cpu())
return dict(zip(keys, out))
def aug_test(self, imgs):
raise NotImplementedError
def forward_feature(self, img) -> Dict[str, torch.Tensor]:
"""Forward feature means forward backbone + neck/multineck ,get dict of output feature,
self.neck_num = 0: means only forward backbone, output backbone feature with avgpool, with key neck,
self.neck_num > 0: means has 1/multi neck, output neck's feature with key neck_neckidx_featureidx, suck as neck_0_0
Returns:
x (torch.Tensor): feature tensor
"""
return_dict = {}
x = self.backbone(img)
# return_dict['backbone'] = x[-1]
if hasattr(self, 'neck_0'):
tmp = []
for idx in range(self.neck_num):
neck_name = 'neck_%d' % idx
h = getattr(self, neck_name)
neck_h = h([i for i in x])
tmp = tmp + neck_h
for j in range(len(neck_h)):
neck_name = 'neck_%d_%d' % (idx, j)
return_dict['neck_%d_%d' % (idx, j)] = neck_h[j]
if neck_name not in self.extract_list:
self.extract_list.append(neck_name)
return_dict['neck'] = tmp[0]
else:
feature = self.avg_pool(x[-1])
feature = feature.view(feature.size(0), -1)
return_dict['neck'] = feature
return return_dict
def update_extract_list(self, key):
if key not in self.extract_list:
self.extract_list.append(key)
return
def forward(
self,
img: torch.Tensor,
mode: str = 'train',
gt_labels: Optional[torch.Tensor] = None,
img_metas: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
# TODO: support Dict[any, any] type of img_metas
del img_metas # fake img_metas for support jit
if mode == 'train':
assert gt_labels is not None
return self.forward_train(img, gt_labels)
elif mode == 'test':
if gt_labels is not None:
return self.forward_test_label(img, gt_labels)
else:
return self.forward_test(img)
elif mode == 'onnx':
return self.forward_onnx(img)
elif mode == 'extract':
rd = self.forward_feature(img)
rv = {}
for name in self.extract_list:
if name in rd.keys():
rv[name] = rd[name].cpu()
else:
raise ValueError(
'Extract {} is not support in classification models'.
format(name))
if gt_labels is not None:
rv['gt_labels'] = gt_labels.cpu()
return rv
else:
raise KeyError('No such mode: {}'.format(mode))