mmpretrain/mmcls/models/backbones/base_backbone.py

28 lines
709 B
Python
Raw Normal View History

2020-06-03 15:51:17 +08:00
import logging
2020-06-03 15:57:07 +08:00
from abc import ABCMeta, abstractmethod
2020-06-03 15:51:17 +08:00
import torch.nn as nn
from mmcv.runner import load_checkpoint
class BaseBackbone(nn.Module, metaclass=ABCMeta):
def __init__(self):
super(BaseBackbone, self).__init__()
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
pass
else:
raise TypeError('pretrained must be a str or None')
@abstractmethod
def forward(self, x):
pass
def train(self, mode=True):
super(BaseBackbone, self).train(mode)