150 lines
4.1 KiB
Python
150 lines
4.1 KiB
Python
""" Model / Layer Config singleton state
|
|
"""
|
|
import os
|
|
import warnings
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
|
|
__all__ = [
|
|
'is_exportable', 'is_scriptable', 'is_no_jit', 'use_fused_attn',
|
|
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn'
|
|
]
|
|
|
|
# Set to True if prefer to have layers with no jit optimization (includes activations)
|
|
_NO_JIT = False
|
|
|
|
# Set to True if prefer to have activation layers with no jit optimization
|
|
# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying
|
|
# the jit flags so far are activations. This will change as more layers are updated and/or added.
|
|
_NO_ACTIVATION_JIT = False
|
|
|
|
# Set to True if exporting a model with Same padding via ONNX
|
|
_EXPORTABLE = False
|
|
|
|
# Set to True if wanting to use torch.jit.script on a model
|
|
_SCRIPTABLE = False
|
|
|
|
|
|
# use torch.scaled_dot_product_attention where possible
|
|
_HAS_FUSED_ATTN = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
|
if 'TIMM_FUSED_ATTN' in os.environ:
|
|
_USE_FUSED_ATTN = int(os.environ['TIMM_FUSED_ATTN'])
|
|
else:
|
|
_USE_FUSED_ATTN = 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use)
|
|
|
|
|
|
def is_no_jit():
|
|
return _NO_JIT
|
|
|
|
|
|
class set_no_jit:
|
|
def __init__(self, mode: bool) -> None:
|
|
global _NO_JIT
|
|
self.prev = _NO_JIT
|
|
_NO_JIT = mode
|
|
|
|
def __enter__(self) -> None:
|
|
pass
|
|
|
|
def __exit__(self, *args: Any) -> bool:
|
|
global _NO_JIT
|
|
_NO_JIT = self.prev
|
|
return False
|
|
|
|
|
|
def is_exportable():
|
|
return _EXPORTABLE
|
|
|
|
|
|
class set_exportable:
|
|
def __init__(self, mode: bool) -> None:
|
|
global _EXPORTABLE
|
|
self.prev = _EXPORTABLE
|
|
_EXPORTABLE = mode
|
|
|
|
def __enter__(self) -> None:
|
|
pass
|
|
|
|
def __exit__(self, *args: Any) -> bool:
|
|
global _EXPORTABLE
|
|
_EXPORTABLE = self.prev
|
|
return False
|
|
|
|
|
|
def is_scriptable():
|
|
return _SCRIPTABLE
|
|
|
|
|
|
class set_scriptable:
|
|
def __init__(self, mode: bool) -> None:
|
|
global _SCRIPTABLE
|
|
self.prev = _SCRIPTABLE
|
|
_SCRIPTABLE = mode
|
|
|
|
def __enter__(self) -> None:
|
|
pass
|
|
|
|
def __exit__(self, *args: Any) -> bool:
|
|
global _SCRIPTABLE
|
|
_SCRIPTABLE = self.prev
|
|
return False
|
|
|
|
|
|
class set_layer_config:
|
|
""" Layer config context manager that allows setting all layer config flags at once.
|
|
If a flag arg is None, it will not change the current value.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
scriptable: Optional[bool] = None,
|
|
exportable: Optional[bool] = None,
|
|
no_jit: Optional[bool] = None,
|
|
no_activation_jit: Optional[bool] = None):
|
|
global _SCRIPTABLE
|
|
global _EXPORTABLE
|
|
global _NO_JIT
|
|
global _NO_ACTIVATION_JIT
|
|
self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT
|
|
if scriptable is not None:
|
|
_SCRIPTABLE = scriptable
|
|
if exportable is not None:
|
|
_EXPORTABLE = exportable
|
|
if no_jit is not None:
|
|
_NO_JIT = no_jit
|
|
if no_activation_jit is not None:
|
|
_NO_ACTIVATION_JIT = no_activation_jit
|
|
|
|
def __enter__(self) -> None:
|
|
pass
|
|
|
|
def __exit__(self, *args: Any) -> bool:
|
|
global _SCRIPTABLE
|
|
global _EXPORTABLE
|
|
global _NO_JIT
|
|
global _NO_ACTIVATION_JIT
|
|
_SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev
|
|
return False
|
|
|
|
|
|
def use_fused_attn(experimental: bool = False) -> bool:
|
|
# NOTE: ONNX export cannot handle F.scaled_dot_product_attention as of pytorch 2.0
|
|
if not _HAS_FUSED_ATTN or _EXPORTABLE:
|
|
return False
|
|
if experimental:
|
|
return _USE_FUSED_ATTN > 1
|
|
return _USE_FUSED_ATTN > 0
|
|
|
|
|
|
def set_fused_attn(enable: bool = True, experimental: bool = False):
|
|
global _USE_FUSED_ATTN
|
|
if not _HAS_FUSED_ATTN:
|
|
warnings.warn('This version of pytorch does not have F.scaled_dot_product_attention, fused_attn flag ignored.')
|
|
return
|
|
if experimental and enable:
|
|
_USE_FUSED_ATTN = 2
|
|
elif enable:
|
|
_USE_FUSED_ATTN = 1
|
|
else:
|
|
_USE_FUSED_ATTN = 0
|