mirror of https://github.com/alibaba/EasyCV.git
[Feature]: support image visualization for tensorboard and wandb (#15)
* [Feature]: support image visualization for tensorboard and wandbpull/21/head^2
parent
3a4a3a9d0c
commit
9a3826f0d2
|
@ -1,2 +1,3 @@
|
|||
recursive-include easycv/configs *.py
|
||||
recursive-include easycv/tools *.py
|
||||
recursive-include easycv/resource/ *.ttf
|
||||
|
|
|
@ -23,4 +23,5 @@ data = dict(
|
|||
dict(type='CenterCrop', size=224),
|
||||
dict(type='ToTensor'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Collect', keys=['img', 'gt_labels'])
|
||||
]))
|
||||
|
|
|
@ -130,7 +130,14 @@ custom_hooks = [
|
|||
]
|
||||
|
||||
# evaluation
|
||||
eval_config = dict(interval=10, gpu_collect=False)
|
||||
eval_config = dict(
|
||||
interval=10,
|
||||
gpu_collect=False,
|
||||
visualization_config=dict(
|
||||
vis_num=10,
|
||||
score_thr=0.5,
|
||||
) # show by TensorboardLoggerHookV2 and WandbLoggerHookV2
|
||||
)
|
||||
eval_pipelines = [
|
||||
dict(
|
||||
mode='test',
|
||||
|
@ -168,7 +175,8 @@ log_config = dict(
|
|||
interval=100,
|
||||
hooks=[
|
||||
dict(type='TextLoggerHook'),
|
||||
dict(type='TensorboardLoggerHook')
|
||||
dict(type='TensorboardLoggerHookV2'),
|
||||
# dict(type='WandbLoggerHookV2'),
|
||||
])
|
||||
|
||||
export = dict(use_jit=False)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .image import imshow_bboxes, imshow_keypoints
|
||||
from .image import imshow_bboxes, imshow_keypoints, imshow_label
|
||||
|
||||
__all__ = ['imshow_bboxes', 'imshow_keypoints']
|
||||
__all__ = ['imshow_bboxes', 'imshow_keypoints', 'imshow_label']
|
||||
|
|
|
@ -1,11 +1,103 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Adapt from https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/visualization/image.py
|
||||
import math
|
||||
import os
|
||||
from os.path import dirname as opd
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv.utils.misc import deprecated_api_warning
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
|
||||
def get_font_path():
|
||||
root_path = opd(opd(opd(os.path.realpath(__file__))))
|
||||
# find in whl
|
||||
find_path_whl = os.path.join(root_path, 'resource/simhei.ttf')
|
||||
# find in source code
|
||||
find_path_source = os.path.join(opd(root_path), 'resource/simhei.ttf')
|
||||
if os.path.exists(find_path_whl):
|
||||
return find_path_whl
|
||||
elif os.path.exists(find_path_source):
|
||||
return find_path_source
|
||||
else:
|
||||
raise ValueError('Not find font file both in %s and %s' %
|
||||
(find_path_whl, find_path_source))
|
||||
|
||||
|
||||
_FONT_PATH = get_font_path()
|
||||
|
||||
|
||||
def put_text(img, xy, text, fill, size=20):
|
||||
"""support chinese text
|
||||
"""
|
||||
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||||
draw = ImageDraw.Draw(img)
|
||||
fontText = ImageFont.truetype(_FONT_PATH, size=size, encoding='utf-8')
|
||||
draw.text(xy, text, fill=fill, font=fontText)
|
||||
img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
|
||||
return img
|
||||
|
||||
|
||||
def imshow_label(img,
|
||||
labels,
|
||||
text_color='blue',
|
||||
font_size=20,
|
||||
thickness=1,
|
||||
font_scale=0.5,
|
||||
intervel=5,
|
||||
show=True,
|
||||
win_name='',
|
||||
wait_time=0,
|
||||
out_file=None):
|
||||
"""Draw images with labels on an image.
|
||||
|
||||
Args:
|
||||
img (str or ndarray): The image to be displayed.
|
||||
labels (str or list[str]): labels of each image.
|
||||
text_color (str or tuple or :obj:`Color`): Color of texts.
|
||||
font_size (int): Size of font.
|
||||
thickness (int): Thickness of lines.
|
||||
font_scale (float): Font scales of texts.
|
||||
intervel(int): interval pixels between multiple labels
|
||||
show (bool): Whether to show the image.
|
||||
win_name (str): The window name.
|
||||
wait_time (int): Value of waitKey param.
|
||||
out_file (str, optional): The filename to write the image.
|
||||
|
||||
Returns:
|
||||
ndarray: The image with bboxes drawn on it.
|
||||
"""
|
||||
img = mmcv.imread(img)
|
||||
img = np.ascontiguousarray(img)
|
||||
labels = [labels] if isinstance(labels, str) else labels
|
||||
|
||||
cur_height = 0
|
||||
for label in labels:
|
||||
# roughly estimate the proper font size
|
||||
text_size, text_baseline = cv2.getTextSize(label,
|
||||
cv2.FONT_HERSHEY_DUPLEX,
|
||||
font_scale, thickness)
|
||||
|
||||
org = (text_baseline + text_size[1],
|
||||
text_baseline + text_size[1] + cur_height)
|
||||
|
||||
# support chinese text
|
||||
# TODO: Unify the font of cv2 and PIL, and auto get font_size according to the font_scale
|
||||
img = put_text(img, org, text=label, fill=text_color, size=font_size)
|
||||
|
||||
# cv2.putText(img, label, org, cv2.FONT_HERSHEY_DUPLEX, font_scale,
|
||||
# mmcv.color_val(text_color), thickness)
|
||||
|
||||
cur_height += text_baseline + text_size[1] + intervel
|
||||
|
||||
if show:
|
||||
mmcv.imshow(img, win_name, wait_time)
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(img, out_file)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def imshow_bboxes(img,
|
||||
|
@ -13,6 +105,7 @@ def imshow_bboxes(img,
|
|||
labels=None,
|
||||
colors='green',
|
||||
text_color='white',
|
||||
font_size=20,
|
||||
thickness=1,
|
||||
font_scale=0.5,
|
||||
show=True,
|
||||
|
@ -29,6 +122,7 @@ def imshow_bboxes(img,
|
|||
labels (str or list[str], optional): labels of each bbox.
|
||||
colors (list[str or tuple or :obj:`Color`]): A list of colors.
|
||||
text_color (str or tuple or :obj:`Color`): Color of texts.
|
||||
font_size (int): Size of font.
|
||||
thickness (int): Thickness of lines.
|
||||
font_scale (float): Font scales of texts.
|
||||
show (bool): Whether to show the image.
|
||||
|
@ -58,11 +152,10 @@ def imshow_bboxes(img,
|
|||
out_file=None)
|
||||
|
||||
if labels is not None:
|
||||
if not isinstance(labels, list):
|
||||
labels = [labels for _ in range(len(bboxes))]
|
||||
assert len(labels) == len(bboxes)
|
||||
|
||||
for bbox, label, color in zip(bboxes, labels, colors):
|
||||
label = str(label)
|
||||
bbox_int = bbox[0, :4].astype(np.int32)
|
||||
# roughly estimate the proper font size
|
||||
text_size, text_baseline = cv2.getTextSize(label,
|
||||
|
@ -74,9 +167,17 @@ def imshow_bboxes(img,
|
|||
text_y2 = text_y1 + text_size[1] + text_baseline
|
||||
cv2.rectangle(img, (text_x1, text_y1), (text_x2, text_y2), color,
|
||||
cv2.FILLED)
|
||||
cv2.putText(img, label, (text_x1, text_y2 - text_baseline),
|
||||
cv2.FONT_HERSHEY_DUPLEX, font_scale,
|
||||
mmcv.color_val(text_color), thickness)
|
||||
# cv2.putText(img, label, (text_x1, text_y2 - text_baseline),
|
||||
# cv2.FONT_HERSHEY_DUPLEX, font_scale,
|
||||
# mmcv.color_val(text_color), thickness)
|
||||
|
||||
# support chinese text
|
||||
# TODO: Unify the font of cv2 and PIL, and auto get font_size according to the font_scale
|
||||
img = put_text(
|
||||
img, (text_x1, text_y1),
|
||||
text=label,
|
||||
fill=text_color,
|
||||
size=font_size)
|
||||
|
||||
if show:
|
||||
mmcv.imshow(img, win_name, wait_time)
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from easycv.core.visualization.image import imshow_label
|
||||
from easycv.datasets.registry import DATASETS
|
||||
from easycv.datasets.shared.base import BaseDataset
|
||||
|
||||
|
@ -59,3 +60,37 @@ class ClsDataset(BaseDataset):
|
|||
eval_res = evaluators[0].evaluate(results, gt_labels)
|
||||
|
||||
return eval_res
|
||||
|
||||
def visualize(self, results, vis_num=10, **kwargs):
|
||||
"""Visulaize the model output on validation data.
|
||||
Args:
|
||||
results: A dictionary containing
|
||||
class: List of length number of test images.
|
||||
img_metas: List of length number of test images,
|
||||
dict of image meta info, containing filename, img_shape,
|
||||
origin_img_shape and so on.
|
||||
vis_num: number of images visualized
|
||||
Returns: A dictionary containing
|
||||
images: Visulaized images, list of np.ndarray.
|
||||
img_metas: List of length number of test images,
|
||||
dict of image meta info, containing filename, img_shape,
|
||||
origin_img_shape and so on.
|
||||
"""
|
||||
vis_imgs = []
|
||||
|
||||
# TODO: support img_metas for torch.jit
|
||||
if results.get('img_metas', None) is None:
|
||||
return {}
|
||||
|
||||
img_metas = results['img_metas'][:vis_num]
|
||||
labels = results['class']
|
||||
|
||||
for i, img_meta in enumerate(img_metas):
|
||||
filename = img_meta['filename']
|
||||
|
||||
vis_img = imshow_label(img=filename, labels=labels, show=False)
|
||||
vis_imgs.append(vis_img)
|
||||
|
||||
output = {'images': vis_imgs, 'img_metas': img_metas}
|
||||
|
||||
return output
|
||||
|
|
|
@ -9,14 +9,13 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from easycv.datasets.registry import DATASETS, PIPELINES
|
||||
from easycv.datasets.shared.base import BaseDataset
|
||||
from easycv.utils import build_from_cfg
|
||||
from easycv.utils.bbox_util import batched_xyxy2cxcywh_with_shape
|
||||
from easycv.utils.bbox_util import xyxy2xywh as xyxy2cxcywh
|
||||
from .raw import DetDataset
|
||||
|
||||
|
||||
@DATASETS.register_module
|
||||
class DetImagesMixDataset(BaseDataset):
|
||||
class DetImagesMixDataset(DetDataset):
|
||||
"""A wrapper of multiple images mixed dataset.
|
||||
|
||||
Suitable for training on multiple images mixed data augmentation like
|
||||
|
@ -50,7 +49,7 @@ class DetImagesMixDataset(BaseDataset):
|
|||
label_padding=True):
|
||||
|
||||
super(DetImagesMixDataset, self).__init__(
|
||||
data_source, pipeline, profiling=profiling)
|
||||
data_source, pipeline, profiling=profiling, classes=classes)
|
||||
|
||||
if skip_type_keys is not None:
|
||||
assert all([
|
||||
|
@ -70,10 +69,9 @@ class DetImagesMixDataset(BaseDataset):
|
|||
else:
|
||||
raise TypeError('pipeline must be a dict')
|
||||
|
||||
self.CLASSES = classes
|
||||
if hasattr(self.data_source, 'flag'):
|
||||
self.flag = self.data_source.flag
|
||||
self.num_samples = self.data_source.get_length()
|
||||
|
||||
if dynamic_scale is not None:
|
||||
assert isinstance(dynamic_scale, tuple)
|
||||
|
||||
|
@ -83,9 +81,6 @@ class DetImagesMixDataset(BaseDataset):
|
|||
self.label_padding = label_padding
|
||||
self.max_labels_num = 120
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
results = copy.deepcopy(self.data_source.get_sample(idx))
|
||||
for (transform, transform_type) in zip(self.pipeline_yolox,
|
||||
|
@ -116,21 +111,8 @@ class DetImagesMixDataset(BaseDataset):
|
|||
if 'img_scale' in results:
|
||||
results.pop('img_scale')
|
||||
|
||||
# print(result.keys())
|
||||
|
||||
# if self.yolo_format:
|
||||
# # print(type(results['img_metas']), results['img_metas'])
|
||||
# # print(type(results['img_metas']._data), results['img_metas']._data)
|
||||
# img_shape = results['img_metas']._data['img_shape'][:2]
|
||||
# # print(type(results['gt_bboxes']))
|
||||
# gt_bboxes = xyxy2cxcywh_with_shape(results['gt_bboxes']._data, img_shape)
|
||||
# results['gt_bboxes'] = gt_bboxes.float()
|
||||
|
||||
if self.label_padding:
|
||||
|
||||
cxcywh_gt_bboxes = xyxy2cxcywh(results['gt_bboxes']._data)
|
||||
# cxcywh_gt_bboxes = results['gt_bboxes']._data
|
||||
|
||||
padded_gt_bboxes = torch.zeros((self.max_labels_num, 4),
|
||||
device=cxcywh_gt_bboxes.device)
|
||||
padded_gt_bboxes[range(cxcywh_gt_bboxes.shape[0])[:self.max_labels_num]] = \
|
||||
|
@ -146,9 +128,6 @@ class DetImagesMixDataset(BaseDataset):
|
|||
results['gt_bboxes'] = padded_gt_bboxes
|
||||
results['gt_labels'] = padded_labels
|
||||
|
||||
# ['img_metas', 'img', 'gt_bboxes', 'gt_labels']
|
||||
# results.pop('img_metas')
|
||||
# print(results['img_metas'], "hhh", idx)
|
||||
return results
|
||||
|
||||
def update_skip_type_keys(self, skip_type_keys):
|
||||
|
@ -240,30 +219,3 @@ class DetImagesMixDataset(BaseDataset):
|
|||
tmp_dir = None
|
||||
result_files = self.results2json(results, jsonfile_prefix)
|
||||
return result_files, tmp_dir
|
||||
|
||||
def evaluate(self, results, evaluators=None, logger=None):
|
||||
'''results: a dict of list of Tensors, list length equals to number of test images
|
||||
'''
|
||||
|
||||
eval_result = dict()
|
||||
|
||||
groundtruth_dict = {}
|
||||
groundtruth_dict['groundtruth_boxes'] = [
|
||||
batched_xyxy2cxcywh_with_shape(
|
||||
self.data_source.get_ann_info(idx)['bboxes'],
|
||||
results['img_metas'][idx]['ori_img_shape'])
|
||||
for idx in range(len(results['img_metas']))
|
||||
]
|
||||
groundtruth_dict['groundtruth_classes'] = [
|
||||
self.data_source.get_ann_info(idx)['labels']
|
||||
for idx in range(len(results['img_metas']))
|
||||
]
|
||||
groundtruth_dict['groundtruth_is_crowd'] = [
|
||||
self.data_source.get_ann_info(idx)['groundtruth_is_crowd']
|
||||
for idx in range(len(results['img_metas']))
|
||||
]
|
||||
|
||||
for evaluator in evaluators:
|
||||
eval_result.update(evaluator.evaluate(results, groundtruth_dict))
|
||||
# print(eval_result)
|
||||
return eval_result
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import numpy as np
|
||||
|
||||
from easycv.core.visualization.image import imshow_bboxes
|
||||
from easycv.datasets.registry import DATASETS
|
||||
from easycv.datasets.shared.base import BaseDataset
|
||||
from easycv.utils.bbox_util import batched_xyxy2cxcywh_with_shape
|
||||
|
||||
|
||||
@DATASETS.register_module
|
||||
|
@ -19,58 +21,124 @@ class DetDataset(BaseDataset):
|
|||
classes: A list of class names, used in evaluation for result and groundtruth visualization
|
||||
"""
|
||||
self.classes = classes
|
||||
self.CLASSES = classes
|
||||
|
||||
super(DetDataset, self).__init__(
|
||||
data_source, pipeline, profiling=profiling)
|
||||
self.img_num = self.data_source.get_length()
|
||||
self.num_samples = self.data_source.get_length()
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
data_dict = self.data_source.get_sample(idx)
|
||||
data_dict = self.pipeline(data_dict)
|
||||
return data_dict
|
||||
|
||||
def evaluate(self, results, evaluators, logger=None):
|
||||
'''results: a dict of list of Tensors, list length equals to number of test images
|
||||
def evaluate(self, results, evaluators=None, logger=None):
|
||||
'''Evaluates the detection boxes.
|
||||
Args:
|
||||
results: A dictionary containing
|
||||
detection_boxes: List of length number of test images.
|
||||
Float32 numpy array of shape [num_boxes, 4] and
|
||||
format [ymin, xmin, ymax, xmax] in absolute image coordinates.
|
||||
detection_scores: List of length number of test images,
|
||||
detection scores for the boxes, float32 numpy array of shape [num_boxes].
|
||||
detection_classes: List of length number of test images,
|
||||
integer numpy array of shape [num_boxes]
|
||||
containing 1-indexed detection classes for the boxes.
|
||||
img_metas: List of length number of test images,
|
||||
dict of image meta info, containing filename, img_shape,
|
||||
origin_img_shape, scale_factor and so on.
|
||||
evaluators: evaluators to calculate metric with results and groundtruth_dict
|
||||
'''
|
||||
|
||||
eval_result = dict()
|
||||
annotations = self.data_source.get_labels()
|
||||
|
||||
groundtruth_dict = {}
|
||||
groundtruth_dict['groundtruth_boxes'] = [
|
||||
labels[:,
|
||||
1:] if len(labels) > 0 else np.array([], dtype=np.float32)
|
||||
for labels in annotations
|
||||
batched_xyxy2cxcywh_with_shape(
|
||||
self.data_source.get_ann_info(idx)['bboxes'],
|
||||
results['img_metas'][idx]['ori_img_shape'])
|
||||
for idx in range(len(results['img_metas']))
|
||||
]
|
||||
groundtruth_dict['groundtruth_classes'] = [
|
||||
labels[:, 0] if len(labels) > 0 else np.array([], dtype=np.float32)
|
||||
for labels in annotations
|
||||
self.data_source.get_ann_info(idx)['labels']
|
||||
for idx in range(len(results['img_metas']))
|
||||
]
|
||||
# bboxes = [label[:, 1:] for label in annotations]
|
||||
# scores = [label[:, 0] for label in annotations]
|
||||
groundtruth_dict['groundtruth_is_crowd'] = [
|
||||
self.data_source.get_ann_info(idx)['groundtruth_is_crowd']
|
||||
for idx in range(len(results['img_metas']))
|
||||
]
|
||||
|
||||
for evaluator in evaluators:
|
||||
eval_result.update(evaluator.evaluate(results, groundtruth_dict))
|
||||
# eval_res = {'dummy': 1.0}
|
||||
# img = self.data_source.load_ori_img(0)
|
||||
# num_box = results['detection_scores'][0].size(0)
|
||||
# scores = results['detection_scores'][0].detach().cpu().numpy()
|
||||
# bboxes = torch.cat((results['detection_boxes'][0], results['detection_scores'][0].view(num_box, 1)), axis=1).detach().cpu().numpy()
|
||||
# labels = results['detection_classes'][0].detach().cpu().numpy().astype(np.int32)
|
||||
# # draw bounding boxes
|
||||
# score_th = 0.3
|
||||
# indices = scores > score_th
|
||||
# filter_labels = labels[indices]
|
||||
# print([(self.classes[i], score) for i, score in zip(filter_labels, scores)])
|
||||
# mmcv.imshow_det_bboxes(
|
||||
# img,
|
||||
# bboxes,
|
||||
# labels,
|
||||
# class_names=self.classes,
|
||||
# score_thr=score_th,
|
||||
# bbox_color='red',
|
||||
# text_color='black',
|
||||
# thickness=1,
|
||||
# font_scale=0.5,
|
||||
# show=False,
|
||||
# wait_time=0,
|
||||
# out_file='test.jpg')
|
||||
|
||||
return eval_result
|
||||
|
||||
def visualize(self, results, vis_num=10, score_thr=0.3, **kwargs):
|
||||
"""Visulaize the model output on validation data.
|
||||
Args:
|
||||
results: A dictionary containing
|
||||
detection_boxes: List of length number of test images.
|
||||
Float32 numpy array of shape [num_boxes, 4] and
|
||||
format [ymin, xmin, ymax, xmax] in absolute image coordinates.
|
||||
detection_scores: List of length number of test images,
|
||||
detection scores for the boxes, float32 numpy array of shape [num_boxes].
|
||||
detection_classes: List of length number of test images,
|
||||
integer numpy array of shape [num_boxes]
|
||||
containing 1-indexed detection classes for the boxes.
|
||||
img_metas: List of length number of test images,
|
||||
dict of image meta info, containing filename, img_shape,
|
||||
origin_img_shape, scale_factor and so on.
|
||||
vis_num: number of images visualized
|
||||
score_thr: The threshold to filter box,
|
||||
boxes with scores greater than score_thr will be kept.
|
||||
Returns: A dictionary containing
|
||||
images: Visulaized images.
|
||||
img_metas: List of length number of test images,
|
||||
dict of image meta info, containing filename, img_shape,
|
||||
origin_img_shape, scale_factor and so on.
|
||||
"""
|
||||
class_names = None
|
||||
if hasattr(self.data_source, 'CLASSES'):
|
||||
class_names = self.data_source.CLASSES
|
||||
elif hasattr(self.data_source, 'classes'):
|
||||
class_names = self.data_source.classes
|
||||
|
||||
if class_names is not None:
|
||||
detection_classes = []
|
||||
for classes_id in results['detection_classes']:
|
||||
if classes_id is None:
|
||||
detection_classes.append(None)
|
||||
else:
|
||||
detection_classes.append(
|
||||
np.array([class_names[id] for id in classes_id]))
|
||||
results['detection_classes'] = detection_classes
|
||||
|
||||
vis_imgs = []
|
||||
|
||||
img_metas = results['img_metas'][:vis_num]
|
||||
detection_boxes = results.get('detection_boxes', [])
|
||||
detection_scores = results.get('detection_scores', [])
|
||||
detection_classes = results.get('detection_classes', [])
|
||||
|
||||
for i, img_meta in enumerate(img_metas):
|
||||
filename = img_meta['filename']
|
||||
bboxes = np.array(
|
||||
[]) if detection_boxes[i] is None else detection_boxes[i]
|
||||
scores = detection_scores[i]
|
||||
classes = detection_classes[i]
|
||||
|
||||
if scores is not None and score_thr > 0:
|
||||
inds = scores > score_thr
|
||||
bboxes = bboxes[inds]
|
||||
classes = classes[inds]
|
||||
|
||||
vis_img = imshow_bboxes(
|
||||
img=filename, bboxes=bboxes, labels=classes, show=False)
|
||||
vis_imgs.append(vis_img)
|
||||
|
||||
output = {'images': vis_imgs, 'img_metas': img_metas}
|
||||
|
||||
return output
|
||||
|
|
|
@ -28,3 +28,14 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
|||
@abstractmethod
|
||||
def evaluate(self, results, evaluators, logger=None, **kwargs):
|
||||
pass
|
||||
|
||||
def visualize(self, results, **kwargs):
|
||||
"""Visulaize the model output results on validation data.
|
||||
Returns: A dictionary
|
||||
If add image visualization, return dict containing
|
||||
images: List of visulaized images.
|
||||
img_metas: List of length number of test images,
|
||||
dict of image meta info, containing filename, img_shape,
|
||||
origin_img_shape, scale_factor and so on.
|
||||
"""
|
||||
return {}
|
||||
|
|
|
@ -18,6 +18,8 @@ from .show_time_hook import TIMEHook
|
|||
from .swav_hook import SWAVHook
|
||||
from .sync_norm_hook import SyncNormHook
|
||||
from .sync_random_size_hook import SyncRandomSizeHook
|
||||
from .tensorboard import TensorboardLoggerHookV2
|
||||
from .wandb import WandbLoggerHookV2
|
||||
from .yolox_lr_hook import YOLOXLrUpdaterHook
|
||||
from .yolox_mode_switch_hook import YOLOXModeSwitchHook
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
from mmcv.runner import Hook
|
||||
|
@ -45,6 +46,9 @@ class EvalHook(Hook):
|
|||
|
||||
self.mode = mode
|
||||
self.eval_kwargs = eval_kwargs
|
||||
# hook.evaluate runs every interval epoch or iter, popped at init
|
||||
self.vis_config = self.eval_kwargs.pop('visualization_config', {})
|
||||
self.gpu_collect = self.eval_kwargs.pop('gpu_collect', None)
|
||||
self.flush_buffer = flush_buffer
|
||||
|
||||
def before_run(self, runner):
|
||||
|
@ -68,9 +72,23 @@ class EvalHook(Hook):
|
|||
runner.model, self.dataloader, mode=self.mode, show=False)
|
||||
self.evaluate(runner, results)
|
||||
|
||||
def evaluate(self, runner, results):
|
||||
def add_visualization_info(self, runner, results):
|
||||
if runner.visualization_buffer.output.get('eval_results',
|
||||
None) is None:
|
||||
runner.visualization_buffer.output['eval_results'] = OrderedDict()
|
||||
|
||||
if isinstance(self.dataloader, DataLoader):
|
||||
dataset_obj = self.dataloader.dataset
|
||||
else:
|
||||
dataset_obj = self.dataloader
|
||||
|
||||
if hasattr(dataset_obj, 'visualize'):
|
||||
runner.visualization_buffer.output['eval_results'].update(
|
||||
dataset_obj.visualize(results, **self.vis_config))
|
||||
|
||||
def evaluate(self, runner, results):
|
||||
self.add_visualization_info(runner, results)
|
||||
|
||||
gpu_collect = self.eval_kwargs.pop('gpu_collect', None)
|
||||
if isinstance(self.dataloader, DataLoader):
|
||||
eval_res = self.dataloader.dataset.evaluate(
|
||||
results, logger=runner.logger, **self.eval_kwargs)
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.runner.dist_utils import master_only
|
||||
from mmcv.runner.hooks import HOOKS
|
||||
from mmcv.runner.hooks import TensorboardLoggerHook as _TensorboardLoggerHook
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class TensorboardLoggerHookV2(_TensorboardLoggerHook):
|
||||
|
||||
def visualization_log(self, runner):
|
||||
"""Images Visulization.
|
||||
`visualization_buffer` is a dictionary containing:
|
||||
images (list): list of visulaized images.
|
||||
img_metas (list of dict, optional): dict containing ori_filename and so on.
|
||||
ori_filename will be displayed as the tag of the image by default.
|
||||
"""
|
||||
visual_results = runner.visualization_buffer.output
|
||||
for vis_key, vis_result in visual_results.items():
|
||||
images = vis_result.get('images', [])
|
||||
img_metas = vis_result.get('img_metas', None)
|
||||
if img_metas is not None:
|
||||
assert len(images) == len(
|
||||
img_metas
|
||||
), 'Output `images` and `img_metas` must keep the same length!'
|
||||
|
||||
for i, img in enumerate(images):
|
||||
if isinstance(img, np.ndarray):
|
||||
img = torch.from_numpy(img)
|
||||
else:
|
||||
assert isinstance(
|
||||
img, torch.Tensor
|
||||
), 'Only support np.ndarray and torch.Tensor type!'
|
||||
|
||||
default_name = 'image_%i' % i
|
||||
filename = img_metas[i].get(
|
||||
'ori_filename',
|
||||
default_name) if img_metas is not None else default_name
|
||||
self.writer.add_image(
|
||||
f'{vis_key}/{filename}',
|
||||
img,
|
||||
self.get_iter(runner),
|
||||
dataformats='HWC')
|
||||
|
||||
@master_only
|
||||
def log(self, runner):
|
||||
self.visualization_log(runner)
|
||||
super(TensorboardLoggerHookV2, self).log(runner)
|
||||
|
||||
def after_train_iter(self, runner):
|
||||
super(TensorboardLoggerHookV2, self).after_train_iter(runner)
|
||||
# clear visualization_buffer after each iter to ensure that it is only written once,
|
||||
# avoiding repeated writing of the same image buffer every self.interval
|
||||
runner.visualization_buffer.clear_output()
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import cv2
|
||||
import numpy as np
|
||||
from mmcv.runner.dist_utils import master_only
|
||||
from mmcv.runner.hooks import HOOKS
|
||||
from mmcv.runner.hooks import WandbLoggerHook as _WandbLoggerHook
|
||||
from PIL import Image as PILImage
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class WandbLoggerHookV2(_WandbLoggerHook):
|
||||
|
||||
def visualization_log(self, runner):
|
||||
"""Images Visulization.
|
||||
`visualization_buffer` is a dictionary containing:
|
||||
images (list): list of visulaized images.
|
||||
img_metas (list of dict, optional): dict containing ori_filename and so on.
|
||||
ori_filename will be displayed as the tag of the image by default.
|
||||
"""
|
||||
visual_results = runner.visualization_buffer.output
|
||||
for vis_key, vis_result in visual_results.items():
|
||||
images = vis_result.get('images', [])
|
||||
img_metas = vis_result.get('img_metas', None)
|
||||
if img_metas is not None:
|
||||
assert len(images) == len(
|
||||
img_metas
|
||||
), 'Output `images` and `img_metas` must keep the same length!'
|
||||
|
||||
examples = []
|
||||
for i, img in enumerate(images):
|
||||
assert isinstance(img, np.ndarray)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
pil_image = PILImage.fromarray(img, mode='RGB')
|
||||
default_name = 'image_%i' % i
|
||||
filename = img_metas[i].get(
|
||||
'ori_filename',
|
||||
default_name) if img_metas is not None else default_name
|
||||
image = self.wandb.Image(pil_image, caption=filename)
|
||||
examples.append(image)
|
||||
|
||||
self.wandb.log({vis_key: examples},
|
||||
step=self.get_iter(runner),
|
||||
commit=self.commit)
|
||||
|
||||
@master_only
|
||||
def log(self, runner):
|
||||
self.visualization_log(runner)
|
||||
super(WandbLoggerHookV2, self).log(runner)
|
||||
|
||||
def after_train_iter(self, runner):
|
||||
super(WandbLoggerHookV2, self).after_train_iter(runner)
|
||||
# clear visualization_buffer after each iter to ensure that it is only written once,
|
||||
# avoiding repeated writing of the same image buffer every self.interval
|
||||
runner.visualization_buffer.clear_output()
|
|
@ -5,6 +5,7 @@ from distutils.version import LooseVersion
|
|||
|
||||
import torch
|
||||
from mmcv.runner import EpochBasedRunner
|
||||
from mmcv.runner.log_buffer import LogBuffer
|
||||
|
||||
from easycv.file import io
|
||||
from easycv.utils.checkpoint import load_checkpoint, save_checkpoint
|
||||
|
@ -47,6 +48,7 @@ class EVRunner(EpochBasedRunner):
|
|||
meta)
|
||||
self.data_loader = None
|
||||
self.fp16_enable = False
|
||||
self.visualization_buffer = LogBuffer()
|
||||
|
||||
def run_iter(self, data_batch, train_mode, **kwargs):
|
||||
""" process for each iteration.
|
||||
|
|
Binary file not shown.
2
setup.py
2
setup.py
|
@ -159,6 +159,7 @@ def pack_resource():
|
|||
'./thirdparty',
|
||||
proj_dir + 'thirdparty',
|
||||
ignore=shutil.ignore_patterns('*.pyc'))
|
||||
shutil.copytree('./resource', proj_dir + 'resource')
|
||||
shutil.copytree('./requirements', 'package/requirements')
|
||||
shutil.copy('./requirements.txt', 'package/requirements.txt')
|
||||
shutil.copy('./MANIFEST.in', 'package/MANIFEST.in')
|
||||
|
@ -181,6 +182,7 @@ if __name__ == '__main__':
|
|||
keywords='self-supvervised, classification, vision',
|
||||
url='https://github.com/alibaba/EasyCV.git',
|
||||
packages=find_packages(exclude=('configs', 'tools', 'demo')),
|
||||
include_package_data=True,
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
|
|
|
@ -4,7 +4,8 @@ import unittest
|
|||
|
||||
import numpy as np
|
||||
|
||||
from easycv.core.visualization import imshow_bboxes, imshow_keypoints
|
||||
from easycv.core.visualization import (imshow_bboxes, imshow_keypoints,
|
||||
imshow_label)
|
||||
|
||||
|
||||
class ImshowTest(unittest.TestCase):
|
||||
|
@ -32,7 +33,7 @@ class ImshowTest(unittest.TestCase):
|
|||
img = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
bboxes = np.array([[10, 10, 30, 30], [10, 50, 30, 80]],
|
||||
dtype=np.float32)
|
||||
labels = ['label 1', 'label 2']
|
||||
labels = ['标签 1', 'label 2']
|
||||
colors = ['red', 'green']
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
|
@ -65,6 +66,17 @@ class ImshowTest(unittest.TestCase):
|
|||
else:
|
||||
self.fail('ValueError not raised')
|
||||
|
||||
def test_imshow_label(self):
|
||||
img = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
labels = ['标签 1', 'label 2']
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
_ = imshow_label(
|
||||
img=img,
|
||||
labels=labels,
|
||||
show=False,
|
||||
out_file=f'{tmpdir}/out.png')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -13,7 +13,7 @@ class ClsDatasetTest(unittest.TestCase):
|
|||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
def test_default(self):
|
||||
def _get_dataset(self):
|
||||
data_root = SMALL_IMAGENET_RAW_LOCAL
|
||||
data_train_list = os.path.join(data_root, 'meta/train_labeled_200.txt')
|
||||
pipeline = [
|
||||
|
@ -33,12 +33,36 @@ class ClsDatasetTest(unittest.TestCase):
|
|||
pipeline=pipeline))
|
||||
dataset = build_dataset(data['train'])
|
||||
|
||||
return dataset
|
||||
|
||||
def test_default(self):
|
||||
dataset = self._get_dataset()
|
||||
for _, batch in enumerate(dataset):
|
||||
img, target = batch['img'], batch['gt_labels']
|
||||
self.assertEqual(img.shape, torch.Size([3, 224, 224]))
|
||||
self.assertIn(target, list(range(1000)))
|
||||
break
|
||||
|
||||
def test_visualize(self):
|
||||
# TODO: add img_metas for classification
|
||||
return
|
||||
dataset = self._get_dataset()
|
||||
count = 5
|
||||
classes = []
|
||||
img_metas = []
|
||||
for i, data in enumerate(dataset):
|
||||
classes.append(data['gt_labels'])
|
||||
img_metas.append(data['img_metas'].data)
|
||||
if i > count:
|
||||
break
|
||||
|
||||
results = {'class': classes, 'img_metas': img_metas}
|
||||
output = dataset.visualize(results, vis_num=2)
|
||||
self.assertEqual(len(output['images']), 2)
|
||||
self.assertEqual(len(output['img_metas']), 2)
|
||||
self.assertEqual(len(output['images'][0].shape), 3)
|
||||
self.assertIn('filename', output['img_metas'][0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -3,6 +3,7 @@ import os
|
|||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from tests.ut_config import DET_DATA_RAW_LOCAL, IMG_NORM_CFG_255
|
||||
|
||||
from easycv.datasets.detection import DetDataset
|
||||
|
@ -13,7 +14,7 @@ class DetDatasetTest(unittest.TestCase):
|
|||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
def test_load(self):
|
||||
def _get_dataset(self):
|
||||
img_scale = (640, 640)
|
||||
data_source_cfg = dict(
|
||||
type='DetSourceRaw',
|
||||
|
@ -33,6 +34,11 @@ class DetDatasetTest(unittest.TestCase):
|
|||
]
|
||||
|
||||
dataset = DetDataset(data_source=data_source_cfg, pipeline=pipeline)
|
||||
|
||||
return dataset
|
||||
|
||||
def test_load(self):
|
||||
dataset = self._get_dataset()
|
||||
data_num = len(dataset)
|
||||
s = time.time()
|
||||
for data in dataset:
|
||||
|
@ -47,6 +53,37 @@ class DetDatasetTest(unittest.TestCase):
|
|||
self.assertTrue('img_shape' in img_metas)
|
||||
self.assertTrue('ori_img_shape' in img_metas)
|
||||
|
||||
def test_visualize(self):
|
||||
dataset = self._get_dataset()
|
||||
count = 5
|
||||
detection_boxes = []
|
||||
detection_classes = []
|
||||
img_metas = []
|
||||
for i, data in enumerate(dataset):
|
||||
detection_boxes.append(
|
||||
data['gt_bboxes'].data.cpu().detach().numpy())
|
||||
detection_classes.append(
|
||||
data['gt_labels'].data.cpu().detach().numpy())
|
||||
img_metas.append(data['img_metas'].data)
|
||||
if i > count:
|
||||
break
|
||||
|
||||
detection_scores = []
|
||||
for classes in detection_classes:
|
||||
detection_scores.append(0.1 * np.array(range(len(classes))))
|
||||
|
||||
results = {
|
||||
'detection_boxes': detection_boxes,
|
||||
'detection_scores': detection_scores,
|
||||
'detection_classes': detection_classes,
|
||||
'img_metas': img_metas
|
||||
}
|
||||
output = dataset.visualize(results, vis_num=2, score_thr=0.1)
|
||||
self.assertEqual(len(output['images']), 2)
|
||||
self.assertEqual(len(output['img_metas']), 2)
|
||||
self.assertEqual(len(output['images'][0].shape), 3)
|
||||
self.assertIn('filename', output['img_metas'][0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -21,7 +21,7 @@ class SSLSourceImageNetFeatureTest(unittest.TestCase):
|
|||
index_list = random.choices(list(range(100)), k=3)
|
||||
for idx in index_list:
|
||||
results = data_source.get_sample(idx)
|
||||
feat = results['feature']
|
||||
feat = results['img']
|
||||
label = results['gt_labels']
|
||||
self.assertEqual(feat.shape, (2048, ))
|
||||
self.assertIn(label, list(range(1000)))
|
||||
|
@ -37,7 +37,7 @@ class SSLSourceImageNetFeatureTest(unittest.TestCase):
|
|||
index_list = random.choices(list(range(100)), k=3)
|
||||
for idx in index_list:
|
||||
results = data_source.get_sample(idx)
|
||||
feat = results['feature']
|
||||
feat = results['img']
|
||||
label = results['gt_labels']
|
||||
self.assertEqual(feat.shape, (2048, ))
|
||||
self.assertIn(label, list(range(1000)))
|
||||
|
|
Loading…
Reference in New Issue