diff --git a/demo/large_image.jpg b/demo/large_image.jpg new file mode 100644 index 00000000..1abbc5d9 Binary files /dev/null and b/demo/large_image.jpg differ diff --git a/demo/large_image_demo.py b/demo/large_image_demo.py new file mode 100644 index 00000000..9b4a72ac --- /dev/null +++ b/demo/large_image_demo.py @@ -0,0 +1,208 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Perform MMYOLO inference on large images (as satellite imagery) as: + +```shell +wget -P checkpoint https://download.openmmlab.com/mmyolo/v0/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth # noqa: E501, E261. + +python demo/large_image_demo.py \ + demo/large_image.jpg \ + configs/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco.py \ + checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth \ +``` +""" + +import os +from argparse import ArgumentParser + +import mmcv +from mmdet.apis import inference_detector, init_detector +from mmengine.logging import print_log +from mmengine.utils import ProgressBar + +try: + from sahi.slicing import slice_image +except ImportError: + raise ImportError('Please run "pip install -U sahi" ' + 'to install sahi first for large image inference.') + +from mmyolo.registry import VISUALIZERS +from mmyolo.utils import register_all_modules, switch_to_deploy +from mmyolo.utils.large_image import merge_results_by_nms +from mmyolo.utils.misc import get_file_list + + +def parse_args(): + parser = ArgumentParser( + description='Perform MMYOLO inference on large images.') + parser.add_argument( + 'img', help='Image path, include image file, dir and URL.') + parser.add_argument('config', help='Config file') + parser.add_argument('checkpoint', help='Checkpoint file') + parser.add_argument( + '--out-dir', default='./output', help='Path to output file') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + parser.add_argument( + '--show', action='store_true', help='Show the detection results') + parser.add_argument( + '--deploy', + action='store_true', + help='Switch model to deployment mode') + parser.add_argument( + '--score-thr', type=float, default=0.3, help='Bbox score threshold') + parser.add_argument( + '--patch-size', type=int, default=640, help='The size of patches') + parser.add_argument( + '--patch-overlap-ratio', + type=int, + default=0.25, + help='Ratio of overlap between two patches') + parser.add_argument( + '--merge-iou-thr', + type=float, + default=0.25, + help='IoU threshould for merging results') + parser.add_argument( + '--merge-nms-type', + type=str, + default='nms', + help='NMS type for merging results') + parser.add_argument( + '--batch-size', + type=int, + default=1, + help='Batch size, must greater than or equal to 1') + parser.add_argument( + '--debug', + action='store_true', + help='Export debug images at each stage for 1 input') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + # register all modules in mmdet into the registries + register_all_modules() + + # build the model from a config file and a checkpoint file + model = init_detector(args.config, args.checkpoint, device=args.device) + + if args.deploy: + switch_to_deploy(model) + + if not os.path.exists(args.out_dir) and not args.show: + os.mkdir(args.out_dir) + + # init visualizer + visualizer = VISUALIZERS.build(model.cfg.visualizer) + visualizer.dataset_meta = model.dataset_meta + + # get file list + files, source_type = get_file_list(args.img) + + # if debug, only process the first file + if args.debug: + files = files[:1] + + # start detector inference + print(f'Performing inference on {len(files)} images... \ +This may take a while.') + progress_bar = ProgressBar(len(files)) + for file in files: + # read image + img = mmcv.imread(file) + + # arrange slices + height, width = img.shape[:2] + sliced_image_object = slice_image( + img, + slice_height=args.patch_size, + slice_width=args.patch_size, + auto_slice_resolution=False, + overlap_height_ratio=args.patch_overlap_ratio, + overlap_width_ratio=args.patch_overlap_ratio, + ) + + # perform sliced inference + slice_results = [] + start = 0 + while True: + # prepare batch slices + end = min(start + args.batch_size, len(sliced_image_object)) + images = [] + for sliced_image in sliced_image_object.images[start:end]: + images.append(sliced_image) + + # forward the model + slice_results.extend(inference_detector(model, images)) + + if end >= len(sliced_image_object): + break + start += args.batch_size + + if source_type['is_dir']: + filename = os.path.relpath(file, args.img).replace('/', '_') + else: + filename = os.path.basename(file) + + # export debug images + if args.debug: + # export sliced images + for i, image in enumerate(sliced_image_object.images): + image = mmcv.imconvert(image, 'bgr', 'rgb') + out_file = os.path.join(args.out_dir, 'sliced_images', + filename + f'_slice_{i}.jpg') + + mmcv.imwrite(image, out_file) + + # export sliced image results + for i, slice_result in enumerate(slice_results): + out_file = os.path.join(args.out_dir, 'sliced_image_results', + filename + f'_slice_{i}_result.jpg') + image = mmcv.imconvert(sliced_image_object.images[i], 'bgr', + 'rgb') + + visualizer.add_datasample( + os.path.basename(out_file), + image, + data_sample=slice_result, + draw_gt=False, + show=args.show, + wait_time=0, + out_file=out_file, + pred_score_thr=args.score_thr, + ) + + image_result = merge_results_by_nms( + slice_results, + sliced_image_object.starting_pixels, + src_image_shape=(height, width), + nms_cfg={ + 'type': args.merge_nms_type, + 'iou_thr': args.merge_iou_thr + }) + + img = mmcv.imconvert(img, 'bgr', 'rgb') + out_file = None if args.show else os.path.join(args.out_dir, filename) + + visualizer.add_datasample( + os.path.basename(out_file), + img, + data_sample=image_result, + draw_gt=False, + show=args.show, + wait_time=0, + out_file=out_file, + pred_score_thr=args.score_thr, + ) + progress_bar.update() + + if not args.show: + print_log( + f'\nResults have been saved at {os.path.abspath(args.out_dir)}') + + +if __name__ == '__main__': + main() diff --git a/docs/en/user_guides/useful_tools.md b/docs/en/user_guides/useful_tools.md index a1ec5dfc..b23e0b12 100644 --- a/docs/en/user_guides/useful_tools.md +++ b/docs/en/user_guides/useful_tools.md @@ -350,6 +350,59 @@ python tools/analysis_tools/optimize_anchors.py ${CONFIG} \ --output-dir ${OUTPUT_DIR} ``` +## Perform inference on large images + +First install [`sahi`](https://github.com/obss/sahi) with: + +```shell +pip install -U sahi>=0.11.4 +``` + +Perform MMYOLO inference on large images (as satellite imagery) as: + +```shell +wget -P checkpoint https://download.openmmlab.com/mmyolo/v0/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth + +python demo/large_image_demo.py \ + demo/large_image.jpg \ + configs/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco.py \ + checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth \ +``` + +Arrange slicing parameters as: + +```shell +python demo/large_image_demo.py \ + demo/large_image.jpg \ + configs/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco.py \ + checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth \ + --patch-size 512 + --patch-overlap-ratio 0.25 +``` + +Export debug visuals while performing inference on large images as: + +```shell +python demo/large_image_demo.py \ + demo/large_image.jpg \ + configs/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco.py \ + checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth \ + --debug +``` + +[`sahi`](https://github.com/obss/sahi) citation: + +``` +@article{akyon2022sahi, + title={Slicing Aided Hyper Inference and Fine-tuning for Small Object Detection}, + author={Akyon, Fatih Cagatay and Altinuc, Sinan Onur and Temizel, Alptekin}, + journal={2022 IEEE International Conference on Image Processing (ICIP)}, + doi={10.1109/ICIP46576.2022.9897990}, + pages={966-970}, + year={2022} +} +``` + ## Extracts a subset of COCO The training dataset of the COCO2017 dataset includes 118K images, and the validation set includes 5K images, which is a relatively large dataset. Loading JSON in debugging or quick verification scenarios will consume more resources and bring slower startup speed. diff --git a/mmyolo/utils/large_image.py b/mmyolo/utils/large_image.py new file mode 100644 index 00000000..68c6938e --- /dev/null +++ b/mmyolo/utils/large_image.py @@ -0,0 +1,76 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence, Tuple + +from mmcv.ops import batched_nms +from mmdet.structures import DetDataSample, SampleList +from mmengine.structures import InstanceData + + +def shift_predictions(det_data_samples: SampleList, + offsets: Sequence[Tuple[int, int]], + src_image_shape: Tuple[int, int]) -> SampleList: + """Shift predictions to the original image. + + Args: + det_data_samples (List[:obj:`DetDataSample`]): A list of patch results. + offsets (Sequence[Tuple[int, int]]): Positions of the left top points + of patches. + src_image_shape (Tuple[int, int]): A (height, width) tuple of the large + image's width and height. + Returns: + (List[:obj:`DetDataSample`]): shifted results. + """ + try: + from sahi.slicing import shift_bboxes, shift_masks + except ImportError: + raise ImportError('Please run "pip install -U sahi" ' + 'to install sahi first for large image inference.') + + assert len(det_data_samples) == len( + offsets), 'The `results` should has the ' 'same length with `offsets`.' + shifted_predictions = [] + for det_data_sample, offset in zip(det_data_samples, offsets): + pred_inst = det_data_sample.pred_instances.clone() + + # shift bboxes and masks + pred_inst.bboxes = shift_bboxes(pred_inst.bboxes, offset) + if 'masks' in det_data_sample: + pred_inst.masks = shift_masks(pred_inst.masks, offset, + src_image_shape) + + shifted_predictions.append(pred_inst.clone()) + + shifted_predictions = InstanceData.cat(shifted_predictions) + + return shifted_predictions + + +def merge_results_by_nms(results: SampleList, offsets: Sequence[Tuple[int, + int]], + src_image_shape: Tuple[int, int], + nms_cfg: dict) -> DetDataSample: + """Merge patch results by nms. + + Args: + results (List[:obj:`DetDataSample`]): A list of patch results. + offsets (Sequence[Tuple[int, int]]): Positions of the left top points + of patches. + src_image_shape (Tuple[int, int]): A (height, width) tuple of the large + image's width and height. + nms_cfg (dict): it should specify nms type and other parameters + like `iou_threshold`. + Returns: + :obj:`DetDataSample`: merged results. + """ + shifted_instances = shift_predictions(results, offsets, src_image_shape) + + _, keeps = batched_nms( + boxes=shifted_instances.bboxes, + scores=shifted_instances.scores, + idxs=shifted_instances.labels, + nms_cfg=nms_cfg) + merged_instances = shifted_instances[keeps] + + merged_result = results[0].clone() + merged_result.pred_instances = merged_instances + return merged_result diff --git a/requirements/sahi.txt b/requirements/sahi.txt new file mode 100644 index 00000000..0e7b7b84 --- /dev/null +++ b/requirements/sahi.txt @@ -0,0 +1 @@ +sahi>=0.11.4