mirror of https://github.com/alibaba/EasyCV.git
177 lines
6.4 KiB
Python
177 lines
6.4 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import copy
|
|
import os.path as osp
|
|
from abc import ABCMeta
|
|
from collections import defaultdict
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
|
|
from easycv.datasets.registry import DATASOURCES
|
|
from easycv.utils.logger import get_root_logger
|
|
|
|
|
|
@DATASOURCES.register_module()
|
|
class PoseDataSourceForVideoRec(Dataset, metaclass=ABCMeta):
|
|
"""Pose data source for video recognition.
|
|
Args:
|
|
ann_file (str): Path to the annotation file.
|
|
data_prefix (str | None): Path to a directory where videos are held.
|
|
Default: None.
|
|
multi_class (bool): Determines whether the dataset is a multi-class
|
|
dataset. Default: False.
|
|
num_classes (int | None): Number of classes of the dataset, used in
|
|
multi-class datasets. Default: None.
|
|
start_index (int): Specify a start index for frames in consideration of
|
|
different filename format. However, when taking videos as input,
|
|
it should be set to 0, since frames loaded from videos count
|
|
from 0. Default: 1.
|
|
sample_by_class (bool): Sampling by class, should be set `True` when
|
|
performing inter-class data balancing. Only compatible with
|
|
`multi_class == False`. Only applies for training. Default: False.
|
|
power (float): We support sampling data with the probability
|
|
proportional to the power of its label frequency (freq ^ power)
|
|
when sampling data. `power == 1` indicates uniformly sampling all
|
|
data; `power == 0` indicates uniformly sampling all classes.
|
|
Default: 0.
|
|
dynamic_length (bool): If the dataset length is dynamic (used by
|
|
ClassSpecificDistributedSampler). Default: False.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
ann_file,
|
|
data_prefix=None,
|
|
multi_class=False,
|
|
num_classes=None,
|
|
start_index=1,
|
|
sample_by_class=False,
|
|
power=0,
|
|
dynamic_length=False,
|
|
split=None,
|
|
valid_ratio=None,
|
|
box_thr=None,
|
|
class_prob=None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.modality = 'Pose'
|
|
# split, applicable to ucf or hmdb
|
|
self.split = split
|
|
|
|
self.ann_file = ann_file
|
|
self.data_prefix = osp.realpath(
|
|
data_prefix) if data_prefix is not None and osp.isdir(
|
|
data_prefix) else data_prefix
|
|
self.multi_class = multi_class
|
|
self.num_classes = num_classes
|
|
self.start_index = start_index
|
|
|
|
self.sample_by_class = sample_by_class
|
|
self.power = power
|
|
self.dynamic_length = dynamic_length
|
|
|
|
assert not (self.multi_class and self.sample_by_class)
|
|
|
|
self.video_infos = self.load_annotations()
|
|
if self.sample_by_class:
|
|
self.video_infos_by_class = self.parse_by_class()
|
|
|
|
class_prob = []
|
|
for _, samples in self.video_infos_by_class.items():
|
|
class_prob.append(len(samples) / len(self.video_infos))
|
|
class_prob = [x**self.power for x in class_prob]
|
|
|
|
summ = sum(class_prob)
|
|
class_prob = [x / summ for x in class_prob]
|
|
|
|
self.class_prob = dict(zip(self.video_infos_by_class, class_prob))
|
|
|
|
# box_thr, which should be a string
|
|
self.box_thr = box_thr
|
|
if self.box_thr is not None:
|
|
assert box_thr in ['0.5', '0.6', '0.7', '0.8', '0.9']
|
|
|
|
# Thresholding Training Examples
|
|
self.valid_ratio = valid_ratio
|
|
if self.valid_ratio is not None:
|
|
assert isinstance(self.valid_ratio, float)
|
|
if self.box_thr is None:
|
|
self.video_infos = self.video_infos = [
|
|
x for x in self.video_infos
|
|
if x['valid_frames'] / x['total_frames'] >= valid_ratio
|
|
]
|
|
else:
|
|
key = f'valid@{self.box_thr}'
|
|
self.video_infos = [
|
|
x for x in self.video_infos
|
|
if x[key] / x['total_frames'] >= valid_ratio
|
|
]
|
|
if self.box_thr != '0.5':
|
|
box_thr = float(self.box_thr)
|
|
for item in self.video_infos:
|
|
inds = [
|
|
i for i, score in enumerate(item['box_score'])
|
|
if score >= box_thr
|
|
]
|
|
item['anno_inds'] = np.array(inds)
|
|
|
|
if class_prob is not None:
|
|
self.class_prob = class_prob
|
|
|
|
logger = get_root_logger()
|
|
logger.info(f'{len(self)} videos remain after valid thresholding')
|
|
|
|
def load_annotations(self):
|
|
"""Load annotation file to get video information."""
|
|
assert self.ann_file.endswith('.pkl')
|
|
return self.load_pkl_annotations()
|
|
|
|
def load_pkl_annotations(self):
|
|
data = mmcv.load(self.ann_file)
|
|
|
|
if self.split:
|
|
split, data = data['split'], data['annotations']
|
|
identifier = 'filename' if 'filename' in data[0] else 'frame_dir'
|
|
data = [x for x in data if x[identifier] in split[self.split]]
|
|
|
|
for item in data:
|
|
# Sometimes we may need to load anno from the file
|
|
if 'filename' in item:
|
|
item['filename'] = osp.join(self.data_prefix, item['filename'])
|
|
if 'frame_dir' in item:
|
|
item['frame_dir'] = osp.join(self.data_prefix,
|
|
item['frame_dir'])
|
|
return data
|
|
|
|
def parse_by_class(self):
|
|
video_infos_by_class = defaultdict(list)
|
|
for item in self.video_infos:
|
|
label = item['label']
|
|
video_infos_by_class[label].append(item)
|
|
return video_infos_by_class
|
|
|
|
def prepare_frames(self, idx):
|
|
"""Prepare the frames for training given the index."""
|
|
results = copy.deepcopy(self.video_infos[idx])
|
|
results['modality'] = self.modality
|
|
results['start_index'] = self.start_index
|
|
|
|
# prepare tensor in getitem
|
|
# If HVU, type(results['label']) is dict
|
|
if self.multi_class and isinstance(results['label'], list):
|
|
onehot = torch.zeros(self.num_classes)
|
|
onehot[results['label']] = 1.
|
|
results['label'] = onehot
|
|
|
|
return results
|
|
|
|
def __len__(self):
|
|
"""Get the size of the dataset."""
|
|
return len(self.video_infos)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.prepare_frames(idx)
|