Fix a few typos, fix fastvit proj_drop, add code link

This commit is contained in:
Ross Wightman 2023-08-28 21:26:29 -07:00
parent fc5d705b83
commit c8b2f28096
2 changed files with 10 additions and 7 deletions

View File

@ -1,11 +1,13 @@
# FastViT for PyTorch
#
# Original implementation and weights from https://github.com/apple/ml-fastvit
#
# For licensing see accompanying LICENSE file at https://github.com/apple/ml-fastvit/tree/main
# Original work is copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import os
from functools import partial
from typing import List, Tuple, Optional, Union
from typing import Tuple, Optional, Union
import torch
import torch.nn as nn
@ -1141,7 +1143,7 @@ class FastVit(nn.Module):
mlp_ratio=mlp_ratios[i],
act_layer=act_layer,
norm_layer=norm_layer,
proj_drop_rate=drop_rate,
proj_drop_rate=proj_drop_rate,
drop_path_rate=dpr[i],
layer_scale_init_value=layer_scale_init_value,
lkc_use_act=lkc_use_act,

View File

@ -1,5 +1,6 @@
"""
InceptionNeXt implementation, paper: https://arxiv.org/abs/2303.16900
InceptionNeXt paper: https://arxiv.org/abs/2303.16900
Original implementation & weights from: https://github.com/sail-sg/inceptionnext
"""
from functools import partial
@ -8,14 +9,14 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath, to_2tuple, create_conv2d, get_padding, SelectAdaptivePool2d
from timm.layers import trunc_normal_, DropPath, to_2tuple, get_padding, SelectAdaptivePool2d
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs
class InceptionDWConv2d(nn.Module):
""" Inception depthweise convolution
""" Inception depthwise convolution
"""
def __init__(
@ -219,7 +220,7 @@ class MetaNeXtStage(nn.Module):
class MetaNeXt(nn.Module):
r""" MetaNeXt
A PyTorch impl of : `InceptionNeXt: When Inception Meets ConvNeXt` - https://arxiv.org/pdf/2203.xxxxx.pdf
A PyTorch impl of : `InceptionNeXt: When Inception Meets ConvNeXt` - https://arxiv.org/abs/2303.16900
Args:
in_chans (int): Number of input image channels. Default: 3
@ -227,7 +228,7 @@ class MetaNeXt(nn.Module):
depths (tuple(int)): Number of blocks at each stage. Default: (3, 3, 9, 3)
dims (tuple(int)): Feature dimension at each stage. Default: (96, 192, 384, 768)
token_mixers: Token mixer function. Default: nn.Identity
norm_layer: Normalziation layer. Default: nn.BatchNorm2d
norm_layer: Normalization layer. Default: nn.BatchNorm2d
act_layer: Activation function for MLP. Default: nn.GELU
mlp_ratios (int or tuple(int)): MLP ratios. Default: (4, 4, 4, 3)
head_fn: classifier head