Add browse dataset

pull/913/head
Ezra-Yu 2022-05-30 03:11:44 +00:00 committed by mzr1996
parent efaa93aae1
commit 4d4b22377d
6 changed files with 207 additions and 25 deletions

1
.gitignore vendored
View File

@ -21,6 +21,7 @@ parts/
sdist/
var/
wheels/
outputs/
*.egg-info/
.installed.cfg
*.egg

View File

@ -1,40 +1,47 @@
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
# file_client_args = dict(
# backend='petrel',
# path_mapping=dict({
# './data/': 's3://openmmlab/datasets/classification/',
# 'data/': 's3://openmmlab/datasets/classification/'
# }))
file_client_args = dict(backend='disk')
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='RandomResizedCrop', size=224),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackClsInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='Resize', scale=(256, -1), keep_ratio=True),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
dict(type='PackClsInputs')
]
data = dict(
samples_per_gpu=32,
workers_per_gpu=2,
train=dict(
type=dataset_type,
data_prefix='data/imagenet/train',
pipeline=train_pipeline),
val=dict(
train_dataloader = dict(
batch_size=32,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_prefix='data/imagenet/val',
ann_file='data/imagenet/meta/val.txt',
pipeline=test_pipeline),
test=dict(
# replace `data/val` with `data/test` for standard test
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=32,
num_workers=2,
persistent_workers=True,
dataset=dict(
type=dataset_type,
data_prefix='data/imagenet/val',
ann_file='data/imagenet/meta/val.txt',
pipeline=test_pipeline))
test_dataloader = val_dataloader
evaluation = dict(interval=1, metric='accuracy')

View File

@ -27,6 +27,10 @@ env_cfg = dict(
dist_cfg=dict(backend='nccl'),
)
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='ClsVisualizer', vis_backends=vis_backends, name='visualizer')
# Log level configuration
log_level = 'INFO'

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
from typing import Optional, Tuple
import mmcv
import numpy as np
@ -10,6 +10,31 @@ from mmcls.core import ClsDataSample
from mmcls.registry import VISUALIZERS
def _get_adaptive_scale(img_shape: Tuple[int, int],
min_scale: float = 0.3,
max_scale: float = 3.0) -> float:
"""Get adaptive scale according to image shape.
The target scale depends on the the short edge length of the image. If the
short edge length equals 224, the output is 1.0. And output linear scales
according the short edge length.
You can also specify the minimum scale and the maximum scale to limit the
linear scale.
Args:
img_shape (Tuple[int, int]): The shape of the canvas image.
min_size (int): The minimum scale. Defaults to 0.3.
max_size (int): The maximum scale. Defaults to 3.0.
Returns:
int: The adaptive scale.
"""
short_edge_length = min(img_shape)
scale = short_edge_length / 224.
return min(max(scale, min_scale), max_scale)
@VISUALIZERS.register_module()
class ClsVisualizer(Visualizer):
"""Universal Visualizer for classification task.
@ -62,6 +87,7 @@ class ClsVisualizer(Visualizer):
draw_gt: bool = True,
draw_pred: bool = True,
draw_score: bool = True,
rescale_factor: Optional[float] = None,
show: bool = False,
text_cfg: dict = dict(),
wait_time: float = 0,
@ -86,6 +112,8 @@ class ClsVisualizer(Visualizer):
Default to True.
draw_score (bool): Whether to draw the prediction scores
of prediction categories. Default to True.
rescale_factor (float, optional): Rescale the image by the rescale
factor before visualization. Defaults to None.
show (bool): Whether to display the drawn image. Default to False.
text_cfg (dict): Extra text setting, which accepts
arguments of :attr:`mmengine.Visualizer.draw_texts`.
@ -101,6 +129,9 @@ class ClsVisualizer(Visualizer):
if self.dataset_meta is not None:
classes = self.dataset_meta.get('CLASSES', None)
if rescale_factor is not None:
image = mmcv.imrescale(image, rescale_factor)
texts = []
self.set_image(image)
@ -134,8 +165,10 @@ class ClsVisualizer(Visualizer):
prefix = 'Prediction: '
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
img_scale = _get_adaptive_scale(image.shape[:2])
text_cfg = {
'positions': np.array([(5, 5)]),
'positions': np.array([(img_scale * 5, ) * 2]).astype(np.int32),
'font_sizes': int(img_scale * 7),
'font_families': 'monospace',
'colors': 'white',
'bboxes': dict(facecolor='black', alpha=0.5, boxstyle='Round'),

View File

@ -94,5 +94,39 @@ class TestClsVisualizer(TestCase):
self.vis.add_datasample(
'test', image=image, data_sample=data_sample)
# Test adaptive font size
def assert_font_size(target_size):
def draw_texts(text, font_sizes, *_, **__):
self.assertEqual(font_sizes, target_size)
return draw_texts
with patch.object(self.vis, 'draw_texts', assert_font_size(7)):
self.vis.add_datasample(
'test',
image=np.ones((224, 384, 3), np.uint8),
data_sample=data_sample)
with patch.object(self.vis, 'draw_texts', assert_font_size(2)):
self.vis.add_datasample(
'test',
image=np.ones((10, 384, 3), np.uint8),
data_sample=data_sample)
with patch.object(self.vis, 'draw_texts', assert_font_size(21)):
self.vis.add_datasample(
'test',
image=np.ones((1000, 1000, 3), np.uint8),
data_sample=data_sample)
# Test rescale image
with patch.object(self.vis, 'draw_texts', assert_font_size(14)):
self.vis.add_datasample(
'test',
image=np.ones((224, 384, 3), np.uint8),
rescale_factor=2.,
data_sample=data_sample)
def tearDown(self):
self.tmpdir.cleanup()

View File

@ -0,0 +1,103 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import itertools
import os.path as osp
import sys
import mmcv
from mmcv import Config, DictAction
from mmcls.datasets.builder import build_dataset
from mmcls.registry import VISUALIZERS
from mmcls.utils import register_all_modules
def parse_args():
parser = argparse.ArgumentParser(description='Browse a dataset')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--output-dir',
default='./outputs',
type=str,
help='If there is no display interface, you can save it')
parser.add_argument('--not-show', default=False, action='store_true')
parser.add_argument(
'--phase',
default='train',
type=str,
choices=['train', 'test', 'val'],
help='phase of dataset to visualize, accept "train" "test" and "val".'
' Default train.')
parser.add_argument(
'--show-number',
type=int,
default=sys.maxsize,
help='number of images selected to visualize, must bigger than 0. if '
'the number is bigger than length of dataset, show all the images in '
'dataset; default "sys.maxsize", show all images in dataset')
parser.add_argument(
'--show-interval',
type=float,
default=2,
help='the interval of show (s)')
parser.add_argument(
'--rescale-factor',
type=float,
help='image rescale factor, which is useful if the output is too '
'large or too small.')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# register all modules in mmdet into the registries
register_all_modules()
dataloader = cfg[f'{args.phase}_dataloader']
dataset = build_dataset(dataloader.dataset)
cfg.visualizer.save_dir = args.output_dir
visualizer = VISUALIZERS.build(cfg.visualizer)
visualizer.dataset_meta = dataset.metainfo
display_number = min(args.show_number, len(dataset))
progress_bar = mmcv.ProgressBar(display_number)
for item in itertools.islice(dataset, display_number):
img = item['inputs'].permute(1, 2, 0).numpy()
data_sample = item['data_sample'].numpy()
img_path = osp.basename(item['data_sample'].img_path)
out_file = osp.join(
args.output_dir,
osp.basename(img_path)) if args.output_dir is not None else None
img = img[..., [2, 1, 0]] # bgr to rgb
visualizer.add_datasample(
osp.basename(img_path),
img,
data_sample,
rescale_factor=args.rescale_factor,
show=not args.not_show,
wait_time=args.show_interval,
out_file=out_file)
progress_bar.update()
if __name__ == '__main__':
main()