mirror of https://github.com/open-mmlab/mmcv.git
C++ implementation for Flow Warp Module (#71)
* flow_warp_c * flow_warp_c * beautify format * beautify format * beautify format * beautify format * add Cython * modify * fix details * fix details * fix typepull/77/head^2
parent
20d7cd30dc
commit
9097957434
|
@ -8,7 +8,7 @@ before_install:
|
|||
- sudo apt-get install -y ffmpeg
|
||||
|
||||
install:
|
||||
- pip install opencv-python pyyaml codecov flake8
|
||||
- pip install opencv-python pyyaml codecov flake8 Cython
|
||||
|
||||
cache:
|
||||
pip: true
|
||||
|
@ -40,4 +40,4 @@ deploy:
|
|||
python: "3.6"
|
||||
distributions: sdist bdist_wheel
|
||||
skip_cleanup: true
|
||||
skip_upload_docs: true
|
||||
skip_upload_docs: true
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
from .io import Cache, VideoReader, frames2video
|
||||
from .processing import convert_video, resize_video, cut_video, concat_video
|
||||
from .optflow import flowread, flowwrite, quantize_flow, dequantize_flow
|
||||
from .optflow import (flowread, flowwrite, quantize_flow,
|
||||
dequantize_flow, flow_warp)
|
||||
|
||||
__all__ = [
|
||||
'Cache', 'VideoReader', 'frames2video', 'convert_video', 'resize_video',
|
||||
'cut_video', 'concat_video', 'flowread', 'flowwrite', 'quantize_flow',
|
||||
'dequantize_flow'
|
||||
'dequantize_flow', 'flow_warp'
|
||||
]
|
||||
|
|
|
@ -3,6 +3,7 @@ import numpy as np
|
|||
from mmcv.arraymisc import quantize, dequantize
|
||||
from mmcv.image import imread, imwrite
|
||||
from mmcv.utils import is_str
|
||||
from mmcv._ext import flow_warp_c
|
||||
|
||||
|
||||
def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs):
|
||||
|
@ -137,3 +138,33 @@ def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
|
|||
dy *= dx.shape[0]
|
||||
flow = np.dstack((dx, dy))
|
||||
return flow
|
||||
|
||||
|
||||
def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'):
|
||||
"""Use flow to warp img
|
||||
|
||||
Args:
|
||||
img (ndarray, float or uint8): Image to be warped.
|
||||
flow (ndarray, float): Optical Flow.
|
||||
filling_value (int): The missing pixels will be set with filling_value.
|
||||
interpolate_mode (str): bilinear -> Bilinear Interpolation;
|
||||
nearest -> Nearest Neighbor.
|
||||
|
||||
Returns:
|
||||
ndarray: Warped image with the same shape of img
|
||||
"""
|
||||
interpolate_mode_dict = {'bilinear': 0, 'nearest': 1}
|
||||
assert len(img.shape) == 3
|
||||
assert len(flow.shape) == 3 and flow.shape[2] == 2
|
||||
assert flow.shape[:2] == img.shape[:2]
|
||||
assert interpolate_mode in interpolate_mode_dict.keys()
|
||||
|
||||
interpolate_mode = interpolate_mode_dict[interpolate_mode]
|
||||
img_float = img.astype(np.float64)
|
||||
|
||||
out = flow_warp_c(img_float,
|
||||
flow.astype(np.float64),
|
||||
filling_value=filling_value,
|
||||
interpolate_mode=interpolate_mode)
|
||||
|
||||
return out
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
#include <flow_warp.hpp>
|
||||
|
||||
void flowWarp(double* img, double* flow, double* out, const int height,
|
||||
const int width, const int channels, const int filling_value = 0,
|
||||
const int interpolateMode = 0) {
|
||||
for (int h = 0; h < height; h++) {
|
||||
for (int w = 0; w < width; w++) {
|
||||
int offset_cur = h * width + w;
|
||||
int offset_img = offset_cur * channels;
|
||||
int offset_flow = offset_cur * 2;
|
||||
double x, y;
|
||||
x = h + flow[offset_flow + 1];
|
||||
y = w + flow[offset_flow];
|
||||
|
||||
if (x < 0 || x >= height - 1 || y < 0 || y >= width - 1) {
|
||||
for (int k = 0; k < channels; k++) {
|
||||
out[offset_img + k] = filling_value;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (interpolateMode == 0)
|
||||
BilinearInterpolate(img, width, height, channels, x, y,
|
||||
out + offset_img);
|
||||
else if (interpolateMode == 1)
|
||||
NNInterpolate(img, width, height, channels, x, y, out + offset_img);
|
||||
else
|
||||
throw "Not Implemented Interpolation Method";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BilinearInterpolate(const double* img, int width, int height, int channels,
|
||||
double x, double y, double* out) {
|
||||
int xx, yy, m, n, u, v, offset, offset_img, l;
|
||||
xx = x;
|
||||
yy = y;
|
||||
|
||||
double dx, dy, s;
|
||||
|
||||
dx = __max(__min(x - xx, double(1)), double(0));
|
||||
dy = __max(__min(y - yy, double(1)), double(0));
|
||||
|
||||
for (m = 0; m <= 1; m++)
|
||||
for (n = 0; n <= 1; n++) {
|
||||
u = EnforceRange(yy + n, width);
|
||||
v = EnforceRange(xx + m, height);
|
||||
offset = v * width + u;
|
||||
offset_img = offset * channels;
|
||||
s = fabs(1 - m - dx) * fabs(1 - n - dy);
|
||||
for (l = 0; l < channels; l++) out[l] += img[offset_img + l] * s;
|
||||
}
|
||||
}
|
||||
|
||||
void NNInterpolate(const double* img, int width, int height, int channels,
|
||||
double x, double y, double* out) {
|
||||
int xx, yy, m, n, u, v, offset, offset_img, l;
|
||||
xx = x;
|
||||
yy = y;
|
||||
|
||||
double dx, dy;
|
||||
|
||||
dx = __max(__min(x - xx, double(1)), double(0));
|
||||
dy = __max(__min(y - yy, double(1)), double(0));
|
||||
|
||||
m = (dx < 0.5) ? 0 : 1;
|
||||
n = (dy < 0.5) ? 0 : 1;
|
||||
|
||||
u = EnforceRange(yy + n, width);
|
||||
v = EnforceRange(xx + m, height);
|
||||
offset = v * width + u;
|
||||
offset_img = offset * channels;
|
||||
|
||||
for (l = 0; l < channels; l++) out[l] = img[offset_img + l];
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
#include <math.h>
|
||||
#include <string.h>
|
||||
#include <iostream>
|
||||
|
||||
using namespace std;
|
||||
|
||||
void flowWarp(double* img, double* flow1, double* out, const int height,
|
||||
const int width, const int channels, const int filling_value,
|
||||
const int interpolateMode);
|
||||
|
||||
void BilinearInterpolate(const double* img, int width, int height, int channels,
|
||||
double x, double y, double* out);
|
||||
|
||||
void NNInterpolate(const double* img, int width, int height, int channels,
|
||||
double x, double y, double* out);
|
||||
|
||||
template <typename T>
|
||||
inline T __min(T a, T b) {
|
||||
return a > b ? b : a;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T __max(T a, T b) {
|
||||
return (a < b) ? b : a;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T EnforceRange(const T x, const int MaxValue) {
|
||||
return __min(__max(x, 0), MaxValue);
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
STUFF = "Hi"
|
||||
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
|
||||
np.import_array()
|
||||
|
||||
cdef extern from "flow_warp.hpp":
|
||||
void flowWarp(double* img, double* flow1, double* out, const int height, const int width, const int channels, const int filling_value, const int interpolateMode)
|
||||
|
||||
def flow_warp_c(np.ndarray[double, ndim=3, mode="c"] img_array not None,
|
||||
np.ndarray[double, ndim=3, mode="c"] flow_array not None,
|
||||
int filling_value=0,
|
||||
int interpolate_mode=1):
|
||||
|
||||
out_array = np.zeros_like(img_array)
|
||||
|
||||
flowWarp(<double*> np.PyArray_DATA(img_array),
|
||||
<double*> np.PyArray_DATA(flow_array),
|
||||
<double*> np.PyArray_DATA(out_array),
|
||||
out_array.shape[0],
|
||||
out_array.shape[1],
|
||||
out_array.shape[2],
|
||||
filling_value,
|
||||
interpolate_mode)
|
||||
|
||||
return out_array
|
22
setup.py
22
setup.py
|
@ -1,10 +1,14 @@
|
|||
import sys
|
||||
from io import open # for Python 2 (identical to builtin in Python 3)
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
from setuptools import Extension, find_packages, setup
|
||||
|
||||
import numpy
|
||||
from Cython.Distutils import build_ext
|
||||
|
||||
install_requires = [
|
||||
'numpy>=1.11.1', 'pyyaml', 'six', 'addict', 'requests', 'opencv-python'
|
||||
'numpy>=1.11.1', 'pyyaml', 'six', 'addict', 'requests', 'opencv-python',
|
||||
'Cython'
|
||||
]
|
||||
if sys.version_info < (3, 3):
|
||||
install_requires.append('backports.shutil_get_terminal_size')
|
||||
|
@ -25,6 +29,18 @@ def get_version():
|
|||
return locals()['__version__']
|
||||
|
||||
|
||||
EXT_MODULES = [
|
||||
Extension(
|
||||
name='mmcv._ext',
|
||||
sources=[
|
||||
'./mmcv/video/optflow_warp/flow_warp.cpp',
|
||||
'./mmcv/video/optflow_warp/flow_warp_module.pyx'
|
||||
],
|
||||
include_dirs=[numpy.get_include(), './mmcv/video/optflow_warp/'],
|
||||
language="c++",
|
||||
),
|
||||
]
|
||||
|
||||
setup(
|
||||
name='mmcv',
|
||||
version=get_version(),
|
||||
|
@ -51,4 +67,6 @@ setup(
|
|||
setup_requires=['pytest-runner'],
|
||||
tests_require=['pytest'],
|
||||
install_requires=install_requires,
|
||||
ext_modules=EXT_MODULES,
|
||||
cmdclass={'build_ext': build_ext},
|
||||
zip_safe=False)
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
import os
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
@ -22,18 +23,18 @@ def test_flowread():
|
|||
assert_array_equal(flow, flow_same)
|
||||
|
||||
# read quantized flow concatenated vertically
|
||||
flow = mmcv.flowread(
|
||||
osp.join(osp.dirname(__file__), 'data/optflow_concat0.jpg'),
|
||||
quantize=True,
|
||||
denorm=True)
|
||||
flow = mmcv.flowread(osp.join(osp.dirname(__file__),
|
||||
'data/optflow_concat0.jpg'),
|
||||
quantize=True,
|
||||
denorm=True)
|
||||
assert flow.shape == flow_shape
|
||||
|
||||
# read quantized flow concatenated horizontally
|
||||
flow = mmcv.flowread(
|
||||
osp.join(osp.dirname(__file__), 'data/optflow_concat1.jpg'),
|
||||
quantize=True,
|
||||
concat_axis=1,
|
||||
denorm=True)
|
||||
flow = mmcv.flowread(osp.join(osp.dirname(__file__),
|
||||
'data/optflow_concat1.jpg'),
|
||||
quantize=True,
|
||||
concat_axis=1,
|
||||
denorm=True)
|
||||
assert flow.shape == flow_shape
|
||||
|
||||
# test exceptions
|
||||
|
@ -61,8 +62,10 @@ def test_flowwrite():
|
|||
# write to two .jpg files
|
||||
tmp_filename = osp.join(tempfile.gettempdir(), 'mmcv_test_flow.jpg')
|
||||
for concat_axis in range(2):
|
||||
mmcv.flowwrite(
|
||||
flow, tmp_filename, quantize=True, concat_axis=concat_axis)
|
||||
mmcv.flowwrite(flow,
|
||||
tmp_filename,
|
||||
quantize=True,
|
||||
concat_axis=concat_axis)
|
||||
shape = (200, 100) if concat_axis == 0 else (100, 200)
|
||||
assert osp.isfile(tmp_filename)
|
||||
assert mmcv.imread(tmp_filename, flag='unchanged').shape == shape
|
||||
|
@ -117,16 +120,16 @@ def test_dequantize_flow():
|
|||
ref = np.zeros_like(flow, dtype=np.float32)
|
||||
for i in range(ref.shape[0]):
|
||||
for j in range(ref.shape[1]):
|
||||
ref[i, j,
|
||||
0] = (float(dx[i, j] + 0.5) * 2 * max_val / 255 - max_val) * w
|
||||
ref[i, j,
|
||||
1] = (float(dy[i, j] + 0.5) * 2 * max_val / 255 - max_val) * h
|
||||
ref[i, j, 0] = (float(dx[i, j] + 0.5) * 2 * max_val / 255 -
|
||||
max_val) * w
|
||||
ref[i, j, 1] = (float(dy[i, j] + 0.5) * 2 * max_val / 255 -
|
||||
max_val) * h
|
||||
assert_array_almost_equal(flow, ref)
|
||||
|
||||
|
||||
def test_flow2rgb():
|
||||
flow = np.array(
|
||||
[[[0, 0], [0.5, 0.5], [1, 1], [2, 1], [3, np.inf]]], dtype=np.float32)
|
||||
flow = np.array([[[0, 0], [0.5, 0.5], [1, 1], [2, 1], [3, np.inf]]],
|
||||
dtype=np.float32)
|
||||
flow_img = mmcv.flow2rgb(flow)
|
||||
# yapf: disable
|
||||
assert_array_almost_equal(
|
||||
|
@ -140,6 +143,45 @@ def test_flow2rgb():
|
|||
# yapf: enable
|
||||
|
||||
|
||||
def test_flow_warp():
|
||||
def np_flow_warp(flow, img):
|
||||
output = np.zeros_like(img, dtype=img.dtype)
|
||||
height = flow.shape[0]
|
||||
width = flow.shape[1]
|
||||
|
||||
grid = np.indices((height, width)).swapaxes(0, 1).swapaxes(1, 2)
|
||||
dx = grid[:, :, 0] + flow[:, :, 1]
|
||||
dy = grid[:, :, 1] + flow[:, :, 0]
|
||||
sx = np.floor(dx).astype(int)
|
||||
sy = np.floor(dy).astype(int)
|
||||
valid = (sx >= 0) & (sx < height - 1) & (sy >= 0) & (sy < width - 1)
|
||||
|
||||
output[valid, :] = img[dx[valid].round().astype(int), dy[valid].round(
|
||||
).astype(int), :]
|
||||
|
||||
return output
|
||||
|
||||
dim = 500
|
||||
a = np.random.randn(dim, dim, 3) * 10 + 125
|
||||
b = np.random.randn(dim, dim, 2) + 2 + 0.2
|
||||
|
||||
c = mmcv.flow_warp(a, b, interpolate_mode='nearest')
|
||||
|
||||
d = np_flow_warp(b, a)
|
||||
|
||||
simple_a = np.zeros((5, 5, 3))
|
||||
simple_a[2, 2, 0] = 1
|
||||
simple_b = np.ones((5, 5, 2))
|
||||
|
||||
simple_res_c = np.zeros((5, 5, 3))
|
||||
simple_res_c[1, 1, 0] = 1
|
||||
|
||||
res_c = mmcv.flow_warp(simple_a, simple_b, interpolate_mode='bilinear')
|
||||
|
||||
assert_array_equal(c, d)
|
||||
assert_array_equal(res_c, simple_res_c)
|
||||
|
||||
|
||||
def test_make_color_wheel():
|
||||
default_color_wheel = mmcv.make_color_wheel()
|
||||
color_wheel = mmcv.make_color_wheel([2, 2, 2, 2, 2, 2])
|
||||
|
|
Loading…
Reference in New Issue