mirror of https://github.com/alibaba/EasyCV.git
146 lines
5.0 KiB
Python
146 lines
5.0 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from easycv.models import builder
|
|
from easycv.models.base import BaseModel
|
|
from easycv.models.builder import MODELS
|
|
from easycv.models.ocr.postprocess.db_postprocess import DBPostProcess
|
|
from easycv.utils.checkpoint import load_checkpoint
|
|
from easycv.utils.logger import get_root_logger
|
|
|
|
|
|
@MODELS.register_module()
|
|
class DBNet(BaseModel):
|
|
"""DBNet for text detection
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
backbone,
|
|
neck,
|
|
head,
|
|
postprocess,
|
|
loss=None,
|
|
pretrained=None,
|
|
**kwargs,
|
|
):
|
|
super(DBNet, self).__init__()
|
|
|
|
self.pretrained = pretrained
|
|
|
|
self.backbone = builder.build_backbone(backbone)
|
|
self.neck = builder.build_neck(neck)
|
|
self.head = builder.build_head(head)
|
|
self.loss = builder.build_loss(loss) if loss else None
|
|
self.postprocess_op = DBPostProcess(**postprocess)
|
|
self.init_weights()
|
|
|
|
def init_weights(self):
|
|
logger = get_root_logger()
|
|
if self.pretrained:
|
|
load_checkpoint(self, self.pretrained, strict=False, logger=logger)
|
|
else:
|
|
# weight initialization
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d) or isinstance(
|
|
m, nn.ConvTranspose2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias)
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
nn.init.ones_(m.weight)
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias)
|
|
elif isinstance(m, nn.Linear):
|
|
nn.init.normal_(m.weight, 0, 0.01)
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias)
|
|
|
|
def extract_feat(self, x):
|
|
x = self.backbone(x)
|
|
# y["backbone_out"] = x
|
|
x = self.neck(x)
|
|
# y["neck_out"] = x
|
|
x = self.head(x)
|
|
return x
|
|
|
|
def forward_train(self, img, **kwargs):
|
|
predicts = self.extract_feat(img)
|
|
loss = self.loss(predicts, kwargs)
|
|
return loss
|
|
|
|
def forward_test(self, img, **kwargs):
|
|
shape_list = [
|
|
img_meta['ori_img_shape'] for img_meta in kwargs['img_metas']
|
|
]
|
|
with torch.no_grad():
|
|
preds = self.extract_feat(img)
|
|
post_results = self.postprocess_op(preds, shape_list)
|
|
if 'ignore_tags' in kwargs['img_metas'][0]:
|
|
ignore_tags = [
|
|
img_meta['ignore_tags'] for img_meta in kwargs['img_metas']
|
|
]
|
|
post_results['ignore_tags'] = ignore_tags
|
|
if 'polys' in kwargs['img_metas'][0]:
|
|
polys = [img_meta['polys'] for img_meta in kwargs['img_metas']]
|
|
post_results['polys'] = polys
|
|
return post_results
|
|
|
|
def postprocess(self, preds, shape_list):
|
|
|
|
post_results = self.postprocess_op(preds, shape_list)
|
|
points_results = post_results['points']
|
|
dt_boxes = []
|
|
for idx in range(len(points_results)):
|
|
dt_box = points_results[idx]
|
|
dt_box = self.filter_tag_det_res(dt_box, shape_list[idx])
|
|
dt_boxes.append(dt_box)
|
|
return dt_boxes
|
|
|
|
def filter_tag_det_res(self, dt_boxes, image_shape):
|
|
img_height, img_width = image_shape[0:2]
|
|
dt_boxes_new = []
|
|
for box in dt_boxes:
|
|
box = self.order_points_clockwise(box)
|
|
box = self.clip_det_res(box, img_height, img_width)
|
|
rect_width = int(np.linalg.norm(box[0] - box[1]))
|
|
rect_height = int(np.linalg.norm(box[0] - box[3]))
|
|
if rect_width <= 3 or rect_height <= 3:
|
|
continue
|
|
dt_boxes_new.append(box)
|
|
dt_boxes = np.array(dt_boxes_new)
|
|
return dt_boxes
|
|
|
|
def order_points_clockwise(self, pts):
|
|
"""
|
|
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
|
|
# sort the points based on their x-coordinates
|
|
"""
|
|
xSorted = pts[np.argsort(pts[:, 0]), :]
|
|
|
|
# grab the left-most and right-most points from the sorted
|
|
# x-roodinate points
|
|
leftMost = xSorted[:2, :]
|
|
rightMost = xSorted[2:, :]
|
|
|
|
# now, sort the left-most coordinates according to their
|
|
# y-coordinates so we can grab the top-left and bottom-left
|
|
# points, respectively
|
|
leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
|
|
(tl, bl) = leftMost
|
|
|
|
rightMost = rightMost[np.argsort(rightMost[:, 1]), :]
|
|
(tr, br) = rightMost
|
|
|
|
rect = np.array([tl, tr, br, bl], dtype='float32')
|
|
return rect
|
|
|
|
def clip_det_res(self, points, img_height, img_width):
|
|
for pno in range(points.shape[0]):
|
|
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
|
|
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
|
|
return points
|