add nrtr dml distill loss (#9968)
* support min_area_rect crop * add check_install * fix requirement.txt * fix check_install * add lanms-neo for drrg * fix * fix doc * fix * support set gpu_id when inference * fix #8855 * fix #8855 * opt slim doc * fix doc bug * add v4_rec_distill config * delete debug * fix comment * fix comment * add dml nrtr distill losspull/9874/head^2
parent
1643f268d3
commit
abc4be007e
|
@ -96,6 +96,96 @@ class DistillationDMLLoss(DMLLoss):
|
|||
continue
|
||||
return new_outs
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, pair in enumerate(self.model_name_pairs):
|
||||
out1 = predicts[pair[0]]
|
||||
out2 = predicts[pair[1]]
|
||||
if self.key is not None:
|
||||
out1 = out1[self.key]
|
||||
out2 = out2[self.key]
|
||||
if self.maps_name is None:
|
||||
if self.multi_head:
|
||||
loss = super().forward(out1[self.dis_head],
|
||||
out2[self.dis_head])
|
||||
else:
|
||||
loss = super().forward(out1, out2)
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
|
||||
idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, idx)] = loss
|
||||
else:
|
||||
outs1 = self._slice_out(out1)
|
||||
outs2 = self._slice_out(out2)
|
||||
for _c, k in enumerate(outs1.keys()):
|
||||
loss = super().forward(outs1[k], outs2[k])
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
|
||||
0], pair[1], self.maps_name, idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}_{}".format(self.name, self.maps_name[
|
||||
_c], idx)] = loss
|
||||
|
||||
loss_dict = _sum_loss(loss_dict)
|
||||
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationKLDivLoss(KLDivLoss):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_name_pairs=[],
|
||||
key=None,
|
||||
multi_head=False,
|
||||
dis_head='ctc',
|
||||
maps_name=None,
|
||||
name="kl_div"):
|
||||
super().__init__()
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.multi_head = multi_head
|
||||
self.dis_head = dis_head
|
||||
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
|
||||
self.name = name
|
||||
self.maps_name = self._check_maps_name(maps_name)
|
||||
|
||||
def _check_model_name_pairs(self, model_name_pairs):
|
||||
if not isinstance(model_name_pairs, list):
|
||||
return []
|
||||
elif isinstance(model_name_pairs[0], list) and isinstance(
|
||||
model_name_pairs[0][0], str):
|
||||
return model_name_pairs
|
||||
else:
|
||||
return [model_name_pairs]
|
||||
|
||||
def _check_maps_name(self, maps_name):
|
||||
if maps_name is None:
|
||||
return None
|
||||
elif type(maps_name) == str:
|
||||
return [maps_name]
|
||||
elif type(maps_name) == list:
|
||||
return [maps_name]
|
||||
else:
|
||||
return None
|
||||
|
||||
def _slice_out(self, outs):
|
||||
new_outs = {}
|
||||
for k in self.maps_name:
|
||||
if k == "thrink_maps":
|
||||
new_outs[k] = outs[:, 0, :, :]
|
||||
elif k == "threshold_maps":
|
||||
new_outs[k] = outs[:, 1, :, :]
|
||||
elif k == "binary_maps":
|
||||
new_outs[k] = outs[:, 2, :, :]
|
||||
else:
|
||||
continue
|
||||
return new_outs
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, pair in enumerate(self.model_name_pairs):
|
||||
|
@ -141,6 +231,149 @@ class DistillationDMLLoss(DMLLoss):
|
|||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDKDLoss(DKDLoss):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_name_pairs=[],
|
||||
key=None,
|
||||
multi_head=False,
|
||||
dis_head='ctc',
|
||||
maps_name=None,
|
||||
name="dkd",
|
||||
temperature=1.0,
|
||||
alpha=1.0,
|
||||
beta=1.0):
|
||||
super().__init__(temperature, alpha, beta)
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.multi_head = multi_head
|
||||
self.dis_head = dis_head
|
||||
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
|
||||
self.name = name
|
||||
self.maps_name = self._check_maps_name(maps_name)
|
||||
|
||||
def _check_model_name_pairs(self, model_name_pairs):
|
||||
if not isinstance(model_name_pairs, list):
|
||||
return []
|
||||
elif isinstance(model_name_pairs[0], list) and isinstance(
|
||||
model_name_pairs[0][0], str):
|
||||
return model_name_pairs
|
||||
else:
|
||||
return [model_name_pairs]
|
||||
|
||||
def _check_maps_name(self, maps_name):
|
||||
if maps_name is None:
|
||||
return None
|
||||
elif type(maps_name) == str:
|
||||
return [maps_name]
|
||||
elif type(maps_name) == list:
|
||||
return [maps_name]
|
||||
else:
|
||||
return None
|
||||
|
||||
def _slice_out(self, outs):
|
||||
new_outs = {}
|
||||
for k in self.maps_name:
|
||||
if k == "thrink_maps":
|
||||
new_outs[k] = outs[:, 0, :, :]
|
||||
elif k == "threshold_maps":
|
||||
new_outs[k] = outs[:, 1, :, :]
|
||||
elif k == "binary_maps":
|
||||
new_outs[k] = outs[:, 2, :, :]
|
||||
else:
|
||||
continue
|
||||
return new_outs
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
|
||||
for idx, pair in enumerate(self.model_name_pairs):
|
||||
out1 = predicts[pair[0]]
|
||||
out2 = predicts[pair[1]]
|
||||
if self.key is not None:
|
||||
out1 = out1[self.key]
|
||||
out2 = out2[self.key]
|
||||
if self.maps_name is None:
|
||||
if self.multi_head:
|
||||
# for nrtr dml loss
|
||||
max_len = batch[3].max()
|
||||
tgt = batch[2][:, 1:2 +
|
||||
max_len] # [batch_size, max_len + 1]
|
||||
|
||||
tgt = tgt.reshape([-1]) # batch_size * (max_len + 1)
|
||||
non_pad_mask = paddle.not_equal(
|
||||
tgt, paddle.zeros(
|
||||
tgt.shape,
|
||||
dtype=tgt.dtype)) # batch_size * (max_len + 1)
|
||||
|
||||
loss = super().forward(
|
||||
out1[self.dis_head], out2[self.dis_head], tgt,
|
||||
non_pad_mask) # [batch_size, max_len + 1, num_char]
|
||||
else:
|
||||
loss = super().forward(out1, out2)
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
|
||||
idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, idx)] = loss
|
||||
else:
|
||||
outs1 = self._slice_out(out1)
|
||||
outs2 = self._slice_out(out2)
|
||||
for _c, k in enumerate(outs1.keys()):
|
||||
loss = super().forward(outs1[k], outs2[k])
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
|
||||
0], pair[1], self.maps_name, idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}_{}".format(self.name, self.maps_name[
|
||||
_c], idx)] = loss
|
||||
|
||||
loss_dict = _sum_loss(loss_dict)
|
||||
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationNRTRDMLLoss(DistillationDMLLoss):
|
||||
"""
|
||||
"""
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, pair in enumerate(self.model_name_pairs):
|
||||
out1 = predicts[pair[0]]
|
||||
out2 = predicts[pair[1]]
|
||||
if self.key is not None:
|
||||
out1 = out1[self.key]
|
||||
out2 = out2[self.key]
|
||||
|
||||
if self.multi_head:
|
||||
# for nrtr dml loss
|
||||
max_len = batch[3].max()
|
||||
tgt = batch[2][:, 1:2 + max_len]
|
||||
tgt = tgt.reshape([-1])
|
||||
non_pad_mask = paddle.not_equal(
|
||||
tgt, paddle.zeros(
|
||||
tgt.shape, dtype=tgt.dtype))
|
||||
loss = super().forward(out1[self.dis_head], out2[self.dis_head],
|
||||
non_pad_mask)
|
||||
else:
|
||||
loss = super().forward(out1, out2)
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
|
||||
idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, idx)] = loss
|
||||
|
||||
loss_dict = _sum_loss(loss_dict)
|
||||
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationKLDivLoss(KLDivLoss):
|
||||
"""
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue