mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
update yolo.py TTA flexibility and extensibility (#506)
* update yolo.py TTA flexibility and extensibility * Update scale_img()
This commit is contained in:
parent
4b5f4806bc
commit
1d17b9af0f
@ -82,18 +82,19 @@ class Model(nn.Module):
|
||||
def forward(self, x, augment=False, profile=False):
|
||||
if augment:
|
||||
img_size = x.shape[-2:] # height, width
|
||||
s = [0.83, 0.67] # scales
|
||||
y = []
|
||||
for i, xi in enumerate((x,
|
||||
torch_utils.scale_img(x.flip(3), s[0]), # flip-lr and scale
|
||||
torch_utils.scale_img(x, s[1]), # scale
|
||||
)):
|
||||
# cv2.imwrite('img%g.jpg' % i, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1])
|
||||
y.append(self.forward_once(xi)[0])
|
||||
|
||||
y[1][..., :4] /= s[0] # scale
|
||||
y[1][..., 0] = img_size[1] - y[1][..., 0] # flip lr
|
||||
y[2][..., :4] /= s[1] # scale
|
||||
s = [1, 0.83, 0.67] # scales
|
||||
f = [None, 3, None] # flips (2-ud, 3-lr)
|
||||
y = [] # outputs
|
||||
for si, fi in zip(s, f):
|
||||
xi = torch_utils.scale_img(x.flip(fi) if fi else x, si)
|
||||
yi = self.forward_once(xi)[0] # forward
|
||||
# cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
|
||||
yi[..., :4] /= si # de-scale
|
||||
if fi is 2:
|
||||
yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
|
||||
elif fi is 3:
|
||||
yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
|
||||
y.append(yi)
|
||||
return torch.cat(y, 1), None # augmented inference, train
|
||||
else:
|
||||
return self.forward_once(x, profile) # single-scale inference, train
|
||||
|
@ -164,13 +164,16 @@ def load_classifier(name='resnet101', n=2):
|
||||
|
||||
def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio
|
||||
# scales img(bs,3,y,x) by ratio
|
||||
h, w = img.shape[2:]
|
||||
s = (int(h * ratio), int(w * ratio)) # new size
|
||||
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
|
||||
if not same_shape: # pad/crop img
|
||||
gs = 32 # (pixels) grid size
|
||||
h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
|
||||
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
|
||||
if ratio == 1.0:
|
||||
return img
|
||||
else:
|
||||
h, w = img.shape[2:]
|
||||
s = (int(h * ratio), int(w * ratio)) # new size
|
||||
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
|
||||
if not same_shape: # pad/crop img
|
||||
gs = 32 # (pixels) grid size
|
||||
h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
|
||||
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
|
||||
|
||||
|
||||
def copy_attr(a, b, include=(), exclude=()):
|
||||
|
Loading…
x
Reference in New Issue
Block a user