56 lines
1.7 KiB
Python
56 lines
1.7 KiB
Python
import logging
|
|
from abc import ABCMeta, abstractmethod
|
|
|
|
import torch.nn as nn
|
|
from mmcv.runner import load_checkpoint
|
|
|
|
|
|
class BaseBackbone(nn.Module, metaclass=ABCMeta):
|
|
"""Base backbone.
|
|
|
|
This class defines the basic functions of a backbone.
|
|
Any backbone that inherits this class should at least
|
|
define its own `forward` function.
|
|
|
|
"""
|
|
|
|
def __init__(self):
|
|
super(BaseBackbone, self).__init__()
|
|
|
|
def init_weights(self, pretrained=None):
|
|
"""Init backbone weights
|
|
|
|
Args:
|
|
pretrained (str | None): If pretrained is a string, then it
|
|
initializes backbone weights by loading the pretrained
|
|
checkpoint. If pretrained is None, then it follows default
|
|
initializer or customized initializer in subclasses.
|
|
"""
|
|
if isinstance(pretrained, str):
|
|
logger = logging.getLogger()
|
|
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
|
elif pretrained is None:
|
|
# use default initializer or customized initializer in subclasses
|
|
pass
|
|
else:
|
|
raise TypeError('pretrained must be a str or None.'
|
|
f' But received {type(pretrained)}.')
|
|
|
|
@abstractmethod
|
|
def forward(self, x):
|
|
"""Forward computation
|
|
|
|
Args:
|
|
x (tensor | tuple[tensor]): x could be a Torch.tensor or a tuple of
|
|
Torch.tensor, containing input data for forward computation.
|
|
"""
|
|
pass
|
|
|
|
def train(self, mode=True):
|
|
"""Set module status before forward computation
|
|
|
|
Args:
|
|
mode (bool): Whether it is train_mode or test_mode
|
|
"""
|
|
super(BaseBackbone, self).train(mode)
|