Remove padding calc from pack, minor fixes

This commit is contained in:
Ross Wightman 2023-09-14 10:12:07 -07:00
parent d81f75b461
commit f93083e2b2

View File

@ -84,13 +84,21 @@ class PackedSequence:
self.total_len += seq_len
self.num_images += 1
def to_tensors(self, max_len, max_packed, return_mask=True):
def to_tensors(self, max_seq_len, max_num_seq):
"""
Args:
max_seq_len: maximum sequence length (pad to this)
max_num_seq: maximum # of sequences (images) packed into one sequence (across the batch)
Returns:
Tuple of tensors for packed batch of images
"""
assert self.total_len > 0
assert max_len >= self.total_len
assert max_seq_len >= self.total_len
device = self.tokens[-1].device
dim = self.tokens[-1].shape[-1]
pad_len = max_len - self.total_len
seq_pad = max(0, max_packed - len(self.seq_lens))
pad_len = max_seq_len - self.total_len
seq_pad = max(0, max_num_seq - len(self.seq_lens))
seq_lens = self.seq_lens + [0] * seq_pad if seq_pad else self.seq_lens
seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=device)
if pad_len:
@ -104,9 +112,6 @@ class PackedSequence:
tokens = torch.concat(tokens)
pos_indices = torch.concat(pos_indices)
seq_ids = torch.concat(seq_ids)
if return_mask:
mask = seq_ids != 0
return tokens, pos_indices, seq_ids, seq_lens, mask
return tokens, pos_indices, seq_ids, seq_lens
@ -173,7 +178,7 @@ def pack_images(
max_packed = max(sequence.num_images, max_packed)
next_pos += 1
tensors = [p.to_tensors(max_len=max_seq_len, max_packed=max_packed) for p in packed_sequences]
tensors = [p.to_tensors(max_seq_len=max_seq_len, max_num_seq=max_packed) for p in packed_sequences]
o = [torch.stack(t) for t in zip(*tensors)]
return tuple(o)
@ -655,12 +660,12 @@ class VisionTransformerPacked(nn.Module):
@torch.jit.ignore
def no_weight_decay(self):
return {'embeds.pos_embed', 'embeds.cls_token'}
return {'pos_embed_h', 'pos_embed_w'}
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^embeds', # stem and embed
stem=r'^embeds', # stem and embed # FIXME correct when design finalized
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
)
@ -675,7 +680,7 @@ class VisionTransformerPacked(nn.Module):
def reset_classifier(self, num_classes: int, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg', 'token')
assert global_pool in ('', 'avg', 'attn')
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
@ -693,7 +698,7 @@ class VisionTransformerPacked(nn.Module):
tokens = tokens.unbind(0)
if isinstance(tokens, (list, tuple)):
tokens, pos_indices, seq_ids, seq_lens, padding_mask = pack_images(
tokens, pos_indices, seq_ids, seq_lens = pack_images(
tokens,
self.patch_size,
max_grid_size=self.grid_size,