diff --git a/.gitignore b/.gitignore
index ea4657e6..190d49e7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -122,6 +122,7 @@ venv.bak/
 *.log.json
 /work_dirs
 /mmcls/.mim
+.DS_Store
 
 # Pytorch
 *.pth
diff --git a/README.md b/README.md
index 0e1ca76d..a0ff08fd 100644
--- a/README.md
+++ b/README.md
@@ -138,6 +138,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea
 - [x] [EfficientNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientnet)
 - [x] [ConvNeXt](https://github.com/open-mmlab/mmclassification/tree/master/configs/convnext)
 - [x] [HRNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hrnet)
+- [x] [VAN](https://github.com/open-mmlab/mmclassification/tree/master/configs/van)
 - [x] [ConvMixer](https://github.com/open-mmlab/mmclassification/tree/master/configs/convmixer)
 - [x] [CSPNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/cspnet)
 - [x] [PoolFormer](https://github.com/open-mmlab/mmclassification/tree/master/configs/poolformer)
diff --git a/README_zh-CN.md b/README_zh-CN.md
index f80d7992..0354d578 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -136,6 +136,7 @@ pip3 install -e .
 - [x] [EfficientNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientnet)
 - [x] [ConvNeXt](https://github.com/open-mmlab/mmclassification/tree/master/configs/convnext)
 - [x] [HRNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hrnet)
+- [x] [VAN](https://github.com/open-mmlab/mmclassification/tree/master/configs/van)
 - [x] [ConvMixer](https://github.com/open-mmlab/mmclassification/tree/master/configs/convmixer)
 - [x] [CSPNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/cspnet)
 - [x] [PoolFormer](https://github.com/open-mmlab/mmclassification/tree/master/configs/poolformer)
diff --git a/configs/_base_/models/van/van_base.py b/configs/_base_/models/van/van_base.py
new file mode 100644
index 00000000..00645925
--- /dev/null
+++ b/configs/_base_/models/van/van_base.py
@@ -0,0 +1,13 @@
+# model settings
+model = dict(
+    type='ImageClassifier',
+    backbone=dict(type='VAN', arch='base', drop_path_rate=0.1),
+    neck=dict(type='GlobalAveragePooling'),
+    head=dict(
+        type='LinearClsHead',
+        num_classes=1000,
+        in_channels=512,
+        init_cfg=None,  # suppress the default init_cfg of LinearClsHead.
+        loss=dict(
+            type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+        cal_acc=False))
diff --git a/configs/_base_/models/van/van_large.py b/configs/_base_/models/van/van_large.py
new file mode 100644
index 00000000..4ebafabd
--- /dev/null
+++ b/configs/_base_/models/van/van_large.py
@@ -0,0 +1,13 @@
+# model settings
+model = dict(
+    type='ImageClassifier',
+    backbone=dict(type='VAN', arch='large', drop_path_rate=0.2),
+    neck=dict(type='GlobalAveragePooling'),
+    head=dict(
+        type='LinearClsHead',
+        num_classes=1000,
+        in_channels=512,
+        init_cfg=None,  # suppress the default init_cfg of LinearClsHead.
+        loss=dict(
+            type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+        cal_acc=False))
diff --git a/configs/_base_/models/van/van_small.py b/configs/_base_/models/van/van_small.py
new file mode 100644
index 00000000..320e90af
--- /dev/null
+++ b/configs/_base_/models/van/van_small.py
@@ -0,0 +1,21 @@
+# model settings
+model = dict(
+    type='ImageClassifier',
+    backbone=dict(type='VAN', arch='small', drop_path_rate=0.1),
+    neck=dict(type='GlobalAveragePooling'),
+    head=dict(
+        type='LinearClsHead',
+        num_classes=1000,
+        in_channels=512,
+        init_cfg=None,  # suppress the default init_cfg of LinearClsHead.
+        loss=dict(
+            type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+        cal_acc=False),
+    init_cfg=[
+        dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
+        dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
+    ],
+    train_cfg=dict(augments=[
+        dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
+        dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
+    ]))
diff --git a/configs/_base_/models/van/van_tiny.py b/configs/_base_/models/van/van_tiny.py
new file mode 100644
index 00000000..42791ac3
--- /dev/null
+++ b/configs/_base_/models/van/van_tiny.py
@@ -0,0 +1,21 @@
+# model settings
+model = dict(
+    type='ImageClassifier',
+    backbone=dict(type='VAN', arch='tiny', drop_path_rate=0.1),
+    neck=dict(type='GlobalAveragePooling'),
+    head=dict(
+        type='LinearClsHead',
+        num_classes=1000,
+        in_channels=256,
+        init_cfg=None,  # suppress the default init_cfg of LinearClsHead.
+        loss=dict(
+            type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+        cal_acc=False),
+    init_cfg=[
+        dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
+        dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
+    ],
+    train_cfg=dict(augments=[
+        dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
+        dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
+    ]))
diff --git a/configs/van/README.md b/configs/van/README.md
new file mode 100644
index 00000000..b356621c
--- /dev/null
+++ b/configs/van/README.md
@@ -0,0 +1,37 @@
+# Visual Attention Network
+
+> [Visual Attention Network](https://arxiv.org/pdf/2202.09741v2.pdf)
+<!-- [ALGORITHM] -->
+
+## Abstract
+
+While originally designed for natural language processing (NLP) tasks, the self-attention mechanism has recently taken various computer vision areas by storm. However, the 2D nature of images brings three challenges for applying self-attention in computer vision. (1) Treating images as 1D sequences neglects their 2D structures. (2) The quadratic complexity is too expensive for high-resolution images. (3) It only captures spatial adaptability but ignores channel adaptability. In this paper, we propose a novel large kernel attention (LKA) module to enable self-adaptive and long-range correlations in self-attention while avoiding the above issues. We further introduce a novel neural network based on LKA, namely Visual Attention Network (VAN). While extremely simple and efficient, VAN outperforms the state-of-the-art vision transformers and convolutional neural networks with a large margin in extensive experiments, including image classification, object detection, semantic segmentation, instance segmentation, etc.
+
+<div align=center>
+<img src="https://user-images.githubusercontent.com/24734142/157409411-2f622ba7-553c-4702-91be-eba03f9ea04f.png" width="80%"/>
+</div>
+
+
+## Results and models
+
+### ImageNet-1k
+
+|   Model   |   Pretrain   | resolution  | Params(M) |  Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
+|:---------:|:------------:|:-----------:|:---------:|:---------:|:---------:|:---------:|:------:|:--------:|
+|  VAN-T\*   | From scratch |   224x224   |   4.11   |    0.88   |   75.41   |   93.02   |  [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-tiny_8xb128_in1k.py)  | [model](https://download.openmmlab.com/mmclassification/v0/van/van-tiny_8xb128_in1k_20220427-8ac0feec.pth)  |
+|  VAN-S\*   | From scratch |   224x224   |   13.86   |    2.52   |   81.01   |    95.63   | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-small_8xb128_in1k.py)  | [model](https://download.openmmlab.com/mmclassification/v0/van/van-small_8xb128_in1k_20220427-bd6a9edd.pth)  |
+|  VAN-B\*   | From scratch |   224x224   |   26.58   |   5.03   |   82.80   |    96.21   |   [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-base_8xb128_in1k.py)  | [model](https://download.openmmlab.com/mmclassification/v0/van/van-base_8xb128_in1k_20220427-5275471d.pth)  |
+|  VAN-L\* | From scratch |   224x224   |   44.77   |    8.99   |   83.86   |    96.73   |   [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-large_8xb128_in1k.py)  | [model](https://download.openmmlab.com/mmclassification/v0/van/van-large_8xb128_in1k_20220427-56159105.pth)  |
+
+*Models with \* are converted from [the official repo](https://github.com/Visual-Attention-Network/VAN-Classification). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.
+
+## Citation
+
+```
+@article{guo2022visual,
+  title={Visual Attention Network},
+  author={Guo, Meng-Hao and Lu, Cheng-Ze and Liu, Zheng-Ning and Cheng, Ming-Ming and Hu, Shi-Min},
+  journal={arXiv preprint arXiv:2202.09741},
+  year={2022}
+}
+```
diff --git a/configs/van/metafile.yml b/configs/van/metafile.yml
new file mode 100644
index 00000000..26b40558
--- /dev/null
+++ b/configs/van/metafile.yml
@@ -0,0 +1,70 @@
+Collections:
+  - Name: Visual-Attention-Network
+    Metadata:
+      Training Data: ImageNet-1k
+      Training Techniques:
+        - AdamW
+        - Weight Decay
+      Architecture:
+        - Visual Attention Network
+    Paper:
+      URL: https://arxiv.org/pdf/2202.09741v2.pdf
+      Title: "Visual Attention Network"
+    README: configs/van/README.md
+    Code:
+      URL: https://github.com/open-mmlab/mmclassification/blob/v0.23.0/mmcls/models/backbones/van.py
+      Version: v0.23.0
+
+Models:
+  - Name: van-tiny_8xb128_in1k
+    Metadata:
+      FLOPs: 4110000      # 4.11M
+      Parameters: 880000000   # 0.88G
+    In Collection: Visual-Attention-Network
+    Results:
+      - Dataset: ImageNet-1k
+        Metrics:
+          Top 1 Accuracy: 75.41
+          Top 5 Accuracy: 93.02
+        Task: Image Classification
+    Weights: https://download.openmmlab.com/mmclassification/v0/van/van-tiny_8xb128_in1k_20220427-8ac0feec.pth
+    Config: configs/van/van-tiny_8xb128_in1k.py
+  - Name: van-small_8xb128_in1k
+    Metadata:
+      FLOPs:  13860000          # 13.86M
+      Parameters: 2520000000    # 2.52G
+    In Collection: Visual-Attention-Network
+    Results:
+        - Dataset: ImageNet-1k
+          Metrics:
+            Top 1 Accuracy: 81.01
+            Top 5 Accuracy: 95.63
+          Task: Image Classification
+    Weights: https://download.openmmlab.com/mmclassification/v0/van/van-small_8xb128_in1k_20220427-bd6a9edd.pth
+    Config: configs/van/van-small_8xb128_in1k.py
+  - Name: van-base_8xb128_in1k
+    Metadata:
+      FLOPs: 26580000            # 26.58M
+      Parameters: 5030000000                # 5.03G
+    In Collection: Visual-Attention-Network
+    Results:
+        - Dataset: ImageNet-1k
+          Metrics:
+            Top 1 Accuracy: 82.80
+            Top 5 Accuracy: 96.21
+          Task: Image Classification
+    Weights: https://download.openmmlab.com/mmclassification/v0/van/van-base_8xb128_in1k_20220427-5275471d.pth
+    Config: configs/van/van-base_8xb128_in1k.py
+  - Name: van-large_8xb128_in1k
+    Metadata:
+      FLOPs: 44770000              # 44.77 M
+      Parameters: 8990000000              # 8.99G
+    In Collection: Visual-Attention-Network
+    Results:
+        - Dataset: ImageNet-1k
+          Metrics:
+            Top 1 Accuracy: 83.86
+            Top 5 Accuracy: 96.73
+          Task: Image Classification
+    Weights: https://download.openmmlab.com/mmclassification/v0/van/van-large_8xb128_in1k_20220427-56159105.pth
+    Config: configs/van/van-large_8xb128_in1k.py
diff --git a/configs/van/van-base_8xb128_in1k.py b/configs/van/van-base_8xb128_in1k.py
new file mode 100644
index 00000000..704f111b
--- /dev/null
+++ b/configs/van/van-base_8xb128_in1k.py
@@ -0,0 +1,61 @@
+_base_ = [
+    '../_base_/models/van/van_base.py',
+    '../_base_/datasets/imagenet_bs64_swin_224.py',
+    '../_base_/schedules/imagenet_bs1024_adamw_swin.py',
+    '../_base_/default_runtime.py'
+]
+
+# Note that the mean and variance used here are different from other configs
+img_norm_cfg = dict(
+    mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='RandomResizedCrop',
+        size=224,
+        backend='pillow',
+        interpolation='bicubic'),
+    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
+    dict(
+        type='RandAugment',
+        policies={{_base_.rand_increasing_policies}},
+        num_policies=2,
+        total_level=10,
+        magnitude_level=9,
+        magnitude_std=0.5,
+        hparams=dict(
+            pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
+            interpolation='bicubic')),
+    dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
+    dict(
+        type='RandomErasing',
+        erase_prob=0.25,
+        mode='rand',
+        min_area_ratio=0.02,
+        max_area_ratio=1 / 3,
+        fill_color=img_norm_cfg['mean'][::-1],
+        fill_std=img_norm_cfg['std'][::-1]),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='ImageToTensor', keys=['img']),
+    dict(type='ToTensor', keys=['gt_label']),
+    dict(type='Collect', keys=['img', 'gt_label'])
+]
+
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='Resize',
+        size=(248, -1),
+        backend='pillow',
+        interpolation='bicubic'),
+    dict(type='CenterCrop', crop_size=224),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='ImageToTensor', keys=['img']),
+    dict(type='Collect', keys=['img'])
+]
+
+data = dict(
+    samples_per_gpu=128,
+    train=dict(pipeline=train_pipeline),
+    val=dict(pipeline=test_pipeline),
+    test=dict(pipeline=test_pipeline))
diff --git a/configs/van/van-large_8xb128_in1k.py b/configs/van/van-large_8xb128_in1k.py
new file mode 100644
index 00000000..b55aff16
--- /dev/null
+++ b/configs/van/van-large_8xb128_in1k.py
@@ -0,0 +1,61 @@
+_base_ = [
+    '../_base_/models/van/van_large.py',
+    '../_base_/datasets/imagenet_bs64_swin_224.py',
+    '../_base_/schedules/imagenet_bs1024_adamw_swin.py',
+    '../_base_/default_runtime.py'
+]
+
+# Note that the mean and variance used here are different from other configs
+img_norm_cfg = dict(
+    mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='RandomResizedCrop',
+        size=224,
+        backend='pillow',
+        interpolation='bicubic'),
+    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
+    dict(
+        type='RandAugment',
+        policies={{_base_.rand_increasing_policies}},
+        num_policies=2,
+        total_level=10,
+        magnitude_level=9,
+        magnitude_std=0.5,
+        hparams=dict(
+            pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
+            interpolation='bicubic')),
+    dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
+    dict(
+        type='RandomErasing',
+        erase_prob=0.25,
+        mode='rand',
+        min_area_ratio=0.02,
+        max_area_ratio=1 / 3,
+        fill_color=img_norm_cfg['mean'][::-1],
+        fill_std=img_norm_cfg['std'][::-1]),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='ImageToTensor', keys=['img']),
+    dict(type='ToTensor', keys=['gt_label']),
+    dict(type='Collect', keys=['img', 'gt_label'])
+]
+
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='Resize',
+        size=(248, -1),
+        backend='pillow',
+        interpolation='bicubic'),
+    dict(type='CenterCrop', crop_size=224),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='ImageToTensor', keys=['img']),
+    dict(type='Collect', keys=['img'])
+]
+
+data = dict(
+    samples_per_gpu=128,
+    train=dict(pipeline=train_pipeline),
+    val=dict(pipeline=test_pipeline),
+    test=dict(pipeline=test_pipeline))
diff --git a/configs/van/van-small_8xb128_in1k.py b/configs/van/van-small_8xb128_in1k.py
new file mode 100644
index 00000000..3b83e25a
--- /dev/null
+++ b/configs/van/van-small_8xb128_in1k.py
@@ -0,0 +1,61 @@
+_base_ = [
+    '../_base_/models/van/van_small.py',
+    '../_base_/datasets/imagenet_bs64_swin_224.py',
+    '../_base_/schedules/imagenet_bs1024_adamw_swin.py',
+    '../_base_/default_runtime.py'
+]
+
+# Note that the mean and variance used here are different from other configs
+img_norm_cfg = dict(
+    mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='RandomResizedCrop',
+        size=224,
+        backend='pillow',
+        interpolation='bicubic'),
+    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
+    dict(
+        type='RandAugment',
+        policies={{_base_.rand_increasing_policies}},
+        num_policies=2,
+        total_level=10,
+        magnitude_level=9,
+        magnitude_std=0.5,
+        hparams=dict(
+            pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
+            interpolation='bicubic')),
+    dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
+    dict(
+        type='RandomErasing',
+        erase_prob=0.25,
+        mode='rand',
+        min_area_ratio=0.02,
+        max_area_ratio=1 / 3,
+        fill_color=img_norm_cfg['mean'][::-1],
+        fill_std=img_norm_cfg['std'][::-1]),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='ImageToTensor', keys=['img']),
+    dict(type='ToTensor', keys=['gt_label']),
+    dict(type='Collect', keys=['img', 'gt_label'])
+]
+
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='Resize',
+        size=(248, -1),
+        backend='pillow',
+        interpolation='bicubic'),
+    dict(type='CenterCrop', crop_size=224),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='ImageToTensor', keys=['img']),
+    dict(type='Collect', keys=['img'])
+]
+
+data = dict(
+    samples_per_gpu=128,
+    train=dict(pipeline=train_pipeline),
+    val=dict(pipeline=test_pipeline),
+    test=dict(pipeline=test_pipeline))
diff --git a/configs/van/van-tiny_8xb128_in1k.py b/configs/van/van-tiny_8xb128_in1k.py
new file mode 100644
index 00000000..1e001c1c
--- /dev/null
+++ b/configs/van/van-tiny_8xb128_in1k.py
@@ -0,0 +1,61 @@
+_base_ = [
+    '../_base_/models/van/van_tiny.py',
+    '../_base_/datasets/imagenet_bs64_swin_224.py',
+    '../_base_/schedules/imagenet_bs1024_adamw_swin.py',
+    '../_base_/default_runtime.py'
+]
+
+# Note that the mean and variance used here are different from other configs
+img_norm_cfg = dict(
+    mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='RandomResizedCrop',
+        size=224,
+        backend='pillow',
+        interpolation='bicubic'),
+    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
+    dict(
+        type='RandAugment',
+        policies={{_base_.rand_increasing_policies}},
+        num_policies=2,
+        total_level=10,
+        magnitude_level=9,
+        magnitude_std=0.5,
+        hparams=dict(
+            pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
+            interpolation='bicubic')),
+    dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
+    dict(
+        type='RandomErasing',
+        erase_prob=0.25,
+        mode='rand',
+        min_area_ratio=0.02,
+        max_area_ratio=1 / 3,
+        fill_color=img_norm_cfg['mean'][::-1],
+        fill_std=img_norm_cfg['std'][::-1]),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='ImageToTensor', keys=['img']),
+    dict(type='ToTensor', keys=['gt_label']),
+    dict(type='Collect', keys=['img', 'gt_label'])
+]
+
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='Resize',
+        size=(248, -1),
+        backend='pillow',
+        interpolation='bicubic'),
+    dict(type='CenterCrop', crop_size=224),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='ImageToTensor', keys=['img']),
+    dict(type='Collect', keys=['img'])
+]
+
+data = dict(
+    samples_per_gpu=128,
+    train=dict(pipeline=train_pipeline),
+    val=dict(pipeline=test_pipeline),
+    test=dict(pipeline=test_pipeline))
diff --git a/docs/en/api/models.rst b/docs/en/api/models.rst
index 7b953024..687b8009 100644
--- a/docs/en/api/models.rst
+++ b/docs/en/api/models.rst
@@ -83,6 +83,7 @@ Backbones
    T2T_ViT
    TIMMBackbone
    TNT
+   VAN
    VGG
    VisionTransformer
 
diff --git a/docs/en/model_zoo.md b/docs/en/model_zoo.md
index 707ad094..6451b70a 100644
--- a/docs/en/model_zoo.md
+++ b/docs/en/model_zoo.md
@@ -133,6 +133,10 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
 | CSPDarkNet50\*  |  27.64 | 5.04 | 80.05 | 95.07  | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/cspnet/cspdarknet50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/cspnet/cspdarknet50_3rdparty_8xb32_in1k_20220329-bd275287.pth) |
 |  CSPResNet50\*  |  21.62 | 3.48 | 79.55 | 94.68  | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/cspnet/cspresnet50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/cspnet/cspresnet50_3rdparty_8xb32_in1k_20220329-dd6dddfb.pth) |
 |  CSPResNeXt50\* | 20.57 | 3.11 | 79.96 | 94.96 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/cspnet/cspresnext50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/cspnet/cspresnext50_3rdparty_8xb32_in1k_20220329-2cc84d21.pth) |
+|  VAN-T\*   |   4.11   |    0.88   |   75.41   |   93.02   |  [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-tiny_8xb128_in1k.py)  | [model](https://download.openmmlab.com/mmclassification/v0/van/van-tiny_8xb128_in1k_20220427-8ac0feec.pth)  |
+|  VAN-S\*   |   13.86   |    2.52   |   81.01   |    95.63   | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-small_8xb128_in1k.py)  | [model](https://download.openmmlab.com/mmclassification/v0/van/van-small_8xb128_in1k_20220427-bd6a9edd.pth)  |
+|  VAN-B\*   |   26.58   |   5.03   |   82.80   |    96.21   |   [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-base_8xb128_in1k.py)  | [model](https://download.openmmlab.com/mmclassification/v0/van/van-base_8xb128_in1k_20220427-5275471d.pth)  |
+|  VAN-L\*   |   44.77   |    8.99   |   83.86   |    96.73   |   [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-large_8xb128_in1k.py)  | [model](https://download.openmmlab.com/mmclassification/v0/van/van-large_8xb128_in1k_20220427-56159105.pth)  |
 
 *Models with \* are converted from other repos, others are trained by ourselves.*
 
diff --git a/docs/zh_CN/changelog.md b/docs/zh_CN/changelog.md
index 6b731cd0..4444256e 120000
--- a/docs/zh_CN/changelog.md
+++ b/docs/zh_CN/changelog.md
@@ -1 +1 @@
-../en/changelog.md
\ No newline at end of file
+../en/changelog.md
diff --git a/docs/zh_CN/model_zoo.md b/docs/zh_CN/model_zoo.md
index 013a9acc..df350774 120000
--- a/docs/zh_CN/model_zoo.md
+++ b/docs/zh_CN/model_zoo.md
@@ -1 +1 @@
-../en/model_zoo.md
\ No newline at end of file
+../en/model_zoo.md
diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py
index 2d72b662..13f565f4 100644
--- a/mmcls/models/backbones/__init__.py
+++ b/mmcls/models/backbones/__init__.py
@@ -29,46 +29,17 @@ from .t2t_vit import T2T_ViT
 from .timm_backbone import TIMMBackbone
 from .tnt import TNT
 from .twins import PCPVT, SVT
+from .van import VAN
 from .vgg import VGG
 from .vision_transformer import VisionTransformer
 
 __all__ = [
-    'LeNet5',
-    'AlexNet',
-    'VGG',
-    'RegNet',
-    'ResNet',
-    'ResNeXt',
-    'ResNetV1d',
-    'ResNeSt',
-    'ResNet_CIFAR',
-    'SEResNet',
-    'SEResNeXt',
-    'ShuffleNetV1',
-    'ShuffleNetV2',
-    'MobileNetV2',
-    'MobileNetV3',
-    'VisionTransformer',
-    'SwinTransformer',
-    'TNT',
-    'TIMMBackbone',
-    'T2T_ViT',
-    'Res2Net',
-    'RepVGG',
-    'Conformer',
-    'MlpMixer',
-    'DistilledVisionTransformer',
-    'PCPVT',
-    'SVT',
-    'EfficientNet',
-    'ConvNeXt',
-    'HRNet',
-    'ResNetV1c',
-    'ConvMixer',
-    'CSPDarkNet',
-    'CSPResNet',
-    'CSPResNeXt',
-    'CSPNet',
-    'RepMLPNet',
-    'PoolFormer',
+    'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d',
+    'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
+    'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
+    'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG',
+    'Conformer', 'MlpMixer', 'DistilledVisionTransformer', 'PCPVT', 'SVT',
+    'EfficientNet', 'ConvNeXt', 'HRNet', 'ResNetV1c', 'ConvMixer',
+    'CSPDarkNet', 'CSPResNet', 'CSPResNeXt', 'CSPNet', 'RepMLPNet',
+    'PoolFormer', 'VAN'
 ]
diff --git a/mmcls/models/backbones/van.py b/mmcls/models/backbones/van.py
new file mode 100644
index 00000000..4022cc0d
--- /dev/null
+++ b/mmcls/models/backbones/van.py
@@ -0,0 +1,434 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
+from mmcv.cnn.bricks import DropPath
+from mmcv.cnn.bricks.transformer import PatchEmbed
+from mmcv.runner import BaseModule, ModuleList
+from mmcv.utils.parrots_wrapper import _BatchNorm
+
+from ..builder import BACKBONES
+from .base_backbone import BaseBackbone
+
+
+class MixFFN(BaseModule):
+    """An implementation of MixFFN of VAN. Refer to
+    mmdetection/mmdet/models/backbones/pvt.py.
+
+    The differences between MixFFN & FFN:
+        1. Use 1X1 Conv to replace Linear layer.
+        2. Introduce 3X3 Depth-wise Conv to encode positional information.
+
+    Args:
+        embed_dims (int): The feature dimension. Same as
+            `MultiheadAttention`.
+        feedforward_channels (int): The hidden dimension of FFNs.
+        act_cfg (dict, optional): The activation config for FFNs.
+            Default: dict(type='GELU').
+        ffn_drop (float, optional): Probability of an element to be
+            zeroed in FFN. Default 0.0.
+        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+            Default: None.
+    """
+
+    def __init__(self,
+                 embed_dims,
+                 feedforward_channels,
+                 act_cfg=dict(type='GELU'),
+                 ffn_drop=0.,
+                 init_cfg=None):
+        super(MixFFN, self).__init__(init_cfg=init_cfg)
+
+        self.embed_dims = embed_dims
+        self.feedforward_channels = feedforward_channels
+        self.act_cfg = act_cfg
+
+        self.fc1 = Conv2d(
+            in_channels=embed_dims,
+            out_channels=feedforward_channels,
+            kernel_size=1)
+        self.dwconv = Conv2d(
+            in_channels=feedforward_channels,
+            out_channels=feedforward_channels,
+            kernel_size=3,
+            stride=1,
+            padding=1,
+            bias=True,
+            groups=feedforward_channels)
+        self.act = build_activation_layer(act_cfg)
+        self.fc2 = Conv2d(
+            in_channels=feedforward_channels,
+            out_channels=embed_dims,
+            kernel_size=1)
+        self.drop = nn.Dropout(ffn_drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.dwconv(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class LKA(BaseModule):
+    """Large Kernel Attention(LKA) of VAN.
+
+    .. code:: text
+            DW_conv (depth-wise convolution)
+                            |
+                            |
+        DW_D_conv (depth-wise dilation convolution)
+                            |
+                            |
+        Transition Convolution (1×1 convolution)
+
+    Args:
+        embed_dims (int): Number of input channels.
+        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+            Default: None.
+    """
+
+    def __init__(self, embed_dims, init_cfg=None):
+        super(LKA, self).__init__(init_cfg=init_cfg)
+
+        # a spatial local convolution (depth-wise convolution)
+        self.DW_conv = Conv2d(
+            in_channels=embed_dims,
+            out_channels=embed_dims,
+            kernel_size=5,
+            padding=2,
+            groups=embed_dims)
+
+        # a spatial long-range convolution (depth-wise dilation convolution)
+        self.DW_D_conv = Conv2d(
+            in_channels=embed_dims,
+            out_channels=embed_dims,
+            kernel_size=7,
+            stride=1,
+            padding=9,
+            groups=embed_dims,
+            dilation=3)
+
+        self.conv1 = Conv2d(
+            in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
+
+    def forward(self, x):
+        u = x.clone()
+        attn = self.DW_conv(x)
+        attn = self.DW_D_conv(attn)
+        attn = self.conv1(attn)
+
+        return u * attn
+
+
+class SpatialAttention(BaseModule):
+    """Basic attention module in VANBloack.
+
+    Args:
+        embed_dims (int): Number of input channels.
+        act_cfg (dict, optional): The activation config for FFNs.
+            Default: dict(type='GELU').
+        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+            Default: None.
+    """
+
+    def __init__(self, embed_dims, act_cfg=dict(type='GELU'), init_cfg=None):
+        super(SpatialAttention, self).__init__(init_cfg=init_cfg)
+
+        self.proj_1 = Conv2d(
+            in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
+        self.activation = build_activation_layer(act_cfg)
+        self.spatial_gating_unit = LKA(embed_dims)
+        self.proj_2 = Conv2d(
+            in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
+
+    def forward(self, x):
+        shorcut = x.clone()
+        x = self.proj_1(x)
+        x = self.activation(x)
+        x = self.spatial_gating_unit(x)
+        x = self.proj_2(x)
+        x = x + shorcut
+        return x
+
+
+class VANBlock(BaseModule):
+    """A block of VAN.
+
+    Args:
+        embed_dims (int): Number of input channels.
+        ffn_ratio (float): The expansion ratio of feedforward network hidden
+            layer channels. Defaults to 4.
+        drop_rate (float): Dropout rate after embedding. Defaults to 0.
+        drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
+        act_cfg (dict, optional): The activation config for FFNs.
+            Default: dict(type='GELU').
+        layer_scale_init_value (float): Init value for Layer Scale.
+            Defaults to 1e-2.
+        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+            Default: None.
+    """
+
+    def __init__(self,
+                 embed_dims,
+                 ffn_ratio=4.,
+                 drop_rate=0.,
+                 drop_path_rate=0.,
+                 act_cfg=dict(type='GELU'),
+                 norm_cfg=dict(type='BN', eps=1e-5),
+                 layer_scale_init_value=1e-2,
+                 init_cfg=None):
+        super(VANBlock, self).__init__(init_cfg=init_cfg)
+        self.out_channels = embed_dims
+
+        self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
+        self.attn = SpatialAttention(embed_dims, act_cfg=act_cfg)
+        self.drop_path = DropPath(
+            drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+
+        self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
+        mlp_hidden_dim = int(embed_dims * ffn_ratio)
+        self.mlp = MixFFN(
+            embed_dims=embed_dims,
+            feedforward_channels=mlp_hidden_dim,
+            act_cfg=act_cfg,
+            ffn_drop=drop_rate)
+        self.layer_scale_1 = nn.Parameter(
+            layer_scale_init_value * torch.ones((embed_dims)),
+            requires_grad=True) if layer_scale_init_value > 0 else None
+        self.layer_scale_2 = nn.Parameter(
+            layer_scale_init_value * torch.ones((embed_dims)),
+            requires_grad=True) if layer_scale_init_value > 0 else None
+
+    def forward(self, x):
+        identity = x
+        x = self.norm1(x)
+        x = self.attn(x)
+        if self.layer_scale_1 is not None:
+            x = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * x
+        x = identity + self.drop_path(x)
+
+        identity = x
+        x = self.norm2(x)
+        x = self.mlp(x)
+        if self.layer_scale_2 is not None:
+            x = self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * x
+        x = identity + self.drop_path(x)
+
+        return x
+
+
+class VANPatchEmbed(PatchEmbed):
+    """Image to Patch Embedding of VAN.
+
+    The differences between VANPatchEmbed & PatchEmbed:
+        1. Use BN.
+        2. Do not use 'flatten' and 'transpose'.
+    """
+
+    def __init__(self, *args, norm_cfg=dict(type='BN'), **kwargs):
+        super(VANPatchEmbed, self).__init__(*args, norm_cfg=norm_cfg, **kwargs)
+
+    def forward(self, x):
+        """
+        Args:
+            x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
+        Returns:
+            tuple: Contains merged results and its spatial shape.
+            - x (Tensor): Has shape (B, out_h * out_w, embed_dims)
+            - out_size (tuple[int]): Spatial shape of x, arrange as
+              (out_h, out_w).
+        """
+
+        if self.adaptive_padding:
+            x = self.adaptive_padding(x)
+
+        x = self.projection(x)
+        out_size = (x.shape[2], x.shape[3])
+        if self.norm is not None:
+            x = self.norm(x)
+        return x, out_size
+
+
+@BACKBONES.register_module()
+class VAN(BaseBackbone):
+    """Visual Attention Network.
+
+    A PyTorch implement of : `Visual Attention Network
+    <https://arxiv.org/pdf/2202.09741v2.pdf>`_
+
+    Inspiration from
+    https://github.com/Visual-Attention-Network/VAN-Classification
+
+    Args:
+        arch (str | dict): Visual Attention Network architecture.
+            If use string, choose from 'tiny', 'small', 'base' and 'large'.
+            If use dict, it should have below keys:
+
+            - **embed_dims** (List[int]): The dimensions of embedding.
+            - **depths** (List[int]): The number of blocks in each stage.
+            - **ffn_ratios** (List[int]): The number of expansion ratio of
+            feedforward network hidden layer channels.
+
+            Defaults to 'tiny'.
+        patch_sizes (List[int | tuple]): The patch size in patch embeddings.
+            Defaults to [7, 3, 3, 3].
+        in_channels (int): The num of input channels. Defaults to 3.
+        drop_rate (float): Dropout rate after embedding. Defaults to 0.
+        drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
+        out_indices (Sequence[int]): Output from which stages.
+            Default: ``(3, )``.
+        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+            -1 means not freezing any parameters. Defaults to -1.
+        norm_eval (bool): Whether to set norm layers to eval mode, namely,
+            freeze running stats (mean and var). Note: Effect on Batch Norm
+            and its variants only. Defaults to False.
+        norm_cfg (dict): Config dict for normalization layer for all output
+            features. Defaults to ``dict(type='LN')``
+        block_cfgs (Sequence[dict] | dict): The extra config of each block.
+            Defaults to empty dicts.
+        init_cfg (dict, optional): The Config for initialization.
+            Defaults to None.
+
+    Examples:
+        >>> from mmcls.models import VAN
+        >>> import torch
+        >>> cfg = dict(arch='tiny')
+        >>> model = VAN(**cfg)
+        >>> inputs = torch.rand(1, 3, 224, 224)
+        >>> outputs = model(inputs)
+        >>> for out in outputs:
+        >>>     print(out.size())
+        (1, 256, 7, 7)
+    """
+    arch_zoo = {
+        **dict.fromkeys(['t', 'tiny'],
+                        {'embed_dims': [32, 64, 160, 256],
+                         'depths': [3, 3, 5, 2],
+                         'ffn_ratios': [8, 8, 4, 4]}),
+        **dict.fromkeys(['s', 'small'],
+                        {'embed_dims': [64, 128, 320, 512],
+                         'depths': [2, 2, 4, 2],
+                         'ffn_ratios': [8, 8, 4, 4]}),
+        **dict.fromkeys(['b', 'base'],
+                        {'embed_dims': [64, 128, 320, 512],
+                         'depths': [3, 3, 12, 3],
+                         'ffn_ratios': [8, 8, 4, 4]}),
+        **dict.fromkeys(['l', 'large'],
+                        {'embed_dims': [64, 128, 320, 512],
+                         'depths': [3, 5, 27, 3],
+                         'ffn_ratios': [8, 8, 4, 4]}),
+    }  # yapf: disable
+
+    def __init__(self,
+                 arch='tiny',
+                 patch_sizes=[7, 3, 3, 3],
+                 in_channels=3,
+                 drop_rate=0.,
+                 drop_path_rate=0.,
+                 out_indices=(3, ),
+                 frozen_stages=-1,
+                 norm_eval=False,
+                 norm_cfg=dict(type='LN'),
+                 block_cfgs=dict(),
+                 init_cfg=None):
+        super(VAN, self).__init__(init_cfg=init_cfg)
+
+        if isinstance(arch, str):
+            arch = arch.lower()
+            assert arch in set(self.arch_zoo), \
+                f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
+            self.arch_settings = self.arch_zoo[arch]
+        else:
+            essential_keys = {'embed_dims', 'depths', 'ffn_ratios'}
+            assert isinstance(arch, dict) and set(arch) == essential_keys, \
+                f'Custom arch needs a dict with keys {essential_keys}'
+            self.arch_settings = arch
+
+        self.embed_dims = self.arch_settings['embed_dims']
+        self.depths = self.arch_settings['depths']
+        self.ffn_ratios = self.arch_settings['ffn_ratios']
+        self.num_stages = len(self.depths)
+        self.out_indices = out_indices
+        self.frozen_stages = frozen_stages
+        self.norm_eval = norm_eval
+
+        total_depth = sum(self.depths)
+        dpr = [
+            x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
+        ]  # stochastic depth decay rule
+
+        cur_block_idx = 0
+        for i, depth in enumerate(self.depths):
+            patch_embed = VANPatchEmbed(
+                in_channels=in_channels if i == 0 else self.embed_dims[i - 1],
+                input_size=None,
+                embed_dims=self.embed_dims[i],
+                kernel_size=patch_sizes[i],
+                stride=patch_sizes[i] // 2 + 1,
+                padding=(patch_sizes[i] // 2, patch_sizes[i] // 2),
+                norm_cfg=dict(type='BN'))
+
+            blocks = ModuleList([
+                VANBlock(
+                    embed_dims=self.embed_dims[i],
+                    ffn_ratio=self.ffn_ratios[i],
+                    drop_rate=drop_rate,
+                    drop_path_rate=dpr[cur_block_idx + j],
+                    **block_cfgs) for j in range(depth)
+            ])
+            cur_block_idx += depth
+            norm = build_norm_layer(norm_cfg, self.embed_dims[i])[1]
+
+            self.add_module(f'patch_embed{i + 1}', patch_embed)
+            self.add_module(f'blocks{i + 1}', blocks)
+            self.add_module(f'norm{i + 1}', norm)
+
+    def train(self, mode=True):
+        super(VAN, self).train(mode)
+        self._freeze_stages()
+        if mode and self.norm_eval:
+            for m in self.modules():
+                # trick: eval have effect on BatchNorm only
+                if isinstance(m, _BatchNorm):
+                    m.eval()
+
+    def _freeze_stages(self):
+        for i in range(0, self.frozen_stages + 1):
+            # freeze patch embed
+            m = getattr(self, f'patch_embed{i + 1}')
+            m.eval()
+            for param in m.parameters():
+                param.requires_grad = False
+
+            # freeze blocks
+            m = getattr(self, f'blocks{i + 1}')
+            m.eval()
+            for param in m.parameters():
+                param.requires_grad = False
+
+            # freeze norm
+            m = getattr(self, f'norm{i + 1}')
+            m.eval()
+            for param in m.parameters():
+                param.requires_grad = False
+
+    def forward(self, x):
+        outs = []
+        for i in range(self.num_stages):
+            patch_embed = getattr(self, f'patch_embed{i + 1}')
+            blocks = getattr(self, f'blocks{i + 1}')
+            norm = getattr(self, f'norm{i + 1}')
+            x, hw_shape = patch_embed(x)
+            for block in blocks:
+                x = block(x)
+            x = x.flatten(2).transpose(1, 2)
+            x = norm(x)
+            x = x.reshape(-1, *hw_shape,
+                          block.out_channels).permute(0, 3, 1, 2).contiguous()
+            if i in self.out_indices:
+                outs.append(x)
+
+        return tuple(outs)
diff --git a/model-index.yml b/model-index.yml
index 5d5c767a..81932fd6 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -22,6 +22,7 @@ Import:
   - configs/hrnet/metafile.yml
   - configs/repmlp/metafile.yml
   - configs/wrn/metafile.yml
+  - configs/van/metafile.yml
   - configs/cspnet/metafile.yml
   - configs/convmixer/metafile.yml
   - configs/poolformer/metafile.yml
diff --git a/tests/test_models/test_backbones/test_van.py b/tests/test_models/test_backbones/test_van.py
new file mode 100644
index 00000000..136ce973
--- /dev/null
+++ b/tests/test_models/test_backbones/test_van.py
@@ -0,0 +1,188 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from copy import deepcopy
+from itertools import chain
+from unittest import TestCase
+
+import torch
+from mmcv.utils.parrots_wrapper import _BatchNorm
+from torch import nn
+
+from mmcls.models.backbones import VAN
+
+
+def check_norm_state(modules, train_state):
+    """Check if norm layer is in correct train state."""
+    for mod in modules:
+        if isinstance(mod, _BatchNorm):
+            if mod.training != train_state:
+                return False
+    return True
+
+
+class TestVAN(TestCase):
+
+    def setUp(self):
+        self.cfg = dict(arch='t', drop_path_rate=0.1)
+
+    def test_arch(self):
+        # Test invalid default arch
+        with self.assertRaisesRegex(AssertionError, 'not in default archs'):
+            cfg = deepcopy(self.cfg)
+            cfg['arch'] = 'unknown'
+            VAN(**cfg)
+
+        # Test invalid custom arch
+        with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
+            cfg = deepcopy(self.cfg)
+            cfg['arch'] = {
+                'embed_dims': [32, 64, 160, 256],
+                'ffn_ratios': [8, 8, 4, 4],
+            }
+            VAN(**cfg)
+
+        # Test custom arch
+        cfg = deepcopy(self.cfg)
+        embed_dims = [32, 64, 160, 256]
+        depths = [3, 3, 5, 2]
+        ffn_ratios = [8, 8, 4, 4]
+        cfg['arch'] = {
+            'embed_dims': embed_dims,
+            'depths': depths,
+            'ffn_ratios': ffn_ratios
+        }
+        model = VAN(**cfg)
+
+        for i in range(len(depths)):
+            stage = getattr(model, f'blocks{i + 1}')
+            self.assertEqual(stage[-1].out_channels, embed_dims[i])
+            self.assertEqual(len(stage), depths[i])
+
+    def test_init_weights(self):
+        # test weight init cfg
+        cfg = deepcopy(self.cfg)
+        cfg['init_cfg'] = [
+            dict(
+                type='Kaiming',
+                layer='Conv2d',
+                mode='fan_in',
+                nonlinearity='linear')
+        ]
+        model = VAN(**cfg)
+        ori_weight = model.patch_embed1.projection.weight.clone().detach()
+
+        model.init_weights()
+        initialized_weight = model.patch_embed1.projection.weight
+        self.assertFalse(torch.allclose(ori_weight, initialized_weight))
+
+    def test_forward(self):
+        imgs = torch.randn(3, 3, 224, 224)
+
+        cfg = deepcopy(self.cfg)
+        model = VAN(**cfg)
+        outs = model(imgs)
+        self.assertIsInstance(outs, tuple)
+        self.assertEqual(len(outs), 1)
+        feat = outs[-1]
+        self.assertEqual(feat.shape, (3, 256, 7, 7))
+
+        # test with patch_sizes
+        cfg = deepcopy(self.cfg)
+        cfg['patch_sizes'] = [7, 5, 5, 5]
+        model = VAN(**cfg)
+        outs = model(torch.randn(3, 3, 224, 224))
+        self.assertIsInstance(outs, tuple)
+        self.assertEqual(len(outs), 1)
+        feat = outs[-1]
+        self.assertEqual(feat.shape, (3, 256, 3, 3))
+
+        # test multiple output indices
+        cfg = deepcopy(self.cfg)
+        cfg['out_indices'] = (0, 1, 2, 3)
+        model = VAN(**cfg)
+        outs = model(imgs)
+        self.assertIsInstance(outs, tuple)
+        self.assertEqual(len(outs), 4)
+        for emb_size, stride, out in zip([32, 64, 160, 256], [1, 2, 4, 8],
+                                         outs):
+            self.assertEqual(out.shape,
+                             (3, emb_size, 56 // stride, 56 // stride))
+
+        # test with dynamic input shape
+        imgs1 = torch.randn(3, 3, 224, 224)
+        imgs2 = torch.randn(3, 3, 256, 256)
+        imgs3 = torch.randn(3, 3, 256, 309)
+        cfg = deepcopy(self.cfg)
+        model = VAN(**cfg)
+        for imgs in [imgs1, imgs2, imgs3]:
+            outs = model(imgs)
+            self.assertIsInstance(outs, tuple)
+            self.assertEqual(len(outs), 1)
+            feat = outs[-1]
+            expect_feat_shape = (math.ceil(imgs.shape[2] / 32),
+                                 math.ceil(imgs.shape[3] / 32))
+            self.assertEqual(feat.shape, (3, 256, *expect_feat_shape))
+
+    def test_structure(self):
+        # test drop_path_rate decay
+        cfg = deepcopy(self.cfg)
+        cfg['drop_path_rate'] = 0.2
+        model = VAN(**cfg)
+        depths = model.arch_settings['depths']
+        stages = [model.blocks1, model.blocks2, model.blocks3, model.blocks4]
+        blocks = chain(*[stage for stage in stages])
+        total_depth = sum(depths)
+        dpr = [
+            x.item()
+            for x in torch.linspace(0, cfg['drop_path_rate'], total_depth)
+        ]
+        for i, (block, expect_prob) in enumerate(zip(blocks, dpr)):
+            if expect_prob == 0:
+                assert isinstance(block.drop_path, nn.Identity)
+            else:
+                self.assertAlmostEqual(block.drop_path.drop_prob, expect_prob)
+
+        # test VAN with norm_eval=True
+        cfg = deepcopy(self.cfg)
+        cfg['norm_eval'] = True
+        cfg['norm_cfg'] = dict(type='BN')
+        model = VAN(**cfg)
+        model.init_weights()
+        model.train()
+        self.assertTrue(check_norm_state(model.modules(), False))
+
+        # test VAN with first stage frozen.
+        cfg = deepcopy(self.cfg)
+        frozen_stages = 0
+        cfg['frozen_stages'] = frozen_stages
+        cfg['out_indices'] = (0, 1, 2, 3)
+        model = VAN(**cfg)
+        model.init_weights()
+        model.train()
+
+        # the patch_embed and first stage should not require grad.
+        self.assertFalse(model.patch_embed1.training)
+        for param in model.patch_embed1.parameters():
+            self.assertFalse(param.requires_grad)
+        for i in range(frozen_stages + 1):
+            patch = getattr(model, f'patch_embed{i+1}')
+            for param in patch.parameters():
+                self.assertFalse(param.requires_grad)
+            blocks = getattr(model, f'blocks{i + 1}')
+            for param in blocks.parameters():
+                self.assertFalse(param.requires_grad)
+            norm = getattr(model, f'norm{i + 1}')
+            for param in norm.parameters():
+                self.assertFalse(param.requires_grad)
+
+        # the second stage should require grad.
+        for i in range(frozen_stages + 1, 4):
+            patch = getattr(model, f'patch_embed{i + 1}')
+            for param in patch.parameters():
+                self.assertTrue(param.requires_grad)
+            blocks = getattr(model, f'blocks{i+1}')
+            for param in blocks.parameters():
+                self.assertTrue(param.requires_grad)
+            norm = getattr(model, f'norm{i + 1}')
+            for param in norm.parameters():
+                self.assertTrue(param.requires_grad)
diff --git a/tools/convert_models/van2mmcls.py b/tools/convert_models/van2mmcls.py
new file mode 100644
index 00000000..5ea7d9ca
--- /dev/null
+++ b/tools/convert_models/van2mmcls.py
@@ -0,0 +1,65 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import os.path as osp
+from collections import OrderedDict
+
+import mmcv
+import torch
+from mmcv.runner import CheckpointLoader
+
+
+def convert_van(ckpt):
+
+    new_ckpt = OrderedDict()
+
+    for k, v in list(ckpt.items()):
+        new_v = v
+        if k.startswith('head'):
+            new_k = k.replace('head.', 'head.fc.')
+            new_ckpt[new_k] = new_v
+            continue
+        elif k.startswith('patch_embed'):
+            if 'proj.' in k:
+                new_k = k.replace('proj.', 'projection.')
+            else:
+                new_k = k
+        elif k.startswith('block'):
+            new_k = k.replace('block', 'blocks')
+            if 'attn.spatial_gating_unit' in new_k:
+                new_k = new_k.replace('conv0', 'DW_conv')
+                new_k = new_k.replace('conv_spatial', 'DW_D_conv')
+            if 'dwconv.dwconv' in new_k:
+                new_k = new_k.replace('dwconv.dwconv', 'dwconv')
+        else:
+            new_k = k
+
+        if not new_k.startswith('head'):
+            new_k = 'backbone.' + new_k
+        new_ckpt[new_k] = new_v
+    return new_ckpt
+
+
+def main():
+    parser = argparse.ArgumentParser(
+        description='Convert keys in pretrained van models to mmcls style.')
+    parser.add_argument('src', help='src model path or url')
+    # The dst path must be a full path of the new checkpoint.
+    parser.add_argument('dst', help='save path')
+    args = parser.parse_args()
+
+    checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
+
+    if 'state_dict' in checkpoint:
+        state_dict = checkpoint['state_dict']
+    else:
+        state_dict = checkpoint
+
+    weight = convert_van(state_dict)
+    mmcv.mkdir_or_exist(osp.dirname(args.dst))
+    torch.save(weight, args.dst)
+
+    print('Done!!')
+
+
+if __name__ == '__main__':
+    main()