mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
A few more crossvit tweaks, fix training w/ no_weight_decay names, add crop option for scaling, adjust default crop_pct for large img size to 1.0 for better results
This commit is contained in:
parent
7ab2491ab7
commit
f8a215cfe6
@ -40,7 +40,7 @@ from .vision_transformer import Mlp, Block
|
|||||||
def _cfg(url='', **kwargs):
|
def _cfg(url='', **kwargs):
|
||||||
return {
|
return {
|
||||||
'url': url,
|
'url': url,
|
||||||
'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None,
|
'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None, 'crop_pct': 0.875,
|
||||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
|
||||||
'first_conv': ('patch_embed.0.proj', 'patch_embed.1.proj'),
|
'first_conv': ('patch_embed.0.proj', 'patch_embed.1.proj'),
|
||||||
'classifier': ('head.0', 'head.1'),
|
'classifier': ('head.0', 'head.1'),
|
||||||
@ -56,7 +56,7 @@ default_cfgs = {
|
|||||||
),
|
),
|
||||||
'crossvit_15_dagger_408': _cfg(
|
'crossvit_15_dagger_408': _cfg(
|
||||||
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_384.pth',
|
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_384.pth',
|
||||||
input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
|
input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0,
|
||||||
),
|
),
|
||||||
'crossvit_18_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_224.pth'),
|
'crossvit_18_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_224.pth'),
|
||||||
'crossvit_18_dagger_240': _cfg(
|
'crossvit_18_dagger_240': _cfg(
|
||||||
@ -65,7 +65,7 @@ default_cfgs = {
|
|||||||
),
|
),
|
||||||
'crossvit_18_dagger_408': _cfg(
|
'crossvit_18_dagger_408': _cfg(
|
||||||
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_384.pth',
|
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_384.pth',
|
||||||
input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
|
input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0,
|
||||||
),
|
),
|
||||||
'crossvit_9_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_224.pth'),
|
'crossvit_9_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_224.pth'),
|
||||||
'crossvit_9_dagger_240': _cfg(
|
'crossvit_9_dagger_240': _cfg(
|
||||||
@ -263,7 +263,7 @@ class CrossViT(nn.Module):
|
|||||||
self, img_size=224, img_scale=(1.0, 1.0), patch_size=(8, 16), in_chans=3, num_classes=1000,
|
self, img_size=224, img_scale=(1.0, 1.0), patch_size=(8, 16), in_chans=3, num_classes=1000,
|
||||||
embed_dim=(192, 384), depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)), num_heads=(6, 12), mlp_ratio=(2., 2., 4.),
|
embed_dim=(192, 384), depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)), num_heads=(6, 12), mlp_ratio=(2., 2., 4.),
|
||||||
qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
||||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=False
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=False, crop_scale=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -271,6 +271,7 @@ class CrossViT(nn.Module):
|
|||||||
self.img_size = to_2tuple(img_size)
|
self.img_size = to_2tuple(img_size)
|
||||||
img_scale = to_2tuple(img_scale)
|
img_scale = to_2tuple(img_scale)
|
||||||
self.img_size_scaled = [tuple([int(sj * si) for sj in self.img_size]) for si in img_scale]
|
self.img_size_scaled = [tuple([int(sj * si) for sj in self.img_size]) for si in img_scale]
|
||||||
|
self.crop_scale = crop_scale # crop instead of interpolate for scale
|
||||||
num_patches = _compute_num_patches(self.img_size_scaled, patch_size)
|
num_patches = _compute_num_patches(self.img_size_scaled, patch_size)
|
||||||
self.num_branches = len(patch_size)
|
self.num_branches = len(patch_size)
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
@ -307,8 +308,7 @@ class CrossViT(nn.Module):
|
|||||||
for i in range(self.num_branches)])
|
for i in range(self.num_branches)])
|
||||||
|
|
||||||
for i in range(self.num_branches):
|
for i in range(self.num_branches):
|
||||||
if hasattr(self, f'pos_embed_{i}'):
|
trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02)
|
||||||
trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02)
|
|
||||||
trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02)
|
trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02)
|
||||||
|
|
||||||
self.apply(self._init_weights)
|
self.apply(self._init_weights)
|
||||||
@ -324,9 +324,12 @@ class CrossViT(nn.Module):
|
|||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def no_weight_decay(self):
|
def no_weight_decay(self):
|
||||||
out = {'cls_token'}
|
out = set()
|
||||||
if self.pos_embed[0].requires_grad:
|
for i in range(self.num_branches):
|
||||||
out.add('pos_embed')
|
out.add(f'cls_token_{i}')
|
||||||
|
pe = getattr(self, f'pos_embed_{i}', None)
|
||||||
|
if pe is not None and pe.requires_grad:
|
||||||
|
out.add(f'pos_embed_{i}')
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def get_classifier(self):
|
def get_classifier(self):
|
||||||
@ -342,23 +345,29 @@ class CrossViT(nn.Module):
|
|||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
xs = []
|
xs = []
|
||||||
for i, patch_embed in enumerate(self.patch_embed):
|
for i, patch_embed in enumerate(self.patch_embed):
|
||||||
|
x_ = x
|
||||||
ss = self.img_size_scaled[i]
|
ss = self.img_size_scaled[i]
|
||||||
x_ = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False) if H != ss[0] else x
|
if H != ss[0] or W != ss[1]:
|
||||||
tmp = patch_embed(x_)
|
if self.crop_scale and ss[0] <= H and ss[1] <= W:
|
||||||
|
cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
|
||||||
|
x_ = x_[:, :, cu:cu + ss[0], cl:cl + ss[1]]
|
||||||
|
else:
|
||||||
|
x_ = torch.nn.functional.interpolate(x_, size=ss, mode='bicubic', align_corners=False)
|
||||||
|
x_ = patch_embed(x_)
|
||||||
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
|
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
|
||||||
cls_tokens = cls_tokens.expand(B, -1, -1)
|
cls_tokens = cls_tokens.expand(B, -1, -1)
|
||||||
tmp = torch.cat((cls_tokens, tmp), dim=1)
|
x_ = torch.cat((cls_tokens, x_), dim=1)
|
||||||
pos_embed = self.pos_embed_0 if i == 0 else self.pos_embed_1 # hard-coded for torch jit script
|
pos_embed = self.pos_embed_0 if i == 0 else self.pos_embed_1 # hard-coded for torch jit script
|
||||||
tmp = tmp + pos_embed
|
x_ = x_ + pos_embed
|
||||||
tmp = self.pos_drop(tmp)
|
x_ = self.pos_drop(x_)
|
||||||
xs.append(tmp)
|
xs.append(x_)
|
||||||
|
|
||||||
for i, blk in enumerate(self.blocks):
|
for i, blk in enumerate(self.blocks):
|
||||||
xs = blk(xs)
|
xs = blk(xs)
|
||||||
|
|
||||||
# NOTE: was before branch token section, move to here to assure all branch token are before layer norm
|
# NOTE: was before branch token section, move to here to assure all branch token are before layer norm
|
||||||
xs = [norm(xs[i]) for i, norm in enumerate(self.norm)]
|
xs = [norm(xs[i]) for i, norm in enumerate(self.norm)]
|
||||||
return [x[:, 0] for x in xs]
|
return [xo[:, 0] for xo in xs]
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
xs = self.forward_features(x)
|
xs = self.forward_features(x)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user