mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update README. Fine-grained layer-wise lr decay working for tiny_vit and both efficientvits. Minor fixes.
This commit is contained in:
parent
2f0fbb59b3
commit
0d124ffd4f
@ -35,6 +35,10 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
|
||||
* The Hugging Face Hub (https://huggingface.co/timm) is now the primary source for `timm` weights. Model cards include link to papers, original source, license.
|
||||
* Previous 0.6.x can be cloned from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch or installed via pip with version.
|
||||
|
||||
### Sep 1, 2023
|
||||
* TinyViT added by [SeeFun](https://github.com/seefun)
|
||||
* Fix EfficientViT (MIT) to use torch.autocast so it works back to PT 1.10
|
||||
|
||||
### Aug 28, 2023
|
||||
* Add dynamic img size support to models in `vision_transformer.py`, `vision_transformer_hybrid.py`, `deit.py`, and `eva.py` w/o breaking backward compat.
|
||||
* Add `dynamic_img_size=True` to args at model creation time to allow changing the grid size (interpolate abs and/or ROPE pos embed each forward pass).
|
||||
|
@ -508,11 +508,10 @@ class EfficientVit(nn.Module):
|
||||
|
||||
# stages
|
||||
self.feature_info = []
|
||||
stages = []
|
||||
stage_idx = 0
|
||||
self.stages = nn.Sequential()
|
||||
in_channels = widths[0]
|
||||
for i, (w, d) in enumerate(zip(widths[1:], depths[1:])):
|
||||
stages.append(EfficientVitStage(
|
||||
self.stages.append(EfficientVitStage(
|
||||
in_channels,
|
||||
w,
|
||||
depth=d,
|
||||
@ -524,10 +523,8 @@ class EfficientVit(nn.Module):
|
||||
))
|
||||
stride *= 2
|
||||
in_channels = w
|
||||
self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{stage_idx}')]
|
||||
stage_idx += 1
|
||||
self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{i}')]
|
||||
|
||||
self.stages = nn.Sequential(*stages)
|
||||
self.num_features = in_channels
|
||||
self.head_widths = head_widths
|
||||
self.head_dropout = drop_rate
|
||||
@ -548,8 +545,11 @@ class EfficientVit(nn.Module):
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
matcher = dict(
|
||||
stem=r'^stem', # stem and embed
|
||||
blocks=[(r'^stages\.(\d+)', None)]
|
||||
stem=r'^stem',
|
||||
blocks=r'^stages\.(\d+)' if coarse else [
|
||||
(r'^stages\.(\d+).downsample', (0,)),
|
||||
(r'^stages\.(\d+)\.\w+\.(\d+)', None),
|
||||
]
|
||||
)
|
||||
return matcher
|
||||
|
||||
|
@ -441,11 +441,18 @@ class EfficientVitMsra(nn.Module):
|
||||
self.head = NormLinear(
|
||||
self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity()
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {x for x in self.state_dict().keys() if 'attention_biases' in x}
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
matcher = dict(
|
||||
stem=r'^patch_embed',
|
||||
blocks=[(r'^stages\.(\d+)', None)]
|
||||
blocks=r'^stages\.(\d+)' if coarse else [
|
||||
(r'^stages\.(\d+).downsample', (0,)),
|
||||
(r'^stages\.(\d+)\.\w+\.(\d+)', None),
|
||||
]
|
||||
)
|
||||
return matcher
|
||||
|
||||
@ -455,7 +462,7 @@ class EfficientVitMsra(nn.Module):
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
return self.head.linear
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.num_classes = num_classes
|
||||
|
@ -509,11 +509,18 @@ class TinyVit(nn.Module):
|
||||
def no_weight_decay_keywords(self):
|
||||
return {'attention_biases'}
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {x for x in self.state_dict().keys() if 'attention_biases' in x}
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
matcher = dict(
|
||||
stem=r'^patch_embed',
|
||||
blocks=[(r'^stages\.(\d+)', None)]
|
||||
blocks=r'^stages\.(\d+)' if coarse else [
|
||||
(r'^stages\.(\d+).downsample', (0,)),
|
||||
(r'^stages\.(\d+)\.\w+\.(\d+)', None),
|
||||
]
|
||||
)
|
||||
return matcher
|
||||
|
||||
@ -523,7 +530,7 @@ class TinyVit(nn.Module):
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.num_classes = num_classes
|
||||
|
Loading…
x
Reference in New Issue
Block a user