mmyolo/demo/image_demo.py

138 lines
4.3 KiB
Python
Raw Normal View History

2022-09-18 10:11:55 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
import os
2022-09-18 10:11:55 +08:00
from argparse import ArgumentParser
import mmcv
from mmdet.apis import inference_detector, init_detector
from mmengine.logging import print_log
from mmengine.utils import ProgressBar, path
2022-09-20 10:57:33 +08:00
2022-09-18 10:11:55 +08:00
from mmyolo.registry import VISUALIZERS
[Feature] Support YOLOv6 training (#183) * init v6 loss * init v6s train * Add train pipeline * Add lr scheduler * update * update * update * update * update * update * update * update * update * fix detach bug * fix detach bug * update * Add stop aug hook * Add save best ckpt * update * Add PipelineSwitchHook * Fix train pipeline stage 2 * update * Fix train pipeline * update * fix stage2 randomaffine bug update update clean clean * update letterResize param * add v6affine config * add v6 randomaffine * update v6 config * update * update * update * update * update config param * update * update * refactor iou loss % rm v6affine * update * rm dfl * add v6 300 epoch config * Factor batch atss assigner * Format code * Format code * Roll back * Refactor dist_calculator * Refactor select_candidates_in_gts * Refactor select_highest_overlaps * Refactor iou_calculator * Refactor all code * Improve docstr * Improve code * clean config * add nano tiny config * pre-commit * Refactor * Improve code * Improve naming and link * Add UT * pre commit * Add UT * Add UT * Improve code, using mmdet.BboxOverlaps2D for all iou calculation * Improve code, using mmdet.BboxOverlaps2D for all iou calculation * Improve code * pre commit * pre commit * Add UT * fix config * pre commit * Improve code * Improve code * Improve code * Improve code * [Refactor] YOLOv6 BatchATSSAssigner (#179) * Factor batch atss assigner * Format code * Format code * Roll back * Refactor dist_calculator * Refactor select_candidates_in_gts * Refactor select_highest_overlaps * Refactor iou_calculator * Refactor all code * Improve docstr * Improve code * Improve code * Improve naming and link * Add UT * pre commit * Add UT * Add UT * Improve code, using mmdet.BboxOverlaps2D for all iou calculation * Improve code, using mmdet.BboxOverlaps2D for all iou calculation * Improve code * pre commit * Fix conflicts * Improve code * Improve code * Improve code * Improve code * Improve code * Improve code * add utils.py, order the input param * Improve docstr * Fix lint * Improve param mapping * Improve param mapping * Improve naming * assigner return dict * update * update config * update config * Fix * Fix UT * Improve UT * Improve naming * Improve coding * pre commit * pre commit * pre commit * Fix ci * Improve naming * Improve coding * Fix training iou calculate error * Improve naming * Improve naming * Improve type hint * fix lint * fix conflicts * fix UT * Improve type hint * Improve naming * Improve coding * Improve coding * Fix UT * Refactor SIoU * Pre commit * Fix * Improve ciou * Improve ciou * refactor varifocal * Improve ciou * Improve ciou * Improve siou * Improve type hint * Improve siou * Improve siou * Fix lint * refactor varifocal * fix iou bug * fix siou and loss_cls bug * update * update * add scope * update * update * Improve func `gt_instances_preprocess` * support deploy mode * Improve func `gt_instances_preprocess` * Improve func `gt_instances_preprocess` * Improve func `gt_instances_preprocess` * Improve func `bbox_overlaps` * Improve coding * Improve bbox_overlaps * Delete useless code * add yolov6 deploy mode hook * fix lint * Add common attributes to reduce calculation * Improve code * Improve code * Fix bug * Fix bug * update * add readme * update readme * update readme url Co-authored-by: HinGwenWoong <peterhuang0323@qq.com>
2022-11-02 20:23:25 +08:00
from mmyolo.utils import register_all_modules, switch_to_deploy
from mmyolo.utils.labelme_utils import LabelmeFormat
from mmyolo.utils.misc import get_file_list, show_data_classes
2022-09-18 10:11:55 +08:00
def parse_args():
parser = ArgumentParser()
parser.add_argument(
'img', help='Image path, include image file, dir and URL.')
2022-09-18 10:11:55 +08:00
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--out-dir', default='./output', help='Path to output file')
2022-09-18 10:11:55 +08:00
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--show', action='store_true', help='Show the detection results')
2022-09-18 10:11:55 +08:00
parser.add_argument(
[Feature] Support YOLOv6 training (#183) * init v6 loss * init v6s train * Add train pipeline * Add lr scheduler * update * update * update * update * update * update * update * update * update * fix detach bug * fix detach bug * update * Add stop aug hook * Add save best ckpt * update * Add PipelineSwitchHook * Fix train pipeline stage 2 * update * Fix train pipeline * update * fix stage2 randomaffine bug update update clean clean * update letterResize param * add v6affine config * add v6 randomaffine * update v6 config * update * update * update * update * update config param * update * update * refactor iou loss % rm v6affine * update * rm dfl * add v6 300 epoch config * Factor batch atss assigner * Format code * Format code * Roll back * Refactor dist_calculator * Refactor select_candidates_in_gts * Refactor select_highest_overlaps * Refactor iou_calculator * Refactor all code * Improve docstr * Improve code * clean config * add nano tiny config * pre-commit * Refactor * Improve code * Improve naming and link * Add UT * pre commit * Add UT * Add UT * Improve code, using mmdet.BboxOverlaps2D for all iou calculation * Improve code, using mmdet.BboxOverlaps2D for all iou calculation * Improve code * pre commit * pre commit * Add UT * fix config * pre commit * Improve code * Improve code * Improve code * Improve code * [Refactor] YOLOv6 BatchATSSAssigner (#179) * Factor batch atss assigner * Format code * Format code * Roll back * Refactor dist_calculator * Refactor select_candidates_in_gts * Refactor select_highest_overlaps * Refactor iou_calculator * Refactor all code * Improve docstr * Improve code * Improve code * Improve naming and link * Add UT * pre commit * Add UT * Add UT * Improve code, using mmdet.BboxOverlaps2D for all iou calculation * Improve code, using mmdet.BboxOverlaps2D for all iou calculation * Improve code * pre commit * Fix conflicts * Improve code * Improve code * Improve code * Improve code * Improve code * Improve code * add utils.py, order the input param * Improve docstr * Fix lint * Improve param mapping * Improve param mapping * Improve naming * assigner return dict * update * update config * update config * Fix * Fix UT * Improve UT * Improve naming * Improve coding * pre commit * pre commit * pre commit * Fix ci * Improve naming * Improve coding * Fix training iou calculate error * Improve naming * Improve naming * Improve type hint * fix lint * fix conflicts * fix UT * Improve type hint * Improve naming * Improve coding * Improve coding * Fix UT * Refactor SIoU * Pre commit * Fix * Improve ciou * Improve ciou * refactor varifocal * Improve ciou * Improve ciou * Improve siou * Improve type hint * Improve siou * Improve siou * Fix lint * refactor varifocal * fix iou bug * fix siou and loss_cls bug * update * update * add scope * update * update * Improve func `gt_instances_preprocess` * support deploy mode * Improve func `gt_instances_preprocess` * Improve func `gt_instances_preprocess` * Improve func `gt_instances_preprocess` * Improve func `bbox_overlaps` * Improve coding * Improve bbox_overlaps * Delete useless code * add yolov6 deploy mode hook * fix lint * Add common attributes to reduce calculation * Improve code * Improve code * Fix bug * Fix bug * update * add readme * update readme * update readme url Co-authored-by: HinGwenWoong <peterhuang0323@qq.com>
2022-11-02 20:23:25 +08:00
'--deploy',
action='store_true',
help='Switch model to deployment mode')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='Bbox score threshold')
parser.add_argument(
'--class-name',
nargs='+',
type=str,
help='Only Save those classes if set')
parser.add_argument(
'--to-labelme',
action='store_true',
help='Output labelme style label file')
2022-09-18 10:11:55 +08:00
args = parser.parse_args()
return args
def main():
args = parse_args()
if args.to_labelme and args.show:
raise RuntimeError('`--to-labelme` or `--show` only '
'can choose one at the same time.')
2022-09-18 10:11:55 +08:00
# register all modules in mmdet into the registries
register_all_modules()
# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)
[Feature] Support YOLOv6 training (#183) * init v6 loss * init v6s train * Add train pipeline * Add lr scheduler * update * update * update * update * update * update * update * update * update * fix detach bug * fix detach bug * update * Add stop aug hook * Add save best ckpt * update * Add PipelineSwitchHook * Fix train pipeline stage 2 * update * Fix train pipeline * update * fix stage2 randomaffine bug update update clean clean * update letterResize param * add v6affine config * add v6 randomaffine * update v6 config * update * update * update * update * update config param * update * update * refactor iou loss % rm v6affine * update * rm dfl * add v6 300 epoch config * Factor batch atss assigner * Format code * Format code * Roll back * Refactor dist_calculator * Refactor select_candidates_in_gts * Refactor select_highest_overlaps * Refactor iou_calculator * Refactor all code * Improve docstr * Improve code * clean config * add nano tiny config * pre-commit * Refactor * Improve code * Improve naming and link * Add UT * pre commit * Add UT * Add UT * Improve code, using mmdet.BboxOverlaps2D for all iou calculation * Improve code, using mmdet.BboxOverlaps2D for all iou calculation * Improve code * pre commit * pre commit * Add UT * fix config * pre commit * Improve code * Improve code * Improve code * Improve code * [Refactor] YOLOv6 BatchATSSAssigner (#179) * Factor batch atss assigner * Format code * Format code * Roll back * Refactor dist_calculator * Refactor select_candidates_in_gts * Refactor select_highest_overlaps * Refactor iou_calculator * Refactor all code * Improve docstr * Improve code * Improve code * Improve naming and link * Add UT * pre commit * Add UT * Add UT * Improve code, using mmdet.BboxOverlaps2D for all iou calculation * Improve code, using mmdet.BboxOverlaps2D for all iou calculation * Improve code * pre commit * Fix conflicts * Improve code * Improve code * Improve code * Improve code * Improve code * Improve code * add utils.py, order the input param * Improve docstr * Fix lint * Improve param mapping * Improve param mapping * Improve naming * assigner return dict * update * update config * update config * Fix * Fix UT * Improve UT * Improve naming * Improve coding * pre commit * pre commit * pre commit * Fix ci * Improve naming * Improve coding * Fix training iou calculate error * Improve naming * Improve naming * Improve type hint * fix lint * fix conflicts * fix UT * Improve type hint * Improve naming * Improve coding * Improve coding * Fix UT * Refactor SIoU * Pre commit * Fix * Improve ciou * Improve ciou * refactor varifocal * Improve ciou * Improve ciou * Improve siou * Improve type hint * Improve siou * Improve siou * Fix lint * refactor varifocal * fix iou bug * fix siou and loss_cls bug * update * update * add scope * update * update * Improve func `gt_instances_preprocess` * support deploy mode * Improve func `gt_instances_preprocess` * Improve func `gt_instances_preprocess` * Improve func `gt_instances_preprocess` * Improve func `bbox_overlaps` * Improve coding * Improve bbox_overlaps * Delete useless code * add yolov6 deploy mode hook * fix lint * Add common attributes to reduce calculation * Improve code * Improve code * Fix bug * Fix bug * update * add readme * update readme * update readme url Co-authored-by: HinGwenWoong <peterhuang0323@qq.com>
2022-11-02 20:23:25 +08:00
if args.deploy:
switch_to_deploy(model)
if not args.show:
path.mkdir_or_exist(args.out_dir)
2022-09-18 10:11:55 +08:00
# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = model.dataset_meta
# get file list
files, source_type = get_file_list(args.img)
# get model class name
dataset_classes = model.dataset_meta.get('classes')
# ready for labelme format if it is needed
to_label_format = LabelmeFormat(classes=dataset_classes)
# check class name
if args.class_name is not None:
for class_name in args.class_name:
if class_name in dataset_classes:
continue
show_data_classes(dataset_classes)
raise RuntimeError(
'Expected args.class_name to be one of the list, '
f'but got "{class_name}"')
# start detector inference
progress_bar = ProgressBar(len(files))
for file in files:
result = inference_detector(model, file)
img = mmcv.imread(file)
img = mmcv.imconvert(img, 'bgr', 'rgb')
if source_type['is_dir']:
filename = os.path.relpath(file, args.img).replace('/', '_')
else:
filename = os.path.basename(file)
out_file = None if args.show else os.path.join(args.out_dir, filename)
progress_bar.update()
# Get candidate predict info with score threshold
pred_instances = result.pred_instances[
result.pred_instances.scores > args.score_thr]
if args.to_labelme:
# save result to labelme files
out_file = out_file.replace(
os.path.splitext(out_file)[-1], '.json')
to_label_format(pred_instances, result.metainfo, out_file,
args.class_name)
continue
visualizer.add_datasample(
filename,
img,
data_sample=result,
draw_gt=False,
show=args.show,
wait_time=0,
out_file=out_file,
pred_score_thr=args.score_thr)
if not args.show and not args.to_labelme:
print_log(
f'\nResults have been saved at {os.path.abspath(args.out_dir)}')
2022-09-18 10:11:55 +08:00
elif args.to_labelme:
print_log('\nLabelme format label files '
f'had all been saved in {args.out_dir}')
2022-09-18 10:11:55 +08:00
if __name__ == '__main__':
main()