mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
refactoring for optflow
This commit is contained in:
parent
7ccd8b0b74
commit
978ecfda84
@ -1,3 +1,4 @@
|
||||
from .arraymisc import *
|
||||
from .utils import *
|
||||
from .fileio import *
|
||||
from .opencv_info import *
|
||||
|
1
mmcv/arraymisc/__init__.py
Normal file
1
mmcv/arraymisc/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .quantization import *
|
58
mmcv/arraymisc/quantization.py
Normal file
58
mmcv/arraymisc/quantization.py
Normal file
@ -0,0 +1,58 @@
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['quantize', 'dequantize']
|
||||
|
||||
|
||||
def quantize(arr, min_val, max_val, levels, dtype=np.int64):
|
||||
"""Quantize an array of (-inf, inf) to [0, levels-1].
|
||||
|
||||
Args:
|
||||
arr (ndarray): Input array.
|
||||
min_val (scalar): Minimum value to be clipped.
|
||||
max_val (scalar): Maximum value to be clipped.
|
||||
levels (int): Quantization levels.
|
||||
dtype (np.type): The type of the quantized array.
|
||||
|
||||
Returns:
|
||||
tuple: Quantized array.
|
||||
"""
|
||||
if not (isinstance(levels, int) and levels > 1):
|
||||
raise ValueError(
|
||||
'levels must be a positive integer, but got {}'.format(levels))
|
||||
if min_val >= max_val:
|
||||
raise ValueError(
|
||||
'min_val ({}) must be smaller than max_val ({})'.format(
|
||||
min_val, max_val))
|
||||
|
||||
arr = np.clip(arr, min_val, max_val) - min_val
|
||||
quantized_arr = np.minimum(
|
||||
np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
|
||||
|
||||
return quantized_arr
|
||||
|
||||
|
||||
def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
|
||||
"""Dequantize an array.
|
||||
|
||||
Args:
|
||||
arr (ndarray): Input array.
|
||||
min_val (scalar): Minimum value to be clipped.
|
||||
max_val (scalar): Maximum value to be clipped.
|
||||
levels (int): Quantization levels.
|
||||
dtype (np.type): The type of the dequantized array.
|
||||
|
||||
Returns:
|
||||
tuple: Dequantized array.
|
||||
"""
|
||||
if not (isinstance(levels, int) and levels > 1):
|
||||
raise ValueError(
|
||||
'levels must be a positive integer, but got {}'.format(levels))
|
||||
if min_val >= max_val:
|
||||
raise ValueError(
|
||||
'min_val ({}) must be smaller than max_val ({})'.format(
|
||||
min_val, max_val))
|
||||
|
||||
dequantized_arr = (arr + 0.5).astype(dtype) * (
|
||||
max_val - min_val) / levels + min_val
|
||||
|
||||
return dequantized_arr
|
@ -1,26 +1,19 @@
|
||||
import numpy as np
|
||||
|
||||
from mmcv.arraymisc import quantize, dequantize
|
||||
from mmcv.image import imread, imwrite
|
||||
from mmcv.utils import is_str
|
||||
|
||||
|
||||
def _pair_name(filename, suffix=('_dx', '_dy')):
|
||||
parts = filename.split('.')
|
||||
path_wo_ext = parts[-2]
|
||||
parts[-2] = path_wo_ext + suffix[0]
|
||||
dx_filename = '.'.join(parts)
|
||||
parts[-2] = path_wo_ext + suffix[1]
|
||||
dy_filename = '.'.join(parts)
|
||||
return dx_filename, dy_filename
|
||||
|
||||
|
||||
def read_flow(flow_or_path, quantize=False, *args, **kwargs):
|
||||
def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs):
|
||||
"""Read an optical flow map.
|
||||
|
||||
Args:
|
||||
flow_or_path (ndarray or str): A flow map or filepath.
|
||||
quantize (bool): whether to read quantized pair, if set to True,
|
||||
remaining args will be passed to :func:`dequantize_flow`.
|
||||
concat_axis (int): The axis that dx and dy are concatenated,
|
||||
can be either 0 or 1. Ignored if quantize is False.
|
||||
|
||||
Returns:
|
||||
ndarray: Optical flow represented as a (h, w, 2) numpy array
|
||||
@ -51,23 +44,34 @@ def read_flow(flow_or_path, quantize=False, *args, **kwargs):
|
||||
h = np.fromfile(f, np.int32, 1).squeeze()
|
||||
flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
|
||||
else:
|
||||
dx_filename, dy_filename = _pair_name(flow_or_path)
|
||||
dx = imread(dx_filename, flag='unchanged')
|
||||
dy = imread(dy_filename, flag='unchanged')
|
||||
assert concat_axis in [0, 1]
|
||||
cat_flow = imread(flow_or_path, flag='unchanged')
|
||||
if cat_flow.ndim != 2:
|
||||
raise IOError(
|
||||
'{} is not a valid quantized flow file, its dimension is {}.'.
|
||||
format(flow_or_path, cat_flow.ndim))
|
||||
assert cat_flow.shape[concat_axis] % 2 == 0
|
||||
dx, dy = np.split(cat_flow, 2, axis=concat_axis)
|
||||
flow = dequantize_flow(dx, dy, *args, **kwargs)
|
||||
|
||||
return flow.astype(np.float32)
|
||||
|
||||
|
||||
def write_flow(flow, filename, quantize=False, *args, **kwargs):
|
||||
def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
|
||||
"""Write optical flow to file.
|
||||
|
||||
If the flow is not quantized, it will be saved as a .flo file losslessly,
|
||||
otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
|
||||
will be concatenated horizontally into a single image if quantize is True.)
|
||||
|
||||
Args:
|
||||
flow (ndarray): (h, w, 2) array of optical flow.
|
||||
filename (str): Output filepath.
|
||||
quantize (bool): Whether to quantize the flow and save it to 2 jpeg
|
||||
images. If set to True, remaining args will be passed to
|
||||
:func:`quantize_flow`.
|
||||
concat_axis (int): The axis that dx and dy are concatenated,
|
||||
can be either 0 or 1. Ignored if quantize is False.
|
||||
"""
|
||||
if not quantize:
|
||||
with open(filename, 'wb') as f:
|
||||
@ -77,10 +81,10 @@ def write_flow(flow, filename, quantize=False, *args, **kwargs):
|
||||
flow.tofile(f)
|
||||
f.flush()
|
||||
else:
|
||||
assert concat_axis in [0, 1]
|
||||
dx, dy = quantize_flow(flow, *args, **kwargs)
|
||||
dx_filename, dy_filename = _pair_name(filename)
|
||||
imwrite(dx, dx_filename)
|
||||
imwrite(dy, dy_filename)
|
||||
dxdy = np.concatenate((dx, dy), axis=concat_axis)
|
||||
imwrite(dxdy, filename)
|
||||
|
||||
|
||||
def quantize_flow(flow, max_val=0.02, norm=True):
|
||||
@ -96,7 +100,7 @@ def quantize_flow(flow, max_val=0.02, norm=True):
|
||||
norm (bool): Whether to divide flow values by image width/height.
|
||||
|
||||
Returns:
|
||||
tuple: Quantized dx and dy.
|
||||
tuple[ndarray]: Quantized dx and dy.
|
||||
"""
|
||||
h, w, _ = flow.shape
|
||||
dx = flow[..., 0]
|
||||
@ -104,11 +108,11 @@ def quantize_flow(flow, max_val=0.02, norm=True):
|
||||
if norm:
|
||||
dx = dx / w # avoid inplace operations
|
||||
dy = dy / h
|
||||
dx = np.maximum(0, np.minimum(dx + max_val, 2 * max_val))
|
||||
dy = np.maximum(0, np.minimum(dy + max_val, 2 * max_val))
|
||||
dx = np.round(dx * 255 / (max_val * 2)).astype(np.uint8)
|
||||
dy = np.round(dy * 255 / (max_val * 2)).astype(np.uint8)
|
||||
return dx, dy
|
||||
# use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
|
||||
flow_comps = [
|
||||
quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]
|
||||
]
|
||||
return tuple(flow_comps)
|
||||
|
||||
|
||||
def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
|
||||
@ -121,12 +125,13 @@ def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
|
||||
denorm (bool): Whether to multiply flow values with width/height.
|
||||
|
||||
Returns:
|
||||
tuple: Dequantized dx and dy
|
||||
ndarray: Dequantized flow.
|
||||
"""
|
||||
assert dx.shape == dy.shape
|
||||
assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
|
||||
dx = dx.astype(np.float32) * max_val * 2 / 255 - max_val
|
||||
dy = dy.astype(np.float32) * max_val * 2 / 255 - max_val
|
||||
|
||||
dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
|
||||
|
||||
if denorm:
|
||||
dx *= dx.shape[1]
|
||||
dy *= dx.shape[0]
|
||||
|
BIN
tests/data/optflow_concat0.jpg
Normal file
BIN
tests/data/optflow_concat0.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.9 KiB |
BIN
tests/data/optflow_concat1.jpg
Normal file
BIN
tests/data/optflow_concat1.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.9 KiB |
Binary file not shown.
Before Width: | Height: | Size: 1.5 KiB |
Binary file not shown.
Before Width: | Height: | Size: 1.7 KiB |
54
tests/test_arraymisc.py
Normal file
54
tests/test_arraymisc.py
Normal file
@ -0,0 +1,54 @@
|
||||
from __future__ import division
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
||||
|
||||
def test_quantize():
|
||||
arr = np.random.randn(10, 10)
|
||||
levels = 20
|
||||
|
||||
qarr = mmcv.quantize(arr, -1, 1, levels)
|
||||
assert qarr.shape == arr.shape
|
||||
assert qarr.dtype == np.dtype('int64')
|
||||
for i in range(arr.shape[0]):
|
||||
for j in range(arr.shape[1]):
|
||||
ref = min(levels - 1,
|
||||
int(np.floor(10 * (1 + max(min(arr[i, j], 1), -1)))))
|
||||
assert qarr[i, j] == ref
|
||||
|
||||
qarr = mmcv.quantize(arr, -1, 1, 20, dtype=np.uint8)
|
||||
assert qarr.shape == arr.shape
|
||||
assert qarr.dtype == np.dtype('uint8')
|
||||
|
||||
|
||||
def test_dequantize():
|
||||
levels = 20
|
||||
qarr = np.random.randint(levels, size=(10, 10))
|
||||
|
||||
arr = mmcv.dequantize(qarr, -1, 1, levels)
|
||||
assert arr.shape == qarr.shape
|
||||
assert arr.dtype == np.dtype('float64')
|
||||
for i in range(qarr.shape[0]):
|
||||
for j in range(qarr.shape[1]):
|
||||
assert arr[i, j] == (qarr[i, j] + 0.5) / 10 - 1
|
||||
|
||||
arr = mmcv.dequantize(qarr, -1, 1, levels, dtype=np.float32)
|
||||
assert arr.shape == qarr.shape
|
||||
assert arr.dtype == np.dtype('float32')
|
||||
|
||||
|
||||
def test_joint():
|
||||
arr = np.random.randn(100, 100)
|
||||
levels = 1000
|
||||
qarr = mmcv.quantize(arr, -1, 1, levels)
|
||||
recover = mmcv.dequantize(qarr, -1, 1, levels)
|
||||
assert np.abs(recover[arr < -1] + 0.999).max() < 1e-6
|
||||
assert np.abs(recover[arr > 1] - 0.999).max() < 1e-6
|
||||
assert np.abs((recover - arr)[(arr >= -1) & (arr <= 1)]).max() <= 1e-3
|
||||
|
||||
arr = np.clip(np.random.randn(100) / 1000, -0.01, 0.01)
|
||||
levels = 99
|
||||
qarr = mmcv.quantize(arr, -1, 1, levels)
|
||||
recover = mmcv.dequantize(qarr, -1, 1, levels)
|
||||
assert np.all(recover == 0)
|
@ -8,39 +8,67 @@ import pytest
|
||||
from numpy.testing import assert_array_equal, assert_array_almost_equal
|
||||
|
||||
|
||||
def test_read_flow():
|
||||
flow = mmcv.read_flow(osp.join(osp.dirname(__file__), 'data/optflow.flo'))
|
||||
assert flow.ndim == 3 and flow.shape[-1] == 2
|
||||
flow_same = mmcv.read_flow(flow)
|
||||
def test_flowread():
|
||||
flow_shape = (60, 80, 2)
|
||||
|
||||
# read .flo file
|
||||
flow = mmcv.flowread(osp.join(osp.dirname(__file__), 'data/optflow.flo'))
|
||||
assert flow.shape == flow_shape
|
||||
|
||||
# pseudo read
|
||||
flow_same = mmcv.flowread(flow)
|
||||
assert_array_equal(flow, flow_same)
|
||||
flow = mmcv.read_flow(
|
||||
osp.join(osp.dirname(__file__), 'data/optflow.jpg'),
|
||||
|
||||
# read quantized flow concatenated vertically
|
||||
flow = mmcv.flowread(
|
||||
osp.join(osp.dirname(__file__), 'data/optflow_concat0.jpg'),
|
||||
quantize=True,
|
||||
denorm=True)
|
||||
assert flow.ndim == 3 and flow.shape[-1] == 2
|
||||
with pytest.raises(IOError):
|
||||
mmcv.read_flow(osp.join(osp.dirname(__file__), 'data/color.jpg'))
|
||||
with pytest.raises(ValueError):
|
||||
mmcv.read_flow(np.zeros((100, 100, 1)))
|
||||
assert flow.shape == flow_shape
|
||||
|
||||
# read quantized flow concatenated horizontally
|
||||
flow = mmcv.flowread(
|
||||
osp.join(osp.dirname(__file__), 'data/optflow_concat1.jpg'),
|
||||
quantize=True,
|
||||
concat_axis=1,
|
||||
denorm=True)
|
||||
assert flow.shape == flow_shape
|
||||
|
||||
# test exceptions
|
||||
notflow_file = osp.join(osp.dirname(__file__), 'data/color.jpg')
|
||||
with pytest.raises(TypeError):
|
||||
mmcv.read_flow(1)
|
||||
mmcv.flowread(1)
|
||||
with pytest.raises(IOError):
|
||||
mmcv.flowread(notflow_file)
|
||||
with pytest.raises(IOError):
|
||||
mmcv.flowread(notflow_file, quantize=True)
|
||||
with pytest.raises(ValueError):
|
||||
mmcv.flowread(np.zeros((100, 100, 1)))
|
||||
|
||||
|
||||
def test_write_flow():
|
||||
def test_flowwrite():
|
||||
flow = np.random.rand(100, 100, 2).astype(np.float32)
|
||||
|
||||
# write to a .flo file
|
||||
_, filename = tempfile.mkstemp()
|
||||
mmcv.write_flow(flow, filename)
|
||||
flow_from_file = mmcv.read_flow(filename)
|
||||
mmcv.flowwrite(flow, filename)
|
||||
flow_from_file = mmcv.flowread(filename)
|
||||
assert_array_equal(flow, flow_from_file)
|
||||
os.remove(filename)
|
||||
|
||||
# write to two .jpg files
|
||||
tmp_dir = tempfile.gettempdir()
|
||||
mmcv.write_flow(flow, osp.join(tmp_dir, 'test_flow.jpg'), quantize=True)
|
||||
assert osp.isfile(osp.join(tmp_dir, 'test_flow_dx.jpg'))
|
||||
assert osp.isfile(osp.join(tmp_dir, 'test_flow_dy.jpg'))
|
||||
os.remove(osp.join(tmp_dir, 'test_flow_dx.jpg'))
|
||||
os.remove(osp.join(tmp_dir, 'test_flow_dy.jpg'))
|
||||
tmp_filename = osp.join(tempfile.gettempdir(), 'mmcv_test_flow.jpg')
|
||||
for concat_axis in range(2):
|
||||
mmcv.flowwrite(
|
||||
flow, tmp_filename, quantize=True, concat_axis=concat_axis)
|
||||
shape = (200, 100) if concat_axis == 0 else (100, 200)
|
||||
assert osp.isfile(tmp_filename)
|
||||
assert mmcv.imread(tmp_filename, flag='unchanged').shape == shape
|
||||
os.remove(tmp_filename)
|
||||
|
||||
# test exceptions
|
||||
with pytest.raises(AssertionError):
|
||||
mmcv.flowwrite(flow, tmp_filename, quantize=True, concat_axis=2)
|
||||
|
||||
|
||||
def test_quantize_flow():
|
||||
@ -53,7 +81,7 @@ def test_quantize_flow():
|
||||
for k in range(ref.shape[2]):
|
||||
val = flow[i, j, k] + max_val
|
||||
val = min(max(val, 0), 2 * max_val)
|
||||
ref[i, j, k] = np.round(255 * val / (2 * max_val))
|
||||
ref[i, j, k] = min(np.floor(255 * val / (2 * max_val)), 254)
|
||||
assert_array_equal(dx, ref[..., 0])
|
||||
assert_array_equal(dy, ref[..., 1])
|
||||
max_val = 0.5
|
||||
@ -65,7 +93,7 @@ def test_quantize_flow():
|
||||
scale = flow.shape[1] if k == 0 else flow.shape[0]
|
||||
val = flow[i, j, k] / scale + max_val
|
||||
val = min(max(val, 0), 2 * max_val)
|
||||
ref[i, j, k] = np.round(255 * val / (2 * max_val))
|
||||
ref[i, j, k] = min(np.floor(255 * val / (2 * max_val)), 254)
|
||||
assert_array_equal(dx, ref[..., 0])
|
||||
assert_array_equal(dy, ref[..., 1])
|
||||
|
||||
@ -78,8 +106,8 @@ def test_dequantize_flow():
|
||||
ref = np.zeros_like(flow, dtype=np.float32)
|
||||
for i in range(ref.shape[0]):
|
||||
for j in range(ref.shape[1]):
|
||||
ref[i, j, 0] = float(dx[i, j]) * 2 * max_val / 255 - max_val
|
||||
ref[i, j, 1] = float(dy[i, j]) * 2 * max_val / 255 - max_val
|
||||
ref[i, j, 0] = float(dx[i, j] + 0.5) * 2 * max_val / 255 - max_val
|
||||
ref[i, j, 1] = float(dy[i, j] + 0.5) * 2 * max_val / 255 - max_val
|
||||
assert_array_almost_equal(flow, ref)
|
||||
max_val = 0.5
|
||||
flow = mmcv.dequantize_flow(dx, dy, max_val=max_val, denorm=True)
|
||||
@ -87,8 +115,10 @@ def test_dequantize_flow():
|
||||
ref = np.zeros_like(flow, dtype=np.float32)
|
||||
for i in range(ref.shape[0]):
|
||||
for j in range(ref.shape[1]):
|
||||
ref[i, j, 0] = (float(dx[i, j]) * 2 * max_val / 255 - max_val) * w
|
||||
ref[i, j, 1] = (float(dy[i, j]) * 2 * max_val / 255 - max_val) * h
|
||||
ref[i, j,
|
||||
0] = (float(dx[i, j] + 0.5) * 2 * max_val / 255 - max_val) * w
|
||||
ref[i, j,
|
||||
1] = (float(dy[i, j] + 0.5) * 2 * max_val / 255 - max_val) * h
|
||||
assert_array_almost_equal(flow, ref)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user