refactoring for optflow

This commit is contained in:
Kai Chen 2018-08-28 01:38:53 +08:00
parent 7ccd8b0b74
commit 978ecfda84
10 changed files with 203 additions and 54 deletions

View File

@ -1,3 +1,4 @@
from .arraymisc import *
from .utils import *
from .fileio import *
from .opencv_info import *

View File

@ -0,0 +1 @@
from .quantization import *

View File

@ -0,0 +1,58 @@
import numpy as np
__all__ = ['quantize', 'dequantize']
def quantize(arr, min_val, max_val, levels, dtype=np.int64):
"""Quantize an array of (-inf, inf) to [0, levels-1].
Args:
arr (ndarray): Input array.
min_val (scalar): Minimum value to be clipped.
max_val (scalar): Maximum value to be clipped.
levels (int): Quantization levels.
dtype (np.type): The type of the quantized array.
Returns:
tuple: Quantized array.
"""
if not (isinstance(levels, int) and levels > 1):
raise ValueError(
'levels must be a positive integer, but got {}'.format(levels))
if min_val >= max_val:
raise ValueError(
'min_val ({}) must be smaller than max_val ({})'.format(
min_val, max_val))
arr = np.clip(arr, min_val, max_val) - min_val
quantized_arr = np.minimum(
np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
return quantized_arr
def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
"""Dequantize an array.
Args:
arr (ndarray): Input array.
min_val (scalar): Minimum value to be clipped.
max_val (scalar): Maximum value to be clipped.
levels (int): Quantization levels.
dtype (np.type): The type of the dequantized array.
Returns:
tuple: Dequantized array.
"""
if not (isinstance(levels, int) and levels > 1):
raise ValueError(
'levels must be a positive integer, but got {}'.format(levels))
if min_val >= max_val:
raise ValueError(
'min_val ({}) must be smaller than max_val ({})'.format(
min_val, max_val))
dequantized_arr = (arr + 0.5).astype(dtype) * (
max_val - min_val) / levels + min_val
return dequantized_arr

View File

@ -1,26 +1,19 @@
import numpy as np
from mmcv.arraymisc import quantize, dequantize
from mmcv.image import imread, imwrite
from mmcv.utils import is_str
def _pair_name(filename, suffix=('_dx', '_dy')):
parts = filename.split('.')
path_wo_ext = parts[-2]
parts[-2] = path_wo_ext + suffix[0]
dx_filename = '.'.join(parts)
parts[-2] = path_wo_ext + suffix[1]
dy_filename = '.'.join(parts)
return dx_filename, dy_filename
def read_flow(flow_or_path, quantize=False, *args, **kwargs):
def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs):
"""Read an optical flow map.
Args:
flow_or_path (ndarray or str): A flow map or filepath.
quantize (bool): whether to read quantized pair, if set to True,
remaining args will be passed to :func:`dequantize_flow`.
concat_axis (int): The axis that dx and dy are concatenated,
can be either 0 or 1. Ignored if quantize is False.
Returns:
ndarray: Optical flow represented as a (h, w, 2) numpy array
@ -51,23 +44,34 @@ def read_flow(flow_or_path, quantize=False, *args, **kwargs):
h = np.fromfile(f, np.int32, 1).squeeze()
flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
else:
dx_filename, dy_filename = _pair_name(flow_or_path)
dx = imread(dx_filename, flag='unchanged')
dy = imread(dy_filename, flag='unchanged')
assert concat_axis in [0, 1]
cat_flow = imread(flow_or_path, flag='unchanged')
if cat_flow.ndim != 2:
raise IOError(
'{} is not a valid quantized flow file, its dimension is {}.'.
format(flow_or_path, cat_flow.ndim))
assert cat_flow.shape[concat_axis] % 2 == 0
dx, dy = np.split(cat_flow, 2, axis=concat_axis)
flow = dequantize_flow(dx, dy, *args, **kwargs)
return flow.astype(np.float32)
def write_flow(flow, filename, quantize=False, *args, **kwargs):
def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
"""Write optical flow to file.
If the flow is not quantized, it will be saved as a .flo file losslessly,
otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
will be concatenated horizontally into a single image if quantize is True.)
Args:
flow (ndarray): (h, w, 2) array of optical flow.
filename (str): Output filepath.
quantize (bool): Whether to quantize the flow and save it to 2 jpeg
images. If set to True, remaining args will be passed to
:func:`quantize_flow`.
concat_axis (int): The axis that dx and dy are concatenated,
can be either 0 or 1. Ignored if quantize is False.
"""
if not quantize:
with open(filename, 'wb') as f:
@ -77,10 +81,10 @@ def write_flow(flow, filename, quantize=False, *args, **kwargs):
flow.tofile(f)
f.flush()
else:
assert concat_axis in [0, 1]
dx, dy = quantize_flow(flow, *args, **kwargs)
dx_filename, dy_filename = _pair_name(filename)
imwrite(dx, dx_filename)
imwrite(dy, dy_filename)
dxdy = np.concatenate((dx, dy), axis=concat_axis)
imwrite(dxdy, filename)
def quantize_flow(flow, max_val=0.02, norm=True):
@ -96,7 +100,7 @@ def quantize_flow(flow, max_val=0.02, norm=True):
norm (bool): Whether to divide flow values by image width/height.
Returns:
tuple: Quantized dx and dy.
tuple[ndarray]: Quantized dx and dy.
"""
h, w, _ = flow.shape
dx = flow[..., 0]
@ -104,11 +108,11 @@ def quantize_flow(flow, max_val=0.02, norm=True):
if norm:
dx = dx / w # avoid inplace operations
dy = dy / h
dx = np.maximum(0, np.minimum(dx + max_val, 2 * max_val))
dy = np.maximum(0, np.minimum(dy + max_val, 2 * max_val))
dx = np.round(dx * 255 / (max_val * 2)).astype(np.uint8)
dy = np.round(dy * 255 / (max_val * 2)).astype(np.uint8)
return dx, dy
# use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
flow_comps = [
quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]
]
return tuple(flow_comps)
def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
@ -121,12 +125,13 @@ def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
denorm (bool): Whether to multiply flow values with width/height.
Returns:
tuple: Dequantized dx and dy
ndarray: Dequantized flow.
"""
assert dx.shape == dy.shape
assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
dx = dx.astype(np.float32) * max_val * 2 / 255 - max_val
dy = dy.astype(np.float32) * max_val * 2 / 255 - max_val
dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
if denorm:
dx *= dx.shape[1]
dy *= dx.shape[0]

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.7 KiB

54
tests/test_arraymisc.py Normal file
View File

@ -0,0 +1,54 @@
from __future__ import division
import mmcv
import numpy as np
def test_quantize():
arr = np.random.randn(10, 10)
levels = 20
qarr = mmcv.quantize(arr, -1, 1, levels)
assert qarr.shape == arr.shape
assert qarr.dtype == np.dtype('int64')
for i in range(arr.shape[0]):
for j in range(arr.shape[1]):
ref = min(levels - 1,
int(np.floor(10 * (1 + max(min(arr[i, j], 1), -1)))))
assert qarr[i, j] == ref
qarr = mmcv.quantize(arr, -1, 1, 20, dtype=np.uint8)
assert qarr.shape == arr.shape
assert qarr.dtype == np.dtype('uint8')
def test_dequantize():
levels = 20
qarr = np.random.randint(levels, size=(10, 10))
arr = mmcv.dequantize(qarr, -1, 1, levels)
assert arr.shape == qarr.shape
assert arr.dtype == np.dtype('float64')
for i in range(qarr.shape[0]):
for j in range(qarr.shape[1]):
assert arr[i, j] == (qarr[i, j] + 0.5) / 10 - 1
arr = mmcv.dequantize(qarr, -1, 1, levels, dtype=np.float32)
assert arr.shape == qarr.shape
assert arr.dtype == np.dtype('float32')
def test_joint():
arr = np.random.randn(100, 100)
levels = 1000
qarr = mmcv.quantize(arr, -1, 1, levels)
recover = mmcv.dequantize(qarr, -1, 1, levels)
assert np.abs(recover[arr < -1] + 0.999).max() < 1e-6
assert np.abs(recover[arr > 1] - 0.999).max() < 1e-6
assert np.abs((recover - arr)[(arr >= -1) & (arr <= 1)]).max() <= 1e-3
arr = np.clip(np.random.randn(100) / 1000, -0.01, 0.01)
levels = 99
qarr = mmcv.quantize(arr, -1, 1, levels)
recover = mmcv.dequantize(qarr, -1, 1, levels)
assert np.all(recover == 0)

View File

@ -8,39 +8,67 @@ import pytest
from numpy.testing import assert_array_equal, assert_array_almost_equal
def test_read_flow():
flow = mmcv.read_flow(osp.join(osp.dirname(__file__), 'data/optflow.flo'))
assert flow.ndim == 3 and flow.shape[-1] == 2
flow_same = mmcv.read_flow(flow)
def test_flowread():
flow_shape = (60, 80, 2)
# read .flo file
flow = mmcv.flowread(osp.join(osp.dirname(__file__), 'data/optflow.flo'))
assert flow.shape == flow_shape
# pseudo read
flow_same = mmcv.flowread(flow)
assert_array_equal(flow, flow_same)
flow = mmcv.read_flow(
osp.join(osp.dirname(__file__), 'data/optflow.jpg'),
# read quantized flow concatenated vertically
flow = mmcv.flowread(
osp.join(osp.dirname(__file__), 'data/optflow_concat0.jpg'),
quantize=True,
denorm=True)
assert flow.ndim == 3 and flow.shape[-1] == 2
with pytest.raises(IOError):
mmcv.read_flow(osp.join(osp.dirname(__file__), 'data/color.jpg'))
with pytest.raises(ValueError):
mmcv.read_flow(np.zeros((100, 100, 1)))
assert flow.shape == flow_shape
# read quantized flow concatenated horizontally
flow = mmcv.flowread(
osp.join(osp.dirname(__file__), 'data/optflow_concat1.jpg'),
quantize=True,
concat_axis=1,
denorm=True)
assert flow.shape == flow_shape
# test exceptions
notflow_file = osp.join(osp.dirname(__file__), 'data/color.jpg')
with pytest.raises(TypeError):
mmcv.read_flow(1)
mmcv.flowread(1)
with pytest.raises(IOError):
mmcv.flowread(notflow_file)
with pytest.raises(IOError):
mmcv.flowread(notflow_file, quantize=True)
with pytest.raises(ValueError):
mmcv.flowread(np.zeros((100, 100, 1)))
def test_write_flow():
def test_flowwrite():
flow = np.random.rand(100, 100, 2).astype(np.float32)
# write to a .flo file
_, filename = tempfile.mkstemp()
mmcv.write_flow(flow, filename)
flow_from_file = mmcv.read_flow(filename)
mmcv.flowwrite(flow, filename)
flow_from_file = mmcv.flowread(filename)
assert_array_equal(flow, flow_from_file)
os.remove(filename)
# write to two .jpg files
tmp_dir = tempfile.gettempdir()
mmcv.write_flow(flow, osp.join(tmp_dir, 'test_flow.jpg'), quantize=True)
assert osp.isfile(osp.join(tmp_dir, 'test_flow_dx.jpg'))
assert osp.isfile(osp.join(tmp_dir, 'test_flow_dy.jpg'))
os.remove(osp.join(tmp_dir, 'test_flow_dx.jpg'))
os.remove(osp.join(tmp_dir, 'test_flow_dy.jpg'))
tmp_filename = osp.join(tempfile.gettempdir(), 'mmcv_test_flow.jpg')
for concat_axis in range(2):
mmcv.flowwrite(
flow, tmp_filename, quantize=True, concat_axis=concat_axis)
shape = (200, 100) if concat_axis == 0 else (100, 200)
assert osp.isfile(tmp_filename)
assert mmcv.imread(tmp_filename, flag='unchanged').shape == shape
os.remove(tmp_filename)
# test exceptions
with pytest.raises(AssertionError):
mmcv.flowwrite(flow, tmp_filename, quantize=True, concat_axis=2)
def test_quantize_flow():
@ -53,7 +81,7 @@ def test_quantize_flow():
for k in range(ref.shape[2]):
val = flow[i, j, k] + max_val
val = min(max(val, 0), 2 * max_val)
ref[i, j, k] = np.round(255 * val / (2 * max_val))
ref[i, j, k] = min(np.floor(255 * val / (2 * max_val)), 254)
assert_array_equal(dx, ref[..., 0])
assert_array_equal(dy, ref[..., 1])
max_val = 0.5
@ -65,7 +93,7 @@ def test_quantize_flow():
scale = flow.shape[1] if k == 0 else flow.shape[0]
val = flow[i, j, k] / scale + max_val
val = min(max(val, 0), 2 * max_val)
ref[i, j, k] = np.round(255 * val / (2 * max_val))
ref[i, j, k] = min(np.floor(255 * val / (2 * max_val)), 254)
assert_array_equal(dx, ref[..., 0])
assert_array_equal(dy, ref[..., 1])
@ -78,8 +106,8 @@ def test_dequantize_flow():
ref = np.zeros_like(flow, dtype=np.float32)
for i in range(ref.shape[0]):
for j in range(ref.shape[1]):
ref[i, j, 0] = float(dx[i, j]) * 2 * max_val / 255 - max_val
ref[i, j, 1] = float(dy[i, j]) * 2 * max_val / 255 - max_val
ref[i, j, 0] = float(dx[i, j] + 0.5) * 2 * max_val / 255 - max_val
ref[i, j, 1] = float(dy[i, j] + 0.5) * 2 * max_val / 255 - max_val
assert_array_almost_equal(flow, ref)
max_val = 0.5
flow = mmcv.dequantize_flow(dx, dy, max_val=max_val, denorm=True)
@ -87,8 +115,10 @@ def test_dequantize_flow():
ref = np.zeros_like(flow, dtype=np.float32)
for i in range(ref.shape[0]):
for j in range(ref.shape[1]):
ref[i, j, 0] = (float(dx[i, j]) * 2 * max_val / 255 - max_val) * w
ref[i, j, 1] = (float(dy[i, j]) * 2 * max_val / 255 - max_val) * h
ref[i, j,
0] = (float(dx[i, j] + 0.5) * 2 * max_val / 255 - max_val) * w
ref[i, j,
1] = (float(dy[i, j] + 0.5) * 2 * max_val / 255 - max_val) * h
assert_array_almost_equal(flow, ref)