256 lines
11 KiB
Python
256 lines
11 KiB
Python
# --------------------------------------------------------
|
|
# InternImage
|
|
# Copyright (c) 2022 OpenGVLab
|
|
# Licensed under The MIT License [see LICENSE for details]
|
|
# --------------------------------------------------------
|
|
|
|
# Copied from
|
|
# https://github.com/OpenGVLab/InternImage/blob/master/classification/models/
|
|
|
|
from __future__ import absolute_import, division, print_function
|
|
import math # noqa
|
|
import time
|
|
|
|
import torch
|
|
import torch.nn as nn # noqa
|
|
from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch
|
|
from torch.autograd import gradcheck # noqa
|
|
|
|
H_in, W_in = 8, 8
|
|
N, M, D = 2, 4, 16
|
|
Kh, Kw = 3, 3
|
|
remove_center = False
|
|
P = Kh * Kw - remove_center
|
|
offset_scale = 2.0
|
|
pad = 1
|
|
dilation = 1
|
|
stride = 1
|
|
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
|
|
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
|
|
|
|
torch.manual_seed(3)
|
|
|
|
|
|
@torch.no_grad()
|
|
def check_forward_equal_with_pytorch_double():
|
|
input = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
|
|
offset = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
|
|
mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
|
|
mask /= mask.sum(-1, keepdim=True)
|
|
mask = mask.reshape(N, H_out, W_out, M * P)
|
|
|
|
output_pytorch = dcnv3_core_pytorch(input.double(), offset.double(),
|
|
mask.double(), Kh, Kw, stride, stride,
|
|
Kh // 2, Kw // 2, dilation, dilation,
|
|
M, D, offset_scale,
|
|
remove_center).detach().cpu()
|
|
|
|
im2col_step = 2
|
|
output_cuda = DCNv3Function.apply(input.double(), offset.double(),
|
|
mask.double(), Kh, Kw, stride, stride,
|
|
Kh // 2, Kw // 2, dilation, dilation, M,
|
|
D, offset_scale, im2col_step,
|
|
remove_center).detach().cpu()
|
|
|
|
fwdok = torch.allclose(output_cuda, output_pytorch)
|
|
max_abs_err = (output_cuda - output_pytorch).abs().max()
|
|
max_rel_err = ((output_cuda - output_pytorch).abs() /
|
|
output_pytorch.abs()).max()
|
|
print('>>> forward double')
|
|
print(f'* {fwdok} check_forward_equal_with_pytorch_double:'
|
|
f' max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
|
|
|
|
|
@torch.no_grad()
|
|
def check_forward_equal_with_pytorch_float():
|
|
input = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
|
|
offset = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
|
|
mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
|
|
mask /= mask.sum(-1, keepdim=True)
|
|
mask = mask.reshape(N, H_out, W_out, M * P)
|
|
|
|
output_pytorch = dcnv3_core_pytorch(input, offset, mask, Kh, Kw, stride,
|
|
stride, Kh // 2, Kw // 2, dilation,
|
|
dilation, M, D, offset_scale,
|
|
remove_center).detach().cpu()
|
|
|
|
im2col_step = 2
|
|
output_cuda = DCNv3Function.apply(input, offset, mask, Kh, Kw, stride,
|
|
stride, Kh // 2, Kw // 2, dilation,
|
|
dilation, M, D, offset_scale,
|
|
im2col_step,
|
|
remove_center).detach().cpu()
|
|
|
|
fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
|
|
max_abs_err = (output_cuda - output_pytorch).abs().max()
|
|
max_rel_err = ((output_cuda - output_pytorch).abs() /
|
|
output_pytorch.abs()).max()
|
|
print('>>> forward float')
|
|
print(f'* {fwdok} check_forward_equal_with_pytorch_float:'
|
|
f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
|
|
|
|
|
def check_backward_equal_with_pytorch_double(channels=4,
|
|
grad_input=True,
|
|
grad_offset=True,
|
|
grad_mask=True):
|
|
# H_in, W_in = 4, 4
|
|
N = 2
|
|
M = 2
|
|
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
|
|
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
|
|
|
|
D = channels
|
|
input0 = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
|
|
offset0 = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
|
|
mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
|
|
mask0 /= mask0.sum(-1, keepdim=True)
|
|
mask0 = mask0.reshape(N, H_out, W_out, M * P)
|
|
input0.requires_grad = grad_input
|
|
offset0.requires_grad = grad_offset
|
|
mask0.requires_grad = grad_mask
|
|
|
|
output_pytorch = dcnv3_core_pytorch(input0.double(), offset0.double(),
|
|
mask0.double(), Kh, Kw, stride, stride,
|
|
Kh // 2, Kw // 2, dilation, dilation,
|
|
M, D, offset_scale, remove_center)
|
|
output_pytorch.sum().backward()
|
|
|
|
input1 = input0.detach()
|
|
offset1 = offset0.detach()
|
|
mask1 = mask0.detach()
|
|
input1.requires_grad = grad_input
|
|
offset1.requires_grad = grad_offset
|
|
mask1.requires_grad = grad_mask
|
|
|
|
im2col_step = 2
|
|
output_cuda = DCNv3Function.apply(input1.double(), offset1.double(),
|
|
mask1.double(), Kh, Kw, stride, stride,
|
|
Kh // 2, Kw // 2, dilation, dilation, M,
|
|
D, offset_scale, im2col_step,
|
|
remove_center)
|
|
output_cuda.sum().backward()
|
|
|
|
print(f'>>> backward double: channels {D}')
|
|
bwdok = torch.allclose(input0.grad, input1.grad, rtol=1e-2, atol=1e-3)
|
|
max_abs_err = (input0.grad - input1.grad).abs().max()
|
|
max_rel_err = ((input0.grad - input1.grad).abs() / input0.grad.abs()).max()
|
|
print(f'* {bwdok} input_grad check_backward_equal_with_pytorch_double:'
|
|
f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
|
|
|
bwdok = torch.allclose(offset0.grad, offset1.grad, rtol=1e-2, atol=1e-3)
|
|
max_abs_err = (offset0.grad - offset1.grad).abs().max()
|
|
max_rel_err = ((offset0.grad - offset1.grad).abs() /
|
|
offset0.grad.abs()).max()
|
|
print(f'* {bwdok} offset_grad check_backward_equal_with_pytorch_double:'
|
|
f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
|
|
|
bwdok = torch.allclose(mask0.grad, mask1.grad, rtol=1e-2, atol=1e-3)
|
|
max_abs_err = (mask0.grad - mask1.grad).abs().max()
|
|
max_rel_err = ((mask0.grad - mask1.grad).abs() / mask0.grad.abs()).max()
|
|
print(f'* {bwdok} mask_grad check_backward_equal_with_pytorch_double:'
|
|
f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
|
|
|
|
|
def check_backward_equal_with_pytorch_float(channels=4,
|
|
grad_input=True,
|
|
grad_offset=True,
|
|
grad_mask=True):
|
|
# H_in, W_in = 4, 4
|
|
N = 2
|
|
M = 2
|
|
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
|
|
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
|
|
|
|
D = channels
|
|
input0 = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
|
|
offset0 = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
|
|
mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
|
|
mask0 /= mask0.sum(-1, keepdim=True)
|
|
mask0 = mask0.reshape(N, H_out, W_out, M * P)
|
|
input0.requires_grad = grad_input
|
|
offset0.requires_grad = grad_offset
|
|
mask0.requires_grad = grad_mask
|
|
|
|
output_pytorch = dcnv3_core_pytorch(input0, offset0, mask0, Kh, Kw, stride,
|
|
stride, Kh // 2, Kw // 2, dilation,
|
|
dilation, M, D, offset_scale,
|
|
remove_center)
|
|
output_pytorch.sum().backward()
|
|
|
|
input1 = input0.detach()
|
|
offset1 = offset0.detach()
|
|
mask1 = mask0.detach()
|
|
input1.requires_grad = grad_input
|
|
offset1.requires_grad = grad_offset
|
|
mask1.requires_grad = grad_mask
|
|
|
|
im2col_step = 2
|
|
output_cuda = DCNv3Function.apply(input1, offset1, mask1, Kh, Kw, stride,
|
|
stride, Kh // 2, Kw // 2, dilation,
|
|
dilation, M, D, offset_scale,
|
|
im2col_step, remove_center)
|
|
output_cuda.sum().backward()
|
|
|
|
print(f'>>> backward float: channels {D}')
|
|
bwdok = torch.allclose(input0.grad, input1.grad, rtol=1e-2, atol=1e-3)
|
|
max_abs_err = (input0.grad - input1.grad).abs().max()
|
|
max_rel_err = ((input0.grad - input1.grad).abs() / input0.grad.abs()).max()
|
|
print(f'* {bwdok} input_grad check_backward_equal_with_pytorch_float:'
|
|
f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
|
|
|
bwdok = torch.allclose(offset0.grad, offset1.grad, rtol=1e-2, atol=1e-3)
|
|
max_abs_err = (offset0.grad - offset1.grad).abs().max()
|
|
max_rel_err = ((offset0.grad - offset1.grad).abs() /
|
|
offset0.grad.abs()).max()
|
|
print(f'* {bwdok} offset_grad check_backward_equal_with_pytorch_float:'
|
|
f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
|
|
|
bwdok = torch.allclose(mask0.grad, mask1.grad, rtol=1e-2, atol=1e-3)
|
|
max_abs_err = (mask0.grad - mask1.grad).abs().max()
|
|
max_rel_err = ((mask0.grad - mask1.grad).abs() / mask0.grad.abs()).max()
|
|
print(f'* {bwdok} mask_grad check_backward_equal_with_pytorch_float:'
|
|
f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
|
|
|
|
|
@torch.no_grad()
|
|
def check_time_cost(im2col_step=128):
|
|
N = 512
|
|
H_in, W_in = 64, 64
|
|
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
|
|
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
|
|
|
|
input = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
|
|
offset = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
|
|
mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
|
|
mask /= mask.sum(-1, keepdim=True)
|
|
mask = mask.reshape(N, H_out, W_out, M * P)
|
|
print(f'>>> time cost: im2col_step {im2col_step};'
|
|
f'input {input.shape}; points {P} ')
|
|
repeat = 100
|
|
for i in range(repeat):
|
|
output_cuda = DCNv3Function.apply(input, offset, mask, Kh, Kw, stride,
|
|
stride, Kh // 2, Kw // 2, dilation,
|
|
dilation, M, D, 1.0, im2col_step,
|
|
remove_center)
|
|
torch.cuda.synchronize()
|
|
start = time.time()
|
|
for i in range(repeat):
|
|
output_cuda = DCNv3Function.apply( # noqa
|
|
input, offset, mask, Kh, Kw, stride, stride, Kh // 2, Kw // 2,
|
|
dilation, dilation, M, D, 1.0, im2col_step, remove_center)
|
|
torch.cuda.synchronize()
|
|
print(f'foward time cost: {(time.time() - start) / repeat}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
check_forward_equal_with_pytorch_double()
|
|
check_forward_equal_with_pytorch_float()
|
|
for channels in [1, 16, 30, 32, 64, 71, 1025]:
|
|
check_backward_equal_with_pytorch_double(channels, True, True, True)
|
|
for channels in [1, 16, 30, 32, 64, 71, 1025]:
|
|
check_backward_equal_with_pytorch_float(channels, True, True, True)
|
|
for i in range(3):
|
|
im2col_step = 128 * (2**i)
|
|
check_time_cost(im2col_step)
|