mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
* Move 'test time pool' to Module that can be used by any model, remove from DPN * Remove ResNext model file and combine with ResNet * Remove fbresnet200 as it was an old conversion and pretrained performance not worth param count * Cleanup adaptive avgmax pooling and add back conctat variant * Factor out checkpoint load fn
28 lines
1.0 KiB
Python
28 lines
1.0 KiB
Python
from torch import nn
|
|
import torch.nn.functional as F
|
|
from models.adaptive_avgmax_pool import adaptive_avgmax_pool2d
|
|
|
|
|
|
class TestTimePoolHead(nn.Module):
|
|
def __init__(self, base, original_pool=7):
|
|
super(TestTimePoolHead, self).__init__()
|
|
self.base = base
|
|
self.original_pool = original_pool
|
|
base_fc = self.base.get_classifier()
|
|
if isinstance(base_fc, nn.Conv2d):
|
|
self.fc = base_fc
|
|
else:
|
|
self.fc = nn.Conv2d(
|
|
self.base.num_features, self.base.num_classes, kernel_size=1, bias=True)
|
|
self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size()))
|
|
self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size()))
|
|
self.base.reset_classifier(0) # delete original fc layer
|
|
|
|
def forward(self, x):
|
|
x = self.base.forward_features(x, pool=False)
|
|
x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1)
|
|
x = self.fc(x)
|
|
x = adaptive_avgmax_pool2d(x, 1)
|
|
return x.view(x.size(0), -1)
|
|
|