[Enhancement] Optimize the vis results of large images in `debug` mode (#346)

* Optimize the vis results of large images in debug mode

* to test

* to test

* remove

* update

* update
pull/350/head
Haian Huang(深度眸) 2022-12-06 18:47:43 +08:00 committed by GitHub
parent 78a23ca8b9
commit 92fd724e87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 91 additions and 37 deletions

View File

@ -7,14 +7,16 @@ wget -P checkpoint https://download.openmmlab.com/mmyolo/v0/yolov5/yolov5_m-v61_
python demo/large_image_demo.py \
demo/large_image.jpg \
configs/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco.py \
checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth \
checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth
```
"""
import os
import random
from argparse import ArgumentParser
import mmcv
import numpy as np
from mmdet.apis import inference_detector, init_detector
from mmengine.logging import print_log
from mmengine.utils import ProgressBar
@ -27,7 +29,7 @@ except ImportError:
from mmyolo.registry import VISUALIZERS
from mmyolo.utils import register_all_modules, switch_to_deploy
from mmyolo.utils.large_image import merge_results_by_nms
from mmyolo.utils.large_image import merge_results_by_nms, shift_predictions
from mmyolo.utils.misc import get_file_list
@ -75,7 +77,12 @@ def parse_args():
parser.add_argument(
'--debug',
action='store_true',
help='Export debug images at each stage for 1 input')
help='Export debug results before merging')
parser.add_argument(
'--save-patch',
action='store_true',
help='Save the results of each patch. '
'The `--debug` must be enabled.')
args = parser.parse_args()
return args
@ -102,13 +109,9 @@ def main():
# get file list
files, source_type = get_file_list(args.img)
# if debug, only process the first file
if args.debug:
files = files[:1]
# start detector inference
print(f'Performing inference on {len(files)} images... \
This may take a while.')
print(f'Performing inference on {len(files)} images.... '
'This may take a while.')
progress_bar = ProgressBar(len(files))
for file in files:
# read image
@ -147,33 +150,87 @@ This may take a while.')
else:
filename = os.path.basename(file)
img = mmcv.imconvert(img, 'bgr', 'rgb')
out_file = None if args.show else os.path.join(args.out_dir, filename)
# export debug images
if args.debug:
# export sliced images
for i, image in enumerate(sliced_image_object.images):
image = mmcv.imconvert(image, 'bgr', 'rgb')
out_file = os.path.join(args.out_dir, 'sliced_images',
filename + f'_slice_{i}.jpg')
mmcv.imwrite(image, out_file)
# export sliced image results
for i, slice_result in enumerate(slice_results):
out_file = os.path.join(args.out_dir, 'sliced_image_results',
filename + f'_slice_{i}_result.jpg')
image = mmcv.imconvert(sliced_image_object.images[i], 'bgr',
'rgb')
name, suffix = os.path.splitext(filename)
visualizer.add_datasample(
os.path.basename(out_file),
image,
data_sample=slice_result,
draw_gt=False,
show=args.show,
wait_time=0,
out_file=out_file,
pred_score_thr=args.score_thr,
)
shifted_instances = shift_predictions(
slice_results,
sliced_image_object.starting_pixels,
src_image_shape=(height, width))
merged_result = slice_results[0].clone()
merged_result.pred_instances = shifted_instances
debug_file_name = name + '_debug' + suffix
debug_out_file = None if args.show else os.path.join(
args.out_dir, debug_file_name)
visualizer.set_image(img.copy())
debug_grids = []
for starting_point in sliced_image_object.starting_pixels:
start_point_x = starting_point[0]
start_point_y = starting_point[1]
end_point_x = start_point_x + args.patch_size
end_point_y = start_point_y + args.patch_size
debug_grids.append(
[start_point_x, start_point_y, end_point_x, end_point_y])
debug_grids = np.array(debug_grids)
debug_grids[:, 0::2] = np.clip(debug_grids[:, 0::2], 1,
img.shape[1] - 1)
debug_grids[:, 1::2] = np.clip(debug_grids[:, 1::2], 1,
img.shape[0] - 1)
palette = np.random.randint(0, 256, size=(len(debug_grids), 3))
palette = [tuple(c) for c in palette]
line_styles = random.choices(['-', '-.', ':'], k=len(debug_grids))
visualizer.draw_bboxes(
debug_grids,
edge_colors=palette,
alpha=1,
line_styles=line_styles)
visualizer.draw_bboxes(
debug_grids, face_colors=palette, alpha=0.15)
visualizer.draw_texts(
list(range(len(debug_grids))),
debug_grids[:, :2] + 5,
colors='w')
visualizer.add_datasample(
debug_file_name,
visualizer.get_image(),
data_sample=merged_result,
draw_gt=False,
show=args.show,
wait_time=0,
out_file=debug_out_file,
pred_score_thr=args.score_thr,
)
if args.save_patch:
debug_patch_out_dir = os.path.join(args.out_dir,
f'{name}_patch')
for i, slice_result in enumerate(slice_results):
patch_out_file = os.path.join(
debug_patch_out_dir,
f'{filename}_slice_{i}_result.jpg')
image = mmcv.imconvert(sliced_image_object.images[i],
'bgr', 'rgb')
visualizer.add_datasample(
'patch_result',
image,
data_sample=slice_result,
draw_gt=False,
show=False,
wait_time=0,
out_file=patch_out_file,
pred_score_thr=args.score_thr,
)
image_result = merge_results_by_nms(
slice_results,
@ -184,11 +241,8 @@ This may take a while.')
'iou_thr': args.merge_iou_thr
})
img = mmcv.imconvert(img, 'bgr', 'rgb')
out_file = None if args.show else os.path.join(args.out_dir, filename)
visualizer.add_datasample(
os.path.basename(out_file),
filename,
img,
data_sample=image_result,
draw_gt=False,
@ -199,7 +253,7 @@ This may take a while.')
)
progress_bar.update()
if not args.show:
if not args.show or (args.debug and args.save_patch):
print_log(
f'\nResults have been saved at {os.path.abspath(args.out_dir)}')