Fix a few more hiera API issues
parent
211d18d8ac
commit
3e03b2bf3f
|
@ -24,11 +24,12 @@ Adapted for timm from originals at https://github.com/facebookresearch/hiera
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List, Tuple, Type, Callable, Optional, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
@ -480,14 +481,14 @@ class Hiera(nn.Module):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
|
self.grad_checkpointing = False
|
||||||
norm_layer = get_norm_layer(norm_layer)
|
norm_layer = get_norm_layer(norm_layer)
|
||||||
depth = sum(stages)
|
|
||||||
self.patch_stride = patch_stride
|
self.patch_stride = patch_stride
|
||||||
self.tokens_spatial_shape = [i // s for i, s in zip(img_size, patch_stride)]
|
self.tokens_spatial_shape = [i // s for i, s in zip(img_size, patch_stride)]
|
||||||
num_tokens = math.prod(self.tokens_spatial_shape)
|
num_tokens = math.prod(self.tokens_spatial_shape)
|
||||||
flat_mu_size = math.prod(mask_unit_size)
|
flat_mu_size = math.prod(mask_unit_size)
|
||||||
flat_q_stride = math.prod(q_stride)
|
flat_q_stride = math.prod(q_stride)
|
||||||
|
|
||||||
assert q_pool < len(stages)
|
assert q_pool < len(stages)
|
||||||
self.q_pool, self.q_stride = q_pool, q_stride
|
self.q_pool, self.q_stride = q_pool, q_stride
|
||||||
self.mu_size, self.mask_unit_size = flat_mu_size, mask_unit_size
|
self.mu_size, self.mask_unit_size = flat_mu_size, mask_unit_size
|
||||||
|
@ -532,11 +533,10 @@ class Hiera(nn.Module):
|
||||||
# q_pool locations
|
# q_pool locations
|
||||||
q_pool_blocks = [x + 1 for x in self.stage_ends[:q_pool]]
|
q_pool_blocks = [x + 1 for x in self.stage_ends[:q_pool]]
|
||||||
|
|
||||||
# stochastic depth decay rule
|
|
||||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
|
||||||
|
|
||||||
# Transformer blocks
|
# Transformer blocks
|
||||||
cur_stage = 0
|
cur_stage = 0
|
||||||
|
depth = sum(stages)
|
||||||
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||||
self.blocks = nn.ModuleList()
|
self.blocks = nn.ModuleList()
|
||||||
self.feature_info = []
|
self.feature_info = []
|
||||||
for i in range(depth):
|
for i in range(depth):
|
||||||
|
@ -586,8 +586,9 @@ class Hiera(nn.Module):
|
||||||
else:
|
else:
|
||||||
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
||||||
self.apply(partial(self._init_weights))
|
self.apply(partial(self._init_weights))
|
||||||
self.head.fc.weight.data.mul_(head_init_scale)
|
if isinstance(self.head.fc, nn.Linear):
|
||||||
self.head.fc.bias.data.mul_(head_init_scale)
|
self.head.fc.weight.data.mul_(head_init_scale)
|
||||||
|
self.head.fc.bias.data.mul_(head_init_scale)
|
||||||
|
|
||||||
def _init_weights(self, m, init_bias=0.02):
|
def _init_weights(self, m, init_bias=0.02):
|
||||||
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||||
|
@ -605,6 +606,25 @@ class Hiera(nn.Module):
|
||||||
else:
|
else:
|
||||||
return ["pos_embed_spatial", "pos_embed_temporal"]
|
return ["pos_embed_spatial", "pos_embed_temporal"]
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def group_matcher(self, coarse: bool = False) -> Dict:
|
||||||
|
return dict(
|
||||||
|
stem=r'^pos_embed|pos_embed_spatial|pos_embed_temporal|patch_embed', # stem and embed
|
||||||
|
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable: bool = True) -> None:
|
||||||
|
self.grad_checkpointing = enable
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def get_classifier(self):
|
||||||
|
return self.head.fc
|
||||||
|
|
||||||
|
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, other: bool = False):
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.head.reset(num_classes, global_pool, other=other)
|
||||||
|
|
||||||
def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor:
|
def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Generates a random mask, mask_ratio fraction are dropped.
|
Generates a random mask, mask_ratio fraction are dropped.
|
||||||
|
@ -740,7 +760,10 @@ class Hiera(nn.Module):
|
||||||
|
|
||||||
intermediates = []
|
intermediates = []
|
||||||
for i, blk in enumerate(self.blocks):
|
for i, blk in enumerate(self.blocks):
|
||||||
x = blk(x)
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
x = checkpoint(blk, x)
|
||||||
|
else:
|
||||||
|
x = blk(x)
|
||||||
if return_intermediates and i in self.stage_ends:
|
if return_intermediates and i in self.stage_ends:
|
||||||
intermediates.append(self.reroll(x, i, mask=mask))
|
intermediates.append(self.reroll(x, i, mask=mask))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue