Fix a few more hiera API issues

pull/2156/head
Ross Wightman 2024-05-12 11:11:45 -07:00
parent 211d18d8ac
commit 3e03b2bf3f
1 changed files with 32 additions and 9 deletions

View File

@ -24,11 +24,12 @@ Adapted for timm from originals at https://github.com/facebookresearch/hiera
# -------------------------------------------------------- # --------------------------------------------------------
import math import math
from functools import partial from functools import partial
from typing import List, Tuple, Type, Callable, Optional, Union from typing import Callable, Dict, List, Optional, Tuple, Type, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
@ -480,14 +481,14 @@ class Hiera(nn.Module):
): ):
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.grad_checkpointing = False
norm_layer = get_norm_layer(norm_layer) norm_layer = get_norm_layer(norm_layer)
depth = sum(stages)
self.patch_stride = patch_stride self.patch_stride = patch_stride
self.tokens_spatial_shape = [i // s for i, s in zip(img_size, patch_stride)] self.tokens_spatial_shape = [i // s for i, s in zip(img_size, patch_stride)]
num_tokens = math.prod(self.tokens_spatial_shape) num_tokens = math.prod(self.tokens_spatial_shape)
flat_mu_size = math.prod(mask_unit_size) flat_mu_size = math.prod(mask_unit_size)
flat_q_stride = math.prod(q_stride) flat_q_stride = math.prod(q_stride)
assert q_pool < len(stages) assert q_pool < len(stages)
self.q_pool, self.q_stride = q_pool, q_stride self.q_pool, self.q_stride = q_pool, q_stride
self.mu_size, self.mask_unit_size = flat_mu_size, mask_unit_size self.mu_size, self.mask_unit_size = flat_mu_size, mask_unit_size
@ -532,11 +533,10 @@ class Hiera(nn.Module):
# q_pool locations # q_pool locations
q_pool_blocks = [x + 1 for x in self.stage_ends[:q_pool]] q_pool_blocks = [x + 1 for x in self.stage_ends[:q_pool]]
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
# Transformer blocks # Transformer blocks
cur_stage = 0 cur_stage = 0
depth = sum(stages)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList() self.blocks = nn.ModuleList()
self.feature_info = [] self.feature_info = []
for i in range(depth): for i in range(depth):
@ -586,8 +586,9 @@ class Hiera(nn.Module):
else: else:
nn.init.trunc_normal_(self.pos_embed, std=0.02) nn.init.trunc_normal_(self.pos_embed, std=0.02)
self.apply(partial(self._init_weights)) self.apply(partial(self._init_weights))
self.head.fc.weight.data.mul_(head_init_scale) if isinstance(self.head.fc, nn.Linear):
self.head.fc.bias.data.mul_(head_init_scale) self.head.fc.weight.data.mul_(head_init_scale)
self.head.fc.bias.data.mul_(head_init_scale)
def _init_weights(self, m, init_bias=0.02): def _init_weights(self, m, init_bias=0.02):
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)): if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
@ -605,6 +606,25 @@ class Hiera(nn.Module):
else: else:
return ["pos_embed_spatial", "pos_embed_temporal"] return ["pos_embed_spatial", "pos_embed_temporal"]
@torch.jit.ignore
def group_matcher(self, coarse: bool = False) -> Dict:
return dict(
stem=r'^pos_embed|pos_embed_spatial|pos_embed_temporal|patch_embed', # stem and embed
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True) -> None:
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, other: bool = False):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool, other=other)
def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor: def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor:
""" """
Generates a random mask, mask_ratio fraction are dropped. Generates a random mask, mask_ratio fraction are dropped.
@ -740,7 +760,10 @@ class Hiera(nn.Module):
intermediates = [] intermediates = []
for i, blk in enumerate(self.blocks): for i, blk in enumerate(self.blocks):
x = blk(x) if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x)
else:
x = blk(x)
if return_intermediates and i in self.stage_ends: if return_intermediates and i in self.stage_ends:
intermediates.append(self.reroll(x, i, mask=mask)) intermediates.append(self.reroll(x, i, mask=mask))