mirror of https://github.com/open-mmlab/mmyolo.git
295 lines
10 KiB
Python
295 lines
10 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
"""Perform MMYOLO inference on large images (as satellite imagery) as:
|
|
|
|
```shell
|
|
wget -P checkpoint https://download.openmmlab.com/mmyolo/v0/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth # noqa: E501, E261.
|
|
|
|
python demo/large_image_demo.py \
|
|
demo/large_image.jpg \
|
|
configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
|
|
checkpoint/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth
|
|
```
|
|
"""
|
|
|
|
import os
|
|
import random
|
|
from argparse import ArgumentParser
|
|
from pathlib import Path
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
from mmdet.apis import inference_detector, init_detector
|
|
from mmengine.config import Config, ConfigDict
|
|
from mmengine.logging import print_log
|
|
from mmengine.utils import ProgressBar
|
|
|
|
try:
|
|
from sahi.slicing import slice_image
|
|
except ImportError:
|
|
raise ImportError('Please run "pip install -U sahi" '
|
|
'to install sahi first for large image inference.')
|
|
|
|
from mmyolo.registry import VISUALIZERS
|
|
from mmyolo.utils import switch_to_deploy
|
|
from mmyolo.utils.large_image import merge_results_by_nms, shift_predictions
|
|
from mmyolo.utils.misc import get_file_list
|
|
|
|
|
|
def parse_args():
|
|
parser = ArgumentParser(
|
|
description='Perform MMYOLO inference on large images.')
|
|
parser.add_argument(
|
|
'img', help='Image path, include image file, dir and URL.')
|
|
parser.add_argument('config', help='Config file')
|
|
parser.add_argument('checkpoint', help='Checkpoint file')
|
|
parser.add_argument(
|
|
'--out-dir', default='./output', help='Path to output file')
|
|
parser.add_argument(
|
|
'--device', default='cuda:0', help='Device used for inference')
|
|
parser.add_argument(
|
|
'--show', action='store_true', help='Show the detection results')
|
|
parser.add_argument(
|
|
'--deploy',
|
|
action='store_true',
|
|
help='Switch model to deployment mode')
|
|
parser.add_argument(
|
|
'--tta',
|
|
action='store_true',
|
|
help='Whether to use test time augmentation')
|
|
parser.add_argument(
|
|
'--score-thr', type=float, default=0.3, help='Bbox score threshold')
|
|
parser.add_argument(
|
|
'--patch-size', type=int, default=640, help='The size of patches')
|
|
parser.add_argument(
|
|
'--patch-overlap-ratio',
|
|
type=float,
|
|
default=0.25,
|
|
help='Ratio of overlap between two patches')
|
|
parser.add_argument(
|
|
'--merge-iou-thr',
|
|
type=float,
|
|
default=0.25,
|
|
help='IoU threshould for merging results')
|
|
parser.add_argument(
|
|
'--merge-nms-type',
|
|
type=str,
|
|
default='nms',
|
|
help='NMS type for merging results')
|
|
parser.add_argument(
|
|
'--batch-size',
|
|
type=int,
|
|
default=1,
|
|
help='Batch size, must greater than or equal to 1')
|
|
parser.add_argument(
|
|
'--debug',
|
|
action='store_true',
|
|
help='Export debug results before merging')
|
|
parser.add_argument(
|
|
'--save-patch',
|
|
action='store_true',
|
|
help='Save the results of each patch. '
|
|
'The `--debug` must be enabled.')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
config = args.config
|
|
|
|
if isinstance(config, (str, Path)):
|
|
config = Config.fromfile(config)
|
|
elif not isinstance(config, Config):
|
|
raise TypeError('config must be a filename or Config object, '
|
|
f'but got {type(config)}')
|
|
if 'init_cfg' in config.model.backbone:
|
|
config.model.backbone.init_cfg = None
|
|
|
|
if args.tta:
|
|
assert 'tta_model' in config, 'Cannot find ``tta_model`` in config.' \
|
|
" Can't use tta !"
|
|
assert 'tta_pipeline' in config, 'Cannot find ``tta_pipeline`` ' \
|
|
"in config. Can't use tta !"
|
|
config.model = ConfigDict(**config.tta_model, module=config.model)
|
|
test_data_cfg = config.test_dataloader.dataset
|
|
while 'dataset' in test_data_cfg:
|
|
test_data_cfg = test_data_cfg['dataset']
|
|
|
|
# batch_shapes_cfg will force control the size of the output image,
|
|
# it is not compatible with tta.
|
|
if 'batch_shapes_cfg' in test_data_cfg:
|
|
test_data_cfg.batch_shapes_cfg = None
|
|
test_data_cfg.pipeline = config.tta_pipeline
|
|
|
|
# TODO: TTA mode will error if cfg_options is not set.
|
|
# This is an mmdet issue and needs to be fixed later.
|
|
# build the model from a config file and a checkpoint file
|
|
model = init_detector(
|
|
config, args.checkpoint, device=args.device, cfg_options={})
|
|
|
|
if args.deploy:
|
|
switch_to_deploy(model)
|
|
|
|
if not os.path.exists(args.out_dir) and not args.show:
|
|
os.mkdir(args.out_dir)
|
|
|
|
# init visualizer
|
|
visualizer = VISUALIZERS.build(model.cfg.visualizer)
|
|
visualizer.dataset_meta = model.dataset_meta
|
|
|
|
# get file list
|
|
files, source_type = get_file_list(args.img)
|
|
|
|
# start detector inference
|
|
print(f'Performing inference on {len(files)} images.... '
|
|
'This may take a while.')
|
|
progress_bar = ProgressBar(len(files))
|
|
for file in files:
|
|
# read image
|
|
img = mmcv.imread(file)
|
|
|
|
# arrange slices
|
|
height, width = img.shape[:2]
|
|
sliced_image_object = slice_image(
|
|
img,
|
|
slice_height=args.patch_size,
|
|
slice_width=args.patch_size,
|
|
auto_slice_resolution=False,
|
|
overlap_height_ratio=args.patch_overlap_ratio,
|
|
overlap_width_ratio=args.patch_overlap_ratio,
|
|
)
|
|
|
|
# perform sliced inference
|
|
slice_results = []
|
|
start = 0
|
|
while True:
|
|
# prepare batch slices
|
|
end = min(start + args.batch_size, len(sliced_image_object))
|
|
images = []
|
|
for sliced_image in sliced_image_object.images[start:end]:
|
|
images.append(sliced_image)
|
|
|
|
# forward the model
|
|
slice_results.extend(inference_detector(model, images))
|
|
|
|
if end >= len(sliced_image_object):
|
|
break
|
|
start += args.batch_size
|
|
|
|
if source_type['is_dir']:
|
|
filename = os.path.relpath(file, args.img).replace('/', '_')
|
|
else:
|
|
filename = os.path.basename(file)
|
|
|
|
img = mmcv.imconvert(img, 'bgr', 'rgb')
|
|
out_file = None if args.show else os.path.join(args.out_dir, filename)
|
|
|
|
# export debug images
|
|
if args.debug:
|
|
# export sliced image results
|
|
name, suffix = os.path.splitext(filename)
|
|
|
|
shifted_instances = shift_predictions(
|
|
slice_results,
|
|
sliced_image_object.starting_pixels,
|
|
src_image_shape=(height, width))
|
|
merged_result = slice_results[0].clone()
|
|
merged_result.pred_instances = shifted_instances
|
|
|
|
debug_file_name = name + '_debug' + suffix
|
|
debug_out_file = None if args.show else os.path.join(
|
|
args.out_dir, debug_file_name)
|
|
visualizer.set_image(img.copy())
|
|
|
|
debug_grids = []
|
|
for starting_point in sliced_image_object.starting_pixels:
|
|
start_point_x = starting_point[0]
|
|
start_point_y = starting_point[1]
|
|
end_point_x = start_point_x + args.patch_size
|
|
end_point_y = start_point_y + args.patch_size
|
|
debug_grids.append(
|
|
[start_point_x, start_point_y, end_point_x, end_point_y])
|
|
debug_grids = np.array(debug_grids)
|
|
debug_grids[:, 0::2] = np.clip(debug_grids[:, 0::2], 1,
|
|
img.shape[1] - 1)
|
|
debug_grids[:, 1::2] = np.clip(debug_grids[:, 1::2], 1,
|
|
img.shape[0] - 1)
|
|
|
|
palette = np.random.randint(0, 256, size=(len(debug_grids), 3))
|
|
palette = [tuple(c) for c in palette]
|
|
line_styles = random.choices(['-', '-.', ':'], k=len(debug_grids))
|
|
visualizer.draw_bboxes(
|
|
debug_grids,
|
|
edge_colors=palette,
|
|
alpha=1,
|
|
line_styles=line_styles)
|
|
visualizer.draw_bboxes(
|
|
debug_grids, face_colors=palette, alpha=0.15)
|
|
|
|
visualizer.draw_texts(
|
|
list(range(len(debug_grids))),
|
|
debug_grids[:, :2] + 5,
|
|
colors='w')
|
|
|
|
visualizer.add_datasample(
|
|
debug_file_name,
|
|
visualizer.get_image(),
|
|
data_sample=merged_result,
|
|
draw_gt=False,
|
|
show=args.show,
|
|
wait_time=0,
|
|
out_file=debug_out_file,
|
|
pred_score_thr=args.score_thr,
|
|
)
|
|
|
|
if args.save_patch:
|
|
debug_patch_out_dir = os.path.join(args.out_dir,
|
|
f'{name}_patch')
|
|
for i, slice_result in enumerate(slice_results):
|
|
patch_out_file = os.path.join(
|
|
debug_patch_out_dir,
|
|
f'{filename}_slice_{i}_result.jpg')
|
|
image = mmcv.imconvert(sliced_image_object.images[i],
|
|
'bgr', 'rgb')
|
|
|
|
visualizer.add_datasample(
|
|
'patch_result',
|
|
image,
|
|
data_sample=slice_result,
|
|
draw_gt=False,
|
|
show=False,
|
|
wait_time=0,
|
|
out_file=patch_out_file,
|
|
pred_score_thr=args.score_thr,
|
|
)
|
|
|
|
image_result = merge_results_by_nms(
|
|
slice_results,
|
|
sliced_image_object.starting_pixels,
|
|
src_image_shape=(height, width),
|
|
nms_cfg={
|
|
'type': args.merge_nms_type,
|
|
'iou_threshold': args.merge_iou_thr
|
|
})
|
|
|
|
visualizer.add_datasample(
|
|
filename,
|
|
img,
|
|
data_sample=image_result,
|
|
draw_gt=False,
|
|
show=args.show,
|
|
wait_time=0,
|
|
out_file=out_file,
|
|
pred_score_thr=args.score_thr,
|
|
)
|
|
progress_bar.update()
|
|
|
|
if not args.show or (args.debug and args.save_patch):
|
|
print_log(
|
|
f'\nResults have been saved at {os.path.abspath(args.out_dir)}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|