from __future__ import absolute_import import torch from torch import nn from torch.nn import functional as F from torch.autograd import Variable import torchvision __all__ = ['ResNet50'] class ResNet50(nn.Module): def __init__(self, num_classes, **kwargs): super(ResNet50, self).__init__() resnet50 = torchvision.models.resnet50(pretrained=True) self.base = nn.Sequential(*list(resnet50.children())[:-2]) self.classifier = nn.Linear(2048, num_classes) self.classifier.weight.data.uniform_(-1, 1) self.classifier.bias.data.zero_() def forward(self, x): x = self.base(x) x = F.avg_pool2d(x, x.size()[2:]) x = x.view(x.size(0), -1) if not self.training: return x x = self.classifier(x) return x