mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Compact _covert_beit3 fn
This commit is contained in:
parent
38c5f3bc63
commit
2ca94a6ce4
@ -1181,87 +1181,64 @@ def _convert_aimv2(
|
||||
return out_dict
|
||||
|
||||
|
||||
def _convert_beit3(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
model: VisionTransformer,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Convert BEiT3 weights to standard VisionTransformer format."""
|
||||
def _convert_beit3(state_dict: dict, model):
|
||||
"""
|
||||
Turn a BEiT-3 checkpoint into a standard VisionTransformer state-dict.
|
||||
"""
|
||||
import re
|
||||
state_dict = state_dict.get("model", state_dict) # unwrap if needed
|
||||
|
||||
if 'model' in state_dict:
|
||||
state_dict = state_dict['model']
|
||||
# Prune unused
|
||||
for k in ("beit3.text_embed.weight", "beit3.vision_embed.mask_token"):
|
||||
state_dict.pop(k, None)
|
||||
|
||||
# Remove text and mask tokens (vision-only)
|
||||
state_dict.pop('beit3.text_embed.weight', None)
|
||||
state_dict.pop('beit3.vision_embed.mask_token', None)
|
||||
# Key renaming rules
|
||||
rules = [
|
||||
(r"beit3\.", ""),
|
||||
(r"vision_embed\.cls_token", "cls_token"),
|
||||
(r"vision_embed\.", "patch_embed."),
|
||||
(r"embed_positions\.", "pos_embed."),
|
||||
(r"encoder\.", ""),
|
||||
(r"layers\.", "blocks."),
|
||||
(r"ffn_layernorm\.", "norm."), (r"ffn\.", "mlp."),
|
||||
(r"self_attn_layer_norm\.", "norm1."), (r"self_attn\.", "attn."),
|
||||
(r"final_layer_norm\.", "norm2."),
|
||||
(r"inner_attn_ln", "norm"),
|
||||
(r"out_proj", "proj"),
|
||||
(r"\.A\.", "."),
|
||||
]
|
||||
|
||||
# First pass: Apply all key transformations except qkv fusion
|
||||
intermediate_dict = {}
|
||||
# First pass, rename keys
|
||||
tmp = {}
|
||||
for k, v in state_dict.items():
|
||||
# Skip B branch weights (use only A branch)
|
||||
if '.B.' in k:
|
||||
if ".B." in k:
|
||||
continue # use branch-A only
|
||||
for old, new in rules:
|
||||
k = re.sub(old, new, k)
|
||||
if k == "pos_embed.weight":
|
||||
# strip first two positions, [1, N+1, D]
|
||||
tmp["pos_embed"] = v[2:].unsqueeze(0)
|
||||
else:
|
||||
tmp[k] = v
|
||||
|
||||
# Second pass, fuse q, k, v
|
||||
out, buf = {}, {}
|
||||
pat = re.compile(r"blocks\.(\d+)\.attn\.(q|k|v)_proj\.(weight|bias)$")
|
||||
for k, v in tmp.items():
|
||||
m = pat.fullmatch(k)
|
||||
if not m: # anything not q/k/v -> copy through
|
||||
out[k] = v
|
||||
continue
|
||||
|
||||
# Apply all BEiT3 key transformations in one go
|
||||
if 'vision_embed.cls_token' in k:
|
||||
k = 'cls_token'
|
||||
else:
|
||||
k = k.replace('beit3.', '')
|
||||
k = k.replace('embed_positions.', 'pos_embed.')
|
||||
k = k.replace('vision_embed.', 'patch_embed.')
|
||||
k = k.replace('encoder.', '')
|
||||
k = k.replace('layers.', 'blocks.')
|
||||
k = k.replace('ffn.', 'mlp.')
|
||||
k = k.replace('ffn_layernorm.', 'norm.')
|
||||
k = k.replace('self_attn.', 'attn.')
|
||||
k = k.replace('self_attn_layer_norm.', 'norm1.')
|
||||
k = k.replace('final_layer_norm.', 'norm2.')
|
||||
k = k.replace('inner_attn_ln', 'norm') # Map inner attention LayerNorm to scale norm
|
||||
k = k.replace('out_proj', 'proj') # Map out_proj to proj
|
||||
k = k.replace('A.', '') # Remove A branch prefix
|
||||
blk, which, kind = m.groups() # block idx, 'q'/'k'/'v', 'weight'/'bias'
|
||||
stash = buf.setdefault((blk, kind), {}) # Gather by block & param type
|
||||
stash[which] = v
|
||||
if len(stash) == 3: # Have q, k, v -> concatenate
|
||||
out[f"blocks.{blk}.attn.qkv.{kind}"] = torch.cat(
|
||||
[stash['q'], stash['k'], stash['v']], dim=0
|
||||
)
|
||||
|
||||
# Handle positional embedding - skip first 2 positions (BEiT3 starts from index 2)
|
||||
if k == 'pos_embed.weight':
|
||||
# BEiT3 pos_embed.weight has shape [num_patches + 3, embed_dim]
|
||||
# We want [1, num_patches + 1, embed_dim] for standard ViT (cls token + patches)
|
||||
intermediate_dict['pos_embed'] = v[2:].unsqueeze(0) # Skip first 2 positions, add batch dim
|
||||
else:
|
||||
intermediate_dict[k] = v
|
||||
|
||||
# Second pass: Handle qkv fusion
|
||||
out_dict = {}
|
||||
processed_qkv = set()
|
||||
for k, v in intermediate_dict.items():
|
||||
# Handle attention projections - convert separate q,k,v to fused qkv
|
||||
if re.match(r"blocks\.(\d+)\.attn\.[qkv]_proj\.(weight|bias)", k):
|
||||
block_idx = re.search(r"blocks\.(\d+)", k).group(1)
|
||||
param_type = re.search(r"\.(weight|bias)$", k).group(1)
|
||||
|
||||
# Only process once per block per parameter type
|
||||
block_param_key = f"{block_idx}_{param_type}"
|
||||
if block_param_key in processed_qkv:
|
||||
continue
|
||||
|
||||
# Collect all three projections for this block
|
||||
q_key = f"blocks.{block_idx}.attn.q_proj.{param_type}"
|
||||
k_key = f"blocks.{block_idx}.attn.k_proj.{param_type}"
|
||||
v_key = f"blocks.{block_idx}.attn.v_proj.{param_type}"
|
||||
|
||||
if all(key in intermediate_dict for key in [q_key, k_key, v_key]):
|
||||
qkv_tensor = torch.cat([
|
||||
intermediate_dict[q_key],
|
||||
intermediate_dict[k_key],
|
||||
intermediate_dict[v_key]
|
||||
], dim=0)
|
||||
out_dict[f"blocks.{block_idx}.attn.qkv.{param_type}"] = qkv_tensor
|
||||
processed_qkv.add(block_param_key)
|
||||
continue
|
||||
else:
|
||||
assert False
|
||||
else:
|
||||
out_dict[k] = v
|
||||
|
||||
return out_dict
|
||||
return out
|
||||
|
||||
|
||||
def checkpoint_filter_fn(
|
||||
|
Loading…
x
Reference in New Issue
Block a user