mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
* Split MobileNetV3 and EfficientNet model files and put builder and blocks in own files (getting too large) * Finalize CondConv EfficientNet variant * Add the AdvProp weights files and B8 EfficientNet model * Refine the feature extraction module for EfficientNet and MobileNetV3
32 lines
1.1 KiB
Python
32 lines
1.1 KiB
Python
from collections import defaultdict, OrderedDict
|
|
from functools import partial
|
|
|
|
|
|
class FeatureHooks:
|
|
|
|
def __init__(self, hooks, named_modules):
|
|
# setup feature hooks
|
|
modules = {k: v for k, v in named_modules}
|
|
for h in hooks:
|
|
hook_name = h['name']
|
|
m = modules[hook_name]
|
|
hook_fn = partial(self._collect_output_hook, hook_name)
|
|
if h['type'] == 'forward_pre':
|
|
m.register_forward_pre_hook(hook_fn)
|
|
elif h['type'] == 'forward':
|
|
m.register_forward_hook(hook_fn)
|
|
else:
|
|
assert False, "Unsupported hook type"
|
|
self._feature_outputs = defaultdict(OrderedDict)
|
|
|
|
def _collect_output_hook(self, name, *args):
|
|
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
|
|
if isinstance(x, tuple):
|
|
x = x[0] # unwrap input tuple
|
|
self._feature_outputs[x.device][name] = x
|
|
|
|
def get_output(self, device):
|
|
output = tuple(self._feature_outputs[device].values())[::-1]
|
|
self._feature_outputs[device] = OrderedDict() # clear after reading
|
|
return output
|