diff --git a/mmcv/video/io.py b/mmcv/video/io.py index 7dd17f9d4..496abb713 100644 --- a/mmcv/video/io.py +++ b/mmcv/video/io.py @@ -163,8 +163,8 @@ class VideoReader(object): ndarray or None: Return the frame if successful, otherwise None. """ if frame_id < 0 or frame_id >= self._frame_cnt: - raise ValueError('"frame_id" must be between 0 and {}'.format( - self._frame_cnt)) + raise IndexError('"frame_id" must be between 0 and {}'.format( + self._frame_cnt - 1)) if frame_id == self._position: return self.read() if self._cache: @@ -240,7 +240,7 @@ class VideoReader(object): def __getitem__(self, index): if isinstance(index, slice): - raise RuntimeError('slice has not been supported yet') + return [self.get_frame(i) for i in range(*index.indices(self.frame_cnt))] return self.get_frame(index) def __iter__(self): diff --git a/tests/test_video.py b/tests/test_video.py index fc9a85220..1ddd0c850 100644 --- a/tests/test_video.py +++ b/tests/test_video.py @@ -69,6 +69,12 @@ class TestVideo(object): with pytest.raises(ValueError): v.get_frame(self.num_frames + 1) + def test_slice(self): + v = mmcv.VideoReader(self.video_path) + imgs = v[-105:-103] + assert int(round(imgs[0].mean())) == 94 + assert int(round(imgs[1].mean())) == 205 + def test_current_frame(self): v = mmcv.VideoReader(self.video_path) assert v.current_frame() is None