Some missed reset_classifier() type annotations
parent
71101ebba0
commit
b1a6f4a946
|
@ -156,7 +156,7 @@ class EfficientNet(nn.Module):
|
|||
def get_classifier(self) -> nn.Module:
|
||||
return self.classifier
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool, self.classifier = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool)
|
||||
|
|
|
@ -273,7 +273,7 @@ class GhostNet(nn.Module):
|
|||
def get_classifier(self) -> nn.Module:
|
||||
return self.classifier
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
|
||||
self.num_classes = num_classes
|
||||
# cannot meaningfully change pooling of efficient head after creation
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
|
|
|
@ -739,7 +739,7 @@ class HighResolutionNet(nn.Module):
|
|||
def get_classifier(self) -> nn.Module:
|
||||
return self.classifier
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool, self.classifier = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool)
|
||||
|
|
|
@ -280,7 +280,7 @@ class InceptionV4(nn.Module):
|
|||
def get_classifier(self) -> nn.Module:
|
||||
return self.last_linear
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool, self.last_linear = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool)
|
||||
|
|
|
@ -26,9 +26,9 @@ Adapted from https://github.com/sail-sg/metaformer, original copyright below
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -548,7 +548,7 @@ class MetaFormer(nn.Module):
|
|||
# if using MlpHead, dropout is handled by MlpHead
|
||||
if num_classes > 0:
|
||||
if self.use_mlp_head:
|
||||
# FIXME hidden size
|
||||
# FIXME not actually returning mlp hidden state right now as pre-logits.
|
||||
final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate)
|
||||
self.head_hidden_size = self.num_features
|
||||
else:
|
||||
|
@ -583,7 +583,7 @@ class MetaFormer(nn.Module):
|
|||
def get_classifier(self) -> nn.Module:
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes=0, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
if global_pool is not None:
|
||||
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
||||
|
|
|
@ -518,7 +518,7 @@ class NASNetALarge(nn.Module):
|
|||
def get_classifier(self) -> nn.Module:
|
||||
return self.last_linear
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool, self.last_linear = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool)
|
||||
|
|
|
@ -307,7 +307,7 @@ class PNASNet5Large(nn.Module):
|
|||
def get_classifier(self) -> nn.Module:
|
||||
return self.last_linear
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool, self.last_linear = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool)
|
||||
|
|
|
@ -514,7 +514,7 @@ class RegNet(nn.Module):
|
|||
def get_classifier(self) -> nn.Module:
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
|
|
|
@ -12,6 +12,7 @@ Copyright 2020 Ross Wightman
|
|||
|
||||
from functools import partial
|
||||
from math import ceil
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -229,7 +230,7 @@ class RexNet(nn.Module):
|
|||
def get_classifier(self) -> nn.Module:
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
|
|
|
@ -161,7 +161,7 @@ class SelecSls(nn.Module):
|
|||
def get_classifier(self) -> nn.Module:
|
||||
return self.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||
|
||||
|
|
|
@ -337,7 +337,7 @@ class SENet(nn.Module):
|
|||
def get_classifier(self) -> nn.Module:
|
||||
return self.last_linear
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool, self.last_linear = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool)
|
||||
|
|
|
@ -381,7 +381,7 @@ class VisionTransformerRelPos(nn.Module):
|
|||
def get_classifier(self) -> nn.Module:
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes: int, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('', 'avg', 'token')
|
||||
|
|
|
@ -536,7 +536,7 @@ class VisionTransformerSAM(nn.Module):
|
|||
def get_classifier(self) -> nn.Module:
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes=0, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
|
|
|
@ -11,7 +11,7 @@ for some reference, rewrote most of the code.
|
|||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -134,9 +134,17 @@ class OsaStage(nn.Module):
|
|||
else:
|
||||
drop_path = None
|
||||
blocks += [OsaBlock(
|
||||
in_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0, depthwise=depthwise,
|
||||
attn=attn if last_block else '', norm_layer=norm_layer, act_layer=act_layer, drop_path=drop_path)
|
||||
]
|
||||
in_chs,
|
||||
mid_chs,
|
||||
out_chs,
|
||||
layer_per_block,
|
||||
residual=residual and i > 0,
|
||||
depthwise=depthwise,
|
||||
attn=attn if last_block else '',
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
drop_path=drop_path
|
||||
)]
|
||||
in_chs = out_chs
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
|
@ -252,8 +260,9 @@ class VovNet(nn.Module):
|
|||
def get_classifier(self) -> nn.Module:
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||
def reset_classifier(self, num_classes, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
|
|
|
@ -174,7 +174,7 @@ class Xception(nn.Module):
|
|||
def get_classifier(self) -> nn.Module:
|
||||
return self.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||
|
||||
|
|
|
@ -274,7 +274,7 @@ class XceptionAligned(nn.Module):
|
|||
def get_classifier(self) -> nn.Module:
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
|
|
Loading…
Reference in New Issue