mirror of https://github.com/open-mmlab/mmocr.git
Fix some browse dataset script bugs and draw textdet gt instance with ignore flags (#1701)
* [Enhancement] textdet draw gt instance with ignore flags * [Fix] 明确key值定义,防止后续使用img_path时得到lmdb格式中的img_key导致图片无法读取 * [Fix] fix five browse_dataset.py script bugs * [Fix] fix some pr problems * [Fix] keep img_path attribute * [Fix] 防止width很大text很小时,font_size过大显示不全(做keep_ratio的resize然后padding到固定尺寸时可能出现此类情况)pull/1673/head
parent
e9bf689f74
commit
37c5d371c7
|
@ -126,7 +126,7 @@ class RecogLMDBDataset(BaseDataset):
|
|||
"""
|
||||
data_info = {}
|
||||
img_key, text = raw_anno_info
|
||||
data_info['img_path'] = img_key
|
||||
data_info['img_key'] = img_key
|
||||
data_info['instances'] = [dict(text=text)]
|
||||
return data_info
|
||||
|
||||
|
@ -141,7 +141,7 @@ class RecogLMDBDataset(BaseDataset):
|
|||
"""
|
||||
data_info = self.get_data_info(idx)
|
||||
with self.env.begin(write=False) as txn:
|
||||
img_bytes = txn.get(data_info['img_path'].encode('utf-8'))
|
||||
img_bytes = txn.get(data_info['img_key'].encode('utf-8'))
|
||||
if img_bytes is None:
|
||||
return None
|
||||
data_info['img'] = mmcv.imfrombytes(
|
||||
|
|
|
@ -65,11 +65,6 @@ class LoadImageFromFile(MMCV_LoadImageFromFile):
|
|||
def transform(self, results: dict) -> Optional[dict]:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
"""
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
|
@ -154,8 +149,9 @@ class LoadImageFromNDArray(LoadImageFromFile):
|
|||
img = img.astype(np.float32)
|
||||
if self.color_type == 'grayscale':
|
||||
img = mmcv.image.rgb2gray(img)
|
||||
results['img_path'] = None
|
||||
results['img'] = img
|
||||
if results.get('img_path', None) is None:
|
||||
results['img_path'] = None
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
|
|
@ -30,6 +30,12 @@ class TextDetLocalVisualizer(BaseLocalVisualizer):
|
|||
value, all the lines will have the same colors. Refer to
|
||||
`matplotlib.colors` for full list of formats that are accepted.
|
||||
Defaults to 'g'.
|
||||
gt_ignored_color (Union[str, tuple, list[str], list[tuple]]): The
|
||||
colors of ignored GT polygons and bboxes. ``colors`` can have
|
||||
the same length with lines or just single value. If ``colors``
|
||||
is single value, all the lines will have the same colors. Refer
|
||||
to `matplotlib.colors` for full list of formats that are accepted.
|
||||
Defaults to 'b'.
|
||||
pred_color (Union[str, tuple, list[str], list[tuple]]): The
|
||||
colors of pred polygons and bboxes. ``colors`` can have the same
|
||||
length with lines or just single value. If ``colors`` is single
|
||||
|
@ -48,6 +54,8 @@ class TextDetLocalVisualizer(BaseLocalVisualizer):
|
|||
vis_backends: Optional[Dict] = None,
|
||||
save_dir: Optional[str] = None,
|
||||
gt_color: Union[str, Tuple, List[str], List[Tuple]] = 'g',
|
||||
gt_ignored_color: Union[str, Tuple, List[str],
|
||||
List[Tuple]] = 'b',
|
||||
pred_color: Union[str, Tuple, List[str], List[Tuple]] = 'r',
|
||||
line_width: Union[int, float] = 2,
|
||||
alpha: float = 0.8) -> None:
|
||||
|
@ -59,6 +67,7 @@ class TextDetLocalVisualizer(BaseLocalVisualizer):
|
|||
self.with_poly = with_poly
|
||||
self.with_bbox = with_bbox
|
||||
self.gt_color = gt_color
|
||||
self.gt_ignored_color = gt_ignored_color
|
||||
self.pred_color = pred_color
|
||||
self.line_width = line_width
|
||||
self.alpha = alpha
|
||||
|
@ -142,9 +151,22 @@ class TextDetLocalVisualizer(BaseLocalVisualizer):
|
|||
if data_sample is not None:
|
||||
if draw_gt and 'gt_instances' in data_sample:
|
||||
gt_instances = data_sample.gt_instances
|
||||
gt_img_data = image.copy()
|
||||
if gt_instances.get('ignored', None) is not None:
|
||||
ignore_flags = gt_instances.ignored
|
||||
gt_ignored_instances = gt_instances[ignore_flags]
|
||||
gt_ignored_polygons = gt_ignored_instances.get(
|
||||
'polygons', None)
|
||||
gt_ignored_bboxes = gt_ignored_instances.get(
|
||||
'bboxes', None)
|
||||
gt_img_data = self._draw_instances(gt_img_data,
|
||||
gt_ignored_bboxes,
|
||||
gt_ignored_polygons,
|
||||
self.gt_ignored_color)
|
||||
gt_instances = gt_instances[~ignore_flags]
|
||||
gt_polygons = gt_instances.get('polygons', None)
|
||||
gt_bboxes = gt_instances.get('bboxes', None)
|
||||
gt_img_data = self._draw_instances(image.copy(), gt_bboxes,
|
||||
gt_img_data = self._draw_instances(gt_img_data, gt_bboxes,
|
||||
gt_polygons, self.gt_color)
|
||||
cat_images.append(gt_img_data)
|
||||
if draw_pred and 'pred_instances' in data_sample:
|
||||
|
|
|
@ -60,7 +60,7 @@ class TextRecogLocalVisualizer(BaseLocalVisualizer):
|
|||
height, width = image.shape[:2]
|
||||
empty_img = np.full_like(image, 255)
|
||||
self.set_image(empty_img)
|
||||
font_size = 0.5 * width / (len(text) + 1)
|
||||
font_size = min(0.5 * width / (len(text) + 1), 0.5 * height)
|
||||
self.draw_texts(
|
||||
text,
|
||||
np.array([width / 2, height / 2]),
|
||||
|
|
|
@ -109,10 +109,13 @@ def _get_adaptive_scale(img_shape: Tuple[int, int],
|
|||
return min(max(scale, min_scale), max_scale)
|
||||
|
||||
|
||||
def make_grid(imgs, names):
|
||||
def make_grid(imgs, infos):
|
||||
"""Concat list of pictures into a single big picture, align height here."""
|
||||
visualizer = Visualizer.get_current_instance()
|
||||
ori_shapes = [img.shape[:2] for img in imgs]
|
||||
names = [info['name'] for info in infos]
|
||||
ori_shapes = [
|
||||
info['dataset_sample'].metainfo['img_shape'] for info in infos
|
||||
]
|
||||
max_height = int(max(img.shape[0] for img in imgs) * 1.1)
|
||||
min_width = min(img.shape[1] for img in imgs)
|
||||
horizontal_gap = min_width // 10
|
||||
|
@ -162,17 +165,13 @@ class InspectCompose(Compose):
|
|||
self.intermediate_imgs = intermediate_imgs
|
||||
|
||||
def __call__(self, data):
|
||||
if 'img' in data:
|
||||
self.intermediate_imgs.append({
|
||||
'name': 'original',
|
||||
'img': data['img'].copy()
|
||||
})
|
||||
self.ptransforms = [
|
||||
self.transforms[i] for i in range(len(self.transforms) - 1)
|
||||
]
|
||||
for t in self.ptransforms:
|
||||
data = t(data)
|
||||
# Keep the same meta_keys in the PackDetInputs
|
||||
# Keep the same meta_keys in the PackTextDetInputs
|
||||
# or PackTextRecogInputs
|
||||
self.transforms[-1].meta_keys = [key for key in data]
|
||||
data_sample = self.transforms[-1](data)
|
||||
if data is None:
|
||||
|
@ -299,7 +298,29 @@ def obtain_dataset_cfg(cfg: Config, phase: str, mode: str, task: str) -> Tuple:
|
|||
|
||||
if mode == 'original':
|
||||
default_cfg = default_cfgs[infer_dataset_task(task, dataset)]
|
||||
# Image can be stored in other methods, like LMDB,
|
||||
# which LoadImageFromFile can not handle
|
||||
if dataset.pipeline is not None:
|
||||
all_transform_types = [tfm['type'] for tfm in dataset.pipeline]
|
||||
if any([
|
||||
tfm_type.startswith('LoadImageFrom')
|
||||
for tfm_type in all_transform_types
|
||||
]):
|
||||
for tfm in dataset.pipeline:
|
||||
if tfm['type'].startswith('LoadImageFrom'):
|
||||
# update LoadImageFrom** transform
|
||||
default_cfg['pipeline'][0] = tfm
|
||||
dataset.pipeline = default_cfg['pipeline']
|
||||
else:
|
||||
# In test_pipeline LoadOCRAnnotations is placed behind
|
||||
# other transforms. Transform will not be applied on
|
||||
# gt annotation.
|
||||
if phase == 'test':
|
||||
all_transform_types = [tfm['type'] for tfm in dataset.pipeline]
|
||||
load_ocr_ann_tfm_index = all_transform_types.index(
|
||||
'LoadOCRAnnotations')
|
||||
load_ocr_ann_tfm = dataset.pipeline.pop(load_ocr_ann_tfm_index)
|
||||
dataset.pipeline.insert(1, load_ocr_ann_tfm)
|
||||
|
||||
return dataset, visualizer
|
||||
|
||||
|
@ -360,7 +381,8 @@ def main():
|
|||
result_i = [result['dataset_sample'] for result in intermediate_imgs]
|
||||
for k, datasample in enumerate(result_i):
|
||||
image = datasample.img
|
||||
image = image[..., [2, 1, 0]] # bgr to rgb
|
||||
if len(image.shape) == 3:
|
||||
image = image[..., [2, 1, 0]] # bgr to rgb
|
||||
image_show = visualizer.add_datasample(
|
||||
'result',
|
||||
image,
|
||||
|
@ -371,8 +393,7 @@ def main():
|
|||
image_i.append(image_show)
|
||||
|
||||
if args.mode == 'pipeline':
|
||||
image = make_grid([result for result in image_i],
|
||||
[result['name'] for result in intermediate_imgs])
|
||||
image = make_grid(image_i, intermediate_imgs)
|
||||
else:
|
||||
image = image_i[-1]
|
||||
|
||||
|
|
Loading…
Reference in New Issue