mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
add PP-OCRv4 det code (#9766)
* add ppocrv4 det student and teacher model * update head and config, refine details * refine config and head details * refine config and head details * refine details * refine details * remove application * refine fpn * fix bug * update code * fix bug * align lcnet to rec * align hgnet to rec * refine make shrink * remove theseus layer
This commit is contained in:
parent
7710ee04c6
commit
ca8c8200ba
172
configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml
Normal file
172
configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml
Normal file
@ -0,0 +1,172 @@
|
||||
Global:
|
||||
debug: false
|
||||
use_gpu: true
|
||||
epoch_num: 500
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 100
|
||||
save_model_dir: ./output/ch_PP-OCRv3_mv3_cbnlocal_shrink/
|
||||
save_epoch_step: 10
|
||||
eval_batch_step:
|
||||
- 0
|
||||
- 1500
|
||||
cal_metric_during_train: false
|
||||
checkpoints:
|
||||
pretrained_model:
|
||||
save_inference_dir: null
|
||||
use_visualdl: false
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./checkpoints/det_db/predicts_db.txt
|
||||
distributed: true
|
||||
|
||||
Architecture:
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform: null
|
||||
Backbone:
|
||||
name: LCNetv3
|
||||
scale: 0.75
|
||||
det: True
|
||||
Neck:
|
||||
name: RSEFPN
|
||||
out_channels: 96
|
||||
shortcut: True
|
||||
Head:
|
||||
name: CBNHeadLocal
|
||||
k: 50
|
||||
mode: "small"
|
||||
|
||||
Loss:
|
||||
name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001 #(8*8c)
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: L2
|
||||
factor: 5.0e-05
|
||||
|
||||
PostProcess:
|
||||
name: DBPostProcess
|
||||
thresh: 0.3
|
||||
box_thresh: 0.6
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
|
||||
Metric:
|
||||
name: DetMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- DetLabelEncode: null
|
||||
- CopyPaste: null
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- type: Fliplr
|
||||
args:
|
||||
p: 0.5
|
||||
- type: Affine
|
||||
args:
|
||||
rotate:
|
||||
- -10
|
||||
- 10
|
||||
- type: Resize
|
||||
args:
|
||||
size:
|
||||
- 0.5
|
||||
- 3
|
||||
- EastRandomCropData:
|
||||
size:
|
||||
- 640
|
||||
- 640
|
||||
max_tries: 50
|
||||
keep_ratio: true
|
||||
- MakeBorderMap:
|
||||
shrink_ratio: 0.4
|
||||
thresh_min: 0.3
|
||||
thresh_max: 0.7
|
||||
total_epoch: 500
|
||||
- MakeShrinkMap:
|
||||
shrink_ratio: 0.4
|
||||
min_text_size: 8
|
||||
total_epoch: 500
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean:
|
||||
- 0.485
|
||||
- 0.456
|
||||
- 0.406
|
||||
std:
|
||||
- 0.229
|
||||
- 0.224
|
||||
- 0.225
|
||||
order: hwc
|
||||
- ToCHWImage: null
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- threshold_map
|
||||
- threshold_mask
|
||||
- shrink_map
|
||||
- shrink_mask
|
||||
loader:
|
||||
shuffle: true
|
||||
drop_last: false
|
||||
batch_size_per_card: 8
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- DetLabelEncode: null
|
||||
- DetResizeForTest:
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean:
|
||||
- 0.485
|
||||
- 0.456
|
||||
- 0.406
|
||||
std:
|
||||
- 0.229
|
||||
- 0.224
|
||||
- 0.225
|
||||
order: hwc
|
||||
- ToCHWImage: null
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- shape
|
||||
- polys
|
||||
- ignore_tags
|
||||
loader:
|
||||
shuffle: false
|
||||
drop_last: false
|
||||
batch_size_per_card: 1
|
||||
num_workers: 2
|
||||
profiler_options: null
|
172
configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_teacher.yml
Normal file
172
configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_teacher.yml
Normal file
@ -0,0 +1,172 @@
|
||||
Global:
|
||||
debug: false
|
||||
use_gpu: true
|
||||
epoch_num: 500
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 100
|
||||
save_model_dir: ./output/ch_PP-OCRv3_mv3_cbnlocal_shrink/
|
||||
save_epoch_step: 10
|
||||
eval_batch_step:
|
||||
- 0
|
||||
- 1500
|
||||
cal_metric_during_train: false
|
||||
checkpoints:
|
||||
pretrained_model:
|
||||
save_inference_dir: null
|
||||
use_visualdl: false
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./checkpoints/det_db/predicts_db.txt
|
||||
distributed: true
|
||||
|
||||
Architecture:
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform: null
|
||||
Backbone:
|
||||
name: PPHGNet_small
|
||||
det: True
|
||||
Neck:
|
||||
name: LKPAN
|
||||
out_channels: 256
|
||||
intracl: true
|
||||
Head:
|
||||
name: CBNHeadLocal
|
||||
k: 50
|
||||
mode: "large"
|
||||
|
||||
|
||||
Loss:
|
||||
name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001 #(8*8c)
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: L2
|
||||
factor: 1e-6
|
||||
|
||||
PostProcess:
|
||||
name: DBPostProcess
|
||||
thresh: 0.3
|
||||
box_thresh: 0.6
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
|
||||
Metric:
|
||||
name: DetMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- DetLabelEncode: null
|
||||
- CopyPaste: null
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- type: Fliplr
|
||||
args:
|
||||
p: 0.5
|
||||
- type: Affine
|
||||
args:
|
||||
rotate:
|
||||
- -10
|
||||
- 10
|
||||
- type: Resize
|
||||
args:
|
||||
size:
|
||||
- 0.5
|
||||
- 3
|
||||
- EastRandomCropData:
|
||||
size:
|
||||
- 640
|
||||
- 640
|
||||
max_tries: 50
|
||||
keep_ratio: true
|
||||
- MakeBorderMap:
|
||||
shrink_ratio: 0.4
|
||||
thresh_min: 0.3
|
||||
thresh_max: 0.7
|
||||
total_epoch: 500
|
||||
- MakeShrinkMap:
|
||||
shrink_ratio: 0.4
|
||||
min_text_size: 8
|
||||
total_epoch: 500
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean:
|
||||
- 0.485
|
||||
- 0.456
|
||||
- 0.406
|
||||
std:
|
||||
- 0.229
|
||||
- 0.224
|
||||
- 0.225
|
||||
order: hwc
|
||||
- ToCHWImage: null
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- threshold_map
|
||||
- threshold_mask
|
||||
- shrink_map
|
||||
- shrink_mask
|
||||
loader:
|
||||
shuffle: true
|
||||
drop_last: false
|
||||
batch_size_per_card: 8
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- DetLabelEncode: null
|
||||
- DetResizeForTest:
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean:
|
||||
- 0.485
|
||||
- 0.456
|
||||
- 0.406
|
||||
std:
|
||||
- 0.229
|
||||
- 0.224
|
||||
- 0.225
|
||||
order: hwc
|
||||
- ToCHWImage: null
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- shape
|
||||
- polys
|
||||
- ignore_tags
|
||||
loader:
|
||||
shuffle: false
|
||||
drop_last: false
|
||||
batch_size_per_card: 1
|
||||
num_workers: 2
|
||||
profiler_options: null
|
@ -44,6 +44,10 @@ class MakeBorderMap(object):
|
||||
self.shrink_ratio = shrink_ratio
|
||||
self.thresh_min = thresh_min
|
||||
self.thresh_max = thresh_max
|
||||
if 'total_epoch' in kwargs and 'epoch' in kwargs and kwargs[
|
||||
'epoch'] != "None":
|
||||
self.shrink_ratio = self.shrink_ratio + 0.2 * kwargs[
|
||||
'epoch'] / float(kwargs['total_epoch'])
|
||||
|
||||
def __call__(self, data):
|
||||
|
||||
|
@ -38,6 +38,10 @@ class MakeShrinkMap(object):
|
||||
def __init__(self, min_text_size=8, shrink_ratio=0.4, **kwargs):
|
||||
self.min_text_size = min_text_size
|
||||
self.shrink_ratio = shrink_ratio
|
||||
if 'total_epoch' in kwargs and 'epoch' in kwargs and kwargs[
|
||||
'epoch'] != "None":
|
||||
self.shrink_ratio = self.shrink_ratio + 0.2 * kwargs[
|
||||
'epoch'] / float(kwargs['total_epoch'])
|
||||
|
||||
def __call__(self, data):
|
||||
image = data['image']
|
||||
|
@ -48,11 +48,25 @@ class SimpleDataSet(Dataset):
|
||||
self.data_idx_order_list = list(range(len(self.data_lines)))
|
||||
if self.mode == "train" and self.do_shuffle:
|
||||
self.shuffle_data_random()
|
||||
|
||||
self.set_epoch_as_seed(self.seed)
|
||||
|
||||
self.ops = create_operators(dataset_config['transforms'], global_config)
|
||||
self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx",
|
||||
2)
|
||||
self.need_reset = True in [x < 1 for x in ratio_list]
|
||||
|
||||
def set_epoch_as_seed(self, seed):
|
||||
if self.mode is 'train':
|
||||
try:
|
||||
dataset_config['transforms'][5]['MakeBorderMap'][
|
||||
'epoch'] = seed if seed is not None else 0
|
||||
dataset_config['transforms'][6]['MakeShrinkMap'][
|
||||
'epoch'] = seed if seed is not None else 0
|
||||
except Exception as E:
|
||||
logger.info(E)
|
||||
return
|
||||
|
||||
def get_image_info_list(self, file_list, ratio_list):
|
||||
if isinstance(file_list, str):
|
||||
file_list = [file_list]
|
||||
|
@ -20,6 +20,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
|
||||
@ -66,11 +67,21 @@ class DBLoss(nn.Layer):
|
||||
label_shrink_mask)
|
||||
loss_shrink_maps = self.alpha * loss_shrink_maps
|
||||
loss_threshold_maps = self.beta * loss_threshold_maps
|
||||
# CBN loss
|
||||
if 'distance_maps' in predicts.keys():
|
||||
distance_maps = predicts['distance_maps']
|
||||
cbn_maps = predicts['cbn_maps']
|
||||
cbn_loss = self.bce_loss(cbn_maps[:, 0, :, :], label_shrink_map,
|
||||
label_shrink_mask)
|
||||
else:
|
||||
dis_loss = paddle.to_tensor([0.])
|
||||
cbn_loss = paddle.to_tensor([0.])
|
||||
|
||||
loss_all = loss_shrink_maps + loss_threshold_maps \
|
||||
+ loss_binary_maps
|
||||
losses = {'loss': loss_all, \
|
||||
losses = {'loss': loss_all+ cbn_loss, \
|
||||
"loss_shrink_maps": loss_shrink_maps, \
|
||||
"loss_threshold_maps": loss_threshold_maps, \
|
||||
"loss_binary_maps": loss_binary_maps}
|
||||
"loss_binary_maps": loss_binary_maps, \
|
||||
"loss_cbn": cbn_loss}
|
||||
return losses
|
||||
|
@ -22,8 +22,11 @@ def build_backbone(config, model_type):
|
||||
from .det_resnet_vd import ResNet_vd
|
||||
from .det_resnet_vd_sast import ResNet_SAST
|
||||
from .det_pp_lcnet import PPLCNet
|
||||
from .rec_lcnetv3 import LCNetv3
|
||||
from .rec_hgnet import PPHGNet_small
|
||||
support_dict = [
|
||||
"MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST", "PPLCNet"
|
||||
"MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST", "PPLCNet",
|
||||
"LCNetv3", "PPHGNet_small"
|
||||
]
|
||||
if model_type == "table":
|
||||
from .table_master_resnet import TableResNetExtra
|
||||
|
@ -188,8 +188,19 @@ class PPHGNet(nn.Layer):
|
||||
model: nn.Layer. Specific PPHGNet model depends on args.
|
||||
"""
|
||||
|
||||
def __init__(self, stem_channels, stage_config, layer_num, in_channels=3):
|
||||
def __init__(
|
||||
self,
|
||||
stem_channels,
|
||||
stage_config,
|
||||
layer_num,
|
||||
in_channels=3,
|
||||
det=False,
|
||||
out_indices=None, ):
|
||||
super().__init__()
|
||||
self.det = det
|
||||
self.out_indices = out_indices if out_indices is not None else [
|
||||
0, 1, 2, 3
|
||||
]
|
||||
|
||||
# stem
|
||||
stem_channels.insert(0, in_channels)
|
||||
@ -202,16 +213,23 @@ class PPHGNet(nn.Layer):
|
||||
len(stem_channels) - 1)
|
||||
])
|
||||
|
||||
if self.det:
|
||||
self.pool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
|
||||
# stages
|
||||
self.stages = nn.LayerList()
|
||||
for k in stage_config:
|
||||
self.out_channels = []
|
||||
for block_id, k in enumerate(stage_config):
|
||||
in_channels, mid_channels, out_channels, block_num, downsample, stride = stage_config[
|
||||
k]
|
||||
self.stages.append(
|
||||
HG_Stage(in_channels, mid_channels, out_channels, block_num,
|
||||
layer_num, downsample, stride))
|
||||
if block_id in self.out_indices:
|
||||
self.out_channels.append(out_channels)
|
||||
|
||||
if not self.det:
|
||||
self.out_channels = stage_config["stage4"][2]
|
||||
|
||||
self.out_channels = stage_config["stage4"][2]
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
@ -226,8 +244,17 @@ class PPHGNet(nn.Layer):
|
||||
|
||||
def forward(self, x):
|
||||
x = self.stem(x)
|
||||
for stage in self.stages:
|
||||
if self.det:
|
||||
x = self.pool(x)
|
||||
|
||||
out = []
|
||||
for i, stage in enumerate(self.stages):
|
||||
x = stage(x)
|
||||
if self.det and i in self.out_indices:
|
||||
out.append(x)
|
||||
if self.det:
|
||||
return out
|
||||
|
||||
if self.training:
|
||||
x = F.adaptive_avg_pool2d(x, [1, 40])
|
||||
else:
|
||||
@ -261,7 +288,7 @@ def PPHGNet_tiny(pretrained=False, use_ssld=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
def PPHGNet_small(pretrained=False, use_ssld=False, **kwargs):
|
||||
def PPHGNet_small(pretrained=False, use_ssld=False, det=False, **kwargs):
|
||||
"""
|
||||
PPHGNet_small
|
||||
Args:
|
||||
@ -271,7 +298,15 @@ def PPHGNet_small(pretrained=False, use_ssld=False, **kwargs):
|
||||
Returns:
|
||||
model: nn.Layer. Specific `PPHGNet_small` model depends on args.
|
||||
"""
|
||||
stage_config = {
|
||||
stage_config_det = {
|
||||
# in_channels, mid_channels, out_channels, blocks, downsample
|
||||
"stage1": [128, 128, 256, 1, False, 2],
|
||||
"stage2": [256, 160, 512, 1, True, 2],
|
||||
"stage3": [512, 192, 768, 2, True, 2],
|
||||
"stage4": [768, 224, 1024, 1, True, 2],
|
||||
}
|
||||
|
||||
stage_config_rec = {
|
||||
# in_channels, mid_channels, out_channels, blocks, downsample
|
||||
"stage1": [128, 128, 256, 1, True, [2, 1]],
|
||||
"stage2": [256, 160, 512, 1, True, [1, 2]],
|
||||
@ -281,8 +316,9 @@ def PPHGNet_small(pretrained=False, use_ssld=False, **kwargs):
|
||||
|
||||
model = PPHGNet(
|
||||
stem_channels=[64, 64, 128],
|
||||
stage_config=stage_config,
|
||||
stage_config=stage_config_det if det else stage_config_rec,
|
||||
layer_num=6,
|
||||
det=det,
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
@ -24,7 +24,20 @@ from paddle.nn.initializer import Constant, KaimingNormal
|
||||
from paddle.nn import AdaptiveAvgPool2D, BatchNorm2D, Conv2D, Dropout, Hardsigmoid, Hardswish, Identity, Linear, ReLU
|
||||
from paddle.regularizer import L2Decay
|
||||
|
||||
NET_CONFIG = {
|
||||
NET_CONFIG_det = {
|
||||
"blocks2":
|
||||
#k, in_c, out_c, s, use_se
|
||||
[[3, 16, 32, 1, False]],
|
||||
"blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
|
||||
"blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
|
||||
"blocks5":
|
||||
[[3, 128, 256, 2, False], [5, 256, 256, 1, False], [5, 256, 256, 1, False],
|
||||
[5, 256, 256, 1, False], [5, 256, 256, 1, False]],
|
||||
"blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True],
|
||||
[5, 512, 512, 1, False], [5, 512, 512, 1, False]]
|
||||
}
|
||||
|
||||
NET_CONFIG_rec = {
|
||||
"blocks2":
|
||||
#k, in_c, out_c, s, use_se
|
||||
[[3, 16, 32, 1, False]],
|
||||
@ -335,11 +348,14 @@ class PPLCNetV3(nn.Layer):
|
||||
conv_kxk_num=4,
|
||||
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
|
||||
lab_lr=0.1,
|
||||
det=False,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.lr_mult_list = lr_mult_list
|
||||
self.net_config = NET_CONFIG
|
||||
self.det = det
|
||||
|
||||
self.net_config = NET_CONFIG_det if self.det else NET_CONFIG_rec
|
||||
|
||||
assert isinstance(self.lr_mult_list, (
|
||||
list, tuple
|
||||
@ -365,8 +381,9 @@ class PPLCNetV3(nn.Layer):
|
||||
use_se=se,
|
||||
conv_kxk_num=conv_kxk_num,
|
||||
lr_mult=self.lr_mult_list[1],
|
||||
lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate(
|
||||
self.net_config["blocks2"])
|
||||
lab_lr=lab_lr)
|
||||
for i, (k, in_c, out_c, s, se) in enumerate(self.net_config[
|
||||
"blocks2"])
|
||||
])
|
||||
|
||||
self.blocks3 = nn.Sequential(* [
|
||||
@ -378,8 +395,9 @@ class PPLCNetV3(nn.Layer):
|
||||
use_se=se,
|
||||
conv_kxk_num=conv_kxk_num,
|
||||
lr_mult=self.lr_mult_list[2],
|
||||
lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate(
|
||||
self.net_config["blocks3"])
|
||||
lab_lr=lab_lr)
|
||||
for i, (k, in_c, out_c, s, se) in enumerate(self.net_config[
|
||||
"blocks3"])
|
||||
])
|
||||
|
||||
self.blocks4 = nn.Sequential(* [
|
||||
@ -391,8 +409,9 @@ class PPLCNetV3(nn.Layer):
|
||||
use_se=se,
|
||||
conv_kxk_num=conv_kxk_num,
|
||||
lr_mult=self.lr_mult_list[3],
|
||||
lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate(
|
||||
self.net_config["blocks4"])
|
||||
lab_lr=lab_lr)
|
||||
for i, (k, in_c, out_c, s, se) in enumerate(self.net_config[
|
||||
"blocks4"])
|
||||
])
|
||||
|
||||
self.blocks5 = nn.Sequential(* [
|
||||
@ -404,8 +423,9 @@ class PPLCNetV3(nn.Layer):
|
||||
use_se=se,
|
||||
conv_kxk_num=conv_kxk_num,
|
||||
lr_mult=self.lr_mult_list[4],
|
||||
lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate(
|
||||
self.net_config["blocks5"])
|
||||
lab_lr=lab_lr)
|
||||
for i, (k, in_c, out_c, s, se) in enumerate(self.net_config[
|
||||
"blocks5"])
|
||||
])
|
||||
|
||||
self.blocks6 = nn.Sequential(* [
|
||||
@ -417,19 +437,52 @@ class PPLCNetV3(nn.Layer):
|
||||
use_se=se,
|
||||
conv_kxk_num=conv_kxk_num,
|
||||
lr_mult=self.lr_mult_list[5],
|
||||
lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate(
|
||||
self.net_config["blocks6"])
|
||||
lab_lr=lab_lr)
|
||||
for i, (k, in_c, out_c, s, se) in enumerate(self.net_config[
|
||||
"blocks6"])
|
||||
])
|
||||
self.out_channels = make_divisible(512 * scale)
|
||||
|
||||
if self.det:
|
||||
mv_c = [16, 24, 56, 480]
|
||||
self.out_channels = [
|
||||
make_divisible(self.net_config["blocks3"][-1][2] * scale),
|
||||
make_divisible(self.net_config["blocks4"][-1][2] * scale),
|
||||
make_divisible(self.net_config["blocks5"][-1][2] * scale),
|
||||
make_divisible(self.net_config["blocks6"][-1][2] * scale),
|
||||
]
|
||||
|
||||
self.layer_list = nn.LayerList([
|
||||
nn.Conv2D(self.out_channels[0], int(mv_c[0] * scale), 1, 1, 0),
|
||||
nn.Conv2D(self.out_channels[1], int(mv_c[1] * scale), 1, 1, 0),
|
||||
nn.Conv2D(self.out_channels[2], int(mv_c[2] * scale), 1, 1, 0),
|
||||
nn.Conv2D(self.out_channels[3], int(mv_c[3] * scale), 1, 1, 0)
|
||||
])
|
||||
self.out_channels = [
|
||||
int(mv_c[0] * scale), int(mv_c[1] * scale),
|
||||
int(mv_c[2] * scale), int(mv_c[3] * scale)
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
out_list = []
|
||||
x = self.conv1(x)
|
||||
|
||||
x = self.blocks2(x)
|
||||
x = self.blocks3(x)
|
||||
out_list.append(x)
|
||||
x = self.blocks4(x)
|
||||
out_list.append(x)
|
||||
x = self.blocks5(x)
|
||||
out_list.append(x)
|
||||
x = self.blocks6(x)
|
||||
out_list.append(x)
|
||||
|
||||
if self.det:
|
||||
out_list[0] = self.layer_list[0](out_list[0])
|
||||
out_list[1] = self.layer_list[1](out_list[1])
|
||||
out_list[2] = self.layer_list[2](out_list[2])
|
||||
out_list[3] = self.layer_list[3](out_list[3])
|
||||
return out_list
|
||||
|
||||
if self.training:
|
||||
x = F.adaptive_avg_pool2d(x, [1, 40])
|
||||
@ -438,6 +491,6 @@ class PPLCNetV3(nn.Layer):
|
||||
return x
|
||||
|
||||
|
||||
def LCNetv3(pretrained=False, use_ssld=False, **kwargs):
|
||||
model = PPLCNetV3(scale=0.95, conv_kxk_num=4, **kwargs)
|
||||
def LCNetv3(scale=0.95, **kwargs):
|
||||
model = PPLCNetV3(scale=scale, conv_kxk_num=4, **kwargs)
|
||||
return model
|
||||
|
@ -17,14 +17,13 @@ __all__ = ['build_head']
|
||||
|
||||
def build_head(config):
|
||||
# det head
|
||||
from .det_db_head import DBHead
|
||||
from .det_db_head import DBHead, CBNHeadLocal
|
||||
from .det_east_head import EASTHead
|
||||
from .det_sast_head import SASTHead
|
||||
from .det_pse_head import PSEHead
|
||||
from .det_fce_head import FCEHead
|
||||
from .e2e_pg_head import PGHead
|
||||
from .det_ct_head import CT_Head
|
||||
|
||||
# rec head
|
||||
from .rec_ctc_head import CTCHead
|
||||
from .rec_att_head import AttentionHead
|
||||
@ -57,7 +56,7 @@ def build_head(config):
|
||||
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
|
||||
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
|
||||
'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead',
|
||||
'DRRGHead', 'CANHead', 'SATRNHead'
|
||||
'DRRGHead', 'CANHead', 'SATRNHead', 'CBNHeadLocal'
|
||||
]
|
||||
|
||||
if config['name'] == 'DRRGHead':
|
||||
|
@ -21,6 +21,7 @@ import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle import ParamAttr
|
||||
from ppocr.modeling.backbones.det_mobilenet_v3 import ConvBNLayer
|
||||
|
||||
|
||||
def get_bias_attr(k):
|
||||
@ -48,6 +49,7 @@ class Head(nn.Layer):
|
||||
bias_attr=ParamAttr(
|
||||
initializer=paddle.nn.initializer.Constant(value=1e-4)),
|
||||
act='relu')
|
||||
|
||||
self.conv2 = nn.Conv2DTranspose(
|
||||
in_channels=in_channels // 4,
|
||||
out_channels=in_channels // 4,
|
||||
@ -72,13 +74,17 @@ class Head(nn.Layer):
|
||||
initializer=paddle.nn.initializer.KaimingUniform()),
|
||||
bias_attr=get_bias_attr(in_channels // 4), )
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, return_f=False):
|
||||
x = self.conv1(x)
|
||||
x = self.conv_bn1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv_bn2(x)
|
||||
if return_f is True:
|
||||
f = x
|
||||
x = self.conv3(x)
|
||||
x = F.sigmoid(x)
|
||||
if return_f is True:
|
||||
return x, f
|
||||
return x
|
||||
|
||||
|
||||
@ -108,3 +114,41 @@ class DBHead(nn.Layer):
|
||||
binary_maps = self.step_function(shrink_maps, threshold_maps)
|
||||
y = paddle.concat([shrink_maps, threshold_maps, binary_maps], axis=1)
|
||||
return {'maps': y}
|
||||
|
||||
|
||||
class LocalModule(nn.Layer):
|
||||
def __init__(self, in_c, mid_c, use_distance=True):
|
||||
super(self.__class__, self).__init__()
|
||||
self.last_3 = ConvBNLayer(in_c + 1, mid_c, 3, 1, 1, act='relu')
|
||||
self.last_1 = nn.Conv2D(mid_c, 1, 1, 1, 0)
|
||||
|
||||
def forward(self, x, init_map, distance_map):
|
||||
outf = paddle.concat([init_map, x], axis=1)
|
||||
# last Conv
|
||||
out = self.last_1(self.last_3(outf))
|
||||
return out
|
||||
|
||||
|
||||
class CBNHeadLocal(DBHead):
|
||||
def __init__(self, in_channels, k=50, mode='small', **kwargs):
|
||||
super(CBNHeadLocal, self).__init__(in_channels, k, **kwargs)
|
||||
self.mode = mode
|
||||
|
||||
self.up_conv = nn.Upsample(scale_factor=2, mode="nearest", align_mode=1)
|
||||
if self.mode == 'large':
|
||||
self.cbn_layer = LocalModule(in_channels // 4, in_channels // 4)
|
||||
elif self.mode == 'small':
|
||||
self.cbn_layer = LocalModule(in_channels // 4, in_channels // 8)
|
||||
|
||||
def forward(self, x, targets=None):
|
||||
shrink_maps, f = self.binarize(x, return_f=True)
|
||||
base_maps = shrink_maps
|
||||
cbn_maps = self.cbn_layer(self.up_conv(f), shrink_maps, None)
|
||||
cbn_maps = F.sigmoid(cbn_maps)
|
||||
if not self.training:
|
||||
return {'maps': 0.5 * (base_maps + cbn_maps), 'cbn_maps': cbn_maps}
|
||||
|
||||
threshold_maps = self.thresh(x)
|
||||
binary_maps = self.step_function(shrink_maps, threshold_maps)
|
||||
y = paddle.concat([cbn_maps, threshold_maps, binary_maps], axis=1)
|
||||
return {'maps': y, 'distance_maps': cbn_maps, 'cbn_maps': binary_maps}
|
||||
|
@ -22,6 +22,7 @@ import paddle.nn.functional as F
|
||||
from paddle import ParamAttr
|
||||
import os
|
||||
import sys
|
||||
from ppocr.modeling.necks.intracl import IntraCLBlock
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
@ -228,6 +229,13 @@ class RSEFPN(nn.Layer):
|
||||
self.out_channels = out_channels
|
||||
self.ins_conv = nn.LayerList()
|
||||
self.inp_conv = nn.LayerList()
|
||||
self.intracl = False
|
||||
if 'intracl' in kwargs.keys() and kwargs['intracl'] is True:
|
||||
self.intracl = kwargs['intracl']
|
||||
self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
|
||||
self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
|
||||
self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
|
||||
self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
|
||||
|
||||
for i in range(len(in_channels)):
|
||||
self.ins_conv.append(
|
||||
@ -263,6 +271,12 @@ class RSEFPN(nn.Layer):
|
||||
p3 = self.inp_conv[1](out3)
|
||||
p2 = self.inp_conv[0](out2)
|
||||
|
||||
if self.intracl is True:
|
||||
p5 = self.incl4(p5)
|
||||
p4 = self.incl3(p4)
|
||||
p3 = self.incl2(p3)
|
||||
p2 = self.incl1(p2)
|
||||
|
||||
p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1)
|
||||
p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1)
|
||||
p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1)
|
||||
@ -329,6 +343,14 @@ class LKPAN(nn.Layer):
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False))
|
||||
|
||||
self.intracl = False
|
||||
if 'intracl' in kwargs.keys() and kwargs['intracl'] is True:
|
||||
self.intracl = kwargs['intracl']
|
||||
self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
|
||||
self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
|
||||
self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
|
||||
self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
|
||||
|
||||
def forward(self, x):
|
||||
c2, c3, c4, c5 = x
|
||||
|
||||
@ -358,6 +380,12 @@ class LKPAN(nn.Layer):
|
||||
p4 = self.pan_lat_conv[2](pan4)
|
||||
p5 = self.pan_lat_conv[3](pan5)
|
||||
|
||||
if self.intracl is True:
|
||||
p5 = self.incl4(p5)
|
||||
p4 = self.incl3(p4)
|
||||
p3 = self.incl2(p3)
|
||||
p2 = self.incl1(p2)
|
||||
|
||||
p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1)
|
||||
p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1)
|
||||
p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1)
|
||||
@ -424,4 +452,4 @@ class ASFBlock(nn.Layer):
|
||||
out_list = []
|
||||
for i in range(self.out_features_num):
|
||||
out_list.append(attention_scores[:, i:i + 1] * features_list[i])
|
||||
return paddle.concat(out_list, axis=1)
|
||||
return paddle.concat(out_list, axis=1)
|
118
ppocr/modeling/necks/intracl.py
Normal file
118
ppocr/modeling/necks/intracl.py
Normal file
@ -0,0 +1,118 @@
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
# refer from: https://github.com/ViTAE-Transformer/I3CL/blob/736c80237f66d352d488e83b05f3e33c55201317/mmdet/models/detectors/intra_cl_module.py
|
||||
|
||||
|
||||
class IntraCLBlock(nn.Layer):
|
||||
def __init__(self, in_channels=96, reduce_factor=4):
|
||||
super(IntraCLBlock, self).__init__()
|
||||
self.channels = in_channels
|
||||
self.rf = reduce_factor
|
||||
weight_attr = paddle.nn.initializer.KaimingUniform()
|
||||
self.conv1x1_reduce_channel = nn.Conv2D(
|
||||
self.channels,
|
||||
self.channels // self.rf,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.conv1x1_return_channel = nn.Conv2D(
|
||||
self.channels // self.rf,
|
||||
self.channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
self.v_layer_7x1 = nn.Conv2D(
|
||||
self.channels // self.rf,
|
||||
self.channels // self.rf,
|
||||
kernel_size=(7, 1),
|
||||
stride=(1, 1),
|
||||
padding=(3, 0))
|
||||
self.v_layer_5x1 = nn.Conv2D(
|
||||
self.channels // self.rf,
|
||||
self.channels // self.rf,
|
||||
kernel_size=(5, 1),
|
||||
stride=(1, 1),
|
||||
padding=(2, 0))
|
||||
self.v_layer_3x1 = nn.Conv2D(
|
||||
self.channels // self.rf,
|
||||
self.channels // self.rf,
|
||||
kernel_size=(3, 1),
|
||||
stride=(1, 1),
|
||||
padding=(1, 0))
|
||||
|
||||
self.q_layer_1x7 = nn.Conv2D(
|
||||
self.channels // self.rf,
|
||||
self.channels // self.rf,
|
||||
kernel_size=(1, 7),
|
||||
stride=(1, 1),
|
||||
padding=(0, 3))
|
||||
self.q_layer_1x5 = nn.Conv2D(
|
||||
self.channels // self.rf,
|
||||
self.channels // self.rf,
|
||||
kernel_size=(1, 5),
|
||||
stride=(1, 1),
|
||||
padding=(0, 2))
|
||||
self.q_layer_1x3 = nn.Conv2D(
|
||||
self.channels // self.rf,
|
||||
self.channels // self.rf,
|
||||
kernel_size=(1, 3),
|
||||
stride=(1, 1),
|
||||
padding=(0, 1))
|
||||
|
||||
# base
|
||||
self.c_layer_7x7 = nn.Conv2D(
|
||||
self.channels // self.rf,
|
||||
self.channels // self.rf,
|
||||
kernel_size=(7, 7),
|
||||
stride=(1, 1),
|
||||
padding=(3, 3))
|
||||
self.c_layer_5x5 = nn.Conv2D(
|
||||
self.channels // self.rf,
|
||||
self.channels // self.rf,
|
||||
kernel_size=(5, 5),
|
||||
stride=(1, 1),
|
||||
padding=(2, 2))
|
||||
self.c_layer_3x3 = nn.Conv2D(
|
||||
self.channels // self.rf,
|
||||
self.channels // self.rf,
|
||||
kernel_size=(3, 3),
|
||||
stride=(1, 1),
|
||||
padding=(1, 1))
|
||||
|
||||
self.bn = nn.BatchNorm2D(self.channels)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x_new = self.conv1x1_reduce_channel(x)
|
||||
|
||||
x_7_c = self.c_layer_7x7(x_new)
|
||||
x_7_v = self.v_layer_7x1(x_new)
|
||||
x_7_q = self.q_layer_1x7(x_new)
|
||||
x_7 = x_7_c + x_7_v + x_7_q
|
||||
|
||||
x_5_c = self.c_layer_5x5(x_7)
|
||||
x_5_v = self.v_layer_5x1(x_7)
|
||||
x_5_q = self.q_layer_1x5(x_7)
|
||||
x_5 = x_5_c + x_5_v + x_5_q
|
||||
|
||||
x_3_c = self.c_layer_3x3(x_5)
|
||||
x_3_v = self.v_layer_3x1(x_5)
|
||||
x_3_q = self.q_layer_1x3(x_5)
|
||||
x_3 = x_3_c + x_3_v + x_3_q
|
||||
|
||||
x_relation = self.conv1x1_return_channel(x_3)
|
||||
|
||||
x_relation = self.bn(x_relation)
|
||||
x_relation = self.relu(x_relation)
|
||||
|
||||
return x + x_relation
|
||||
|
||||
|
||||
def build_intraclblock_list(num_block):
|
||||
IntraCLBlock_list = nn.LayerList()
|
||||
for i in range(num_block):
|
||||
IntraCLBlock_list.append(IntraCLBlock())
|
||||
|
||||
return IntraCLBlock_list
|
Loading…
x
Reference in New Issue
Block a user