mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add typing to reset_classifier() on other models
This commit is contained in:
parent
3e03b2bf3f
commit
c838c4233f
@ -395,7 +395,7 @@ class Beit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
|
@ -331,7 +331,7 @@ class Cait(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('', 'token', 'avg')
|
||||
|
@ -7,8 +7,7 @@ Official CoaT code at: https://github.com/mlpc-ucsd/CoaT
|
||||
|
||||
Modified from timm/models/vision_transformer.py
|
||||
"""
|
||||
from functools import partial
|
||||
from typing import Tuple, List, Union
|
||||
from typing import List, Optional, Union, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -560,7 +559,7 @@ class CoaT(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('token', 'avg')
|
||||
|
@ -21,8 +21,7 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
|
||||
'''These modules are adapted from those of timm, see
|
||||
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
'''
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -349,7 +348,7 @@ class ConVit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('', 'token', 'avg')
|
||||
|
@ -1,6 +1,8 @@
|
||||
""" ConvMixer
|
||||
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -75,7 +77,7 @@ class ConvMixer(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
|
||||
|
@ -37,7 +37,6 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
|
||||
# LICENSE file in the root directory of this source tree (Attribution-NonCommercial 4.0 International (CC BY-NC 4.0))
|
||||
# No code was used directly from ConvNeXt-V2, however the weights are CC BY-NC 4.0 so beware if using commercially.
|
||||
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
|
@ -25,8 +25,7 @@ Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master
|
||||
|
||||
"""
|
||||
from functools import partial
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.hub
|
||||
@ -419,7 +418,7 @@ class CrossVit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('token', 'avg')
|
||||
|
@ -12,7 +12,7 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the MIT license
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -568,7 +568,7 @@ class DaVit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
|
@ -11,7 +11,7 @@ Modifications copyright 2021, Ross Wightman
|
||||
# Copyright (c) 2015-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
from functools import partial
|
||||
from typing import Sequence, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
@ -20,7 +20,6 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import resample_abs_pos_embed
|
||||
from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
__all__ = ['VisionTransformerDistilled'] # model_registry will add each entrypoint fn to this
|
||||
@ -64,7 +63,7 @@ class VisionTransformerDistilled(VisionTransformer):
|
||||
def get_classifier(self):
|
||||
return self.head, self.head_dist
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
@ -8,7 +8,6 @@ Original code and weights from https://github.com/mmaaz60/EdgeNeXt
|
||||
Modifications and additions for timm by / Copyright 2022, Ross Wightman
|
||||
"""
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
|
||||
@ -17,7 +16,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d, \
|
||||
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, create_conv2d, \
|
||||
use_fused_attn, NormMlpClassifierHead, ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_module
|
||||
|
@ -449,7 +449,7 @@ class EfficientFormer(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head, self.head_dist
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
|
@ -16,7 +16,7 @@ Modifications and timm support by / Copyright 2023, Ross Wightman
|
||||
"""
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -612,7 +612,7 @@ class EfficientFormerV2(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head, self.head_dist
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
|
@ -13,7 +13,6 @@ from functools import partial
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh
|
||||
@ -740,7 +739,7 @@ class EfficientVit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.classifier[-1]
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
@ -858,7 +857,7 @@ class EfficientVitLarge(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.classifier[-1]
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
|
@ -9,7 +9,7 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/Effic
|
||||
__all__ = ['EfficientVitMsra']
|
||||
import itertools
|
||||
from collections import OrderedDict
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -464,7 +464,7 @@ class EfficientVitMsra(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.linear
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
if global_pool == 'avg':
|
||||
|
@ -539,7 +539,7 @@ class Eva(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
|
@ -396,7 +396,7 @@ class ReparamLargeKernelConv(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def _fuse_bn(
|
||||
conv: torch.Tensor, bn: nn.BatchNorm2d
|
||||
conv: nn.Conv2d, bn: nn.BatchNorm2d
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Method to fuse batchnorm layer with conv layer.
|
||||
|
||||
@ -1232,7 +1232,7 @@ class FastVit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
|
@ -454,7 +454,7 @@ class FocalNet(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
|
@ -489,7 +489,7 @@ class GlobalContextVit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is None:
|
||||
global_pool = self.head.global_pool.pool_type
|
||||
|
@ -628,7 +628,7 @@ class Levit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None, distillation=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation=None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
@ -730,7 +730,7 @@ class LevitDistilled(Levit):
|
||||
def get_classifier(self):
|
||||
return self.head, self.head_dist
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None, distillation=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation=None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
|
@ -1248,7 +1248,7 @@ class MaxxVit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
|
@ -255,7 +255,7 @@ class MlpMixer(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('', 'avg')
|
||||
|
@ -825,7 +825,7 @@ class MultiScaleVit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
|
@ -6,6 +6,7 @@ Next-ViT model defs and weights adapted from https://github.com/bytedance/Next-V
|
||||
"""
|
||||
# Copyright (c) ByteDance Inc. All rights reserved.
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -553,7 +554,7 @@ class NextViT(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
|
@ -14,13 +14,13 @@ Modifications for timm by / Copyright 2020 Ross Wightman
|
||||
import math
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import Sequence, Tuple
|
||||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import trunc_normal_, to_2tuple, LayerNorm
|
||||
from timm.layers import trunc_normal_, to_2tuple
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from .vision_transformer import Block
|
||||
@ -246,7 +246,7 @@ class PoolingVisionTransformer(nn.Module):
|
||||
else:
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
if self.head_dist is not None:
|
||||
|
@ -16,7 +16,7 @@ Modifications and timm support by / Copyright 2022, Ross Wightman
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Tuple, List, Callable, Union
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -379,7 +379,7 @@ class PyramidVisionTransformerV2(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('avg', '')
|
||||
|
@ -16,15 +16,16 @@ Adapted from official impl at https://github.com/jameslahm/RepViT
|
||||
"""
|
||||
|
||||
__all__ = ['RepVit']
|
||||
|
||||
import torch.nn as nn
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from ._builder import build_model_with_cfg
|
||||
from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple
|
||||
from ._manipulate import checkpoint_seq
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
|
||||
class ConvNorm(nn.Sequential):
|
||||
@ -322,7 +323,7 @@ class RepVit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None, distillation=False):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation=False):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
|
@ -9,7 +9,7 @@ Paper: `Sequencer: Deep LSTM for Image Classification` - https://arxiv.org/pdf/2
|
||||
import math
|
||||
from functools import partial
|
||||
from itertools import accumulate
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -419,7 +419,7 @@ class Sequencer2d(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
|
@ -604,7 +604,7 @@ class SwinTransformer(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
|
@ -605,7 +605,7 @@ class SwinTransformerV2(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
|
@ -8,10 +8,9 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/TinyV
|
||||
|
||||
__all__ = ['TinyVit']
|
||||
|
||||
import math
|
||||
import itertools
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -533,7 +532,7 @@ class TinyVit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
|
@ -7,6 +7,7 @@ The official mindspore code is released and available at
|
||||
https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT
|
||||
"""
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -298,7 +299,7 @@ class TNT(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('', 'token', 'avg')
|
||||
|
@ -7,6 +7,7 @@ Original model: https://github.com/mrT23/TResNet
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -233,7 +234,7 @@ class TResNet(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
|
@ -382,7 +382,7 @@ class Twins(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('', 'avg')
|
||||
|
@ -2374,7 +2374,6 @@ def vit_base_patch16_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> V
|
||||
def vit_large_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-Large model (ViT-L/14) CLIP image tower w/ QuickGELU act
|
||||
"""
|
||||
from timm.layers import get_act_layer
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True,
|
||||
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
|
||||
|
@ -622,7 +622,7 @@ class VOLO(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
|
Loading…
x
Reference in New Issue
Block a user