mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove padding calc from pack, minor fixes
This commit is contained in:
parent
d81f75b461
commit
f93083e2b2
@ -84,13 +84,21 @@ class PackedSequence:
|
||||
self.total_len += seq_len
|
||||
self.num_images += 1
|
||||
|
||||
def to_tensors(self, max_len, max_packed, return_mask=True):
|
||||
def to_tensors(self, max_seq_len, max_num_seq):
|
||||
"""
|
||||
Args:
|
||||
max_seq_len: maximum sequence length (pad to this)
|
||||
max_num_seq: maximum # of sequences (images) packed into one sequence (across the batch)
|
||||
|
||||
Returns:
|
||||
Tuple of tensors for packed batch of images
|
||||
"""
|
||||
assert self.total_len > 0
|
||||
assert max_len >= self.total_len
|
||||
assert max_seq_len >= self.total_len
|
||||
device = self.tokens[-1].device
|
||||
dim = self.tokens[-1].shape[-1]
|
||||
pad_len = max_len - self.total_len
|
||||
seq_pad = max(0, max_packed - len(self.seq_lens))
|
||||
pad_len = max_seq_len - self.total_len
|
||||
seq_pad = max(0, max_num_seq - len(self.seq_lens))
|
||||
seq_lens = self.seq_lens + [0] * seq_pad if seq_pad else self.seq_lens
|
||||
seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=device)
|
||||
if pad_len:
|
||||
@ -104,9 +112,6 @@ class PackedSequence:
|
||||
tokens = torch.concat(tokens)
|
||||
pos_indices = torch.concat(pos_indices)
|
||||
seq_ids = torch.concat(seq_ids)
|
||||
if return_mask:
|
||||
mask = seq_ids != 0
|
||||
return tokens, pos_indices, seq_ids, seq_lens, mask
|
||||
return tokens, pos_indices, seq_ids, seq_lens
|
||||
|
||||
|
||||
@ -173,7 +178,7 @@ def pack_images(
|
||||
max_packed = max(sequence.num_images, max_packed)
|
||||
next_pos += 1
|
||||
|
||||
tensors = [p.to_tensors(max_len=max_seq_len, max_packed=max_packed) for p in packed_sequences]
|
||||
tensors = [p.to_tensors(max_seq_len=max_seq_len, max_num_seq=max_packed) for p in packed_sequences]
|
||||
o = [torch.stack(t) for t in zip(*tensors)]
|
||||
return tuple(o)
|
||||
|
||||
@ -655,12 +660,12 @@ class VisionTransformerPacked(nn.Module):
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'embeds.pos_embed', 'embeds.cls_token'}
|
||||
return {'pos_embed_h', 'pos_embed_w'}
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
return dict(
|
||||
stem=r'^embeds', # stem and embed
|
||||
stem=r'^embeds', # stem and embed # FIXME correct when design finalized
|
||||
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
|
||||
)
|
||||
|
||||
@ -675,7 +680,7 @@ class VisionTransformerPacked(nn.Module):
|
||||
def reset_classifier(self, num_classes: int, global_pool=None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('', 'avg', 'token')
|
||||
assert global_pool in ('', 'avg', 'attn')
|
||||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
@ -693,7 +698,7 @@ class VisionTransformerPacked(nn.Module):
|
||||
tokens = tokens.unbind(0)
|
||||
|
||||
if isinstance(tokens, (list, tuple)):
|
||||
tokens, pos_indices, seq_ids, seq_lens, padding_mask = pack_images(
|
||||
tokens, pos_indices, seq_ids, seq_lens = pack_images(
|
||||
tokens,
|
||||
self.patch_size,
|
||||
max_grid_size=self.grid_size,
|
||||
|
Loading…
x
Reference in New Issue
Block a user