diff --git a/README.md b/README.md index 4cf6f86b..d6033fa1 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,10 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before * The Hugging Face Hub (https://huggingface.co/timm) is now the primary source for `timm` weights. Model cards include link to papers, original source, license. * Previous 0.6.x can be cloned from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch or installed via pip with version. +### Sep 1, 2023 +* TinyViT added by [SeeFun](https://github.com/seefun) +* Fix EfficientViT (MIT) to use torch.autocast so it works back to PT 1.10 + ### Aug 28, 2023 * Add dynamic img size support to models in `vision_transformer.py`, `vision_transformer_hybrid.py`, `deit.py`, and `eva.py` w/o breaking backward compat. * Add `dynamic_img_size=True` to args at model creation time to allow changing the grid size (interpolate abs and/or ROPE pos embed each forward pass). diff --git a/timm/models/efficientvit_mit.py b/timm/models/efficientvit_mit.py index 0ebe3dd2..6fe444ad 100644 --- a/timm/models/efficientvit_mit.py +++ b/timm/models/efficientvit_mit.py @@ -508,11 +508,10 @@ class EfficientVit(nn.Module): # stages self.feature_info = [] - stages = [] - stage_idx = 0 + self.stages = nn.Sequential() in_channels = widths[0] for i, (w, d) in enumerate(zip(widths[1:], depths[1:])): - stages.append(EfficientVitStage( + self.stages.append(EfficientVitStage( in_channels, w, depth=d, @@ -524,10 +523,8 @@ class EfficientVit(nn.Module): )) stride *= 2 in_channels = w - self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{stage_idx}')] - stage_idx += 1 + self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{i}')] - self.stages = nn.Sequential(*stages) self.num_features = in_channels self.head_widths = head_widths self.head_dropout = drop_rate @@ -548,8 +545,11 @@ class EfficientVit(nn.Module): @torch.jit.ignore def group_matcher(self, coarse=False): matcher = dict( - stem=r'^stem', # stem and embed - blocks=[(r'^stages\.(\d+)', None)] + stem=r'^stem', + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+).downsample', (0,)), + (r'^stages\.(\d+)\.\w+\.(\d+)', None), + ] ) return matcher diff --git a/timm/models/efficientvit_msra.py b/timm/models/efficientvit_msra.py index 421d475b..1b7f52a0 100644 --- a/timm/models/efficientvit_msra.py +++ b/timm/models/efficientvit_msra.py @@ -441,11 +441,18 @@ class EfficientVitMsra(nn.Module): self.head = NormLinear( self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity() + @torch.jit.ignore + def no_weight_decay(self): + return {x for x in self.state_dict().keys() if 'attention_biases' in x} + @torch.jit.ignore def group_matcher(self, coarse=False): matcher = dict( stem=r'^patch_embed', - blocks=[(r'^stages\.(\d+)', None)] + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+).downsample', (0,)), + (r'^stages\.(\d+)\.\w+\.(\d+)', None), + ] ) return matcher @@ -455,7 +462,7 @@ class EfficientVitMsra(nn.Module): @torch.jit.ignore def get_classifier(self): - return self.head + return self.head.linear def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes diff --git a/timm/models/tiny_vit.py b/timm/models/tiny_vit.py index b3a6009c..4b583658 100644 --- a/timm/models/tiny_vit.py +++ b/timm/models/tiny_vit.py @@ -509,11 +509,18 @@ class TinyVit(nn.Module): def no_weight_decay_keywords(self): return {'attention_biases'} + @torch.jit.ignore + def no_weight_decay(self): + return {x for x in self.state_dict().keys() if 'attention_biases' in x} + @torch.jit.ignore def group_matcher(self, coarse=False): matcher = dict( stem=r'^patch_embed', - blocks=[(r'^stages\.(\d+)', None)] + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+).downsample', (0,)), + (r'^stages\.(\d+)\.\w+\.(\d+)', None), + ] ) return matcher @@ -523,7 +530,7 @@ class TinyVit(nn.Module): @torch.jit.ignore def get_classifier(self): - return self.head + return self.head.fc def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes