exp13-nosoft
Glenn Jocher 2022-12-20 14:18:38 +01:00
parent ad00c75e8c
commit a3633d7ffe
1 changed files with 4 additions and 4 deletions

View File

@ -53,12 +53,12 @@ class V6Detect(nn.Module):
self.inplace = inplace # use inplace ops (e.g. slice assignment)
self.stride = torch.zeros(self.nl) # strides computed during build
c2, c3 = max(ch[0] // 4, self.reg_max * 4), max(ch[0], self.nc) # channels
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], self.nc) # channels
self.cv2 = nn.ModuleList(
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), Conv(c2, 4 * self.reg_max, 1)) for x in ch)
self.cv3 = nn.ModuleList(
nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), Conv(c3, self.nc, 1, act=False)) for x in ch)
self.dfl = DFL(self.reg_max)
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
def forward(self, x):
shape = x[0].shape # BCHW
@ -71,8 +71,8 @@ class V6Detect(nn.Module):
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape
dbox = dist2bbox(self.dfl(box) if self.reg_max > 1 else box, self.anchors.unsqueeze(0), xywh=True, dim=1)
y = torch.cat((dbox * self.strides, cls.sigmoid()), 1)
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
y = torch.cat((dbox, cls.sigmoid()), 1)
return y if self.export else (y, (x, box, cls))
def bias_init(self):