mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
58 lines
1.7 KiB
Python
58 lines
1.7 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import logging
|
|
import traceback
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from easycv.datasets.registry import DATASETS
|
|
from easycv.datasets.shared.base import BaseDataset
|
|
|
|
|
|
@DATASETS.register_module
|
|
class VideoDataset(BaseDataset):
|
|
"""Dataset for video recognition
|
|
"""
|
|
|
|
def __init__(self, data_source, pipeline, profiling=False):
|
|
"""
|
|
|
|
Args:
|
|
data_source: Data_source config dict
|
|
pipeline: Pipeline config list
|
|
profiling: If set True, will print pipeline time
|
|
"""
|
|
super(VideoDataset, self).__init__(
|
|
data_source, pipeline, profiling=profiling)
|
|
|
|
def __getitem__(self, idx):
|
|
try:
|
|
data_dict = self.data_source[idx]
|
|
data_dict = self.pipeline(data_dict)
|
|
except:
|
|
logging.error(
|
|
'When parsing line {}, error happened with msg: {}'.format(
|
|
idx, traceback.format_exc()))
|
|
data_dict = None
|
|
if data_dict is None:
|
|
return self[np.random.randint(len(self))]
|
|
return data_dict
|
|
|
|
def evaluate(self, results, evaluators=[], **kwargs):
|
|
"""Evaluate the dataset.
|
|
"""
|
|
assert len(evaluators) == 1, \
|
|
'classification evaluation only support one evaluator'
|
|
if results.get('label', None) is not None:
|
|
gt_labels = results.pop('label')
|
|
else:
|
|
gt_labels = []
|
|
for i in range(len(self.data_source)):
|
|
label = self.data_source.video_infos[i]['label']
|
|
gt_labels.append(label)
|
|
gt_labels = torch.Tensor(gt_labels)
|
|
|
|
eval_res = evaluators[0].evaluate(results, gt_labels)
|
|
|
|
return eval_res
|