mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
New scale_segments()
function (#9570)
* Rename scale_coords to scale_boxes * add scale_segments
This commit is contained in:
parent
d669a74623
commit
c8e52304cf
@ -42,7 +42,7 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
|||||||
from models.common import DetectMultiBackend
|
from models.common import DetectMultiBackend
|
||||||
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
|
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
|
||||||
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
|
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
|
||||||
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
|
increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
|
||||||
from utils.plots import Annotator, colors, save_one_box
|
from utils.plots import Annotator, colors, save_one_box
|
||||||
from utils.torch_utils import select_device, smart_inference_mode
|
from utils.torch_utils import select_device, smart_inference_mode
|
||||||
|
|
||||||
@ -148,7 +148,7 @@ def run(
|
|||||||
annotator = Annotator(im0, line_width=line_thickness, example=str(names))
|
annotator = Annotator(im0, line_width=line_thickness, example=str(names))
|
||||||
if len(det):
|
if len(det):
|
||||||
# Rescale boxes from img_size to im0 size
|
# Rescale boxes from img_size to im0 size
|
||||||
det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()
|
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
|
||||||
|
|
||||||
# Print results
|
# Print results
|
||||||
for c in det[:, 5].unique():
|
for c in det[:, 5].unique():
|
||||||
|
@ -23,7 +23,7 @@ from torch.cuda import amp
|
|||||||
|
|
||||||
from utils.dataloaders import exif_transpose, letterbox
|
from utils.dataloaders import exif_transpose, letterbox
|
||||||
from utils.general import (LOGGER, ROOT, Profile, check_requirements, check_suffix, check_version, colorstr,
|
from utils.general import (LOGGER, ROOT, Profile, check_requirements, check_suffix, check_version, colorstr,
|
||||||
increment_path, make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh,
|
increment_path, make_divisible, non_max_suppression, scale_boxes, xywh2xyxy, xyxy2xywh,
|
||||||
yaml_load)
|
yaml_load)
|
||||||
from utils.plots import Annotator, colors, save_one_box
|
from utils.plots import Annotator, colors, save_one_box
|
||||||
from utils.torch_utils import copy_attr, smart_inference_mode
|
from utils.torch_utils import copy_attr, smart_inference_mode
|
||||||
@ -703,7 +703,7 @@ class AutoShape(nn.Module):
|
|||||||
self.multi_label,
|
self.multi_label,
|
||||||
max_det=self.max_det) # NMS
|
max_det=self.max_det) # NMS
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
scale_coords(shape1, y[i][:, :4], shape0[i])
|
scale_boxes(shape1, y[i][:, :4], shape0[i])
|
||||||
|
|
||||||
return Detections(ims, y, files, dt, self.names, x.shape)
|
return Detections(ims, y, files, dt, self.names, x.shape)
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
|||||||
from models.common import DetectMultiBackend
|
from models.common import DetectMultiBackend
|
||||||
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
|
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
|
||||||
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
|
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
|
||||||
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
|
increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
|
||||||
from utils.plots import Annotator, colors, save_one_box
|
from utils.plots import Annotator, colors, save_one_box
|
||||||
from utils.segment.general import process_mask
|
from utils.segment.general import process_mask
|
||||||
from utils.torch_utils import select_device, smart_inference_mode
|
from utils.torch_utils import select_device, smart_inference_mode
|
||||||
@ -152,7 +152,7 @@ def run(
|
|||||||
masks = process_mask(proto[i], det[:, 6:], det[:, :4], im.shape[2:], upsample=True) # HWC
|
masks = process_mask(proto[i], det[:, 6:], det[:, :4], im.shape[2:], upsample=True) # HWC
|
||||||
|
|
||||||
# Rescale boxes from img_size to im0 size
|
# Rescale boxes from img_size to im0 size
|
||||||
det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()
|
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
|
||||||
|
|
||||||
# Print results
|
# Print results
|
||||||
for c in det[:, 5].unique():
|
for c in det[:, 5].unique():
|
||||||
|
@ -44,7 +44,7 @@ from models.yolo import SegmentationModel
|
|||||||
from utils.callbacks import Callbacks
|
from utils.callbacks import Callbacks
|
||||||
from utils.general import (LOGGER, NUM_THREADS, Profile, check_dataset, check_img_size, check_requirements, check_yaml,
|
from utils.general import (LOGGER, NUM_THREADS, Profile, check_dataset, check_img_size, check_requirements, check_yaml,
|
||||||
coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args,
|
coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args,
|
||||||
scale_coords, xywh2xyxy, xyxy2xywh)
|
scale_boxes, xywh2xyxy, xyxy2xywh)
|
||||||
from utils.metrics import ConfusionMatrix, box_iou
|
from utils.metrics import ConfusionMatrix, box_iou
|
||||||
from utils.plots import output_to_target, plot_val_study
|
from utils.plots import output_to_target, plot_val_study
|
||||||
from utils.segment.dataloaders import create_dataloader
|
from utils.segment.dataloaders import create_dataloader
|
||||||
@ -298,12 +298,12 @@ def run(
|
|||||||
if single_cls:
|
if single_cls:
|
||||||
pred[:, 5] = 0
|
pred[:, 5] = 0
|
||||||
predn = pred.clone()
|
predn = pred.clone()
|
||||||
scale_coords(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred
|
scale_boxes(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred
|
||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
if nl:
|
if nl:
|
||||||
tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
|
tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
|
||||||
scale_coords(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels
|
scale_boxes(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels
|
||||||
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
|
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
|
||||||
correct_bboxes = process_batch(predn, labelsn, iouv)
|
correct_bboxes = process_batch(predn, labelsn, iouv)
|
||||||
correct_masks = process_batch(predn, labelsn, iouv, pred_masks, gt_masks, overlap=overlap, masks=True)
|
correct_masks = process_batch(predn, labelsn, iouv, pred_masks, gt_masks, overlap=overlap, masks=True)
|
||||||
|
@ -725,7 +725,7 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
|
|||||||
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
|
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
|
||||||
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
|
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
|
||||||
if clip:
|
if clip:
|
||||||
clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
|
clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
|
||||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||||
y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
|
y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
|
||||||
y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
|
y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
|
||||||
@ -769,7 +769,23 @@ def resample_segments(segments, n=1000):
|
|||||||
return segments
|
return segments
|
||||||
|
|
||||||
|
|
||||||
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
|
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
|
||||||
|
# Rescale boxes (xyxy) from img1_shape to img0_shape
|
||||||
|
if ratio_pad is None: # calculate from img0_shape
|
||||||
|
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
||||||
|
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
||||||
|
else:
|
||||||
|
gain = ratio_pad[0][0]
|
||||||
|
pad = ratio_pad[1]
|
||||||
|
|
||||||
|
boxes[:, [0, 2]] -= pad[0] # x padding
|
||||||
|
boxes[:, [1, 3]] -= pad[1] # y padding
|
||||||
|
boxes[:, :4] /= gain
|
||||||
|
clip_boxes(boxes, img0_shape)
|
||||||
|
return boxes
|
||||||
|
|
||||||
|
|
||||||
|
def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None):
|
||||||
# Rescale coords (xyxy) from img1_shape to img0_shape
|
# Rescale coords (xyxy) from img1_shape to img0_shape
|
||||||
if ratio_pad is None: # calculate from img0_shape
|
if ratio_pad is None: # calculate from img0_shape
|
||||||
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
||||||
@ -778,15 +794,15 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
|
|||||||
gain = ratio_pad[0][0]
|
gain = ratio_pad[0][0]
|
||||||
pad = ratio_pad[1]
|
pad = ratio_pad[1]
|
||||||
|
|
||||||
coords[:, [0, 2]] -= pad[0] # x padding
|
segments[:, 0] -= pad[0] # x padding
|
||||||
coords[:, [1, 3]] -= pad[1] # y padding
|
segments[:, 1] -= pad[1] # y padding
|
||||||
coords[:, :4] /= gain
|
segments /= gain
|
||||||
clip_coords(coords, img0_shape)
|
clip_segments(segments, img0_shape)
|
||||||
return coords
|
return segments
|
||||||
|
|
||||||
|
|
||||||
def clip_coords(boxes, shape):
|
def clip_boxes(boxes, shape):
|
||||||
# Clip bounding xyxy bounding boxes to image shape (height, width)
|
# Clip boxes (xyxy) to image shape (height, width)
|
||||||
if isinstance(boxes, torch.Tensor): # faster individually
|
if isinstance(boxes, torch.Tensor): # faster individually
|
||||||
boxes[:, 0].clamp_(0, shape[1]) # x1
|
boxes[:, 0].clamp_(0, shape[1]) # x1
|
||||||
boxes[:, 1].clamp_(0, shape[0]) # y1
|
boxes[:, 1].clamp_(0, shape[0]) # y1
|
||||||
@ -797,6 +813,16 @@ def clip_coords(boxes, shape):
|
|||||||
boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
|
boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
|
||||||
|
|
||||||
|
|
||||||
|
def clip_segments(boxes, shape):
|
||||||
|
# Clip segments (xy1,xy2,...) to image shape (height, width)
|
||||||
|
if isinstance(boxes, torch.Tensor): # faster individually
|
||||||
|
boxes[:, 0].clamp_(0, shape[1]) # x
|
||||||
|
boxes[:, 1].clamp_(0, shape[0]) # y
|
||||||
|
else: # np.array (faster grouped)
|
||||||
|
boxes[:, 0] = boxes[:, 0].clip(0, shape[1]) # x
|
||||||
|
boxes[:, 1] = boxes[:, 1].clip(0, shape[0]) # y
|
||||||
|
|
||||||
|
|
||||||
def non_max_suppression(
|
def non_max_suppression(
|
||||||
prediction,
|
prediction,
|
||||||
conf_thres=0.25,
|
conf_thres=0.25,
|
||||||
@ -980,7 +1006,7 @@ def apply_classifier(x, model, img, im0):
|
|||||||
d[:, :4] = xywh2xyxy(b).long()
|
d[:, :4] = xywh2xyxy(b).long()
|
||||||
|
|
||||||
# Rescale boxes from img_size to im0 size
|
# Rescale boxes from img_size to im0 size
|
||||||
scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
|
scale_boxes(img.shape[2:], d[:, :4], im0[i].shape)
|
||||||
|
|
||||||
# Classes
|
# Classes
|
||||||
pred_cls1 = d[:, 5].long()
|
pred_cls1 = d[:, 5].long()
|
||||||
|
@ -28,7 +28,7 @@ import torchvision.transforms as T
|
|||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from utils.dataloaders import img2label_paths
|
from utils.dataloaders import img2label_paths
|
||||||
from utils.general import check_dataset, scale_coords, xywh2xyxy
|
from utils.general import check_dataset, scale_boxes, xywh2xyxy
|
||||||
from utils.metrics import box_iou
|
from utils.metrics import box_iou
|
||||||
|
|
||||||
COMET_PREFIX = "comet://"
|
COMET_PREFIX = "comet://"
|
||||||
@ -293,14 +293,14 @@ class CometLogger:
|
|||||||
pred[:, 5] = 0
|
pred[:, 5] = 0
|
||||||
|
|
||||||
predn = pred.clone()
|
predn = pred.clone()
|
||||||
scale_coords(image.shape[1:], predn[:, :4], shape[0], shape[1])
|
scale_boxes(image.shape[1:], predn[:, :4], shape[0], shape[1])
|
||||||
|
|
||||||
labelsn = None
|
labelsn = None
|
||||||
if nl:
|
if nl:
|
||||||
tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
|
tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
|
||||||
scale_coords(image.shape[1:], tbox, shape[0], shape[1]) # native-space labels
|
scale_boxes(image.shape[1:], tbox, shape[0], shape[1]) # native-space labels
|
||||||
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
|
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
|
||||||
scale_coords(image.shape[1:], predn[:, :4], shape[0], shape[1]) # native-space pred
|
scale_boxes(image.shape[1:], predn[:, :4], shape[0], shape[1]) # native-space pred
|
||||||
|
|
||||||
return predn, labelsn
|
return predn, labelsn
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ import torch
|
|||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
from utils import TryExcept, threaded
|
from utils import TryExcept, threaded
|
||||||
from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_coords, increment_path,
|
from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_boxes, increment_path,
|
||||||
is_ascii, xywh2xyxy, xyxy2xywh)
|
is_ascii, xywh2xyxy, xyxy2xywh)
|
||||||
from utils.metrics import fitness
|
from utils.metrics import fitness
|
||||||
from utils.segment.general import scale_image
|
from utils.segment.general import scale_image
|
||||||
@ -565,7 +565,7 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False,
|
|||||||
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
|
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
|
||||||
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
|
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
|
||||||
xyxy = xywh2xyxy(b).long()
|
xyxy = xywh2xyxy(b).long()
|
||||||
clip_coords(xyxy, im.shape)
|
clip_boxes(xyxy, im.shape)
|
||||||
crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
|
crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
|
||||||
if save:
|
if save:
|
||||||
file.parent.mkdir(parents=True, exist_ok=True) # make directory
|
file.parent.mkdir(parents=True, exist_ok=True) # make directory
|
||||||
|
6
val.py
6
val.py
@ -40,7 +40,7 @@ from utils.callbacks import Callbacks
|
|||||||
from utils.dataloaders import create_dataloader
|
from utils.dataloaders import create_dataloader
|
||||||
from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_yaml,
|
from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_yaml,
|
||||||
coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args,
|
coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args,
|
||||||
scale_coords, xywh2xyxy, xyxy2xywh)
|
scale_boxes, xywh2xyxy, xyxy2xywh)
|
||||||
from utils.metrics import ConfusionMatrix, ap_per_class, box_iou
|
from utils.metrics import ConfusionMatrix, ap_per_class, box_iou
|
||||||
from utils.plots import output_to_target, plot_images, plot_val_study
|
from utils.plots import output_to_target, plot_images, plot_val_study
|
||||||
from utils.torch_utils import select_device, smart_inference_mode
|
from utils.torch_utils import select_device, smart_inference_mode
|
||||||
@ -244,12 +244,12 @@ def run(
|
|||||||
if single_cls:
|
if single_cls:
|
||||||
pred[:, 5] = 0
|
pred[:, 5] = 0
|
||||||
predn = pred.clone()
|
predn = pred.clone()
|
||||||
scale_coords(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred
|
scale_boxes(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred
|
||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
if nl:
|
if nl:
|
||||||
tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
|
tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
|
||||||
scale_coords(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels
|
scale_boxes(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels
|
||||||
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
|
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
|
||||||
correct = process_batch(predn, labelsn, iouv)
|
correct = process_batch(predn, labelsn, iouv)
|
||||||
if plots:
|
if plots:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user