mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Small TResNet simplification, just use SelectAdaptivePool, don't notice any perf difference
This commit is contained in:
parent
e3a98171b2
commit
be7c784d21
@ -229,9 +229,9 @@ class TResNet(nn.Module):
|
|||||||
|
|
||||||
# head
|
# head
|
||||||
self.num_features = (self.planes * 8) * Bottleneck.expansion
|
self.num_features = (self.planes * 8) * Bottleneck.expansion
|
||||||
self.global_pool = None
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
|
||||||
self.head = None
|
self.head = nn.Sequential(OrderedDict([
|
||||||
self.reset_classifier(num_classes, global_pool)
|
('fc', nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes))]))
|
||||||
|
|
||||||
# model initilization
|
# model initilization
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
@ -273,11 +273,8 @@ class TResNet(nn.Module):
|
|||||||
return self.head.fc
|
return self.head.fc
|
||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||||
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
if global_pool == 'avg':
|
|
||||||
self.global_pool = FastGlobalAvgPool2d(flatten=True)
|
|
||||||
else:
|
|
||||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
|
|
||||||
self.head = None
|
self.head = None
|
||||||
if num_classes:
|
if num_classes:
|
||||||
self.head = nn.Sequential(OrderedDict([
|
self.head = nn.Sequential(OrderedDict([
|
||||||
|
Loading…
x
Reference in New Issue
Block a user