mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
commit
52763f1cf6
@ -94,14 +94,11 @@ Loss:
|
||||
- ["Student", "Student2"]
|
||||
maps_name: "thrink_maps"
|
||||
weight: 1.0
|
||||
# act: None
|
||||
model_name_pairs: ["Student", "Student2"]
|
||||
key: maps
|
||||
- DistillationDBLoss:
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Student2"]
|
||||
# key: maps
|
||||
# name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
@ -191,7 +188,6 @@ Eval:
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
# image_shape: [736, 1280]
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
|
@ -24,6 +24,7 @@ Architecture:
|
||||
model_type: det
|
||||
Models:
|
||||
Student:
|
||||
pretrained:
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform: null
|
||||
@ -40,6 +41,7 @@ Architecture:
|
||||
name: DBHead
|
||||
k: 50
|
||||
Student2:
|
||||
pretrained:
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform: null
|
||||
@ -56,6 +58,7 @@ Architecture:
|
||||
name: DBHead
|
||||
k: 50
|
||||
Teacher:
|
||||
pretrained:
|
||||
freeze_params: true
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
@ -91,14 +94,11 @@ Loss:
|
||||
- ["Student", "Student2"]
|
||||
maps_name: "thrink_maps"
|
||||
weight: 1.0
|
||||
# act: None
|
||||
model_name_pairs: ["Student", "Student2"]
|
||||
key: maps
|
||||
- DistillationDBLoss:
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Student2"]
|
||||
# key: maps
|
||||
# name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
@ -204,31 +204,21 @@ Eval:
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- DetLabelEncode: null
|
||||
- DetResizeForTest: null
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean:
|
||||
- 0.485
|
||||
- 0.456
|
||||
- 0.406
|
||||
std:
|
||||
- 0.229
|
||||
- 0.224
|
||||
- 0.225
|
||||
order: hwc
|
||||
- ToCHWImage: null
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- shape
|
||||
- polys
|
||||
- ignore_tags
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: false
|
||||
drop_last: false
|
||||
batch_size_per_card: 1
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
||||
|
@ -60,19 +60,19 @@ class KLJSLoss(object):
|
||||
], "mode can only be one of ['kl', 'KL', 'js', 'JS']"
|
||||
self.mode = mode
|
||||
|
||||
def __call__(self, p1, p2, reduction="mean"):
|
||||
def __call__(self, p1, p2, reduction="mean", eps=1e-5):
|
||||
|
||||
if self.mode.lower() == 'kl':
|
||||
loss = paddle.multiply(p2,
|
||||
paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
|
||||
paddle.log((p2 + eps) / (p1 + eps) + eps))
|
||||
loss += paddle.multiply(
|
||||
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
|
||||
p1, paddle.log((p1 + eps) / (p2 + eps) + eps))
|
||||
loss *= 0.5
|
||||
elif self.mode.lower() == "js":
|
||||
loss = paddle.multiply(
|
||||
p2, paddle.log((2 * p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
|
||||
p2, paddle.log((2 * p2 + eps) / (p1 + p2 + eps) + eps))
|
||||
loss += paddle.multiply(
|
||||
p1, paddle.log((2 * p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
|
||||
p1, paddle.log((2 * p1 + eps) / (p1 + p2 + eps) + eps))
|
||||
loss *= 0.5
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -125,7 +125,7 @@ class DMLLoss(nn.Layer):
|
||||
loss = (
|
||||
self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0
|
||||
else:
|
||||
# for detection distillation log is not needed
|
||||
# distillation log is not needed for detection
|
||||
loss = self.jskl_loss(out1, out2)
|
||||
return loss
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user