mmfewshot/tools/detection/misc/visualize_saved_dataset.py

270 lines
10 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
"""Visualized instances of saved dataset.
Example:
python tools/detection/misc/visualize_saved_dataset.py \
--src ./work_dirs/xx_saved_data.json
--dir ./vis_images
"""
import argparse
import json
import os
import mmcv
import numpy as np
from terminaltables import AsciiTable
try:
import cv2
except ():
raise ImportError('please install cv2 mutually')
class Visualizer:
"""Visualize instances of saved dataset.
Args:
src (str): Path to saved dataset.
out_dir (str): Saving directory for output image. Default: ''.
classes (list[str]): Classes of saved dataset. Default: None.
img_prefix (str): Prefix for images path. Default: ''.
"""
def __init__(self, src, out_dir='', classes=None, img_prefix=''):
if classes is None:
classes = []
self.CLASSES = classes
self.img_prefix = img_prefix
self.ann_file = src
self.out_dir = out_dir
self.data_infos = self.load_annotations_saved(src)
mmcv.mkdir_or_exist(os.path.abspath(out_dir))
self.color_map = np.random.randint(0, 255,
(len(self.CLASSES), 3)).tolist()
def __repr__(self):
"""Print the number of instance number."""
result = (f'\n dataset statistics'
f'with number of images {len(self.data_infos)}, '
f'and instance counts: \n')
if len(self.CLASSES) == 0:
result += 'Category names are not provided. \n'
return result
instance_count = np.zeros(len(self.CLASSES) + 1).astype(int)
# count the instance number in each image
for data_info in self.data_infos:
label = data_info['ann']['labels']
unique, counts = np.unique(label, return_counts=True)
if len(unique) > 0:
# add the occurrence number to each class
instance_count[unique] += counts
else:
# background is the last index
instance_count[-1] += 1
# create a table with category count
table_data = [['category', 'count'] * 5]
row_data = []
for cls, count in enumerate(instance_count):
if cls < len(self.CLASSES):
row_data += [f'{cls} [{self.CLASSES[cls]}]', f'{count}']
else:
# add the background number
row_data += ['-1 background', f'{count}']
if len(row_data) == 10:
table_data.append(row_data)
row_data = []
table = AsciiTable(table_data)
result += table.table
return result
def visualize(self,
save_name='',
num_rows=20,
num_cols=10,
instance_size=256):
"""Visualizing cropped instances in grid layout.
Instances of same classes will be placed in same row, If the number
of total classes is larger than `num_rows`, then it will output
multiple images. If the number of instances is larger than `num_cols`,
then the exceeded instances will not be visualized.
Args:
save_name (str): Name of output image. Default: ''.
num_rows (int): Number of rows (classes) in single output image.
Default: 20.
num_cols (int): Number of column (instances of each class) in
single output image. Default: 10.
instance_size (int): Size of cropped instance. Default: 256.
"""
if save_name == '':
save_name = os.path.split(self.ann_file)[1].split('.')[0]
instances = {i: [] for i in range(len(self.CLASSES))}
for data_info in self.data_infos:
labels = data_info['ann']['labels']
bboxes = data_info['ann']['bboxes']
file_name = data_info['filename']
image_path = os.path.join(self.img_prefix, file_name)
img = cv2.imread(image_path)
for i in range(len(bboxes)):
bbox = list(map(int, bboxes[i]))
cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]),
self.color_map[labels[i]], 2)
for i in range(len(labels)):
croped_instance = self.crop_instance(img, bboxes[i],
instance_size)
instances[int(labels[i])].append(croped_instance)
start, end = 0, 0
classes = list(instances.keys())
while end < len(classes):
if start + num_rows < len(classes):
end = start + num_rows
else:
end = len(classes)
instances_ = {i: instances[i] for i in classes[start:end]}
image_name = f'{save_name}_class_{start}_to_{end}.jpg'
img_out = self.concat_images(instances_, num_cols, instance_size)
image_path = os.path.join(self.out_dir, image_name)
cv2.imwrite(image_path, img_out)
print(f'image {image_name} is saved to {self.out_dir}')
start = end
@staticmethod
def crop_instance(img, bbox, instance_size=256):
"""Crop and resize instance.
Args:
img (numpy.ndarray): Image to be cropped.
bbox (list[int]): BBox of instance.
instance_size (int): Resize cropped instance to `instance_size`.
Default: 256.
Returns:
numpy.ndarray: cropped instance.
"""
bbox = list(map(int, bbox))
h, w = img.shape[0], img.shape[1]
b_h = bbox[3] - bbox[1]
b_w = bbox[2] - bbox[0]
pad_y = (instance_size -
b_h) // 2 if b_h < instance_size else b_h * 0.1
pad_x = (instance_size -
b_w) // 2 if b_w < instance_size else b_w * 0.1
if b_h > b_w:
pad_y = int(pad_y)
target_size = pad_y * 2 + b_h
pad_x = int((target_size - b_w) // 2)
region = np.zeros((target_size + 4, target_size + 4, 3), np.uint8)
else:
pad_x = int(pad_x)
target_size = pad_x * 2 + b_w
pad_y = int((target_size - b_h) // 2)
region = np.zeros((target_size + 4, target_size + 4, 3), np.uint8)
y0 = bbox[1] - pad_y if bbox[1] - pad_y > 0 else 0
y1 = bbox[3] + pad_y if bbox[3] + pad_y < h else h
x0 = bbox[0] - pad_x if bbox[0] - pad_x > 0 else 0
x1 = bbox[2] + pad_x if bbox[2] + pad_x < w else w
t_y0 = (target_size - (y1 - y0)) // 2
t_y1 = t_y0 + (y1 - y0)
t_x0 = (target_size - (x1 - x0)) // 2
t_x1 = t_x0 + (x1 - x0)
region[t_y0:t_y1, t_x0:t_x1] = img[y0:y1, x0:x1]
region = cv2.resize(region, (instance_size, instance_size))
return region
def concat_images(self, cropped_images, num_cols, offset):
"""Concat cropped cropped_images in grid layout.
Args:
cropped_images (dict): Images of cropped instance.
num_cols (int): Number of columns of grid layout.
offset (int): Size of each cropped instance.
Returns:
numpy.ndarray: visualized image.
"""
num_rows = len(cropped_images.keys())
num_cols = num_cols + 1
canvas = np.zeros((num_rows * offset, num_cols * offset, 3), np.uint8)
for row, label in enumerate(cropped_images.keys()):
label_image = np.full((offset, offset, 3), 255, dtype=np.uint8)
class_name = self.CLASSES[label]
label_image = cv2.putText(label_image, class_name,
(10, offset // 2),
cv2.FONT_HERSHEY_COMPLEX, 1, (0, 0, 0),
2, cv2.LINE_AA)
canvas[offset * row:offset * (row + 1), 0:offset, :] = label_image
for col in range(1, num_cols):
if col > len(cropped_images[label]):
img = np.zeros((offset, offset, 3), np.uint8)
else:
img = cropped_images[label][col - 1]
canvas[offset * row:offset * (row + 1),
offset * col:offset * (col + 1), :] = img
return canvas
def load_annotations_saved(self, ann_file):
"""Load data_infos from saved json."""
with open(ann_file) as f:
data_infos = json.load(f)
meta_index = None
for i, data_info in enumerate(data_infos):
if 'CLASSES' in data_info.keys():
self.CLASSES = data_info['CLASSES']
if 'img_prefix' in data_info.keys():
self.img_prefix = data_info['img_prefix']
meta_index = i
continue
for k in data_info['ann']:
if isinstance(data_info['ann'][k], list):
if len(data_info['ann'][k]) == 0 and k == 'bboxes_ignore':
data_info['ann'][k] = np.zeros((0, 4))
else:
data_info['ann'][k] = np.array(data_info['ann'][k])
if 'box' in k:
data_info['ann'][k] = data_info['ann'][k].astype(
np.float32)
else:
data_info['ann'][k] = data_info['ann'][k].astype(
np.int64)
assert len(self.CLASSES) > 0, 'missing CLASSES for saved dataset json.'
if meta_index is not None:
data_infos.pop(meta_index)
return data_infos
def parse_args():
parser = argparse.ArgumentParser(
description='Visualize a saved FewShot Dataset')
parser.add_argument(
'--src', type=str, help='saved few shot dataset file path')
parser.add_argument(
'--dir', type=str, help='output dir to save visualize images')
parser.add_argument(
'--save-name',
default='',
type=str,
help='saved name of visualize images')
parser.add_argument(
'--row',
default=20,
type=int,
help='number of classes to show in one image')
parser.add_argument(
'--col',
default=10,
type=int,
help='number of instance to show for each class')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
visualizer = Visualizer(args.src, args.dir)
print(visualizer)
visualizer.visualize(args.save_name, num_rows=args.row, num_cols=args.col)