mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
* [Feature] Add RepeatAugSampler * initial commit * spts inference done * merge repeat_aug (bug in multi-node?) * fix inference * train done * rm readme * Revert "merge repeat_aug (bug in multi-node?)" This reverts commit 393506a97cbe6d75ad1f28611ea10eba6b8fa4b3. * Revert "[Feature] Add RepeatAugSampler" This reverts commit 2089b02b4844157670033766f257b5d1bca452ce. * remove utils * readme & conversion script * update readme * fix * optimize * rename cfg & del compose * fix * fix
45 lines
1.4 KiB
Python
45 lines
1.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
|
|
from mmocr.models.textrecog.encoders import BaseEncoder
|
|
from mmocr.registry import MODELS
|
|
from mmocr.structures import TextSpottingDataSample
|
|
|
|
|
|
@MODELS.register_module()
|
|
class SPTSEncoder(BaseEncoder):
|
|
"""SPTS Encoder.
|
|
|
|
Args:
|
|
d_backbone (int): Backbone output dimension.
|
|
d_model (int): Model output dimension.
|
|
init_cfg (dict or list[dict], optional): Initialization configs.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
d_backbone: int = 2048,
|
|
d_model: int = 256,
|
|
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
|
super().__init__(init_cfg=init_cfg)
|
|
self.input_proj = nn.Conv2d(d_backbone, d_model, kernel_size=1)
|
|
|
|
def forward(self,
|
|
feat: Tensor,
|
|
data_samples: List[TextSpottingDataSample] = None) -> Tensor:
|
|
"""Forward propagation of encoder.
|
|
|
|
Args:
|
|
feat (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`.
|
|
data_samples (list[TextSpottingDataSample]): Batch of
|
|
TextSpottingDataSample.
|
|
Defaults to None.
|
|
|
|
Returns:
|
|
Tensor: A tensor of shape :math:`(N, T, D_m)`.
|
|
"""
|
|
return self.input_proj(feat[0])
|