Fix a few more hiera API issues

pull/2156/head
Ross Wightman 2024-05-12 11:11:45 -07:00
parent 211d18d8ac
commit 3e03b2bf3f
1 changed files with 32 additions and 9 deletions

View File

@ -24,11 +24,12 @@ Adapted for timm from originals at https://github.com/facebookresearch/hiera
# --------------------------------------------------------
import math
from functools import partial
from typing import List, Tuple, Type, Callable, Optional, Union
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
@ -480,14 +481,14 @@ class Hiera(nn.Module):
):
super().__init__()
self.num_classes = num_classes
self.grad_checkpointing = False
norm_layer = get_norm_layer(norm_layer)
depth = sum(stages)
self.patch_stride = patch_stride
self.tokens_spatial_shape = [i // s for i, s in zip(img_size, patch_stride)]
num_tokens = math.prod(self.tokens_spatial_shape)
flat_mu_size = math.prod(mask_unit_size)
flat_q_stride = math.prod(q_stride)
assert q_pool < len(stages)
self.q_pool, self.q_stride = q_pool, q_stride
self.mu_size, self.mask_unit_size = flat_mu_size, mask_unit_size
@ -532,11 +533,10 @@ class Hiera(nn.Module):
# q_pool locations
q_pool_blocks = [x + 1 for x in self.stage_ends[:q_pool]]
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
# Transformer blocks
cur_stage = 0
depth = sum(stages)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList()
self.feature_info = []
for i in range(depth):
@ -586,8 +586,9 @@ class Hiera(nn.Module):
else:
nn.init.trunc_normal_(self.pos_embed, std=0.02)
self.apply(partial(self._init_weights))
self.head.fc.weight.data.mul_(head_init_scale)
self.head.fc.bias.data.mul_(head_init_scale)
if isinstance(self.head.fc, nn.Linear):
self.head.fc.weight.data.mul_(head_init_scale)
self.head.fc.bias.data.mul_(head_init_scale)
def _init_weights(self, m, init_bias=0.02):
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
@ -605,6 +606,25 @@ class Hiera(nn.Module):
else:
return ["pos_embed_spatial", "pos_embed_temporal"]
@torch.jit.ignore
def group_matcher(self, coarse: bool = False) -> Dict:
return dict(
stem=r'^pos_embed|pos_embed_spatial|pos_embed_temporal|patch_embed', # stem and embed
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True) -> None:
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, other: bool = False):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool, other=other)
def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor:
"""
Generates a random mask, mask_ratio fraction are dropped.
@ -740,7 +760,10 @@ class Hiera(nn.Module):
intermediates = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x)
else:
x = blk(x)
if return_intermediates and i in self.stage_ends:
intermediates.append(self.reroll(x, i, mask=mask))