mirror of https://github.com/open-mmlab/mmcv.git
add slice support and unit test
parent
6d51a94196
commit
2ab73c85ee
|
@ -163,8 +163,8 @@ class VideoReader(object):
|
|||
ndarray or None: Return the frame if successful, otherwise None.
|
||||
"""
|
||||
if frame_id < 0 or frame_id >= self._frame_cnt:
|
||||
raise ValueError('"frame_id" must be between 0 and {}'.format(
|
||||
self._frame_cnt))
|
||||
raise IndexError('"frame_id" must be between 0 and {}'.format(
|
||||
self._frame_cnt - 1))
|
||||
if frame_id == self._position:
|
||||
return self.read()
|
||||
if self._cache:
|
||||
|
@ -240,7 +240,7 @@ class VideoReader(object):
|
|||
|
||||
def __getitem__(self, index):
|
||||
if isinstance(index, slice):
|
||||
raise RuntimeError('slice has not been supported yet')
|
||||
return [self.get_frame(i) for i in range(*index.indices(self.frame_cnt))]
|
||||
return self.get_frame(index)
|
||||
|
||||
def __iter__(self):
|
||||
|
|
|
@ -69,6 +69,12 @@ class TestVideo(object):
|
|||
with pytest.raises(ValueError):
|
||||
v.get_frame(self.num_frames + 1)
|
||||
|
||||
def test_slice(self):
|
||||
v = mmcv.VideoReader(self.video_path)
|
||||
imgs = v[-105:-103]
|
||||
assert int(round(imgs[0].mean())) == 94
|
||||
assert int(round(imgs[1].mean())) == 205
|
||||
|
||||
def test_current_frame(self):
|
||||
v = mmcv.VideoReader(self.video_path)
|
||||
assert v.current_frame() is None
|
||||
|
|
Loading…
Reference in New Issue