cae config fix
parent
fb763b7096
commit
f24e1f9bcf
|
@ -630,17 +630,17 @@ def _load_pretrained(pretrained,
|
|||
model,
|
||||
model_keys,
|
||||
model_ema_configs,
|
||||
abs_pos_emb,
|
||||
rel_pos_bias,
|
||||
use_abs_pos_emb,
|
||||
use_rel_pos_bias,
|
||||
use_ssld=False):
|
||||
if pretrained is False:
|
||||
pass
|
||||
return
|
||||
elif pretrained is True:
|
||||
local_weight_path = get_weights_path_from_url(pretrained_url).replace(
|
||||
".pdparams", "")
|
||||
checkpoint = paddle.load(local_weight_path + ".pdparams")
|
||||
elif isinstance(pretrained, str):
|
||||
checkpoint = paddle.load(local_weight_path + ".pdparams")
|
||||
checkpoint = paddle.load(pretrained + ".pdparams")
|
||||
|
||||
checkpoint_model = None
|
||||
for model_key in model_keys.split('|'):
|
||||
|
@ -693,10 +693,10 @@ def _load_pretrained(pretrained,
|
|||
if "relative_position_index" in key:
|
||||
checkpoint_model.pop(key)
|
||||
|
||||
if "relative_position_bias_table" in key and rel_pos_bias:
|
||||
if "relative_position_bias_table" in key and use_rel_pos_bias:
|
||||
rel_pos_bias = checkpoint_model[key]
|
||||
src_num_pos, num_attn_heads = rel_pos_bias.size()
|
||||
dst_num_pos, _ = model.state_dict()[key].size()
|
||||
src_num_pos, num_attn_heads = rel_pos_bias.shape
|
||||
dst_num_pos, _ = model.state_dict()[key].shape
|
||||
dst_patch_shape = model.patch_embed.patch_shape
|
||||
if dst_patch_shape[0] != dst_patch_shape[1]:
|
||||
raise NotImplementedError()
|
||||
|
@ -742,8 +742,8 @@ def _load_pretrained(pretrained,
|
|||
src_size).float().numpy()
|
||||
f = interpolate.interp2d(x, y, z, kind='cubic')
|
||||
all_rel_pos_bias.append(
|
||||
paddle.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(
|
||||
rel_pos_bias.device))
|
||||
paddle.Tensor(f(dx, dy)).astype('float32').reshape(
|
||||
[-1, 1]))
|
||||
|
||||
rel_pos_bias = paddle.concat(all_rel_pos_bias, axis=-1)
|
||||
|
||||
|
@ -752,7 +752,7 @@ def _load_pretrained(pretrained,
|
|||
checkpoint_model[key] = new_rel_pos_bias
|
||||
|
||||
# interpolate position embedding
|
||||
if 'pos_embed' in checkpoint_model and abs_pos_emb:
|
||||
if 'pos_embed' in checkpoint_model and use_abs_pos_emb:
|
||||
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
||||
embedding_size = pos_embed_checkpoint.shape[-1]
|
||||
num_patches = model.patch_embed.num_patches
|
||||
|
@ -791,8 +791,8 @@ def cae_base_patch16_224(pretrained=True, use_ssld=False, **kwargs):
|
|||
enable_linear_eval = config.pop('enable_linear_eval')
|
||||
model_keys = config.pop('model_key')
|
||||
model_ema_configs = config.pop('model_ema')
|
||||
abs_pos_emb = config.pop('abs_pos_emb')
|
||||
rel_pos_bias = config.pop('rel_pos_bias')
|
||||
use_abs_pos_emb = config.get('use_abs_pos_emb', False)
|
||||
use_rel_pos_bias = config.get('use_rel_pos_bias', True)
|
||||
if pretrained in config:
|
||||
pretrained = config.pop('pretrained')
|
||||
|
||||
|
@ -816,8 +816,8 @@ def cae_base_patch16_224(pretrained=True, use_ssld=False, **kwargs):
|
|||
model,
|
||||
model_keys,
|
||||
model_ema_configs,
|
||||
abs_pos_emb,
|
||||
rel_pos_bias,
|
||||
use_abs_pos_emb,
|
||||
use_rel_pos_bias,
|
||||
use_ssld=False)
|
||||
|
||||
return model
|
||||
|
@ -828,8 +828,8 @@ def cae_large_patch16_224(pretrained=True, use_ssld=False, **kwargs):
|
|||
enable_linear_eval = config.pop('enable_linear_eval')
|
||||
model_keys = config.pop('model_key')
|
||||
model_ema_configs = config.pop('model_ema')
|
||||
abs_pos_emb = config.pop('abs_pos_emb')
|
||||
rel_pos_bias = config.pop('rel_pos_bias')
|
||||
use_abs_pos_emb = config.get('use_abs_pos_emb', False)
|
||||
use_rel_pos_bias = config.get('use_rel_pos_bias', True)
|
||||
if pretrained in config:
|
||||
pretrained = config.pop('pretrained')
|
||||
|
||||
|
@ -853,8 +853,8 @@ def cae_large_patch16_224(pretrained=True, use_ssld=False, **kwargs):
|
|||
model,
|
||||
model_keys,
|
||||
model_ema_configs,
|
||||
abs_pos_emb,
|
||||
rel_pos_bias,
|
||||
use_abs_pos_emb,
|
||||
use_rel_pos_bias,
|
||||
use_ssld=False)
|
||||
|
||||
return model
|
||||
|
|
|
@ -31,10 +31,8 @@ Arch:
|
|||
|
||||
sin_pos_emb: True
|
||||
|
||||
abs_pos_emb: False
|
||||
enable_linear_eval: False
|
||||
model_key: model|module|state_dict
|
||||
rel_pos_bias: True
|
||||
model_ema:
|
||||
enable_model_ema: False
|
||||
model_ema_decay: 0.9999
|
||||
|
@ -83,23 +81,27 @@ DataLoader:
|
|||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- RandCropImage:
|
||||
- RandomResizedCrop:
|
||||
size: 224
|
||||
interpolation: bilinear
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- RandAugment:
|
||||
- RandomHorizontalFlip:
|
||||
prob: 0.5
|
||||
- TimmAutoAugment:
|
||||
config_str: rand-m9-mstd0.5-inc1
|
||||
interpolation: bicubic
|
||||
img_size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
order: ''
|
||||
- RandomErasing:
|
||||
EPSILON: 0.5
|
||||
EPSILON: 0.25
|
||||
sl: 0.02
|
||||
sh: 0.3
|
||||
sh: 1.0/3.0
|
||||
r1: 0.3
|
||||
|
||||
attempt: 10
|
||||
use_log_aspect: True
|
||||
mode: pixel
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 16
|
||||
|
@ -110,7 +112,7 @@ DataLoader:
|
|||
use_shared_memory: True
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/flowers102/
|
||||
cls_label_path: ./dataset/flowers102/val_list.txt
|
||||
|
@ -124,8 +126,8 @@ DataLoader:
|
|||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
|
|
|
@ -31,10 +31,8 @@ Arch:
|
|||
|
||||
sin_pos_emb: True
|
||||
|
||||
abs_pos_emb: False
|
||||
enable_linear_eval: False
|
||||
model_key: model|module|state_dict
|
||||
rel_pos_bias: True
|
||||
model_ema:
|
||||
enable_model_ema: False
|
||||
model_ema_decay: 0.9999
|
||||
|
@ -83,23 +81,27 @@ DataLoader:
|
|||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- RandCropImage:
|
||||
- RandomResizedCrop:
|
||||
size: 224
|
||||
interpolation: bilinear
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- RandAugment:
|
||||
- RandomHorizontalFlip:
|
||||
prob: 0.5
|
||||
- TimmAutoAugment:
|
||||
config_str: rand-m9-mstd0.5-inc1
|
||||
interpolation: bicubic
|
||||
img_size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
order: ''
|
||||
- RandomErasing:
|
||||
EPSILON: 0.5
|
||||
EPSILON: 0.25
|
||||
sl: 0.02
|
||||
sh: 0.3
|
||||
sh: 1.0/3.0
|
||||
r1: 0.3
|
||||
|
||||
attempt: 10
|
||||
use_log_aspect: True
|
||||
mode: pixel
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 16
|
||||
|
@ -110,7 +112,7 @@ DataLoader:
|
|||
use_shared_memory: True
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/flowers102/
|
||||
cls_label_path: ./dataset/flowers102/val_list.txt
|
||||
|
@ -124,8 +126,8 @@ DataLoader:
|
|||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
|
|
Loading…
Reference in New Issue