mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add missing feature_info() on MobileNetV3, make hook feature output order/type consistent with bottleneck (list, decreasing fmap size)
This commit is contained in:
parent
88129b2569
commit
7be299504f
@ -24,9 +24,12 @@ An implementation of EfficienNet that covers variety of related models with effi
|
|||||||
|
|
||||||
Hacked together by Ross Wightman
|
Hacked together by Ross Wightman
|
||||||
"""
|
"""
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||||
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
|
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
|
||||||
@ -471,7 +474,7 @@ class EfficientNetFeatures(nn.Module):
|
|||||||
return self._feature_info[idx]
|
return self._feature_info[idx]
|
||||||
return [self._feature_info[i] for i in self.out_indices]
|
return [self._feature_info[i] for i in self.out_indices]
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x) -> List[torch.Tensor]:
|
||||||
x = self.conv_stem(x)
|
x = self.conv_stem(x)
|
||||||
x = self.bn1(x)
|
x = self.bn1(x)
|
||||||
x = self.act1(x)
|
x = self.act1(x)
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
from collections import defaultdict, OrderedDict
|
from collections import defaultdict, OrderedDict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
class FeatureHooks:
|
class FeatureHooks:
|
||||||
@ -25,7 +28,7 @@ class FeatureHooks:
|
|||||||
x = x[0] # unwrap input tuple
|
x = x[0] # unwrap input tuple
|
||||||
self._feature_outputs[x.device][name] = x
|
self._feature_outputs[x.device][name] = x
|
||||||
|
|
||||||
def get_output(self, device):
|
def get_output(self, device) -> List[torch.tensor]:
|
||||||
output = tuple(self._feature_outputs[device].values())[::-1]
|
output = list(self._feature_outputs[device].values())
|
||||||
self._feature_outputs[device] = OrderedDict() # clear after reading
|
self._feature_outputs[device] = OrderedDict() # clear after reading
|
||||||
return output
|
return output
|
||||||
|
@ -7,9 +7,12 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
|
|||||||
|
|
||||||
Hacked together by Ross Wightman
|
Hacked together by Ross Wightman
|
||||||
"""
|
"""
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||||
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
|
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
|
||||||
@ -206,7 +209,16 @@ class MobileNetV3Features(nn.Module):
|
|||||||
return self._feature_info[idx]['num_chs']
|
return self._feature_info[idx]['num_chs']
|
||||||
return [self._feature_info[i]['num_chs'] for i in self.out_indices]
|
return [self._feature_info[i]['num_chs'] for i in self.out_indices]
|
||||||
|
|
||||||
def forward(self, x):
|
def feature_info(self, idx=None):
|
||||||
|
""" Feature Channel Shortcut
|
||||||
|
Returns feature channel count for each output index if idx == None. If idx is an integer, will
|
||||||
|
return feature channel count for that feature block index (independent of out_indices setting).
|
||||||
|
"""
|
||||||
|
if isinstance(idx, int):
|
||||||
|
return self._feature_info[idx]
|
||||||
|
return [self._feature_info[i] for i in self.out_indices]
|
||||||
|
|
||||||
|
def forward(self, x) -> List[torch.Tensor]:
|
||||||
x = self.conv_stem(x)
|
x = self.conv_stem(x)
|
||||||
x = self.bn1(x)
|
x = self.bn1(x)
|
||||||
x = self.act1(x)
|
x = self.act1(x)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user