mirror of https://github.com/alibaba/EasyCV.git
116 lines
3.9 KiB
Python
116 lines
3.9 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from easycv.models import builder
|
|
from easycv.models.base import BaseModel
|
|
from easycv.models.builder import MODELS
|
|
from easycv.models.ocr.postprocess.rec_postprocess import CTCLabelDecode
|
|
from easycv.utils.checkpoint import load_checkpoint
|
|
from easycv.utils.logger import get_root_logger
|
|
|
|
|
|
@MODELS.register_module()
|
|
class OCRRecNet(BaseModel):
|
|
"""for text recognition
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
backbone,
|
|
head,
|
|
postprocess,
|
|
neck=None,
|
|
loss=None,
|
|
pretrained=None,
|
|
**kwargs,
|
|
):
|
|
super(OCRRecNet, self).__init__()
|
|
|
|
self.pretrained = pretrained
|
|
|
|
# self.backbone = eval(backbone.type)(**backbone)
|
|
self.backbone = builder.build_backbone(backbone)
|
|
self.neck = builder.build_neck(neck) if neck else None
|
|
self.head = builder.build_head(head)
|
|
self.loss = builder.build_loss(loss) if loss else None
|
|
self.postprocess_op = eval(postprocess.type)(**postprocess)
|
|
self.init_weights()
|
|
|
|
def init_weights(self):
|
|
logger = get_root_logger()
|
|
if self.pretrained:
|
|
load_checkpoint(self, self.pretrained, strict=False, logger=logger)
|
|
else:
|
|
# weight initialization
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d) or isinstance(
|
|
m, nn.ConvTranspose2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias)
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
nn.init.ones_(m.weight)
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias)
|
|
elif isinstance(m, nn.Linear):
|
|
nn.init.normal_(m.weight, 0, 0.01)
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias)
|
|
|
|
def extract_feat(self, x, label=None, valid_ratios=None):
|
|
y = dict()
|
|
x = self.backbone(x)
|
|
y['backbone_out'] = x
|
|
if self.neck:
|
|
x = self.neck(x)
|
|
y['neck_out'] = x
|
|
x = self.head(x, label=label, valid_ratios=valid_ratios)
|
|
# for multi head, save ctc neck out for udml
|
|
if isinstance(x, dict) and 'ctc_nect' in x.keys():
|
|
y['neck_out'] = x['ctc_neck']
|
|
y['head_out'] = x
|
|
elif isinstance(x, dict):
|
|
y.update(x)
|
|
else:
|
|
y['head_out'] = x
|
|
return y
|
|
|
|
def forward_train(self, img, **kwargs):
|
|
label_ctc = kwargs.get('label_ctc', None)
|
|
label_sar = kwargs.get('label_sar', None)
|
|
length = kwargs.get('length', None)
|
|
valid_ratio = kwargs.get('valid_ratio', None)
|
|
predicts = self.extract_feat(
|
|
img, label=label_sar, valid_ratios=valid_ratio)
|
|
loss = self.loss(
|
|
predicts, label_ctc=label_ctc, label_sar=label_sar, length=length)
|
|
return loss
|
|
|
|
def forward_test(self, img, **kwargs):
|
|
label_ctc = kwargs.get('label_ctc', None)
|
|
result = {}
|
|
with torch.no_grad():
|
|
preds = self.extract_feat(img)
|
|
if label_ctc == None:
|
|
preds_text = self.postprocess(preds)
|
|
else:
|
|
preds_text, label_text = self.postprocess(preds, label_ctc)
|
|
result['label_text'] = label_text
|
|
result['preds_text'] = preds_text
|
|
return result
|
|
|
|
def postprocess(self, preds, label=None):
|
|
if isinstance(preds, dict):
|
|
preds = preds['head_out']
|
|
if isinstance(preds, list):
|
|
preds = [v.cpu().detach().numpy() for v in preds]
|
|
else:
|
|
preds = preds.cpu().detach().numpy()
|
|
label = label.cpu().detach().numpy() if label != None else label
|
|
text_out = self.postprocess_op(preds, label)
|
|
|
|
return text_out
|