mirror of https://github.com/alibaba/EasyCV.git
parent
c73edeee1c
commit
4cf6f794e4
|
@ -0,0 +1,120 @@
|
|||
_base_ = 'configs/base.py'
|
||||
|
||||
CLASSES = [
|
||||
'drink water', 'eat meal/snack', 'brushing teeth', 'brushing hair', 'drop',
|
||||
'pickup', 'throw', 'sitting down', 'standing up (from sitting position)',
|
||||
'clapping', 'reading', 'writing', 'tear up paper', 'wear jacket',
|
||||
'take off jacket', 'wear a shoe', 'take off a shoe', 'wear on glasses',
|
||||
'take off glasses', 'put on a hat/cap', 'take off a hat/cap', 'cheer up',
|
||||
'hand waving', 'kicking something', 'reach into pocket',
|
||||
'hopping (one foot jumping)', 'jump up', 'make a phone call/answer phone',
|
||||
'playing with phone/tablet', 'typing on a keyboard',
|
||||
'pointing to something with finger', 'taking a selfie',
|
||||
'check time (from watch)', 'rub two hands together', 'nod head/bow',
|
||||
'shake head', 'wipe face', 'salute', 'put the palms together',
|
||||
'cross hands in front (say stop)', 'sneeze/cough', 'staggering', 'falling',
|
||||
'touch head (headache)', 'touch chest (stomachache/heart pain)',
|
||||
'touch back (backache)', 'touch neck (neckache)',
|
||||
'nausea or vomiting condition',
|
||||
'use a fan (with hand or paper)/feeling warm',
|
||||
'punching/slapping other person', 'kicking other person',
|
||||
'pushing other person', 'pat on back of other person',
|
||||
'point finger at the other person', 'hugging other person',
|
||||
'giving something to other person', "touch other person's pocket",
|
||||
'handshaking', 'walking towards each other',
|
||||
'walking apart from each other'
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='SkeletonGCN',
|
||||
backbone=dict(
|
||||
type='STGCN',
|
||||
in_channels=3,
|
||||
edge_importance_weighting=True,
|
||||
graph_cfg=dict(layout='coco', strategy='spatial')),
|
||||
cls_head=dict(
|
||||
type='STGCNHead',
|
||||
num_classes=60,
|
||||
in_channels=256,
|
||||
loss_cls=dict(type='CrossEntropyLoss')),
|
||||
train_cfg=None,
|
||||
test_cfg=None)
|
||||
|
||||
dataset_type = 'VideoDataset'
|
||||
ann_file_train = 'data/posec3d/ntu60_xsub_train.pkl'
|
||||
ann_file_val = 'data/posec3d/ntu60_xsub_val.pkl'
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='PaddingWithLoop', clip_len=300),
|
||||
dict(type='PoseDecode'),
|
||||
dict(type='FormatGCNInput', input_format='NCTVM'),
|
||||
dict(type='PoseNormalize'),
|
||||
dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
|
||||
dict(type='VideoToTensor', keys=['keypoint'])
|
||||
]
|
||||
val_pipeline = [
|
||||
dict(type='PaddingWithLoop', clip_len=300),
|
||||
dict(type='PoseDecode'),
|
||||
dict(type='FormatGCNInput', input_format='NCTVM'),
|
||||
dict(type='PoseNormalize'),
|
||||
dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
|
||||
dict(type='VideoToTensor', keys=['keypoint'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='PaddingWithLoop', clip_len=300),
|
||||
dict(type='PoseDecode'),
|
||||
dict(type='FormatGCNInput', input_format='NCTVM'),
|
||||
dict(type='PoseNormalize'),
|
||||
dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
|
||||
dict(type='VideoToTensor', keys=['keypoint'])
|
||||
]
|
||||
data = dict(
|
||||
imgs_per_gpu=16,
|
||||
workers_per_gpu=2,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_source=dict(
|
||||
type='PoseDataSourceForVideoRec',
|
||||
ann_file=ann_file_train,
|
||||
data_prefix='',
|
||||
),
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
imgs_per_gpu=1,
|
||||
data_source=dict(
|
||||
type='PoseDataSourceForVideoRec',
|
||||
ann_file=ann_file_val,
|
||||
data_prefix='',
|
||||
),
|
||||
pipeline=val_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_source=dict(
|
||||
type='PoseDataSourceForVideoRec',
|
||||
ann_file=ann_file_val,
|
||||
data_prefix='',
|
||||
),
|
||||
pipeline=test_pipeline))
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(
|
||||
type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001, nesterov=True)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[10, 50])
|
||||
total_epochs = 80
|
||||
|
||||
# eval
|
||||
eval_config = dict(initial=False, interval=1, gpu_collect=True)
|
||||
eval_pipelines = [
|
||||
dict(
|
||||
mode='test',
|
||||
data=data['val'],
|
||||
dist_eval=True,
|
||||
evaluators=[dict(type='ClsEvaluator', topk=(1, 5))],
|
||||
)
|
||||
]
|
||||
|
||||
log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])
|
||||
checkpoint_config = dict(interval=1)
|
|
@ -0,0 +1,217 @@
|
|||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from easycv.file.utils import is_url_path
|
||||
from easycv.predictors.pose_predictor import PoseTopDownPredictor
|
||||
from easycv.predictors.video_classifier import STGCNPredictor
|
||||
|
||||
try:
|
||||
import moviepy.editor as mpy
|
||||
except ImportError:
|
||||
raise ImportError('Please install moviepy to enable output file')
|
||||
|
||||
FONTFACE = cv2.FONT_HERSHEY_DUPLEX
|
||||
FONTSCALE = 0.75
|
||||
FONTCOLOR = (255, 255, 255) # BGR, white
|
||||
THICKNESS = 1
|
||||
LINETYPE = 1
|
||||
TMP_DIR = './tmp'
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Video classification demo based skeleton.')
|
||||
parser.add_argument(
|
||||
'--video',
|
||||
default=
|
||||
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/demos/videos/ntu_sample.avi',
|
||||
help='video file/url')
|
||||
parser.add_argument(
|
||||
'--out_file',
|
||||
default=f'{TMP_DIR}/demo_show.mp4',
|
||||
help='output filename')
|
||||
parser.add_argument(
|
||||
'--config',
|
||||
default=(
|
||||
'configs/video_recognition/stgcn/stgcn_80e_ntu60_xsub_keypoint.py'
|
||||
),
|
||||
help='skeleton model config file path')
|
||||
parser.add_argument(
|
||||
'--checkpoint',
|
||||
default=
|
||||
('http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/video/skeleton_based/stgcn/stgcn_80e_ntu60_xsub.pth'
|
||||
),
|
||||
help='skeleton model checkpoint file/url')
|
||||
parser.add_argument(
|
||||
'--det-config',
|
||||
default='configs/detection/yolox/yolox_s_8xb16_300e_coco.py',
|
||||
help='human detection config file path')
|
||||
parser.add_argument(
|
||||
'--det-checkpoint',
|
||||
default=
|
||||
('http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_s_bs16_lr002/epoch_300.pt'
|
||||
),
|
||||
help='human detection checkpoint file/url')
|
||||
parser.add_argument(
|
||||
'--det-predictor-type',
|
||||
default='YoloXPredictor',
|
||||
help='detection predictor type')
|
||||
parser.add_argument(
|
||||
'--pose-config',
|
||||
default='configs/pose/hrnet_w48_coco_256x192_udp.py',
|
||||
help='human pose estimation config file path')
|
||||
parser.add_argument(
|
||||
'--pose-checkpoint',
|
||||
default=
|
||||
('http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/pose/top_down_hrnet/pose_hrnet_epoch_210_export.pt'
|
||||
),
|
||||
help='human pose estimation checkpoint file/url')
|
||||
parser.add_argument(
|
||||
'--bbox-thr',
|
||||
type=float,
|
||||
default=0.5,
|
||||
help='the threshold of human detection score')
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda:0', help='CPU/CUDA device option')
|
||||
parser.add_argument(
|
||||
'--short-side',
|
||||
type=int,
|
||||
default=480,
|
||||
help='specify the short-side length of the image')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def frame_extraction(video_path, short_side):
|
||||
"""Extract frames given video_path.
|
||||
|
||||
Args:
|
||||
video_path (str): The video_path.
|
||||
"""
|
||||
if is_url_path(video_path):
|
||||
from torch.hub import download_url_to_file
|
||||
cache_video_path = os.path.join(TMP_DIR, os.path.basename(video_path))
|
||||
print(
|
||||
'Download video file from remote to local path "{cache_video_path}"...'
|
||||
)
|
||||
download_url_to_file(video_path, cache_video_path)
|
||||
video_path = cache_video_path
|
||||
|
||||
# Load the video, extract frames into ./tmp/video_name
|
||||
target_dir = osp.join(TMP_DIR, osp.basename(osp.splitext(video_path)[0]))
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
# Should be able to handle videos up to several hours
|
||||
frame_tmpl = osp.join(target_dir, 'img_{:06d}.jpg')
|
||||
vid = cv2.VideoCapture(video_path)
|
||||
frames = []
|
||||
frame_paths = []
|
||||
flag, frame = vid.read()
|
||||
cnt = 0
|
||||
new_h, new_w = None, None
|
||||
while flag:
|
||||
if new_h is None:
|
||||
h, w, _ = frame.shape
|
||||
new_w, new_h = mmcv.rescale_size((w, h), (short_side, np.Inf))
|
||||
frame = mmcv.imresize(frame, (new_w, new_h))
|
||||
frames.append(frame)
|
||||
frame_path = frame_tmpl.format(cnt + 1)
|
||||
frame_paths.append(frame_path)
|
||||
|
||||
cv2.imwrite(frame_path, frame)
|
||||
cnt += 1
|
||||
flag, frame = vid.read()
|
||||
|
||||
return frame_paths, frames
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if not osp.exists(TMP_DIR):
|
||||
os.makedirs(TMP_DIR)
|
||||
|
||||
frame_paths, original_frames = frame_extraction(args.video,
|
||||
args.short_side)
|
||||
num_frame = len(frame_paths)
|
||||
h, w, _ = original_frames[0].shape
|
||||
|
||||
# Get Human detection results
|
||||
pose_predictor = PoseTopDownPredictor(
|
||||
model_path=args.pose_checkpoint,
|
||||
config_file=args.pose_config,
|
||||
detection_predictor_config=dict(
|
||||
type=args.det_predictor_type,
|
||||
model_path=args.det_checkpoint,
|
||||
config_file=args.det_config,
|
||||
),
|
||||
bbox_thr=args.bbox_thr,
|
||||
cat_id=0, # person category id
|
||||
)
|
||||
|
||||
video_cls_predictor = STGCNPredictor(
|
||||
model_path=args.checkpoint,
|
||||
config_file=args.config,
|
||||
ori_image_size=(w, h),
|
||||
label_map=None)
|
||||
|
||||
pose_results = pose_predictor(original_frames)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
fake_anno = dict(
|
||||
frame_dir='',
|
||||
label=-1,
|
||||
img_shape=(h, w),
|
||||
original_shape=(h, w),
|
||||
start_index=0,
|
||||
modality='Pose',
|
||||
total_frames=num_frame)
|
||||
num_person = max([len(x) for x in pose_results])
|
||||
|
||||
num_keypoint = 17
|
||||
keypoints = np.zeros((num_person, num_frame, num_keypoint, 2),
|
||||
dtype=np.float16)
|
||||
keypoints_score = np.zeros((num_person, num_frame, num_keypoint),
|
||||
dtype=np.float16)
|
||||
for i, poses in enumerate(pose_results):
|
||||
if len(poses) < 1:
|
||||
continue
|
||||
_keypoint = poses['keypoints'] # shape = (num_person, num_keypoint, 3)
|
||||
for j, pose in enumerate(_keypoint):
|
||||
keypoints[j, i] = pose[:, :2]
|
||||
keypoints_score[j, i] = pose[:, 2]
|
||||
|
||||
fake_anno['keypoint'] = keypoints
|
||||
fake_anno['keypoint_score'] = keypoints_score
|
||||
|
||||
results = video_cls_predictor([fake_anno])
|
||||
|
||||
action_label = results[0]['class_name'][0]
|
||||
print(f'action label: {action_label}')
|
||||
|
||||
vis_frames = [
|
||||
pose_predictor.show_result(original_frames[i], pose_results[i])
|
||||
if len(pose_results[i]) > 0 else original_frames[i]
|
||||
for i in range(num_frame)
|
||||
]
|
||||
for frame in vis_frames:
|
||||
cv2.putText(frame, action_label, (10, 30), FONTFACE, FONTSCALE,
|
||||
FONTCOLOR, THICKNESS, LINETYPE)
|
||||
|
||||
vid = mpy.ImageSequenceClip([x[:, :, ::-1] for x in vis_frames], fps=24)
|
||||
vid.write_videofile(args.out_file, remove_temp=True)
|
||||
print(f'Write video to {args.out_file} successfully!')
|
||||
|
||||
tmp_frame_dir = osp.dirname(frame_paths[0])
|
||||
shutil.rmtree(tmp_frame_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -50,11 +50,17 @@ class ClsEvaluator(Evaluator):
|
|||
a dict, each key is metric_name, value is metric value
|
||||
'''
|
||||
eval_res = OrderedDict()
|
||||
|
||||
if isinstance(gt_labels, dict):
|
||||
assert len(gt_labels) == 1
|
||||
gt_labels = list(gt_labels.values())[0]
|
||||
|
||||
target = gt_labels.long()
|
||||
|
||||
# if self.neck_num is not None:
|
||||
if self.neck_num is None:
|
||||
predictions = {'neck': predictions['neck']}
|
||||
if len(predictions) > 1:
|
||||
predictions = {'neck': predictions['neck']}
|
||||
else:
|
||||
predictions = {
|
||||
'neck_%d_0' % self.neck_num:
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .pose_datasource import PoseDataSourceForVideoRec
|
||||
from .video_datasource import VideoDatasource
|
||||
from .video_text_datasource import VideoTextDatasource
|
||||
|
|
|
@ -0,0 +1,176 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import copy
|
||||
import os.path as osp
|
||||
from abc import ABCMeta
|
||||
from collections import defaultdict
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from easycv.datasets.registry import DATASOURCES
|
||||
from easycv.utils.logger import get_root_logger
|
||||
|
||||
|
||||
@DATASOURCES.register_module()
|
||||
class PoseDataSourceForVideoRec(Dataset, metaclass=ABCMeta):
|
||||
"""Pose data source for video recognition.
|
||||
Args:
|
||||
ann_file (str): Path to the annotation file.
|
||||
data_prefix (str | None): Path to a directory where videos are held.
|
||||
Default: None.
|
||||
multi_class (bool): Determines whether the dataset is a multi-class
|
||||
dataset. Default: False.
|
||||
num_classes (int | None): Number of classes of the dataset, used in
|
||||
multi-class datasets. Default: None.
|
||||
start_index (int): Specify a start index for frames in consideration of
|
||||
different filename format. However, when taking videos as input,
|
||||
it should be set to 0, since frames loaded from videos count
|
||||
from 0. Default: 1.
|
||||
sample_by_class (bool): Sampling by class, should be set `True` when
|
||||
performing inter-class data balancing. Only compatible with
|
||||
`multi_class == False`. Only applies for training. Default: False.
|
||||
power (float): We support sampling data with the probability
|
||||
proportional to the power of its label frequency (freq ^ power)
|
||||
when sampling data. `power == 1` indicates uniformly sampling all
|
||||
data; `power == 0` indicates uniformly sampling all classes.
|
||||
Default: 0.
|
||||
dynamic_length (bool): If the dataset length is dynamic (used by
|
||||
ClassSpecificDistributedSampler). Default: False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ann_file,
|
||||
data_prefix=None,
|
||||
multi_class=False,
|
||||
num_classes=None,
|
||||
start_index=1,
|
||||
sample_by_class=False,
|
||||
power=0,
|
||||
dynamic_length=False,
|
||||
split=None,
|
||||
valid_ratio=None,
|
||||
box_thr=None,
|
||||
class_prob=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.modality = 'Pose'
|
||||
# split, applicable to ucf or hmdb
|
||||
self.split = split
|
||||
|
||||
self.ann_file = ann_file
|
||||
self.data_prefix = osp.realpath(
|
||||
data_prefix) if data_prefix is not None and osp.isdir(
|
||||
data_prefix) else data_prefix
|
||||
self.multi_class = multi_class
|
||||
self.num_classes = num_classes
|
||||
self.start_index = start_index
|
||||
|
||||
self.sample_by_class = sample_by_class
|
||||
self.power = power
|
||||
self.dynamic_length = dynamic_length
|
||||
|
||||
assert not (self.multi_class and self.sample_by_class)
|
||||
|
||||
self.video_infos = self.load_annotations()
|
||||
if self.sample_by_class:
|
||||
self.video_infos_by_class = self.parse_by_class()
|
||||
|
||||
class_prob = []
|
||||
for _, samples in self.video_infos_by_class.items():
|
||||
class_prob.append(len(samples) / len(self.video_infos))
|
||||
class_prob = [x**self.power for x in class_prob]
|
||||
|
||||
summ = sum(class_prob)
|
||||
class_prob = [x / summ for x in class_prob]
|
||||
|
||||
self.class_prob = dict(zip(self.video_infos_by_class, class_prob))
|
||||
|
||||
# box_thr, which should be a string
|
||||
self.box_thr = box_thr
|
||||
if self.box_thr is not None:
|
||||
assert box_thr in ['0.5', '0.6', '0.7', '0.8', '0.9']
|
||||
|
||||
# Thresholding Training Examples
|
||||
self.valid_ratio = valid_ratio
|
||||
if self.valid_ratio is not None:
|
||||
assert isinstance(self.valid_ratio, float)
|
||||
if self.box_thr is None:
|
||||
self.video_infos = self.video_infos = [
|
||||
x for x in self.video_infos
|
||||
if x['valid_frames'] / x['total_frames'] >= valid_ratio
|
||||
]
|
||||
else:
|
||||
key = f'valid@{self.box_thr}'
|
||||
self.video_infos = [
|
||||
x for x in self.video_infos
|
||||
if x[key] / x['total_frames'] >= valid_ratio
|
||||
]
|
||||
if self.box_thr != '0.5':
|
||||
box_thr = float(self.box_thr)
|
||||
for item in self.video_infos:
|
||||
inds = [
|
||||
i for i, score in enumerate(item['box_score'])
|
||||
if score >= box_thr
|
||||
]
|
||||
item['anno_inds'] = np.array(inds)
|
||||
|
||||
if class_prob is not None:
|
||||
self.class_prob = class_prob
|
||||
|
||||
logger = get_root_logger()
|
||||
logger.info(f'{len(self)} videos remain after valid thresholding')
|
||||
|
||||
def load_annotations(self):
|
||||
"""Load annotation file to get video information."""
|
||||
assert self.ann_file.endswith('.pkl')
|
||||
return self.load_pkl_annotations()
|
||||
|
||||
def load_pkl_annotations(self):
|
||||
data = mmcv.load(self.ann_file)
|
||||
|
||||
if self.split:
|
||||
split, data = data['split'], data['annotations']
|
||||
identifier = 'filename' if 'filename' in data[0] else 'frame_dir'
|
||||
data = [x for x in data if x[identifier] in split[self.split]]
|
||||
|
||||
for item in data:
|
||||
# Sometimes we may need to load anno from the file
|
||||
if 'filename' in item:
|
||||
item['filename'] = osp.join(self.data_prefix, item['filename'])
|
||||
if 'frame_dir' in item:
|
||||
item['frame_dir'] = osp.join(self.data_prefix,
|
||||
item['frame_dir'])
|
||||
return data
|
||||
|
||||
def parse_by_class(self):
|
||||
video_infos_by_class = defaultdict(list)
|
||||
for item in self.video_infos:
|
||||
label = item['label']
|
||||
video_infos_by_class[label].append(item)
|
||||
return video_infos_by_class
|
||||
|
||||
def prepare_frames(self, idx):
|
||||
"""Prepare the frames for training given the index."""
|
||||
results = copy.deepcopy(self.video_infos[idx])
|
||||
results['modality'] = self.modality
|
||||
results['start_index'] = self.start_index
|
||||
|
||||
# prepare tensor in getitem
|
||||
# If HVU, type(results['label']) is dict
|
||||
if self.multi_class and isinstance(results['label'], list):
|
||||
onehot = torch.zeros(self.num_classes)
|
||||
onehot[results['label']] = 1.
|
||||
results['label'] = onehot
|
||||
|
||||
return results
|
||||
|
||||
def __len__(self):
|
||||
"""Get the size of the dataset."""
|
||||
return len(self.video_infos)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.prepare_frames(idx)
|
|
@ -1,7 +1,20 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# isort:skip_file
|
||||
# yapf:disable
|
||||
from .loading import DecordInit, DecordDecode, SampleFrames
|
||||
from .transform import VideoImgaug, VideoFuse, VideoRandomScale, VideoRandomCrop, VideoRandomResizedCrop, VideoMultiScaleCrop, VideoResize, VideoRandomRescale, VideoFlip, VideoNormalize, VideoColorJitter, VideoCenterCrop, VideoThreeCrop, VideoTenCrop, VideoMultiGroupCrop
|
||||
from .loading import DecordDecode, DecordInit, SampleFrames
|
||||
from .pose_transform import (FormatGCNInput, PaddingWithLoop, PoseDecode,
|
||||
PoseNormalize)
|
||||
from .text_transform import TextTokenizer
|
||||
__all__ = [DecordInit, DecordDecode, SampleFrames, VideoImgaug, VideoFuse, VideoRandomScale, VideoRandomCrop, VideoRandomResizedCrop, VideoMultiScaleCrop, VideoResize, VideoRandomRescale, VideoFlip, VideoNormalize, VideoColorJitter, VideoCenterCrop, VideoThreeCrop, VideoTenCrop, VideoMultiGroupCrop, TextTokenizer]
|
||||
from .transform import (VideoCenterCrop, VideoColorJitter, VideoFlip,
|
||||
VideoFuse, VideoImgaug, VideoMultiGroupCrop,
|
||||
VideoMultiScaleCrop, VideoNormalize, VideoRandomCrop,
|
||||
VideoRandomRescale, VideoRandomResizedCrop,
|
||||
VideoRandomScale, VideoResize, VideoTenCrop,
|
||||
VideoThreeCrop)
|
||||
|
||||
__all__ = [
|
||||
'DecordInit', 'DecordDecode', 'SampleFrames', 'VideoImgaug', 'VideoFuse',
|
||||
'VideoRandomScale', 'VideoRandomCrop', 'VideoRandomResizedCrop',
|
||||
'VideoMultiScaleCrop', 'VideoResize', 'VideoRandomRescale', 'VideoFlip',
|
||||
'VideoNormalize', 'VideoColorJitter', 'VideoCenterCrop', 'VideoThreeCrop',
|
||||
'VideoTenCrop', 'VideoMultiGroupCrop', 'TextTokenizer', 'PaddingWithLoop',
|
||||
'PoseDecode', 'PoseNormalize', 'FormatGCNInput'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,184 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Refer to: https://github.com/open-mmlab/mmaction2/blob/master/mmaction/datasets/pipelines/pose_loading.py
|
||||
import numpy as np
|
||||
|
||||
from easycv.datasets.registry import PIPELINES
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class PaddingWithLoop:
|
||||
"""Sample frames from the video.
|
||||
|
||||
To sample an n-frame clip from the video, PaddingWithLoop samples
|
||||
the frames from zero index, and loop the frames if the length of
|
||||
video frames is less than te value of 'clip_len'.
|
||||
|
||||
Required keys are "total_frames", added or modified keys
|
||||
are "frame_inds", "clip_len", "frame_interval" and "num_clips".
|
||||
|
||||
Args:
|
||||
clip_len (int): Frames of each sampled output clip.
|
||||
num_clips (int): Number of clips to be sampled. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, clip_len, num_clips=1):
|
||||
|
||||
self.clip_len = clip_len
|
||||
self.num_clips = num_clips
|
||||
|
||||
def __call__(self, results):
|
||||
num_frames = results['total_frames']
|
||||
|
||||
start = 0
|
||||
inds = np.arange(start, start + self.clip_len)
|
||||
inds = np.mod(inds, num_frames)
|
||||
|
||||
results['frame_inds'] = inds.astype(np.int)
|
||||
results['clip_len'] = self.clip_len
|
||||
results['frame_interval'] = None
|
||||
results['num_clips'] = self.num_clips
|
||||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class PoseDecode:
|
||||
"""Load and decode pose with given indices.
|
||||
|
||||
Required keys are "keypoint", "frame_inds" (optional), "keypoint_score"
|
||||
(optional), added or modified keys are "keypoint", "keypoint_score" (if
|
||||
applicable).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _load_kp(kp, frame_inds):
|
||||
"""Load keypoints given frame indices.
|
||||
|
||||
Args:
|
||||
kp (np.ndarray): The keypoint coordinates.
|
||||
frame_inds (np.ndarray): The frame indices.
|
||||
"""
|
||||
|
||||
return [x[frame_inds].astype(np.float32) for x in kp]
|
||||
|
||||
@staticmethod
|
||||
def _load_kpscore(kpscore, frame_inds):
|
||||
"""Load keypoint scores given frame indices.
|
||||
|
||||
Args:
|
||||
kpscore (np.ndarray): The confidence scores of keypoints.
|
||||
frame_inds (np.ndarray): The frame indices.
|
||||
"""
|
||||
|
||||
return [x[frame_inds].astype(np.float32) for x in kpscore]
|
||||
|
||||
def __call__(self, results):
|
||||
|
||||
if 'frame_inds' not in results:
|
||||
results['frame_inds'] = np.arange(results['total_frames'])
|
||||
|
||||
if results['frame_inds'].ndim != 1:
|
||||
results['frame_inds'] = np.squeeze(results['frame_inds'])
|
||||
|
||||
offset = results.get('offset', 0)
|
||||
frame_inds = results['frame_inds'] + offset
|
||||
|
||||
if 'keypoint_score' in results:
|
||||
kpscore = results['keypoint_score']
|
||||
results['keypoint_score'] = kpscore[:,
|
||||
frame_inds].astype(np.float32)
|
||||
|
||||
if 'keypoint' in results:
|
||||
results['keypoint'] = results['keypoint'][:, frame_inds].astype(
|
||||
np.float32)
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = f'{self.__class__.__name__}()'
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class PoseNormalize:
|
||||
"""Normalize the range of keypoint values to [-1,1].
|
||||
|
||||
Args:
|
||||
mean (list | tuple): The mean value of the keypoint values.
|
||||
min_value (list | tuple): The minimum value of the keypoint values.
|
||||
max_value (list | tuple): The maximum value of the keypoint values.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mean=(960., 540., 0.5),
|
||||
min_value=(0., 0., 0.),
|
||||
max_value=(1920, 1080, 1.)):
|
||||
self.mean = np.array(mean, dtype=np.float32).reshape(-1, 1, 1, 1)
|
||||
self.min_value = np.array(
|
||||
min_value, dtype=np.float32).reshape(-1, 1, 1, 1)
|
||||
self.max_value = np.array(
|
||||
max_value, dtype=np.float32).reshape(-1, 1, 1, 1)
|
||||
|
||||
def __call__(self, results):
|
||||
keypoint = results['keypoint']
|
||||
keypoint = (keypoint - self.mean) / (self.max_value - self.min_value)
|
||||
results['keypoint'] = keypoint
|
||||
results['keypoint_norm_cfg'] = dict(
|
||||
mean=self.mean, min_value=self.min_value, max_value=self.max_value)
|
||||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class FormatGCNInput:
|
||||
"""Format final skeleton shape to the given input_format.
|
||||
|
||||
Required keys are "keypoint" and "keypoint_score"(optional),
|
||||
added or modified keys are "keypoint" and "input_shape".
|
||||
|
||||
Args:
|
||||
input_format (str): Define the final skeleton format.
|
||||
"""
|
||||
|
||||
def __init__(self, input_format, num_person=2):
|
||||
self.input_format = input_format
|
||||
if self.input_format not in ['NCTVM']:
|
||||
raise ValueError(
|
||||
f'The input format {self.input_format} is invalid.')
|
||||
self.num_person = num_person
|
||||
|
||||
def __call__(self, results):
|
||||
"""Performs the FormatShape formatting.
|
||||
|
||||
Args:
|
||||
results (dict): The resulting dict to be modified and passed
|
||||
to the next transform in pipeline.
|
||||
"""
|
||||
keypoint = results['keypoint']
|
||||
|
||||
if 'keypoint_score' in results:
|
||||
keypoint_confidence = results['keypoint_score']
|
||||
keypoint_confidence = np.expand_dims(keypoint_confidence, -1)
|
||||
keypoint_3d = np.concatenate((keypoint, keypoint_confidence),
|
||||
axis=-1)
|
||||
else:
|
||||
keypoint_3d = keypoint
|
||||
|
||||
keypoint_3d = np.transpose(keypoint_3d,
|
||||
(3, 1, 2, 0)) # M T V C -> C T V M
|
||||
|
||||
if keypoint_3d.shape[-1] < self.num_person:
|
||||
pad_dim = self.num_person - keypoint_3d.shape[-1]
|
||||
pad = np.zeros(
|
||||
keypoint_3d.shape[:-1] + (pad_dim, ), dtype=keypoint_3d.dtype)
|
||||
keypoint_3d = np.concatenate((keypoint_3d, pad), axis=-1)
|
||||
elif keypoint_3d.shape[-1] > self.num_person:
|
||||
keypoint_3d = keypoint_3d[:, :, :, :self.num_person]
|
||||
|
||||
results['keypoint'] = keypoint_3d
|
||||
results['input_shape'] = keypoint_3d.shape
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f"(input_format='{self.input_format}')"
|
||||
return repr_str
|
|
@ -3,6 +3,7 @@ import logging
|
|||
import traceback
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from easycv.datasets.registry import DATASETS
|
||||
from easycv.datasets.shared.base import BaseDataset
|
||||
|
@ -42,7 +43,15 @@ class VideoDataset(BaseDataset):
|
|||
"""
|
||||
assert len(evaluators) == 1, \
|
||||
'classification evaluation only support one evaluator'
|
||||
gt_labels = results.pop('label')
|
||||
if results.get('label', None) is not None:
|
||||
gt_labels = results.pop('label')
|
||||
else:
|
||||
gt_labels = []
|
||||
for i in range(len(self.data_source)):
|
||||
label = self.data_source.video_infos[i]['label']
|
||||
gt_labels.append(label)
|
||||
gt_labels = torch.Tensor(gt_labels)
|
||||
|
||||
eval_res = evaluators[0].evaluate(results, gt_labels)
|
||||
|
||||
return eval_res
|
||||
|
|
|
@ -2,3 +2,4 @@
|
|||
from .ClipBertTwoStream import ClipBertTwoStream
|
||||
from .heads import I3DHead
|
||||
from .recognizer3d import Recognizer3D
|
||||
from .skeleton_gcn.skeleton_gcn import SkeletonGCN
|
||||
|
|
|
@ -58,6 +58,7 @@ class BaseHead(nn.Module, metaclass=ABCMeta):
|
|||
recognition task. Default: False.
|
||||
label_smooth_eps (float): Epsilon used in label smooth.
|
||||
Reference: arxiv.org/abs/1906.02629. Default: 0.
|
||||
topk (int | tuple): Top-k accuracy. Default: (1, 5).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -65,13 +66,20 @@ class BaseHead(nn.Module, metaclass=ABCMeta):
|
|||
in_channels,
|
||||
loss_cls=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
multi_class=False,
|
||||
label_smooth_eps=0.0):
|
||||
label_smooth_eps=0.0,
|
||||
topk=(1, 5)):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.in_channels = in_channels
|
||||
self.loss_cls = build_loss(loss_cls)
|
||||
self.multi_class = multi_class
|
||||
self.label_smooth_eps = label_smooth_eps
|
||||
assert isinstance(topk, (int, tuple))
|
||||
if isinstance(topk, int):
|
||||
topk = (topk, )
|
||||
for _topk in topk:
|
||||
assert _topk > 0, 'Top-k should be larger than 0'
|
||||
self.topk = topk
|
||||
|
||||
@abstractmethod
|
||||
def init_weights(self):
|
||||
|
@ -89,7 +97,7 @@ class BaseHead(nn.Module, metaclass=ABCMeta):
|
|||
labels (torch.Tensor): The target output of the model.
|
||||
Returns:
|
||||
dict: A dict containing field 'loss_cls'(mandatory)
|
||||
and 'top1_acc', 'top5_acc'(optional).
|
||||
and 'topk_acc'(optional).
|
||||
"""
|
||||
losses = dict()
|
||||
if labels.shape == torch.Size([]):
|
||||
|
@ -103,11 +111,11 @@ class BaseHead(nn.Module, metaclass=ABCMeta):
|
|||
|
||||
if not self.multi_class and cls_score.size() != labels.size():
|
||||
top_k_acc = top_k_accuracy(cls_score.detach().cpu().numpy(),
|
||||
labels.detach().cpu().numpy(), (1, 5))
|
||||
losses['top1_acc'] = torch.tensor(
|
||||
top_k_acc[0], device=cls_score.device)
|
||||
losses['top5_acc'] = torch.tensor(
|
||||
top_k_acc[1], device=cls_score.device)
|
||||
labels.detach().cpu().numpy(),
|
||||
self.topk)
|
||||
for k, a in zip(self.topk, top_k_acc):
|
||||
losses[f'top{k}_acc'] = torch.tensor(
|
||||
a, device=cls_score.device)
|
||||
|
||||
elif self.multi_class and self.label_smooth_eps != 0:
|
||||
labels = ((1 - self.label_smooth_eps) * labels +
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .base import BaseGCN
|
||||
from .skeleton_gcn import SkeletonGCN
|
||||
from .stgcn_backbone import STGCN
|
||||
from .stgcn_head import STGCNHead
|
||||
|
||||
__all__ = ['BaseGCN', 'SkeletonGCN', 'STGCN', 'STGCNHead']
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Refer to: https://github.com/open-mmlab/mmaction2/blob/master/mmaction/models/skeleton_gcn/base.py
|
||||
from easycv.models import builder
|
||||
from easycv.models.base import BaseModel
|
||||
|
||||
|
||||
class BaseGCN(BaseModel):
|
||||
"""Base class for GCN-based action recognition.
|
||||
|
||||
All GCN-based recognizers should subclass it.
|
||||
All subclass should overwrite:
|
||||
|
||||
- Methods:``forward_train``, supporting to forward when training.
|
||||
- Methods:``forward_test``, supporting to forward when testing.
|
||||
|
||||
Args:
|
||||
backbone (dict): Backbone modules to extract feature.
|
||||
cls_head (dict | None): Classification head to process feature.
|
||||
Default: None.
|
||||
train_cfg (dict | None): Config for training. Default: None.
|
||||
test_cfg (dict | None): Config for testing. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, backbone, cls_head=None, train_cfg=None, test_cfg=None):
|
||||
super().__init__()
|
||||
self.backbone = builder.build_backbone(backbone)
|
||||
self.cls_head = builder.build_head(cls_head) if cls_head else None
|
||||
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@property
|
||||
def with_cls_head(self):
|
||||
"""bool: whether the recognizer has a cls_head"""
|
||||
return hasattr(self, 'cls_head') and self.cls_head is not None
|
||||
|
||||
def forward(self, keypoint, label=None, mode='train', **kwargs):
|
||||
"""Define the computation performed at every call."""
|
||||
if mode == 'train':
|
||||
if label is None:
|
||||
raise ValueError('Label should not be None.')
|
||||
return self.forward_train(keypoint, label, **kwargs)
|
||||
|
||||
return self.forward_test(keypoint, **kwargs)
|
||||
|
||||
def extract_feat(self, skeletons):
|
||||
"""Extract features through a backbone.
|
||||
|
||||
Args:
|
||||
skeletons (torch.Tensor): The input skeletons.
|
||||
|
||||
Returns:
|
||||
torch.tensor: The extracted features.
|
||||
"""
|
||||
x = self.backbone(skeletons)
|
||||
return x
|
||||
|
||||
def train_step(self, data_batch, optimizer, **kwargs):
|
||||
"""The iteration step during training.
|
||||
|
||||
This method defines an iteration step during training, except for the
|
||||
back propagation and optimizer updating, which are done in an optimizer
|
||||
hook. Note that in some complicated cases or models, the whole process
|
||||
including back propagation and optimizer updating is also defined in
|
||||
this method, such as GAN.
|
||||
|
||||
Args:
|
||||
data_batch (dict): The output of dataloader.
|
||||
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
|
||||
runner is passed to ``train_step()``. This argument is unused
|
||||
and reserved.
|
||||
|
||||
Returns:
|
||||
dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
|
||||
``num_samples``.
|
||||
``loss`` is a tensor for back propagation, which can be a
|
||||
weighted sum of multiple losses.
|
||||
``log_vars`` contains all the variables to be sent to the
|
||||
logger.
|
||||
``num_samples`` indicates the batch size (when the model is
|
||||
DDP, it means the batch size on each GPU), which is used for
|
||||
averaging the logs.
|
||||
"""
|
||||
skeletons = data_batch['keypoint']
|
||||
label = data_batch['label']
|
||||
label = label.squeeze(-1)
|
||||
|
||||
losses = self(skeletons, label, return_loss=True)
|
||||
|
||||
loss, log_vars = self._parse_losses(losses)
|
||||
outputs = dict(
|
||||
loss=loss, log_vars=log_vars, num_samples=len(skeletons.data))
|
||||
|
||||
return outputs
|
||||
|
||||
def val_step(self, data_batch, optimizer, **kwargs):
|
||||
"""The iteration step during validation.
|
||||
|
||||
This method shares the same signature as :func:`train_step`, but used
|
||||
during val epochs. Note that the evaluation after training epochs is
|
||||
not implemented with this method, but an evaluation hook.
|
||||
"""
|
||||
skeletons = data_batch['keypoint']
|
||||
label = data_batch['label']
|
||||
|
||||
losses = self(skeletons, label, return_loss=True)
|
||||
|
||||
loss, log_vars = self._parse_losses(losses)
|
||||
outputs = dict(
|
||||
loss=loss, log_vars=log_vars, num_samples=len(skeletons.data))
|
||||
|
||||
return outputs
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Refer to: https://github.com/open-mmlab/mmaction2/blob/master/mmaction/models/skeleton_gcn/skeletongcn.py
|
||||
from easycv.models.builder import MODELS
|
||||
from .base import BaseGCN
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SkeletonGCN(BaseGCN):
|
||||
"""Spatial temporal graph convolutional networks."""
|
||||
|
||||
def forward_train(self, skeletons, labels, **kwargs):
|
||||
"""Defines the computation performed at every call when training."""
|
||||
assert self.with_cls_head
|
||||
losses = dict()
|
||||
|
||||
x = self.extract_feat(skeletons)
|
||||
output = self.cls_head(x)
|
||||
gt_labels = labels.squeeze(-1)
|
||||
loss = self.cls_head.loss(output, gt_labels)
|
||||
losses.update(loss)
|
||||
|
||||
return losses
|
||||
|
||||
def forward_test(self, skeletons, **kwargs):
|
||||
"""Defines the computation performed at every call when evaluation and
|
||||
testing."""
|
||||
x = self.extract_feat(skeletons)
|
||||
assert self.with_cls_head
|
||||
output = self.cls_head(x)
|
||||
|
||||
result = {'prob': output.cpu()}
|
||||
return result
|
|
@ -0,0 +1,283 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Refer to: https://github.com/open-mmlab/mmaction2/blob/master/mmaction/models/backbones/stgcn.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import constant_init, kaiming_init, normal_init
|
||||
from mmcv.utils import _BatchNorm
|
||||
|
||||
from easycv.models import BACKBONES
|
||||
from easycv.utils.checkpoint import load_checkpoint
|
||||
from easycv.utils.logger import get_root_logger
|
||||
from .utils import Graph
|
||||
|
||||
|
||||
def zero(x):
|
||||
"""return zero."""
|
||||
return 0
|
||||
|
||||
|
||||
def identity(x):
|
||||
"""return input itself."""
|
||||
return x
|
||||
|
||||
|
||||
class STGCNBlock(nn.Module):
|
||||
"""Applies a spatial temporal graph convolution over an input graph
|
||||
sequence.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input sequence data
|
||||
out_channels (int): Number of channels produced by the convolution
|
||||
kernel_size (tuple): Size of the temporal convolving kernel and
|
||||
graph convolving kernel
|
||||
stride (int, optional): Stride of the temporal convolution. Default: 1
|
||||
dropout (int, optional): Dropout rate of the final output. Default: 0
|
||||
residual (bool, optional): If ``True``, applies a residual mechanism.
|
||||
Default: ``True``
|
||||
|
||||
Shape:
|
||||
- Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)`
|
||||
format
|
||||
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
|
||||
- Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out},
|
||||
V)` format
|
||||
- Output[1]: Graph adjacency matrix for output data in :math:`(K, V,
|
||||
V)` format
|
||||
|
||||
where
|
||||
:math:`N` is a batch size,
|
||||
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]
|
||||
`,
|
||||
:math:`T_{in}/T_{out}` is a length of input/output sequence,
|
||||
:math:`V` is the number of graph nodes.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dropout=0,
|
||||
residual=True):
|
||||
super().__init__()
|
||||
|
||||
assert len(kernel_size) == 2
|
||||
assert kernel_size[0] % 2 == 1
|
||||
padding = ((kernel_size[0] - 1) // 2, 0)
|
||||
|
||||
self.gcn = ConvTemporalGraphical(in_channels, out_channels,
|
||||
kernel_size[1])
|
||||
self.tcn = nn.Sequential(
|
||||
nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels, (kernel_size[0], 1),
|
||||
(stride, 1), padding), nn.BatchNorm2d(out_channels),
|
||||
nn.Dropout(dropout, inplace=True))
|
||||
|
||||
if not residual:
|
||||
self.residual = zero
|
||||
|
||||
elif (in_channels == out_channels) and (stride == 1):
|
||||
self.residual = identity
|
||||
|
||||
else:
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=(stride, 1)), nn.BatchNorm2d(out_channels))
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x, adj_mat):
|
||||
"""Defines the computation performed at every call."""
|
||||
res = self.residual(x)
|
||||
x, adj_mat = self.gcn(x, adj_mat)
|
||||
x = self.tcn(x) + res
|
||||
|
||||
return self.relu(x), adj_mat
|
||||
|
||||
|
||||
class ConvTemporalGraphical(nn.Module):
|
||||
"""The basic module for applying a graph convolution.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input sequence data
|
||||
out_channels (int): Number of channels produced by the convolution
|
||||
kernel_size (int): Size of the graph convolving kernel
|
||||
t_kernel_size (int): Size of the temporal convolving kernel
|
||||
t_stride (int, optional): Stride of the temporal convolution.
|
||||
Default: 1
|
||||
t_padding (int, optional): Temporal zero-padding added to both sides
|
||||
of the input. Default: 0
|
||||
t_dilation (int, optional): Spacing between temporal kernel elements.
|
||||
Default: 1
|
||||
bias (bool, optional): If ``True``, adds a learnable bias to the
|
||||
output. Default: ``True``
|
||||
|
||||
Shape:
|
||||
- Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)`
|
||||
format
|
||||
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
|
||||
- Output[0]: Output graph sequence in :math:`(N, out_channels, T_{out}
|
||||
, V)` format
|
||||
- Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)
|
||||
` format
|
||||
|
||||
where
|
||||
:math:`N` is a batch size,
|
||||
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]
|
||||
`,
|
||||
:math:`T_{in}/T_{out}` is a length of input/output sequence,
|
||||
:math:`V` is the number of graph nodes.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
t_kernel_size=1,
|
||||
t_stride=1,
|
||||
t_padding=0,
|
||||
t_dilation=1,
|
||||
bias=True):
|
||||
super().__init__()
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels * kernel_size,
|
||||
kernel_size=(t_kernel_size, 1),
|
||||
padding=(t_padding, 0),
|
||||
stride=(t_stride, 1),
|
||||
dilation=(t_dilation, 1),
|
||||
bias=bias)
|
||||
|
||||
def forward(self, x, adj_mat):
|
||||
"""Defines the computation performed at every call."""
|
||||
assert adj_mat.size(0) == self.kernel_size
|
||||
|
||||
x = self.conv(x)
|
||||
|
||||
n, kc, t, v = x.size()
|
||||
x = x.view(n, self.kernel_size, kc // self.kernel_size, t, v)
|
||||
x = torch.einsum('nkctv,kvw->nctw', (x, adj_mat))
|
||||
|
||||
return x.contiguous(), adj_mat
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class STGCN(nn.Module):
|
||||
"""Backbone of Spatial temporal graph convolutional networks.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input data.
|
||||
graph_cfg (dict): The arguments for building the graph.
|
||||
edge_importance_weighting (bool): If ``True``, adds a learnable
|
||||
importance weighting to the edges of the graph. Default: True.
|
||||
data_bn (bool): If 'True', adds data normalization to the inputs.
|
||||
Default: True.
|
||||
pretrained (str | None): Name of pretrained model.
|
||||
**kwargs (optional): Other parameters for graph convolution units.
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})`
|
||||
- Output: :math:`(N, num_class)` where
|
||||
:math:`N` is a batch size,
|
||||
:math:`T_{in}` is a length of input sequence,
|
||||
:math:`V_{in}` is the number of graph nodes,
|
||||
:math:`M_{in}` is the number of instance in a frame.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
graph_cfg,
|
||||
edge_importance_weighting=True,
|
||||
data_bn=True,
|
||||
pretrained=None,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
# load graph
|
||||
self.graph = Graph(**graph_cfg)
|
||||
A = torch.tensor(
|
||||
self.graph.A, dtype=torch.float32, requires_grad=False)
|
||||
self.register_buffer('A', A)
|
||||
|
||||
# build networks
|
||||
spatial_kernel_size = A.size(0)
|
||||
temporal_kernel_size = 9
|
||||
kernel_size = (temporal_kernel_size, spatial_kernel_size)
|
||||
self.data_bn = nn.BatchNorm1d(in_channels *
|
||||
A.size(1)) if data_bn else identity
|
||||
|
||||
kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'}
|
||||
self.st_gcn_networks = nn.ModuleList((
|
||||
STGCNBlock(
|
||||
in_channels, 64, kernel_size, 1, residual=False, **kwargs0),
|
||||
STGCNBlock(64, 64, kernel_size, 1, **kwargs),
|
||||
STGCNBlock(64, 64, kernel_size, 1, **kwargs),
|
||||
STGCNBlock(64, 64, kernel_size, 1, **kwargs),
|
||||
STGCNBlock(64, 128, kernel_size, 2, **kwargs),
|
||||
STGCNBlock(128, 128, kernel_size, 1, **kwargs),
|
||||
STGCNBlock(128, 128, kernel_size, 1, **kwargs),
|
||||
STGCNBlock(128, 256, kernel_size, 2, **kwargs),
|
||||
STGCNBlock(256, 256, kernel_size, 1, **kwargs),
|
||||
STGCNBlock(256, 256, kernel_size, 1, **kwargs),
|
||||
))
|
||||
|
||||
# initialize parameters for edge importance weighting
|
||||
if edge_importance_weighting:
|
||||
self.edge_importance = nn.ParameterList([
|
||||
nn.Parameter(torch.ones(self.A.size()))
|
||||
for i in self.st_gcn_networks
|
||||
])
|
||||
else:
|
||||
self.edge_importance = [1 for _ in self.st_gcn_networks]
|
||||
|
||||
self.pretrained = pretrained
|
||||
|
||||
def init_weights(self):
|
||||
"""Initiate the parameters either from existing checkpoint or from
|
||||
scratch."""
|
||||
if isinstance(self.pretrained, str):
|
||||
logger = get_root_logger()
|
||||
logger.info(f'load model from: {self.pretrained}')
|
||||
|
||||
load_checkpoint(self, self.pretrained, strict=False, logger=logger)
|
||||
|
||||
elif self.pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m)
|
||||
elif isinstance(m, nn.Linear):
|
||||
normal_init(m)
|
||||
elif isinstance(m, _BatchNorm):
|
||||
constant_init(m, 1)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
def forward(self, x):
|
||||
"""Defines the computation performed at every call.
|
||||
Args:
|
||||
x (torch.Tensor): The input data.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output of the module.
|
||||
"""
|
||||
# data normalization
|
||||
x = x.float()
|
||||
n, c, t, v, m = x.size() # bs 3 300 25(17) 2
|
||||
x = x.permute(0, 4, 3, 1, 2).contiguous() # N M V C T
|
||||
x = x.view(n * m, v * c, t)
|
||||
x = self.data_bn(x)
|
||||
x = x.view(n, m, v, c, t)
|
||||
x = x.permute(0, 1, 3, 4, 2).contiguous()
|
||||
x = x.view(n * m, c, t, v) # bsx2 3 300 25(17)
|
||||
|
||||
# forward
|
||||
for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
|
||||
x, _ = gcn(x, self.A * importance)
|
||||
|
||||
return x
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Refer to: https://github.com/open-mmlab/mmaction2/blob/master/mmaction/models/heads/stgcn_head.py
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import normal_init
|
||||
|
||||
from easycv.models import HEADS
|
||||
from easycv.models.video_recognition.heads.base_head import BaseHead
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class STGCNHead(BaseHead):
|
||||
"""The classification head for STGCN.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes to be classified.
|
||||
in_channels (int): Number of channels in input feature.
|
||||
loss_cls (dict): Config for building loss.
|
||||
Default: dict(type='CrossEntropyLoss')
|
||||
spatial_type (str): Pooling type in spatial dimension. Default: 'avg'.
|
||||
num_person (int): Number of person. Default: 2.
|
||||
init_std (float): Std value for Initiation. Default: 0.01.
|
||||
kwargs (dict, optional): Any keyword argument to be used to initialize
|
||||
the head.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
in_channels,
|
||||
loss_cls=dict(type='CrossEntropyLoss'),
|
||||
spatial_type='avg',
|
||||
num_person=2,
|
||||
init_std=0.01,
|
||||
**kwargs):
|
||||
super().__init__(num_classes, in_channels, loss_cls, **kwargs)
|
||||
|
||||
self.spatial_type = spatial_type
|
||||
self.in_channels = in_channels
|
||||
self.num_classes = num_classes
|
||||
self.num_person = num_person
|
||||
self.init_std = init_std
|
||||
|
||||
self.pool = None
|
||||
if self.spatial_type == 'avg':
|
||||
self.pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
elif self.spatial_type == 'max':
|
||||
self.pool = nn.AdaptiveMaxPool2d((1, 1))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.fc = nn.Conv2d(self.in_channels, self.num_classes, kernel_size=1)
|
||||
|
||||
def init_weights(self):
|
||||
normal_init(self.fc, std=self.init_std)
|
||||
|
||||
def forward(self, x):
|
||||
# global pooling
|
||||
assert self.pool is not None
|
||||
x = self.pool(x)
|
||||
x = x.view(x.shape[0] // self.num_person, self.num_person, -1, 1,
|
||||
1).mean(dim=1)
|
||||
|
||||
# prediction
|
||||
x = self.fc(x)
|
||||
x = x.view(x.shape[0], -1)
|
||||
|
||||
return x
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .graph import Graph
|
||||
|
||||
__all__ = ['Graph']
|
|
@ -0,0 +1,198 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Refer to: https://github.com/open-mmlab/mmaction2/blob/master/mmaction/models/skeleton_gcn/utils/graph.py
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_hop_distance(num_node, edge, max_hop=1):
|
||||
adj_mat = np.zeros((num_node, num_node))
|
||||
for i, j in edge:
|
||||
adj_mat[i, j] = 1
|
||||
adj_mat[j, i] = 1
|
||||
|
||||
# compute hop steps
|
||||
hop_dis = np.zeros((num_node, num_node)) + np.inf
|
||||
transfer_mat = [
|
||||
np.linalg.matrix_power(adj_mat, d) for d in range(max_hop + 1)
|
||||
]
|
||||
arrive_mat = (np.stack(transfer_mat) > 0)
|
||||
for d in range(max_hop, -1, -1):
|
||||
hop_dis[arrive_mat[d]] = d
|
||||
return hop_dis
|
||||
|
||||
|
||||
def normalize_digraph(adj_matrix):
|
||||
Dl = np.sum(adj_matrix, 0)
|
||||
num_nodes = adj_matrix.shape[0]
|
||||
Dn = np.zeros((num_nodes, num_nodes))
|
||||
for i in range(num_nodes):
|
||||
if Dl[i] > 0:
|
||||
Dn[i, i] = Dl[i]**(-1)
|
||||
norm_matrix = np.dot(adj_matrix, Dn)
|
||||
return norm_matrix
|
||||
|
||||
|
||||
def edge2mat(link, num_node):
|
||||
A = np.zeros((num_node, num_node))
|
||||
for i, j in link:
|
||||
A[j, i] = 1
|
||||
return A
|
||||
|
||||
|
||||
class Graph:
|
||||
"""The Graph to model the skeletons extracted by the openpose.
|
||||
|
||||
Args:
|
||||
layout (str): must be one of the following candidates
|
||||
- openpose: 18 or 25 joints. For more information, please refer to:
|
||||
https://github.com/CMU-Perceptual-Computing-Lab/openpose#output
|
||||
- ntu-rgb+d: Is consists of 25 joints. For more information, please
|
||||
refer to https://github.com/shahroudy/NTURGB-D
|
||||
|
||||
strategy (str): must be one of the follow candidates
|
||||
- uniform: Uniform Labeling
|
||||
- distance: Distance Partitioning
|
||||
- spatial: Spatial Configuration
|
||||
For more information, please refer to the section 'Partition
|
||||
Strategies' in our paper (https://arxiv.org/abs/1801.07455).
|
||||
|
||||
max_hop (int): the maximal distance between two connected nodes.
|
||||
Default: 1
|
||||
dilation (int): controls the spacing between the kernel points.
|
||||
Default: 1
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
layout='openpose-18',
|
||||
strategy='uniform',
|
||||
max_hop=1,
|
||||
dilation=1):
|
||||
self.max_hop = max_hop
|
||||
self.dilation = dilation
|
||||
|
||||
assert layout in [
|
||||
'openpose-18', 'openpose-25', 'ntu-rgb+d', 'ntu_edge', 'coco'
|
||||
]
|
||||
assert strategy in ['uniform', 'distance', 'spatial', 'agcn']
|
||||
self.get_edge(layout)
|
||||
self.hop_dis = get_hop_distance(
|
||||
self.num_node, self.edge, max_hop=max_hop)
|
||||
self.get_adjacency(strategy)
|
||||
|
||||
def __str__(self):
|
||||
return self.A
|
||||
|
||||
def get_edge(self, layout):
|
||||
"""This method returns the edge pairs of the layout."""
|
||||
|
||||
if layout == 'openpose-18':
|
||||
self.num_node = 18
|
||||
self_link = [(i, i) for i in range(self.num_node)]
|
||||
neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5),
|
||||
(13, 12), (12, 11), (10, 9), (9, 8), (11, 5),
|
||||
(8, 2), (5, 1), (2, 1), (0, 1), (15, 0), (14, 0),
|
||||
(17, 15), (16, 14)]
|
||||
self.edge = self_link + neighbor_link
|
||||
self.center = 1
|
||||
elif layout == 'openpose-25':
|
||||
self.num_node = 25
|
||||
self_link = [(i, i) for i in range(self.num_node)]
|
||||
neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5), (23, 22),
|
||||
(22, 11), (24, 11), (11, 10), (10, 9), (9, 8),
|
||||
(20, 19), (19, 14), (21, 14), (14, 13), (13, 12),
|
||||
(12, 8), (8, 1), (5, 1), (2, 1), (0, 1), (15, 0),
|
||||
(16, 0), (17, 15), (18, 16)]
|
||||
self.self_link = self_link
|
||||
self.neighbor_link = neighbor_link
|
||||
self.edge = self_link + neighbor_link
|
||||
self.center = 1
|
||||
elif layout == 'ntu-rgb+d':
|
||||
self.num_node = 25
|
||||
self_link = [(i, i) for i in range(self.num_node)]
|
||||
neighbor_1base = [(1, 2), (2, 21), (3, 21),
|
||||
(4, 3), (5, 21), (6, 5), (7, 6), (8, 7), (9, 21),
|
||||
(10, 9), (11, 10), (12, 11), (13, 1), (14, 13),
|
||||
(15, 14), (16, 15), (17, 1), (18, 17), (19, 18),
|
||||
(20, 19), (22, 23), (23, 8), (24, 25), (25, 12)]
|
||||
neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]
|
||||
self.self_link = self_link
|
||||
self.neighbor_link = neighbor_link
|
||||
self.edge = self_link + neighbor_link
|
||||
self.center = 21 - 1
|
||||
elif layout == 'ntu_edge':
|
||||
self.num_node = 24
|
||||
self_link = [(i, i) for i in range(self.num_node)]
|
||||
neighbor_1base = [(1, 2), (3, 2), (4, 3), (5, 2), (6, 5), (7, 6),
|
||||
(8, 7), (9, 2), (10, 9), (11, 10), (12, 11),
|
||||
(13, 1), (14, 13), (15, 14), (16, 15), (17, 1),
|
||||
(18, 17), (19, 18), (20, 19), (21, 22), (22, 8),
|
||||
(23, 24), (24, 12)]
|
||||
neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]
|
||||
self.edge = self_link + neighbor_link
|
||||
self.center = 2
|
||||
elif layout == 'coco':
|
||||
self.num_node = 17
|
||||
self_link = [(i, i) for i in range(self.num_node)]
|
||||
neighbor_1base = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13],
|
||||
[6, 12], [7, 13], [6, 7], [8, 6], [9, 7],
|
||||
[10, 8], [11, 9], [2, 3], [2, 1], [3, 1], [4, 2],
|
||||
[5, 3], [4, 6], [5, 7]]
|
||||
neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]
|
||||
self.edge = self_link + neighbor_link
|
||||
self.center = 0
|
||||
else:
|
||||
raise ValueError(f'{layout} is not supported.')
|
||||
|
||||
def get_adjacency(self, strategy):
|
||||
"""This method returns the adjacency matrix according to strategy."""
|
||||
|
||||
valid_hop = range(0, self.max_hop + 1, self.dilation)
|
||||
adjacency = np.zeros((self.num_node, self.num_node))
|
||||
for hop in valid_hop:
|
||||
adjacency[self.hop_dis == hop] = 1
|
||||
normalize_adjacency = normalize_digraph(adjacency)
|
||||
|
||||
if strategy == 'uniform':
|
||||
A = np.zeros((1, self.num_node, self.num_node))
|
||||
A[0] = normalize_adjacency
|
||||
self.A = A
|
||||
elif strategy == 'distance':
|
||||
A = np.zeros((len(valid_hop), self.num_node, self.num_node))
|
||||
for i, hop in enumerate(valid_hop):
|
||||
A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis ==
|
||||
hop]
|
||||
self.A = A
|
||||
elif strategy == 'spatial':
|
||||
A = []
|
||||
for hop in valid_hop:
|
||||
a_root = np.zeros((self.num_node, self.num_node))
|
||||
a_close = np.zeros((self.num_node, self.num_node))
|
||||
a_further = np.zeros((self.num_node, self.num_node))
|
||||
for i in range(self.num_node):
|
||||
for j in range(self.num_node):
|
||||
if self.hop_dis[j, i] == hop:
|
||||
if self.hop_dis[j, self.center] == self.hop_dis[
|
||||
i, self.center]:
|
||||
a_root[j, i] = normalize_adjacency[j, i]
|
||||
elif self.hop_dis[j, self.center] > self.hop_dis[
|
||||
i, self.center]:
|
||||
a_close[j, i] = normalize_adjacency[j, i]
|
||||
else:
|
||||
a_further[j, i] = normalize_adjacency[j, i]
|
||||
if hop == 0:
|
||||
A.append(a_root)
|
||||
else:
|
||||
A.append(a_root + a_close)
|
||||
A.append(a_further)
|
||||
A = np.stack(A)
|
||||
self.A = A
|
||||
elif strategy == 'agcn':
|
||||
A = []
|
||||
link_mat = edge2mat(self.self_link, self.num_node)
|
||||
In = normalize_digraph(edge2mat(self.neighbor_link, self.num_node))
|
||||
outward = [(j, i) for (i, j) in self.neighbor_link]
|
||||
Out = normalize_digraph(edge2mat(outward, self.num_node))
|
||||
A = np.stack((link_mat, In, Out))
|
||||
self.A = A
|
||||
else:
|
||||
raise ValueError('Do Not Exist This Strategy')
|
|
@ -5,6 +5,7 @@ import os
|
|||
import pickle
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.parallel import collate, scatter_kwargs
|
||||
|
@ -419,9 +420,14 @@ class PredictorV2(object):
|
|||
inputs = [inputs]
|
||||
|
||||
results_list = []
|
||||
|
||||
prog_bar = mmcv.ProgressBar(len(inputs))
|
||||
for i in range(0, len(inputs), self.batch_size):
|
||||
batch_inputs = inputs[i:min(len(inputs), i + self.batch_size)]
|
||||
batch_outputs = self.input_processor(batch_inputs)
|
||||
if len(batch_outputs) < 1:
|
||||
results_list.append(batch_outputs)
|
||||
continue
|
||||
batch_outputs = self._to_device(batch_outputs)
|
||||
batch_outputs = self.model_forward(batch_outputs)
|
||||
results = self.output_processor(batch_outputs)
|
||||
|
@ -433,6 +439,8 @@ class PredictorV2(object):
|
|||
else:
|
||||
results_list.append(results)
|
||||
|
||||
prog_bar.update(len(batch_inputs))
|
||||
|
||||
# TODO: support append to file
|
||||
if self.save_results:
|
||||
self.dump(results_list, self.save_path)
|
||||
|
|
|
@ -221,8 +221,8 @@ class PoseTopDownInputProcessor(InputProcessor):
|
|||
|
||||
output_person_info = []
|
||||
for person_result in person_results:
|
||||
box = person_result['bbox'] # x,y,w,h
|
||||
box = [box[0], box[1], box[2] - box[0], box[3] - box[1]]
|
||||
box = person_result['bbox'] # x,y,x,y
|
||||
box = [box[0], box[1], box[2] - box[0], box[3] - box[1]] # x,y,w,h
|
||||
center, scale = _box2cs(self.cfg.data_cfg['image_size'], box)
|
||||
data = {
|
||||
'image_id':
|
||||
|
@ -277,6 +277,9 @@ class PoseTopDownInputProcessor(InputProcessor):
|
|||
for res in self.process_single(inp):
|
||||
batch_outputs.append(res)
|
||||
|
||||
if len(batch_outputs) < 1:
|
||||
return batch_outputs
|
||||
|
||||
batch_outputs = self._collate_fn(batch_outputs)
|
||||
batch_outputs['img_metas']._data = [[
|
||||
img_meta[i] for img_meta in batch_outputs['img_metas']._data
|
||||
|
@ -384,7 +387,7 @@ class PoseTopDownPredictor(PredictorV2):
|
|||
image,
|
||||
keypoints,
|
||||
radius=4,
|
||||
thickness=1,
|
||||
thickness=3,
|
||||
kpt_score_thr=0.3,
|
||||
bbox_color='green',
|
||||
show=False,
|
||||
|
|
|
@ -196,3 +196,118 @@ class VideoClassificationPredictor(PredictorV2):
|
|||
|
||||
def get_output_processor(self):
|
||||
return VideoClsOutputProcessor(self.label_map, self.topk)
|
||||
|
||||
|
||||
class STGCNInputProcessor(InputProcessor):
|
||||
|
||||
def _load_input(self, input):
|
||||
"""Prepare input sample.
|
||||
Args:
|
||||
input (dict): Input sample dict. e.g.
|
||||
{
|
||||
'frame_dir': '',
|
||||
'img_shape': (1080, 1920),
|
||||
'original_shape': (1080, 1920),
|
||||
'total_frames': 40,
|
||||
'keypoint': (2, 40, 17, 2), # shape = (num_person, num_frame, num_keypoints, 2)
|
||||
'keypoint_score': (2, 40, 17),
|
||||
'modality': 'Pose',
|
||||
'start_index': 1
|
||||
}.
|
||||
"""
|
||||
assert isinstance(input, dict)
|
||||
|
||||
keypoint = input['keypoint']
|
||||
|
||||
assert len(keypoint.shape) == 4
|
||||
assert keypoint.shape[-1] in [2, 3]
|
||||
|
||||
if keypoint.shape[-1] == 3:
|
||||
if input.get('keypoint_score', None) is None:
|
||||
input['keypoint_score'] = keypoint[..., -1]
|
||||
|
||||
keypoint = keypoint[..., :2]
|
||||
input['keypoint'] = keypoint
|
||||
|
||||
return input
|
||||
|
||||
|
||||
@PREDICTORS.register_module()
|
||||
class STGCNPredictor(PredictorV2):
|
||||
"""STGCN predict pipeline.
|
||||
Args:
|
||||
model_path (str): Path of model path.
|
||||
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
|
||||
ori_image_size (Optinal[list|tuple]): Original image or video frame size (weight, height).
|
||||
batch_size (int): batch size for forward.
|
||||
label_map ((Optinal[list|tuple])): List or file of labels.
|
||||
device (str | torch.device): Support str('cuda' or 'cpu') or torch.device, if is None, detect device automatically.
|
||||
save_results (bool): Whether to save predict results.
|
||||
save_path (str): File path for saving results, only valid when `save_results` is True.
|
||||
pipelines (list[dict]): Data pipeline configs.
|
||||
input_processor_threads (int): Number of processes to process inputs.
|
||||
mode (str): The image mode into the model.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_path,
|
||||
config_file=None,
|
||||
ori_image_size=None,
|
||||
batch_size=1,
|
||||
label_map=None,
|
||||
topk=1,
|
||||
device=None,
|
||||
save_results=False,
|
||||
save_path=None,
|
||||
pipelines=None,
|
||||
input_processor_threads=8,
|
||||
mode='RGB',
|
||||
*args,
|
||||
**kwargs):
|
||||
super(STGCNPredictor, self).__init__(
|
||||
model_path,
|
||||
config_file=config_file,
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
save_results=save_results,
|
||||
save_path=save_path,
|
||||
pipelines=pipelines,
|
||||
input_processor_threads=input_processor_threads,
|
||||
mode=mode,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
if ori_image_size is not None:
|
||||
w, h = ori_image_size
|
||||
for pipeline in self.cfg.test_pipeline:
|
||||
if pipeline['type'] == 'PoseNormalize':
|
||||
pipeline['mean'] = (w // 2, h // 2, .5)
|
||||
pipeline['max_value'] = (w, h, 1.)
|
||||
|
||||
self.topk = topk
|
||||
if label_map is None:
|
||||
if 'CLASSES' in self.cfg:
|
||||
class_list = self.cfg.get('CLASSES', [])
|
||||
elif 'num_classes' in self.cfg:
|
||||
class_list = list(range(self.cfg.num_classes))
|
||||
class_list = [str(i) for i in class_list]
|
||||
else:
|
||||
class_list = []
|
||||
elif isinstance(label_map, str):
|
||||
with io.open(label_map, 'r') as f:
|
||||
class_list = f.readlines()
|
||||
elif isinstance(label_map, (tuple, list)):
|
||||
class_list = label_map
|
||||
|
||||
self.label_map = [i.strip() for i in class_list]
|
||||
|
||||
def get_input_processor(self):
|
||||
return STGCNInputProcessor(
|
||||
self.cfg,
|
||||
pipelines=self.pipelines,
|
||||
batch_size=self.batch_size,
|
||||
threads=self.input_processor_threads,
|
||||
mode=self.mode)
|
||||
|
||||
def get_output_processor(self):
|
||||
return VideoClsOutputProcessor(self.label_map, self.topk)
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from easycv.models.builder import build_model
|
||||
|
||||
|
||||
class STGCNTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
def _get_model(self):
|
||||
model_cfg = dict(
|
||||
type='SkeletonGCN',
|
||||
backbone=dict(
|
||||
type='STGCN',
|
||||
in_channels=3,
|
||||
edge_importance_weighting=True,
|
||||
graph_cfg=dict(layout='coco', strategy='spatial')),
|
||||
cls_head=dict(
|
||||
type='STGCNHead',
|
||||
num_classes=60,
|
||||
in_channels=256,
|
||||
loss_cls=dict(type='CrossEntropyLoss')),
|
||||
train_cfg=None,
|
||||
test_cfg=None)
|
||||
model = build_model(model_cfg)
|
||||
return model
|
||||
|
||||
def test_train(self):
|
||||
model = self._get_model()
|
||||
model.train()
|
||||
batch_size = 2
|
||||
keypoints = torch.randn([batch_size, 3, 300, 17, 2])
|
||||
label = torch.randint(0, 60, (batch_size, ))
|
||||
output = model(keypoint=keypoints, label=label)
|
||||
self.assertIn('loss_cls', output)
|
||||
self.assertIn('top1_acc', output)
|
||||
self.assertIn('top5_acc', output)
|
||||
|
||||
def test_infer(self):
|
||||
model = self._get_model()
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
keypoints = torch.randn([1, 3, 300, 17, 2])
|
||||
output = model(keypoint=keypoints, mode='test')
|
||||
self.assertEqual(output['prob'].shape, (1, 60))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -2,16 +2,13 @@
|
|||
"""
|
||||
isort:skip_file
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from easycv.predictors.video_classifier import VideoClassificationPredictor
|
||||
from easycv.utils.test_util import clean_up, get_tmp_dir
|
||||
from easycv.predictors.video_classifier import VideoClassificationPredictor, STGCNPredictor
|
||||
from tests.ut_config import (PRETRAINED_MODEL_X3D_XS,
|
||||
VIDEO_DATA_SMALL_RAW_LOCAL)
|
||||
VIDEO_DATA_SMALL_RAW_LOCAL, BASE_LOCAL_PATH)
|
||||
|
||||
|
||||
class VideoClassificationPredictorTest(unittest.TestCase):
|
||||
|
@ -54,5 +51,39 @@ class VideoClassificationPredictorTest(unittest.TestCase):
|
|||
self.assertEqual(len(res['class_probs']), 400)
|
||||
|
||||
|
||||
class STGCNPredictorTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
def test_single(self):
|
||||
checkpoint = os.path.join(
|
||||
BASE_LOCAL_PATH,
|
||||
'pretrained_models/video/stgcn/stgcn_80e_ntu60_xsub.pth')
|
||||
config_file = 'configs/video_recognition/stgcn/stgcn_80e_ntu60_xsub_keypoint.py'
|
||||
predict_op = STGCNPredictor(
|
||||
model_path=checkpoint, config_file=config_file)
|
||||
|
||||
h, w = 480, 853
|
||||
total_frames = 20
|
||||
num_person = 2
|
||||
inp = dict(
|
||||
frame_dir='',
|
||||
label=-1,
|
||||
img_shape=(h, w),
|
||||
original_shape=(h, w),
|
||||
start_index=0,
|
||||
modality='Pose',
|
||||
total_frames=total_frames,
|
||||
keypoint=np.random.random((num_person, total_frames, 17, 2)),
|
||||
keypoint_score=np.random.random((num_person, total_frames, 17)),
|
||||
)
|
||||
|
||||
results = predict_op([inp])[0]
|
||||
self.assertIn('class', results)
|
||||
self.assertIn('class_name', results)
|
||||
self.assertEqual(len(results['class_probs']), 60)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue