Update
parent
ad00c75e8c
commit
a3633d7ffe
|
@ -53,12 +53,12 @@ class V6Detect(nn.Module):
|
||||||
self.inplace = inplace # use inplace ops (e.g. slice assignment)
|
self.inplace = inplace # use inplace ops (e.g. slice assignment)
|
||||||
self.stride = torch.zeros(self.nl) # strides computed during build
|
self.stride = torch.zeros(self.nl) # strides computed during build
|
||||||
|
|
||||||
c2, c3 = max(ch[0] // 4, self.reg_max * 4), max(ch[0], self.nc) # channels
|
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], self.nc) # channels
|
||||||
self.cv2 = nn.ModuleList(
|
self.cv2 = nn.ModuleList(
|
||||||
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), Conv(c2, 4 * self.reg_max, 1)) for x in ch)
|
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), Conv(c2, 4 * self.reg_max, 1)) for x in ch)
|
||||||
self.cv3 = nn.ModuleList(
|
self.cv3 = nn.ModuleList(
|
||||||
nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), Conv(c3, self.nc, 1, act=False)) for x in ch)
|
nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), Conv(c3, self.nc, 1, act=False)) for x in ch)
|
||||||
self.dfl = DFL(self.reg_max)
|
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
shape = x[0].shape # BCHW
|
shape = x[0].shape # BCHW
|
||||||
|
@ -71,8 +71,8 @@ class V6Detect(nn.Module):
|
||||||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
|
|
||||||
dbox = dist2bbox(self.dfl(box) if self.reg_max > 1 else box, self.anchors.unsqueeze(0), xywh=True, dim=1)
|
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
||||||
y = torch.cat((dbox * self.strides, cls.sigmoid()), 1)
|
y = torch.cat((dbox, cls.sigmoid()), 1)
|
||||||
return y if self.export else (y, (x, box, cls))
|
return y if self.export else (y, (x, box, cls))
|
||||||
|
|
||||||
def bias_init(self):
|
def bias_init(self):
|
||||||
|
|
Loading…
Reference in New Issue