mirror of https://github.com/open-mmlab/mmyolo.git
[Feature] Add large image demo with `sahi` (#284)
* add large image demo with sahi * fix some typos * restructure based on reviews * update default patch size * add docstring and update docs * updates based on reviews * print information * add debug, update docs, add large image sample * update docs * update docs * update docs * direct user to install sahipull/303/head
parent
5cee9c977d
commit
0fd6444c00
Binary file not shown.
After Width: | Height: | Size: 168 KiB |
|
@ -0,0 +1,208 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""Perform MMYOLO inference on large images (as satellite imagery) as:
|
||||
|
||||
```shell
|
||||
wget -P checkpoint https://download.openmmlab.com/mmyolo/v0/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth # noqa: E501, E261.
|
||||
|
||||
python demo/large_image_demo.py \
|
||||
demo/large_image.jpg \
|
||||
configs/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco.py \
|
||||
checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth \
|
||||
```
|
||||
"""
|
||||
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import mmcv
|
||||
from mmdet.apis import inference_detector, init_detector
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.utils import ProgressBar
|
||||
|
||||
try:
|
||||
from sahi.slicing import slice_image
|
||||
except ImportError:
|
||||
raise ImportError('Please run "pip install -U sahi" '
|
||||
'to install sahi first for large image inference.')
|
||||
|
||||
from mmyolo.registry import VISUALIZERS
|
||||
from mmyolo.utils import register_all_modules, switch_to_deploy
|
||||
from mmyolo.utils.large_image import merge_results_by_nms
|
||||
from mmyolo.utils.misc import get_file_list
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(
|
||||
description='Perform MMYOLO inference on large images.')
|
||||
parser.add_argument(
|
||||
'img', help='Image path, include image file, dir and URL.')
|
||||
parser.add_argument('config', help='Config file')
|
||||
parser.add_argument('checkpoint', help='Checkpoint file')
|
||||
parser.add_argument(
|
||||
'--out-dir', default='./output', help='Path to output file')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference')
|
||||
parser.add_argument(
|
||||
'--show', action='store_true', help='Show the detection results')
|
||||
parser.add_argument(
|
||||
'--deploy',
|
||||
action='store_true',
|
||||
help='Switch model to deployment mode')
|
||||
parser.add_argument(
|
||||
'--score-thr', type=float, default=0.3, help='Bbox score threshold')
|
||||
parser.add_argument(
|
||||
'--patch-size', type=int, default=640, help='The size of patches')
|
||||
parser.add_argument(
|
||||
'--patch-overlap-ratio',
|
||||
type=int,
|
||||
default=0.25,
|
||||
help='Ratio of overlap between two patches')
|
||||
parser.add_argument(
|
||||
'--merge-iou-thr',
|
||||
type=float,
|
||||
default=0.25,
|
||||
help='IoU threshould for merging results')
|
||||
parser.add_argument(
|
||||
'--merge-nms-type',
|
||||
type=str,
|
||||
default='nms',
|
||||
help='NMS type for merging results')
|
||||
parser.add_argument(
|
||||
'--batch-size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Batch size, must greater than or equal to 1')
|
||||
parser.add_argument(
|
||||
'--debug',
|
||||
action='store_true',
|
||||
help='Export debug images at each stage for 1 input')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# register all modules in mmdet into the registries
|
||||
register_all_modules()
|
||||
|
||||
# build the model from a config file and a checkpoint file
|
||||
model = init_detector(args.config, args.checkpoint, device=args.device)
|
||||
|
||||
if args.deploy:
|
||||
switch_to_deploy(model)
|
||||
|
||||
if not os.path.exists(args.out_dir) and not args.show:
|
||||
os.mkdir(args.out_dir)
|
||||
|
||||
# init visualizer
|
||||
visualizer = VISUALIZERS.build(model.cfg.visualizer)
|
||||
visualizer.dataset_meta = model.dataset_meta
|
||||
|
||||
# get file list
|
||||
files, source_type = get_file_list(args.img)
|
||||
|
||||
# if debug, only process the first file
|
||||
if args.debug:
|
||||
files = files[:1]
|
||||
|
||||
# start detector inference
|
||||
print(f'Performing inference on {len(files)} images... \
|
||||
This may take a while.')
|
||||
progress_bar = ProgressBar(len(files))
|
||||
for file in files:
|
||||
# read image
|
||||
img = mmcv.imread(file)
|
||||
|
||||
# arrange slices
|
||||
height, width = img.shape[:2]
|
||||
sliced_image_object = slice_image(
|
||||
img,
|
||||
slice_height=args.patch_size,
|
||||
slice_width=args.patch_size,
|
||||
auto_slice_resolution=False,
|
||||
overlap_height_ratio=args.patch_overlap_ratio,
|
||||
overlap_width_ratio=args.patch_overlap_ratio,
|
||||
)
|
||||
|
||||
# perform sliced inference
|
||||
slice_results = []
|
||||
start = 0
|
||||
while True:
|
||||
# prepare batch slices
|
||||
end = min(start + args.batch_size, len(sliced_image_object))
|
||||
images = []
|
||||
for sliced_image in sliced_image_object.images[start:end]:
|
||||
images.append(sliced_image)
|
||||
|
||||
# forward the model
|
||||
slice_results.extend(inference_detector(model, images))
|
||||
|
||||
if end >= len(sliced_image_object):
|
||||
break
|
||||
start += args.batch_size
|
||||
|
||||
if source_type['is_dir']:
|
||||
filename = os.path.relpath(file, args.img).replace('/', '_')
|
||||
else:
|
||||
filename = os.path.basename(file)
|
||||
|
||||
# export debug images
|
||||
if args.debug:
|
||||
# export sliced images
|
||||
for i, image in enumerate(sliced_image_object.images):
|
||||
image = mmcv.imconvert(image, 'bgr', 'rgb')
|
||||
out_file = os.path.join(args.out_dir, 'sliced_images',
|
||||
filename + f'_slice_{i}.jpg')
|
||||
|
||||
mmcv.imwrite(image, out_file)
|
||||
|
||||
# export sliced image results
|
||||
for i, slice_result in enumerate(slice_results):
|
||||
out_file = os.path.join(args.out_dir, 'sliced_image_results',
|
||||
filename + f'_slice_{i}_result.jpg')
|
||||
image = mmcv.imconvert(sliced_image_object.images[i], 'bgr',
|
||||
'rgb')
|
||||
|
||||
visualizer.add_datasample(
|
||||
os.path.basename(out_file),
|
||||
image,
|
||||
data_sample=slice_result,
|
||||
draw_gt=False,
|
||||
show=args.show,
|
||||
wait_time=0,
|
||||
out_file=out_file,
|
||||
pred_score_thr=args.score_thr,
|
||||
)
|
||||
|
||||
image_result = merge_results_by_nms(
|
||||
slice_results,
|
||||
sliced_image_object.starting_pixels,
|
||||
src_image_shape=(height, width),
|
||||
nms_cfg={
|
||||
'type': args.merge_nms_type,
|
||||
'iou_thr': args.merge_iou_thr
|
||||
})
|
||||
|
||||
img = mmcv.imconvert(img, 'bgr', 'rgb')
|
||||
out_file = None if args.show else os.path.join(args.out_dir, filename)
|
||||
|
||||
visualizer.add_datasample(
|
||||
os.path.basename(out_file),
|
||||
img,
|
||||
data_sample=image_result,
|
||||
draw_gt=False,
|
||||
show=args.show,
|
||||
wait_time=0,
|
||||
out_file=out_file,
|
||||
pred_score_thr=args.score_thr,
|
||||
)
|
||||
progress_bar.update()
|
||||
|
||||
if not args.show:
|
||||
print_log(
|
||||
f'\nResults have been saved at {os.path.abspath(args.out_dir)}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -350,6 +350,59 @@ python tools/analysis_tools/optimize_anchors.py ${CONFIG} \
|
|||
--output-dir ${OUTPUT_DIR}
|
||||
```
|
||||
|
||||
## Perform inference on large images
|
||||
|
||||
First install [`sahi`](https://github.com/obss/sahi) with:
|
||||
|
||||
```shell
|
||||
pip install -U sahi>=0.11.4
|
||||
```
|
||||
|
||||
Perform MMYOLO inference on large images (as satellite imagery) as:
|
||||
|
||||
```shell
|
||||
wget -P checkpoint https://download.openmmlab.com/mmyolo/v0/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth
|
||||
|
||||
python demo/large_image_demo.py \
|
||||
demo/large_image.jpg \
|
||||
configs/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco.py \
|
||||
checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth \
|
||||
```
|
||||
|
||||
Arrange slicing parameters as:
|
||||
|
||||
```shell
|
||||
python demo/large_image_demo.py \
|
||||
demo/large_image.jpg \
|
||||
configs/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco.py \
|
||||
checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth \
|
||||
--patch-size 512
|
||||
--patch-overlap-ratio 0.25
|
||||
```
|
||||
|
||||
Export debug visuals while performing inference on large images as:
|
||||
|
||||
```shell
|
||||
python demo/large_image_demo.py \
|
||||
demo/large_image.jpg \
|
||||
configs/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco.py \
|
||||
checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth \
|
||||
--debug
|
||||
```
|
||||
|
||||
[`sahi`](https://github.com/obss/sahi) citation:
|
||||
|
||||
```
|
||||
@article{akyon2022sahi,
|
||||
title={Slicing Aided Hyper Inference and Fine-tuning for Small Object Detection},
|
||||
author={Akyon, Fatih Cagatay and Altinuc, Sinan Onur and Temizel, Alptekin},
|
||||
journal={2022 IEEE International Conference on Image Processing (ICIP)},
|
||||
doi={10.1109/ICIP46576.2022.9897990},
|
||||
pages={966-970},
|
||||
year={2022}
|
||||
}
|
||||
```
|
||||
|
||||
## Extracts a subset of COCO
|
||||
|
||||
The training dataset of the COCO2017 dataset includes 118K images, and the validation set includes 5K images, which is a relatively large dataset. Loading JSON in debugging or quick verification scenarios will consume more resources and bring slower startup speed.
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
from mmcv.ops import batched_nms
|
||||
from mmdet.structures import DetDataSample, SampleList
|
||||
from mmengine.structures import InstanceData
|
||||
|
||||
|
||||
def shift_predictions(det_data_samples: SampleList,
|
||||
offsets: Sequence[Tuple[int, int]],
|
||||
src_image_shape: Tuple[int, int]) -> SampleList:
|
||||
"""Shift predictions to the original image.
|
||||
|
||||
Args:
|
||||
det_data_samples (List[:obj:`DetDataSample`]): A list of patch results.
|
||||
offsets (Sequence[Tuple[int, int]]): Positions of the left top points
|
||||
of patches.
|
||||
src_image_shape (Tuple[int, int]): A (height, width) tuple of the large
|
||||
image's width and height.
|
||||
Returns:
|
||||
(List[:obj:`DetDataSample`]): shifted results.
|
||||
"""
|
||||
try:
|
||||
from sahi.slicing import shift_bboxes, shift_masks
|
||||
except ImportError:
|
||||
raise ImportError('Please run "pip install -U sahi" '
|
||||
'to install sahi first for large image inference.')
|
||||
|
||||
assert len(det_data_samples) == len(
|
||||
offsets), 'The `results` should has the ' 'same length with `offsets`.'
|
||||
shifted_predictions = []
|
||||
for det_data_sample, offset in zip(det_data_samples, offsets):
|
||||
pred_inst = det_data_sample.pred_instances.clone()
|
||||
|
||||
# shift bboxes and masks
|
||||
pred_inst.bboxes = shift_bboxes(pred_inst.bboxes, offset)
|
||||
if 'masks' in det_data_sample:
|
||||
pred_inst.masks = shift_masks(pred_inst.masks, offset,
|
||||
src_image_shape)
|
||||
|
||||
shifted_predictions.append(pred_inst.clone())
|
||||
|
||||
shifted_predictions = InstanceData.cat(shifted_predictions)
|
||||
|
||||
return shifted_predictions
|
||||
|
||||
|
||||
def merge_results_by_nms(results: SampleList, offsets: Sequence[Tuple[int,
|
||||
int]],
|
||||
src_image_shape: Tuple[int, int],
|
||||
nms_cfg: dict) -> DetDataSample:
|
||||
"""Merge patch results by nms.
|
||||
|
||||
Args:
|
||||
results (List[:obj:`DetDataSample`]): A list of patch results.
|
||||
offsets (Sequence[Tuple[int, int]]): Positions of the left top points
|
||||
of patches.
|
||||
src_image_shape (Tuple[int, int]): A (height, width) tuple of the large
|
||||
image's width and height.
|
||||
nms_cfg (dict): it should specify nms type and other parameters
|
||||
like `iou_threshold`.
|
||||
Returns:
|
||||
:obj:`DetDataSample`: merged results.
|
||||
"""
|
||||
shifted_instances = shift_predictions(results, offsets, src_image_shape)
|
||||
|
||||
_, keeps = batched_nms(
|
||||
boxes=shifted_instances.bboxes,
|
||||
scores=shifted_instances.scores,
|
||||
idxs=shifted_instances.labels,
|
||||
nms_cfg=nms_cfg)
|
||||
merged_instances = shifted_instances[keeps]
|
||||
|
||||
merged_result = results[0].clone()
|
||||
merged_result.pred_instances = merged_instances
|
||||
return merged_result
|
|
@ -0,0 +1 @@
|
|||
sahi>=0.11.4
|
Loading…
Reference in New Issue