Add a few missing __all__ entries.

weights_only
Ross Wightman 2024-08-07 09:35:51 -07:00
parent 10344625be
commit e9ef9424f0
4 changed files with 8 additions and 0 deletions

View File

@ -29,6 +29,8 @@ from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model from ._registry import generate_default_cfgs, register_model
__all__ = ['EfficientFormerV2']
EfficientFormer_width = { EfficientFormer_width = {
'L': (40, 80, 192, 384), # 26m 83.3% 6attn 'L': (40, 80, 192, 384), # 26m 83.3% 6attn
'S2': (32, 64, 144, 288), # 12m 81.6% 4attn dp0.02 'S2': (32, 64, 144, 288), # 12m 81.6% 4attn dp0.02

View File

@ -20,6 +20,7 @@ from ._features import feature_take_indices
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
__all__ = ['FastVit']
def num_groups(group_size, channels): def num_groups(group_size, channels):
if not group_size: # 0 or None if not group_size: # 0 or None

View File

@ -42,6 +42,9 @@ from ._features import feature_take_indices
from ._features_fx import register_notrace_function from ._features_fx import register_notrace_function
__all__ = ['Hiera']
def conv_nd(n: int) -> Type[nn.Module]: def conv_nd(n: int) -> Type[nn.Module]:
""" """
Returns a conv with nd (e.g., Conv2d for n=2). Work up to n=3. Returns a conv with nd (e.g., Conv2d for n=2). Work up to n=3.

View File

@ -20,6 +20,8 @@ from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model from ._registry import generate_default_cfgs, register_model
__all__ = ['NextViT']
def merge_pre_bn(module, pre_bn_1, pre_bn_2=None): def merge_pre_bn(module, pre_bn_1, pre_bn_2=None):
""" Merge pre BN to reduce inference runtime. """ Merge pre BN to reduce inference runtime.