mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Fix]: Fix dim mismatch in batch test/inference of DBNet (#383)
Modification 1.change the dbnet output dim 2.change the input dim for DBloss
This commit is contained in:
parent
02d657d141
commit
200dfe5fe2
1
.gitignore
vendored
1
.gitignore
vendored
@ -138,3 +138,4 @@ workspace.code-workspace
|
||||
results
|
||||
mmocr/core/font.TTF
|
||||
workdirs/
|
||||
.history/
|
||||
|
@ -66,7 +66,7 @@ class DBHead(HeadMixin, BaseModule):
|
||||
thr_map = self.threshold(inputs)
|
||||
binary_map = self.diff_binarize(prob_map, thr_map, k=50)
|
||||
outputs = torch.cat((prob_map, thr_map, binary_map), dim=1)
|
||||
return (outputs, )
|
||||
return outputs
|
||||
|
||||
def _init_thr(self, inner_channels, bias=False):
|
||||
in_channels = inner_channels
|
||||
|
@ -133,8 +133,6 @@ class DBLoss(nn.Module):
|
||||
assert isinstance(gt_thr, list)
|
||||
assert isinstance(gt_thr_mask, list)
|
||||
|
||||
preds = preds[0]
|
||||
|
||||
pred_prob = preds[:, 0, :, :]
|
||||
pred_thr = preds[:, 1, :, :]
|
||||
pred_db = preds[:, 2, :, :]
|
||||
|
Loading…
x
Reference in New Issue
Block a user