mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
45 lines
1.4 KiB
Python
45 lines
1.4 KiB
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
from typing import Dict, List, Optional, Union
|
||
|
|
||
|
import torch.nn as nn
|
||
|
from torch import Tensor
|
||
|
|
||
|
from mmocr.models.textrecog.encoders import BaseEncoder
|
||
|
from mmocr.registry import MODELS
|
||
|
from mmocr.structures import TextSpottingDataSample
|
||
|
|
||
|
|
||
|
@MODELS.register_module()
|
||
|
class SPTSEncoder(BaseEncoder):
|
||
|
"""SPTS Encoder.
|
||
|
|
||
|
Args:
|
||
|
d_backbone (int): Backbone output dimension.
|
||
|
d_model (int): Model output dimension.
|
||
|
init_cfg (dict or list[dict], optional): Initialization configs.
|
||
|
Defaults to None.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
d_backbone: int = 2048,
|
||
|
d_model: int = 256,
|
||
|
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
||
|
super().__init__(init_cfg=init_cfg)
|
||
|
self.input_proj = nn.Conv2d(d_backbone, d_model, kernel_size=1)
|
||
|
|
||
|
def forward(self,
|
||
|
feat: Tensor,
|
||
|
data_samples: List[TextSpottingDataSample] = None) -> Tensor:
|
||
|
"""Forward propagation of encoder.
|
||
|
|
||
|
Args:
|
||
|
feat (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`.
|
||
|
data_samples (list[TextSpottingDataSample]): Batch of
|
||
|
TextSpottingDataSample.
|
||
|
Defaults to None.
|
||
|
|
||
|
Returns:
|
||
|
Tensor: A tensor of shape :math:`(N, T, D_m)`.
|
||
|
"""
|
||
|
return self.input_proj(feat[0])
|