EasyCV/easycv/models/detection/detectors/detection.py

121 lines
5.0 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
from easycv.models.base import BaseModel
from easycv.models.builder import (MODELS, build_backbone, build_head,
build_neck)
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger, print_log
@MODELS.register_module
class Detection(BaseModel):
def __init__(self, backbone, head=None, neck=None, pretrained=True):
super(Detection, self).__init__()
self.pretrained = pretrained
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
self.head = build_head(head)
self.init_weights()
@property
def with_neck(self):
"""bool: whether the detector has a neck"""
return hasattr(self, 'neck') and self.neck is not None
def init_weights(self):
logger = get_root_logger()
if isinstance(self.pretrained, str):
load_checkpoint(
self.backbone, self.pretrained, strict=False, logger=logger)
elif self.pretrained:
if self.backbone.__class__.__name__ == 'PytorchImageModelWrapper':
self.backbone.init_weights(pretrained=self.pretrained)
elif hasattr(self.backbone, 'default_pretrained_model_path'
) and self.backbone.default_pretrained_model_path:
print_log(
'load model from default path: {}'.format(
self.backbone.default_pretrained_model_path), logger)
load_checkpoint(
self.backbone,
self.backbone.default_pretrained_model_path,
strict=False,
logger=logger)
else:
raise ValueError(
'default_pretrained_model_path for {} not found'.format(
self.backbone.__class__.__name__))
else:
print_log('load model from init weights')
self.backbone.init_weights()
if self.with_neck:
self.neck.init_weights()
self.head.init_weights()
def extract_feat(self, img):
"""Directly extract features from the backbone+neck."""
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def forward_train(self, imgs, img_metas, gt_bboxes, gt_labels):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
gt_bboxes (list[Tensor]): Each item are the truth boxes for each
image in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): Class indices corresponding to each box
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
boxes can be ignored when computing the loss.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
# NOTE the batched image size information may be useful, e.g.
# in DETR, this is needed for the construction of masks, which is
# then used for the transformer_head.
batch_input_shape = tuple(imgs[0].size()[-2:])
for i in range(len(img_metas)):
img_metas[i]['batch_input_shape'] = batch_input_shape
x = self.extract_feat(imgs)
losses = self.head.forward_train(x, img_metas, gt_bboxes, gt_labels)
return losses
def forward_test(self, imgs, img_metas):
"""
Args:
imgs (List[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains all images in the batch.
img_metas (List[List[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch.
"""
# NOTE the batched image size information may be useful, e.g.
# in DETR, this is needed for the construction of masks, which is
# then used for the transformer_head.
for img, img_meta in zip(imgs, img_metas):
batch_size = len(img_meta)
for img_id in range(batch_size):
img_meta[img_id]['batch_input_shape'] = tuple(img.size()[-2:])
x = self.extract_feat(imgs[0])
results = self.head.forward_test(x, img_metas[0])
return results
def forward(self, img, mode='train', **kwargs):
if mode == 'train':
return self.forward_train(img, **kwargs)
elif mode == 'test':
return self.forward_test(img, **kwargs)