mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add missing feature_info() on MobileNetV3, make hook feature output order/type consistent with bottleneck (list, decreasing fmap size)
This commit is contained in:
parent
88129b2569
commit
7be299504f
@ -24,9 +24,12 @@ An implementation of EfficienNet that covers variety of related models with effi
|
||||
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from typing import List
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
|
||||
@ -471,7 +474,7 @@ class EfficientNetFeatures(nn.Module):
|
||||
return self._feature_info[idx]
|
||||
return [self._feature_info[i] for i in self.out_indices]
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x) -> List[torch.Tensor]:
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
|
@ -1,5 +1,8 @@
|
||||
import torch
|
||||
|
||||
from collections import defaultdict, OrderedDict
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
|
||||
class FeatureHooks:
|
||||
@ -25,7 +28,7 @@ class FeatureHooks:
|
||||
x = x[0] # unwrap input tuple
|
||||
self._feature_outputs[x.device][name] = x
|
||||
|
||||
def get_output(self, device):
|
||||
output = tuple(self._feature_outputs[device].values())[::-1]
|
||||
def get_output(self, device) -> List[torch.tensor]:
|
||||
output = list(self._feature_outputs[device].values())
|
||||
self._feature_outputs[device] = OrderedDict() # clear after reading
|
||||
return output
|
||||
|
@ -7,9 +7,12 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
|
||||
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from typing import List
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
|
||||
@ -206,7 +209,16 @@ class MobileNetV3Features(nn.Module):
|
||||
return self._feature_info[idx]['num_chs']
|
||||
return [self._feature_info[i]['num_chs'] for i in self.out_indices]
|
||||
|
||||
def forward(self, x):
|
||||
def feature_info(self, idx=None):
|
||||
""" Feature Channel Shortcut
|
||||
Returns feature channel count for each output index if idx == None. If idx is an integer, will
|
||||
return feature channel count for that feature block index (independent of out_indices setting).
|
||||
"""
|
||||
if isinstance(idx, int):
|
||||
return self._feature_info[idx]
|
||||
return [self._feature_info[i] for i in self.out_indices]
|
||||
|
||||
def forward(self, x) -> List[torch.Tensor]:
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user