Update
parent
ad00c75e8c
commit
a3633d7ffe
|
@ -53,12 +53,12 @@ class V6Detect(nn.Module):
|
|||
self.inplace = inplace # use inplace ops (e.g. slice assignment)
|
||||
self.stride = torch.zeros(self.nl) # strides computed during build
|
||||
|
||||
c2, c3 = max(ch[0] // 4, self.reg_max * 4), max(ch[0], self.nc) # channels
|
||||
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], self.nc) # channels
|
||||
self.cv2 = nn.ModuleList(
|
||||
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), Conv(c2, 4 * self.reg_max, 1)) for x in ch)
|
||||
self.cv3 = nn.ModuleList(
|
||||
nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), Conv(c3, self.nc, 1, act=False)) for x in ch)
|
||||
self.dfl = DFL(self.reg_max)
|
||||
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
shape = x[0].shape # BCHW
|
||||
|
@ -71,8 +71,8 @@ class V6Detect(nn.Module):
|
|||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||||
self.shape = shape
|
||||
|
||||
dbox = dist2bbox(self.dfl(box) if self.reg_max > 1 else box, self.anchors.unsqueeze(0), xywh=True, dim=1)
|
||||
y = torch.cat((dbox * self.strides, cls.sigmoid()), 1)
|
||||
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
||||
y = torch.cat((dbox, cls.sigmoid()), 1)
|
||||
return y if self.export else (y, (x, box, cls))
|
||||
|
||||
def bias_init(self):
|
||||
|
|
Loading…
Reference in New Issue