From a3633d7ffe1312f380c47c4cd5f2bebd59f1c589 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 20 Dec 2022 14:18:38 +0100 Subject: [PATCH] Update --- models/yolo.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/models/yolo.py b/models/yolo.py index d2c67645d..eed2f7f47 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -53,12 +53,12 @@ class V6Detect(nn.Module): self.inplace = inplace # use inplace ops (e.g. slice assignment) self.stride = torch.zeros(self.nl) # strides computed during build - c2, c3 = max(ch[0] // 4, self.reg_max * 4), max(ch[0], self.nc) # channels + c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], self.nc) # channels self.cv2 = nn.ModuleList( nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), Conv(c2, 4 * self.reg_max, 1)) for x in ch) self.cv3 = nn.ModuleList( nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), Conv(c3, self.nc, 1, act=False)) for x in ch) - self.dfl = DFL(self.reg_max) + self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity() def forward(self, x): shape = x[0].shape # BCHW @@ -71,8 +71,8 @@ class V6Detect(nn.Module): self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) self.shape = shape - dbox = dist2bbox(self.dfl(box) if self.reg_max > 1 else box, self.anchors.unsqueeze(0), xywh=True, dim=1) - y = torch.cat((dbox * self.strides, cls.sigmoid()), 1) + dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides + y = torch.cat((dbox, cls.sigmoid()), 1) return y if self.export else (y, (x, box, cls)) def bias_init(self):