mmyolo/tools/analysis_tools/browse_dataset.py

277 lines
9.7 KiB
Python
Raw Normal View History

2022-09-18 10:11:55 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import sys
from typing import Tuple
2022-09-18 10:11:55 +08:00
import cv2
import mmcv
2022-09-18 10:11:55 +08:00
import numpy as np
2022-09-20 10:57:33 +08:00
from mmdet.models.utils import mask2ndarray
from mmdet.structures.bbox import BaseBoxes
2022-09-18 10:11:55 +08:00
from mmengine.config import Config, DictAction
from mmengine.dataset import Compose
from mmengine.registry import init_default_scope
2022-09-18 10:11:55 +08:00
from mmengine.utils import ProgressBar
from mmengine.visualization import Visualizer
2022-09-18 10:11:55 +08:00
from mmyolo.registry import DATASETS, VISUALIZERS
# TODO: Support for printing the change in key of results
Support yolox-pose based on mmpose (#694) * add * reproduce map * add typehint and doc * format code * replace key * add ut * format * format * format code * fix ut * fix ut * fix comment * fix comment * fix comment * [WIP][Feature] Support yolov5-Ins training * fix comment * change data flow and fix loss_mask compute * align the data pipeline * remove albu gt mask key * support yolov5 ins inference * fix multi gpu test * align the post_process with v8 * support training * support training * code formatting * code formatting * Support pad_param type (#672) * add half_pad_param * fix default fast_test * fix loss weight compute * add models * add dataset1 * add dataset2 * add dataset3 * add configs * re commit __init__ * re commit __init__ * re commit * del local * add typo * del PoseToDetConverter and BBoxKeypoints * del local changes * fix mask rescale, add segment merge, fix segment2bbox * fix pipeline * add dataset * fix typo * add resize in mmyolo * fix typo * del local * del local changes * del local changes * fix dir name * fix dir name * add FilterAnnotations * fix typo * new config for yolox-pose * fix typo * fix typo * fix clip and fix mask init * del pose dataset changes * fix YOLOv5DetDataPreprocessor * del local file * fix typo * del init_cfg * simplify config * fix batch size * fix batch size * fix typo * code formatting * code formatting * code formatting * code formatting * fix bug for FilterAnnotations * simpler way for FilterAnnotations * update config * [Fix] fix load image from file * shorten eval time * fix typo * add large model * [Add] Add docs and more config * [Fix] config type and test_formatting * [Fix] fix yolov5-ins_m packdetinputs * hand rebase from yolov5-ins * use new PackDetInputs * rebase fix typo * add mapping table * fix typo * add weight * del typo * del typo * add results * install mmpose, Keypoints note, context manager, predict, ota rename * fix test * add unittest for pose_sim_ota_assigner and yolox_head * add unittest for pose_sim_ota_assigner and yolox_head * fix typo --------- Co-authored-by: Nioolek <379319054@qq.com> Co-authored-by: josonchan <josonchan1998@163.com> Co-authored-by: Nioolek <40284075+Nioolek@users.noreply.github.com> Co-authored-by: huanghaian <huanghaian@sensetime.com>
2023-05-15 10:58:25 +08:00
# TODO: Some bug. If you meet some bug, please use the original
2022-09-18 10:11:55 +08:00
def parse_args():
parser = argparse.ArgumentParser(description='Browse a dataset')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--phase',
'-p',
default='train',
type=str,
choices=['train', 'test', 'val'],
help='phase of dataset to visualize, accept "train" "test" and "val".'
' Defaults to "train".')
parser.add_argument(
'--mode',
'-m',
default='transformed',
type=str,
choices=['original', 'transformed', 'pipeline'],
help='display mode; display original pictures or '
'transformed pictures or comparison pictures. "original" '
'means show images load from disk; "transformed" means '
'to show images after transformed; "pipeline" means show all '
'the intermediate images. Defaults to "transformed".')
2022-09-18 10:11:55 +08:00
parser.add_argument(
'--out-dir',
default='output',
2022-09-18 10:11:55 +08:00
type=str,
help='If there is no display interface, you can save it.')
2022-09-18 10:11:55 +08:00
parser.add_argument('--not-show', default=False, action='store_true')
parser.add_argument(
'--show-number',
'-n',
type=int,
default=sys.maxsize,
help='number of images selected to visualize, '
'must bigger than 0. if the number is bigger than length '
'of dataset, show all the images in dataset; '
'default "sys.maxsize", show all images in dataset')
2022-09-18 10:11:55 +08:00
parser.add_argument(
'--show-interval',
'-i',
2022-09-18 10:11:55 +08:00
type=float,
default=3,
help='the interval of show (s)')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
return args
def _get_adaptive_scale(img_shape: Tuple[int, int],
min_scale: float = 0.3,
max_scale: float = 3.0) -> float:
"""Get adaptive scale according to image shape.
The target scale depends on the the short edge length of the image. If the
short edge length equals 224, the output is 1.0. And output linear
scales according the short edge length. You can also specify the minimum
scale and the maximum scale to limit the linear scale.
Args:
img_shape (Tuple[int, int]): The shape of the canvas image.
min_scale (int): The minimum scale. Defaults to 0.3.
max_scale (int): The maximum scale. Defaults to 3.0.
Returns:
int: The adaptive scale.
"""
short_edge_length = min(img_shape)
scale = short_edge_length / 224.
return min(max(scale, min_scale), max_scale)
def make_grid(imgs, names):
"""Concat list of pictures into a single big picture, align height here."""
visualizer = Visualizer.get_current_instance()
ori_shapes = [img.shape[:2] for img in imgs]
max_height = int(max(img.shape[0] for img in imgs) * 1.1)
min_width = min(img.shape[1] for img in imgs)
horizontal_gap = min_width // 10
img_scale = _get_adaptive_scale((max_height, min_width))
texts = []
text_positions = []
start_x = 0
for i, img in enumerate(imgs):
pad_height = (max_height - img.shape[0]) // 2
pad_width = horizontal_gap // 2
# make border
imgs[i] = cv2.copyMakeBorder(
img,
pad_height,
max_height - img.shape[0] - pad_height + int(img_scale * 30 * 2),
pad_width,
pad_width,
cv2.BORDER_CONSTANT,
value=(255, 255, 255))
texts.append(f'{"execution: "}{i}\n{names[i]}\n{ori_shapes[i]}')
text_positions.append(
[start_x + img.shape[1] // 2 + pad_width, max_height])
start_x += img.shape[1] + horizontal_gap
display_img = np.concatenate(imgs, axis=1)
visualizer.set_image(display_img)
img_scale = _get_adaptive_scale(display_img.shape[:2])
visualizer.draw_texts(
texts,
positions=np.array(text_positions),
font_sizes=img_scale * 7,
colors='black',
horizontal_alignments='center',
font_families='monospace')
return visualizer.get_image()
def swap_pipeline_position(dataset_cfg):
load_ann_tfm_name = 'LoadAnnotations'
pipeline = dataset_cfg.get('pipeline')
if (pipeline is None):
return dataset_cfg
all_transform_types = [tfm['type'] for tfm in pipeline]
if load_ann_tfm_name in all_transform_types:
load_ann_tfm_index = all_transform_types.index(load_ann_tfm_name)
load_ann_tfm = pipeline.pop(load_ann_tfm_index)
pipeline.insert(1, load_ann_tfm)
class InspectCompose(Compose):
"""Compose multiple transforms sequentially.
And record "img" field of all results in one list.
"""
def __init__(self, transforms, intermediate_imgs):
super().__init__(transforms=transforms)
self.intermediate_imgs = intermediate_imgs
def __call__(self, data):
if 'img' in data:
self.intermediate_imgs.append({
'name': 'original',
'img': data['img'].copy()
})
self.ptransforms = [
self.transforms[i] for i in range(len(self.transforms) - 1)
]
for t in self.ptransforms:
data = t(data)
# Keep the same meta_keys in the PackDetInputs
self.transforms[-1].meta_keys = [key for key in data]
data_sample = self.transforms[-1](data)
if data is None:
return None
if 'img' in data:
self.intermediate_imgs.append({
'name':
t.__class__.__name__,
'dataset_sample':
data_sample['data_samples']
})
return data
2022-09-18 10:11:55 +08:00
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
init_default_scope(cfg.get('default_scope', 'mmyolo'))
2022-09-18 10:11:55 +08:00
dataset_cfg = cfg.get(args.phase + '_dataloader').get('dataset')
if (args.phase in ['test', 'val']):
swap_pipeline_position(dataset_cfg)
dataset = DATASETS.build(dataset_cfg)
2022-09-18 10:11:55 +08:00
visualizer = VISUALIZERS.build(cfg.visualizer)
visualizer.dataset_meta = dataset.metainfo
intermediate_imgs = []
if not hasattr(dataset, 'pipeline'):
# for dataset_wrapper
dataset = dataset.dataset
# TODO: The dataset wrapper occasion is not considered here
dataset.pipeline = InspectCompose(dataset.pipeline.transforms,
intermediate_imgs)
# init visualization image number
assert args.show_number > 0
display_number = min(args.show_number, len(dataset))
progress_bar = ProgressBar(display_number)
for i, item in zip(range(display_number), dataset):
image_i = []
result_i = [result['dataset_sample'] for result in intermediate_imgs]
for k, datasample in enumerate(result_i):
image = datasample.img
gt_instances = datasample.gt_instances
image = image[..., [2, 1, 0]] # bgr to rgb
gt_bboxes = gt_instances.get('bboxes', None)
if gt_bboxes is not None and isinstance(gt_bboxes, BaseBoxes):
gt_instances.bboxes = gt_bboxes.tensor
gt_masks = gt_instances.get('masks', None)
if gt_masks is not None:
masks = mask2ndarray(gt_masks)
2022-12-21 10:11:13 +08:00
gt_instances.masks = masks.astype(bool)
datasample.gt_instances = gt_instances
# get filename from dataset or just use index as filename
visualizer.add_datasample(
'result',
image,
datasample,
draw_pred=False,
draw_gt=True,
show=False)
image_show = visualizer.get_image()
image_i.append(image_show)
if args.mode == 'original':
image = image_i[0]
elif args.mode == 'transformed':
image = image_i[-1]
else:
image = make_grid([result for result in image_i],
[result['name'] for result in intermediate_imgs])
if hasattr(datasample, 'img_path'):
filename = osp.basename(datasample.img_path)
else:
# some dataset have not image path
filename = f'{i}.jpg'
out_file = osp.join(args.out_dir,
filename) if args.out_dir is not None else None
if out_file is not None:
mmcv.imwrite(image[..., ::-1], out_file)
if not args.not_show:
visualizer.show(
image, win_name=filename, wait_time=args.show_interval)
2022-09-18 10:11:55 +08:00
intermediate_imgs.clear()
2022-09-18 10:11:55 +08:00
progress_bar.update()
if __name__ == '__main__':
main()