add nrtr dml distill loss (#9968)
* support min_area_rect crop * add check_install * fix requirement.txt * fix check_install * add lanms-neo for drrg * fix * fix doc * fix * support set gpu_id when inference * fix #8855 * fix #8855 * opt slim doc * fix doc bug * add v4_rec_distill config * delete debug * fix comment * fix comment * add dml nrtr distill losspull/9874/head^2
parent
1643f268d3
commit
abc4be007e
|
@ -96,6 +96,96 @@ class DistillationDMLLoss(DMLLoss):
|
||||||
continue
|
continue
|
||||||
return new_outs
|
return new_outs
|
||||||
|
|
||||||
|
def forward(self, predicts, batch):
|
||||||
|
loss_dict = dict()
|
||||||
|
for idx, pair in enumerate(self.model_name_pairs):
|
||||||
|
out1 = predicts[pair[0]]
|
||||||
|
out2 = predicts[pair[1]]
|
||||||
|
if self.key is not None:
|
||||||
|
out1 = out1[self.key]
|
||||||
|
out2 = out2[self.key]
|
||||||
|
if self.maps_name is None:
|
||||||
|
if self.multi_head:
|
||||||
|
loss = super().forward(out1[self.dis_head],
|
||||||
|
out2[self.dis_head])
|
||||||
|
else:
|
||||||
|
loss = super().forward(out1, out2)
|
||||||
|
if isinstance(loss, dict):
|
||||||
|
for key in loss:
|
||||||
|
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
|
||||||
|
idx)] = loss[key]
|
||||||
|
else:
|
||||||
|
loss_dict["{}_{}".format(self.name, idx)] = loss
|
||||||
|
else:
|
||||||
|
outs1 = self._slice_out(out1)
|
||||||
|
outs2 = self._slice_out(out2)
|
||||||
|
for _c, k in enumerate(outs1.keys()):
|
||||||
|
loss = super().forward(outs1[k], outs2[k])
|
||||||
|
if isinstance(loss, dict):
|
||||||
|
for key in loss:
|
||||||
|
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
|
||||||
|
0], pair[1], self.maps_name, idx)] = loss[key]
|
||||||
|
else:
|
||||||
|
loss_dict["{}_{}_{}".format(self.name, self.maps_name[
|
||||||
|
_c], idx)] = loss
|
||||||
|
|
||||||
|
loss_dict = _sum_loss(loss_dict)
|
||||||
|
|
||||||
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
|
class DistillationKLDivLoss(KLDivLoss):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_name_pairs=[],
|
||||||
|
key=None,
|
||||||
|
multi_head=False,
|
||||||
|
dis_head='ctc',
|
||||||
|
maps_name=None,
|
||||||
|
name="kl_div"):
|
||||||
|
super().__init__()
|
||||||
|
assert isinstance(model_name_pairs, list)
|
||||||
|
self.key = key
|
||||||
|
self.multi_head = multi_head
|
||||||
|
self.dis_head = dis_head
|
||||||
|
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
|
||||||
|
self.name = name
|
||||||
|
self.maps_name = self._check_maps_name(maps_name)
|
||||||
|
|
||||||
|
def _check_model_name_pairs(self, model_name_pairs):
|
||||||
|
if not isinstance(model_name_pairs, list):
|
||||||
|
return []
|
||||||
|
elif isinstance(model_name_pairs[0], list) and isinstance(
|
||||||
|
model_name_pairs[0][0], str):
|
||||||
|
return model_name_pairs
|
||||||
|
else:
|
||||||
|
return [model_name_pairs]
|
||||||
|
|
||||||
|
def _check_maps_name(self, maps_name):
|
||||||
|
if maps_name is None:
|
||||||
|
return None
|
||||||
|
elif type(maps_name) == str:
|
||||||
|
return [maps_name]
|
||||||
|
elif type(maps_name) == list:
|
||||||
|
return [maps_name]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _slice_out(self, outs):
|
||||||
|
new_outs = {}
|
||||||
|
for k in self.maps_name:
|
||||||
|
if k == "thrink_maps":
|
||||||
|
new_outs[k] = outs[:, 0, :, :]
|
||||||
|
elif k == "threshold_maps":
|
||||||
|
new_outs[k] = outs[:, 1, :, :]
|
||||||
|
elif k == "binary_maps":
|
||||||
|
new_outs[k] = outs[:, 2, :, :]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
return new_outs
|
||||||
|
|
||||||
def forward(self, predicts, batch):
|
def forward(self, predicts, batch):
|
||||||
loss_dict = dict()
|
loss_dict = dict()
|
||||||
for idx, pair in enumerate(self.model_name_pairs):
|
for idx, pair in enumerate(self.model_name_pairs):
|
||||||
|
@ -141,6 +231,149 @@ class DistillationDMLLoss(DMLLoss):
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
|
class DistillationDKDLoss(DKDLoss):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_name_pairs=[],
|
||||||
|
key=None,
|
||||||
|
multi_head=False,
|
||||||
|
dis_head='ctc',
|
||||||
|
maps_name=None,
|
||||||
|
name="dkd",
|
||||||
|
temperature=1.0,
|
||||||
|
alpha=1.0,
|
||||||
|
beta=1.0):
|
||||||
|
super().__init__(temperature, alpha, beta)
|
||||||
|
assert isinstance(model_name_pairs, list)
|
||||||
|
self.key = key
|
||||||
|
self.multi_head = multi_head
|
||||||
|
self.dis_head = dis_head
|
||||||
|
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
|
||||||
|
self.name = name
|
||||||
|
self.maps_name = self._check_maps_name(maps_name)
|
||||||
|
|
||||||
|
def _check_model_name_pairs(self, model_name_pairs):
|
||||||
|
if not isinstance(model_name_pairs, list):
|
||||||
|
return []
|
||||||
|
elif isinstance(model_name_pairs[0], list) and isinstance(
|
||||||
|
model_name_pairs[0][0], str):
|
||||||
|
return model_name_pairs
|
||||||
|
else:
|
||||||
|
return [model_name_pairs]
|
||||||
|
|
||||||
|
def _check_maps_name(self, maps_name):
|
||||||
|
if maps_name is None:
|
||||||
|
return None
|
||||||
|
elif type(maps_name) == str:
|
||||||
|
return [maps_name]
|
||||||
|
elif type(maps_name) == list:
|
||||||
|
return [maps_name]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _slice_out(self, outs):
|
||||||
|
new_outs = {}
|
||||||
|
for k in self.maps_name:
|
||||||
|
if k == "thrink_maps":
|
||||||
|
new_outs[k] = outs[:, 0, :, :]
|
||||||
|
elif k == "threshold_maps":
|
||||||
|
new_outs[k] = outs[:, 1, :, :]
|
||||||
|
elif k == "binary_maps":
|
||||||
|
new_outs[k] = outs[:, 2, :, :]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
return new_outs
|
||||||
|
|
||||||
|
def forward(self, predicts, batch):
|
||||||
|
loss_dict = dict()
|
||||||
|
|
||||||
|
for idx, pair in enumerate(self.model_name_pairs):
|
||||||
|
out1 = predicts[pair[0]]
|
||||||
|
out2 = predicts[pair[1]]
|
||||||
|
if self.key is not None:
|
||||||
|
out1 = out1[self.key]
|
||||||
|
out2 = out2[self.key]
|
||||||
|
if self.maps_name is None:
|
||||||
|
if self.multi_head:
|
||||||
|
# for nrtr dml loss
|
||||||
|
max_len = batch[3].max()
|
||||||
|
tgt = batch[2][:, 1:2 +
|
||||||
|
max_len] # [batch_size, max_len + 1]
|
||||||
|
|
||||||
|
tgt = tgt.reshape([-1]) # batch_size * (max_len + 1)
|
||||||
|
non_pad_mask = paddle.not_equal(
|
||||||
|
tgt, paddle.zeros(
|
||||||
|
tgt.shape,
|
||||||
|
dtype=tgt.dtype)) # batch_size * (max_len + 1)
|
||||||
|
|
||||||
|
loss = super().forward(
|
||||||
|
out1[self.dis_head], out2[self.dis_head], tgt,
|
||||||
|
non_pad_mask) # [batch_size, max_len + 1, num_char]
|
||||||
|
else:
|
||||||
|
loss = super().forward(out1, out2)
|
||||||
|
if isinstance(loss, dict):
|
||||||
|
for key in loss:
|
||||||
|
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
|
||||||
|
idx)] = loss[key]
|
||||||
|
else:
|
||||||
|
loss_dict["{}_{}".format(self.name, idx)] = loss
|
||||||
|
else:
|
||||||
|
outs1 = self._slice_out(out1)
|
||||||
|
outs2 = self._slice_out(out2)
|
||||||
|
for _c, k in enumerate(outs1.keys()):
|
||||||
|
loss = super().forward(outs1[k], outs2[k])
|
||||||
|
if isinstance(loss, dict):
|
||||||
|
for key in loss:
|
||||||
|
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
|
||||||
|
0], pair[1], self.maps_name, idx)] = loss[key]
|
||||||
|
else:
|
||||||
|
loss_dict["{}_{}_{}".format(self.name, self.maps_name[
|
||||||
|
_c], idx)] = loss
|
||||||
|
|
||||||
|
loss_dict = _sum_loss(loss_dict)
|
||||||
|
|
||||||
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
|
class DistillationNRTRDMLLoss(DistillationDMLLoss):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, predicts, batch):
|
||||||
|
loss_dict = dict()
|
||||||
|
for idx, pair in enumerate(self.model_name_pairs):
|
||||||
|
out1 = predicts[pair[0]]
|
||||||
|
out2 = predicts[pair[1]]
|
||||||
|
if self.key is not None:
|
||||||
|
out1 = out1[self.key]
|
||||||
|
out2 = out2[self.key]
|
||||||
|
|
||||||
|
if self.multi_head:
|
||||||
|
# for nrtr dml loss
|
||||||
|
max_len = batch[3].max()
|
||||||
|
tgt = batch[2][:, 1:2 + max_len]
|
||||||
|
tgt = tgt.reshape([-1])
|
||||||
|
non_pad_mask = paddle.not_equal(
|
||||||
|
tgt, paddle.zeros(
|
||||||
|
tgt.shape, dtype=tgt.dtype))
|
||||||
|
loss = super().forward(out1[self.dis_head], out2[self.dis_head],
|
||||||
|
non_pad_mask)
|
||||||
|
else:
|
||||||
|
loss = super().forward(out1, out2)
|
||||||
|
if isinstance(loss, dict):
|
||||||
|
for key in loss:
|
||||||
|
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
|
||||||
|
idx)] = loss[key]
|
||||||
|
else:
|
||||||
|
loss_dict["{}_{}".format(self.name, idx)] = loss
|
||||||
|
|
||||||
|
loss_dict = _sum_loss(loss_dict)
|
||||||
|
|
||||||
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
class DistillationKLDivLoss(KLDivLoss):
|
class DistillationKLDivLoss(KLDivLoss):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue