mmocr/projects/SPTS/spts/model/spts_encoder.py
Tong Gao 2d743cfa19
[Model] SPTS (#1696)
* [Feature] Add RepeatAugSampler

* initial commit

* spts inference done

* merge repeat_aug (bug in multi-node?)

* fix inference

* train done

* rm readme

* Revert "merge repeat_aug (bug in multi-node?)"

This reverts commit 393506a97cbe6d75ad1f28611ea10eba6b8fa4b3.

* Revert "[Feature] Add RepeatAugSampler"

This reverts commit 2089b02b4844157670033766f257b5d1bca452ce.

* remove utils

* readme & conversion script

* update readme

* fix

* optimize

* rename cfg & del compose

* fix

* fix
2023-02-01 11:58:03 +08:00

45 lines
1.4 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
import torch.nn as nn
from torch import Tensor
from mmocr.models.textrecog.encoders import BaseEncoder
from mmocr.registry import MODELS
from mmocr.structures import TextSpottingDataSample
@MODELS.register_module()
class SPTSEncoder(BaseEncoder):
"""SPTS Encoder.
Args:
d_backbone (int): Backbone output dimension.
d_model (int): Model output dimension.
init_cfg (dict or list[dict], optional): Initialization configs.
Defaults to None.
"""
def __init__(self,
d_backbone: int = 2048,
d_model: int = 256,
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.input_proj = nn.Conv2d(d_backbone, d_model, kernel_size=1)
def forward(self,
feat: Tensor,
data_samples: List[TextSpottingDataSample] = None) -> Tensor:
"""Forward propagation of encoder.
Args:
feat (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`.
data_samples (list[TextSpottingDataSample]): Batch of
TextSpottingDataSample.
Defaults to None.
Returns:
Tensor: A tensor of shape :math:`(N, T, D_m)`.
"""
return self.input_proj(feat[0])