mirror of https://github.com/alibaba/EasyCV.git
[Feature]: support image visualization for tensorboard and wandb (#15)
* [Feature]: support image visualization for tensorboard and wandbpull/21/head^2
parent
3a4a3a9d0c
commit
9a3826f0d2
|
@ -1,2 +1,3 @@
|
||||||
recursive-include easycv/configs *.py
|
recursive-include easycv/configs *.py
|
||||||
recursive-include easycv/tools *.py
|
recursive-include easycv/tools *.py
|
||||||
|
recursive-include easycv/resource/ *.ttf
|
||||||
|
|
|
@ -23,4 +23,5 @@ data = dict(
|
||||||
dict(type='CenterCrop', size=224),
|
dict(type='CenterCrop', size=224),
|
||||||
dict(type='ToTensor'),
|
dict(type='ToTensor'),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Collect', keys=['img', 'gt_labels'])
|
||||||
]))
|
]))
|
||||||
|
|
|
@ -130,7 +130,14 @@ custom_hooks = [
|
||||||
]
|
]
|
||||||
|
|
||||||
# evaluation
|
# evaluation
|
||||||
eval_config = dict(interval=10, gpu_collect=False)
|
eval_config = dict(
|
||||||
|
interval=10,
|
||||||
|
gpu_collect=False,
|
||||||
|
visualization_config=dict(
|
||||||
|
vis_num=10,
|
||||||
|
score_thr=0.5,
|
||||||
|
) # show by TensorboardLoggerHookV2 and WandbLoggerHookV2
|
||||||
|
)
|
||||||
eval_pipelines = [
|
eval_pipelines = [
|
||||||
dict(
|
dict(
|
||||||
mode='test',
|
mode='test',
|
||||||
|
@ -168,7 +175,8 @@ log_config = dict(
|
||||||
interval=100,
|
interval=100,
|
||||||
hooks=[
|
hooks=[
|
||||||
dict(type='TextLoggerHook'),
|
dict(type='TextLoggerHook'),
|
||||||
dict(type='TensorboardLoggerHook')
|
dict(type='TensorboardLoggerHookV2'),
|
||||||
|
# dict(type='WandbLoggerHookV2'),
|
||||||
])
|
])
|
||||||
|
|
||||||
export = dict(use_jit=False)
|
export = dict(use_jit=False)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
from .image import imshow_bboxes, imshow_keypoints
|
from .image import imshow_bboxes, imshow_keypoints, imshow_label
|
||||||
|
|
||||||
__all__ = ['imshow_bboxes', 'imshow_keypoints']
|
__all__ = ['imshow_bboxes', 'imshow_keypoints', 'imshow_label']
|
||||||
|
|
|
@ -1,11 +1,103 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
# Adapt from https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/visualization/image.py
|
# Adapt from https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/visualization/image.py
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
from os.path import dirname as opd
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mmcv.utils.misc import deprecated_api_warning
|
from mmcv.utils.misc import deprecated_api_warning
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
|
||||||
|
def get_font_path():
|
||||||
|
root_path = opd(opd(opd(os.path.realpath(__file__))))
|
||||||
|
# find in whl
|
||||||
|
find_path_whl = os.path.join(root_path, 'resource/simhei.ttf')
|
||||||
|
# find in source code
|
||||||
|
find_path_source = os.path.join(opd(root_path), 'resource/simhei.ttf')
|
||||||
|
if os.path.exists(find_path_whl):
|
||||||
|
return find_path_whl
|
||||||
|
elif os.path.exists(find_path_source):
|
||||||
|
return find_path_source
|
||||||
|
else:
|
||||||
|
raise ValueError('Not find font file both in %s and %s' %
|
||||||
|
(find_path_whl, find_path_source))
|
||||||
|
|
||||||
|
|
||||||
|
_FONT_PATH = get_font_path()
|
||||||
|
|
||||||
|
|
||||||
|
def put_text(img, xy, text, fill, size=20):
|
||||||
|
"""support chinese text
|
||||||
|
"""
|
||||||
|
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
fontText = ImageFont.truetype(_FONT_PATH, size=size, encoding='utf-8')
|
||||||
|
draw.text(xy, text, fill=fill, font=fontText)
|
||||||
|
img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def imshow_label(img,
|
||||||
|
labels,
|
||||||
|
text_color='blue',
|
||||||
|
font_size=20,
|
||||||
|
thickness=1,
|
||||||
|
font_scale=0.5,
|
||||||
|
intervel=5,
|
||||||
|
show=True,
|
||||||
|
win_name='',
|
||||||
|
wait_time=0,
|
||||||
|
out_file=None):
|
||||||
|
"""Draw images with labels on an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (str or ndarray): The image to be displayed.
|
||||||
|
labels (str or list[str]): labels of each image.
|
||||||
|
text_color (str or tuple or :obj:`Color`): Color of texts.
|
||||||
|
font_size (int): Size of font.
|
||||||
|
thickness (int): Thickness of lines.
|
||||||
|
font_scale (float): Font scales of texts.
|
||||||
|
intervel(int): interval pixels between multiple labels
|
||||||
|
show (bool): Whether to show the image.
|
||||||
|
win_name (str): The window name.
|
||||||
|
wait_time (int): Value of waitKey param.
|
||||||
|
out_file (str, optional): The filename to write the image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ndarray: The image with bboxes drawn on it.
|
||||||
|
"""
|
||||||
|
img = mmcv.imread(img)
|
||||||
|
img = np.ascontiguousarray(img)
|
||||||
|
labels = [labels] if isinstance(labels, str) else labels
|
||||||
|
|
||||||
|
cur_height = 0
|
||||||
|
for label in labels:
|
||||||
|
# roughly estimate the proper font size
|
||||||
|
text_size, text_baseline = cv2.getTextSize(label,
|
||||||
|
cv2.FONT_HERSHEY_DUPLEX,
|
||||||
|
font_scale, thickness)
|
||||||
|
|
||||||
|
org = (text_baseline + text_size[1],
|
||||||
|
text_baseline + text_size[1] + cur_height)
|
||||||
|
|
||||||
|
# support chinese text
|
||||||
|
# TODO: Unify the font of cv2 and PIL, and auto get font_size according to the font_scale
|
||||||
|
img = put_text(img, org, text=label, fill=text_color, size=font_size)
|
||||||
|
|
||||||
|
# cv2.putText(img, label, org, cv2.FONT_HERSHEY_DUPLEX, font_scale,
|
||||||
|
# mmcv.color_val(text_color), thickness)
|
||||||
|
|
||||||
|
cur_height += text_baseline + text_size[1] + intervel
|
||||||
|
|
||||||
|
if show:
|
||||||
|
mmcv.imshow(img, win_name, wait_time)
|
||||||
|
if out_file is not None:
|
||||||
|
mmcv.imwrite(img, out_file)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
def imshow_bboxes(img,
|
def imshow_bboxes(img,
|
||||||
|
@ -13,6 +105,7 @@ def imshow_bboxes(img,
|
||||||
labels=None,
|
labels=None,
|
||||||
colors='green',
|
colors='green',
|
||||||
text_color='white',
|
text_color='white',
|
||||||
|
font_size=20,
|
||||||
thickness=1,
|
thickness=1,
|
||||||
font_scale=0.5,
|
font_scale=0.5,
|
||||||
show=True,
|
show=True,
|
||||||
|
@ -29,6 +122,7 @@ def imshow_bboxes(img,
|
||||||
labels (str or list[str], optional): labels of each bbox.
|
labels (str or list[str], optional): labels of each bbox.
|
||||||
colors (list[str or tuple or :obj:`Color`]): A list of colors.
|
colors (list[str or tuple or :obj:`Color`]): A list of colors.
|
||||||
text_color (str or tuple or :obj:`Color`): Color of texts.
|
text_color (str or tuple or :obj:`Color`): Color of texts.
|
||||||
|
font_size (int): Size of font.
|
||||||
thickness (int): Thickness of lines.
|
thickness (int): Thickness of lines.
|
||||||
font_scale (float): Font scales of texts.
|
font_scale (float): Font scales of texts.
|
||||||
show (bool): Whether to show the image.
|
show (bool): Whether to show the image.
|
||||||
|
@ -58,11 +152,10 @@ def imshow_bboxes(img,
|
||||||
out_file=None)
|
out_file=None)
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if not isinstance(labels, list):
|
|
||||||
labels = [labels for _ in range(len(bboxes))]
|
|
||||||
assert len(labels) == len(bboxes)
|
assert len(labels) == len(bboxes)
|
||||||
|
|
||||||
for bbox, label, color in zip(bboxes, labels, colors):
|
for bbox, label, color in zip(bboxes, labels, colors):
|
||||||
|
label = str(label)
|
||||||
bbox_int = bbox[0, :4].astype(np.int32)
|
bbox_int = bbox[0, :4].astype(np.int32)
|
||||||
# roughly estimate the proper font size
|
# roughly estimate the proper font size
|
||||||
text_size, text_baseline = cv2.getTextSize(label,
|
text_size, text_baseline = cv2.getTextSize(label,
|
||||||
|
@ -74,9 +167,17 @@ def imshow_bboxes(img,
|
||||||
text_y2 = text_y1 + text_size[1] + text_baseline
|
text_y2 = text_y1 + text_size[1] + text_baseline
|
||||||
cv2.rectangle(img, (text_x1, text_y1), (text_x2, text_y2), color,
|
cv2.rectangle(img, (text_x1, text_y1), (text_x2, text_y2), color,
|
||||||
cv2.FILLED)
|
cv2.FILLED)
|
||||||
cv2.putText(img, label, (text_x1, text_y2 - text_baseline),
|
# cv2.putText(img, label, (text_x1, text_y2 - text_baseline),
|
||||||
cv2.FONT_HERSHEY_DUPLEX, font_scale,
|
# cv2.FONT_HERSHEY_DUPLEX, font_scale,
|
||||||
mmcv.color_val(text_color), thickness)
|
# mmcv.color_val(text_color), thickness)
|
||||||
|
|
||||||
|
# support chinese text
|
||||||
|
# TODO: Unify the font of cv2 and PIL, and auto get font_size according to the font_scale
|
||||||
|
img = put_text(
|
||||||
|
img, (text_x1, text_y1),
|
||||||
|
text=label,
|
||||||
|
fill=text_color,
|
||||||
|
size=font_size)
|
||||||
|
|
||||||
if show:
|
if show:
|
||||||
mmcv.imshow(img, win_name, wait_time)
|
mmcv.imshow(img, win_name, wait_time)
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from easycv.core.visualization.image import imshow_label
|
||||||
from easycv.datasets.registry import DATASETS
|
from easycv.datasets.registry import DATASETS
|
||||||
from easycv.datasets.shared.base import BaseDataset
|
from easycv.datasets.shared.base import BaseDataset
|
||||||
|
|
||||||
|
@ -59,3 +60,37 @@ class ClsDataset(BaseDataset):
|
||||||
eval_res = evaluators[0].evaluate(results, gt_labels)
|
eval_res = evaluators[0].evaluate(results, gt_labels)
|
||||||
|
|
||||||
return eval_res
|
return eval_res
|
||||||
|
|
||||||
|
def visualize(self, results, vis_num=10, **kwargs):
|
||||||
|
"""Visulaize the model output on validation data.
|
||||||
|
Args:
|
||||||
|
results: A dictionary containing
|
||||||
|
class: List of length number of test images.
|
||||||
|
img_metas: List of length number of test images,
|
||||||
|
dict of image meta info, containing filename, img_shape,
|
||||||
|
origin_img_shape and so on.
|
||||||
|
vis_num: number of images visualized
|
||||||
|
Returns: A dictionary containing
|
||||||
|
images: Visulaized images, list of np.ndarray.
|
||||||
|
img_metas: List of length number of test images,
|
||||||
|
dict of image meta info, containing filename, img_shape,
|
||||||
|
origin_img_shape and so on.
|
||||||
|
"""
|
||||||
|
vis_imgs = []
|
||||||
|
|
||||||
|
# TODO: support img_metas for torch.jit
|
||||||
|
if results.get('img_metas', None) is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
img_metas = results['img_metas'][:vis_num]
|
||||||
|
labels = results['class']
|
||||||
|
|
||||||
|
for i, img_meta in enumerate(img_metas):
|
||||||
|
filename = img_meta['filename']
|
||||||
|
|
||||||
|
vis_img = imshow_label(img=filename, labels=labels, show=False)
|
||||||
|
vis_imgs.append(vis_img)
|
||||||
|
|
||||||
|
output = {'images': vis_imgs, 'img_metas': img_metas}
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
|
@ -9,14 +9,13 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from easycv.datasets.registry import DATASETS, PIPELINES
|
from easycv.datasets.registry import DATASETS, PIPELINES
|
||||||
from easycv.datasets.shared.base import BaseDataset
|
|
||||||
from easycv.utils import build_from_cfg
|
from easycv.utils import build_from_cfg
|
||||||
from easycv.utils.bbox_util import batched_xyxy2cxcywh_with_shape
|
|
||||||
from easycv.utils.bbox_util import xyxy2xywh as xyxy2cxcywh
|
from easycv.utils.bbox_util import xyxy2xywh as xyxy2cxcywh
|
||||||
|
from .raw import DetDataset
|
||||||
|
|
||||||
|
|
||||||
@DATASETS.register_module
|
@DATASETS.register_module
|
||||||
class DetImagesMixDataset(BaseDataset):
|
class DetImagesMixDataset(DetDataset):
|
||||||
"""A wrapper of multiple images mixed dataset.
|
"""A wrapper of multiple images mixed dataset.
|
||||||
|
|
||||||
Suitable for training on multiple images mixed data augmentation like
|
Suitable for training on multiple images mixed data augmentation like
|
||||||
|
@ -50,7 +49,7 @@ class DetImagesMixDataset(BaseDataset):
|
||||||
label_padding=True):
|
label_padding=True):
|
||||||
|
|
||||||
super(DetImagesMixDataset, self).__init__(
|
super(DetImagesMixDataset, self).__init__(
|
||||||
data_source, pipeline, profiling=profiling)
|
data_source, pipeline, profiling=profiling, classes=classes)
|
||||||
|
|
||||||
if skip_type_keys is not None:
|
if skip_type_keys is not None:
|
||||||
assert all([
|
assert all([
|
||||||
|
@ -70,10 +69,9 @@ class DetImagesMixDataset(BaseDataset):
|
||||||
else:
|
else:
|
||||||
raise TypeError('pipeline must be a dict')
|
raise TypeError('pipeline must be a dict')
|
||||||
|
|
||||||
self.CLASSES = classes
|
|
||||||
if hasattr(self.data_source, 'flag'):
|
if hasattr(self.data_source, 'flag'):
|
||||||
self.flag = self.data_source.flag
|
self.flag = self.data_source.flag
|
||||||
self.num_samples = self.data_source.get_length()
|
|
||||||
if dynamic_scale is not None:
|
if dynamic_scale is not None:
|
||||||
assert isinstance(dynamic_scale, tuple)
|
assert isinstance(dynamic_scale, tuple)
|
||||||
|
|
||||||
|
@ -83,9 +81,6 @@ class DetImagesMixDataset(BaseDataset):
|
||||||
self.label_padding = label_padding
|
self.label_padding = label_padding
|
||||||
self.max_labels_num = 120
|
self.max_labels_num = 120
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.num_samples
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
results = copy.deepcopy(self.data_source.get_sample(idx))
|
results = copy.deepcopy(self.data_source.get_sample(idx))
|
||||||
for (transform, transform_type) in zip(self.pipeline_yolox,
|
for (transform, transform_type) in zip(self.pipeline_yolox,
|
||||||
|
@ -116,21 +111,8 @@ class DetImagesMixDataset(BaseDataset):
|
||||||
if 'img_scale' in results:
|
if 'img_scale' in results:
|
||||||
results.pop('img_scale')
|
results.pop('img_scale')
|
||||||
|
|
||||||
# print(result.keys())
|
|
||||||
|
|
||||||
# if self.yolo_format:
|
|
||||||
# # print(type(results['img_metas']), results['img_metas'])
|
|
||||||
# # print(type(results['img_metas']._data), results['img_metas']._data)
|
|
||||||
# img_shape = results['img_metas']._data['img_shape'][:2]
|
|
||||||
# # print(type(results['gt_bboxes']))
|
|
||||||
# gt_bboxes = xyxy2cxcywh_with_shape(results['gt_bboxes']._data, img_shape)
|
|
||||||
# results['gt_bboxes'] = gt_bboxes.float()
|
|
||||||
|
|
||||||
if self.label_padding:
|
if self.label_padding:
|
||||||
|
|
||||||
cxcywh_gt_bboxes = xyxy2cxcywh(results['gt_bboxes']._data)
|
cxcywh_gt_bboxes = xyxy2cxcywh(results['gt_bboxes']._data)
|
||||||
# cxcywh_gt_bboxes = results['gt_bboxes']._data
|
|
||||||
|
|
||||||
padded_gt_bboxes = torch.zeros((self.max_labels_num, 4),
|
padded_gt_bboxes = torch.zeros((self.max_labels_num, 4),
|
||||||
device=cxcywh_gt_bboxes.device)
|
device=cxcywh_gt_bboxes.device)
|
||||||
padded_gt_bboxes[range(cxcywh_gt_bboxes.shape[0])[:self.max_labels_num]] = \
|
padded_gt_bboxes[range(cxcywh_gt_bboxes.shape[0])[:self.max_labels_num]] = \
|
||||||
|
@ -146,9 +128,6 @@ class DetImagesMixDataset(BaseDataset):
|
||||||
results['gt_bboxes'] = padded_gt_bboxes
|
results['gt_bboxes'] = padded_gt_bboxes
|
||||||
results['gt_labels'] = padded_labels
|
results['gt_labels'] = padded_labels
|
||||||
|
|
||||||
# ['img_metas', 'img', 'gt_bboxes', 'gt_labels']
|
|
||||||
# results.pop('img_metas')
|
|
||||||
# print(results['img_metas'], "hhh", idx)
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def update_skip_type_keys(self, skip_type_keys):
|
def update_skip_type_keys(self, skip_type_keys):
|
||||||
|
@ -240,30 +219,3 @@ class DetImagesMixDataset(BaseDataset):
|
||||||
tmp_dir = None
|
tmp_dir = None
|
||||||
result_files = self.results2json(results, jsonfile_prefix)
|
result_files = self.results2json(results, jsonfile_prefix)
|
||||||
return result_files, tmp_dir
|
return result_files, tmp_dir
|
||||||
|
|
||||||
def evaluate(self, results, evaluators=None, logger=None):
|
|
||||||
'''results: a dict of list of Tensors, list length equals to number of test images
|
|
||||||
'''
|
|
||||||
|
|
||||||
eval_result = dict()
|
|
||||||
|
|
||||||
groundtruth_dict = {}
|
|
||||||
groundtruth_dict['groundtruth_boxes'] = [
|
|
||||||
batched_xyxy2cxcywh_with_shape(
|
|
||||||
self.data_source.get_ann_info(idx)['bboxes'],
|
|
||||||
results['img_metas'][idx]['ori_img_shape'])
|
|
||||||
for idx in range(len(results['img_metas']))
|
|
||||||
]
|
|
||||||
groundtruth_dict['groundtruth_classes'] = [
|
|
||||||
self.data_source.get_ann_info(idx)['labels']
|
|
||||||
for idx in range(len(results['img_metas']))
|
|
||||||
]
|
|
||||||
groundtruth_dict['groundtruth_is_crowd'] = [
|
|
||||||
self.data_source.get_ann_info(idx)['groundtruth_is_crowd']
|
|
||||||
for idx in range(len(results['img_metas']))
|
|
||||||
]
|
|
||||||
|
|
||||||
for evaluator in evaluators:
|
|
||||||
eval_result.update(evaluator.evaluate(results, groundtruth_dict))
|
|
||||||
# print(eval_result)
|
|
||||||
return eval_result
|
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from easycv.core.visualization.image import imshow_bboxes
|
||||||
from easycv.datasets.registry import DATASETS
|
from easycv.datasets.registry import DATASETS
|
||||||
from easycv.datasets.shared.base import BaseDataset
|
from easycv.datasets.shared.base import BaseDataset
|
||||||
|
from easycv.utils.bbox_util import batched_xyxy2cxcywh_with_shape
|
||||||
|
|
||||||
|
|
||||||
@DATASETS.register_module
|
@DATASETS.register_module
|
||||||
|
@ -19,58 +21,124 @@ class DetDataset(BaseDataset):
|
||||||
classes: A list of class names, used in evaluation for result and groundtruth visualization
|
classes: A list of class names, used in evaluation for result and groundtruth visualization
|
||||||
"""
|
"""
|
||||||
self.classes = classes
|
self.classes = classes
|
||||||
|
self.CLASSES = classes
|
||||||
|
|
||||||
super(DetDataset, self).__init__(
|
super(DetDataset, self).__init__(
|
||||||
data_source, pipeline, profiling=profiling)
|
data_source, pipeline, profiling=profiling)
|
||||||
self.img_num = self.data_source.get_length()
|
self.num_samples = self.data_source.get_length()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
data_dict = self.data_source.get_sample(idx)
|
data_dict = self.data_source.get_sample(idx)
|
||||||
data_dict = self.pipeline(data_dict)
|
data_dict = self.pipeline(data_dict)
|
||||||
return data_dict
|
return data_dict
|
||||||
|
|
||||||
def evaluate(self, results, evaluators, logger=None):
|
def evaluate(self, results, evaluators=None, logger=None):
|
||||||
'''results: a dict of list of Tensors, list length equals to number of test images
|
'''Evaluates the detection boxes.
|
||||||
|
Args:
|
||||||
|
results: A dictionary containing
|
||||||
|
detection_boxes: List of length number of test images.
|
||||||
|
Float32 numpy array of shape [num_boxes, 4] and
|
||||||
|
format [ymin, xmin, ymax, xmax] in absolute image coordinates.
|
||||||
|
detection_scores: List of length number of test images,
|
||||||
|
detection scores for the boxes, float32 numpy array of shape [num_boxes].
|
||||||
|
detection_classes: List of length number of test images,
|
||||||
|
integer numpy array of shape [num_boxes]
|
||||||
|
containing 1-indexed detection classes for the boxes.
|
||||||
|
img_metas: List of length number of test images,
|
||||||
|
dict of image meta info, containing filename, img_shape,
|
||||||
|
origin_img_shape, scale_factor and so on.
|
||||||
|
evaluators: evaluators to calculate metric with results and groundtruth_dict
|
||||||
'''
|
'''
|
||||||
|
|
||||||
eval_result = dict()
|
eval_result = dict()
|
||||||
annotations = self.data_source.get_labels()
|
|
||||||
groundtruth_dict = {}
|
groundtruth_dict = {}
|
||||||
groundtruth_dict['groundtruth_boxes'] = [
|
groundtruth_dict['groundtruth_boxes'] = [
|
||||||
labels[:,
|
batched_xyxy2cxcywh_with_shape(
|
||||||
1:] if len(labels) > 0 else np.array([], dtype=np.float32)
|
self.data_source.get_ann_info(idx)['bboxes'],
|
||||||
for labels in annotations
|
results['img_metas'][idx]['ori_img_shape'])
|
||||||
|
for idx in range(len(results['img_metas']))
|
||||||
]
|
]
|
||||||
groundtruth_dict['groundtruth_classes'] = [
|
groundtruth_dict['groundtruth_classes'] = [
|
||||||
labels[:, 0] if len(labels) > 0 else np.array([], dtype=np.float32)
|
self.data_source.get_ann_info(idx)['labels']
|
||||||
for labels in annotations
|
for idx in range(len(results['img_metas']))
|
||||||
]
|
]
|
||||||
# bboxes = [label[:, 1:] for label in annotations]
|
groundtruth_dict['groundtruth_is_crowd'] = [
|
||||||
# scores = [label[:, 0] for label in annotations]
|
self.data_source.get_ann_info(idx)['groundtruth_is_crowd']
|
||||||
|
for idx in range(len(results['img_metas']))
|
||||||
|
]
|
||||||
|
|
||||||
for evaluator in evaluators:
|
for evaluator in evaluators:
|
||||||
eval_result.update(evaluator.evaluate(results, groundtruth_dict))
|
eval_result.update(evaluator.evaluate(results, groundtruth_dict))
|
||||||
# eval_res = {'dummy': 1.0}
|
|
||||||
# img = self.data_source.load_ori_img(0)
|
|
||||||
# num_box = results['detection_scores'][0].size(0)
|
|
||||||
# scores = results['detection_scores'][0].detach().cpu().numpy()
|
|
||||||
# bboxes = torch.cat((results['detection_boxes'][0], results['detection_scores'][0].view(num_box, 1)), axis=1).detach().cpu().numpy()
|
|
||||||
# labels = results['detection_classes'][0].detach().cpu().numpy().astype(np.int32)
|
|
||||||
# # draw bounding boxes
|
|
||||||
# score_th = 0.3
|
|
||||||
# indices = scores > score_th
|
|
||||||
# filter_labels = labels[indices]
|
|
||||||
# print([(self.classes[i], score) for i, score in zip(filter_labels, scores)])
|
|
||||||
# mmcv.imshow_det_bboxes(
|
|
||||||
# img,
|
|
||||||
# bboxes,
|
|
||||||
# labels,
|
|
||||||
# class_names=self.classes,
|
|
||||||
# score_thr=score_th,
|
|
||||||
# bbox_color='red',
|
|
||||||
# text_color='black',
|
|
||||||
# thickness=1,
|
|
||||||
# font_scale=0.5,
|
|
||||||
# show=False,
|
|
||||||
# wait_time=0,
|
|
||||||
# out_file='test.jpg')
|
|
||||||
|
|
||||||
return eval_result
|
return eval_result
|
||||||
|
|
||||||
|
def visualize(self, results, vis_num=10, score_thr=0.3, **kwargs):
|
||||||
|
"""Visulaize the model output on validation data.
|
||||||
|
Args:
|
||||||
|
results: A dictionary containing
|
||||||
|
detection_boxes: List of length number of test images.
|
||||||
|
Float32 numpy array of shape [num_boxes, 4] and
|
||||||
|
format [ymin, xmin, ymax, xmax] in absolute image coordinates.
|
||||||
|
detection_scores: List of length number of test images,
|
||||||
|
detection scores for the boxes, float32 numpy array of shape [num_boxes].
|
||||||
|
detection_classes: List of length number of test images,
|
||||||
|
integer numpy array of shape [num_boxes]
|
||||||
|
containing 1-indexed detection classes for the boxes.
|
||||||
|
img_metas: List of length number of test images,
|
||||||
|
dict of image meta info, containing filename, img_shape,
|
||||||
|
origin_img_shape, scale_factor and so on.
|
||||||
|
vis_num: number of images visualized
|
||||||
|
score_thr: The threshold to filter box,
|
||||||
|
boxes with scores greater than score_thr will be kept.
|
||||||
|
Returns: A dictionary containing
|
||||||
|
images: Visulaized images.
|
||||||
|
img_metas: List of length number of test images,
|
||||||
|
dict of image meta info, containing filename, img_shape,
|
||||||
|
origin_img_shape, scale_factor and so on.
|
||||||
|
"""
|
||||||
|
class_names = None
|
||||||
|
if hasattr(self.data_source, 'CLASSES'):
|
||||||
|
class_names = self.data_source.CLASSES
|
||||||
|
elif hasattr(self.data_source, 'classes'):
|
||||||
|
class_names = self.data_source.classes
|
||||||
|
|
||||||
|
if class_names is not None:
|
||||||
|
detection_classes = []
|
||||||
|
for classes_id in results['detection_classes']:
|
||||||
|
if classes_id is None:
|
||||||
|
detection_classes.append(None)
|
||||||
|
else:
|
||||||
|
detection_classes.append(
|
||||||
|
np.array([class_names[id] for id in classes_id]))
|
||||||
|
results['detection_classes'] = detection_classes
|
||||||
|
|
||||||
|
vis_imgs = []
|
||||||
|
|
||||||
|
img_metas = results['img_metas'][:vis_num]
|
||||||
|
detection_boxes = results.get('detection_boxes', [])
|
||||||
|
detection_scores = results.get('detection_scores', [])
|
||||||
|
detection_classes = results.get('detection_classes', [])
|
||||||
|
|
||||||
|
for i, img_meta in enumerate(img_metas):
|
||||||
|
filename = img_meta['filename']
|
||||||
|
bboxes = np.array(
|
||||||
|
[]) if detection_boxes[i] is None else detection_boxes[i]
|
||||||
|
scores = detection_scores[i]
|
||||||
|
classes = detection_classes[i]
|
||||||
|
|
||||||
|
if scores is not None and score_thr > 0:
|
||||||
|
inds = scores > score_thr
|
||||||
|
bboxes = bboxes[inds]
|
||||||
|
classes = classes[inds]
|
||||||
|
|
||||||
|
vis_img = imshow_bboxes(
|
||||||
|
img=filename, bboxes=bboxes, labels=classes, show=False)
|
||||||
|
vis_imgs.append(vis_img)
|
||||||
|
|
||||||
|
output = {'images': vis_imgs, 'img_metas': img_metas}
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
|
@ -28,3 +28,14 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def evaluate(self, results, evaluators, logger=None, **kwargs):
|
def evaluate(self, results, evaluators, logger=None, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def visualize(self, results, **kwargs):
|
||||||
|
"""Visulaize the model output results on validation data.
|
||||||
|
Returns: A dictionary
|
||||||
|
If add image visualization, return dict containing
|
||||||
|
images: List of visulaized images.
|
||||||
|
img_metas: List of length number of test images,
|
||||||
|
dict of image meta info, containing filename, img_shape,
|
||||||
|
origin_img_shape, scale_factor and so on.
|
||||||
|
"""
|
||||||
|
return {}
|
||||||
|
|
|
@ -18,6 +18,8 @@ from .show_time_hook import TIMEHook
|
||||||
from .swav_hook import SWAVHook
|
from .swav_hook import SWAVHook
|
||||||
from .sync_norm_hook import SyncNormHook
|
from .sync_norm_hook import SyncNormHook
|
||||||
from .sync_random_size_hook import SyncRandomSizeHook
|
from .sync_random_size_hook import SyncRandomSizeHook
|
||||||
|
from .tensorboard import TensorboardLoggerHookV2
|
||||||
|
from .wandb import WandbLoggerHookV2
|
||||||
from .yolox_lr_hook import YOLOXLrUpdaterHook
|
from .yolox_lr_hook import YOLOXLrUpdaterHook
|
||||||
from .yolox_mode_switch_hook import YOLOXModeSwitchHook
|
from .yolox_mode_switch_hook import YOLOXModeSwitchHook
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from mmcv.runner import Hook
|
from mmcv.runner import Hook
|
||||||
|
@ -45,6 +46,9 @@ class EvalHook(Hook):
|
||||||
|
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.eval_kwargs = eval_kwargs
|
self.eval_kwargs = eval_kwargs
|
||||||
|
# hook.evaluate runs every interval epoch or iter, popped at init
|
||||||
|
self.vis_config = self.eval_kwargs.pop('visualization_config', {})
|
||||||
|
self.gpu_collect = self.eval_kwargs.pop('gpu_collect', None)
|
||||||
self.flush_buffer = flush_buffer
|
self.flush_buffer = flush_buffer
|
||||||
|
|
||||||
def before_run(self, runner):
|
def before_run(self, runner):
|
||||||
|
@ -68,9 +72,23 @@ class EvalHook(Hook):
|
||||||
runner.model, self.dataloader, mode=self.mode, show=False)
|
runner.model, self.dataloader, mode=self.mode, show=False)
|
||||||
self.evaluate(runner, results)
|
self.evaluate(runner, results)
|
||||||
|
|
||||||
def evaluate(self, runner, results):
|
def add_visualization_info(self, runner, results):
|
||||||
|
if runner.visualization_buffer.output.get('eval_results',
|
||||||
|
None) is None:
|
||||||
|
runner.visualization_buffer.output['eval_results'] = OrderedDict()
|
||||||
|
|
||||||
|
if isinstance(self.dataloader, DataLoader):
|
||||||
|
dataset_obj = self.dataloader.dataset
|
||||||
|
else:
|
||||||
|
dataset_obj = self.dataloader
|
||||||
|
|
||||||
|
if hasattr(dataset_obj, 'visualize'):
|
||||||
|
runner.visualization_buffer.output['eval_results'].update(
|
||||||
|
dataset_obj.visualize(results, **self.vis_config))
|
||||||
|
|
||||||
|
def evaluate(self, runner, results):
|
||||||
|
self.add_visualization_info(runner, results)
|
||||||
|
|
||||||
gpu_collect = self.eval_kwargs.pop('gpu_collect', None)
|
|
||||||
if isinstance(self.dataloader, DataLoader):
|
if isinstance(self.dataloader, DataLoader):
|
||||||
eval_res = self.dataloader.dataset.evaluate(
|
eval_res = self.dataloader.dataset.evaluate(
|
||||||
results, logger=runner.logger, **self.eval_kwargs)
|
results, logger=runner.logger, **self.eval_kwargs)
|
||||||
|
|
|
@ -0,0 +1,55 @@
|
||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from mmcv.runner.dist_utils import master_only
|
||||||
|
from mmcv.runner.hooks import HOOKS
|
||||||
|
from mmcv.runner.hooks import TensorboardLoggerHook as _TensorboardLoggerHook
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class TensorboardLoggerHookV2(_TensorboardLoggerHook):
|
||||||
|
|
||||||
|
def visualization_log(self, runner):
|
||||||
|
"""Images Visulization.
|
||||||
|
`visualization_buffer` is a dictionary containing:
|
||||||
|
images (list): list of visulaized images.
|
||||||
|
img_metas (list of dict, optional): dict containing ori_filename and so on.
|
||||||
|
ori_filename will be displayed as the tag of the image by default.
|
||||||
|
"""
|
||||||
|
visual_results = runner.visualization_buffer.output
|
||||||
|
for vis_key, vis_result in visual_results.items():
|
||||||
|
images = vis_result.get('images', [])
|
||||||
|
img_metas = vis_result.get('img_metas', None)
|
||||||
|
if img_metas is not None:
|
||||||
|
assert len(images) == len(
|
||||||
|
img_metas
|
||||||
|
), 'Output `images` and `img_metas` must keep the same length!'
|
||||||
|
|
||||||
|
for i, img in enumerate(images):
|
||||||
|
if isinstance(img, np.ndarray):
|
||||||
|
img = torch.from_numpy(img)
|
||||||
|
else:
|
||||||
|
assert isinstance(
|
||||||
|
img, torch.Tensor
|
||||||
|
), 'Only support np.ndarray and torch.Tensor type!'
|
||||||
|
|
||||||
|
default_name = 'image_%i' % i
|
||||||
|
filename = img_metas[i].get(
|
||||||
|
'ori_filename',
|
||||||
|
default_name) if img_metas is not None else default_name
|
||||||
|
self.writer.add_image(
|
||||||
|
f'{vis_key}/{filename}',
|
||||||
|
img,
|
||||||
|
self.get_iter(runner),
|
||||||
|
dataformats='HWC')
|
||||||
|
|
||||||
|
@master_only
|
||||||
|
def log(self, runner):
|
||||||
|
self.visualization_log(runner)
|
||||||
|
super(TensorboardLoggerHookV2, self).log(runner)
|
||||||
|
|
||||||
|
def after_train_iter(self, runner):
|
||||||
|
super(TensorboardLoggerHookV2, self).after_train_iter(runner)
|
||||||
|
# clear visualization_buffer after each iter to ensure that it is only written once,
|
||||||
|
# avoiding repeated writing of the same image buffer every self.interval
|
||||||
|
runner.visualization_buffer.clear_output()
|
|
@ -0,0 +1,54 @@
|
||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from mmcv.runner.dist_utils import master_only
|
||||||
|
from mmcv.runner.hooks import HOOKS
|
||||||
|
from mmcv.runner.hooks import WandbLoggerHook as _WandbLoggerHook
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class WandbLoggerHookV2(_WandbLoggerHook):
|
||||||
|
|
||||||
|
def visualization_log(self, runner):
|
||||||
|
"""Images Visulization.
|
||||||
|
`visualization_buffer` is a dictionary containing:
|
||||||
|
images (list): list of visulaized images.
|
||||||
|
img_metas (list of dict, optional): dict containing ori_filename and so on.
|
||||||
|
ori_filename will be displayed as the tag of the image by default.
|
||||||
|
"""
|
||||||
|
visual_results = runner.visualization_buffer.output
|
||||||
|
for vis_key, vis_result in visual_results.items():
|
||||||
|
images = vis_result.get('images', [])
|
||||||
|
img_metas = vis_result.get('img_metas', None)
|
||||||
|
if img_metas is not None:
|
||||||
|
assert len(images) == len(
|
||||||
|
img_metas
|
||||||
|
), 'Output `images` and `img_metas` must keep the same length!'
|
||||||
|
|
||||||
|
examples = []
|
||||||
|
for i, img in enumerate(images):
|
||||||
|
assert isinstance(img, np.ndarray)
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
pil_image = PILImage.fromarray(img, mode='RGB')
|
||||||
|
default_name = 'image_%i' % i
|
||||||
|
filename = img_metas[i].get(
|
||||||
|
'ori_filename',
|
||||||
|
default_name) if img_metas is not None else default_name
|
||||||
|
image = self.wandb.Image(pil_image, caption=filename)
|
||||||
|
examples.append(image)
|
||||||
|
|
||||||
|
self.wandb.log({vis_key: examples},
|
||||||
|
step=self.get_iter(runner),
|
||||||
|
commit=self.commit)
|
||||||
|
|
||||||
|
@master_only
|
||||||
|
def log(self, runner):
|
||||||
|
self.visualization_log(runner)
|
||||||
|
super(WandbLoggerHookV2, self).log(runner)
|
||||||
|
|
||||||
|
def after_train_iter(self, runner):
|
||||||
|
super(WandbLoggerHookV2, self).after_train_iter(runner)
|
||||||
|
# clear visualization_buffer after each iter to ensure that it is only written once,
|
||||||
|
# avoiding repeated writing of the same image buffer every self.interval
|
||||||
|
runner.visualization_buffer.clear_output()
|
|
@ -5,6 +5,7 @@ from distutils.version import LooseVersion
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from mmcv.runner import EpochBasedRunner
|
from mmcv.runner import EpochBasedRunner
|
||||||
|
from mmcv.runner.log_buffer import LogBuffer
|
||||||
|
|
||||||
from easycv.file import io
|
from easycv.file import io
|
||||||
from easycv.utils.checkpoint import load_checkpoint, save_checkpoint
|
from easycv.utils.checkpoint import load_checkpoint, save_checkpoint
|
||||||
|
@ -47,6 +48,7 @@ class EVRunner(EpochBasedRunner):
|
||||||
meta)
|
meta)
|
||||||
self.data_loader = None
|
self.data_loader = None
|
||||||
self.fp16_enable = False
|
self.fp16_enable = False
|
||||||
|
self.visualization_buffer = LogBuffer()
|
||||||
|
|
||||||
def run_iter(self, data_batch, train_mode, **kwargs):
|
def run_iter(self, data_batch, train_mode, **kwargs):
|
||||||
""" process for each iteration.
|
""" process for each iteration.
|
||||||
|
|
Binary file not shown.
2
setup.py
2
setup.py
|
@ -159,6 +159,7 @@ def pack_resource():
|
||||||
'./thirdparty',
|
'./thirdparty',
|
||||||
proj_dir + 'thirdparty',
|
proj_dir + 'thirdparty',
|
||||||
ignore=shutil.ignore_patterns('*.pyc'))
|
ignore=shutil.ignore_patterns('*.pyc'))
|
||||||
|
shutil.copytree('./resource', proj_dir + 'resource')
|
||||||
shutil.copytree('./requirements', 'package/requirements')
|
shutil.copytree('./requirements', 'package/requirements')
|
||||||
shutil.copy('./requirements.txt', 'package/requirements.txt')
|
shutil.copy('./requirements.txt', 'package/requirements.txt')
|
||||||
shutil.copy('./MANIFEST.in', 'package/MANIFEST.in')
|
shutil.copy('./MANIFEST.in', 'package/MANIFEST.in')
|
||||||
|
@ -181,6 +182,7 @@ if __name__ == '__main__':
|
||||||
keywords='self-supvervised, classification, vision',
|
keywords='self-supvervised, classification, vision',
|
||||||
url='https://github.com/alibaba/EasyCV.git',
|
url='https://github.com/alibaba/EasyCV.git',
|
||||||
packages=find_packages(exclude=('configs', 'tools', 'demo')),
|
packages=find_packages(exclude=('configs', 'tools', 'demo')),
|
||||||
|
include_package_data=True,
|
||||||
classifiers=[
|
classifiers=[
|
||||||
'Development Status :: 4 - Beta',
|
'Development Status :: 4 - Beta',
|
||||||
'License :: OSI Approved :: Apache Software License',
|
'License :: OSI Approved :: Apache Software License',
|
||||||
|
|
|
@ -4,7 +4,8 @@ import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from easycv.core.visualization import imshow_bboxes, imshow_keypoints
|
from easycv.core.visualization import (imshow_bboxes, imshow_keypoints,
|
||||||
|
imshow_label)
|
||||||
|
|
||||||
|
|
||||||
class ImshowTest(unittest.TestCase):
|
class ImshowTest(unittest.TestCase):
|
||||||
|
@ -32,7 +33,7 @@ class ImshowTest(unittest.TestCase):
|
||||||
img = np.zeros((100, 100, 3), dtype=np.uint8)
|
img = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
bboxes = np.array([[10, 10, 30, 30], [10, 50, 30, 80]],
|
bboxes = np.array([[10, 10, 30, 30], [10, 50, 30, 80]],
|
||||||
dtype=np.float32)
|
dtype=np.float32)
|
||||||
labels = ['label 1', 'label 2']
|
labels = ['标签 1', 'label 2']
|
||||||
colors = ['red', 'green']
|
colors = ['red', 'green']
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
@ -65,6 +66,17 @@ class ImshowTest(unittest.TestCase):
|
||||||
else:
|
else:
|
||||||
self.fail('ValueError not raised')
|
self.fail('ValueError not raised')
|
||||||
|
|
||||||
|
def test_imshow_label(self):
|
||||||
|
img = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
|
labels = ['标签 1', 'label 2']
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
_ = imshow_label(
|
||||||
|
img=img,
|
||||||
|
labels=labels,
|
||||||
|
show=False,
|
||||||
|
out_file=f'{tmpdir}/out.png')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -13,7 +13,7 @@ class ClsDatasetTest(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||||
|
|
||||||
def test_default(self):
|
def _get_dataset(self):
|
||||||
data_root = SMALL_IMAGENET_RAW_LOCAL
|
data_root = SMALL_IMAGENET_RAW_LOCAL
|
||||||
data_train_list = os.path.join(data_root, 'meta/train_labeled_200.txt')
|
data_train_list = os.path.join(data_root, 'meta/train_labeled_200.txt')
|
||||||
pipeline = [
|
pipeline = [
|
||||||
|
@ -33,12 +33,36 @@ class ClsDatasetTest(unittest.TestCase):
|
||||||
pipeline=pipeline))
|
pipeline=pipeline))
|
||||||
dataset = build_dataset(data['train'])
|
dataset = build_dataset(data['train'])
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def test_default(self):
|
||||||
|
dataset = self._get_dataset()
|
||||||
for _, batch in enumerate(dataset):
|
for _, batch in enumerate(dataset):
|
||||||
img, target = batch['img'], batch['gt_labels']
|
img, target = batch['img'], batch['gt_labels']
|
||||||
self.assertEqual(img.shape, torch.Size([3, 224, 224]))
|
self.assertEqual(img.shape, torch.Size([3, 224, 224]))
|
||||||
self.assertIn(target, list(range(1000)))
|
self.assertIn(target, list(range(1000)))
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def test_visualize(self):
|
||||||
|
# TODO: add img_metas for classification
|
||||||
|
return
|
||||||
|
dataset = self._get_dataset()
|
||||||
|
count = 5
|
||||||
|
classes = []
|
||||||
|
img_metas = []
|
||||||
|
for i, data in enumerate(dataset):
|
||||||
|
classes.append(data['gt_labels'])
|
||||||
|
img_metas.append(data['img_metas'].data)
|
||||||
|
if i > count:
|
||||||
|
break
|
||||||
|
|
||||||
|
results = {'class': classes, 'img_metas': img_metas}
|
||||||
|
output = dataset.visualize(results, vis_num=2)
|
||||||
|
self.assertEqual(len(output['images']), 2)
|
||||||
|
self.assertEqual(len(output['img_metas']), 2)
|
||||||
|
self.assertEqual(len(output['images'][0].shape), 3)
|
||||||
|
self.assertIn('filename', output['img_metas'][0])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -3,6 +3,7 @@ import os
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from tests.ut_config import DET_DATA_RAW_LOCAL, IMG_NORM_CFG_255
|
from tests.ut_config import DET_DATA_RAW_LOCAL, IMG_NORM_CFG_255
|
||||||
|
|
||||||
from easycv.datasets.detection import DetDataset
|
from easycv.datasets.detection import DetDataset
|
||||||
|
@ -13,7 +14,7 @@ class DetDatasetTest(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||||
|
|
||||||
def test_load(self):
|
def _get_dataset(self):
|
||||||
img_scale = (640, 640)
|
img_scale = (640, 640)
|
||||||
data_source_cfg = dict(
|
data_source_cfg = dict(
|
||||||
type='DetSourceRaw',
|
type='DetSourceRaw',
|
||||||
|
@ -33,6 +34,11 @@ class DetDatasetTest(unittest.TestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
dataset = DetDataset(data_source=data_source_cfg, pipeline=pipeline)
|
dataset = DetDataset(data_source=data_source_cfg, pipeline=pipeline)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def test_load(self):
|
||||||
|
dataset = self._get_dataset()
|
||||||
data_num = len(dataset)
|
data_num = len(dataset)
|
||||||
s = time.time()
|
s = time.time()
|
||||||
for data in dataset:
|
for data in dataset:
|
||||||
|
@ -47,6 +53,37 @@ class DetDatasetTest(unittest.TestCase):
|
||||||
self.assertTrue('img_shape' in img_metas)
|
self.assertTrue('img_shape' in img_metas)
|
||||||
self.assertTrue('ori_img_shape' in img_metas)
|
self.assertTrue('ori_img_shape' in img_metas)
|
||||||
|
|
||||||
|
def test_visualize(self):
|
||||||
|
dataset = self._get_dataset()
|
||||||
|
count = 5
|
||||||
|
detection_boxes = []
|
||||||
|
detection_classes = []
|
||||||
|
img_metas = []
|
||||||
|
for i, data in enumerate(dataset):
|
||||||
|
detection_boxes.append(
|
||||||
|
data['gt_bboxes'].data.cpu().detach().numpy())
|
||||||
|
detection_classes.append(
|
||||||
|
data['gt_labels'].data.cpu().detach().numpy())
|
||||||
|
img_metas.append(data['img_metas'].data)
|
||||||
|
if i > count:
|
||||||
|
break
|
||||||
|
|
||||||
|
detection_scores = []
|
||||||
|
for classes in detection_classes:
|
||||||
|
detection_scores.append(0.1 * np.array(range(len(classes))))
|
||||||
|
|
||||||
|
results = {
|
||||||
|
'detection_boxes': detection_boxes,
|
||||||
|
'detection_scores': detection_scores,
|
||||||
|
'detection_classes': detection_classes,
|
||||||
|
'img_metas': img_metas
|
||||||
|
}
|
||||||
|
output = dataset.visualize(results, vis_num=2, score_thr=0.1)
|
||||||
|
self.assertEqual(len(output['images']), 2)
|
||||||
|
self.assertEqual(len(output['img_metas']), 2)
|
||||||
|
self.assertEqual(len(output['images'][0].shape), 3)
|
||||||
|
self.assertIn('filename', output['img_metas'][0])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -21,7 +21,7 @@ class SSLSourceImageNetFeatureTest(unittest.TestCase):
|
||||||
index_list = random.choices(list(range(100)), k=3)
|
index_list = random.choices(list(range(100)), k=3)
|
||||||
for idx in index_list:
|
for idx in index_list:
|
||||||
results = data_source.get_sample(idx)
|
results = data_source.get_sample(idx)
|
||||||
feat = results['feature']
|
feat = results['img']
|
||||||
label = results['gt_labels']
|
label = results['gt_labels']
|
||||||
self.assertEqual(feat.shape, (2048, ))
|
self.assertEqual(feat.shape, (2048, ))
|
||||||
self.assertIn(label, list(range(1000)))
|
self.assertIn(label, list(range(1000)))
|
||||||
|
@ -37,7 +37,7 @@ class SSLSourceImageNetFeatureTest(unittest.TestCase):
|
||||||
index_list = random.choices(list(range(100)), k=3)
|
index_list = random.choices(list(range(100)), k=3)
|
||||||
for idx in index_list:
|
for idx in index_list:
|
||||||
results = data_source.get_sample(idx)
|
results = data_source.get_sample(idx)
|
||||||
feat = results['feature']
|
feat = results['img']
|
||||||
label = results['gt_labels']
|
label = results['gt_labels']
|
||||||
self.assertEqual(feat.shape, (2048, ))
|
self.assertEqual(feat.shape, (2048, ))
|
||||||
self.assertIn(label, list(range(1000)))
|
self.assertIn(label, list(range(1000)))
|
||||||
|
|
Loading…
Reference in New Issue