EasyCV/easycv/models/pose/top_down.py

239 lines
8.4 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
# Adapt from https://raw.githubusercontent.com/open-mmlab/mmpose/master/mmpose/models/detectors/top_down.py
import warnings
import mmcv
import numpy as np
from mmcv.image import imwrite
from mmcv.utils.misc import deprecated_api_warning
from mmcv.visualization.image import imshow
from easycv.core.visualization import imshow_bboxes, imshow_keypoints
from easycv.models import builder
from easycv.models.base import BaseModel
from easycv.models.builder import MODELS
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger
@MODELS.register_module()
class TopDown(BaseModel):
"""Top-down pose detectors.
Args:
backbone (dict): Backbone modules to extract feature.
keypoint_head (dict): Keypoint head to process feature.
train_cfg (dict): Config for training. Default: None.
test_cfg (dict): Config for testing. Default: None.
pretrained (str): Path to the pretrained models.
loss_pose (None): Deprecated arguments. Please use
`loss_keypoint` for heads instead.
"""
def __init__(self,
backbone,
neck=None,
keypoint_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
loss_pose=None):
super().__init__()
self.pretrained = pretrained
self.backbone = builder.build_backbone(backbone)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
if neck is not None:
self.neck = builder.build_neck(neck)
if keypoint_head is not None:
keypoint_head['train_cfg'] = train_cfg
keypoint_head['test_cfg'] = test_cfg
if 'loss_keypoint' not in keypoint_head and loss_pose is not None:
warnings.warn(
'`loss_pose` for TopDown is deprecated, '
'use `loss_keypoint` for heads instead. See '
'https://github.com/open-mmlab/mmpose/pull/382'
' for more information.', DeprecationWarning)
keypoint_head['loss_keypoint'] = loss_pose
self.keypoint_head = builder.build_head(keypoint_head)
self.init_weights()
@property
def with_neck(self):
"""Check if has keypoint_head."""
return hasattr(self, 'neck')
@property
def with_keypoint(self):
"""Check if has keypoint_head."""
return hasattr(self, 'keypoint_head')
def init_weights(self):
"""Weight initialization for model."""
if isinstance(self.pretrained, str):
logger = get_root_logger()
load_checkpoint(
self.backbone, self.pretrained, strict=False, logger=logger)
else:
self.backbone.init_weights()
if self.with_neck:
self.neck.init_weights()
if self.with_keypoint:
self.keypoint_head.init_weights()
def forward_train(self, img, target, target_weight, img_metas, **kwargs):
"""Defines the computation performed at every call when training."""
output = self.backbone(img)
if self.with_neck:
output = self.neck(output)
if self.with_keypoint:
output = self.keypoint_head(output)
# if return loss
losses = dict()
if self.with_keypoint:
keypoint_losses = self.keypoint_head.get_loss(
output, target, target_weight)
losses.update(keypoint_losses)
keypoint_accuracy = self.keypoint_head.get_accuracy(
output, target, target_weight)
losses.update(keypoint_accuracy)
return losses
def forward_test(self, img, img_metas, return_heatmap=False, **kwargs):
"""Defines the computation performed at every call when testing."""
assert img.size(0) == len(img_metas)
batch_size, _, img_height, img_width = img.shape
if batch_size > 1:
assert 'bbox_id' in img_metas[0]
result = {}
features = self.backbone(img)
if self.with_neck:
features = self.neck(features)
if self.with_keypoint:
output_heatmap = self.keypoint_head.inference_model(
features, flip_pairs=None)
if self.test_cfg.get('flip_test', True):
img_flipped = img.flip(3)
features_flipped = self.backbone(img_flipped)
if self.with_neck:
features_flipped = self.neck(features_flipped)
if self.with_keypoint:
output_flipped_heatmap = self.keypoint_head.inference_model(
features_flipped, img_metas[0]['flip_pairs'])
output_heatmap = (output_heatmap +
output_flipped_heatmap) * 0.5
if self.with_keypoint:
keypoint_result = self.keypoint_head.decode(
img_metas, output_heatmap, img_size=[img_width, img_height])
result.update(keypoint_result)
if not return_heatmap:
output_heatmap = None
if output_heatmap:
result['output_heatmap'] = output_heatmap
return result
@deprecated_api_warning({'pose_limb_color': 'pose_link_color'},
cls_name='TopDown')
def show_result(self,
img,
result,
skeleton=None,
kpt_score_thr=0.3,
bbox_color='green',
pose_kpt_color=None,
pose_link_color=None,
text_color='white',
radius=4,
thickness=1,
font_scale=0.5,
bbox_thickness=1,
win_name='',
show=False,
show_keypoint_weight=False,
wait_time=0,
out_file=None):
"""Draw `result` over `img`.
Args:
img (str or Tensor): The image to be displayed.
result (list[dict]): The results to draw over `img`
(bbox_result, pose_result).
skeleton (list[list]): The connection of keypoints.
skeleton is 0-based indexing.
kpt_score_thr (float, optional): Minimum score of keypoints
to be shown. Default: 0.3.
bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
pose_kpt_color (np.array[Nx3]`): Color of N keypoints.
If None, do not draw keypoints.
pose_link_color (np.array[Mx3]): Color of M links.
If None, do not draw links.
text_color (str or tuple or :obj:`Color`): Color of texts.
radius (int): Radius of circles.
thickness (int): Thickness of lines.
font_scale (float): Font scales of texts.
win_name (str): The window name.
show (bool): Whether to show the image. Default: False.
show_keypoint_weight (bool): Whether to change the transparency
using the predicted confidence scores of keypoints.
wait_time (int): Value of waitKey param.
Default: 0.
out_file (str or None): The filename to write the image.
Default: None.
Returns:
Tensor: Visualized img, only if not `show` or `out_file`.
"""
img = mmcv.imread(img)
img = img.copy()
bbox_result = []
pose_result = []
for res in result:
if 'bbox' in res:
bbox_result.append(res['bbox'])
pose_result.append(res['keypoints'])
if bbox_result:
bboxes = np.vstack(bbox_result)
labels = None
if 'label' in result[0]:
labels = [res['label'] for res in result]
# draw bounding boxes
imshow_bboxes(
img,
bboxes,
labels=labels,
colors=bbox_color,
text_color=text_color,
thickness=bbox_thickness,
font_scale=font_scale,
show=False)
imshow_keypoints(img, pose_result, skeleton, kpt_score_thr,
pose_kpt_color, pose_link_color, radius, thickness)
if show:
imshow(img, win_name, wait_time)
if out_file is not None:
imwrite(img, out_file)
return img