`process_mask_native()` cleanup (#10366)
* process_mask_native() cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix arg name * cleanup anno_json * Remove scale_image * Remove scale_image * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update to native Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/10059/head
parent
a1b6e79ccf
commit
9722e6ffe5
|
@ -44,7 +44,7 @@ from models.common import DetectMultiBackend
|
|||
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
|
||||
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
|
||||
increment_path, non_max_suppression, print_args, scale_boxes, scale_segments,
|
||||
strip_optimizer, xyxy2xywh)
|
||||
strip_optimizer)
|
||||
from utils.plots import Annotator, colors, save_one_box
|
||||
from utils.segment.general import masks2segments, process_mask, process_mask_native
|
||||
from utils.torch_utils import select_device, smart_inference_mode
|
||||
|
@ -161,10 +161,9 @@ def run(
|
|||
|
||||
# Segments
|
||||
if save_txt:
|
||||
segments = reversed(masks2segments(masks))
|
||||
segments = [
|
||||
scale_segments(im0.shape if retina_masks else im.shape[2:], x, im0.shape, normalize=True)
|
||||
for x in segments]
|
||||
for x in reversed(masks2segments(masks))]
|
||||
|
||||
# Print results
|
||||
for c in det[:, 5].unique():
|
||||
|
@ -172,15 +171,17 @@ def run(
|
|||
s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
|
||||
|
||||
# Mask plotting
|
||||
plot_img = torch.as_tensor(im0, dtype=torch.float16).to(device).permute(2, 0, 1).flip(0).contiguous() / 255. \
|
||||
if retina_masks else im[i]
|
||||
annotator.masks(masks, colors=[colors(x, True) for x in det[:, 5]], im_gpu=plot_img)
|
||||
annotator.masks(
|
||||
masks,
|
||||
colors=[colors(x, True) for x in det[:, 5]],
|
||||
im_gpu=torch.as_tensor(im0, dtype=torch.float16).to(device).permute(2, 0, 1).flip(0).contiguous() /
|
||||
255 if retina_masks else im[i])
|
||||
|
||||
# Write results
|
||||
for j, (*xyxy, conf, cls) in enumerate(reversed(det[:, :6])):
|
||||
if save_txt: # Write to file
|
||||
segj = segments[j].reshape(-1) # (n,2) to (n*2)
|
||||
line = (cls, *segj, conf) if save_conf else (cls, *segj) # label format
|
||||
seg = segments[j].reshape(-1) # (n,2) to (n*2)
|
||||
line = (cls, *seg, conf) if save_conf else (cls, *seg) # label format
|
||||
with open(f'{txt_path}.txt', 'a') as f:
|
||||
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ from utils.general import (LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, Profile, check_
|
|||
from utils.metrics import ConfusionMatrix, box_iou
|
||||
from utils.plots import output_to_target, plot_val_study
|
||||
from utils.segment.dataloaders import create_dataloader
|
||||
from utils.segment.general import mask_iou, process_mask, process_mask_upsample, scale_image
|
||||
from utils.segment.general import mask_iou, process_mask, process_mask_native, scale_image
|
||||
from utils.segment.metrics import Metrics, ap_per_class_box_and_mask
|
||||
from utils.segment.plots import plot_images_and_masks
|
||||
from utils.torch_utils import de_parallel, select_device, smart_inference_mode
|
||||
|
@ -160,7 +160,7 @@ def run(
|
|||
):
|
||||
if save_json:
|
||||
check_requirements(['pycocotools'])
|
||||
process = process_mask_upsample # more accurate
|
||||
process = process_mask_native # more accurate
|
||||
else:
|
||||
process = process_mask # faster
|
||||
|
||||
|
@ -312,7 +312,7 @@ def run(
|
|||
|
||||
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
|
||||
if plots and batch_i < 3:
|
||||
plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
|
||||
plot_masks.append(pred_masks[:15]) # filter top 15 to plot
|
||||
|
||||
# Save/log
|
||||
if save_txt:
|
||||
|
@ -367,8 +367,8 @@ def run(
|
|||
# Save JSON
|
||||
if save_json and len(jdict):
|
||||
w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights
|
||||
anno_json = str(Path(data.get('path', '../coco')) / 'annotations/instances_val2017.json') # annotations json
|
||||
pred_json = str(save_dir / f"{w}_predictions.json") # predictions json
|
||||
anno_json = str(Path('../datasets/coco/annotations/instances_val2017.json')) # annotations
|
||||
pred_json = str(save_dir / f"{w}_predictions.json") # predictions
|
||||
LOGGER.info(f'\nEvaluating pycocotools mAP... saving {pred_json}...')
|
||||
with open(pred_json, 'w') as f:
|
||||
json.dump(jdict, f)
|
||||
|
|
|
@ -25,10 +25,10 @@ def crop_mask(masks, boxes):
|
|||
def process_mask_upsample(protos, masks_in, bboxes, shape):
|
||||
"""
|
||||
Crop after upsample.
|
||||
proto_out: [mask_dim, mask_h, mask_w]
|
||||
out_masks: [n, mask_dim], n is number of masks after nms
|
||||
protos: [mask_dim, mask_h, mask_w]
|
||||
masks_in: [n, mask_dim], n is number of masks after nms
|
||||
bboxes: [n, 4], n is number of masks after nms
|
||||
shape:input_image_size, (h, w)
|
||||
shape: input_image_size, (h, w)
|
||||
|
||||
return: h, w, n
|
||||
"""
|
||||
|
@ -67,25 +67,25 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
|
|||
return masks.gt_(0.5)
|
||||
|
||||
|
||||
def process_mask_native(protos, masks_in, bboxes, dst_shape):
|
||||
def process_mask_native(protos, masks_in, bboxes, shape):
|
||||
"""
|
||||
Crop after upsample.
|
||||
proto_out: [mask_dim, mask_h, mask_w]
|
||||
out_masks: [n, mask_dim], n is number of masks after nms
|
||||
protos: [mask_dim, mask_h, mask_w]
|
||||
masks_in: [n, mask_dim], n is number of masks after nms
|
||||
bboxes: [n, 4], n is number of masks after nms
|
||||
shape:input_image_size, (h, w)
|
||||
shape: input_image_size, (h, w)
|
||||
|
||||
return: h, w, n
|
||||
"""
|
||||
c, mh, mw = protos.shape # CHW
|
||||
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
|
||||
gain = min(mh / dst_shape[0], mw / dst_shape[1]) # gain = old / new
|
||||
pad = (mw - dst_shape[1] * gain) / 2, (mh - dst_shape[0] * gain) / 2 # wh padding
|
||||
gain = min(mh / shape[0], mw / shape[1]) # gain = old / new
|
||||
pad = (mw - shape[1] * gain) / 2, (mh - shape[0] * gain) / 2 # wh padding
|
||||
top, left = int(pad[1]), int(pad[0]) # y, x
|
||||
bottom, right = int(mh - pad[1]), int(mw - pad[0])
|
||||
masks = masks[:, top:bottom, left:right]
|
||||
|
||||
masks = F.interpolate(masks[None], dst_shape, mode='bilinear', align_corners=False)[0] # CHW
|
||||
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
|
||||
masks = crop_mask(masks, bboxes) # CHW
|
||||
return masks.gt_(0.5)
|
||||
|
||||
|
|
4
val.py
4
val.py
|
@ -302,8 +302,8 @@ def run(
|
|||
# Save JSON
|
||||
if save_json and len(jdict):
|
||||
w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights
|
||||
anno_json = str(Path(data.get('path', '../coco')) / 'annotations/instances_val2017.json') # annotations json
|
||||
pred_json = str(save_dir / f"{w}_predictions.json") # predictions json
|
||||
anno_json = str(Path('../datasets/coco/annotations/instances_val2017.json')) # annotations
|
||||
pred_json = str(save_dir / f"{w}_predictions.json") # predictions
|
||||
LOGGER.info(f'\nEvaluating pycocotools mAP... saving {pred_json}...')
|
||||
with open(pred_json, 'w') as f:
|
||||
json.dump(jdict, f)
|
||||
|
|
Loading…
Reference in New Issue