Add browse dataset
parent
efaa93aae1
commit
4d4b22377d
|
@ -21,6 +21,7 @@ parts/
|
|||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
outputs/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
|
|
@ -1,40 +1,47 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
|
||||
# file_client_args = dict(
|
||||
# backend='petrel',
|
||||
# path_mapping=dict({
|
||||
# './data/': 's3://openmmlab/datasets/classification/',
|
||||
# 'data/': 's3://openmmlab/datasets/classification/'
|
||||
# }))
|
||||
|
||||
file_client_args = dict(backend='disk')
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(type='RandomResizedCrop', size=224),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackClsInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(type='Resize', scale=(256, -1), keep_ratio=True),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
dict(type='PackClsInputs')
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=32,
|
||||
workers_per_gpu=2,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/train',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=2,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/val',
|
||||
ann_file='data/imagenet/meta/val.txt',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
# replace `data/val` with `data/test` for standard test
|
||||
pipeline=train_pipeline))
|
||||
val_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=2,
|
||||
persistent_workers=True,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/val',
|
||||
ann_file='data/imagenet/meta/val.txt',
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
evaluation = dict(interval=1, metric='accuracy')
|
||||
|
|
|
@ -27,6 +27,10 @@ env_cfg = dict(
|
|||
dist_cfg=dict(backend='nccl'),
|
||||
)
|
||||
|
||||
vis_backends = [dict(type='LocalVisBackend')]
|
||||
visualizer = dict(
|
||||
type='ClsVisualizer', vis_backends=vis_backends, name='visualizer')
|
||||
|
||||
# Log level configuration
|
||||
log_level = 'INFO'
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
@ -10,6 +10,31 @@ from mmcls.core import ClsDataSample
|
|||
from mmcls.registry import VISUALIZERS
|
||||
|
||||
|
||||
def _get_adaptive_scale(img_shape: Tuple[int, int],
|
||||
min_scale: float = 0.3,
|
||||
max_scale: float = 3.0) -> float:
|
||||
"""Get adaptive scale according to image shape.
|
||||
|
||||
The target scale depends on the the short edge length of the image. If the
|
||||
short edge length equals 224, the output is 1.0. And output linear scales
|
||||
according the short edge length.
|
||||
|
||||
You can also specify the minimum scale and the maximum scale to limit the
|
||||
linear scale.
|
||||
|
||||
Args:
|
||||
img_shape (Tuple[int, int]): The shape of the canvas image.
|
||||
min_size (int): The minimum scale. Defaults to 0.3.
|
||||
max_size (int): The maximum scale. Defaults to 3.0.
|
||||
|
||||
Returns:
|
||||
int: The adaptive scale.
|
||||
"""
|
||||
short_edge_length = min(img_shape)
|
||||
scale = short_edge_length / 224.
|
||||
return min(max(scale, min_scale), max_scale)
|
||||
|
||||
|
||||
@VISUALIZERS.register_module()
|
||||
class ClsVisualizer(Visualizer):
|
||||
"""Universal Visualizer for classification task.
|
||||
|
@ -62,6 +87,7 @@ class ClsVisualizer(Visualizer):
|
|||
draw_gt: bool = True,
|
||||
draw_pred: bool = True,
|
||||
draw_score: bool = True,
|
||||
rescale_factor: Optional[float] = None,
|
||||
show: bool = False,
|
||||
text_cfg: dict = dict(),
|
||||
wait_time: float = 0,
|
||||
|
@ -86,6 +112,8 @@ class ClsVisualizer(Visualizer):
|
|||
Default to True.
|
||||
draw_score (bool): Whether to draw the prediction scores
|
||||
of prediction categories. Default to True.
|
||||
rescale_factor (float, optional): Rescale the image by the rescale
|
||||
factor before visualization. Defaults to None.
|
||||
show (bool): Whether to display the drawn image. Default to False.
|
||||
text_cfg (dict): Extra text setting, which accepts
|
||||
arguments of :attr:`mmengine.Visualizer.draw_texts`.
|
||||
|
@ -101,6 +129,9 @@ class ClsVisualizer(Visualizer):
|
|||
if self.dataset_meta is not None:
|
||||
classes = self.dataset_meta.get('CLASSES', None)
|
||||
|
||||
if rescale_factor is not None:
|
||||
image = mmcv.imrescale(image, rescale_factor)
|
||||
|
||||
texts = []
|
||||
self.set_image(image)
|
||||
|
||||
|
@ -134,8 +165,10 @@ class ClsVisualizer(Visualizer):
|
|||
prefix = 'Prediction: '
|
||||
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
|
||||
|
||||
img_scale = _get_adaptive_scale(image.shape[:2])
|
||||
text_cfg = {
|
||||
'positions': np.array([(5, 5)]),
|
||||
'positions': np.array([(img_scale * 5, ) * 2]).astype(np.int32),
|
||||
'font_sizes': int(img_scale * 7),
|
||||
'font_families': 'monospace',
|
||||
'colors': 'white',
|
||||
'bboxes': dict(facecolor='black', alpha=0.5, boxstyle='Round'),
|
||||
|
|
|
@ -94,5 +94,39 @@ class TestClsVisualizer(TestCase):
|
|||
self.vis.add_datasample(
|
||||
'test', image=image, data_sample=data_sample)
|
||||
|
||||
# Test adaptive font size
|
||||
def assert_font_size(target_size):
|
||||
|
||||
def draw_texts(text, font_sizes, *_, **__):
|
||||
self.assertEqual(font_sizes, target_size)
|
||||
|
||||
return draw_texts
|
||||
|
||||
with patch.object(self.vis, 'draw_texts', assert_font_size(7)):
|
||||
self.vis.add_datasample(
|
||||
'test',
|
||||
image=np.ones((224, 384, 3), np.uint8),
|
||||
data_sample=data_sample)
|
||||
|
||||
with patch.object(self.vis, 'draw_texts', assert_font_size(2)):
|
||||
self.vis.add_datasample(
|
||||
'test',
|
||||
image=np.ones((10, 384, 3), np.uint8),
|
||||
data_sample=data_sample)
|
||||
|
||||
with patch.object(self.vis, 'draw_texts', assert_font_size(21)):
|
||||
self.vis.add_datasample(
|
||||
'test',
|
||||
image=np.ones((1000, 1000, 3), np.uint8),
|
||||
data_sample=data_sample)
|
||||
|
||||
# Test rescale image
|
||||
with patch.object(self.vis, 'draw_texts', assert_font_size(14)):
|
||||
self.vis.add_datasample(
|
||||
'test',
|
||||
image=np.ones((224, 384, 3), np.uint8),
|
||||
rescale_factor=2.,
|
||||
data_sample=data_sample)
|
||||
|
||||
def tearDown(self):
|
||||
self.tmpdir.cleanup()
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import itertools
|
||||
import os.path as osp
|
||||
import sys
|
||||
|
||||
import mmcv
|
||||
from mmcv import Config, DictAction
|
||||
|
||||
from mmcls.datasets.builder import build_dataset
|
||||
from mmcls.registry import VISUALIZERS
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Browse a dataset')
|
||||
parser.add_argument('config', help='train config file path')
|
||||
parser.add_argument(
|
||||
'--output-dir',
|
||||
default='./outputs',
|
||||
type=str,
|
||||
help='If there is no display interface, you can save it')
|
||||
parser.add_argument('--not-show', default=False, action='store_true')
|
||||
parser.add_argument(
|
||||
'--phase',
|
||||
default='train',
|
||||
type=str,
|
||||
choices=['train', 'test', 'val'],
|
||||
help='phase of dataset to visualize, accept "train" "test" and "val".'
|
||||
' Default train.')
|
||||
parser.add_argument(
|
||||
'--show-number',
|
||||
type=int,
|
||||
default=sys.maxsize,
|
||||
help='number of images selected to visualize, must bigger than 0. if '
|
||||
'the number is bigger than length of dataset, show all the images in '
|
||||
'dataset; default "sys.maxsize", show all images in dataset')
|
||||
parser.add_argument(
|
||||
'--show-interval',
|
||||
type=float,
|
||||
default=2,
|
||||
help='the interval of show (s)')
|
||||
parser.add_argument(
|
||||
'--rescale-factor',
|
||||
type=float,
|
||||
help='image rescale factor, which is useful if the output is too '
|
||||
'large or too small.')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
cfg = Config.fromfile(args.config)
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# register all modules in mmdet into the registries
|
||||
register_all_modules()
|
||||
dataloader = cfg[f'{args.phase}_dataloader']
|
||||
dataset = build_dataset(dataloader.dataset)
|
||||
|
||||
cfg.visualizer.save_dir = args.output_dir
|
||||
visualizer = VISUALIZERS.build(cfg.visualizer)
|
||||
visualizer.dataset_meta = dataset.metainfo
|
||||
|
||||
display_number = min(args.show_number, len(dataset))
|
||||
progress_bar = mmcv.ProgressBar(display_number)
|
||||
|
||||
for item in itertools.islice(dataset, display_number):
|
||||
img = item['inputs'].permute(1, 2, 0).numpy()
|
||||
data_sample = item['data_sample'].numpy()
|
||||
img_path = osp.basename(item['data_sample'].img_path)
|
||||
|
||||
out_file = osp.join(
|
||||
args.output_dir,
|
||||
osp.basename(img_path)) if args.output_dir is not None else None
|
||||
|
||||
img = img[..., [2, 1, 0]] # bgr to rgb
|
||||
visualizer.add_datasample(
|
||||
osp.basename(img_path),
|
||||
img,
|
||||
data_sample,
|
||||
rescale_factor=args.rescale_factor,
|
||||
show=not args.not_show,
|
||||
wait_time=args.show_interval,
|
||||
out_file=out_file)
|
||||
|
||||
progress_bar.update()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue