mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Fix]: Fix dim mismatch in batch test/inference of DBNet (#383)
Modification 1.change the dbnet output dim 2.change the input dim for DBloss
This commit is contained in:
parent
02d657d141
commit
200dfe5fe2
1
.gitignore
vendored
1
.gitignore
vendored
@ -138,3 +138,4 @@ workspace.code-workspace
|
|||||||
results
|
results
|
||||||
mmocr/core/font.TTF
|
mmocr/core/font.TTF
|
||||||
workdirs/
|
workdirs/
|
||||||
|
.history/
|
||||||
|
@ -66,7 +66,7 @@ class DBHead(HeadMixin, BaseModule):
|
|||||||
thr_map = self.threshold(inputs)
|
thr_map = self.threshold(inputs)
|
||||||
binary_map = self.diff_binarize(prob_map, thr_map, k=50)
|
binary_map = self.diff_binarize(prob_map, thr_map, k=50)
|
||||||
outputs = torch.cat((prob_map, thr_map, binary_map), dim=1)
|
outputs = torch.cat((prob_map, thr_map, binary_map), dim=1)
|
||||||
return (outputs, )
|
return outputs
|
||||||
|
|
||||||
def _init_thr(self, inner_channels, bias=False):
|
def _init_thr(self, inner_channels, bias=False):
|
||||||
in_channels = inner_channels
|
in_channels = inner_channels
|
||||||
|
@ -133,8 +133,6 @@ class DBLoss(nn.Module):
|
|||||||
assert isinstance(gt_thr, list)
|
assert isinstance(gt_thr, list)
|
||||||
assert isinstance(gt_thr_mask, list)
|
assert isinstance(gt_thr_mask, list)
|
||||||
|
|
||||||
preds = preds[0]
|
|
||||||
|
|
||||||
pred_prob = preds[:, 0, :, :]
|
pred_prob = preds[:, 0, :, :]
|
||||||
pred_thr = preds[:, 1, :, :]
|
pred_thr = preds[:, 1, :, :]
|
||||||
pred_db = preds[:, 2, :, :]
|
pred_db = preds[:, 2, :, :]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user