Fix a few more hiera API issues
parent
211d18d8ac
commit
3e03b2bf3f
|
@ -24,11 +24,12 @@ Adapted for timm from originals at https://github.com/facebookresearch/hiera
|
|||
# --------------------------------------------------------
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import List, Tuple, Type, Callable, Optional, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
@ -480,14 +481,14 @@ class Hiera(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.grad_checkpointing = False
|
||||
norm_layer = get_norm_layer(norm_layer)
|
||||
depth = sum(stages)
|
||||
|
||||
self.patch_stride = patch_stride
|
||||
self.tokens_spatial_shape = [i // s for i, s in zip(img_size, patch_stride)]
|
||||
num_tokens = math.prod(self.tokens_spatial_shape)
|
||||
flat_mu_size = math.prod(mask_unit_size)
|
||||
flat_q_stride = math.prod(q_stride)
|
||||
|
||||
assert q_pool < len(stages)
|
||||
self.q_pool, self.q_stride = q_pool, q_stride
|
||||
self.mu_size, self.mask_unit_size = flat_mu_size, mask_unit_size
|
||||
|
@ -532,11 +533,10 @@ class Hiera(nn.Module):
|
|||
# q_pool locations
|
||||
q_pool_blocks = [x + 1 for x in self.stage_ends[:q_pool]]
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
||||
|
||||
# Transformer blocks
|
||||
cur_stage = 0
|
||||
depth = sum(stages)
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
self.blocks = nn.ModuleList()
|
||||
self.feature_info = []
|
||||
for i in range(depth):
|
||||
|
@ -586,8 +586,9 @@ class Hiera(nn.Module):
|
|||
else:
|
||||
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
||||
self.apply(partial(self._init_weights))
|
||||
self.head.fc.weight.data.mul_(head_init_scale)
|
||||
self.head.fc.bias.data.mul_(head_init_scale)
|
||||
if isinstance(self.head.fc, nn.Linear):
|
||||
self.head.fc.weight.data.mul_(head_init_scale)
|
||||
self.head.fc.bias.data.mul_(head_init_scale)
|
||||
|
||||
def _init_weights(self, m, init_bias=0.02):
|
||||
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||
|
@ -605,6 +606,25 @@ class Hiera(nn.Module):
|
|||
else:
|
||||
return ["pos_embed_spatial", "pos_embed_temporal"]
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse: bool = False) -> Dict:
|
||||
return dict(
|
||||
stem=r'^pos_embed|pos_embed_spatial|pos_embed_temporal|patch_embed', # stem and embed
|
||||
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable: bool = True) -> None:
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, other: bool = False):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool, other=other)
|
||||
|
||||
def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor:
|
||||
"""
|
||||
Generates a random mask, mask_ratio fraction are dropped.
|
||||
|
@ -740,7 +760,10 @@ class Hiera(nn.Module):
|
|||
|
||||
intermediates = []
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint(blk, x)
|
||||
else:
|
||||
x = blk(x)
|
||||
if return_intermediates and i in self.stage_ends:
|
||||
intermediates.append(self.reroll(x, i, mask=mask))
|
||||
|
||||
|
|
Loading…
Reference in New Issue