Support stgcn (#293)

* add stgcn
pull/299/head
Cathy0908 2023-03-02 19:13:10 +08:00 committed by GitHub
parent c73edeee1c
commit 4cf6f794e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 1676 additions and 23 deletions

View File

@ -0,0 +1,120 @@
_base_ = 'configs/base.py'
CLASSES = [
'drink water', 'eat meal/snack', 'brushing teeth', 'brushing hair', 'drop',
'pickup', 'throw', 'sitting down', 'standing up (from sitting position)',
'clapping', 'reading', 'writing', 'tear up paper', 'wear jacket',
'take off jacket', 'wear a shoe', 'take off a shoe', 'wear on glasses',
'take off glasses', 'put on a hat/cap', 'take off a hat/cap', 'cheer up',
'hand waving', 'kicking something', 'reach into pocket',
'hopping (one foot jumping)', 'jump up', 'make a phone call/answer phone',
'playing with phone/tablet', 'typing on a keyboard',
'pointing to something with finger', 'taking a selfie',
'check time (from watch)', 'rub two hands together', 'nod head/bow',
'shake head', 'wipe face', 'salute', 'put the palms together',
'cross hands in front (say stop)', 'sneeze/cough', 'staggering', 'falling',
'touch head (headache)', 'touch chest (stomachache/heart pain)',
'touch back (backache)', 'touch neck (neckache)',
'nausea or vomiting condition',
'use a fan (with hand or paper)/feeling warm',
'punching/slapping other person', 'kicking other person',
'pushing other person', 'pat on back of other person',
'point finger at the other person', 'hugging other person',
'giving something to other person', "touch other person's pocket",
'handshaking', 'walking towards each other',
'walking apart from each other'
]
model = dict(
type='SkeletonGCN',
backbone=dict(
type='STGCN',
in_channels=3,
edge_importance_weighting=True,
graph_cfg=dict(layout='coco', strategy='spatial')),
cls_head=dict(
type='STGCNHead',
num_classes=60,
in_channels=256,
loss_cls=dict(type='CrossEntropyLoss')),
train_cfg=None,
test_cfg=None)
dataset_type = 'VideoDataset'
ann_file_train = 'data/posec3d/ntu60_xsub_train.pkl'
ann_file_val = 'data/posec3d/ntu60_xsub_val.pkl'
train_pipeline = [
dict(type='PaddingWithLoop', clip_len=300),
dict(type='PoseDecode'),
dict(type='FormatGCNInput', input_format='NCTVM'),
dict(type='PoseNormalize'),
dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
dict(type='VideoToTensor', keys=['keypoint'])
]
val_pipeline = [
dict(type='PaddingWithLoop', clip_len=300),
dict(type='PoseDecode'),
dict(type='FormatGCNInput', input_format='NCTVM'),
dict(type='PoseNormalize'),
dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
dict(type='VideoToTensor', keys=['keypoint'])
]
test_pipeline = [
dict(type='PaddingWithLoop', clip_len=300),
dict(type='PoseDecode'),
dict(type='FormatGCNInput', input_format='NCTVM'),
dict(type='PoseNormalize'),
dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
dict(type='VideoToTensor', keys=['keypoint'])
]
data = dict(
imgs_per_gpu=16,
workers_per_gpu=2,
train=dict(
type=dataset_type,
data_source=dict(
type='PoseDataSourceForVideoRec',
ann_file=ann_file_train,
data_prefix='',
),
pipeline=train_pipeline),
val=dict(
type=dataset_type,
imgs_per_gpu=1,
data_source=dict(
type='PoseDataSourceForVideoRec',
ann_file=ann_file_val,
data_prefix='',
),
pipeline=val_pipeline),
test=dict(
type=dataset_type,
data_source=dict(
type='PoseDataSourceForVideoRec',
ann_file=ann_file_val,
data_prefix='',
),
pipeline=test_pipeline))
# optimizer
optimizer = dict(
type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001, nesterov=True)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[10, 50])
total_epochs = 80
# eval
eval_config = dict(initial=False, interval=1, gpu_collect=True)
eval_pipelines = [
dict(
mode='test',
data=data['val'],
dist_eval=True,
evaluators=[dict(type='ClsEvaluator', topk=(1, 5))],
)
]
log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])
checkpoint_config = dict(interval=1)

View File

@ -0,0 +1,217 @@
import argparse
import os
import os.path as osp
import shutil
import cv2
import mmcv
import numpy as np
import torch
from easycv.file.utils import is_url_path
from easycv.predictors.pose_predictor import PoseTopDownPredictor
from easycv.predictors.video_classifier import STGCNPredictor
try:
import moviepy.editor as mpy
except ImportError:
raise ImportError('Please install moviepy to enable output file')
FONTFACE = cv2.FONT_HERSHEY_DUPLEX
FONTSCALE = 0.75
FONTCOLOR = (255, 255, 255) # BGR, white
THICKNESS = 1
LINETYPE = 1
TMP_DIR = './tmp'
def parse_args():
parser = argparse.ArgumentParser(
description='Video classification demo based skeleton.')
parser.add_argument(
'--video',
default=
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/demos/videos/ntu_sample.avi',
help='video file/url')
parser.add_argument(
'--out_file',
default=f'{TMP_DIR}/demo_show.mp4',
help='output filename')
parser.add_argument(
'--config',
default=(
'configs/video_recognition/stgcn/stgcn_80e_ntu60_xsub_keypoint.py'
),
help='skeleton model config file path')
parser.add_argument(
'--checkpoint',
default=
('http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/video/skeleton_based/stgcn/stgcn_80e_ntu60_xsub.pth'
),
help='skeleton model checkpoint file/url')
parser.add_argument(
'--det-config',
default='configs/detection/yolox/yolox_s_8xb16_300e_coco.py',
help='human detection config file path')
parser.add_argument(
'--det-checkpoint',
default=
('http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_s_bs16_lr002/epoch_300.pt'
),
help='human detection checkpoint file/url')
parser.add_argument(
'--det-predictor-type',
default='YoloXPredictor',
help='detection predictor type')
parser.add_argument(
'--pose-config',
default='configs/pose/hrnet_w48_coco_256x192_udp.py',
help='human pose estimation config file path')
parser.add_argument(
'--pose-checkpoint',
default=
('http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/pose/top_down_hrnet/pose_hrnet_epoch_210_export.pt'
),
help='human pose estimation checkpoint file/url')
parser.add_argument(
'--bbox-thr',
type=float,
default=0.5,
help='the threshold of human detection score')
parser.add_argument(
'--device', type=str, default='cuda:0', help='CPU/CUDA device option')
parser.add_argument(
'--short-side',
type=int,
default=480,
help='specify the short-side length of the image')
args = parser.parse_args()
return args
def frame_extraction(video_path, short_side):
"""Extract frames given video_path.
Args:
video_path (str): The video_path.
"""
if is_url_path(video_path):
from torch.hub import download_url_to_file
cache_video_path = os.path.join(TMP_DIR, os.path.basename(video_path))
print(
'Download video file from remote to local path "{cache_video_path}"...'
)
download_url_to_file(video_path, cache_video_path)
video_path = cache_video_path
# Load the video, extract frames into ./tmp/video_name
target_dir = osp.join(TMP_DIR, osp.basename(osp.splitext(video_path)[0]))
os.makedirs(target_dir, exist_ok=True)
# Should be able to handle videos up to several hours
frame_tmpl = osp.join(target_dir, 'img_{:06d}.jpg')
vid = cv2.VideoCapture(video_path)
frames = []
frame_paths = []
flag, frame = vid.read()
cnt = 0
new_h, new_w = None, None
while flag:
if new_h is None:
h, w, _ = frame.shape
new_w, new_h = mmcv.rescale_size((w, h), (short_side, np.Inf))
frame = mmcv.imresize(frame, (new_w, new_h))
frames.append(frame)
frame_path = frame_tmpl.format(cnt + 1)
frame_paths.append(frame_path)
cv2.imwrite(frame_path, frame)
cnt += 1
flag, frame = vid.read()
return frame_paths, frames
def main():
args = parse_args()
if not osp.exists(TMP_DIR):
os.makedirs(TMP_DIR)
frame_paths, original_frames = frame_extraction(args.video,
args.short_side)
num_frame = len(frame_paths)
h, w, _ = original_frames[0].shape
# Get Human detection results
pose_predictor = PoseTopDownPredictor(
model_path=args.pose_checkpoint,
config_file=args.pose_config,
detection_predictor_config=dict(
type=args.det_predictor_type,
model_path=args.det_checkpoint,
config_file=args.det_config,
),
bbox_thr=args.bbox_thr,
cat_id=0, # person category id
)
video_cls_predictor = STGCNPredictor(
model_path=args.checkpoint,
config_file=args.config,
ori_image_size=(w, h),
label_map=None)
pose_results = pose_predictor(original_frames)
torch.cuda.empty_cache()
fake_anno = dict(
frame_dir='',
label=-1,
img_shape=(h, w),
original_shape=(h, w),
start_index=0,
modality='Pose',
total_frames=num_frame)
num_person = max([len(x) for x in pose_results])
num_keypoint = 17
keypoints = np.zeros((num_person, num_frame, num_keypoint, 2),
dtype=np.float16)
keypoints_score = np.zeros((num_person, num_frame, num_keypoint),
dtype=np.float16)
for i, poses in enumerate(pose_results):
if len(poses) < 1:
continue
_keypoint = poses['keypoints'] # shape = (num_person, num_keypoint, 3)
for j, pose in enumerate(_keypoint):
keypoints[j, i] = pose[:, :2]
keypoints_score[j, i] = pose[:, 2]
fake_anno['keypoint'] = keypoints
fake_anno['keypoint_score'] = keypoints_score
results = video_cls_predictor([fake_anno])
action_label = results[0]['class_name'][0]
print(f'action label: {action_label}')
vis_frames = [
pose_predictor.show_result(original_frames[i], pose_results[i])
if len(pose_results[i]) > 0 else original_frames[i]
for i in range(num_frame)
]
for frame in vis_frames:
cv2.putText(frame, action_label, (10, 30), FONTFACE, FONTSCALE,
FONTCOLOR, THICKNESS, LINETYPE)
vid = mpy.ImageSequenceClip([x[:, :, ::-1] for x in vis_frames], fps=24)
vid.write_videofile(args.out_file, remove_temp=True)
print(f'Write video to {args.out_file} successfully!')
tmp_frame_dir = osp.dirname(frame_paths[0])
shutil.rmtree(tmp_frame_dir)
if __name__ == '__main__':
main()

View File

@ -50,11 +50,17 @@ class ClsEvaluator(Evaluator):
a dict, each key is metric_name, value is metric value
'''
eval_res = OrderedDict()
if isinstance(gt_labels, dict):
assert len(gt_labels) == 1
gt_labels = list(gt_labels.values())[0]
target = gt_labels.long()
# if self.neck_num is not None:
if self.neck_num is None:
predictions = {'neck': predictions['neck']}
if len(predictions) > 1:
predictions = {'neck': predictions['neck']}
else:
predictions = {
'neck_%d_0' % self.neck_num:

View File

@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .pose_datasource import PoseDataSourceForVideoRec
from .video_datasource import VideoDatasource
from .video_text_datasource import VideoTextDatasource

View File

@ -0,0 +1,176 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import os.path as osp
from abc import ABCMeta
from collections import defaultdict
import mmcv
import numpy as np
import torch
from torch.utils.data import Dataset
from easycv.datasets.registry import DATASOURCES
from easycv.utils.logger import get_root_logger
@DATASOURCES.register_module()
class PoseDataSourceForVideoRec(Dataset, metaclass=ABCMeta):
"""Pose data source for video recognition.
Args:
ann_file (str): Path to the annotation file.
data_prefix (str | None): Path to a directory where videos are held.
Default: None.
multi_class (bool): Determines whether the dataset is a multi-class
dataset. Default: False.
num_classes (int | None): Number of classes of the dataset, used in
multi-class datasets. Default: None.
start_index (int): Specify a start index for frames in consideration of
different filename format. However, when taking videos as input,
it should be set to 0, since frames loaded from videos count
from 0. Default: 1.
sample_by_class (bool): Sampling by class, should be set `True` when
performing inter-class data balancing. Only compatible with
`multi_class == False`. Only applies for training. Default: False.
power (float): We support sampling data with the probability
proportional to the power of its label frequency (freq ^ power)
when sampling data. `power == 1` indicates uniformly sampling all
data; `power == 0` indicates uniformly sampling all classes.
Default: 0.
dynamic_length (bool): If the dataset length is dynamic (used by
ClassSpecificDistributedSampler). Default: False.
"""
def __init__(
self,
ann_file,
data_prefix=None,
multi_class=False,
num_classes=None,
start_index=1,
sample_by_class=False,
power=0,
dynamic_length=False,
split=None,
valid_ratio=None,
box_thr=None,
class_prob=None,
):
super().__init__()
self.modality = 'Pose'
# split, applicable to ucf or hmdb
self.split = split
self.ann_file = ann_file
self.data_prefix = osp.realpath(
data_prefix) if data_prefix is not None and osp.isdir(
data_prefix) else data_prefix
self.multi_class = multi_class
self.num_classes = num_classes
self.start_index = start_index
self.sample_by_class = sample_by_class
self.power = power
self.dynamic_length = dynamic_length
assert not (self.multi_class and self.sample_by_class)
self.video_infos = self.load_annotations()
if self.sample_by_class:
self.video_infos_by_class = self.parse_by_class()
class_prob = []
for _, samples in self.video_infos_by_class.items():
class_prob.append(len(samples) / len(self.video_infos))
class_prob = [x**self.power for x in class_prob]
summ = sum(class_prob)
class_prob = [x / summ for x in class_prob]
self.class_prob = dict(zip(self.video_infos_by_class, class_prob))
# box_thr, which should be a string
self.box_thr = box_thr
if self.box_thr is not None:
assert box_thr in ['0.5', '0.6', '0.7', '0.8', '0.9']
# Thresholding Training Examples
self.valid_ratio = valid_ratio
if self.valid_ratio is not None:
assert isinstance(self.valid_ratio, float)
if self.box_thr is None:
self.video_infos = self.video_infos = [
x for x in self.video_infos
if x['valid_frames'] / x['total_frames'] >= valid_ratio
]
else:
key = f'valid@{self.box_thr}'
self.video_infos = [
x for x in self.video_infos
if x[key] / x['total_frames'] >= valid_ratio
]
if self.box_thr != '0.5':
box_thr = float(self.box_thr)
for item in self.video_infos:
inds = [
i for i, score in enumerate(item['box_score'])
if score >= box_thr
]
item['anno_inds'] = np.array(inds)
if class_prob is not None:
self.class_prob = class_prob
logger = get_root_logger()
logger.info(f'{len(self)} videos remain after valid thresholding')
def load_annotations(self):
"""Load annotation file to get video information."""
assert self.ann_file.endswith('.pkl')
return self.load_pkl_annotations()
def load_pkl_annotations(self):
data = mmcv.load(self.ann_file)
if self.split:
split, data = data['split'], data['annotations']
identifier = 'filename' if 'filename' in data[0] else 'frame_dir'
data = [x for x in data if x[identifier] in split[self.split]]
for item in data:
# Sometimes we may need to load anno from the file
if 'filename' in item:
item['filename'] = osp.join(self.data_prefix, item['filename'])
if 'frame_dir' in item:
item['frame_dir'] = osp.join(self.data_prefix,
item['frame_dir'])
return data
def parse_by_class(self):
video_infos_by_class = defaultdict(list)
for item in self.video_infos:
label = item['label']
video_infos_by_class[label].append(item)
return video_infos_by_class
def prepare_frames(self, idx):
"""Prepare the frames for training given the index."""
results = copy.deepcopy(self.video_infos[idx])
results['modality'] = self.modality
results['start_index'] = self.start_index
# prepare tensor in getitem
# If HVU, type(results['label']) is dict
if self.multi_class and isinstance(results['label'], list):
onehot = torch.zeros(self.num_classes)
onehot[results['label']] = 1.
results['label'] = onehot
return results
def __len__(self):
"""Get the size of the dataset."""
return len(self.video_infos)
def __getitem__(self, idx):
return self.prepare_frames(idx)

View File

@ -1,7 +1,20 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# isort:skip_file
# yapf:disable
from .loading import DecordInit, DecordDecode, SampleFrames
from .transform import VideoImgaug, VideoFuse, VideoRandomScale, VideoRandomCrop, VideoRandomResizedCrop, VideoMultiScaleCrop, VideoResize, VideoRandomRescale, VideoFlip, VideoNormalize, VideoColorJitter, VideoCenterCrop, VideoThreeCrop, VideoTenCrop, VideoMultiGroupCrop
from .loading import DecordDecode, DecordInit, SampleFrames
from .pose_transform import (FormatGCNInput, PaddingWithLoop, PoseDecode,
PoseNormalize)
from .text_transform import TextTokenizer
__all__ = [DecordInit, DecordDecode, SampleFrames, VideoImgaug, VideoFuse, VideoRandomScale, VideoRandomCrop, VideoRandomResizedCrop, VideoMultiScaleCrop, VideoResize, VideoRandomRescale, VideoFlip, VideoNormalize, VideoColorJitter, VideoCenterCrop, VideoThreeCrop, VideoTenCrop, VideoMultiGroupCrop, TextTokenizer]
from .transform import (VideoCenterCrop, VideoColorJitter, VideoFlip,
VideoFuse, VideoImgaug, VideoMultiGroupCrop,
VideoMultiScaleCrop, VideoNormalize, VideoRandomCrop,
VideoRandomRescale, VideoRandomResizedCrop,
VideoRandomScale, VideoResize, VideoTenCrop,
VideoThreeCrop)
__all__ = [
'DecordInit', 'DecordDecode', 'SampleFrames', 'VideoImgaug', 'VideoFuse',
'VideoRandomScale', 'VideoRandomCrop', 'VideoRandomResizedCrop',
'VideoMultiScaleCrop', 'VideoResize', 'VideoRandomRescale', 'VideoFlip',
'VideoNormalize', 'VideoColorJitter', 'VideoCenterCrop', 'VideoThreeCrop',
'VideoTenCrop', 'VideoMultiGroupCrop', 'TextTokenizer', 'PaddingWithLoop',
'PoseDecode', 'PoseNormalize', 'FormatGCNInput'
]

View File

@ -0,0 +1,184 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright (c) OpenMMLab. All rights reserved.
# Refer to: https://github.com/open-mmlab/mmaction2/blob/master/mmaction/datasets/pipelines/pose_loading.py
import numpy as np
from easycv.datasets.registry import PIPELINES
@PIPELINES.register_module()
class PaddingWithLoop:
"""Sample frames from the video.
To sample an n-frame clip from the video, PaddingWithLoop samples
the frames from zero index, and loop the frames if the length of
video frames is less than te value of 'clip_len'.
Required keys are "total_frames", added or modified keys
are "frame_inds", "clip_len", "frame_interval" and "num_clips".
Args:
clip_len (int): Frames of each sampled output clip.
num_clips (int): Number of clips to be sampled. Default: 1.
"""
def __init__(self, clip_len, num_clips=1):
self.clip_len = clip_len
self.num_clips = num_clips
def __call__(self, results):
num_frames = results['total_frames']
start = 0
inds = np.arange(start, start + self.clip_len)
inds = np.mod(inds, num_frames)
results['frame_inds'] = inds.astype(np.int)
results['clip_len'] = self.clip_len
results['frame_interval'] = None
results['num_clips'] = self.num_clips
return results
@PIPELINES.register_module()
class PoseDecode:
"""Load and decode pose with given indices.
Required keys are "keypoint", "frame_inds" (optional), "keypoint_score"
(optional), added or modified keys are "keypoint", "keypoint_score" (if
applicable).
"""
@staticmethod
def _load_kp(kp, frame_inds):
"""Load keypoints given frame indices.
Args:
kp (np.ndarray): The keypoint coordinates.
frame_inds (np.ndarray): The frame indices.
"""
return [x[frame_inds].astype(np.float32) for x in kp]
@staticmethod
def _load_kpscore(kpscore, frame_inds):
"""Load keypoint scores given frame indices.
Args:
kpscore (np.ndarray): The confidence scores of keypoints.
frame_inds (np.ndarray): The frame indices.
"""
return [x[frame_inds].astype(np.float32) for x in kpscore]
def __call__(self, results):
if 'frame_inds' not in results:
results['frame_inds'] = np.arange(results['total_frames'])
if results['frame_inds'].ndim != 1:
results['frame_inds'] = np.squeeze(results['frame_inds'])
offset = results.get('offset', 0)
frame_inds = results['frame_inds'] + offset
if 'keypoint_score' in results:
kpscore = results['keypoint_score']
results['keypoint_score'] = kpscore[:,
frame_inds].astype(np.float32)
if 'keypoint' in results:
results['keypoint'] = results['keypoint'][:, frame_inds].astype(
np.float32)
return results
def __repr__(self):
repr_str = f'{self.__class__.__name__}()'
return repr_str
@PIPELINES.register_module()
class PoseNormalize:
"""Normalize the range of keypoint values to [-1,1].
Args:
mean (list | tuple): The mean value of the keypoint values.
min_value (list | tuple): The minimum value of the keypoint values.
max_value (list | tuple): The maximum value of the keypoint values.
"""
def __init__(self,
mean=(960., 540., 0.5),
min_value=(0., 0., 0.),
max_value=(1920, 1080, 1.)):
self.mean = np.array(mean, dtype=np.float32).reshape(-1, 1, 1, 1)
self.min_value = np.array(
min_value, dtype=np.float32).reshape(-1, 1, 1, 1)
self.max_value = np.array(
max_value, dtype=np.float32).reshape(-1, 1, 1, 1)
def __call__(self, results):
keypoint = results['keypoint']
keypoint = (keypoint - self.mean) / (self.max_value - self.min_value)
results['keypoint'] = keypoint
results['keypoint_norm_cfg'] = dict(
mean=self.mean, min_value=self.min_value, max_value=self.max_value)
return results
@PIPELINES.register_module()
class FormatGCNInput:
"""Format final skeleton shape to the given input_format.
Required keys are "keypoint" and "keypoint_score"(optional),
added or modified keys are "keypoint" and "input_shape".
Args:
input_format (str): Define the final skeleton format.
"""
def __init__(self, input_format, num_person=2):
self.input_format = input_format
if self.input_format not in ['NCTVM']:
raise ValueError(
f'The input format {self.input_format} is invalid.')
self.num_person = num_person
def __call__(self, results):
"""Performs the FormatShape formatting.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
keypoint = results['keypoint']
if 'keypoint_score' in results:
keypoint_confidence = results['keypoint_score']
keypoint_confidence = np.expand_dims(keypoint_confidence, -1)
keypoint_3d = np.concatenate((keypoint, keypoint_confidence),
axis=-1)
else:
keypoint_3d = keypoint
keypoint_3d = np.transpose(keypoint_3d,
(3, 1, 2, 0)) # M T V C -> C T V M
if keypoint_3d.shape[-1] < self.num_person:
pad_dim = self.num_person - keypoint_3d.shape[-1]
pad = np.zeros(
keypoint_3d.shape[:-1] + (pad_dim, ), dtype=keypoint_3d.dtype)
keypoint_3d = np.concatenate((keypoint_3d, pad), axis=-1)
elif keypoint_3d.shape[-1] > self.num_person:
keypoint_3d = keypoint_3d[:, :, :, :self.num_person]
results['keypoint'] = keypoint_3d
results['input_shape'] = keypoint_3d.shape
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f"(input_format='{self.input_format}')"
return repr_str

View File

@ -3,6 +3,7 @@ import logging
import traceback
import numpy as np
import torch
from easycv.datasets.registry import DATASETS
from easycv.datasets.shared.base import BaseDataset
@ -42,7 +43,15 @@ class VideoDataset(BaseDataset):
"""
assert len(evaluators) == 1, \
'classification evaluation only support one evaluator'
gt_labels = results.pop('label')
if results.get('label', None) is not None:
gt_labels = results.pop('label')
else:
gt_labels = []
for i in range(len(self.data_source)):
label = self.data_source.video_infos[i]['label']
gt_labels.append(label)
gt_labels = torch.Tensor(gt_labels)
eval_res = evaluators[0].evaluate(results, gt_labels)
return eval_res

View File

@ -2,3 +2,4 @@
from .ClipBertTwoStream import ClipBertTwoStream
from .heads import I3DHead
from .recognizer3d import Recognizer3D
from .skeleton_gcn.skeleton_gcn import SkeletonGCN

View File

@ -58,6 +58,7 @@ class BaseHead(nn.Module, metaclass=ABCMeta):
recognition task. Default: False.
label_smooth_eps (float): Epsilon used in label smooth.
Reference: arxiv.org/abs/1906.02629. Default: 0.
topk (int | tuple): Top-k accuracy. Default: (1, 5).
"""
def __init__(self,
@ -65,13 +66,20 @@ class BaseHead(nn.Module, metaclass=ABCMeta):
in_channels,
loss_cls=dict(type='CrossEntropyLoss', loss_weight=1.0),
multi_class=False,
label_smooth_eps=0.0):
label_smooth_eps=0.0,
topk=(1, 5)):
super().__init__()
self.num_classes = num_classes
self.in_channels = in_channels
self.loss_cls = build_loss(loss_cls)
self.multi_class = multi_class
self.label_smooth_eps = label_smooth_eps
assert isinstance(topk, (int, tuple))
if isinstance(topk, int):
topk = (topk, )
for _topk in topk:
assert _topk > 0, 'Top-k should be larger than 0'
self.topk = topk
@abstractmethod
def init_weights(self):
@ -89,7 +97,7 @@ class BaseHead(nn.Module, metaclass=ABCMeta):
labels (torch.Tensor): The target output of the model.
Returns:
dict: A dict containing field 'loss_cls'(mandatory)
and 'top1_acc', 'top5_acc'(optional).
and 'topk_acc'(optional).
"""
losses = dict()
if labels.shape == torch.Size([]):
@ -103,11 +111,11 @@ class BaseHead(nn.Module, metaclass=ABCMeta):
if not self.multi_class and cls_score.size() != labels.size():
top_k_acc = top_k_accuracy(cls_score.detach().cpu().numpy(),
labels.detach().cpu().numpy(), (1, 5))
losses['top1_acc'] = torch.tensor(
top_k_acc[0], device=cls_score.device)
losses['top5_acc'] = torch.tensor(
top_k_acc[1], device=cls_score.device)
labels.detach().cpu().numpy(),
self.topk)
for k, a in zip(self.topk, top_k_acc):
losses[f'top{k}_acc'] = torch.tensor(
a, device=cls_score.device)
elif self.multi_class and self.label_smooth_eps != 0:
labels = ((1 - self.label_smooth_eps) * labels +

View File

@ -0,0 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .base import BaseGCN
from .skeleton_gcn import SkeletonGCN
from .stgcn_backbone import STGCN
from .stgcn_head import STGCNHead
__all__ = ['BaseGCN', 'SkeletonGCN', 'STGCN', 'STGCNHead']

View File

@ -0,0 +1,115 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright (c) OpenMMLab. All rights reserved.
# Refer to: https://github.com/open-mmlab/mmaction2/blob/master/mmaction/models/skeleton_gcn/base.py
from easycv.models import builder
from easycv.models.base import BaseModel
class BaseGCN(BaseModel):
"""Base class for GCN-based action recognition.
All GCN-based recognizers should subclass it.
All subclass should overwrite:
- Methods:``forward_train``, supporting to forward when training.
- Methods:``forward_test``, supporting to forward when testing.
Args:
backbone (dict): Backbone modules to extract feature.
cls_head (dict | None): Classification head to process feature.
Default: None.
train_cfg (dict | None): Config for training. Default: None.
test_cfg (dict | None): Config for testing. Default: None.
"""
def __init__(self, backbone, cls_head=None, train_cfg=None, test_cfg=None):
super().__init__()
self.backbone = builder.build_backbone(backbone)
self.cls_head = builder.build_head(cls_head) if cls_head else None
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights()
@property
def with_cls_head(self):
"""bool: whether the recognizer has a cls_head"""
return hasattr(self, 'cls_head') and self.cls_head is not None
def forward(self, keypoint, label=None, mode='train', **kwargs):
"""Define the computation performed at every call."""
if mode == 'train':
if label is None:
raise ValueError('Label should not be None.')
return self.forward_train(keypoint, label, **kwargs)
return self.forward_test(keypoint, **kwargs)
def extract_feat(self, skeletons):
"""Extract features through a backbone.
Args:
skeletons (torch.Tensor): The input skeletons.
Returns:
torch.tensor: The extracted features.
"""
x = self.backbone(skeletons)
return x
def train_step(self, data_batch, optimizer, **kwargs):
"""The iteration step during training.
This method defines an iteration step during training, except for the
back propagation and optimizer updating, which are done in an optimizer
hook. Note that in some complicated cases or models, the whole process
including back propagation and optimizer updating is also defined in
this method, such as GAN.
Args:
data_batch (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``. This argument is unused
and reserved.
Returns:
dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
``num_samples``.
``loss`` is a tensor for back propagation, which can be a
weighted sum of multiple losses.
``log_vars`` contains all the variables to be sent to the
logger.
``num_samples`` indicates the batch size (when the model is
DDP, it means the batch size on each GPU), which is used for
averaging the logs.
"""
skeletons = data_batch['keypoint']
label = data_batch['label']
label = label.squeeze(-1)
losses = self(skeletons, label, return_loss=True)
loss, log_vars = self._parse_losses(losses)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(skeletons.data))
return outputs
def val_step(self, data_batch, optimizer, **kwargs):
"""The iteration step during validation.
This method shares the same signature as :func:`train_step`, but used
during val epochs. Note that the evaluation after training epochs is
not implemented with this method, but an evaluation hook.
"""
skeletons = data_batch['keypoint']
label = data_batch['label']
losses = self(skeletons, label, return_loss=True)
loss, log_vars = self._parse_losses(losses)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(skeletons.data))
return outputs

View File

@ -0,0 +1,33 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright (c) OpenMMLab. All rights reserved.
# Refer to: https://github.com/open-mmlab/mmaction2/blob/master/mmaction/models/skeleton_gcn/skeletongcn.py
from easycv.models.builder import MODELS
from .base import BaseGCN
@MODELS.register_module()
class SkeletonGCN(BaseGCN):
"""Spatial temporal graph convolutional networks."""
def forward_train(self, skeletons, labels, **kwargs):
"""Defines the computation performed at every call when training."""
assert self.with_cls_head
losses = dict()
x = self.extract_feat(skeletons)
output = self.cls_head(x)
gt_labels = labels.squeeze(-1)
loss = self.cls_head.loss(output, gt_labels)
losses.update(loss)
return losses
def forward_test(self, skeletons, **kwargs):
"""Defines the computation performed at every call when evaluation and
testing."""
x = self.extract_feat(skeletons)
assert self.with_cls_head
output = self.cls_head(x)
result = {'prob': output.cpu()}
return result

View File

@ -0,0 +1,283 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright (c) OpenMMLab. All rights reserved.
# Refer to: https://github.com/open-mmlab/mmaction2/blob/master/mmaction/models/backbones/stgcn.py
import torch
import torch.nn as nn
from mmcv.cnn import constant_init, kaiming_init, normal_init
from mmcv.utils import _BatchNorm
from easycv.models import BACKBONES
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger
from .utils import Graph
def zero(x):
"""return zero."""
return 0
def identity(x):
"""return input itself."""
return x
class STGCNBlock(nn.Module):
"""Applies a spatial temporal graph convolution over an input graph
sequence.
Args:
in_channels (int): Number of channels in the input sequence data
out_channels (int): Number of channels produced by the convolution
kernel_size (tuple): Size of the temporal convolving kernel and
graph convolving kernel
stride (int, optional): Stride of the temporal convolution. Default: 1
dropout (int, optional): Dropout rate of the final output. Default: 0
residual (bool, optional): If ``True``, applies a residual mechanism.
Default: ``True``
Shape:
- Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)`
format
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
- Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out},
V)` format
- Output[1]: Graph adjacency matrix for output data in :math:`(K, V,
V)` format
where
:math:`N` is a batch size,
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]
`,
:math:`T_{in}/T_{out}` is a length of input/output sequence,
:math:`V` is the number of graph nodes.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
dropout=0,
residual=True):
super().__init__()
assert len(kernel_size) == 2
assert kernel_size[0] % 2 == 1
padding = ((kernel_size[0] - 1) // 2, 0)
self.gcn = ConvTemporalGraphical(in_channels, out_channels,
kernel_size[1])
self.tcn = nn.Sequential(
nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, (kernel_size[0], 1),
(stride, 1), padding), nn.BatchNorm2d(out_channels),
nn.Dropout(dropout, inplace=True))
if not residual:
self.residual = zero
elif (in_channels == out_channels) and (stride == 1):
self.residual = identity
else:
self.residual = nn.Sequential(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=(stride, 1)), nn.BatchNorm2d(out_channels))
self.relu = nn.ReLU(inplace=True)
def forward(self, x, adj_mat):
"""Defines the computation performed at every call."""
res = self.residual(x)
x, adj_mat = self.gcn(x, adj_mat)
x = self.tcn(x) + res
return self.relu(x), adj_mat
class ConvTemporalGraphical(nn.Module):
"""The basic module for applying a graph convolution.
Args:
in_channels (int): Number of channels in the input sequence data
out_channels (int): Number of channels produced by the convolution
kernel_size (int): Size of the graph convolving kernel
t_kernel_size (int): Size of the temporal convolving kernel
t_stride (int, optional): Stride of the temporal convolution.
Default: 1
t_padding (int, optional): Temporal zero-padding added to both sides
of the input. Default: 0
t_dilation (int, optional): Spacing between temporal kernel elements.
Default: 1
bias (bool, optional): If ``True``, adds a learnable bias to the
output. Default: ``True``
Shape:
- Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)`
format
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
- Output[0]: Output graph sequence in :math:`(N, out_channels, T_{out}
, V)` format
- Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)
` format
where
:math:`N` is a batch size,
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]
`,
:math:`T_{in}/T_{out}` is a length of input/output sequence,
:math:`V` is the number of graph nodes.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
t_kernel_size=1,
t_stride=1,
t_padding=0,
t_dilation=1,
bias=True):
super().__init__()
self.kernel_size = kernel_size
self.conv = nn.Conv2d(
in_channels,
out_channels * kernel_size,
kernel_size=(t_kernel_size, 1),
padding=(t_padding, 0),
stride=(t_stride, 1),
dilation=(t_dilation, 1),
bias=bias)
def forward(self, x, adj_mat):
"""Defines the computation performed at every call."""
assert adj_mat.size(0) == self.kernel_size
x = self.conv(x)
n, kc, t, v = x.size()
x = x.view(n, self.kernel_size, kc // self.kernel_size, t, v)
x = torch.einsum('nkctv,kvw->nctw', (x, adj_mat))
return x.contiguous(), adj_mat
@BACKBONES.register_module()
class STGCN(nn.Module):
"""Backbone of Spatial temporal graph convolutional networks.
Args:
in_channels (int): Number of channels in the input data.
graph_cfg (dict): The arguments for building the graph.
edge_importance_weighting (bool): If ``True``, adds a learnable
importance weighting to the edges of the graph. Default: True.
data_bn (bool): If 'True', adds data normalization to the inputs.
Default: True.
pretrained (str | None): Name of pretrained model.
**kwargs (optional): Other parameters for graph convolution units.
Shape:
- Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})`
- Output: :math:`(N, num_class)` where
:math:`N` is a batch size,
:math:`T_{in}` is a length of input sequence,
:math:`V_{in}` is the number of graph nodes,
:math:`M_{in}` is the number of instance in a frame.
"""
def __init__(self,
in_channels,
graph_cfg,
edge_importance_weighting=True,
data_bn=True,
pretrained=None,
**kwargs):
super().__init__()
# load graph
self.graph = Graph(**graph_cfg)
A = torch.tensor(
self.graph.A, dtype=torch.float32, requires_grad=False)
self.register_buffer('A', A)
# build networks
spatial_kernel_size = A.size(0)
temporal_kernel_size = 9
kernel_size = (temporal_kernel_size, spatial_kernel_size)
self.data_bn = nn.BatchNorm1d(in_channels *
A.size(1)) if data_bn else identity
kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'}
self.st_gcn_networks = nn.ModuleList((
STGCNBlock(
in_channels, 64, kernel_size, 1, residual=False, **kwargs0),
STGCNBlock(64, 64, kernel_size, 1, **kwargs),
STGCNBlock(64, 64, kernel_size, 1, **kwargs),
STGCNBlock(64, 64, kernel_size, 1, **kwargs),
STGCNBlock(64, 128, kernel_size, 2, **kwargs),
STGCNBlock(128, 128, kernel_size, 1, **kwargs),
STGCNBlock(128, 128, kernel_size, 1, **kwargs),
STGCNBlock(128, 256, kernel_size, 2, **kwargs),
STGCNBlock(256, 256, kernel_size, 1, **kwargs),
STGCNBlock(256, 256, kernel_size, 1, **kwargs),
))
# initialize parameters for edge importance weighting
if edge_importance_weighting:
self.edge_importance = nn.ParameterList([
nn.Parameter(torch.ones(self.A.size()))
for i in self.st_gcn_networks
])
else:
self.edge_importance = [1 for _ in self.st_gcn_networks]
self.pretrained = pretrained
def init_weights(self):
"""Initiate the parameters either from existing checkpoint or from
scratch."""
if isinstance(self.pretrained, str):
logger = get_root_logger()
logger.info(f'load model from: {self.pretrained}')
load_checkpoint(self, self.pretrained, strict=False, logger=logger)
elif self.pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, nn.Linear):
normal_init(m)
elif isinstance(m, _BatchNorm):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
"""Defines the computation performed at every call.
Args:
x (torch.Tensor): The input data.
Returns:
torch.Tensor: The output of the module.
"""
# data normalization
x = x.float()
n, c, t, v, m = x.size() # bs 3 300 25(17) 2
x = x.permute(0, 4, 3, 1, 2).contiguous() # N M V C T
x = x.view(n * m, v * c, t)
x = self.data_bn(x)
x = x.view(n, m, v, c, t)
x = x.permute(0, 1, 3, 4, 2).contiguous()
x = x.view(n * m, c, t, v) # bsx2 3 300 25(17)
# forward
for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
x, _ = gcn(x, self.A * importance)
return x

View File

@ -0,0 +1,67 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright (c) OpenMMLab. All rights reserved.
# Refer to: https://github.com/open-mmlab/mmaction2/blob/master/mmaction/models/heads/stgcn_head.py
import torch.nn as nn
from mmcv.cnn import normal_init
from easycv.models import HEADS
from easycv.models.video_recognition.heads.base_head import BaseHead
@HEADS.register_module()
class STGCNHead(BaseHead):
"""The classification head for STGCN.
Args:
num_classes (int): Number of classes to be classified.
in_channels (int): Number of channels in input feature.
loss_cls (dict): Config for building loss.
Default: dict(type='CrossEntropyLoss')
spatial_type (str): Pooling type in spatial dimension. Default: 'avg'.
num_person (int): Number of person. Default: 2.
init_std (float): Std value for Initiation. Default: 0.01.
kwargs (dict, optional): Any keyword argument to be used to initialize
the head.
"""
def __init__(self,
num_classes,
in_channels,
loss_cls=dict(type='CrossEntropyLoss'),
spatial_type='avg',
num_person=2,
init_std=0.01,
**kwargs):
super().__init__(num_classes, in_channels, loss_cls, **kwargs)
self.spatial_type = spatial_type
self.in_channels = in_channels
self.num_classes = num_classes
self.num_person = num_person
self.init_std = init_std
self.pool = None
if self.spatial_type == 'avg':
self.pool = nn.AdaptiveAvgPool2d((1, 1))
elif self.spatial_type == 'max':
self.pool = nn.AdaptiveMaxPool2d((1, 1))
else:
raise NotImplementedError
self.fc = nn.Conv2d(self.in_channels, self.num_classes, kernel_size=1)
def init_weights(self):
normal_init(self.fc, std=self.init_std)
def forward(self, x):
# global pooling
assert self.pool is not None
x = self.pool(x)
x = x.view(x.shape[0] // self.num_person, self.num_person, -1, 1,
1).mean(dim=1)
# prediction
x = self.fc(x)
x = x.view(x.shape[0], -1)
return x

View File

@ -0,0 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .graph import Graph
__all__ = ['Graph']

View File

@ -0,0 +1,198 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright (c) OpenMMLab. All rights reserved.
# Refer to: https://github.com/open-mmlab/mmaction2/blob/master/mmaction/models/skeleton_gcn/utils/graph.py
import numpy as np
def get_hop_distance(num_node, edge, max_hop=1):
adj_mat = np.zeros((num_node, num_node))
for i, j in edge:
adj_mat[i, j] = 1
adj_mat[j, i] = 1
# compute hop steps
hop_dis = np.zeros((num_node, num_node)) + np.inf
transfer_mat = [
np.linalg.matrix_power(adj_mat, d) for d in range(max_hop + 1)
]
arrive_mat = (np.stack(transfer_mat) > 0)
for d in range(max_hop, -1, -1):
hop_dis[arrive_mat[d]] = d
return hop_dis
def normalize_digraph(adj_matrix):
Dl = np.sum(adj_matrix, 0)
num_nodes = adj_matrix.shape[0]
Dn = np.zeros((num_nodes, num_nodes))
for i in range(num_nodes):
if Dl[i] > 0:
Dn[i, i] = Dl[i]**(-1)
norm_matrix = np.dot(adj_matrix, Dn)
return norm_matrix
def edge2mat(link, num_node):
A = np.zeros((num_node, num_node))
for i, j in link:
A[j, i] = 1
return A
class Graph:
"""The Graph to model the skeletons extracted by the openpose.
Args:
layout (str): must be one of the following candidates
- openpose: 18 or 25 joints. For more information, please refer to:
https://github.com/CMU-Perceptual-Computing-Lab/openpose#output
- ntu-rgb+d: Is consists of 25 joints. For more information, please
refer to https://github.com/shahroudy/NTURGB-D
strategy (str): must be one of the follow candidates
- uniform: Uniform Labeling
- distance: Distance Partitioning
- spatial: Spatial Configuration
For more information, please refer to the section 'Partition
Strategies' in our paper (https://arxiv.org/abs/1801.07455).
max_hop (int): the maximal distance between two connected nodes.
Default: 1
dilation (int): controls the spacing between the kernel points.
Default: 1
"""
def __init__(self,
layout='openpose-18',
strategy='uniform',
max_hop=1,
dilation=1):
self.max_hop = max_hop
self.dilation = dilation
assert layout in [
'openpose-18', 'openpose-25', 'ntu-rgb+d', 'ntu_edge', 'coco'
]
assert strategy in ['uniform', 'distance', 'spatial', 'agcn']
self.get_edge(layout)
self.hop_dis = get_hop_distance(
self.num_node, self.edge, max_hop=max_hop)
self.get_adjacency(strategy)
def __str__(self):
return self.A
def get_edge(self, layout):
"""This method returns the edge pairs of the layout."""
if layout == 'openpose-18':
self.num_node = 18
self_link = [(i, i) for i in range(self.num_node)]
neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5),
(13, 12), (12, 11), (10, 9), (9, 8), (11, 5),
(8, 2), (5, 1), (2, 1), (0, 1), (15, 0), (14, 0),
(17, 15), (16, 14)]
self.edge = self_link + neighbor_link
self.center = 1
elif layout == 'openpose-25':
self.num_node = 25
self_link = [(i, i) for i in range(self.num_node)]
neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5), (23, 22),
(22, 11), (24, 11), (11, 10), (10, 9), (9, 8),
(20, 19), (19, 14), (21, 14), (14, 13), (13, 12),
(12, 8), (8, 1), (5, 1), (2, 1), (0, 1), (15, 0),
(16, 0), (17, 15), (18, 16)]
self.self_link = self_link
self.neighbor_link = neighbor_link
self.edge = self_link + neighbor_link
self.center = 1
elif layout == 'ntu-rgb+d':
self.num_node = 25
self_link = [(i, i) for i in range(self.num_node)]
neighbor_1base = [(1, 2), (2, 21), (3, 21),
(4, 3), (5, 21), (6, 5), (7, 6), (8, 7), (9, 21),
(10, 9), (11, 10), (12, 11), (13, 1), (14, 13),
(15, 14), (16, 15), (17, 1), (18, 17), (19, 18),
(20, 19), (22, 23), (23, 8), (24, 25), (25, 12)]
neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]
self.self_link = self_link
self.neighbor_link = neighbor_link
self.edge = self_link + neighbor_link
self.center = 21 - 1
elif layout == 'ntu_edge':
self.num_node = 24
self_link = [(i, i) for i in range(self.num_node)]
neighbor_1base = [(1, 2), (3, 2), (4, 3), (5, 2), (6, 5), (7, 6),
(8, 7), (9, 2), (10, 9), (11, 10), (12, 11),
(13, 1), (14, 13), (15, 14), (16, 15), (17, 1),
(18, 17), (19, 18), (20, 19), (21, 22), (22, 8),
(23, 24), (24, 12)]
neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]
self.edge = self_link + neighbor_link
self.center = 2
elif layout == 'coco':
self.num_node = 17
self_link = [(i, i) for i in range(self.num_node)]
neighbor_1base = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13],
[6, 12], [7, 13], [6, 7], [8, 6], [9, 7],
[10, 8], [11, 9], [2, 3], [2, 1], [3, 1], [4, 2],
[5, 3], [4, 6], [5, 7]]
neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]
self.edge = self_link + neighbor_link
self.center = 0
else:
raise ValueError(f'{layout} is not supported.')
def get_adjacency(self, strategy):
"""This method returns the adjacency matrix according to strategy."""
valid_hop = range(0, self.max_hop + 1, self.dilation)
adjacency = np.zeros((self.num_node, self.num_node))
for hop in valid_hop:
adjacency[self.hop_dis == hop] = 1
normalize_adjacency = normalize_digraph(adjacency)
if strategy == 'uniform':
A = np.zeros((1, self.num_node, self.num_node))
A[0] = normalize_adjacency
self.A = A
elif strategy == 'distance':
A = np.zeros((len(valid_hop), self.num_node, self.num_node))
for i, hop in enumerate(valid_hop):
A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis ==
hop]
self.A = A
elif strategy == 'spatial':
A = []
for hop in valid_hop:
a_root = np.zeros((self.num_node, self.num_node))
a_close = np.zeros((self.num_node, self.num_node))
a_further = np.zeros((self.num_node, self.num_node))
for i in range(self.num_node):
for j in range(self.num_node):
if self.hop_dis[j, i] == hop:
if self.hop_dis[j, self.center] == self.hop_dis[
i, self.center]:
a_root[j, i] = normalize_adjacency[j, i]
elif self.hop_dis[j, self.center] > self.hop_dis[
i, self.center]:
a_close[j, i] = normalize_adjacency[j, i]
else:
a_further[j, i] = normalize_adjacency[j, i]
if hop == 0:
A.append(a_root)
else:
A.append(a_root + a_close)
A.append(a_further)
A = np.stack(A)
self.A = A
elif strategy == 'agcn':
A = []
link_mat = edge2mat(self.self_link, self.num_node)
In = normalize_digraph(edge2mat(self.neighbor_link, self.num_node))
outward = [(j, i) for (i, j) in self.neighbor_link]
Out = normalize_digraph(edge2mat(outward, self.num_node))
A = np.stack((link_mat, In, Out))
self.A = A
else:
raise ValueError('Do Not Exist This Strategy')

View File

@ -5,6 +5,7 @@ import os
import pickle
import cv2
import mmcv
import numpy as np
import torch
from mmcv.parallel import collate, scatter_kwargs
@ -419,9 +420,14 @@ class PredictorV2(object):
inputs = [inputs]
results_list = []
prog_bar = mmcv.ProgressBar(len(inputs))
for i in range(0, len(inputs), self.batch_size):
batch_inputs = inputs[i:min(len(inputs), i + self.batch_size)]
batch_outputs = self.input_processor(batch_inputs)
if len(batch_outputs) < 1:
results_list.append(batch_outputs)
continue
batch_outputs = self._to_device(batch_outputs)
batch_outputs = self.model_forward(batch_outputs)
results = self.output_processor(batch_outputs)
@ -433,6 +439,8 @@ class PredictorV2(object):
else:
results_list.append(results)
prog_bar.update(len(batch_inputs))
# TODO: support append to file
if self.save_results:
self.dump(results_list, self.save_path)

View File

@ -221,8 +221,8 @@ class PoseTopDownInputProcessor(InputProcessor):
output_person_info = []
for person_result in person_results:
box = person_result['bbox'] # x,y,w,h
box = [box[0], box[1], box[2] - box[0], box[3] - box[1]]
box = person_result['bbox'] # x,y,x,y
box = [box[0], box[1], box[2] - box[0], box[3] - box[1]] # x,y,w,h
center, scale = _box2cs(self.cfg.data_cfg['image_size'], box)
data = {
'image_id':
@ -277,6 +277,9 @@ class PoseTopDownInputProcessor(InputProcessor):
for res in self.process_single(inp):
batch_outputs.append(res)
if len(batch_outputs) < 1:
return batch_outputs
batch_outputs = self._collate_fn(batch_outputs)
batch_outputs['img_metas']._data = [[
img_meta[i] for img_meta in batch_outputs['img_metas']._data
@ -384,7 +387,7 @@ class PoseTopDownPredictor(PredictorV2):
image,
keypoints,
radius=4,
thickness=1,
thickness=3,
kpt_score_thr=0.3,
bbox_color='green',
show=False,

View File

@ -196,3 +196,118 @@ class VideoClassificationPredictor(PredictorV2):
def get_output_processor(self):
return VideoClsOutputProcessor(self.label_map, self.topk)
class STGCNInputProcessor(InputProcessor):
def _load_input(self, input):
"""Prepare input sample.
Args:
input (dict): Input sample dict. e.g.
{
'frame_dir': '',
'img_shape': (1080, 1920),
'original_shape': (1080, 1920),
'total_frames': 40,
'keypoint': (2, 40, 17, 2), # shape = (num_person, num_frame, num_keypoints, 2)
'keypoint_score': (2, 40, 17),
'modality': 'Pose',
'start_index': 1
}.
"""
assert isinstance(input, dict)
keypoint = input['keypoint']
assert len(keypoint.shape) == 4
assert keypoint.shape[-1] in [2, 3]
if keypoint.shape[-1] == 3:
if input.get('keypoint_score', None) is None:
input['keypoint_score'] = keypoint[..., -1]
keypoint = keypoint[..., :2]
input['keypoint'] = keypoint
return input
@PREDICTORS.register_module()
class STGCNPredictor(PredictorV2):
"""STGCN predict pipeline.
Args:
model_path (str): Path of model path.
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
ori_image_size (Optinal[list|tuple]): Original image or video frame size (weight, height).
batch_size (int): batch size for forward.
label_map ((Optinal[list|tuple])): List or file of labels.
device (str | torch.device): Support str('cuda' or 'cpu') or torch.device, if is None, detect device automatically.
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
input_processor_threads (int): Number of processes to process inputs.
mode (str): The image mode into the model.
"""
def __init__(self,
model_path,
config_file=None,
ori_image_size=None,
batch_size=1,
label_map=None,
topk=1,
device=None,
save_results=False,
save_path=None,
pipelines=None,
input_processor_threads=8,
mode='RGB',
*args,
**kwargs):
super(STGCNPredictor, self).__init__(
model_path,
config_file=config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
input_processor_threads=input_processor_threads,
mode=mode,
*args,
**kwargs)
if ori_image_size is not None:
w, h = ori_image_size
for pipeline in self.cfg.test_pipeline:
if pipeline['type'] == 'PoseNormalize':
pipeline['mean'] = (w // 2, h // 2, .5)
pipeline['max_value'] = (w, h, 1.)
self.topk = topk
if label_map is None:
if 'CLASSES' in self.cfg:
class_list = self.cfg.get('CLASSES', [])
elif 'num_classes' in self.cfg:
class_list = list(range(self.cfg.num_classes))
class_list = [str(i) for i in class_list]
else:
class_list = []
elif isinstance(label_map, str):
with io.open(label_map, 'r') as f:
class_list = f.readlines()
elif isinstance(label_map, (tuple, list)):
class_list = label_map
self.label_map = [i.strip() for i in class_list]
def get_input_processor(self):
return STGCNInputProcessor(
self.cfg,
pipelines=self.pipelines,
batch_size=self.batch_size,
threads=self.input_processor_threads,
mode=self.mode)
def get_output_processor(self):
return VideoClsOutputProcessor(self.label_map, self.topk)

View File

@ -0,0 +1,54 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
import torch
from easycv.models.builder import build_model
class STGCNTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def _get_model(self):
model_cfg = dict(
type='SkeletonGCN',
backbone=dict(
type='STGCN',
in_channels=3,
edge_importance_weighting=True,
graph_cfg=dict(layout='coco', strategy='spatial')),
cls_head=dict(
type='STGCNHead',
num_classes=60,
in_channels=256,
loss_cls=dict(type='CrossEntropyLoss')),
train_cfg=None,
test_cfg=None)
model = build_model(model_cfg)
return model
def test_train(self):
model = self._get_model()
model.train()
batch_size = 2
keypoints = torch.randn([batch_size, 3, 300, 17, 2])
label = torch.randint(0, 60, (batch_size, ))
output = model(keypoint=keypoints, label=label)
self.assertIn('loss_cls', output)
self.assertIn('top1_acc', output)
self.assertIn('top5_acc', output)
def test_infer(self):
model = self._get_model()
model.eval()
with torch.no_grad():
keypoints = torch.randn([1, 3, 300, 17, 2])
output = model(keypoint=keypoints, mode='test')
self.assertEqual(output['prob'].shape, (1, 60))
if __name__ == '__main__':
unittest.main()

View File

@ -2,16 +2,13 @@
"""
isort:skip_file
"""
import json
import os
import unittest
import numpy as np
import cv2
import torch
from easycv.predictors.video_classifier import VideoClassificationPredictor
from easycv.utils.test_util import clean_up, get_tmp_dir
from easycv.predictors.video_classifier import VideoClassificationPredictor, STGCNPredictor
from tests.ut_config import (PRETRAINED_MODEL_X3D_XS,
VIDEO_DATA_SMALL_RAW_LOCAL)
VIDEO_DATA_SMALL_RAW_LOCAL, BASE_LOCAL_PATH)
class VideoClassificationPredictorTest(unittest.TestCase):
@ -54,5 +51,39 @@ class VideoClassificationPredictorTest(unittest.TestCase):
self.assertEqual(len(res['class_probs']), 400)
class STGCNPredictorTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_single(self):
checkpoint = os.path.join(
BASE_LOCAL_PATH,
'pretrained_models/video/stgcn/stgcn_80e_ntu60_xsub.pth')
config_file = 'configs/video_recognition/stgcn/stgcn_80e_ntu60_xsub_keypoint.py'
predict_op = STGCNPredictor(
model_path=checkpoint, config_file=config_file)
h, w = 480, 853
total_frames = 20
num_person = 2
inp = dict(
frame_dir='',
label=-1,
img_shape=(h, w),
original_shape=(h, w),
start_index=0,
modality='Pose',
total_frames=total_frames,
keypoint=np.random.random((num_person, total_frames, 17, 2)),
keypoint_score=np.random.random((num_person, total_frames, 17)),
)
results = predict_op([inp])[0]
self.assertIn('class', results)
self.assertIn('class_name', results)
self.assertEqual(len(results['class_probs']), 60)
if __name__ == '__main__':
unittest.main()