Add typing to reset_classifier() on other models

This commit is contained in:
Ross Wightman 2024-05-12 11:12:00 -07:00
parent 3e03b2bf3f
commit c838c4233f
35 changed files with 61 additions and 64 deletions

View File

@ -395,7 +395,7 @@ class Beit(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool

View File

@ -331,7 +331,7 @@ class Cait(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'token', 'avg')

View File

@ -7,8 +7,7 @@ Official CoaT code at: https://github.com/mlpc-ucsd/CoaT
Modified from timm/models/vision_transformer.py
"""
from functools import partial
from typing import Tuple, List, Union
from typing import List, Optional, Union, Tuple
import torch
import torch.nn as nn
@ -560,7 +559,7 @@ class CoaT(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('token', 'avg')

View File

@ -21,8 +21,7 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
'''These modules are adapted from those of timm, see
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
'''
from functools import partial
from typing import Optional
import torch
import torch.nn as nn
@ -349,7 +348,7 @@ class ConVit(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'token', 'avg')

View File

@ -1,6 +1,8 @@
""" ConvMixer
"""
from typing import Optional
import torch
import torch.nn as nn
@ -75,7 +77,7 @@ class ConvMixer(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)

View File

@ -37,7 +37,6 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
# LICENSE file in the root directory of this source tree (Attribution-NonCommercial 4.0 International (CC BY-NC 4.0))
# No code was used directly from ConvNeXt-V2, however the weights are CC BY-NC 4.0 so beware if using commercially.
from collections import OrderedDict
from functools import partial
from typing import Callable, List, Optional, Tuple, Union

View File

@ -25,8 +25,7 @@ Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master
"""
from functools import partial
from typing import List
from typing import Tuple
from typing import List, Optional, Tuple
import torch
import torch.hub
@ -419,7 +418,7 @@ class CrossVit(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('token', 'avg')

View File

@ -12,7 +12,7 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
# All rights reserved.
# This source code is licensed under the MIT license
from functools import partial
from typing import Tuple
from typing import Optional, Tuple
import torch
import torch.nn as nn
@ -568,7 +568,7 @@ class DaVit(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.head.reset(num_classes, global_pool)
def forward_features(self, x):

View File

@ -11,7 +11,7 @@ Modifications copyright 2021, Ross Wightman
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
from functools import partial
from typing import Sequence, Union
from typing import Optional
import torch
from torch import nn as nn
@ -20,7 +20,6 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import resample_abs_pos_embed
from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
__all__ = ['VisionTransformerDistilled'] # model_registry will add each entrypoint fn to this
@ -64,7 +63,7 @@ class VisionTransformerDistilled(VisionTransformer):
def get_classifier(self):
return self.head, self.head_dist
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()

View File

@ -8,7 +8,6 @@ Original code and weights from https://github.com/mmaaz60/EdgeNeXt
Modifications and additions for timm by / Copyright 2022, Ross Wightman
"""
import math
from collections import OrderedDict
from functools import partial
from typing import Tuple
@ -17,7 +16,7 @@ import torch.nn.functional as F
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d, \
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, create_conv2d, \
use_fused_attn, NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_module

View File

@ -449,7 +449,7 @@ class EfficientFormer(nn.Module):
def get_classifier(self):
return self.head, self.head_dist
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool

View File

@ -16,7 +16,7 @@ Modifications and timm support by / Copyright 2023, Ross Wightman
"""
import math
from functools import partial
from typing import Dict
from typing import Dict, Optional
import torch
import torch.nn as nn
@ -612,7 +612,7 @@ class EfficientFormerV2(nn.Module):
def get_classifier(self):
return self.head, self.head_dist
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool

View File

@ -13,7 +13,6 @@ from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh
@ -740,7 +739,7 @@ class EfficientVit(nn.Module):
def get_classifier(self):
return self.head.classifier[-1]
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
@ -858,7 +857,7 @@ class EfficientVitLarge(nn.Module):
def get_classifier(self):
return self.head.classifier[-1]
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool

View File

@ -9,7 +9,7 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/Effic
__all__ = ['EfficientVitMsra']
import itertools
from collections import OrderedDict
from typing import Dict
from typing import Dict, Optional
import torch
import torch.nn as nn
@ -464,7 +464,7 @@ class EfficientVitMsra(nn.Module):
def get_classifier(self):
return self.head.linear
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
if global_pool == 'avg':

View File

@ -539,7 +539,7 @@ class Eva(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool

View File

@ -396,7 +396,7 @@ class ReparamLargeKernelConv(nn.Module):
@staticmethod
def _fuse_bn(
conv: torch.Tensor, bn: nn.BatchNorm2d
conv: nn.Conv2d, bn: nn.BatchNorm2d
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Method to fuse batchnorm layer with conv layer.
@ -1232,7 +1232,7 @@ class FastVit(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)

View File

@ -454,7 +454,7 @@ class FocalNet(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.head.reset(num_classes, pool_type=global_pool)
def forward_features(self, x):

View File

@ -489,7 +489,7 @@ class GlobalContextVit(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is None:
global_pool = self.head.global_pool.pool_type

View File

@ -628,7 +628,7 @@ class Levit(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None, distillation=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation=None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
@ -730,7 +730,7 @@ class LevitDistilled(Levit):
def get_classifier(self):
return self.head, self.head_dist
def reset_classifier(self, num_classes, global_pool=None, distillation=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation=None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool

View File

@ -1248,7 +1248,7 @@ class MaxxVit(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)

View File

@ -255,7 +255,7 @@ class MlpMixer(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg')

View File

@ -825,7 +825,7 @@ class MultiScaleVit(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool

View File

@ -6,6 +6,7 @@ Next-ViT model defs and weights adapted from https://github.com/bytedance/Next-V
"""
# Copyright (c) ByteDance Inc. All rights reserved.
from functools import partial
from typing import Optional
import torch
import torch.nn.functional as F
@ -553,7 +554,7 @@ class NextViT(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.head.reset(num_classes, pool_type=global_pool)
def forward_features(self, x):

View File

@ -14,13 +14,13 @@ Modifications for timm by / Copyright 2020 Ross Wightman
import math
import re
from functools import partial
from typing import Sequence, Tuple
from typing import Optional, Sequence, Tuple
import torch
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, to_2tuple, LayerNorm
from timm.layers import trunc_normal_, to_2tuple
from ._builder import build_model_with_cfg
from ._registry import register_model, generate_default_cfgs
from .vision_transformer import Block
@ -246,7 +246,7 @@ class PoolingVisionTransformer(nn.Module):
else:
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.head_dist is not None:

View File

@ -16,7 +16,7 @@ Modifications and timm support by / Copyright 2022, Ross Wightman
"""
import math
from typing import Tuple, List, Callable, Union
from typing import Callable, List, Optional, Union
import torch
import torch.nn as nn
@ -379,7 +379,7 @@ class PyramidVisionTransformerV2(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('avg', '')

View File

@ -16,15 +16,16 @@ Adapted from official impl at https://github.com/jameslahm/RepViT
"""
__all__ = ['RepVit']
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from ._registry import register_model, generate_default_cfgs
from ._builder import build_model_with_cfg
from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple
from ._manipulate import checkpoint_seq
from typing import Optional
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs
class ConvNorm(nn.Sequential):
@ -322,7 +323,7 @@ class RepVit(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None, distillation=False):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation=False):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool

View File

@ -9,7 +9,7 @@ Paper: `Sequencer: Deep LSTM for Image Classification` - https://arxiv.org/pdf/2
import math
from functools import partial
from itertools import accumulate
from typing import Tuple
from typing import Optional, Tuple
import torch
import torch.nn as nn
@ -419,7 +419,7 @@ class Sequencer2d(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, pool_type=global_pool)

View File

@ -604,7 +604,7 @@ class SwinTransformer(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, pool_type=global_pool)

View File

@ -605,7 +605,7 @@ class SwinTransformerV2(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)

View File

@ -8,10 +8,9 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/TinyV
__all__ = ['TinyVit']
import math
import itertools
from functools import partial
from typing import Dict
from typing import Dict, Optional
import torch
import torch.nn as nn
@ -533,7 +532,7 @@ class TinyVit(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, pool_type=global_pool)

View File

@ -7,6 +7,7 @@ The official mindspore code is released and available at
https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT
"""
import math
from typing import Optional
import torch
import torch.nn as nn
@ -298,7 +299,7 @@ class TNT(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'token', 'avg')

View File

@ -7,6 +7,7 @@ Original model: https://github.com/mrT23/TResNet
"""
from collections import OrderedDict
from functools import partial
from typing import Optional
import torch
import torch.nn as nn
@ -233,7 +234,7 @@ class TResNet(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.head.reset(num_classes, pool_type=global_pool)
def forward_features(self, x):

View File

@ -382,7 +382,7 @@ class Twins(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg')

View File

@ -2374,7 +2374,6 @@ def vit_base_patch16_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> V
def vit_large_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT-Large model (ViT-L/14) CLIP image tower w/ QuickGELU act
"""
from timm.layers import get_act_layer
model_args = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True,
norm_layer=nn.LayerNorm, act_layer='quick_gelu')

View File

@ -622,7 +622,7 @@ class VOLO(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool