Fix some browse dataset script bugs and draw textdet gt instance with ignore flags (#1701)

* [Enhancement] textdet draw gt instance with ignore flags

* [Fix] 明确key值定义,防止后续使用img_path时得到lmdb格式中的img_key导致图片无法读取

* [Fix] fix five browse_dataset.py script bugs

* [Fix] fix some pr problems

* [Fix] keep img_path attribute

* [Fix] 防止width很大text很小时,font_size过大显示不全(做keep_ratio的resize然后padding到固定尺寸时可能出现此类情况)
pull/1673/head
Kevin Wang 2023-02-17 15:40:24 +08:00 committed by GitHub
parent e9bf689f74
commit 37c5d371c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 60 additions and 21 deletions

View File

@ -126,7 +126,7 @@ class RecogLMDBDataset(BaseDataset):
"""
data_info = {}
img_key, text = raw_anno_info
data_info['img_path'] = img_key
data_info['img_key'] = img_key
data_info['instances'] = [dict(text=text)]
return data_info
@ -141,7 +141,7 @@ class RecogLMDBDataset(BaseDataset):
"""
data_info = self.get_data_info(idx)
with self.env.begin(write=False) as txn:
img_bytes = txn.get(data_info['img_path'].encode('utf-8'))
img_bytes = txn.get(data_info['img_key'].encode('utf-8'))
if img_bytes is None:
return None
data_info['img'] = mmcv.imfrombytes(

View File

@ -65,11 +65,6 @@ class LoadImageFromFile(MMCV_LoadImageFromFile):
def transform(self, results: dict) -> Optional[dict]:
"""Functions to load image.
Args:
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
"""
"""Functions to load image.
Args:
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
@ -154,8 +149,9 @@ class LoadImageFromNDArray(LoadImageFromFile):
img = img.astype(np.float32)
if self.color_type == 'grayscale':
img = mmcv.image.rgb2gray(img)
results['img_path'] = None
results['img'] = img
if results.get('img_path', None) is None:
results['img_path'] = None
results['img_shape'] = img.shape[:2]
results['ori_shape'] = img.shape[:2]
return results

View File

@ -30,6 +30,12 @@ class TextDetLocalVisualizer(BaseLocalVisualizer):
value, all the lines will have the same colors. Refer to
`matplotlib.colors` for full list of formats that are accepted.
Defaults to 'g'.
gt_ignored_color (Union[str, tuple, list[str], list[tuple]]): The
colors of ignored GT polygons and bboxes. ``colors`` can have
the same length with lines or just single value. If ``colors``
is single value, all the lines will have the same colors. Refer
to `matplotlib.colors` for full list of formats that are accepted.
Defaults to 'b'.
pred_color (Union[str, tuple, list[str], list[tuple]]): The
colors of pred polygons and bboxes. ``colors`` can have the same
length with lines or just single value. If ``colors`` is single
@ -48,6 +54,8 @@ class TextDetLocalVisualizer(BaseLocalVisualizer):
vis_backends: Optional[Dict] = None,
save_dir: Optional[str] = None,
gt_color: Union[str, Tuple, List[str], List[Tuple]] = 'g',
gt_ignored_color: Union[str, Tuple, List[str],
List[Tuple]] = 'b',
pred_color: Union[str, Tuple, List[str], List[Tuple]] = 'r',
line_width: Union[int, float] = 2,
alpha: float = 0.8) -> None:
@ -59,6 +67,7 @@ class TextDetLocalVisualizer(BaseLocalVisualizer):
self.with_poly = with_poly
self.with_bbox = with_bbox
self.gt_color = gt_color
self.gt_ignored_color = gt_ignored_color
self.pred_color = pred_color
self.line_width = line_width
self.alpha = alpha
@ -142,9 +151,22 @@ class TextDetLocalVisualizer(BaseLocalVisualizer):
if data_sample is not None:
if draw_gt and 'gt_instances' in data_sample:
gt_instances = data_sample.gt_instances
gt_img_data = image.copy()
if gt_instances.get('ignored', None) is not None:
ignore_flags = gt_instances.ignored
gt_ignored_instances = gt_instances[ignore_flags]
gt_ignored_polygons = gt_ignored_instances.get(
'polygons', None)
gt_ignored_bboxes = gt_ignored_instances.get(
'bboxes', None)
gt_img_data = self._draw_instances(gt_img_data,
gt_ignored_bboxes,
gt_ignored_polygons,
self.gt_ignored_color)
gt_instances = gt_instances[~ignore_flags]
gt_polygons = gt_instances.get('polygons', None)
gt_bboxes = gt_instances.get('bboxes', None)
gt_img_data = self._draw_instances(image.copy(), gt_bboxes,
gt_img_data = self._draw_instances(gt_img_data, gt_bboxes,
gt_polygons, self.gt_color)
cat_images.append(gt_img_data)
if draw_pred and 'pred_instances' in data_sample:

View File

@ -60,7 +60,7 @@ class TextRecogLocalVisualizer(BaseLocalVisualizer):
height, width = image.shape[:2]
empty_img = np.full_like(image, 255)
self.set_image(empty_img)
font_size = 0.5 * width / (len(text) + 1)
font_size = min(0.5 * width / (len(text) + 1), 0.5 * height)
self.draw_texts(
text,
np.array([width / 2, height / 2]),

View File

@ -109,10 +109,13 @@ def _get_adaptive_scale(img_shape: Tuple[int, int],
return min(max(scale, min_scale), max_scale)
def make_grid(imgs, names):
def make_grid(imgs, infos):
"""Concat list of pictures into a single big picture, align height here."""
visualizer = Visualizer.get_current_instance()
ori_shapes = [img.shape[:2] for img in imgs]
names = [info['name'] for info in infos]
ori_shapes = [
info['dataset_sample'].metainfo['img_shape'] for info in infos
]
max_height = int(max(img.shape[0] for img in imgs) * 1.1)
min_width = min(img.shape[1] for img in imgs)
horizontal_gap = min_width // 10
@ -162,17 +165,13 @@ class InspectCompose(Compose):
self.intermediate_imgs = intermediate_imgs
def __call__(self, data):
if 'img' in data:
self.intermediate_imgs.append({
'name': 'original',
'img': data['img'].copy()
})
self.ptransforms = [
self.transforms[i] for i in range(len(self.transforms) - 1)
]
for t in self.ptransforms:
data = t(data)
# Keep the same meta_keys in the PackDetInputs
# Keep the same meta_keys in the PackTextDetInputs
# or PackTextRecogInputs
self.transforms[-1].meta_keys = [key for key in data]
data_sample = self.transforms[-1](data)
if data is None:
@ -299,7 +298,29 @@ def obtain_dataset_cfg(cfg: Config, phase: str, mode: str, task: str) -> Tuple:
if mode == 'original':
default_cfg = default_cfgs[infer_dataset_task(task, dataset)]
# Image can be stored in other methods, like LMDB,
# which LoadImageFromFile can not handle
if dataset.pipeline is not None:
all_transform_types = [tfm['type'] for tfm in dataset.pipeline]
if any([
tfm_type.startswith('LoadImageFrom')
for tfm_type in all_transform_types
]):
for tfm in dataset.pipeline:
if tfm['type'].startswith('LoadImageFrom'):
# update LoadImageFrom** transform
default_cfg['pipeline'][0] = tfm
dataset.pipeline = default_cfg['pipeline']
else:
# In test_pipeline LoadOCRAnnotations is placed behind
# other transforms. Transform will not be applied on
# gt annotation.
if phase == 'test':
all_transform_types = [tfm['type'] for tfm in dataset.pipeline]
load_ocr_ann_tfm_index = all_transform_types.index(
'LoadOCRAnnotations')
load_ocr_ann_tfm = dataset.pipeline.pop(load_ocr_ann_tfm_index)
dataset.pipeline.insert(1, load_ocr_ann_tfm)
return dataset, visualizer
@ -360,7 +381,8 @@ def main():
result_i = [result['dataset_sample'] for result in intermediate_imgs]
for k, datasample in enumerate(result_i):
image = datasample.img
image = image[..., [2, 1, 0]] # bgr to rgb
if len(image.shape) == 3:
image = image[..., [2, 1, 0]] # bgr to rgb
image_show = visualizer.add_datasample(
'result',
image,
@ -371,8 +393,7 @@ def main():
image_i.append(image_show)
if args.mode == 'pipeline':
image = make_grid([result for result in image_i],
[result['name'] for result in intermediate_imgs])
image = make_grid(image_i, intermediate_imgs)
else:
image = image_i[-1]