EasyCV/easycv/datasets/video_recognition/data_sources/pose_datasource.py

177 lines
6.4 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import os.path as osp
from abc import ABCMeta
from collections import defaultdict
import mmcv
import numpy as np
import torch
from torch.utils.data import Dataset
from easycv.datasets.registry import DATASOURCES
from easycv.utils.logger import get_root_logger
@DATASOURCES.register_module()
class PoseDataSourceForVideoRec(Dataset, metaclass=ABCMeta):
"""Pose data source for video recognition.
Args:
ann_file (str): Path to the annotation file.
data_prefix (str | None): Path to a directory where videos are held.
Default: None.
multi_class (bool): Determines whether the dataset is a multi-class
dataset. Default: False.
num_classes (int | None): Number of classes of the dataset, used in
multi-class datasets. Default: None.
start_index (int): Specify a start index for frames in consideration of
different filename format. However, when taking videos as input,
it should be set to 0, since frames loaded from videos count
from 0. Default: 1.
sample_by_class (bool): Sampling by class, should be set `True` when
performing inter-class data balancing. Only compatible with
`multi_class == False`. Only applies for training. Default: False.
power (float): We support sampling data with the probability
proportional to the power of its label frequency (freq ^ power)
when sampling data. `power == 1` indicates uniformly sampling all
data; `power == 0` indicates uniformly sampling all classes.
Default: 0.
dynamic_length (bool): If the dataset length is dynamic (used by
ClassSpecificDistributedSampler). Default: False.
"""
def __init__(
self,
ann_file,
data_prefix=None,
multi_class=False,
num_classes=None,
start_index=1,
sample_by_class=False,
power=0,
dynamic_length=False,
split=None,
valid_ratio=None,
box_thr=None,
class_prob=None,
):
super().__init__()
self.modality = 'Pose'
# split, applicable to ucf or hmdb
self.split = split
self.ann_file = ann_file
self.data_prefix = osp.realpath(
data_prefix) if data_prefix is not None and osp.isdir(
data_prefix) else data_prefix
self.multi_class = multi_class
self.num_classes = num_classes
self.start_index = start_index
self.sample_by_class = sample_by_class
self.power = power
self.dynamic_length = dynamic_length
assert not (self.multi_class and self.sample_by_class)
self.video_infos = self.load_annotations()
if self.sample_by_class:
self.video_infos_by_class = self.parse_by_class()
class_prob = []
for _, samples in self.video_infos_by_class.items():
class_prob.append(len(samples) / len(self.video_infos))
class_prob = [x**self.power for x in class_prob]
summ = sum(class_prob)
class_prob = [x / summ for x in class_prob]
self.class_prob = dict(zip(self.video_infos_by_class, class_prob))
# box_thr, which should be a string
self.box_thr = box_thr
if self.box_thr is not None:
assert box_thr in ['0.5', '0.6', '0.7', '0.8', '0.9']
# Thresholding Training Examples
self.valid_ratio = valid_ratio
if self.valid_ratio is not None:
assert isinstance(self.valid_ratio, float)
if self.box_thr is None:
self.video_infos = self.video_infos = [
x for x in self.video_infos
if x['valid_frames'] / x['total_frames'] >= valid_ratio
]
else:
key = f'valid@{self.box_thr}'
self.video_infos = [
x for x in self.video_infos
if x[key] / x['total_frames'] >= valid_ratio
]
if self.box_thr != '0.5':
box_thr = float(self.box_thr)
for item in self.video_infos:
inds = [
i for i, score in enumerate(item['box_score'])
if score >= box_thr
]
item['anno_inds'] = np.array(inds)
if class_prob is not None:
self.class_prob = class_prob
logger = get_root_logger()
logger.info(f'{len(self)} videos remain after valid thresholding')
def load_annotations(self):
"""Load annotation file to get video information."""
assert self.ann_file.endswith('.pkl')
return self.load_pkl_annotations()
def load_pkl_annotations(self):
data = mmcv.load(self.ann_file)
if self.split:
split, data = data['split'], data['annotations']
identifier = 'filename' if 'filename' in data[0] else 'frame_dir'
data = [x for x in data if x[identifier] in split[self.split]]
for item in data:
# Sometimes we may need to load anno from the file
if 'filename' in item:
item['filename'] = osp.join(self.data_prefix, item['filename'])
if 'frame_dir' in item:
item['frame_dir'] = osp.join(self.data_prefix,
item['frame_dir'])
return data
def parse_by_class(self):
video_infos_by_class = defaultdict(list)
for item in self.video_infos:
label = item['label']
video_infos_by_class[label].append(item)
return video_infos_by_class
def prepare_frames(self, idx):
"""Prepare the frames for training given the index."""
results = copy.deepcopy(self.video_infos[idx])
results['modality'] = self.modality
results['start_index'] = self.start_index
# prepare tensor in getitem
# If HVU, type(results['label']) is dict
if self.multi_class and isinstance(results['label'], list):
onehot = torch.zeros(self.num_classes)
onehot[results['label']] = 1.
results['label'] = onehot
return results
def __len__(self):
"""Get the size of the dataset."""
return len(self.video_infos)
def __getitem__(self, idx):
return self.prepare_frames(idx)