Add missing deprecation mapping for a densenet and xcit model. Fix #2086. Tweak xcit pos embed use of arange for better low prec safety.

pull/2089/head
Ross Wightman 2024-01-24 22:04:04 -08:00
parent 809a9e14e2
commit 3234daf783
2 changed files with 11 additions and 5 deletions

View File

@ -15,7 +15,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import BatchNormAct2d, get_norm_act_layer, BlurPool2d, create_classifier
from ._builder import build_model_with_cfg
from ._manipulate import MATCH_PREV_GROUP
from ._registry import register_model, generate_default_cfgs
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
__all__ = ['DenseNet']
@ -415,3 +415,7 @@ def densenet264d(pretrained=False, **kwargs) -> DenseNet:
model = _create_densenet('densenet264d', pretrained=pretrained, **dict(model_args, **kwargs))
return model
register_model_deprecations(__name__, {
'tv_densenet121': 'densenet121.tv_in1k',
})

View File

@ -48,18 +48,19 @@ class PositionalEncodingFourier(nn.Module):
def forward(self, B: int, H: int, W: int):
device = self.token_projection.weight.device
y_embed = torch.arange(1, H+1, dtype=torch.float32, device=device).unsqueeze(1).repeat(1, 1, W)
x_embed = torch.arange(1, W+1, dtype=torch.float32, device=device).repeat(1, H, 1)
dtype = self.token_projection.weight.dtype
y_embed = torch.arange(1, H + 1, device=device).to(torch.float32).unsqueeze(1).repeat(1, 1, W)
x_embed = torch.arange(1, W + 1, device=device).to(torch.float32).repeat(1, H, 1)
y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale
dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=device)
dim_t = torch.arange(self.hidden_dim, device=device).to(torch.float32)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack([pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()], dim=4).flatten(3)
pos_y = torch.stack([pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()], dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
pos = self.token_projection(pos)
pos = self.token_projection(pos.to(dtype))
return pos.repeat(B, 1, 1, 1) # (B, C, H, W)
@ -890,6 +891,7 @@ register_model_deprecations(__name__, {
'xcit_small_12_p16_224_dist': 'xcit_small_12_p16_224.fb_dist_in1k',
'xcit_small_12_p16_384_dist': 'xcit_small_12_p16_384.fb_dist_in1k',
'xcit_small_24_p16_224_dist': 'xcit_small_24_p16_224.fb_dist_in1k',
'xcit_small_24_p16_384_dist': 'xcit_small_24_p16_384.fb_dist_in1k',
'xcit_medium_24_p16_224_dist': 'xcit_medium_24_p16_224.fb_dist_in1k',
'xcit_medium_24_p16_384_dist': 'xcit_medium_24_p16_384.fb_dist_in1k',
'xcit_large_24_p16_224_dist': 'xcit_large_24_p16_224.fb_dist_in1k',