mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add missing deprecation mapping for a densenet and xcit model. Fix #2086. Tweak xcit pos embed use of arange for better low prec safety.
This commit is contained in:
parent
809a9e14e2
commit
3234daf783
@ -15,7 +15,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||||||
from timm.layers import BatchNormAct2d, get_norm_act_layer, BlurPool2d, create_classifier
|
from timm.layers import BatchNormAct2d, get_norm_act_layer, BlurPool2d, create_classifier
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._manipulate import MATCH_PREV_GROUP
|
from ._manipulate import MATCH_PREV_GROUP
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
||||||
|
|
||||||
__all__ = ['DenseNet']
|
__all__ = ['DenseNet']
|
||||||
|
|
||||||
@ -415,3 +415,7 @@ def densenet264d(pretrained=False, **kwargs) -> DenseNet:
|
|||||||
model = _create_densenet('densenet264d', pretrained=pretrained, **dict(model_args, **kwargs))
|
model = _create_densenet('densenet264d', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
register_model_deprecations(__name__, {
|
||||||
|
'tv_densenet121': 'densenet121.tv_in1k',
|
||||||
|
})
|
||||||
|
@ -48,18 +48,19 @@ class PositionalEncodingFourier(nn.Module):
|
|||||||
|
|
||||||
def forward(self, B: int, H: int, W: int):
|
def forward(self, B: int, H: int, W: int):
|
||||||
device = self.token_projection.weight.device
|
device = self.token_projection.weight.device
|
||||||
y_embed = torch.arange(1, H+1, dtype=torch.float32, device=device).unsqueeze(1).repeat(1, 1, W)
|
dtype = self.token_projection.weight.dtype
|
||||||
x_embed = torch.arange(1, W+1, dtype=torch.float32, device=device).repeat(1, H, 1)
|
y_embed = torch.arange(1, H + 1, device=device).to(torch.float32).unsqueeze(1).repeat(1, 1, W)
|
||||||
|
x_embed = torch.arange(1, W + 1, device=device).to(torch.float32).repeat(1, H, 1)
|
||||||
y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale
|
y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale
|
||||||
x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale
|
x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale
|
||||||
dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=device)
|
dim_t = torch.arange(self.hidden_dim, device=device).to(torch.float32)
|
||||||
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim)
|
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim)
|
||||||
pos_x = x_embed[:, :, :, None] / dim_t
|
pos_x = x_embed[:, :, :, None] / dim_t
|
||||||
pos_y = y_embed[:, :, :, None] / dim_t
|
pos_y = y_embed[:, :, :, None] / dim_t
|
||||||
pos_x = torch.stack([pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()], dim=4).flatten(3)
|
pos_x = torch.stack([pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()], dim=4).flatten(3)
|
||||||
pos_y = torch.stack([pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()], dim=4).flatten(3)
|
pos_y = torch.stack([pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()], dim=4).flatten(3)
|
||||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||||
pos = self.token_projection(pos)
|
pos = self.token_projection(pos.to(dtype))
|
||||||
return pos.repeat(B, 1, 1, 1) # (B, C, H, W)
|
return pos.repeat(B, 1, 1, 1) # (B, C, H, W)
|
||||||
|
|
||||||
|
|
||||||
@ -890,6 +891,7 @@ register_model_deprecations(__name__, {
|
|||||||
'xcit_small_12_p16_224_dist': 'xcit_small_12_p16_224.fb_dist_in1k',
|
'xcit_small_12_p16_224_dist': 'xcit_small_12_p16_224.fb_dist_in1k',
|
||||||
'xcit_small_12_p16_384_dist': 'xcit_small_12_p16_384.fb_dist_in1k',
|
'xcit_small_12_p16_384_dist': 'xcit_small_12_p16_384.fb_dist_in1k',
|
||||||
'xcit_small_24_p16_224_dist': 'xcit_small_24_p16_224.fb_dist_in1k',
|
'xcit_small_24_p16_224_dist': 'xcit_small_24_p16_224.fb_dist_in1k',
|
||||||
|
'xcit_small_24_p16_384_dist': 'xcit_small_24_p16_384.fb_dist_in1k',
|
||||||
'xcit_medium_24_p16_224_dist': 'xcit_medium_24_p16_224.fb_dist_in1k',
|
'xcit_medium_24_p16_224_dist': 'xcit_medium_24_p16_224.fb_dist_in1k',
|
||||||
'xcit_medium_24_p16_384_dist': 'xcit_medium_24_p16_384.fb_dist_in1k',
|
'xcit_medium_24_p16_384_dist': 'xcit_medium_24_p16_384.fb_dist_in1k',
|
||||||
'xcit_large_24_p16_224_dist': 'xcit_large_24_p16_224.fb_dist_in1k',
|
'xcit_large_24_p16_224_dist': 'xcit_large_24_p16_224.fb_dist_in1k',
|
||||||
|
Loading…
x
Reference in New Issue
Block a user