Add set_layer_config contextmgr to adjust all layer configs at once, use in create_module with new args. Remove a few old warning causing constant annotations for jit.
parent
f28170df3f
commit
88129b2569
|
@ -10,7 +10,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Union, Optional, List, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from .registry import is_model, is_model_in_modules, model_entrypoint
|
||||
from .helpers import load_checkpoint
|
||||
from .layers import set_layer_config
|
||||
|
||||
|
||||
def create_model(
|
||||
|
@ -8,6 +9,9 @@ def create_model(
|
|||
num_classes=1000,
|
||||
in_chans=3,
|
||||
checkpoint_path='',
|
||||
scriptable=None,
|
||||
exportable=None,
|
||||
no_jit=None,
|
||||
**kwargs):
|
||||
"""Create a model
|
||||
|
||||
|
@ -17,13 +21,16 @@ def create_model(
|
|||
num_classes (int): number of classes for final fully connected layer (default: 1000)
|
||||
in_chans (int): number of input channels / colors (default: 3)
|
||||
checkpoint_path (str): path of checkpoint to load after model is initialized
|
||||
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
|
||||
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
|
||||
no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)
|
||||
|
||||
Keyword Args:
|
||||
drop_rate (float): dropout rate for training (default: 0.0)
|
||||
global_pool (str): global pool type (default: 'avg')
|
||||
**: other kwargs are model specific
|
||||
"""
|
||||
margs = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans)
|
||||
model_args = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans)
|
||||
|
||||
# Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args
|
||||
is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3'])
|
||||
|
@ -47,11 +54,12 @@ def create_model(
|
|||
if kwargs.get('drop_path_rate', None) is None:
|
||||
kwargs.pop('drop_path_rate', None)
|
||||
|
||||
if is_model(model_name):
|
||||
create_fn = model_entrypoint(model_name)
|
||||
model = create_fn(**margs, **kwargs)
|
||||
else:
|
||||
raise RuntimeError('Unknown model (%s)' % model_name)
|
||||
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
|
||||
if is_model(model_name):
|
||||
create_fn = model_entrypoint(model_name)
|
||||
model = create_fn(**model_args, **kwargs)
|
||||
else:
|
||||
raise RuntimeError('Unknown model (%s)' % model_name)
|
||||
|
||||
if checkpoint_path:
|
||||
load_checkpoint(model, checkpoint_path)
|
||||
|
|
|
@ -193,7 +193,6 @@ class Mixed_7a(nn.Module):
|
|||
|
||||
|
||||
class Block8(nn.Module):
|
||||
__constants__ = ['relu'] # for pre 1.4 torchscript compat
|
||||
|
||||
def __init__(self, scale=1.0, no_relu=False):
|
||||
super(Block8, self).__init__()
|
||||
|
|
|
@ -4,7 +4,8 @@ from .adaptive_avgmax_pool import \
|
|||
from .anti_aliasing import AntiAliasDownsampleLayer
|
||||
from .blur_pool import BlurPool2d
|
||||
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||
from .config import is_exportable, is_scriptable, set_exportable, set_scriptable, is_no_jit, set_no_jit
|
||||
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
|
||||
set_layer_config
|
||||
from .conv2d_same import Conv2dSame
|
||||
from .conv_bn_act import ConvBnAct
|
||||
from .create_act import create_act_layer, get_act_layer, get_act_fn
|
||||
|
|
|
@ -38,7 +38,7 @@ class CondConv2d(nn.Module):
|
|||
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
|
||||
https://github.com/pytorch/pytorch/issues/17983
|
||||
"""
|
||||
__constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding']
|
||||
__constants__ = ['in_channels', 'out_channels', 'dynamic_padding']
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3,
|
||||
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
|
||||
|
|
|
@ -1,13 +1,18 @@
|
|||
""" Model / Layer Config Singleton
|
||||
""" Model / Layer Config singleton state
|
||||
"""
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
__all__ = ['is_exportable', 'is_scriptable', 'set_exportable', 'set_scriptable', 'is_no_jit', 'set_no_jit']
|
||||
__all__ = [
|
||||
'is_exportable', 'is_scriptable', 'is_no_jit',
|
||||
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config'
|
||||
]
|
||||
|
||||
# Set to True if prefer to have layers with no jit optimization (includes activations)
|
||||
_NO_JIT = False
|
||||
|
||||
# Set to True if prefer to have activation layers with no jit optimization
|
||||
# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying
|
||||
# the jit flags so far are activations. This will change as more layers are updated and/or added.
|
||||
_NO_ACTIVATION_JIT = False
|
||||
|
||||
# Set to True if exporting a model with Same padding via ONNX
|
||||
|
@ -72,3 +77,39 @@ class set_scriptable:
|
|||
global _SCRIPTABLE
|
||||
_SCRIPTABLE = self.prev
|
||||
return False
|
||||
|
||||
|
||||
class set_layer_config:
|
||||
""" Layer config context manager that allows setting all layer config flags at once.
|
||||
If a flag arg is None, it will not change the current value.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
scriptable: Optional[bool] = None,
|
||||
exportable: Optional[bool] = None,
|
||||
no_jit: Optional[bool] = None,
|
||||
no_activation_jit: Optional[bool] = None):
|
||||
global _SCRIPTABLE
|
||||
global _EXPORTABLE
|
||||
global _NO_JIT
|
||||
global _NO_ACTIVATION_JIT
|
||||
self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT
|
||||
if scriptable is not None:
|
||||
_SCRIPTABLE = scriptable
|
||||
if exportable is not None:
|
||||
_EXPORTABLE = exportable
|
||||
if no_jit is not None:
|
||||
_NO_JIT = no_jit
|
||||
if no_activation_jit is not None:
|
||||
_NO_ACTIVATION_JIT = no_activation_jit
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
def __exit__(self, *args: Any) -> bool:
|
||||
global _SCRIPTABLE
|
||||
global _EXPORTABLE
|
||||
global _NO_JIT
|
||||
global _NO_ACTIVATION_JIT
|
||||
_SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev
|
||||
return False
|
||||
|
|
|
@ -5,7 +5,7 @@ Hacked together by Ross Wightman
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Union, List, Tuple, Optional
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from .helpers import tup_pair
|
||||
from .padding import pad_same, get_padding_value
|
||||
|
|
|
@ -85,15 +85,13 @@ def validate(args):
|
|||
args.pretrained = args.pretrained or not args.checkpoint
|
||||
args.prefetcher = not args.no_prefetcher
|
||||
|
||||
if args.torchscript:
|
||||
set_scriptable(True)
|
||||
|
||||
# create model
|
||||
model = create_model(
|
||||
args.model,
|
||||
pretrained=args.pretrained,
|
||||
num_classes=args.num_classes,
|
||||
in_chans=3,
|
||||
pretrained=args.pretrained)
|
||||
scriptable=args.torchscript)
|
||||
|
||||
if args.checkpoint:
|
||||
load_checkpoint(model, args.checkpoint, args.use_ema)
|
||||
|
|
Loading…
Reference in New Issue