use mmcv instead (#237)

* use mmcv instead

* update
pull/242/head
Hongbin Sun 2021-05-25 19:58:32 +08:00 committed by GitHub
parent 36e92ebe70
commit b863bbca5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 22 additions and 21 deletions

View File

@ -1,6 +1,6 @@
import cv2
import imgaug
import imgaug.augmenters as iaa
import mmcv
import numpy as np
from mmdet.core.mask import PolygonMasks
@ -145,7 +145,7 @@ class EastRandomCrop:
padded_img = np.zeros(
(self.target_size[1], self.target_size[0], img.shape[2]),
img.dtype)
padded_img[:h, :w] = cv2.resize(
padded_img[:h, :w] = mmcv.imresize(
img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
# for bboxes

View File

@ -1,6 +1,5 @@
import math
import cv2
import mmcv
import numpy as np
import torch
@ -91,8 +90,8 @@ class ResizeOCR:
if dst_max_width is not None:
valid_ratio = min(1.0, 1.0 * new_width / dst_max_width)
resize_width = min(dst_max_width, new_width)
img_resize = cv2.resize(results['img'],
(resize_width, dst_height))
img_resize = mmcv.imresize(results['img'],
(resize_width, dst_height))
resize_shape = img_resize.shape
pad_shape = img_resize.shape
if new_width < dst_max_width:
@ -102,13 +101,13 @@ class ResizeOCR:
pad_val=self.img_pad_value)
pad_shape = img_resize.shape
else:
img_resize = cv2.resize(results['img'],
(new_width, dst_height))
img_resize = mmcv.imresize(results['img'],
(new_width, dst_height))
resize_shape = img_resize.shape
pad_shape = img_resize.shape
else:
img_resize = cv2.resize(results['img'],
(dst_max_width, dst_height))
img_resize = mmcv.imresize(results['img'],
(dst_max_width, dst_height))
resize_shape = img_resize.shape
pad_shape = img_resize.shape
@ -286,10 +285,10 @@ class RandomPaddingOCR:
random_padding_bottom = round(
np.random.uniform(0, self.max_ratio[3]) * ori_height)
img = np.copy(results['img'])
img = cv2.copyMakeBorder(img, random_padding_top,
random_padding_bottom, random_padding_left,
random_padding_right, cv2.BORDER_REPLICATE)
padding = (random_padding_left, random_padding_top,
random_padding_right, random_padding_bottom)
img = mmcv.impad(results['img'], padding=padding, padding_mode='edge')
results['img'] = img
results['img_shape'] = img.shape

View File

@ -1,6 +1,7 @@
import math
import cv2
import mmcv
import numpy as np
import Polygon as plg
import torchvision.transforms as transforms
@ -587,7 +588,7 @@ class RandomRotatePolyInstances:
(h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
np.random.randint(0, w * 7 // 8))
img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)]
img_cut = cv2.resize(img_cut, (canvas_size[1], canvas_size[0]))
img_cut = mmcv.imresize(img_cut, (canvas_size[1], canvas_size[0]))
mask = cv2.warpAffine(
mask,
rotation_matrix, (canvas_size[1], canvas_size[0]),
@ -670,7 +671,7 @@ class SquareResizePad:
t_w = self.target_size if h <= w else int(w * self.target_size / h)
else:
t_h = t_w = self.target_size
img = cv2.resize(img, (t_w, t_h))
img = mmcv.imresize(img, (t_w, t_h))
return img, (t_h, t_w)
def square_pad(self, img):
@ -685,7 +686,7 @@ class SquareResizePad:
(h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
np.random.randint(0, w * 7 // 8))
img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)]
expand_img = cv2.resize(img_cut, (pad_size, pad_size))
expand_img = mmcv.imresize(img_cut, (pad_size, pad_size))
if h > w:
y0, x0 = 0, (h - w) // 2
else:
@ -758,7 +759,7 @@ class RandomScaling:
scales = self.size * 1.0 / max(h, w) * aspect_ratio
scales = np.array([scales, scales])
out_size = (int(h * scales[1]), int(w * scales[0]))
image = cv2.resize(image, out_size[::-1])
image = mmcv.imresize(image, out_size[::-1])
results['img'] = image
results['img_shape'] = image.shape

View File

@ -1,6 +1,7 @@
import numpy as np
from mmocr.models.builder import CONVERTORS
from mmocr.utils import list_from_file
@CONVERTORS.register_module()
@ -36,10 +37,10 @@ class NerConvertor:
assert self.max_len > 2
assert self.annotation_type in ['bio', 'bioes']
lines = open(vocab_file, encoding='utf-8').readlines()
self.vocab_size = len(lines)
for i in range(len(lines)):
self.word2ids.update({lines[i].rstrip(): i})
vocabs = list_from_file(vocab_file)
self.vocab_size = len(vocabs)
for idx, vocab in enumerate(vocabs):
self.word2ids.update({vocab: idx})
if self.annotation_type == 'bio':
self.label2id_dict, self.id2label, self.ignore_id = \