27 lines
812 B
Python
27 lines
812 B
Python
|
from __future__ import absolute_import
|
||
|
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
from torch.nn import functional as F
|
||
|
from torch.autograd import Variable
|
||
|
import torchvision
|
||
|
|
||
|
__all__ = ['ResNet50']
|
||
|
|
||
|
class ResNet50(nn.Module):
|
||
|
def __init__(self, num_classes, **kwargs):
|
||
|
super(ResNet50, self).__init__()
|
||
|
resnet50 = torchvision.models.resnet50(pretrained=True)
|
||
|
self.base = nn.Sequential(*list(resnet50.children())[:-2])
|
||
|
self.classifier = nn.Linear(2048, num_classes)
|
||
|
self.classifier.weight.data.uniform_(-1, 1)
|
||
|
self.classifier.bias.data.zero_()
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.base(x)
|
||
|
x = F.avg_pool2d(x, x.size()[2:])
|
||
|
x = x.view(x.size(0), -1)
|
||
|
if not self.training:
|
||
|
return x
|
||
|
x = self.classifier(x)
|
||
|
return x
|