mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add naflex loader support to validate.py, fix bug in naflex pos embed add, classic vit weight loading for naflex model
This commit is contained in:
parent
c527c37969
commit
3dc90ed7a7
@ -760,7 +760,6 @@ def patchify(
|
|||||||
nh, nw = h // ph, w // pw
|
nh, nw = h // ph, w // pw
|
||||||
# Reshape image to patches [nh, nw, ph, pw, c]
|
# Reshape image to patches [nh, nw, ph, pw, c]
|
||||||
patches = img.view(c, nh, ph, nw, pw).permute(1, 3, 2, 4, 0).reshape(nh * nw, ph * pw * c)
|
patches = img.view(c, nh, ph, nw, pw).permute(1, 3, 2, 4, 0).reshape(nh * nw, ph * pw * c)
|
||||||
|
|
||||||
if include_info:
|
if include_info:
|
||||||
# Create coordinate indices
|
# Create coordinate indices
|
||||||
y_idx, x_idx = torch.meshgrid(torch.arange(nh), torch.arange(nw), indexing='ij')
|
y_idx, x_idx = torch.meshgrid(torch.arange(nh), torch.arange(nw), indexing='ij')
|
||||||
|
@ -318,7 +318,7 @@ def transforms_imagenet_eval(
|
|||||||
tfl += [ResizeToSequence(
|
tfl += [ResizeToSequence(
|
||||||
patch_size=patch_size,
|
patch_size=patch_size,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
interpolation=interpolation
|
interpolation=interpolation,
|
||||||
)]
|
)]
|
||||||
else:
|
else:
|
||||||
if crop_mode == 'squash':
|
if crop_mode == 'squash':
|
||||||
|
@ -52,6 +52,7 @@ def batch_patchify(
|
|||||||
|
|
||||||
nh, nw = H // ph, W // pw
|
nh, nw = H // ph, W // pw
|
||||||
patches = x.view(B, C, nh, ph, nw, pw).permute(0, 2, 4, 3, 5, 1).reshape(B, nh * nw, ph * pw * C)
|
patches = x.view(B, C, nh, ph, nw, pw).permute(0, 2, 4, 3, 5, 1).reshape(B, nh * nw, ph * pw * C)
|
||||||
|
# FIXME confirm we want 'channels last' in the patch channel layout, egg ph, ph, C instead of C, ph, hw
|
||||||
|
|
||||||
return patches, (nh, nw)
|
return patches, (nh, nw)
|
||||||
|
|
||||||
@ -297,7 +298,7 @@ class FlexEmbeds(nn.Module):
|
|||||||
size_to_indices[k].append(bi)
|
size_to_indices[k].append(bi)
|
||||||
|
|
||||||
# Handle each batch element separately with its own grid size
|
# Handle each batch element separately with its own grid size
|
||||||
pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2) # B,C,H,W
|
pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W
|
||||||
for k, batch_indices in size_to_indices.items():
|
for k, batch_indices in size_to_indices.items():
|
||||||
h, w = k
|
h, w = k
|
||||||
#h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
|
#h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
|
||||||
@ -312,9 +313,14 @@ class FlexEmbeds(nn.Module):
|
|||||||
align_corners=False,
|
align_corners=False,
|
||||||
antialias=True,
|
antialias=True,
|
||||||
).flatten(2).transpose(1, 2)
|
).flatten(2).transpose(1, 2)
|
||||||
|
pos_embed_flat = pos_embed_flat.to(dtype=x.dtype)
|
||||||
|
|
||||||
seq_len = min(x.shape[1], pos_embed_flat.shape[1])
|
seq_len = min(x.shape[1], pos_embed_flat.shape[1])
|
||||||
x[batch_indices, :seq_len].add_(pos_embed_flat[:, :seq_len])
|
x[:, :seq_len].index_add_(
|
||||||
|
0,
|
||||||
|
torch.as_tensor(batch_indices, device=x.device),
|
||||||
|
pos_embed_flat[:, :seq_len].expand(len(batch_indices), -1, -1)
|
||||||
|
)
|
||||||
|
|
||||||
def _apply_learned_pos_embed(
|
def _apply_learned_pos_embed(
|
||||||
self,
|
self,
|
||||||
@ -328,12 +334,13 @@ class FlexEmbeds(nn.Module):
|
|||||||
else:
|
else:
|
||||||
# Resize if needed - directly using F.interpolate
|
# Resize if needed - directly using F.interpolate
|
||||||
pos_embed_flat = F.interpolate(
|
pos_embed_flat = F.interpolate(
|
||||||
self.pos_embed.permute(0, 3, 1, 2), # B,C,H,W
|
self.pos_embed.permute(0, 3, 1, 2).float(), # B,C,H,W
|
||||||
size=grid_size,
|
size=grid_size,
|
||||||
mode=self.pos_embed_interp_mode,
|
mode=self.pos_embed_interp_mode,
|
||||||
align_corners=False,
|
align_corners=False,
|
||||||
antialias=True,
|
antialias=True,
|
||||||
).flatten(2).transpose(1, 2)
|
).flatten(2).transpose(1, 2)
|
||||||
|
pos_embed_flat = pos_embed_flat.to(dtype=x.dtype)
|
||||||
|
|
||||||
x.add_(pos_embed_flat)
|
x.add_(pos_embed_flat)
|
||||||
|
|
||||||
@ -806,21 +813,20 @@ class VisionTransformerFlex(nn.Module):
|
|||||||
# Apply the mask to extract only valid tokens
|
# Apply the mask to extract only valid tokens
|
||||||
x = x[:, self.num_prefix_tokens:] # prefix tokens not included in pooling
|
x = x[:, self.num_prefix_tokens:] # prefix tokens not included in pooling
|
||||||
|
|
||||||
|
patch_valid_float = patch_valid.to(x.dtype)
|
||||||
if pool_type == 'avg':
|
if pool_type == 'avg':
|
||||||
# Compute masked average pooling
|
# Compute masked average pooling, sum valid tokens and divide by count of valid tokens
|
||||||
# Sum valid tokens and divide by count of valid tokens
|
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
|
||||||
masked_sums = (x * patch_valid.unsqueeze(-1).float()).sum(dim=1)
|
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
|
||||||
valid_counts = patch_valid.float().sum(dim=1, keepdim=True).clamp(min=1)
|
|
||||||
pooled = masked_sums / valid_counts
|
pooled = masked_sums / valid_counts
|
||||||
return pooled
|
return pooled
|
||||||
elif pool_type == 'avgmax':
|
elif pool_type == 'avgmax':
|
||||||
# For avgmax, compute masked average and masked max
|
# For avgmax, compute masked average and masked max
|
||||||
# For max, we set masked positions to large negative value
|
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
|
||||||
masked_sums = (x * patch_valid.unsqueeze(-1).float()).sum(dim=1)
|
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
|
||||||
valid_counts = patch_valid.float().sum(dim=1, keepdim=True).clamp(min=1)
|
|
||||||
masked_avg = masked_sums / valid_counts
|
masked_avg = masked_sums / valid_counts
|
||||||
|
|
||||||
# For max pooling with mask
|
# For max pooling we set masked positions to large negative value
|
||||||
masked_x = x.clone()
|
masked_x = x.clone()
|
||||||
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
|
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
|
||||||
masked_max = masked_x.max(dim=1)[0]
|
masked_max = masked_x.max(dim=1)[0]
|
||||||
@ -915,6 +921,82 @@ def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable:
|
|||||||
return init_weights_vit_timm
|
return init_weights_vit_timm
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_filter_fn(state_dict, model):
|
||||||
|
"""Handle state dict conversion from original ViT to the new version with combined embedding."""
|
||||||
|
from .vision_transformer import checkpoint_filter_fn as orig_filter_fn
|
||||||
|
|
||||||
|
# Handle CombinedEmbed module pattern
|
||||||
|
out_dict = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
# Convert tokens and embeddings to combined_embed structure
|
||||||
|
if k == 'pos_embed':
|
||||||
|
# Handle position embedding format conversion - from (1, N, C) to (1, H, W, C)
|
||||||
|
if hasattr(model.embeds, 'pos_embed') and v.ndim == 3:
|
||||||
|
num_cls_token = 0
|
||||||
|
num_reg_token = 0
|
||||||
|
if 'reg_token' in state_dict:
|
||||||
|
num_reg_token = state_dict['reg_token'].shape[1]
|
||||||
|
if 'cls_token' in state_dict:
|
||||||
|
num_cls_token = state_dict['cls_token'].shape[1]
|
||||||
|
num_prefix_tokens = num_cls_token + num_reg_token
|
||||||
|
|
||||||
|
# Original format is (1, N, C), need to reshape to (1, H, W, C)
|
||||||
|
num_patches = v.shape[1]
|
||||||
|
num_patches_no_prefix = num_patches - num_prefix_tokens
|
||||||
|
grid_size_no_prefix = math.sqrt(num_patches_no_prefix)
|
||||||
|
grid_size = math.sqrt(num_patches)
|
||||||
|
if (grid_size_no_prefix != grid_size and (
|
||||||
|
grid_size_no_prefix.is_integer() and not grid_size.is_integer())):
|
||||||
|
# make a decision, did the pos_embed of the original include the prefix tokens?
|
||||||
|
num_patches = num_patches_no_prefix
|
||||||
|
cls_token_emb = v[:, 0:num_cls_token]
|
||||||
|
if cls_token_emb.numel():
|
||||||
|
state_dict['cls_token'] += cls_token_emb
|
||||||
|
reg_token_emb = v[:, num_cls_token:num_reg_token]
|
||||||
|
if reg_token_emb.numel():
|
||||||
|
state_dict['reg_token'] += reg_token_emb
|
||||||
|
v = v[:, num_prefix_tokens:]
|
||||||
|
grid_size = grid_size_no_prefix
|
||||||
|
grid_size = int(grid_size)
|
||||||
|
|
||||||
|
# Check if it's a perfect square for a standard grid
|
||||||
|
if grid_size * grid_size == num_patches:
|
||||||
|
# Reshape from (1, N, C) to (1, H, W, C)
|
||||||
|
v = v.reshape(1, grid_size, grid_size, v.shape[2])
|
||||||
|
else:
|
||||||
|
# Not a square grid, we need to get the actual dimensions
|
||||||
|
if hasattr(model.embeds.patch_embed, 'grid_size'):
|
||||||
|
h, w = model.embeds.patch_embed.grid_size
|
||||||
|
if h * w == num_patches:
|
||||||
|
# We have the right dimensions
|
||||||
|
v = v.reshape(1, h, w, v.shape[2])
|
||||||
|
else:
|
||||||
|
# Dimensions don't match, use interpolation
|
||||||
|
_logger.warning(
|
||||||
|
f"Position embedding size mismatch: checkpoint={num_patches}, model={(h * w)}. "
|
||||||
|
f"Using default initialization and will resize in forward pass."
|
||||||
|
)
|
||||||
|
# Keep v as is, the forward pass will handle resizing
|
||||||
|
|
||||||
|
out_dict['embeds.pos_embed'] = v
|
||||||
|
elif k == 'cls_token':
|
||||||
|
out_dict['embeds.cls_token'] = v
|
||||||
|
elif k == 'reg_token':
|
||||||
|
out_dict['embeds.reg_token'] = v
|
||||||
|
# Convert patch_embed.X to embeds.patch_embed.X
|
||||||
|
elif k.startswith('patch_embed.'):
|
||||||
|
suffix = k[12:]
|
||||||
|
if suffix == 'proj.weight':
|
||||||
|
# FIXME confirm patchify memory layout across use cases
|
||||||
|
v = v.permute(0, 2, 3, 1).flatten(1)
|
||||||
|
new_key = 'embeds.' + suffix
|
||||||
|
out_dict[new_key] = v
|
||||||
|
else:
|
||||||
|
out_dict[k] = v
|
||||||
|
|
||||||
|
return out_dict
|
||||||
|
|
||||||
|
|
||||||
def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
|
def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
'url': url,
|
'url': url,
|
||||||
@ -936,6 +1018,26 @@ default_cfgs = generate_default_cfgs({
|
|||||||
'vit_naflex_base_patch16': _cfg(),
|
'vit_naflex_base_patch16': _cfg(),
|
||||||
'vit_naflex_base_patch16_gap': _cfg(),
|
'vit_naflex_base_patch16_gap': _cfg(),
|
||||||
'vit_naflex_base_patch16_map': _cfg(),
|
'vit_naflex_base_patch16_map': _cfg(),
|
||||||
|
|
||||||
|
# sbb model testijg
|
||||||
|
'vit_naflex_mediumd_patch16_reg4_gap.sbb2_r256_e200_in12k_ft_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k',
|
||||||
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
|
'vit_naflex_so150m2_patch16_reg1_gap.sbb_r256_e200_in12k_ft_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_256.sbb_e200_in12k_ft_in1k',
|
||||||
|
input_size=(3, 256, 256), crop_pct=1.0),
|
||||||
|
'vit_naflex_so150m2_patch16_reg1_gap.sbb_r384_e200_in12k_ft_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_384.sbb_e200_in12k_ft_in1k',
|
||||||
|
input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
|
'vit_naflex_so150m2_patch16_reg1_gap.sbb_r448_e200_in12k_ft_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_448.sbb_e200_in12k_ft_in1k',
|
||||||
|
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash'),
|
||||||
|
|
||||||
|
# traditional vit testing
|
||||||
|
'vit_naflex_base_patch16.augreg2_r224_in21k_ft_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/vit_base_patch16_224.augreg2_in21k_ft_in1k'),
|
||||||
|
'vit_naflex_base_patch8.augreg2_r224_in21k_ft_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/vit_base_patch16_224.augreg2_in21k_ft_in1k'),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
@ -948,10 +1050,22 @@ def _create_vision_transformer_flex(variant, pretrained=False, **kwargs):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_naflex_mediumd_patch16_reg4_gap(pretrained=False, **kwargs):
|
||||||
|
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
|
||||||
|
"""
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5,
|
||||||
|
global_pool='avg', class_token=False, reg_tokens=4, fc_norm=True, **kwargs)
|
||||||
|
model = _create_vision_transformer_flex(
|
||||||
|
'vit_naflex_mediumd_patch16_reg4_gap', pretrained=pretrained, **model_args)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vit_naflex_base_patch16(pretrained=False, **kwargs):
|
def vit_naflex_base_patch16(pretrained=False, **kwargs):
|
||||||
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
|
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
|
||||||
|
|
||||||
This model supports:
|
This model supports:
|
||||||
1. Variable aspect ratios and resolutions via patch coordinates
|
1. Variable aspect ratios and resolutions via patch coordinates
|
||||||
2. Position embedding interpolation for arbitrary grid sizes
|
2. Position embedding interpolation for arbitrary grid sizes
|
||||||
@ -987,54 +1101,28 @@ def vit_naflex_base_patch16_map(pretrained=False, **kwargs):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_filter_fn(state_dict, model):
|
@register_model
|
||||||
"""Handle state dict conversion from original ViT to the new version with combined embedding."""
|
def vit_naflex_so150m2_patch16_reg1_gap(pretrained=False, **kwargs):
|
||||||
from .vision_transformer import checkpoint_filter_fn as orig_filter_fn
|
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
|
||||||
|
|
||||||
# FIXME conversion of existing vit checkpoints has not been finished or tested
|
This model supports:
|
||||||
|
1. Variable aspect ratios and resolutions via patch coordinates
|
||||||
|
2. Position embedding interpolation for arbitrary grid sizes
|
||||||
|
3. Explicit patch coordinates and valid token masking
|
||||||
|
"""
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5,
|
||||||
|
qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', fc_norm=True, **kwargs)
|
||||||
|
model = _create_vision_transformer_flex(
|
||||||
|
'vit_naflex_so150m2_patch16_reg1_gap', pretrained=pretrained, **model_args)
|
||||||
|
return model
|
||||||
|
|
||||||
# Handle CombinedEmbed module pattern
|
|
||||||
out_dict = {}
|
@register_model
|
||||||
for k, v in state_dict.items():
|
def vit_naflex_base_patch16(pretrained: bool = False, **kwargs):
|
||||||
# Convert tokens and embeddings to combined_embed structure
|
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||||
if k == 'pos_embed':
|
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||||
# Handle position embedding format conversion - from (1, N, C) to (1, H, W, C)
|
"""
|
||||||
if hasattr(model.embeds, 'pos_embed') and v.ndim == 3:
|
model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, global_pool='token', class_token=True, pos_embed_grid_size=(14, 14))
|
||||||
# Original format is (1, N, C) - need to reshape to (1, H, W, C)
|
model = _create_vision_transformer_flex('vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
num_patches = v.shape[1]
|
return model
|
||||||
grid_size = int(math.sqrt(num_patches))
|
|
||||||
|
|
||||||
# Check if it's a perfect square for a standard grid
|
|
||||||
if grid_size * grid_size == num_patches:
|
|
||||||
# Reshape from (1, N, C) to (1, H, W, C)
|
|
||||||
v = v.reshape(1, grid_size, grid_size, v.shape[2])
|
|
||||||
else:
|
|
||||||
# Not a square grid, we need to get the actual dimensions
|
|
||||||
if hasattr(model.embeds.patch_embed, 'grid_size'):
|
|
||||||
h, w = model.embeds.patch_embed.grid_size
|
|
||||||
if h * w == num_patches:
|
|
||||||
# We have the right dimensions
|
|
||||||
v = v.reshape(1, h, w, v.shape[2])
|
|
||||||
else:
|
|
||||||
# Dimensions don't match, use interpolation
|
|
||||||
_logger.warning(
|
|
||||||
f"Position embedding size mismatch: checkpoint={num_patches}, model={(h * w)}. "
|
|
||||||
f"Using default initialization and will resize in forward pass."
|
|
||||||
)
|
|
||||||
# Keep v as is, the forward pass will handle resizing
|
|
||||||
|
|
||||||
out_dict['embeds.pos_embed'] = v
|
|
||||||
|
|
||||||
elif k == 'cls_token':
|
|
||||||
out_dict['embeds.cls_token'] = v
|
|
||||||
elif k == 'reg_token':
|
|
||||||
out_dict['embeds.reg_token'] = v
|
|
||||||
# Convert patch_embed.X to embeds.patch_embed.X
|
|
||||||
elif k.startswith('patch_embed.'):
|
|
||||||
new_key = 'embeds.' + k[12:]
|
|
||||||
out_dict[new_key] = v
|
|
||||||
else:
|
|
||||||
out_dict[k] = v
|
|
||||||
|
|
||||||
# Call the original filter function to handle other patterns
|
|
||||||
return orig_filter_fn(out_dict, model)
|
|
||||||
|
69
validate.py
69
validate.py
@ -158,6 +158,12 @@ parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
|
|||||||
parser.add_argument('--retry', default=False, action='store_true',
|
parser.add_argument('--retry', default=False, action='store_true',
|
||||||
help='Enable batch size decay & retry for single model validation')
|
help='Enable batch size decay & retry for single model validation')
|
||||||
|
|
||||||
|
# NaFlex loader arguments
|
||||||
|
parser.add_argument('--naflex-loader', action='store_true', default=False,
|
||||||
|
help='Use NaFlex loader (Requires NaFlex compatible model)')
|
||||||
|
parser.add_argument('--naflex-max-seq-len', type=int, default=576,
|
||||||
|
help='Fixed maximum sequence length for NaFlex loader (validation)')
|
||||||
|
|
||||||
|
|
||||||
def validate(args):
|
def validate(args):
|
||||||
# might as well try to validate something
|
# might as well try to validate something
|
||||||
@ -293,23 +299,43 @@ def validate(args):
|
|||||||
real_labels = None
|
real_labels = None
|
||||||
|
|
||||||
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
|
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
|
||||||
loader = create_loader(
|
if args.naflex_loader:
|
||||||
dataset,
|
from timm.data import create_naflex_loader
|
||||||
input_size=data_config['input_size'],
|
loader = create_naflex_loader(
|
||||||
batch_size=args.batch_size,
|
dataset,
|
||||||
use_prefetcher=args.prefetcher,
|
batch_size=args.batch_size,
|
||||||
interpolation=data_config['interpolation'],
|
use_prefetcher=args.prefetcher,
|
||||||
mean=data_config['mean'],
|
interpolation=data_config['interpolation'],
|
||||||
std=data_config['std'],
|
mean=data_config['mean'],
|
||||||
num_workers=args.workers,
|
std=data_config['std'],
|
||||||
crop_pct=crop_pct,
|
num_workers=args.workers,
|
||||||
crop_mode=data_config['crop_mode'],
|
crop_pct=crop_pct,
|
||||||
crop_border_pixels=args.crop_border_pixels,
|
crop_mode=data_config['crop_mode'],
|
||||||
pin_memory=args.pin_mem,
|
crop_border_pixels=args.crop_border_pixels,
|
||||||
device=device,
|
pin_memory=args.pin_mem,
|
||||||
img_dtype=model_dtype or torch.float32,
|
device=device,
|
||||||
tf_preprocessing=args.tf_preprocessing,
|
img_dtype=model_dtype or torch.float32,
|
||||||
)
|
patch_size=16, # Could be derived from model config
|
||||||
|
max_seq_len=args.naflex_max_seq_len,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
loader = create_loader(
|
||||||
|
dataset,
|
||||||
|
input_size=data_config['input_size'],
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
use_prefetcher=args.prefetcher,
|
||||||
|
interpolation=data_config['interpolation'],
|
||||||
|
mean=data_config['mean'],
|
||||||
|
std=data_config['std'],
|
||||||
|
num_workers=args.workers,
|
||||||
|
crop_pct=crop_pct,
|
||||||
|
crop_mode=data_config['crop_mode'],
|
||||||
|
crop_border_pixels=args.crop_border_pixels,
|
||||||
|
pin_memory=args.pin_mem,
|
||||||
|
device=device,
|
||||||
|
img_dtype=model_dtype or torch.float32,
|
||||||
|
tf_preprocessing=args.tf_preprocessing,
|
||||||
|
)
|
||||||
|
|
||||||
batch_time = AverageMeter()
|
batch_time = AverageMeter()
|
||||||
losses = AverageMeter()
|
losses = AverageMeter()
|
||||||
@ -345,10 +371,11 @@ def validate(args):
|
|||||||
real_labels.add_result(output)
|
real_labels.add_result(output)
|
||||||
|
|
||||||
# measure accuracy and record loss
|
# measure accuracy and record loss
|
||||||
|
batch_size = output.shape[0]
|
||||||
acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
|
acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
|
||||||
losses.update(loss.item(), input.size(0))
|
losses.update(loss.item(), batch_size)
|
||||||
top1.update(acc1.item(), input.size(0))
|
top1.update(acc1.item(), batch_size)
|
||||||
top5.update(acc5.item(), input.size(0))
|
top5.update(acc5.item(), batch_size)
|
||||||
|
|
||||||
# measure elapsed time
|
# measure elapsed time
|
||||||
batch_time.update(time.time() - end)
|
batch_time.update(time.time() - end)
|
||||||
@ -364,7 +391,7 @@ def validate(args):
|
|||||||
batch_idx,
|
batch_idx,
|
||||||
len(loader),
|
len(loader),
|
||||||
batch_time=batch_time,
|
batch_time=batch_time,
|
||||||
rate_avg=input.size(0) / batch_time.avg,
|
rate_avg=batch_size / batch_time.avg,
|
||||||
loss=losses,
|
loss=losses,
|
||||||
top1=top1,
|
top1=top1,
|
||||||
top5=top5
|
top5=top5
|
||||||
|
Loading…
x
Reference in New Issue
Block a user