mirror of https://github.com/alibaba/EasyCV.git
241 lines
7.6 KiB
Python
241 lines
7.6 KiB
Python
import copy
|
|
import math
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
def conv_bn(inp, oup, kernel, stride, padding=1):
|
|
return nn.Sequential(
|
|
nn.Conv2d(inp, oup, kernel, stride, padding, bias=False),
|
|
nn.BatchNorm2d(oup), nn.PReLU(oup))
|
|
|
|
|
|
def conv_no_relu(inp, oup, kernel, stride, padding=1):
|
|
return nn.Sequential(
|
|
nn.Conv2d(inp, oup, kernel, stride, padding, bias=False),
|
|
nn.BatchNorm2d(oup))
|
|
|
|
|
|
class View(nn.Module):
|
|
|
|
def __init__(self, shape):
|
|
super(View, self).__init__()
|
|
self.shape = shape
|
|
|
|
def forward(self, x):
|
|
return x.view(*self.shape)
|
|
|
|
|
|
class Softmax(nn.Module):
|
|
|
|
def __init__(self, dim):
|
|
super(Softmax, self).__init__()
|
|
self.softmax = nn.Softmax(dim)
|
|
|
|
def forward(self, x):
|
|
return self.softmax(x)
|
|
|
|
|
|
class InvertedResidual(nn.Module):
|
|
|
|
def __init__(self,
|
|
inp,
|
|
oup,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
expand_ratio=2,
|
|
use_connect=False,
|
|
activation='relu'):
|
|
super(InvertedResidual, self).__init__()
|
|
|
|
hid_channels = int(inp * expand_ratio)
|
|
if activation == 'relu':
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(inp, hid_channels, 1, 1, 0, bias=False),
|
|
nn.BatchNorm2d(hid_channels), nn.ReLU(inplace=True),
|
|
nn.Conv2d(
|
|
hid_channels,
|
|
hid_channels,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
groups=hid_channels,
|
|
bias=False), nn.BatchNorm2d(hid_channels),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(hid_channels, oup, 1, 1, 0, bias=False),
|
|
nn.BatchNorm2d(oup))
|
|
elif activation == 'prelu':
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(inp, hid_channels, 1, 1, 0, bias=False),
|
|
nn.BatchNorm2d(hid_channels), nn.PReLU(hid_channels),
|
|
nn.Conv2d(
|
|
hid_channels,
|
|
hid_channels,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
groups=hid_channels,
|
|
bias=False), nn.BatchNorm2d(hid_channels),
|
|
nn.PReLU(hid_channels),
|
|
nn.Conv2d(hid_channels, oup, 1, 1, 0, bias=False),
|
|
nn.BatchNorm2d(oup))
|
|
elif activation == 'half_v1':
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(inp, hid_channels, 1, 1, 0, bias=False),
|
|
nn.BatchNorm2d(hid_channels), nn.ReLU(inplace=True),
|
|
nn.Conv2d(
|
|
hid_channels,
|
|
hid_channels,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
groups=hid_channels,
|
|
bias=False), nn.BatchNorm2d(hid_channels),
|
|
nn.PReLU(hid_channels),
|
|
nn.Conv2d(hid_channels, oup, 1, 1, 0, bias=False),
|
|
nn.BatchNorm2d(oup))
|
|
elif activation == 'half_v2':
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(inp, hid_channels, 1, 1, 0, bias=False),
|
|
nn.BatchNorm2d(hid_channels), nn.PReLU(hid_channels),
|
|
nn.Conv2d(
|
|
hid_channels,
|
|
hid_channels,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
groups=hid_channels,
|
|
bias=False), nn.BatchNorm2d(hid_channels),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(hid_channels, oup, 1, 1, 0, bias=False),
|
|
nn.BatchNorm2d(oup))
|
|
self.use_connect = use_connect
|
|
|
|
def forward(self, x):
|
|
if self.use_connect:
|
|
return x + self.conv(x)
|
|
else:
|
|
return self.conv(x)
|
|
|
|
|
|
class Residual(nn.Module):
|
|
|
|
def __init__(self,
|
|
inp,
|
|
oup,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
use_connect=False,
|
|
activation='relu'):
|
|
super(Residual, self).__init__()
|
|
|
|
self.use_connect = use_connect
|
|
|
|
if activation == 'relu':
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(
|
|
inp,
|
|
inp,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
groups=inp,
|
|
bias=False), nn.BatchNorm2d(inp), nn.ReLU(inplace=True),
|
|
nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup),
|
|
nn.ReLU(inplace=True))
|
|
elif activation == 'prelu':
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(
|
|
inp,
|
|
inp,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
groups=inp,
|
|
bias=False), nn.BatchNorm2d(inp), nn.PReLU(inp),
|
|
nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup),
|
|
nn.PReLU(oup))
|
|
elif activation == 'half_v1':
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(
|
|
inp,
|
|
inp,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
groups=inp,
|
|
bias=False), nn.BatchNorm2d(inp), nn.ReLU(inplace=True),
|
|
nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup),
|
|
nn.PReLU(oup))
|
|
elif activation == 'half_v2':
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(
|
|
inp,
|
|
inp,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
groups=inp,
|
|
bias=False), nn.BatchNorm2d(inp), nn.PReLU(inp),
|
|
nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup),
|
|
nn.ReLU(inplace=True))
|
|
|
|
def forward(self, x):
|
|
if self.use_connect:
|
|
return x + self.conv(x)
|
|
else:
|
|
return self.conv(x)
|
|
|
|
|
|
def pose_accuracy(output, target):
|
|
with torch.no_grad():
|
|
output = output.detach().cpu().numpy()
|
|
target = target.detach().cpu().numpy()
|
|
|
|
acc = np.mean(np.abs(output - target))
|
|
return acc
|
|
|
|
|
|
def ION(output, target, left_eye_left_coner_idx, right_eye_right_corner_idx,
|
|
num_pts):
|
|
with torch.no_grad():
|
|
output = output.view(-1, num_pts, 2).cpu().numpy()
|
|
target = target.view(-1, num_pts, 2).cpu().numpy()
|
|
|
|
interocular = target[:,
|
|
left_eye_left_coner_idx] - target[:,
|
|
right_eye_right_corner_idx]
|
|
interocular = np.sqrt(
|
|
np.square(interocular[:, 0]) + np.square(interocular[:, 1])) + 1e-5
|
|
dist = target - output
|
|
dist = np.sqrt(np.square(dist[:, :, 0]) + np.square(dist[:, :, 1]))
|
|
dist = np.sum(dist, axis=1)
|
|
nme = dist / (interocular * num_pts)
|
|
|
|
return np.mean(nme)
|
|
|
|
|
|
def get_keypoint_accuracy(output, target_point):
|
|
accuracy = dict()
|
|
num_points = 106
|
|
left_eye_left_corner_index = 66
|
|
right_eye_right_corner_index = 79
|
|
|
|
nme = ION(output, target_point, left_eye_left_corner_index,
|
|
right_eye_right_corner_index, num_points)
|
|
|
|
accuracy['nme'] = nme
|
|
|
|
return accuracy
|
|
|
|
|
|
def get_pose_accuracy(output, target_pose):
|
|
accuracy = dict()
|
|
pose_acc = pose_accuracy(output, target_pose)
|
|
accuracy['pose_acc'] = float(pose_acc)
|
|
return accuracy
|