mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Detector] refactor basedetector
This commit is contained in:
parent
d2e8e79df1
commit
b955df9904
@ -21,6 +21,6 @@ model = dict(
|
|||||||
det_head=dict(
|
det_head=dict(
|
||||||
type='DBHead',
|
type='DBHead',
|
||||||
in_channels=256,
|
in_channels=256,
|
||||||
loss=dict(type='DBLoss'),
|
loss_module=dict(type='DBLoss'),
|
||||||
postprocessor=dict(type='DBPostprocessor', text_repr_type='quad')),
|
postprocessor=dict(type='DBPostprocessor', text_repr_type='quad')),
|
||||||
preprocess_cfg=preprocess_cfg)
|
preprocess_cfg=preprocess_cfg)
|
||||||
|
@ -23,6 +23,6 @@ model = dict(
|
|||||||
det_head=dict(
|
det_head=dict(
|
||||||
type='DBHead',
|
type='DBHead',
|
||||||
in_channels=256,
|
in_channels=256,
|
||||||
loss=dict(type='DBLoss'),
|
loss_module=dict(type='DBLoss'),
|
||||||
postprocessor=dict(type='DBPostprocessor', text_repr_type='quad')),
|
postprocessor=dict(type='DBPostprocessor', text_repr_type='quad')),
|
||||||
preprocess_cfg=preprocess_cfg)
|
preprocess_cfg=preprocess_cfg)
|
||||||
|
@ -20,7 +20,7 @@ model = dict(
|
|||||||
bbox_head=dict(
|
bbox_head=dict(
|
||||||
type='DBHead',
|
type='DBHead',
|
||||||
in_channels=256,
|
in_channels=256,
|
||||||
loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True),
|
loss_module=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True),
|
||||||
postprocessor=dict(
|
postprocessor=dict(
|
||||||
type='DBPostprocessor', text_repr_type='quad',
|
type='DBPostprocessor', text_repr_type='quad',
|
||||||
epsilon_ratio=0.002)),
|
epsilon_ratio=0.002)),
|
||||||
|
@ -17,5 +17,5 @@ model = dict(
|
|||||||
in_channels=32,
|
in_channels=32,
|
||||||
text_region_thr=0.3,
|
text_region_thr=0.3,
|
||||||
center_region_thr=0.4,
|
center_region_thr=0.4,
|
||||||
loss=dict(type='DRRGLoss'),
|
loss_module=dict(type='DRRGLoss'),
|
||||||
postprocessor=dict(type='DRRGPostprocessor', link_thr=0.80)))
|
postprocessor=dict(type='DRRGPostprocessor', link_thr=0.80)))
|
||||||
|
@ -23,7 +23,7 @@ model = dict(
|
|||||||
in_channels=256,
|
in_channels=256,
|
||||||
scales=(8, 16, 32),
|
scales=(8, 16, 32),
|
||||||
fourier_degree=5,
|
fourier_degree=5,
|
||||||
loss=dict(type='FCELoss', num_sample=50),
|
loss_module=dict(type='FCELoss', num_sample=50),
|
||||||
postprocessor=dict(
|
postprocessor=dict(
|
||||||
type='FCEPostprocessor',
|
type='FCEPostprocessor',
|
||||||
text_repr_type='quad',
|
text_repr_type='quad',
|
||||||
|
@ -25,7 +25,7 @@ model = dict(
|
|||||||
in_channels=256,
|
in_channels=256,
|
||||||
scales=(8, 16, 32),
|
scales=(8, 16, 32),
|
||||||
fourier_degree=5,
|
fourier_degree=5,
|
||||||
loss=dict(type='FCELoss', num_sample=50),
|
loss_module=dict(type='FCELoss', num_sample=50),
|
||||||
postprocessor=dict(
|
postprocessor=dict(
|
||||||
type='FCEPostprocessor',
|
type='FCEPostprocessor',
|
||||||
text_repr_type='poly',
|
text_repr_type='poly',
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
preprocess_cfg = dict(
|
|
||||||
mean=[123.675, 116.28, 103.53],
|
|
||||||
std=[58.395, 57.12, 57.375],
|
|
||||||
to_rgb=True,
|
|
||||||
pad_size_divisor=32)
|
|
||||||
|
|
||||||
# BasicBlock has a little difference from official PANet
|
# BasicBlock has a little difference from official PANet
|
||||||
# BasicBlock in mmdet lacks RELU in the last convolution.
|
# BasicBlock in mmdet lacks RELU in the last convolution.
|
||||||
model = dict(
|
model = dict(
|
||||||
type='PANet',
|
type='PANet',
|
||||||
|
data_preprocessor=dict(
|
||||||
|
type='TextDetDataPreprocessor',
|
||||||
|
mean=[123.675, 116.28, 103.53],
|
||||||
|
std=[58.395, 57.12, 57.375],
|
||||||
|
bgr_to_rgb=True,
|
||||||
|
pad_size_divisor=32),
|
||||||
backbone=dict(
|
backbone=dict(
|
||||||
type='mmdet.ResNet',
|
type='mmdet.ResNet',
|
||||||
depth=18,
|
depth=18,
|
||||||
@ -25,10 +25,9 @@ model = dict(
|
|||||||
in_channels=[128, 128, 128, 128],
|
in_channels=[128, 128, 128, 128],
|
||||||
hidden_dim=128,
|
hidden_dim=128,
|
||||||
out_channel=6,
|
out_channel=6,
|
||||||
loss=dict(
|
loss_module=dict(
|
||||||
type='PANLoss',
|
type='PANLoss',
|
||||||
loss_text=dict(type='MaskedSquareDiceLoss'),
|
loss_text=dict(type='MaskedSquareDiceLoss'),
|
||||||
loss_kernel=dict(type='MaskedSquareDiceLoss'),
|
loss_kernel=dict(type='MaskedSquareDiceLoss'),
|
||||||
),
|
),
|
||||||
postprocessor=dict(type='PANPostprocessor', text_repr_type='quad')),
|
postprocessor=dict(type='PANPostprocessor', text_repr_type='quad')))
|
||||||
preprocess_cfg=preprocess_cfg)
|
|
||||||
|
@ -15,7 +15,7 @@ model = dict(
|
|||||||
type='PANHead',
|
type='PANHead',
|
||||||
in_channels=[128, 128, 128, 128],
|
in_channels=[128, 128, 128, 128],
|
||||||
out_channels=6,
|
out_channels=6,
|
||||||
loss=dict(type='PANLoss', speedup_bbox_thr=32),
|
loss_module=dict(type='PANLoss', speedup_bbox_thr=32),
|
||||||
postprocessor=dict(type='PANPostprocessor', text_repr_type='poly')),
|
postprocessor=dict(type='PANPostprocessor', text_repr_type='poly')),
|
||||||
train_cfg=None,
|
train_cfg=None,
|
||||||
test_cfg=None)
|
test_cfg=None)
|
||||||
|
@ -26,7 +26,7 @@ model_poly = dict(
|
|||||||
in_channels=[256],
|
in_channels=[256],
|
||||||
hidden_dim=256,
|
hidden_dim=256,
|
||||||
out_channel=7,
|
out_channel=7,
|
||||||
loss=dict(type='PSELoss'),
|
loss_module=dict(type='PSELoss'),
|
||||||
postprocessor=dict(type='PSEPostprocessor', text_repr_type='poly')),
|
postprocessor=dict(type='PSEPostprocessor', text_repr_type='poly')),
|
||||||
preprocess_cfg=preprocess_cfg)
|
preprocess_cfg=preprocess_cfg)
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ model = dict(
|
|||||||
det_head=dict(
|
det_head=dict(
|
||||||
type='TextSnakeHead',
|
type='TextSnakeHead',
|
||||||
in_channels=32,
|
in_channels=32,
|
||||||
loss=dict(type='TextSnakeLoss'),
|
loss_module=dict(type='TextSnakeLoss'),
|
||||||
postprocessor=dict(
|
postprocessor=dict(
|
||||||
type='TextSnakePostprocessor', text_repr_type='poly')),
|
type='TextSnakePostprocessor', text_repr_type='poly')),
|
||||||
preprocess_cfg=preprocess_cfg)
|
preprocess_cfg=preprocess_cfg)
|
||||||
|
@ -12,7 +12,7 @@ model = dict(
|
|||||||
type='CRNNDecoder',
|
type='CRNNDecoder',
|
||||||
in_channels=512,
|
in_channels=512,
|
||||||
rnn_flag=True,
|
rnn_flag=True,
|
||||||
loss=dict(type='CTCLoss', letter_case='lower'),
|
loss_module=dict(type='CTCLoss', letter_case='lower'),
|
||||||
postprocessor=dict(type='CTCPostProcessor')),
|
postprocessor=dict(type='CTCPostProcessor')),
|
||||||
dictionary=dictionary,
|
dictionary=dictionary,
|
||||||
data_preprocessor=dict(
|
data_preprocessor=dict(
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from . import detectors, heads, losses, necks, postprocessors
|
from . import (data_preprocessors, detectors, heads, losses, necks,
|
||||||
|
postprocessors)
|
||||||
|
from .data_preprocessors import * # NOQA
|
||||||
from .detectors import * # NOQA
|
from .detectors import * # NOQA
|
||||||
from .heads import * # NOQA
|
from .heads import * # NOQA
|
||||||
from .losses import * # NOQA
|
from .losses import * # NOQA
|
||||||
@ -8,4 +10,4 @@ from .postprocessors import * # NOQA
|
|||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
heads.__all__ + detectors.__all__ + losses.__all__ + necks.__all__ +
|
heads.__all__ + detectors.__all__ + losses.__all__ + necks.__all__ +
|
||||||
postprocessors.__all__)
|
postprocessors.__all__ + data_preprocessors.__all__)
|
||||||
|
4
mmocr/models/textdet/data_preprocessors/__init__.py
Normal file
4
mmocr/models/textdet/data_preprocessors/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from .data_preprocessor import TextDetDataPreprocessor
|
||||||
|
|
||||||
|
__all__ = ['TextDetDataPreprocessor']
|
104
mmocr/models/textdet/data_preprocessors/data_preprocessor.py
Normal file
104
mmocr/models/textdet/data_preprocessors/data_preprocessor.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from numbers import Number
|
||||||
|
from typing import List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from mmengine.model import ImgDataPreprocessor
|
||||||
|
|
||||||
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module()
|
||||||
|
class TextDetDataPreprocessor(ImgDataPreprocessor):
|
||||||
|
"""Image pre-processor for detection tasks.
|
||||||
|
|
||||||
|
Comparing with the :class:`mmengine.ImgDataPreprocessor`,
|
||||||
|
|
||||||
|
1. It supports batch augmentations.
|
||||||
|
2. It will additionally append batch_input_shape and pad_shape
|
||||||
|
to data_samples considering the object detection task.
|
||||||
|
|
||||||
|
It provides the data pre-processing as follows
|
||||||
|
|
||||||
|
- Collate and move data to the target device.
|
||||||
|
- Pad inputs to the maximum size of current batch with defined
|
||||||
|
``pad_value``. The padding size can be divisible by a defined
|
||||||
|
``pad_size_divisor``
|
||||||
|
- Stack inputs to batch_inputs.
|
||||||
|
- Convert inputs from bgr to rgb if the shape of input is (3, H, W).
|
||||||
|
- Normalize image with defined std and mean.
|
||||||
|
- Do batch augmentations during training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
|
||||||
|
Defaults to None.
|
||||||
|
std (Sequence[Number], optional): The pixel standard deviation of
|
||||||
|
R, G, B channels. Defaults to None.
|
||||||
|
pad_size_divisor (int): The size of padded image should be
|
||||||
|
divisible by ``pad_size_divisor``. Defaults to 1.
|
||||||
|
pad_value (Number): The padded pixel value. Defaults to 0.
|
||||||
|
pad_mask (bool): Whether to pad instance masks. Defaults to False.
|
||||||
|
mask_pad_value (int): The padded pixel value for instance masks.
|
||||||
|
Defaults to 0.
|
||||||
|
pad_seg (bool): Whether to pad semantic segmentation maps.
|
||||||
|
Defaults to False.
|
||||||
|
seg_pad_value (int): The padded pixel value for semantic
|
||||||
|
segmentation maps. Defaults to 255.
|
||||||
|
bgr_to_rgb (bool): whether to convert image from BGR to RGB.
|
||||||
|
Defaults to False.
|
||||||
|
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
|
||||||
|
Defaults to False.
|
||||||
|
batch_augments (list[dict], optional): Batch-level augmentations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
mean: Sequence[Number] = None,
|
||||||
|
std: Sequence[Number] = None,
|
||||||
|
pad_size_divisor: int = 1,
|
||||||
|
pad_value: Union[float, int] = 0,
|
||||||
|
bgr_to_rgb: bool = False,
|
||||||
|
rgb_to_bgr: bool = False,
|
||||||
|
batch_augments: Optional[List[dict]] = None):
|
||||||
|
super().__init__(
|
||||||
|
mean=mean,
|
||||||
|
std=std,
|
||||||
|
pad_size_divisor=pad_size_divisor,
|
||||||
|
pad_value=pad_value,
|
||||||
|
bgr_to_rgb=bgr_to_rgb,
|
||||||
|
rgb_to_bgr=rgb_to_bgr)
|
||||||
|
if batch_augments is not None:
|
||||||
|
self.batch_augments = nn.ModuleList(
|
||||||
|
[MODELS.build(aug) for aug in batch_augments])
|
||||||
|
else:
|
||||||
|
self.batch_augments = None
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
data: Sequence[dict],
|
||||||
|
training: bool = False) -> Tuple[torch.Tensor, Optional[list]]:
|
||||||
|
"""Perform normalization、padding and bgr2rgb conversion based on
|
||||||
|
``BaseDataPreprocessor``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (Sequence[dict]): data sampled from dataloader.
|
||||||
|
training (bool): Whether to enable training time augmentation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
|
||||||
|
model input.
|
||||||
|
"""
|
||||||
|
batch_inputs, batch_data_samples = super().forward(
|
||||||
|
data=data, training=training)
|
||||||
|
|
||||||
|
if batch_data_samples is not None:
|
||||||
|
batch_input_shape = tuple(batch_inputs[0].size()[-2:])
|
||||||
|
for data_samples in batch_data_samples:
|
||||||
|
data_samples.set_metainfo(
|
||||||
|
{'batch_input_shape': batch_input_shape})
|
||||||
|
|
||||||
|
if training and self.batch_augments is not None:
|
||||||
|
for batch_aug in self.batch_augments:
|
||||||
|
batch_inputs, batch_data_samples = batch_aug(
|
||||||
|
batch_inputs, batch_data_samples)
|
||||||
|
|
||||||
|
return batch_inputs, batch_data_samples
|
@ -2,7 +2,6 @@
|
|||||||
from typing import Dict, Optional, Sequence
|
from typing import Dict, Optional, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from mmcv.runner import auto_fp16
|
|
||||||
from mmdet.models.detectors.base import BaseDetector as MMDET_BaseDetector
|
from mmdet.models.detectors.base import BaseDetector as MMDET_BaseDetector
|
||||||
|
|
||||||
from mmocr.core.data_structures import TextDetDataSample
|
from mmocr.core.data_structures import TextDetDataSample
|
||||||
@ -21,7 +20,7 @@ class SingleStageTextDetector(MMDET_BaseDetector):
|
|||||||
neck (dict, optional): Neck config. If None, the output from backbone
|
neck (dict, optional): Neck config. If None, the output from backbone
|
||||||
will be directly fed into ``det_head``.
|
will be directly fed into ``det_head``.
|
||||||
det_head (dict): Head config.
|
det_head (dict): Head config.
|
||||||
preprocess_cfg (dict, optional): Model preprocessing config
|
data_preprocessor (dict, optional): Model preprocessing config
|
||||||
for processing the input image data. Keys allowed are
|
for processing the input image data. Keys allowed are
|
||||||
``to_rgb``(bool), ``pad_size_divisor``(int), ``pad_value``(int or
|
``to_rgb``(bool), ``pad_size_divisor``(int), ``pad_value``(int or
|
||||||
float), ``mean``(int or float) and ``std``(int or float).
|
float), ``mean``(int or float) and ``std``(int or float).
|
||||||
@ -35,83 +34,96 @@ class SingleStageTextDetector(MMDET_BaseDetector):
|
|||||||
backbone: Dict,
|
backbone: Dict,
|
||||||
det_head: Dict,
|
det_head: Dict,
|
||||||
neck: Optional[Dict] = None,
|
neck: Optional[Dict] = None,
|
||||||
preprocess_cfg: Optional[Dict] = None,
|
data_preprocessor: Optional[Dict] = None,
|
||||||
init_cfg: Optional[Dict] = None) -> None:
|
init_cfg: Optional[Dict] = None) -> None:
|
||||||
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
|
super().__init__(
|
||||||
|
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
|
||||||
assert det_head is not None, 'det_head cannot be None!'
|
assert det_head is not None, 'det_head cannot be None!'
|
||||||
self.backbone = MODELS.build(backbone)
|
self.backbone = MODELS.build(backbone)
|
||||||
if neck is not None:
|
if neck is not None:
|
||||||
self.neck = MODELS.build(neck)
|
self.neck = MODELS.build(neck)
|
||||||
self.det_head = MODELS.build(det_head)
|
self.det_head = MODELS.build(det_head)
|
||||||
|
|
||||||
def extract_feat(self, img: torch.Tensor) -> torch.Tensor:
|
def extract_feat(self, batch_inputs: torch.Tensor) -> torch.Tensor:
|
||||||
"""Directly extract features from the backbone+neck."""
|
"""Extract features.
|
||||||
x = self.backbone(img)
|
|
||||||
if self.with_neck:
|
|
||||||
x = self.neck(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward_train(self, img: torch.Tensor,
|
|
||||||
data_samples: Sequence[TextDetDataSample]) -> Dict:
|
|
||||||
"""
|
|
||||||
Args:
|
Args:
|
||||||
img (torch.Tensor): Input images of shape (N, C, H, W).
|
batch_inputs (Tensor): Image tensor with shape (N, C, H ,W).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor or tuple[Tensor]: Multi-level features that may have
|
||||||
|
different resolutions.
|
||||||
|
"""
|
||||||
|
batch_inputs = self.backbone(batch_inputs)
|
||||||
|
if self.with_neck:
|
||||||
|
batch_inputs = self.neck(batch_inputs)
|
||||||
|
return batch_inputs
|
||||||
|
|
||||||
|
def loss(self, batch_inputs: torch.Tensor,
|
||||||
|
batch_data_samples: Sequence[TextDetDataSample]) -> Dict:
|
||||||
|
"""Calculate losses from a batch of inputs and data samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_inputs (torch.Tensor): Input images of shape (N, C, H, W).
|
||||||
Typically these should be mean centered and std scaled.
|
Typically these should be mean centered and std scaled.
|
||||||
data_samples (list[TextDetDataSample]): A list of N datasamples,
|
batch_data_samples (list[TextDetDataSample]): A list of N
|
||||||
containing meta information and gold annotations for each of
|
datasamples, containing meta information and gold annotations
|
||||||
the images.
|
for each of the images.
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, Tensor]: A dictionary of loss components.
|
dict[str, Tensor]: A dictionary of loss components.
|
||||||
"""
|
"""
|
||||||
x = self.extract_feat(img)
|
batch_inputs = self.extract_feat(batch_inputs)
|
||||||
preds = self.det_head(x, data_samples)
|
return self.det_head.loss(batch_inputs, batch_data_samples)
|
||||||
losses = self.det_head.loss(preds, data_samples)
|
|
||||||
return losses
|
|
||||||
|
|
||||||
def simple_test(self, img: torch.Tensor,
|
def predict(
|
||||||
data_samples: Sequence[TextDetDataSample]
|
self, batch_inputs: torch.Tensor,
|
||||||
) -> Sequence[TextDetDataSample]:
|
batch_data_samples: Sequence[TextDetDataSample]
|
||||||
"""Test function without test-time augmentation.
|
) -> Sequence[TextDetDataSample]:
|
||||||
|
"""Predict results from a batch of inputs and data samples with post-
|
||||||
|
processing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img (torch.Tensor): Images of shape (N, C, H, W).
|
batch_inputs (torch.Tensor): Images of shape (N, C, H, W).
|
||||||
data_samples (list[TextDetDataSample]): A list of N datasamples,
|
batch_data_samples (list[TextDetDataSample]): A list of N
|
||||||
containing meta information and gold annotations for each of
|
datasamples, containing meta information and gold annotations
|
||||||
the images.
|
for each of the images.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[TextDetDataSample]: A list of N datasamples of prediction
|
list[TextDetDataSample]: A list of N datasamples of prediction
|
||||||
results. Results are stored in ``pred_instances``.
|
results. Each DetDataSample usually contain
|
||||||
|
'pred_instances'. And the ``pred_instances`` usually
|
||||||
|
contains following keys.
|
||||||
|
|
||||||
|
- scores (Tensor): Classification scores, has a shape
|
||||||
|
(num_instance, )
|
||||||
|
- labels (Tensor): Labels of bboxes, has a shape
|
||||||
|
(num_instances, ).
|
||||||
|
- bboxes (Tensor): Has a shape (num_instances, 4),
|
||||||
|
the last dimension 4 arrange as (x1, y1, x2, y2).
|
||||||
|
- polygons (list[np.ndarray]): The length is num_instances.
|
||||||
|
Each element represents the polygon of the
|
||||||
|
instance, in (xn, yn) order.
|
||||||
"""
|
"""
|
||||||
x = self.extract_feat(img)
|
x = self.extract_feat(batch_inputs)
|
||||||
preds = self.det_head(x, data_samples)
|
return self.det_head.predict(x, batch_data_samples)
|
||||||
return self.det_head.postprocessor(preds, data_samples)
|
|
||||||
|
|
||||||
def aug_test(
|
def _forward(self,
|
||||||
self, imgs: Sequence[torch.Tensor],
|
batch_inputs: torch.Tensor,
|
||||||
data_samples: Sequence[Sequence[TextDetDataSample]]
|
batch_data_samples: Optional[
|
||||||
) -> Sequence[Sequence[TextDetDataSample]]:
|
Sequence[TextDetDataSample]] = None,
|
||||||
"""Test function with test time augmentation."""
|
**kwargs) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
"""Network forward process. Usually includes backbone, neck and head
|
||||||
|
forward without any post-processing.
|
||||||
|
|
||||||
@auto_fp16(apply_to=('imgs', ))
|
Args:
|
||||||
def forward_simple_test(self, imgs: torch.Tensor,
|
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||||
data_samples: Sequence[TextDetDataSample]
|
batch_data_samples (list[TextDetDataSample]): A list of N
|
||||||
) -> Sequence[TextDetDataSample]:
|
datasamples, containing meta information and gold annotations
|
||||||
"""Test forward function called by self.forward() when running in test
|
for each of the images.
|
||||||
mode without test time augmentation.
|
|
||||||
|
|
||||||
Though not useful in MMOCR, it has been kept to maintain the maximum
|
|
||||||
compatibility with MMDetection's BaseDetector.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
img (torch.Tensor): Images of shape (N, C, H, W).
|
|
||||||
data_samples (list[TextDetDataSample]): A list of N datasamples,
|
|
||||||
containing meta information and gold annotations for each of
|
|
||||||
the images.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[TextDetDataSample]: A list of N datasamples of prediction
|
Tensor or tuple[Tensor]: A tuple of features from ``det_head``
|
||||||
results. Results are stored in ``pred_instances``.
|
forward.
|
||||||
"""
|
"""
|
||||||
return self.simple_test(imgs, data_samples)
|
x = self.extract_feat(batch_inputs)
|
||||||
|
return self.det_head(x, batch_data_samples)
|
||||||
|
@ -1,15 +1,52 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
from mmcv.runner import BaseModule
|
from mmcv.runner import BaseModule
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from mmocr.core.data_structures import TextDetDataSample
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
SampleList = List[TextDetDataSample]
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class BaseTextDetHead(BaseModule):
|
class BaseTextDetHead(BaseModule):
|
||||||
"""Base head for text detection, build the loss and postprocessor.
|
"""Base head for text detection, build the loss and postprocessor.
|
||||||
|
|
||||||
|
1. The ``init_weights`` method is used to initialize head's
|
||||||
|
model parameters. After detector initialization, ``init_weights``
|
||||||
|
is triggered when ``detector.init_weights()`` is called externally.
|
||||||
|
|
||||||
|
2. The ``loss`` method is used to calculate the loss of head,
|
||||||
|
which includes two steps: (1) the head model performs forward
|
||||||
|
propagation to obtain the feature maps (2) The ``loss_module`` method
|
||||||
|
is called based on the feature maps to calculate the loss.
|
||||||
|
|
||||||
|
.. code:: text
|
||||||
|
|
||||||
|
loss(): forward() -> loss_module()
|
||||||
|
|
||||||
|
3. The ``predict`` method is used to predict detection results,
|
||||||
|
which includes two steps: (1) the head model performs forward
|
||||||
|
propagation to obtain the feature maps (2) The ``postprocessor`` method
|
||||||
|
is called based on the feature maps to predict detection results including
|
||||||
|
post-processing.
|
||||||
|
|
||||||
|
.. code:: text
|
||||||
|
|
||||||
|
predict(): forward() -> postprocessor()
|
||||||
|
|
||||||
|
4. The ``loss_and_predict`` method is used to return loss and detection
|
||||||
|
results at the same time. It will call head's ``forward``,
|
||||||
|
``loss_module`` and ``postprocessor`` methods in order.
|
||||||
|
|
||||||
|
.. code:: text
|
||||||
|
|
||||||
|
loss_and_predict(): forward() -> loss_module() -> postprocessor()
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
loss (dict): Config to build loss.
|
loss (dict): Config to build loss.
|
||||||
postprocessor (dict): Config to build postprocessor.
|
postprocessor (dict): Config to build postprocessor.
|
||||||
@ -18,12 +55,75 @@ class BaseTextDetHead(BaseModule):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
loss: Dict,
|
loss_module: Dict,
|
||||||
postprocessor: Dict,
|
postprocessor: Dict,
|
||||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
||||||
super().__init__(init_cfg=init_cfg)
|
super().__init__(init_cfg=init_cfg)
|
||||||
assert isinstance(loss, dict)
|
assert isinstance(loss_module, dict)
|
||||||
assert isinstance(postprocessor, dict)
|
assert isinstance(postprocessor, dict)
|
||||||
|
|
||||||
self.loss = MODELS.build(loss)
|
self.loss_module = MODELS.build(loss_module)
|
||||||
self.postprocessor = MODELS.build(postprocessor)
|
self.postprocessor = MODELS.build(postprocessor)
|
||||||
|
|
||||||
|
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> dict:
|
||||||
|
"""Perform forward propagation and loss calculation of the detection
|
||||||
|
head on the features of the upstream network.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (tuple[Tensor]): Features from the upstream network, each is
|
||||||
|
a 4D-tensor.
|
||||||
|
batch_data_samples (List[:obj:`DetDataSample`]): The Data
|
||||||
|
Samples. It usually includes information such as
|
||||||
|
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary of loss components.
|
||||||
|
"""
|
||||||
|
outs = self(x)
|
||||||
|
losses = self.loss_module(outs, batch_data_samples)
|
||||||
|
return losses
|
||||||
|
|
||||||
|
def loss_and_predict(self, x: Tuple[Tensor], batch_data_samples: SampleList
|
||||||
|
) -> Tuple[dict, SampleList]:
|
||||||
|
"""Perform forward propagation of the head, then calculate loss and
|
||||||
|
predictions from the features and data samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (tuple[Tensor]): Features from FPN.
|
||||||
|
batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
|
||||||
|
the meta information of each image and corresponding
|
||||||
|
annotations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: the return value is a tuple contains:
|
||||||
|
|
||||||
|
- losses: (dict[str, Tensor]): A dictionary of loss components.
|
||||||
|
- predictions (list[:obj:`InstanceData`]): Detection
|
||||||
|
results of each image after the post process.
|
||||||
|
"""
|
||||||
|
outs = self(x)
|
||||||
|
losses = self.loss_module(outs, batch_data_samples)
|
||||||
|
|
||||||
|
predictions = self.postprocessor(outs, batch_data_samples)
|
||||||
|
return losses, predictions
|
||||||
|
|
||||||
|
def predict(self, x: torch.Tensor,
|
||||||
|
batch_data_samples: SampleList) -> SampleList:
|
||||||
|
"""Perform forward propagation of the detection head and predict
|
||||||
|
detection results on the features of the upstream network.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (tuple[Tensor]): Multi-level features from the
|
||||||
|
upstream network, each is a 4D-tensor.
|
||||||
|
batch_data_samples (List[:obj:`DetDataSample`]): The Data
|
||||||
|
Samples. It usually includes information such as
|
||||||
|
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SampleList: Detection results of each image
|
||||||
|
after the post process.
|
||||||
|
"""
|
||||||
|
outs = self(x)
|
||||||
|
|
||||||
|
predictions = self.postprocessor(outs, batch_data_samples)
|
||||||
|
return predictions
|
||||||
|
@ -28,7 +28,7 @@ class DBHead(BaseTextDetHead):
|
|||||||
self,
|
self,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
with_bias: bool = False,
|
with_bias: bool = False,
|
||||||
loss: Dict = dict(type='DBLoss'),
|
loss_module: Dict = dict(type='DBLoss'),
|
||||||
postprocessor: Dict = dict(
|
postprocessor: Dict = dict(
|
||||||
type='DBPostprocessor', text_repr_type='quad'),
|
type='DBPostprocessor', text_repr_type='quad'),
|
||||||
init_cfg: Optional[Union[Dict, List[Dict]]] = [
|
init_cfg: Optional[Union[Dict, List[Dict]]] = [
|
||||||
@ -37,7 +37,9 @@ class DBHead(BaseTextDetHead):
|
|||||||
]
|
]
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
loss=loss, postprocessor=postprocessor, init_cfg=init_cfg)
|
loss_module=loss_module,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
init_cfg=init_cfg)
|
||||||
|
|
||||||
assert isinstance(in_channels, int)
|
assert isinstance(in_channels, int)
|
||||||
assert isinstance(with_bias, bool)
|
assert isinstance(with_bias, bool)
|
||||||
|
@ -30,7 +30,7 @@ class FCEHead(BaseTextDetHead):
|
|||||||
self,
|
self,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
fourier_degree: int = 5,
|
fourier_degree: int = 5,
|
||||||
loss: Dict = dict(type='FCELoss', num_sample=50),
|
loss_module: Dict = dict(type='FCELoss', num_sample=50),
|
||||||
postprocessor: Dict = dict(
|
postprocessor: Dict = dict(
|
||||||
type='FCEPostprocessor',
|
type='FCEPostprocessor',
|
||||||
text_repr_type='poly',
|
text_repr_type='poly',
|
||||||
@ -45,10 +45,12 @@ class FCEHead(BaseTextDetHead):
|
|||||||
override=[dict(name='out_conv_cls'),
|
override=[dict(name='out_conv_cls'),
|
||||||
dict(name='out_conv_reg')])
|
dict(name='out_conv_reg')])
|
||||||
) -> None:
|
) -> None:
|
||||||
loss['fourier_degree'] = fourier_degree
|
loss_module['fourier_degree'] = fourier_degree
|
||||||
postprocessor['fourier_degree'] = fourier_degree
|
postprocessor['fourier_degree'] = fourier_degree
|
||||||
super().__init__(
|
super().__init__(
|
||||||
loss=loss, postprocessor=postprocessor, init_cfg=init_cfg)
|
loss_module=loss_module,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
init_cfg=init_cfg)
|
||||||
|
|
||||||
assert isinstance(in_channels, int)
|
assert isinstance(in_channels, int)
|
||||||
assert isinstance(fourier_degree, int)
|
assert isinstance(fourier_degree, int)
|
||||||
|
@ -33,7 +33,7 @@ class PANHead(BaseTextDetHead):
|
|||||||
in_channels: List[int],
|
in_channels: List[int],
|
||||||
hidden_dim: int,
|
hidden_dim: int,
|
||||||
out_channel: int,
|
out_channel: int,
|
||||||
loss=dict(type='PANLoss'),
|
loss_module=dict(type='PANLoss'),
|
||||||
postprocessor=dict(type='PANPostprocessor', text_repr_type='poly'),
|
postprocessor=dict(type='PANPostprocessor', text_repr_type='poly'),
|
||||||
init_cfg=[
|
init_cfg=[
|
||||||
dict(type='Normal', mean=0, std=0.01, layer='Conv2d'),
|
dict(type='Normal', mean=0, std=0.01, layer='Conv2d'),
|
||||||
@ -41,7 +41,9 @@ class PANHead(BaseTextDetHead):
|
|||||||
]
|
]
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
loss=loss, postprocessor=postprocessor, init_cfg=init_cfg)
|
loss_module=loss_module,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
init_cfg=init_cfg)
|
||||||
|
|
||||||
assert check_argument.is_type_list(in_channels, int)
|
assert check_argument.is_type_list(in_channels, int)
|
||||||
assert isinstance(out_channel, int)
|
assert isinstance(out_channel, int)
|
||||||
|
@ -24,7 +24,7 @@ class PSEHead(PANHead):
|
|||||||
in_channels: List[int],
|
in_channels: List[int],
|
||||||
hidden_dim: int,
|
hidden_dim: int,
|
||||||
out_channel: int,
|
out_channel: int,
|
||||||
loss: Dict = dict(type='PSELoss'),
|
loss_module: Dict = dict(type='PSELoss'),
|
||||||
postprocessor: Dict = dict(
|
postprocessor: Dict = dict(
|
||||||
type='PSEPostprocessor', text_repr_type='poly'),
|
type='PSEPostprocessor', text_repr_type='poly'),
|
||||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
||||||
@ -33,6 +33,6 @@ class PSEHead(PANHead):
|
|||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
hidden_dim=hidden_dim,
|
hidden_dim=hidden_dim,
|
||||||
out_channel=out_channel,
|
out_channel=out_channel,
|
||||||
loss=loss,
|
loss_module=loss_module,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
init_cfg=init_cfg)
|
init_cfg=init_cfg)
|
||||||
|
@ -31,14 +31,16 @@ class TextSnakeHead(BaseTextDetHead):
|
|||||||
in_channels: int,
|
in_channels: int,
|
||||||
out_channels: int = 5,
|
out_channels: int = 5,
|
||||||
downsample_ratio: float = 1.0,
|
downsample_ratio: float = 1.0,
|
||||||
loss: Dict = dict(type='TextSnakeLoss'),
|
loss_module: Dict = dict(type='TextSnakeLoss'),
|
||||||
postprocessor: Dict = dict(
|
postprocessor: Dict = dict(
|
||||||
type='TextSnakePostprocessor', text_repr_type='poly'),
|
type='TextSnakePostprocessor', text_repr_type='poly'),
|
||||||
init_cfg: Optional[Union[Dict, List[Dict]]] = dict(
|
init_cfg: Optional[Union[Dict, List[Dict]]] = dict(
|
||||||
type='Normal', override=dict(name='out_conv'), mean=0, std=0.01)
|
type='Normal', override=dict(name='out_conv'), mean=0, std=0.01)
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
loss=loss, postprocessor=postprocessor, init_cfg=init_cfg)
|
loss_module=loss_module,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
init_cfg=init_cfg)
|
||||||
assert isinstance(in_channels, int)
|
assert isinstance(in_channels, int)
|
||||||
assert isinstance(out_channels, int)
|
assert isinstance(out_channels, int)
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
@ -0,0 +1,106 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from unittest import TestCase
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmocr.core import TextDetDataSample
|
||||||
|
from mmocr.models.textdet.data_preprocessors import TextDetDataPreprocessor
|
||||||
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module()
|
||||||
|
class TDAugment(torch.nn.Module):
|
||||||
|
|
||||||
|
def forward(self, batch_inputs, batch_data_samples):
|
||||||
|
return batch_inputs, batch_data_samples
|
||||||
|
|
||||||
|
|
||||||
|
class TestTextDetDataPreprocessor(TestCase):
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
# test mean is None
|
||||||
|
processor = TextDetDataPreprocessor()
|
||||||
|
self.assertTrue(not hasattr(processor, 'mean'))
|
||||||
|
self.assertTrue(processor._enable_normalize is False)
|
||||||
|
|
||||||
|
# test mean is not None
|
||||||
|
processor = TextDetDataPreprocessor(mean=[0, 0, 0], std=[1, 1, 1])
|
||||||
|
self.assertTrue(hasattr(processor, 'mean'))
|
||||||
|
self.assertTrue(hasattr(processor, 'std'))
|
||||||
|
self.assertTrue(processor._enable_normalize)
|
||||||
|
|
||||||
|
# please specify both mean and std
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
TextDetDataPreprocessor(mean=[0, 0, 0])
|
||||||
|
|
||||||
|
# bgr2rgb and rgb2bgr cannot be set to True at the same time
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
TextDetDataPreprocessor(bgr_to_rgb=True, rgb_to_bgr=True)
|
||||||
|
|
||||||
|
aug_cfg = [dict(type='TDAugment')]
|
||||||
|
processor = TextDetDataPreprocessor()
|
||||||
|
self.assertIsNone(processor.batch_augments)
|
||||||
|
processor = TextDetDataPreprocessor(batch_augments=aug_cfg)
|
||||||
|
self.assertIsInstance(processor.batch_augments, torch.nn.ModuleList)
|
||||||
|
self.assertIsInstance(processor.batch_augments[0], TDAugment)
|
||||||
|
|
||||||
|
def test_forward(self):
|
||||||
|
processor = TextDetDataPreprocessor(mean=[0, 0, 0], std=[1, 1, 1])
|
||||||
|
|
||||||
|
data = [{
|
||||||
|
'inputs':
|
||||||
|
torch.randint(0, 256, (3, 11, 10)),
|
||||||
|
'data_sample':
|
||||||
|
TextDetDataSample(
|
||||||
|
metainfo=dict(img_shape=(11, 10), valid_ratio=1.0))
|
||||||
|
}]
|
||||||
|
inputs, data_samples = processor(data)
|
||||||
|
print(inputs.dtype)
|
||||||
|
self.assertEqual(inputs.shape, (1, 3, 11, 10))
|
||||||
|
self.assertEqual(len(data_samples), 1)
|
||||||
|
|
||||||
|
# test channel_conversion
|
||||||
|
processor = TextDetDataPreprocessor(
|
||||||
|
mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True)
|
||||||
|
inputs, data_samples = processor(data)
|
||||||
|
self.assertEqual(inputs.shape, (1, 3, 11, 10))
|
||||||
|
self.assertEqual(len(data_samples), 1)
|
||||||
|
|
||||||
|
# test padding
|
||||||
|
data = [{
|
||||||
|
'inputs': torch.randint(0, 256, (3, 10, 11))
|
||||||
|
}, {
|
||||||
|
'inputs': torch.randint(0, 256, (3, 9, 14))
|
||||||
|
}]
|
||||||
|
processor = TextDetDataPreprocessor(
|
||||||
|
mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True)
|
||||||
|
inputs, data_samples = processor(data)
|
||||||
|
self.assertEqual(inputs.shape, (2, 3, 10, 14))
|
||||||
|
self.assertIsNone(data_samples)
|
||||||
|
|
||||||
|
# test pad_size_divisor
|
||||||
|
data = [{
|
||||||
|
'inputs':
|
||||||
|
torch.randint(0, 256, (3, 10, 11)),
|
||||||
|
'data_sample':
|
||||||
|
TextDetDataSample(
|
||||||
|
metainfo=dict(img_shape=(10, 11), valid_ratio=1.0))
|
||||||
|
}, {
|
||||||
|
'inputs':
|
||||||
|
torch.randint(0, 256, (3, 9, 24)),
|
||||||
|
'data_sample':
|
||||||
|
TextDetDataSample(
|
||||||
|
metainfo=dict(img_shape=(9, 24), valid_ratio=1.0))
|
||||||
|
}]
|
||||||
|
aug_cfg = [dict(type='TDAugment')]
|
||||||
|
processor = TextDetDataPreprocessor(
|
||||||
|
mean=[0., 0., 0.],
|
||||||
|
std=[1., 1., 1.],
|
||||||
|
pad_size_divisor=5,
|
||||||
|
batch_augments=aug_cfg)
|
||||||
|
inputs, data_samples = processor(data, training=True)
|
||||||
|
self.assertEqual(inputs.shape, (2, 3, 10, 25))
|
||||||
|
self.assertEqual(len(data_samples), 2)
|
||||||
|
for data_sample, expected_shape in zip(data_samples, [(10, 25),
|
||||||
|
(10, 25)]):
|
||||||
|
self.assertEqual(data_sample.batch_input_shape, expected_shape)
|
59
tests/test_models/test_textdet/test_heads/test_base_head.py
Normal file
59
tests/test_models/test_textdet/test_heads/test_base_head.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from unittest import TestCase, mock
|
||||||
|
|
||||||
|
from mmocr.models.textdet import BaseTextDetHead
|
||||||
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module()
|
||||||
|
class FakeModule:
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_targets(self, datasamples):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __call__(self, *args):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseTextDetHead(TestCase):
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
cfg = dict(type='FakeModule')
|
||||||
|
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
BaseTextDetHead([], cfg)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
BaseTextDetHead(cfg, [])
|
||||||
|
|
||||||
|
decoder = BaseTextDetHead(cfg, cfg)
|
||||||
|
self.assertIsInstance(decoder.loss_module, FakeModule)
|
||||||
|
self.assertIsInstance(decoder.postprocessor, FakeModule)
|
||||||
|
|
||||||
|
@mock.patch(f'{__name__}.BaseTextDetHead.forward')
|
||||||
|
def test_forward(self, mock_forward):
|
||||||
|
|
||||||
|
def mock_forward(feat, out_enc, datasamples):
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
mock_forward.side_effect = mock_forward
|
||||||
|
cfg = dict(type='FakeModule')
|
||||||
|
decoder = BaseTextDetHead(cfg, cfg)
|
||||||
|
# test loss
|
||||||
|
loss = decoder.loss(None, None)
|
||||||
|
self.assertIsNone(loss)
|
||||||
|
|
||||||
|
# test predict
|
||||||
|
predict = decoder.predict(None, None)
|
||||||
|
self.assertIsNone(predict)
|
||||||
|
|
||||||
|
# test forward
|
||||||
|
tensor = decoder(None, None)
|
||||||
|
self.assertTrue(tensor)
|
||||||
|
|
||||||
|
loss, predict = decoder.loss_and_predict(None, None)
|
||||||
|
self.assertIsNone(loss)
|
||||||
|
self.assertIsNone(predict)
|
Loading…
x
Reference in New Issue
Block a user