[Feature]: support image visualization for tensorboard and wandb (#15)

* [Feature]: support image visualization for tensorboard and wandb
pull/21/head^2
Cathy0908 2022-04-21 20:48:58 +08:00 committed by GitHub
parent 3a4a3a9d0c
commit 9a3826f0d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 487 additions and 104 deletions

View File

@ -1,2 +1,3 @@
recursive-include easycv/configs *.py
recursive-include easycv/tools *.py
recursive-include easycv/resource/ *.ttf

View File

@ -23,4 +23,5 @@ data = dict(
dict(type='CenterCrop', size=224),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Collect', keys=['img', 'gt_labels'])
]))

View File

@ -130,7 +130,14 @@ custom_hooks = [
]
# evaluation
eval_config = dict(interval=10, gpu_collect=False)
eval_config = dict(
interval=10,
gpu_collect=False,
visualization_config=dict(
vis_num=10,
score_thr=0.5,
) # show by TensorboardLoggerHookV2 and WandbLoggerHookV2
)
eval_pipelines = [
dict(
mode='test',
@ -168,7 +175,8 @@ log_config = dict(
interval=100,
hooks=[
dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')
dict(type='TensorboardLoggerHookV2'),
# dict(type='WandbLoggerHookV2'),
])
export = dict(use_jit=False)

View File

@ -1,4 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .image import imshow_bboxes, imshow_keypoints
from .image import imshow_bboxes, imshow_keypoints, imshow_label
__all__ = ['imshow_bboxes', 'imshow_keypoints']
__all__ = ['imshow_bboxes', 'imshow_keypoints', 'imshow_label']

View File

@ -1,11 +1,103 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Adapt from https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/visualization/image.py
import math
import os
from os.path import dirname as opd
import cv2
import mmcv
import numpy as np
from mmcv.utils.misc import deprecated_api_warning
from PIL import Image, ImageDraw, ImageFont
def get_font_path():
root_path = opd(opd(opd(os.path.realpath(__file__))))
# find in whl
find_path_whl = os.path.join(root_path, 'resource/simhei.ttf')
# find in source code
find_path_source = os.path.join(opd(root_path), 'resource/simhei.ttf')
if os.path.exists(find_path_whl):
return find_path_whl
elif os.path.exists(find_path_source):
return find_path_source
else:
raise ValueError('Not find font file both in %s and %s' %
(find_path_whl, find_path_source))
_FONT_PATH = get_font_path()
def put_text(img, xy, text, fill, size=20):
"""support chinese text
"""
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(img)
fontText = ImageFont.truetype(_FONT_PATH, size=size, encoding='utf-8')
draw.text(xy, text, fill=fill, font=fontText)
img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
return img
def imshow_label(img,
labels,
text_color='blue',
font_size=20,
thickness=1,
font_scale=0.5,
intervel=5,
show=True,
win_name='',
wait_time=0,
out_file=None):
"""Draw images with labels on an image.
Args:
img (str or ndarray): The image to be displayed.
labels (str or list[str]): labels of each image.
text_color (str or tuple or :obj:`Color`): Color of texts.
font_size (int): Size of font.
thickness (int): Thickness of lines.
font_scale (float): Font scales of texts.
intervelint): interval pixels between multiple labels
show (bool): Whether to show the image.
win_name (str): The window name.
wait_time (int): Value of waitKey param.
out_file (str, optional): The filename to write the image.
Returns:
ndarray: The image with bboxes drawn on it.
"""
img = mmcv.imread(img)
img = np.ascontiguousarray(img)
labels = [labels] if isinstance(labels, str) else labels
cur_height = 0
for label in labels:
# roughly estimate the proper font size
text_size, text_baseline = cv2.getTextSize(label,
cv2.FONT_HERSHEY_DUPLEX,
font_scale, thickness)
org = (text_baseline + text_size[1],
text_baseline + text_size[1] + cur_height)
# support chinese text
# TODO: Unify the font of cv2 and PIL, and auto get font_size according to the font_scale
img = put_text(img, org, text=label, fill=text_color, size=font_size)
# cv2.putText(img, label, org, cv2.FONT_HERSHEY_DUPLEX, font_scale,
# mmcv.color_val(text_color), thickness)
cur_height += text_baseline + text_size[1] + intervel
if show:
mmcv.imshow(img, win_name, wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)
return img
def imshow_bboxes(img,
@ -13,6 +105,7 @@ def imshow_bboxes(img,
labels=None,
colors='green',
text_color='white',
font_size=20,
thickness=1,
font_scale=0.5,
show=True,
@ -29,6 +122,7 @@ def imshow_bboxes(img,
labels (str or list[str], optional): labels of each bbox.
colors (list[str or tuple or :obj:`Color`]): A list of colors.
text_color (str or tuple or :obj:`Color`): Color of texts.
font_size (int): Size of font.
thickness (int): Thickness of lines.
font_scale (float): Font scales of texts.
show (bool): Whether to show the image.
@ -58,11 +152,10 @@ def imshow_bboxes(img,
out_file=None)
if labels is not None:
if not isinstance(labels, list):
labels = [labels for _ in range(len(bboxes))]
assert len(labels) == len(bboxes)
for bbox, label, color in zip(bboxes, labels, colors):
label = str(label)
bbox_int = bbox[0, :4].astype(np.int32)
# roughly estimate the proper font size
text_size, text_baseline = cv2.getTextSize(label,
@ -74,9 +167,17 @@ def imshow_bboxes(img,
text_y2 = text_y1 + text_size[1] + text_baseline
cv2.rectangle(img, (text_x1, text_y1), (text_x2, text_y2), color,
cv2.FILLED)
cv2.putText(img, label, (text_x1, text_y2 - text_baseline),
cv2.FONT_HERSHEY_DUPLEX, font_scale,
mmcv.color_val(text_color), thickness)
# cv2.putText(img, label, (text_x1, text_y2 - text_baseline),
# cv2.FONT_HERSHEY_DUPLEX, font_scale,
# mmcv.color_val(text_color), thickness)
# support chinese text
# TODO: Unify the font of cv2 and PIL, and auto get font_size according to the font_scale
img = put_text(
img, (text_x1, text_y1),
text=label,
fill=text_color,
size=font_size)
if show:
mmcv.imshow(img, win_name, wait_time)

View File

@ -3,6 +3,7 @@
import torch
from PIL import Image
from easycv.core.visualization.image import imshow_label
from easycv.datasets.registry import DATASETS
from easycv.datasets.shared.base import BaseDataset
@ -59,3 +60,37 @@ class ClsDataset(BaseDataset):
eval_res = evaluators[0].evaluate(results, gt_labels)
return eval_res
def visualize(self, results, vis_num=10, **kwargs):
"""Visulaize the model output on validation data.
Args:
results: A dictionary containing
class: List of length number of test images.
img_metas: List of length number of test images,
dict of image meta info, containing filename, img_shape,
origin_img_shape and so on.
vis_num: number of images visualized
Returns: A dictionary containing
images: Visulaized images, list of np.ndarray.
img_metas: List of length number of test images,
dict of image meta info, containing filename, img_shape,
origin_img_shape and so on.
"""
vis_imgs = []
# TODO: support img_metas for torch.jit
if results.get('img_metas', None) is None:
return {}
img_metas = results['img_metas'][:vis_num]
labels = results['class']
for i, img_meta in enumerate(img_metas):
filename = img_meta['filename']
vis_img = imshow_label(img=filename, labels=labels, show=False)
vis_imgs.append(vis_img)
output = {'images': vis_imgs, 'img_metas': img_metas}
return output

View File

@ -9,14 +9,13 @@ import numpy as np
import torch
from easycv.datasets.registry import DATASETS, PIPELINES
from easycv.datasets.shared.base import BaseDataset
from easycv.utils import build_from_cfg
from easycv.utils.bbox_util import batched_xyxy2cxcywh_with_shape
from easycv.utils.bbox_util import xyxy2xywh as xyxy2cxcywh
from .raw import DetDataset
@DATASETS.register_module
class DetImagesMixDataset(BaseDataset):
class DetImagesMixDataset(DetDataset):
"""A wrapper of multiple images mixed dataset.
Suitable for training on multiple images mixed data augmentation like
@ -50,7 +49,7 @@ class DetImagesMixDataset(BaseDataset):
label_padding=True):
super(DetImagesMixDataset, self).__init__(
data_source, pipeline, profiling=profiling)
data_source, pipeline, profiling=profiling, classes=classes)
if skip_type_keys is not None:
assert all([
@ -70,10 +69,9 @@ class DetImagesMixDataset(BaseDataset):
else:
raise TypeError('pipeline must be a dict')
self.CLASSES = classes
if hasattr(self.data_source, 'flag'):
self.flag = self.data_source.flag
self.num_samples = self.data_source.get_length()
if dynamic_scale is not None:
assert isinstance(dynamic_scale, tuple)
@ -83,9 +81,6 @@ class DetImagesMixDataset(BaseDataset):
self.label_padding = label_padding
self.max_labels_num = 120
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
results = copy.deepcopy(self.data_source.get_sample(idx))
for (transform, transform_type) in zip(self.pipeline_yolox,
@ -116,21 +111,8 @@ class DetImagesMixDataset(BaseDataset):
if 'img_scale' in results:
results.pop('img_scale')
# print(result.keys())
# if self.yolo_format:
# # print(type(results['img_metas']), results['img_metas'])
# # print(type(results['img_metas']._data), results['img_metas']._data)
# img_shape = results['img_metas']._data['img_shape'][:2]
# # print(type(results['gt_bboxes']))
# gt_bboxes = xyxy2cxcywh_with_shape(results['gt_bboxes']._data, img_shape)
# results['gt_bboxes'] = gt_bboxes.float()
if self.label_padding:
cxcywh_gt_bboxes = xyxy2cxcywh(results['gt_bboxes']._data)
# cxcywh_gt_bboxes = results['gt_bboxes']._data
padded_gt_bboxes = torch.zeros((self.max_labels_num, 4),
device=cxcywh_gt_bboxes.device)
padded_gt_bboxes[range(cxcywh_gt_bboxes.shape[0])[:self.max_labels_num]] = \
@ -146,9 +128,6 @@ class DetImagesMixDataset(BaseDataset):
results['gt_bboxes'] = padded_gt_bboxes
results['gt_labels'] = padded_labels
# ['img_metas', 'img', 'gt_bboxes', 'gt_labels']
# results.pop('img_metas')
# print(results['img_metas'], "hhh", idx)
return results
def update_skip_type_keys(self, skip_type_keys):
@ -240,30 +219,3 @@ class DetImagesMixDataset(BaseDataset):
tmp_dir = None
result_files = self.results2json(results, jsonfile_prefix)
return result_files, tmp_dir
def evaluate(self, results, evaluators=None, logger=None):
'''results: a dict of list of Tensors, list length equals to number of test images
'''
eval_result = dict()
groundtruth_dict = {}
groundtruth_dict['groundtruth_boxes'] = [
batched_xyxy2cxcywh_with_shape(
self.data_source.get_ann_info(idx)['bboxes'],
results['img_metas'][idx]['ori_img_shape'])
for idx in range(len(results['img_metas']))
]
groundtruth_dict['groundtruth_classes'] = [
self.data_source.get_ann_info(idx)['labels']
for idx in range(len(results['img_metas']))
]
groundtruth_dict['groundtruth_is_crowd'] = [
self.data_source.get_ann_info(idx)['groundtruth_is_crowd']
for idx in range(len(results['img_metas']))
]
for evaluator in evaluators:
eval_result.update(evaluator.evaluate(results, groundtruth_dict))
# print(eval_result)
return eval_result

View File

@ -1,8 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
from easycv.core.visualization.image import imshow_bboxes
from easycv.datasets.registry import DATASETS
from easycv.datasets.shared.base import BaseDataset
from easycv.utils.bbox_util import batched_xyxy2cxcywh_with_shape
@DATASETS.register_module
@ -19,58 +21,124 @@ class DetDataset(BaseDataset):
classes: A list of class names, used in evaluation for result and groundtruth visualization
"""
self.classes = classes
self.CLASSES = classes
super(DetDataset, self).__init__(
data_source, pipeline, profiling=profiling)
self.img_num = self.data_source.get_length()
self.num_samples = self.data_source.get_length()
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
data_dict = self.data_source.get_sample(idx)
data_dict = self.pipeline(data_dict)
return data_dict
def evaluate(self, results, evaluators, logger=None):
'''results: a dict of list of Tensors, list length equals to number of test images
def evaluate(self, results, evaluators=None, logger=None):
'''Evaluates the detection boxes.
Args:
results: A dictionary containing
detection_boxes: List of length number of test images.
Float32 numpy array of shape [num_boxes, 4] and
format [ymin, xmin, ymax, xmax] in absolute image coordinates.
detection_scores: List of length number of test images,
detection scores for the boxes, float32 numpy array of shape [num_boxes].
detection_classes: List of length number of test images,
integer numpy array of shape [num_boxes]
containing 1-indexed detection classes for the boxes.
img_metas: List of length number of test images,
dict of image meta info, containing filename, img_shape,
origin_img_shape, scale_factor and so on.
evaluators: evaluators to calculate metric with results and groundtruth_dict
'''
eval_result = dict()
annotations = self.data_source.get_labels()
groundtruth_dict = {}
groundtruth_dict['groundtruth_boxes'] = [
labels[:,
1:] if len(labels) > 0 else np.array([], dtype=np.float32)
for labels in annotations
batched_xyxy2cxcywh_with_shape(
self.data_source.get_ann_info(idx)['bboxes'],
results['img_metas'][idx]['ori_img_shape'])
for idx in range(len(results['img_metas']))
]
groundtruth_dict['groundtruth_classes'] = [
labels[:, 0] if len(labels) > 0 else np.array([], dtype=np.float32)
for labels in annotations
self.data_source.get_ann_info(idx)['labels']
for idx in range(len(results['img_metas']))
]
# bboxes = [label[:, 1:] for label in annotations]
# scores = [label[:, 0] for label in annotations]
groundtruth_dict['groundtruth_is_crowd'] = [
self.data_source.get_ann_info(idx)['groundtruth_is_crowd']
for idx in range(len(results['img_metas']))
]
for evaluator in evaluators:
eval_result.update(evaluator.evaluate(results, groundtruth_dict))
# eval_res = {'dummy': 1.0}
# img = self.data_source.load_ori_img(0)
# num_box = results['detection_scores'][0].size(0)
# scores = results['detection_scores'][0].detach().cpu().numpy()
# bboxes = torch.cat((results['detection_boxes'][0], results['detection_scores'][0].view(num_box, 1)), axis=1).detach().cpu().numpy()
# labels = results['detection_classes'][0].detach().cpu().numpy().astype(np.int32)
# # draw bounding boxes
# score_th = 0.3
# indices = scores > score_th
# filter_labels = labels[indices]
# print([(self.classes[i], score) for i, score in zip(filter_labels, scores)])
# mmcv.imshow_det_bboxes(
# img,
# bboxes,
# labels,
# class_names=self.classes,
# score_thr=score_th,
# bbox_color='red',
# text_color='black',
# thickness=1,
# font_scale=0.5,
# show=False,
# wait_time=0,
# out_file='test.jpg')
return eval_result
def visualize(self, results, vis_num=10, score_thr=0.3, **kwargs):
"""Visulaize the model output on validation data.
Args:
results: A dictionary containing
detection_boxes: List of length number of test images.
Float32 numpy array of shape [num_boxes, 4] and
format [ymin, xmin, ymax, xmax] in absolute image coordinates.
detection_scores: List of length number of test images,
detection scores for the boxes, float32 numpy array of shape [num_boxes].
detection_classes: List of length number of test images,
integer numpy array of shape [num_boxes]
containing 1-indexed detection classes for the boxes.
img_metas: List of length number of test images,
dict of image meta info, containing filename, img_shape,
origin_img_shape, scale_factor and so on.
vis_num: number of images visualized
score_thr: The threshold to filter box,
boxes with scores greater than score_thr will be kept.
Returns: A dictionary containing
images: Visulaized images.
img_metas: List of length number of test images,
dict of image meta info, containing filename, img_shape,
origin_img_shape, scale_factor and so on.
"""
class_names = None
if hasattr(self.data_source, 'CLASSES'):
class_names = self.data_source.CLASSES
elif hasattr(self.data_source, 'classes'):
class_names = self.data_source.classes
if class_names is not None:
detection_classes = []
for classes_id in results['detection_classes']:
if classes_id is None:
detection_classes.append(None)
else:
detection_classes.append(
np.array([class_names[id] for id in classes_id]))
results['detection_classes'] = detection_classes
vis_imgs = []
img_metas = results['img_metas'][:vis_num]
detection_boxes = results.get('detection_boxes', [])
detection_scores = results.get('detection_scores', [])
detection_classes = results.get('detection_classes', [])
for i, img_meta in enumerate(img_metas):
filename = img_meta['filename']
bboxes = np.array(
[]) if detection_boxes[i] is None else detection_boxes[i]
scores = detection_scores[i]
classes = detection_classes[i]
if scores is not None and score_thr > 0:
inds = scores > score_thr
bboxes = bboxes[inds]
classes = classes[inds]
vis_img = imshow_bboxes(
img=filename, bboxes=bboxes, labels=classes, show=False)
vis_imgs.append(vis_img)
output = {'images': vis_imgs, 'img_metas': img_metas}
return output

View File

@ -28,3 +28,14 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
@abstractmethod
def evaluate(self, results, evaluators, logger=None, **kwargs):
pass
def visualize(self, results, **kwargs):
"""Visulaize the model output results on validation data.
Returns: A dictionary
If add image visualization, return dict containing
images: List of visulaized images.
img_metas: List of length number of test images,
dict of image meta info, containing filename, img_shape,
origin_img_shape, scale_factor and so on.
"""
return {}

View File

@ -18,6 +18,8 @@ from .show_time_hook import TIMEHook
from .swav_hook import SWAVHook
from .sync_norm_hook import SyncNormHook
from .sync_random_size_hook import SyncRandomSizeHook
from .tensorboard import TensorboardLoggerHookV2
from .wandb import WandbLoggerHookV2
from .yolox_lr_hook import YOLOXLrUpdaterHook
from .yolox_mode_switch_hook import YOLOXModeSwitchHook

View File

@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
from collections import OrderedDict
import torch
from mmcv.runner import Hook
@ -45,6 +46,9 @@ class EvalHook(Hook):
self.mode = mode
self.eval_kwargs = eval_kwargs
# hook.evaluate runs every interval epoch or iter, popped at init
self.vis_config = self.eval_kwargs.pop('visualization_config', {})
self.gpu_collect = self.eval_kwargs.pop('gpu_collect', None)
self.flush_buffer = flush_buffer
def before_run(self, runner):
@ -68,9 +72,23 @@ class EvalHook(Hook):
runner.model, self.dataloader, mode=self.mode, show=False)
self.evaluate(runner, results)
def evaluate(self, runner, results):
def add_visualization_info(self, runner, results):
if runner.visualization_buffer.output.get('eval_results',
None) is None:
runner.visualization_buffer.output['eval_results'] = OrderedDict()
if isinstance(self.dataloader, DataLoader):
dataset_obj = self.dataloader.dataset
else:
dataset_obj = self.dataloader
if hasattr(dataset_obj, 'visualize'):
runner.visualization_buffer.output['eval_results'].update(
dataset_obj.visualize(results, **self.vis_config))
def evaluate(self, runner, results):
self.add_visualization_info(runner, results)
gpu_collect = self.eval_kwargs.pop('gpu_collect', None)
if isinstance(self.dataloader, DataLoader):
eval_res = self.dataloader.dataset.evaluate(
results, logger=runner.logger, **self.eval_kwargs)

View File

@ -0,0 +1,55 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
from mmcv.runner.dist_utils import master_only
from mmcv.runner.hooks import HOOKS
from mmcv.runner.hooks import TensorboardLoggerHook as _TensorboardLoggerHook
@HOOKS.register_module()
class TensorboardLoggerHookV2(_TensorboardLoggerHook):
def visualization_log(self, runner):
"""Images Visulization.
`visualization_buffer` is a dictionary containing:
images (list): list of visulaized images.
img_metas (list of dict, optional): dict containing ori_filename and so on.
ori_filename will be displayed as the tag of the image by default.
"""
visual_results = runner.visualization_buffer.output
for vis_key, vis_result in visual_results.items():
images = vis_result.get('images', [])
img_metas = vis_result.get('img_metas', None)
if img_metas is not None:
assert len(images) == len(
img_metas
), 'Output `images` and `img_metas` must keep the same length!'
for i, img in enumerate(images):
if isinstance(img, np.ndarray):
img = torch.from_numpy(img)
else:
assert isinstance(
img, torch.Tensor
), 'Only support np.ndarray and torch.Tensor type!'
default_name = 'image_%i' % i
filename = img_metas[i].get(
'ori_filename',
default_name) if img_metas is not None else default_name
self.writer.add_image(
f'{vis_key}/{filename}',
img,
self.get_iter(runner),
dataformats='HWC')
@master_only
def log(self, runner):
self.visualization_log(runner)
super(TensorboardLoggerHookV2, self).log(runner)
def after_train_iter(self, runner):
super(TensorboardLoggerHookV2, self).after_train_iter(runner)
# clear visualization_buffer after each iter to ensure that it is only written once,
# avoiding repeated writing of the same image buffer every self.interval
runner.visualization_buffer.clear_output()

View File

@ -0,0 +1,54 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import cv2
import numpy as np
from mmcv.runner.dist_utils import master_only
from mmcv.runner.hooks import HOOKS
from mmcv.runner.hooks import WandbLoggerHook as _WandbLoggerHook
from PIL import Image as PILImage
@HOOKS.register_module()
class WandbLoggerHookV2(_WandbLoggerHook):
def visualization_log(self, runner):
"""Images Visulization.
`visualization_buffer` is a dictionary containing:
images (list): list of visulaized images.
img_metas (list of dict, optional): dict containing ori_filename and so on.
ori_filename will be displayed as the tag of the image by default.
"""
visual_results = runner.visualization_buffer.output
for vis_key, vis_result in visual_results.items():
images = vis_result.get('images', [])
img_metas = vis_result.get('img_metas', None)
if img_metas is not None:
assert len(images) == len(
img_metas
), 'Output `images` and `img_metas` must keep the same length!'
examples = []
for i, img in enumerate(images):
assert isinstance(img, np.ndarray)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
pil_image = PILImage.fromarray(img, mode='RGB')
default_name = 'image_%i' % i
filename = img_metas[i].get(
'ori_filename',
default_name) if img_metas is not None else default_name
image = self.wandb.Image(pil_image, caption=filename)
examples.append(image)
self.wandb.log({vis_key: examples},
step=self.get_iter(runner),
commit=self.commit)
@master_only
def log(self, runner):
self.visualization_log(runner)
super(WandbLoggerHookV2, self).log(runner)
def after_train_iter(self, runner):
super(WandbLoggerHookV2, self).after_train_iter(runner)
# clear visualization_buffer after each iter to ensure that it is only written once,
# avoiding repeated writing of the same image buffer every self.interval
runner.visualization_buffer.clear_output()

View File

@ -5,6 +5,7 @@ from distutils.version import LooseVersion
import torch
from mmcv.runner import EpochBasedRunner
from mmcv.runner.log_buffer import LogBuffer
from easycv.file import io
from easycv.utils.checkpoint import load_checkpoint, save_checkpoint
@ -47,6 +48,7 @@ class EVRunner(EpochBasedRunner):
meta)
self.data_loader = None
self.fp16_enable = False
self.visualization_buffer = LogBuffer()
def run_iter(self, data_batch, train_mode, **kwargs):
""" process for each iteration.

BIN
resource/simhei.ttf 100644

Binary file not shown.

View File

@ -159,6 +159,7 @@ def pack_resource():
'./thirdparty',
proj_dir + 'thirdparty',
ignore=shutil.ignore_patterns('*.pyc'))
shutil.copytree('./resource', proj_dir + 'resource')
shutil.copytree('./requirements', 'package/requirements')
shutil.copy('./requirements.txt', 'package/requirements.txt')
shutil.copy('./MANIFEST.in', 'package/MANIFEST.in')
@ -181,6 +182,7 @@ if __name__ == '__main__':
keywords='self-supvervised, classification, vision',
url='https://github.com/alibaba/EasyCV.git',
packages=find_packages(exclude=('configs', 'tools', 'demo')),
include_package_data=True,
classifiers=[
'Development Status :: 4 - Beta',
'License :: OSI Approved :: Apache Software License',

View File

@ -4,7 +4,8 @@ import unittest
import numpy as np
from easycv.core.visualization import imshow_bboxes, imshow_keypoints
from easycv.core.visualization import (imshow_bboxes, imshow_keypoints,
imshow_label)
class ImshowTest(unittest.TestCase):
@ -32,7 +33,7 @@ class ImshowTest(unittest.TestCase):
img = np.zeros((100, 100, 3), dtype=np.uint8)
bboxes = np.array([[10, 10, 30, 30], [10, 50, 30, 80]],
dtype=np.float32)
labels = ['label 1', 'label 2']
labels = ['标签 1', 'label 2']
colors = ['red', 'green']
with tempfile.TemporaryDirectory() as tmpdir:
@ -65,6 +66,17 @@ class ImshowTest(unittest.TestCase):
else:
self.fail('ValueError not raised')
def test_imshow_label(self):
img = np.zeros((100, 100, 3), dtype=np.uint8)
labels = ['标签 1', 'label 2']
with tempfile.TemporaryDirectory() as tmpdir:
_ = imshow_label(
img=img,
labels=labels,
show=False,
out_file=f'{tmpdir}/out.png')
if __name__ == '__main__':
unittest.main()

View File

@ -13,7 +13,7 @@ class ClsDatasetTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_default(self):
def _get_dataset(self):
data_root = SMALL_IMAGENET_RAW_LOCAL
data_train_list = os.path.join(data_root, 'meta/train_labeled_200.txt')
pipeline = [
@ -33,12 +33,36 @@ class ClsDatasetTest(unittest.TestCase):
pipeline=pipeline))
dataset = build_dataset(data['train'])
return dataset
def test_default(self):
dataset = self._get_dataset()
for _, batch in enumerate(dataset):
img, target = batch['img'], batch['gt_labels']
self.assertEqual(img.shape, torch.Size([3, 224, 224]))
self.assertIn(target, list(range(1000)))
break
def test_visualize(self):
# TODO: add img_metas for classification
return
dataset = self._get_dataset()
count = 5
classes = []
img_metas = []
for i, data in enumerate(dataset):
classes.append(data['gt_labels'])
img_metas.append(data['img_metas'].data)
if i > count:
break
results = {'class': classes, 'img_metas': img_metas}
output = dataset.visualize(results, vis_num=2)
self.assertEqual(len(output['images']), 2)
self.assertEqual(len(output['img_metas']), 2)
self.assertEqual(len(output['images'][0].shape), 3)
self.assertIn('filename', output['img_metas'][0])
if __name__ == '__main__':
unittest.main()

View File

@ -3,6 +3,7 @@ import os
import time
import unittest
import numpy as np
from tests.ut_config import DET_DATA_RAW_LOCAL, IMG_NORM_CFG_255
from easycv.datasets.detection import DetDataset
@ -13,7 +14,7 @@ class DetDatasetTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_load(self):
def _get_dataset(self):
img_scale = (640, 640)
data_source_cfg = dict(
type='DetSourceRaw',
@ -33,6 +34,11 @@ class DetDatasetTest(unittest.TestCase):
]
dataset = DetDataset(data_source=data_source_cfg, pipeline=pipeline)
return dataset
def test_load(self):
dataset = self._get_dataset()
data_num = len(dataset)
s = time.time()
for data in dataset:
@ -47,6 +53,37 @@ class DetDatasetTest(unittest.TestCase):
self.assertTrue('img_shape' in img_metas)
self.assertTrue('ori_img_shape' in img_metas)
def test_visualize(self):
dataset = self._get_dataset()
count = 5
detection_boxes = []
detection_classes = []
img_metas = []
for i, data in enumerate(dataset):
detection_boxes.append(
data['gt_bboxes'].data.cpu().detach().numpy())
detection_classes.append(
data['gt_labels'].data.cpu().detach().numpy())
img_metas.append(data['img_metas'].data)
if i > count:
break
detection_scores = []
for classes in detection_classes:
detection_scores.append(0.1 * np.array(range(len(classes))))
results = {
'detection_boxes': detection_boxes,
'detection_scores': detection_scores,
'detection_classes': detection_classes,
'img_metas': img_metas
}
output = dataset.visualize(results, vis_num=2, score_thr=0.1)
self.assertEqual(len(output['images']), 2)
self.assertEqual(len(output['img_metas']), 2)
self.assertEqual(len(output['images'][0].shape), 3)
self.assertIn('filename', output['img_metas'][0])
if __name__ == '__main__':
unittest.main()

View File

@ -21,7 +21,7 @@ class SSLSourceImageNetFeatureTest(unittest.TestCase):
index_list = random.choices(list(range(100)), k=3)
for idx in index_list:
results = data_source.get_sample(idx)
feat = results['feature']
feat = results['img']
label = results['gt_labels']
self.assertEqual(feat.shape, (2048, ))
self.assertIn(label, list(range(1000)))
@ -37,7 +37,7 @@ class SSLSourceImageNetFeatureTest(unittest.TestCase):
index_list = random.choices(list(range(100)), k=3)
for idx in index_list:
results = data_source.get_sample(idx)
feat = results['feature']
feat = results['img']
label = results['gt_labels']
self.assertEqual(feat.shape, (2048, ))
self.assertIn(label, list(range(1000)))