[Feature]: support image visualization for tensorboard and wandb (#15)

* [Feature]: support image visualization for tensorboard and wandb
pull/21/head^2
Cathy0908 2022-04-21 20:48:58 +08:00 committed by GitHub
parent 3a4a3a9d0c
commit 9a3826f0d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 487 additions and 104 deletions

View File

@ -1,2 +1,3 @@
recursive-include easycv/configs *.py recursive-include easycv/configs *.py
recursive-include easycv/tools *.py recursive-include easycv/tools *.py
recursive-include easycv/resource/ *.ttf

View File

@ -23,4 +23,5 @@ data = dict(
dict(type='CenterCrop', size=224), dict(type='CenterCrop', size=224),
dict(type='ToTensor'), dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),
dict(type='Collect', keys=['img', 'gt_labels'])
])) ]))

View File

@ -130,7 +130,14 @@ custom_hooks = [
] ]
# evaluation # evaluation
eval_config = dict(interval=10, gpu_collect=False) eval_config = dict(
interval=10,
gpu_collect=False,
visualization_config=dict(
vis_num=10,
score_thr=0.5,
) # show by TensorboardLoggerHookV2 and WandbLoggerHookV2
)
eval_pipelines = [ eval_pipelines = [
dict( dict(
mode='test', mode='test',
@ -168,7 +175,8 @@ log_config = dict(
interval=100, interval=100,
hooks=[ hooks=[
dict(type='TextLoggerHook'), dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook') dict(type='TensorboardLoggerHookV2'),
# dict(type='WandbLoggerHookV2'),
]) ])
export = dict(use_jit=False) export = dict(use_jit=False)

View File

@ -1,4 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
from .image import imshow_bboxes, imshow_keypoints from .image import imshow_bboxes, imshow_keypoints, imshow_label
__all__ = ['imshow_bboxes', 'imshow_keypoints'] __all__ = ['imshow_bboxes', 'imshow_keypoints', 'imshow_label']

View File

@ -1,11 +1,103 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
# Adapt from https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/visualization/image.py # Adapt from https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/visualization/image.py
import math import math
import os
from os.path import dirname as opd
import cv2 import cv2
import mmcv import mmcv
import numpy as np import numpy as np
from mmcv.utils.misc import deprecated_api_warning from mmcv.utils.misc import deprecated_api_warning
from PIL import Image, ImageDraw, ImageFont
def get_font_path():
root_path = opd(opd(opd(os.path.realpath(__file__))))
# find in whl
find_path_whl = os.path.join(root_path, 'resource/simhei.ttf')
# find in source code
find_path_source = os.path.join(opd(root_path), 'resource/simhei.ttf')
if os.path.exists(find_path_whl):
return find_path_whl
elif os.path.exists(find_path_source):
return find_path_source
else:
raise ValueError('Not find font file both in %s and %s' %
(find_path_whl, find_path_source))
_FONT_PATH = get_font_path()
def put_text(img, xy, text, fill, size=20):
"""support chinese text
"""
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(img)
fontText = ImageFont.truetype(_FONT_PATH, size=size, encoding='utf-8')
draw.text(xy, text, fill=fill, font=fontText)
img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
return img
def imshow_label(img,
labels,
text_color='blue',
font_size=20,
thickness=1,
font_scale=0.5,
intervel=5,
show=True,
win_name='',
wait_time=0,
out_file=None):
"""Draw images with labels on an image.
Args:
img (str or ndarray): The image to be displayed.
labels (str or list[str]): labels of each image.
text_color (str or tuple or :obj:`Color`): Color of texts.
font_size (int): Size of font.
thickness (int): Thickness of lines.
font_scale (float): Font scales of texts.
intervelint): interval pixels between multiple labels
show (bool): Whether to show the image.
win_name (str): The window name.
wait_time (int): Value of waitKey param.
out_file (str, optional): The filename to write the image.
Returns:
ndarray: The image with bboxes drawn on it.
"""
img = mmcv.imread(img)
img = np.ascontiguousarray(img)
labels = [labels] if isinstance(labels, str) else labels
cur_height = 0
for label in labels:
# roughly estimate the proper font size
text_size, text_baseline = cv2.getTextSize(label,
cv2.FONT_HERSHEY_DUPLEX,
font_scale, thickness)
org = (text_baseline + text_size[1],
text_baseline + text_size[1] + cur_height)
# support chinese text
# TODO: Unify the font of cv2 and PIL, and auto get font_size according to the font_scale
img = put_text(img, org, text=label, fill=text_color, size=font_size)
# cv2.putText(img, label, org, cv2.FONT_HERSHEY_DUPLEX, font_scale,
# mmcv.color_val(text_color), thickness)
cur_height += text_baseline + text_size[1] + intervel
if show:
mmcv.imshow(img, win_name, wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)
return img
def imshow_bboxes(img, def imshow_bboxes(img,
@ -13,6 +105,7 @@ def imshow_bboxes(img,
labels=None, labels=None,
colors='green', colors='green',
text_color='white', text_color='white',
font_size=20,
thickness=1, thickness=1,
font_scale=0.5, font_scale=0.5,
show=True, show=True,
@ -29,6 +122,7 @@ def imshow_bboxes(img,
labels (str or list[str], optional): labels of each bbox. labels (str or list[str], optional): labels of each bbox.
colors (list[str or tuple or :obj:`Color`]): A list of colors. colors (list[str or tuple or :obj:`Color`]): A list of colors.
text_color (str or tuple or :obj:`Color`): Color of texts. text_color (str or tuple or :obj:`Color`): Color of texts.
font_size (int): Size of font.
thickness (int): Thickness of lines. thickness (int): Thickness of lines.
font_scale (float): Font scales of texts. font_scale (float): Font scales of texts.
show (bool): Whether to show the image. show (bool): Whether to show the image.
@ -58,11 +152,10 @@ def imshow_bboxes(img,
out_file=None) out_file=None)
if labels is not None: if labels is not None:
if not isinstance(labels, list):
labels = [labels for _ in range(len(bboxes))]
assert len(labels) == len(bboxes) assert len(labels) == len(bboxes)
for bbox, label, color in zip(bboxes, labels, colors): for bbox, label, color in zip(bboxes, labels, colors):
label = str(label)
bbox_int = bbox[0, :4].astype(np.int32) bbox_int = bbox[0, :4].astype(np.int32)
# roughly estimate the proper font size # roughly estimate the proper font size
text_size, text_baseline = cv2.getTextSize(label, text_size, text_baseline = cv2.getTextSize(label,
@ -74,9 +167,17 @@ def imshow_bboxes(img,
text_y2 = text_y1 + text_size[1] + text_baseline text_y2 = text_y1 + text_size[1] + text_baseline
cv2.rectangle(img, (text_x1, text_y1), (text_x2, text_y2), color, cv2.rectangle(img, (text_x1, text_y1), (text_x2, text_y2), color,
cv2.FILLED) cv2.FILLED)
cv2.putText(img, label, (text_x1, text_y2 - text_baseline), # cv2.putText(img, label, (text_x1, text_y2 - text_baseline),
cv2.FONT_HERSHEY_DUPLEX, font_scale, # cv2.FONT_HERSHEY_DUPLEX, font_scale,
mmcv.color_val(text_color), thickness) # mmcv.color_val(text_color), thickness)
# support chinese text
# TODO: Unify the font of cv2 and PIL, and auto get font_size according to the font_scale
img = put_text(
img, (text_x1, text_y1),
text=label,
fill=text_color,
size=font_size)
if show: if show:
mmcv.imshow(img, win_name, wait_time) mmcv.imshow(img, win_name, wait_time)

View File

@ -3,6 +3,7 @@
import torch import torch
from PIL import Image from PIL import Image
from easycv.core.visualization.image import imshow_label
from easycv.datasets.registry import DATASETS from easycv.datasets.registry import DATASETS
from easycv.datasets.shared.base import BaseDataset from easycv.datasets.shared.base import BaseDataset
@ -59,3 +60,37 @@ class ClsDataset(BaseDataset):
eval_res = evaluators[0].evaluate(results, gt_labels) eval_res = evaluators[0].evaluate(results, gt_labels)
return eval_res return eval_res
def visualize(self, results, vis_num=10, **kwargs):
"""Visulaize the model output on validation data.
Args:
results: A dictionary containing
class: List of length number of test images.
img_metas: List of length number of test images,
dict of image meta info, containing filename, img_shape,
origin_img_shape and so on.
vis_num: number of images visualized
Returns: A dictionary containing
images: Visulaized images, list of np.ndarray.
img_metas: List of length number of test images,
dict of image meta info, containing filename, img_shape,
origin_img_shape and so on.
"""
vis_imgs = []
# TODO: support img_metas for torch.jit
if results.get('img_metas', None) is None:
return {}
img_metas = results['img_metas'][:vis_num]
labels = results['class']
for i, img_meta in enumerate(img_metas):
filename = img_meta['filename']
vis_img = imshow_label(img=filename, labels=labels, show=False)
vis_imgs.append(vis_img)
output = {'images': vis_imgs, 'img_metas': img_metas}
return output

View File

@ -9,14 +9,13 @@ import numpy as np
import torch import torch
from easycv.datasets.registry import DATASETS, PIPELINES from easycv.datasets.registry import DATASETS, PIPELINES
from easycv.datasets.shared.base import BaseDataset
from easycv.utils import build_from_cfg from easycv.utils import build_from_cfg
from easycv.utils.bbox_util import batched_xyxy2cxcywh_with_shape
from easycv.utils.bbox_util import xyxy2xywh as xyxy2cxcywh from easycv.utils.bbox_util import xyxy2xywh as xyxy2cxcywh
from .raw import DetDataset
@DATASETS.register_module @DATASETS.register_module
class DetImagesMixDataset(BaseDataset): class DetImagesMixDataset(DetDataset):
"""A wrapper of multiple images mixed dataset. """A wrapper of multiple images mixed dataset.
Suitable for training on multiple images mixed data augmentation like Suitable for training on multiple images mixed data augmentation like
@ -50,7 +49,7 @@ class DetImagesMixDataset(BaseDataset):
label_padding=True): label_padding=True):
super(DetImagesMixDataset, self).__init__( super(DetImagesMixDataset, self).__init__(
data_source, pipeline, profiling=profiling) data_source, pipeline, profiling=profiling, classes=classes)
if skip_type_keys is not None: if skip_type_keys is not None:
assert all([ assert all([
@ -70,10 +69,9 @@ class DetImagesMixDataset(BaseDataset):
else: else:
raise TypeError('pipeline must be a dict') raise TypeError('pipeline must be a dict')
self.CLASSES = classes
if hasattr(self.data_source, 'flag'): if hasattr(self.data_source, 'flag'):
self.flag = self.data_source.flag self.flag = self.data_source.flag
self.num_samples = self.data_source.get_length()
if dynamic_scale is not None: if dynamic_scale is not None:
assert isinstance(dynamic_scale, tuple) assert isinstance(dynamic_scale, tuple)
@ -83,9 +81,6 @@ class DetImagesMixDataset(BaseDataset):
self.label_padding = label_padding self.label_padding = label_padding
self.max_labels_num = 120 self.max_labels_num = 120
def __len__(self):
return self.num_samples
def __getitem__(self, idx): def __getitem__(self, idx):
results = copy.deepcopy(self.data_source.get_sample(idx)) results = copy.deepcopy(self.data_source.get_sample(idx))
for (transform, transform_type) in zip(self.pipeline_yolox, for (transform, transform_type) in zip(self.pipeline_yolox,
@ -116,21 +111,8 @@ class DetImagesMixDataset(BaseDataset):
if 'img_scale' in results: if 'img_scale' in results:
results.pop('img_scale') results.pop('img_scale')
# print(result.keys())
# if self.yolo_format:
# # print(type(results['img_metas']), results['img_metas'])
# # print(type(results['img_metas']._data), results['img_metas']._data)
# img_shape = results['img_metas']._data['img_shape'][:2]
# # print(type(results['gt_bboxes']))
# gt_bboxes = xyxy2cxcywh_with_shape(results['gt_bboxes']._data, img_shape)
# results['gt_bboxes'] = gt_bboxes.float()
if self.label_padding: if self.label_padding:
cxcywh_gt_bboxes = xyxy2cxcywh(results['gt_bboxes']._data) cxcywh_gt_bboxes = xyxy2cxcywh(results['gt_bboxes']._data)
# cxcywh_gt_bboxes = results['gt_bboxes']._data
padded_gt_bboxes = torch.zeros((self.max_labels_num, 4), padded_gt_bboxes = torch.zeros((self.max_labels_num, 4),
device=cxcywh_gt_bboxes.device) device=cxcywh_gt_bboxes.device)
padded_gt_bboxes[range(cxcywh_gt_bboxes.shape[0])[:self.max_labels_num]] = \ padded_gt_bboxes[range(cxcywh_gt_bboxes.shape[0])[:self.max_labels_num]] = \
@ -146,9 +128,6 @@ class DetImagesMixDataset(BaseDataset):
results['gt_bboxes'] = padded_gt_bboxes results['gt_bboxes'] = padded_gt_bboxes
results['gt_labels'] = padded_labels results['gt_labels'] = padded_labels
# ['img_metas', 'img', 'gt_bboxes', 'gt_labels']
# results.pop('img_metas')
# print(results['img_metas'], "hhh", idx)
return results return results
def update_skip_type_keys(self, skip_type_keys): def update_skip_type_keys(self, skip_type_keys):
@ -240,30 +219,3 @@ class DetImagesMixDataset(BaseDataset):
tmp_dir = None tmp_dir = None
result_files = self.results2json(results, jsonfile_prefix) result_files = self.results2json(results, jsonfile_prefix)
return result_files, tmp_dir return result_files, tmp_dir
def evaluate(self, results, evaluators=None, logger=None):
'''results: a dict of list of Tensors, list length equals to number of test images
'''
eval_result = dict()
groundtruth_dict = {}
groundtruth_dict['groundtruth_boxes'] = [
batched_xyxy2cxcywh_with_shape(
self.data_source.get_ann_info(idx)['bboxes'],
results['img_metas'][idx]['ori_img_shape'])
for idx in range(len(results['img_metas']))
]
groundtruth_dict['groundtruth_classes'] = [
self.data_source.get_ann_info(idx)['labels']
for idx in range(len(results['img_metas']))
]
groundtruth_dict['groundtruth_is_crowd'] = [
self.data_source.get_ann_info(idx)['groundtruth_is_crowd']
for idx in range(len(results['img_metas']))
]
for evaluator in evaluators:
eval_result.update(evaluator.evaluate(results, groundtruth_dict))
# print(eval_result)
return eval_result

View File

@ -1,8 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np import numpy as np
from easycv.core.visualization.image import imshow_bboxes
from easycv.datasets.registry import DATASETS from easycv.datasets.registry import DATASETS
from easycv.datasets.shared.base import BaseDataset from easycv.datasets.shared.base import BaseDataset
from easycv.utils.bbox_util import batched_xyxy2cxcywh_with_shape
@DATASETS.register_module @DATASETS.register_module
@ -19,58 +21,124 @@ class DetDataset(BaseDataset):
classes: A list of class names, used in evaluation for result and groundtruth visualization classes: A list of class names, used in evaluation for result and groundtruth visualization
""" """
self.classes = classes self.classes = classes
self.CLASSES = classes
super(DetDataset, self).__init__( super(DetDataset, self).__init__(
data_source, pipeline, profiling=profiling) data_source, pipeline, profiling=profiling)
self.img_num = self.data_source.get_length() self.num_samples = self.data_source.get_length()
def __len__(self):
return self.num_samples
def __getitem__(self, idx): def __getitem__(self, idx):
data_dict = self.data_source.get_sample(idx) data_dict = self.data_source.get_sample(idx)
data_dict = self.pipeline(data_dict) data_dict = self.pipeline(data_dict)
return data_dict return data_dict
def evaluate(self, results, evaluators, logger=None): def evaluate(self, results, evaluators=None, logger=None):
'''results: a dict of list of Tensors, list length equals to number of test images '''Evaluates the detection boxes.
Args:
results: A dictionary containing
detection_boxes: List of length number of test images.
Float32 numpy array of shape [num_boxes, 4] and
format [ymin, xmin, ymax, xmax] in absolute image coordinates.
detection_scores: List of length number of test images,
detection scores for the boxes, float32 numpy array of shape [num_boxes].
detection_classes: List of length number of test images,
integer numpy array of shape [num_boxes]
containing 1-indexed detection classes for the boxes.
img_metas: List of length number of test images,
dict of image meta info, containing filename, img_shape,
origin_img_shape, scale_factor and so on.
evaluators: evaluators to calculate metric with results and groundtruth_dict
''' '''
eval_result = dict() eval_result = dict()
annotations = self.data_source.get_labels()
groundtruth_dict = {} groundtruth_dict = {}
groundtruth_dict['groundtruth_boxes'] = [ groundtruth_dict['groundtruth_boxes'] = [
labels[:, batched_xyxy2cxcywh_with_shape(
1:] if len(labels) > 0 else np.array([], dtype=np.float32) self.data_source.get_ann_info(idx)['bboxes'],
for labels in annotations results['img_metas'][idx]['ori_img_shape'])
for idx in range(len(results['img_metas']))
] ]
groundtruth_dict['groundtruth_classes'] = [ groundtruth_dict['groundtruth_classes'] = [
labels[:, 0] if len(labels) > 0 else np.array([], dtype=np.float32) self.data_source.get_ann_info(idx)['labels']
for labels in annotations for idx in range(len(results['img_metas']))
] ]
# bboxes = [label[:, 1:] for label in annotations] groundtruth_dict['groundtruth_is_crowd'] = [
# scores = [label[:, 0] for label in annotations] self.data_source.get_ann_info(idx)['groundtruth_is_crowd']
for idx in range(len(results['img_metas']))
]
for evaluator in evaluators: for evaluator in evaluators:
eval_result.update(evaluator.evaluate(results, groundtruth_dict)) eval_result.update(evaluator.evaluate(results, groundtruth_dict))
# eval_res = {'dummy': 1.0}
# img = self.data_source.load_ori_img(0)
# num_box = results['detection_scores'][0].size(0)
# scores = results['detection_scores'][0].detach().cpu().numpy()
# bboxes = torch.cat((results['detection_boxes'][0], results['detection_scores'][0].view(num_box, 1)), axis=1).detach().cpu().numpy()
# labels = results['detection_classes'][0].detach().cpu().numpy().astype(np.int32)
# # draw bounding boxes
# score_th = 0.3
# indices = scores > score_th
# filter_labels = labels[indices]
# print([(self.classes[i], score) for i, score in zip(filter_labels, scores)])
# mmcv.imshow_det_bboxes(
# img,
# bboxes,
# labels,
# class_names=self.classes,
# score_thr=score_th,
# bbox_color='red',
# text_color='black',
# thickness=1,
# font_scale=0.5,
# show=False,
# wait_time=0,
# out_file='test.jpg')
return eval_result return eval_result
def visualize(self, results, vis_num=10, score_thr=0.3, **kwargs):
"""Visulaize the model output on validation data.
Args:
results: A dictionary containing
detection_boxes: List of length number of test images.
Float32 numpy array of shape [num_boxes, 4] and
format [ymin, xmin, ymax, xmax] in absolute image coordinates.
detection_scores: List of length number of test images,
detection scores for the boxes, float32 numpy array of shape [num_boxes].
detection_classes: List of length number of test images,
integer numpy array of shape [num_boxes]
containing 1-indexed detection classes for the boxes.
img_metas: List of length number of test images,
dict of image meta info, containing filename, img_shape,
origin_img_shape, scale_factor and so on.
vis_num: number of images visualized
score_thr: The threshold to filter box,
boxes with scores greater than score_thr will be kept.
Returns: A dictionary containing
images: Visulaized images.
img_metas: List of length number of test images,
dict of image meta info, containing filename, img_shape,
origin_img_shape, scale_factor and so on.
"""
class_names = None
if hasattr(self.data_source, 'CLASSES'):
class_names = self.data_source.CLASSES
elif hasattr(self.data_source, 'classes'):
class_names = self.data_source.classes
if class_names is not None:
detection_classes = []
for classes_id in results['detection_classes']:
if classes_id is None:
detection_classes.append(None)
else:
detection_classes.append(
np.array([class_names[id] for id in classes_id]))
results['detection_classes'] = detection_classes
vis_imgs = []
img_metas = results['img_metas'][:vis_num]
detection_boxes = results.get('detection_boxes', [])
detection_scores = results.get('detection_scores', [])
detection_classes = results.get('detection_classes', [])
for i, img_meta in enumerate(img_metas):
filename = img_meta['filename']
bboxes = np.array(
[]) if detection_boxes[i] is None else detection_boxes[i]
scores = detection_scores[i]
classes = detection_classes[i]
if scores is not None and score_thr > 0:
inds = scores > score_thr
bboxes = bboxes[inds]
classes = classes[inds]
vis_img = imshow_bboxes(
img=filename, bboxes=bboxes, labels=classes, show=False)
vis_imgs.append(vis_img)
output = {'images': vis_imgs, 'img_metas': img_metas}
return output

View File

@ -28,3 +28,14 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
@abstractmethod @abstractmethod
def evaluate(self, results, evaluators, logger=None, **kwargs): def evaluate(self, results, evaluators, logger=None, **kwargs):
pass pass
def visualize(self, results, **kwargs):
"""Visulaize the model output results on validation data.
Returns: A dictionary
If add image visualization, return dict containing
images: List of visulaized images.
img_metas: List of length number of test images,
dict of image meta info, containing filename, img_shape,
origin_img_shape, scale_factor and so on.
"""
return {}

View File

@ -18,6 +18,8 @@ from .show_time_hook import TIMEHook
from .swav_hook import SWAVHook from .swav_hook import SWAVHook
from .sync_norm_hook import SyncNormHook from .sync_norm_hook import SyncNormHook
from .sync_random_size_hook import SyncRandomSizeHook from .sync_random_size_hook import SyncRandomSizeHook
from .tensorboard import TensorboardLoggerHookV2
from .wandb import WandbLoggerHookV2
from .yolox_lr_hook import YOLOXLrUpdaterHook from .yolox_lr_hook import YOLOXLrUpdaterHook
from .yolox_mode_switch_hook import YOLOXModeSwitchHook from .yolox_mode_switch_hook import YOLOXModeSwitchHook

View File

@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp import os.path as osp
from collections import OrderedDict
import torch import torch
from mmcv.runner import Hook from mmcv.runner import Hook
@ -45,6 +46,9 @@ class EvalHook(Hook):
self.mode = mode self.mode = mode
self.eval_kwargs = eval_kwargs self.eval_kwargs = eval_kwargs
# hook.evaluate runs every interval epoch or iter, popped at init
self.vis_config = self.eval_kwargs.pop('visualization_config', {})
self.gpu_collect = self.eval_kwargs.pop('gpu_collect', None)
self.flush_buffer = flush_buffer self.flush_buffer = flush_buffer
def before_run(self, runner): def before_run(self, runner):
@ -68,9 +72,23 @@ class EvalHook(Hook):
runner.model, self.dataloader, mode=self.mode, show=False) runner.model, self.dataloader, mode=self.mode, show=False)
self.evaluate(runner, results) self.evaluate(runner, results)
def evaluate(self, runner, results): def add_visualization_info(self, runner, results):
if runner.visualization_buffer.output.get('eval_results',
None) is None:
runner.visualization_buffer.output['eval_results'] = OrderedDict()
if isinstance(self.dataloader, DataLoader):
dataset_obj = self.dataloader.dataset
else:
dataset_obj = self.dataloader
if hasattr(dataset_obj, 'visualize'):
runner.visualization_buffer.output['eval_results'].update(
dataset_obj.visualize(results, **self.vis_config))
def evaluate(self, runner, results):
self.add_visualization_info(runner, results)
gpu_collect = self.eval_kwargs.pop('gpu_collect', None)
if isinstance(self.dataloader, DataLoader): if isinstance(self.dataloader, DataLoader):
eval_res = self.dataloader.dataset.evaluate( eval_res = self.dataloader.dataset.evaluate(
results, logger=runner.logger, **self.eval_kwargs) results, logger=runner.logger, **self.eval_kwargs)

View File

@ -0,0 +1,55 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
from mmcv.runner.dist_utils import master_only
from mmcv.runner.hooks import HOOKS
from mmcv.runner.hooks import TensorboardLoggerHook as _TensorboardLoggerHook
@HOOKS.register_module()
class TensorboardLoggerHookV2(_TensorboardLoggerHook):
def visualization_log(self, runner):
"""Images Visulization.
`visualization_buffer` is a dictionary containing:
images (list): list of visulaized images.
img_metas (list of dict, optional): dict containing ori_filename and so on.
ori_filename will be displayed as the tag of the image by default.
"""
visual_results = runner.visualization_buffer.output
for vis_key, vis_result in visual_results.items():
images = vis_result.get('images', [])
img_metas = vis_result.get('img_metas', None)
if img_metas is not None:
assert len(images) == len(
img_metas
), 'Output `images` and `img_metas` must keep the same length!'
for i, img in enumerate(images):
if isinstance(img, np.ndarray):
img = torch.from_numpy(img)
else:
assert isinstance(
img, torch.Tensor
), 'Only support np.ndarray and torch.Tensor type!'
default_name = 'image_%i' % i
filename = img_metas[i].get(
'ori_filename',
default_name) if img_metas is not None else default_name
self.writer.add_image(
f'{vis_key}/{filename}',
img,
self.get_iter(runner),
dataformats='HWC')
@master_only
def log(self, runner):
self.visualization_log(runner)
super(TensorboardLoggerHookV2, self).log(runner)
def after_train_iter(self, runner):
super(TensorboardLoggerHookV2, self).after_train_iter(runner)
# clear visualization_buffer after each iter to ensure that it is only written once,
# avoiding repeated writing of the same image buffer every self.interval
runner.visualization_buffer.clear_output()

View File

@ -0,0 +1,54 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import cv2
import numpy as np
from mmcv.runner.dist_utils import master_only
from mmcv.runner.hooks import HOOKS
from mmcv.runner.hooks import WandbLoggerHook as _WandbLoggerHook
from PIL import Image as PILImage
@HOOKS.register_module()
class WandbLoggerHookV2(_WandbLoggerHook):
def visualization_log(self, runner):
"""Images Visulization.
`visualization_buffer` is a dictionary containing:
images (list): list of visulaized images.
img_metas (list of dict, optional): dict containing ori_filename and so on.
ori_filename will be displayed as the tag of the image by default.
"""
visual_results = runner.visualization_buffer.output
for vis_key, vis_result in visual_results.items():
images = vis_result.get('images', [])
img_metas = vis_result.get('img_metas', None)
if img_metas is not None:
assert len(images) == len(
img_metas
), 'Output `images` and `img_metas` must keep the same length!'
examples = []
for i, img in enumerate(images):
assert isinstance(img, np.ndarray)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
pil_image = PILImage.fromarray(img, mode='RGB')
default_name = 'image_%i' % i
filename = img_metas[i].get(
'ori_filename',
default_name) if img_metas is not None else default_name
image = self.wandb.Image(pil_image, caption=filename)
examples.append(image)
self.wandb.log({vis_key: examples},
step=self.get_iter(runner),
commit=self.commit)
@master_only
def log(self, runner):
self.visualization_log(runner)
super(WandbLoggerHookV2, self).log(runner)
def after_train_iter(self, runner):
super(WandbLoggerHookV2, self).after_train_iter(runner)
# clear visualization_buffer after each iter to ensure that it is only written once,
# avoiding repeated writing of the same image buffer every self.interval
runner.visualization_buffer.clear_output()

View File

@ -5,6 +5,7 @@ from distutils.version import LooseVersion
import torch import torch
from mmcv.runner import EpochBasedRunner from mmcv.runner import EpochBasedRunner
from mmcv.runner.log_buffer import LogBuffer
from easycv.file import io from easycv.file import io
from easycv.utils.checkpoint import load_checkpoint, save_checkpoint from easycv.utils.checkpoint import load_checkpoint, save_checkpoint
@ -47,6 +48,7 @@ class EVRunner(EpochBasedRunner):
meta) meta)
self.data_loader = None self.data_loader = None
self.fp16_enable = False self.fp16_enable = False
self.visualization_buffer = LogBuffer()
def run_iter(self, data_batch, train_mode, **kwargs): def run_iter(self, data_batch, train_mode, **kwargs):
""" process for each iteration. """ process for each iteration.

BIN
resource/simhei.ttf 100644

Binary file not shown.

View File

@ -159,6 +159,7 @@ def pack_resource():
'./thirdparty', './thirdparty',
proj_dir + 'thirdparty', proj_dir + 'thirdparty',
ignore=shutil.ignore_patterns('*.pyc')) ignore=shutil.ignore_patterns('*.pyc'))
shutil.copytree('./resource', proj_dir + 'resource')
shutil.copytree('./requirements', 'package/requirements') shutil.copytree('./requirements', 'package/requirements')
shutil.copy('./requirements.txt', 'package/requirements.txt') shutil.copy('./requirements.txt', 'package/requirements.txt')
shutil.copy('./MANIFEST.in', 'package/MANIFEST.in') shutil.copy('./MANIFEST.in', 'package/MANIFEST.in')
@ -181,6 +182,7 @@ if __name__ == '__main__':
keywords='self-supvervised, classification, vision', keywords='self-supvervised, classification, vision',
url='https://github.com/alibaba/EasyCV.git', url='https://github.com/alibaba/EasyCV.git',
packages=find_packages(exclude=('configs', 'tools', 'demo')), packages=find_packages(exclude=('configs', 'tools', 'demo')),
include_package_data=True,
classifiers=[ classifiers=[
'Development Status :: 4 - Beta', 'Development Status :: 4 - Beta',
'License :: OSI Approved :: Apache Software License', 'License :: OSI Approved :: Apache Software License',

View File

@ -4,7 +4,8 @@ import unittest
import numpy as np import numpy as np
from easycv.core.visualization import imshow_bboxes, imshow_keypoints from easycv.core.visualization import (imshow_bboxes, imshow_keypoints,
imshow_label)
class ImshowTest(unittest.TestCase): class ImshowTest(unittest.TestCase):
@ -32,7 +33,7 @@ class ImshowTest(unittest.TestCase):
img = np.zeros((100, 100, 3), dtype=np.uint8) img = np.zeros((100, 100, 3), dtype=np.uint8)
bboxes = np.array([[10, 10, 30, 30], [10, 50, 30, 80]], bboxes = np.array([[10, 10, 30, 30], [10, 50, 30, 80]],
dtype=np.float32) dtype=np.float32)
labels = ['label 1', 'label 2'] labels = ['标签 1', 'label 2']
colors = ['red', 'green'] colors = ['red', 'green']
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
@ -65,6 +66,17 @@ class ImshowTest(unittest.TestCase):
else: else:
self.fail('ValueError not raised') self.fail('ValueError not raised')
def test_imshow_label(self):
img = np.zeros((100, 100, 3), dtype=np.uint8)
labels = ['标签 1', 'label 2']
with tempfile.TemporaryDirectory() as tmpdir:
_ = imshow_label(
img=img,
labels=labels,
show=False,
out_file=f'{tmpdir}/out.png')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -13,7 +13,7 @@ class ClsDatasetTest(unittest.TestCase):
def setUp(self): def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_default(self): def _get_dataset(self):
data_root = SMALL_IMAGENET_RAW_LOCAL data_root = SMALL_IMAGENET_RAW_LOCAL
data_train_list = os.path.join(data_root, 'meta/train_labeled_200.txt') data_train_list = os.path.join(data_root, 'meta/train_labeled_200.txt')
pipeline = [ pipeline = [
@ -33,12 +33,36 @@ class ClsDatasetTest(unittest.TestCase):
pipeline=pipeline)) pipeline=pipeline))
dataset = build_dataset(data['train']) dataset = build_dataset(data['train'])
return dataset
def test_default(self):
dataset = self._get_dataset()
for _, batch in enumerate(dataset): for _, batch in enumerate(dataset):
img, target = batch['img'], batch['gt_labels'] img, target = batch['img'], batch['gt_labels']
self.assertEqual(img.shape, torch.Size([3, 224, 224])) self.assertEqual(img.shape, torch.Size([3, 224, 224]))
self.assertIn(target, list(range(1000))) self.assertIn(target, list(range(1000)))
break break
def test_visualize(self):
# TODO: add img_metas for classification
return
dataset = self._get_dataset()
count = 5
classes = []
img_metas = []
for i, data in enumerate(dataset):
classes.append(data['gt_labels'])
img_metas.append(data['img_metas'].data)
if i > count:
break
results = {'class': classes, 'img_metas': img_metas}
output = dataset.visualize(results, vis_num=2)
self.assertEqual(len(output['images']), 2)
self.assertEqual(len(output['img_metas']), 2)
self.assertEqual(len(output['images'][0].shape), 3)
self.assertIn('filename', output['img_metas'][0])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -3,6 +3,7 @@ import os
import time import time
import unittest import unittest
import numpy as np
from tests.ut_config import DET_DATA_RAW_LOCAL, IMG_NORM_CFG_255 from tests.ut_config import DET_DATA_RAW_LOCAL, IMG_NORM_CFG_255
from easycv.datasets.detection import DetDataset from easycv.datasets.detection import DetDataset
@ -13,7 +14,7 @@ class DetDatasetTest(unittest.TestCase):
def setUp(self): def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_load(self): def _get_dataset(self):
img_scale = (640, 640) img_scale = (640, 640)
data_source_cfg = dict( data_source_cfg = dict(
type='DetSourceRaw', type='DetSourceRaw',
@ -33,6 +34,11 @@ class DetDatasetTest(unittest.TestCase):
] ]
dataset = DetDataset(data_source=data_source_cfg, pipeline=pipeline) dataset = DetDataset(data_source=data_source_cfg, pipeline=pipeline)
return dataset
def test_load(self):
dataset = self._get_dataset()
data_num = len(dataset) data_num = len(dataset)
s = time.time() s = time.time()
for data in dataset: for data in dataset:
@ -47,6 +53,37 @@ class DetDatasetTest(unittest.TestCase):
self.assertTrue('img_shape' in img_metas) self.assertTrue('img_shape' in img_metas)
self.assertTrue('ori_img_shape' in img_metas) self.assertTrue('ori_img_shape' in img_metas)
def test_visualize(self):
dataset = self._get_dataset()
count = 5
detection_boxes = []
detection_classes = []
img_metas = []
for i, data in enumerate(dataset):
detection_boxes.append(
data['gt_bboxes'].data.cpu().detach().numpy())
detection_classes.append(
data['gt_labels'].data.cpu().detach().numpy())
img_metas.append(data['img_metas'].data)
if i > count:
break
detection_scores = []
for classes in detection_classes:
detection_scores.append(0.1 * np.array(range(len(classes))))
results = {
'detection_boxes': detection_boxes,
'detection_scores': detection_scores,
'detection_classes': detection_classes,
'img_metas': img_metas
}
output = dataset.visualize(results, vis_num=2, score_thr=0.1)
self.assertEqual(len(output['images']), 2)
self.assertEqual(len(output['img_metas']), 2)
self.assertEqual(len(output['images'][0].shape), 3)
self.assertIn('filename', output['img_metas'][0])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -21,7 +21,7 @@ class SSLSourceImageNetFeatureTest(unittest.TestCase):
index_list = random.choices(list(range(100)), k=3) index_list = random.choices(list(range(100)), k=3)
for idx in index_list: for idx in index_list:
results = data_source.get_sample(idx) results = data_source.get_sample(idx)
feat = results['feature'] feat = results['img']
label = results['gt_labels'] label = results['gt_labels']
self.assertEqual(feat.shape, (2048, )) self.assertEqual(feat.shape, (2048, ))
self.assertIn(label, list(range(1000))) self.assertIn(label, list(range(1000)))
@ -37,7 +37,7 @@ class SSLSourceImageNetFeatureTest(unittest.TestCase):
index_list = random.choices(list(range(100)), k=3) index_list = random.choices(list(range(100)), k=3)
for idx in index_list: for idx in index_list:
results = data_source.get_sample(idx) results = data_source.get_sample(idx)
feat = results['feature'] feat = results['img']
label = results['gt_labels'] label = results['gt_labels']
self.assertEqual(feat.shape, (2048, )) self.assertEqual(feat.shape, (2048, ))
self.assertIn(label, list(range(1000))) self.assertIn(label, list(range(1000)))