mirror of
https://github.com/open-mmlab/mmpretrain.git
synced 2025-06-03 14:59:18 +08:00
32 lines
1000 B
Python
32 lines
1000 B
Python
|
from mmcls.models import ResNet
|
||
|
from mmcls.registry import MODELS
|
||
|
|
||
|
|
||
|
# Register your model to the `MODELS`.
|
||
|
@MODELS.register_module()
|
||
|
class ExampleNet(ResNet):
|
||
|
"""Implements an example backbone.
|
||
|
|
||
|
Implement the backbone network just like a normal pytorch network.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, **kwargs) -> None:
|
||
|
print('#############################\n'
|
||
|
'# Hello MMClassification! #\n'
|
||
|
'#############################')
|
||
|
super().__init__(**kwargs)
|
||
|
|
||
|
def forward(self, x):
|
||
|
"""The forward method of the network.
|
||
|
|
||
|
Args:
|
||
|
x (torch.Tensor): A tensor of image batch with shape
|
||
|
``(batch_size, num_channels, height, width)``.
|
||
|
|
||
|
Returns:
|
||
|
Tuple[torch.Tensor]: Please return a tuple of tensors and every
|
||
|
tensor is a feature map of specified scale. If you only want the
|
||
|
final feature map, simply return a tuple with one item.
|
||
|
"""
|
||
|
return super().forward(x)
|