mirror of https://github.com/alibaba/EasyCV.git
122 lines
4.0 KiB
Python
122 lines
4.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# Adapt from https://github.com/open-mmlab/mmpose/blob/master/mmpose/models/heads/topdown_heatmap_base_head.py
|
|
|
|
from abc import ABCMeta, abstractmethod
|
|
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
|
|
from easycv.core.evaluation.top_down_eval import keypoints_from_heatmaps
|
|
|
|
|
|
class TopdownHeatmapBaseHead(nn.Module):
|
|
"""Base class for top-down heatmap heads.
|
|
|
|
All top-down heatmap heads should subclass it.
|
|
All subclass should overwrite:
|
|
|
|
Methods:`get_loss`, supporting to calculate loss.
|
|
Methods:`get_accuracy`, supporting to calculate accuracy.
|
|
Methods:`forward`, supporting to forward model.
|
|
Methods:`inference_model`, supporting to inference model.
|
|
"""
|
|
|
|
__metaclass__ = ABCMeta
|
|
|
|
@abstractmethod
|
|
def get_loss(self, **kwargs):
|
|
"""Gets the loss."""
|
|
|
|
@abstractmethod
|
|
def get_accuracy(self, **kwargs):
|
|
"""Gets the accuracy."""
|
|
|
|
@abstractmethod
|
|
def forward(self, **kwargs):
|
|
"""Forward function."""
|
|
|
|
@abstractmethod
|
|
def inference_model(self, **kwargs):
|
|
"""Inference function."""
|
|
|
|
def decode(self, img_metas, output, **kwargs):
|
|
"""Decode keypoints from heatmaps.
|
|
|
|
Args:
|
|
img_metas (list(dict)): Information about data augmentation
|
|
By default this includes:
|
|
- "image_file: path to the image file
|
|
- "center": center of the bbox
|
|
- "scale": scale of the bbox
|
|
- "rotation": rotation of the bbox
|
|
- "bbox_score": score of bbox
|
|
output (np.ndarray[N, K, H, W]): model predicted heatmaps.
|
|
"""
|
|
batch_size = len(img_metas)
|
|
|
|
if 'bbox_id' in img_metas[0]:
|
|
bbox_ids = []
|
|
else:
|
|
bbox_ids = None
|
|
|
|
c = np.zeros((batch_size, 2), dtype=np.float32)
|
|
s = np.zeros((batch_size, 2), dtype=np.float32)
|
|
image_ids = []
|
|
score = np.ones(batch_size)
|
|
for i in range(batch_size):
|
|
c[i, :] = img_metas[i]['center']
|
|
s[i, :] = img_metas[i]['scale']
|
|
image_ids.append(img_metas[i]['image_id'])
|
|
|
|
if 'bbox_score' in img_metas[i]:
|
|
score[i] = np.array(img_metas[i]['bbox_score']).reshape(-1)
|
|
if bbox_ids is not None:
|
|
bbox_ids.append(img_metas[i]['bbox_id'])
|
|
|
|
preds, maxvals = keypoints_from_heatmaps(
|
|
output,
|
|
c,
|
|
s,
|
|
unbiased=self.test_cfg.get('unbiased_decoding', False),
|
|
post_process=self.test_cfg.get('post_process', 'default'),
|
|
kernel=self.test_cfg.get('modulate_kernel', 11),
|
|
valid_radius_factor=self.test_cfg.get('valid_radius_factor',
|
|
0.0546875),
|
|
use_udp=self.test_cfg.get('use_udp', False),
|
|
target_type=self.test_cfg.get('target_type', 'GaussianHeatmap'))
|
|
|
|
all_preds = np.zeros((batch_size, preds.shape[1], 3), dtype=np.float32)
|
|
all_boxes = np.zeros((batch_size, 6), dtype=np.float32)
|
|
all_preds[:, :, 0:2] = preds[:, :, 0:2]
|
|
all_preds[:, :, 2:3] = maxvals
|
|
all_boxes[:, 0:2] = c[:, 0:2]
|
|
all_boxes[:, 2:4] = s[:, 0:2]
|
|
all_boxes[:, 4] = np.prod(s * 200.0, axis=1)
|
|
all_boxes[:, 5] = score
|
|
|
|
result = {}
|
|
|
|
result['preds'] = all_preds
|
|
result['boxes'] = all_boxes
|
|
result['image_ids'] = image_ids
|
|
result['bbox_ids'] = bbox_ids
|
|
|
|
return result
|
|
|
|
@staticmethod
|
|
def _get_deconv_cfg(deconv_kernel):
|
|
"""Get configurations for deconv layers."""
|
|
if deconv_kernel == 4:
|
|
padding = 1
|
|
output_padding = 0
|
|
elif deconv_kernel == 3:
|
|
padding = 1
|
|
output_padding = 1
|
|
elif deconv_kernel == 2:
|
|
padding = 0
|
|
output_padding = 0
|
|
else:
|
|
raise ValueError(f'Not supported num_kernels ({deconv_kernel}).')
|
|
|
|
return deconv_kernel, padding, output_padding
|