mirror of https://github.com/WongKinYiu/yolov7.git
Fuse IAuxDetect
parent
4c207e1ae6
commit
954cde65ab
|
@ -303,6 +303,8 @@ class IKeypoint(nn.Module):
|
|||
class IAuxDetect(nn.Module):
|
||||
stride = None # strides computed during build
|
||||
export = False # onnx export
|
||||
end2end = False
|
||||
include_nms = False
|
||||
|
||||
def __init__(self, nc=80, anchors=(), ch=()): # detection layer
|
||||
super(IAuxDetect, self).__init__()
|
||||
|
@ -338,17 +340,83 @@ class IAuxDetect(nn.Module):
|
|||
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
|
||||
|
||||
y = x[i].sigmoid()
|
||||
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
|
||||
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
||||
if not torch.onnx.is_in_onnx_export():
|
||||
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
|
||||
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
||||
else:
|
||||
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
|
||||
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].data # wh
|
||||
y = torch.cat((xy, wh, y[..., 4:]), -1)
|
||||
z.append(y.view(bs, -1, self.no))
|
||||
|
||||
return x if self.training else (torch.cat(z, 1), x[:self.nl])
|
||||
|
||||
def fuseforward(self, x):
|
||||
# x = x.copy() # for profiling
|
||||
z = [] # inference output
|
||||
self.training |= self.export
|
||||
for i in range(self.nl):
|
||||
x[i] = self.m[i](x[i]) # conv
|
||||
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
|
||||
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
|
||||
|
||||
if not self.training: # inference
|
||||
if self.grid[i].shape[2:4] != x[i].shape[2:4]:
|
||||
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
|
||||
|
||||
y = x[i].sigmoid()
|
||||
if not torch.onnx.is_in_onnx_export():
|
||||
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
|
||||
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
||||
else:
|
||||
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
|
||||
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].data # wh
|
||||
y = torch.cat((xy, wh, y[..., 4:]), -1)
|
||||
z.append(y.view(bs, -1, self.no))
|
||||
|
||||
if self.training:
|
||||
out = x
|
||||
elif self.end2end:
|
||||
out = torch.cat(z, 1)
|
||||
elif self.include_nms:
|
||||
z = self.convert(z)
|
||||
out = (z, )
|
||||
else:
|
||||
out = (torch.cat(z, 1), x)
|
||||
|
||||
return out
|
||||
|
||||
def fuse(self):
|
||||
print("IAuxDetect.fuse")
|
||||
# fuse ImplicitA and Convolution
|
||||
for i in range(len(self.m)):
|
||||
c1,c2,_,_ = self.m[i].weight.shape
|
||||
c1_,c2_, _,_ = self.ia[i].implicit.shape
|
||||
self.m[i].bias += torch.matmul(self.m[i].weight.reshape(c1,c2),self.ia[i].implicit.reshape(c2_,c1_)).squeeze(1)
|
||||
|
||||
# fuse ImplicitM and Convolution
|
||||
for i in range(len(self.m)):
|
||||
c1,c2, _,_ = self.im[i].implicit.shape
|
||||
self.m[i].bias *= self.im[i].implicit.reshape(c2)
|
||||
self.m[i].weight *= self.im[i].implicit.transpose(0,1)
|
||||
|
||||
@staticmethod
|
||||
def _make_grid(nx=20, ny=20):
|
||||
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
|
||||
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
|
||||
|
||||
def convert(self, z):
|
||||
z = torch.cat(z, 1)
|
||||
box = z[:, :, :4]
|
||||
conf = z[:, :, 4:5]
|
||||
score = z[:, :, 5:]
|
||||
score *= conf
|
||||
convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
|
||||
dtype=torch.float32,
|
||||
device=z.device)
|
||||
box @= convert_matrix
|
||||
return (box, score)
|
||||
|
||||
|
||||
class IBin(nn.Module):
|
||||
stride = None # strides computed during build
|
||||
|
@ -623,7 +691,7 @@ class Model(nn.Module):
|
|||
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
||||
delattr(m, 'bn') # remove batchnorm
|
||||
m.forward = m.fuseforward # update forward
|
||||
elif isinstance(m, IDetect):
|
||||
elif isinstance(m, (IDetect, IAuxDetect)):
|
||||
m.fuse()
|
||||
m.forward = m.fuseforward
|
||||
self.info()
|
||||
|
|
Loading…
Reference in New Issue