mmclassification/mmcls/models/backbones/base_backbone.py

56 lines
1.7 KiB
Python

import logging
from abc import ABCMeta, abstractmethod
import torch.nn as nn
from mmcv.runner import load_checkpoint
class BaseBackbone(nn.Module, metaclass=ABCMeta):
"""Base backbone.
This class defines the basic functions of a backbone.
Any backbone that inherits this class should at least
define its own `forward` function.
"""
def __init__(self):
super(BaseBackbone, self).__init__()
def init_weights(self, pretrained=None):
"""Init backbone weights
Args:
pretrained (str | None): If pretrained is a string, then it
initializes backbone weights by loading the pretrained
checkpoint. If pretrained is None, then it follows default
initializer or customized initializer in subclasses.
"""
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
# use default initializer or customized initializer in subclasses
pass
else:
raise TypeError('pretrained must be a str or None.'
f' But received {type(pretrained)}.')
@abstractmethod
def forward(self, x):
"""Forward computation
Args:
x (tensor | tuple[tensor]): x could be a Torch.tensor or a tuple of
Torch.tensor, containing input data for forward computation.
"""
pass
def train(self, mode=True):
"""Set module status before forward computation
Args:
mode (bool): Whether it is train_mode or test_mode
"""
super(BaseBackbone, self).train(mode)