Update README. Fine-grained layer-wise lr decay working for tiny_vit and both efficientvits. Minor fixes.

This commit is contained in:
Ross Wightman 2023-09-01 15:05:29 -07:00
parent 2f0fbb59b3
commit 0d124ffd4f
4 changed files with 30 additions and 12 deletions

View File

@ -35,6 +35,10 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
* The Hugging Face Hub (https://huggingface.co/timm) is now the primary source for `timm` weights. Model cards include link to papers, original source, license.
* Previous 0.6.x can be cloned from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch or installed via pip with version.
### Sep 1, 2023
* TinyViT added by [SeeFun](https://github.com/seefun)
* Fix EfficientViT (MIT) to use torch.autocast so it works back to PT 1.10
### Aug 28, 2023
* Add dynamic img size support to models in `vision_transformer.py`, `vision_transformer_hybrid.py`, `deit.py`, and `eva.py` w/o breaking backward compat.
* Add `dynamic_img_size=True` to args at model creation time to allow changing the grid size (interpolate abs and/or ROPE pos embed each forward pass).

View File

@ -508,11 +508,10 @@ class EfficientVit(nn.Module):
# stages
self.feature_info = []
stages = []
stage_idx = 0
self.stages = nn.Sequential()
in_channels = widths[0]
for i, (w, d) in enumerate(zip(widths[1:], depths[1:])):
stages.append(EfficientVitStage(
self.stages.append(EfficientVitStage(
in_channels,
w,
depth=d,
@ -524,10 +523,8 @@ class EfficientVit(nn.Module):
))
stride *= 2
in_channels = w
self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{stage_idx}')]
stage_idx += 1
self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
self.num_features = in_channels
self.head_widths = head_widths
self.head_dropout = drop_rate
@ -548,8 +545,11 @@ class EfficientVit(nn.Module):
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem', # stem and embed
blocks=[(r'^stages\.(\d+)', None)]
stem=r'^stem',
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+).downsample', (0,)),
(r'^stages\.(\d+)\.\w+\.(\d+)', None),
]
)
return matcher

View File

@ -441,11 +441,18 @@ class EfficientVitMsra(nn.Module):
self.head = NormLinear(
self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity()
@torch.jit.ignore
def no_weight_decay(self):
return {x for x in self.state_dict().keys() if 'attention_biases' in x}
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^patch_embed',
blocks=[(r'^stages\.(\d+)', None)]
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+).downsample', (0,)),
(r'^stages\.(\d+)\.\w+\.(\d+)', None),
]
)
return matcher
@ -455,7 +462,7 @@ class EfficientVitMsra(nn.Module):
@torch.jit.ignore
def get_classifier(self):
return self.head
return self.head.linear
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes

View File

@ -509,11 +509,18 @@ class TinyVit(nn.Module):
def no_weight_decay_keywords(self):
return {'attention_biases'}
@torch.jit.ignore
def no_weight_decay(self):
return {x for x in self.state_dict().keys() if 'attention_biases' in x}
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^patch_embed',
blocks=[(r'^stages\.(\d+)', None)]
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+).downsample', (0,)),
(r'^stages\.(\d+)\.\w+\.(\d+)', None),
]
)
return matcher
@ -523,7 +530,7 @@ class TinyVit(nn.Module):
@torch.jit.ignore
def get_classifier(self):
return self.head
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes