[Feature] Support ConvMixer. (#716)
* basic support for ConvMixer * simplify * add data pipeine config for timm * Add model readme and metafile * add unittest for convmixer * add copyright * modify * add tests * update model * add conv2dAdaptivePadding replacement * update model index * fix comments * Update checkpoint path Co-authored-by: mzr1996 <mzr1996@163.com>pull/745/head
parent
3482521587
commit
04cb42a768
|
@ -126,6 +126,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea
|
|||
- [x] [EfficientNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientnet)
|
||||
- [x] [ConvNeXt](https://github.com/open-mmlab/mmclassification/tree/master/configs/convnext)
|
||||
- [x] [HRNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hrnet)
|
||||
- [x] [ConvMixer](https://github.com/open-mmlab/mmclassification/tree/master/configs/convmixer)
|
||||
|
||||
</details>
|
||||
|
||||
|
|
|
@ -124,6 +124,7 @@ MMClassification 是一款基于 PyTorch 的开源图像分类工具箱,是 [O
|
|||
- [x] [EfficientNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientnet)
|
||||
- [x] [ConvNeXt](https://github.com/open-mmlab/mmclassification/tree/master/configs/convnext)
|
||||
- [x] [HRNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hrnet)
|
||||
- [x] [ConvMixer](https://github.com/open-mmlab/mmclassification/tree/master/configs/convmixer)
|
||||
|
||||
</details>
|
||||
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
_base_ = ['./pipelines/rand_aug.py']
|
||||
|
||||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
size=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies={{_base_.rand_increasing_policies}},
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
|
||||
interpolation='bicubic')),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=img_norm_cfg['mean'][::-1],
|
||||
fill_std=img_norm_cfg['std'][::-1]),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
size=(233, -1),
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=64,
|
||||
workers_per_gpu=8,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/train',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/val',
|
||||
ann_file='data/imagenet/meta/val.txt',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
# replace `data/val` with `data/test` for standard test
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/val',
|
||||
ann_file='data/imagenet/meta/val.txt',
|
||||
pipeline=test_pipeline))
|
||||
|
||||
evaluation = dict(interval=10, metric='accuracy')
|
|
@ -0,0 +1,11 @@
|
|||
# Model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='ConvMixer', arch='1024/20'),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1024,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
))
|
|
@ -0,0 +1,11 @@
|
|||
# Model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='ConvMixer', arch='1536/20'),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1536,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
))
|
|
@ -0,0 +1,11 @@
|
|||
# Model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='ConvMixer', arch='768/32', act_cfg=dict(type='ReLU')),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
))
|
|
@ -0,0 +1,41 @@
|
|||
# ConvMixer
|
||||
|
||||
> [Patches Are All You Need?](https://arxiv.org/abs/2201.09792)
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
<!-- [ABSTRACT] -->
|
||||
Although convolutional networks have been the dominant architecture for vision tasks for many years, recent experiments have shown that Transformer-based models, most notably the Vision Transformer (ViT), may exceed their performance in some settings. However, due to the quadratic runtime of the self-attention layers in Transformers, ViTs require the use of patch embeddings, which group together small regions of the image into single input features, in order to be applied to larger image sizes. This raises a question: Is the performance of ViTs due to the inherently-more-powerful Transformer architecture, or is it at least partly due to using patches as the input representation? In this paper, we present some evidence for the latter: specifically, we propose the ConvMixer, an extremely simple model that is similar in spirit to the ViT and the even-more-basic MLP-Mixer in that it operates directly on patches as input, separates the mixing of spatial and channel dimensions, and maintains equal size and resolution throughout the network. In contrast, however, the ConvMixer uses only standard convolutions to achieve the mixing steps. Despite its simplicity, we show that the ConvMixer outperforms the ViT, MLP-Mixer, and some of their variants for similar parameter counts and data set sizes, in addition to outperforming classical vision models such as the ResNet.
|
||||
|
||||
<!-- [IMAGE] -->
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/42952108/156284977-abf2245e-d9ba-4e0d-8e10-c0664a20f4c8.png" width="100%"/>
|
||||
</div>
|
||||
|
||||
## Results and models
|
||||
|
||||
### ImageNet-1k
|
||||
|
||||
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||
|:---------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
|
||||
| ConvMixer-768/32\* | 21.11 | 19.62 | 80.16 | 95.08 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convmixer/convmixer-768-32_10xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convmixer/convmixer-768-32_3rdparty_10xb64_in1k_20220323-bca1f7b8.pth) |
|
||||
| ConvMixer-1024/20\* | 24.38 | 5.55 | 76.94 | 93.36 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convmixer/convmixer-1024-20_10xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convmixer/convmixer-1024-20_3rdparty_10xb64_in1k_20220323-48f8aeba.pth) |
|
||||
| ConvMixer-1536/20\* | 51.63 | 48.71 | 81.37 | 95.61 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convmixer/convmixer-1536-20_10xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convmixer/convmixer-1536_20_3rdparty_10xb64_in1k_20220323-ea5786f3.pth) |
|
||||
|
||||
|
||||
*Models with \* are converted from the [official repo](https://github.com/locuslab/convmixer). The config files of these models are only for inference. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
|
||||
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{trockman2022patches,
|
||||
title={Patches Are All You Need?},
|
||||
author={Asher Trockman and J. Zico Kolter},
|
||||
year={2022},
|
||||
eprint={2201.09792},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CV}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,10 @@
|
|||
_base_ = [
|
||||
'../_base_/models/convmixer/convmixer-1024-20.py',
|
||||
'../_base_/datasets/imagenet_bs64_convmixer_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
optimizer = dict(lr=0.01)
|
||||
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=150)
|
|
@ -0,0 +1,10 @@
|
|||
_base_ = [
|
||||
'../_base_/models/convmixer/convmixer-1536-20.py',
|
||||
'../_base_/datasets/imagenet_bs64_convmixer_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
optimizer = dict(lr=0.01)
|
||||
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=150)
|
|
@ -0,0 +1,10 @@
|
|||
_base_ = [
|
||||
'../_base_/models/convmixer/convmixer-768-32.py',
|
||||
'../_base_/datasets/imagenet_bs64_convmixer_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
optimizer = dict(lr=0.01)
|
||||
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=300)
|
|
@ -0,0 +1,61 @@
|
|||
Collections:
|
||||
- Name: ConvMixer
|
||||
Metadata:
|
||||
Training Data: ImageNet-1k
|
||||
Architecture:
|
||||
- 1x1 Convolution
|
||||
- LayerScale
|
||||
Paper:
|
||||
URL: https://arxiv.org/abs/2201.09792
|
||||
Title: Patches Are All You Need?
|
||||
README: configs/convmixer/README.md
|
||||
|
||||
Models:
|
||||
- Name: convmixer-768-32_10xb64_in1k
|
||||
Metadata:
|
||||
FLOPs: 19623051264
|
||||
Parameters: 21110248
|
||||
In Collections: ConvMixer
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 80.16
|
||||
Top 5 Accuracy: 95.08
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convmixer/convmixer-768-32_3rdparty_10xb64_in1k_20220323-bca1f7b8.pth
|
||||
Config: configs/convmixer/convmixer-768-32_10xb64_in1k.py
|
||||
Converted From:
|
||||
Weights: https://github.com/tmp-iclr/convmixer/releases/download/v1.0/convmixer_768_32_ks7_p7_relu.pth.tar
|
||||
Code: https://github.com/locuslab/convmixer
|
||||
- Name: convmixer-1024-20_10xb64_in1k
|
||||
Metadata:
|
||||
FLOPs: 5550112768
|
||||
Parameters: 24383464
|
||||
In Collections: ConvMixer
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 76.94
|
||||
Top 5 Accuracy: 93.36
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convmixer/convmixer-1024-20_3rdparty_10xb64_in1k_20220323-48f8aeba.pth
|
||||
Config: configs/convmixer/convmixer-1024-20_10xb64_in1k.py
|
||||
Converted From:
|
||||
Weights: https://github.com/tmp-iclr/convmixer/releases/download/v1.0/convmixer_1024_20_ks9_p14.pth.tar
|
||||
Code: https://github.com/locuslab/convmixer
|
||||
- Name: convmixer-1536-20_10xb64_in1k
|
||||
Metadata:
|
||||
FLOPs: 48713170944
|
||||
Parameters: 51625960
|
||||
In Collections: ConvMixer
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 81.37
|
||||
Top 5 Accuracy: 95.61
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convmixer/convmixer-1536_20_3rdparty_10xb64_in1k_20220323-ea5786f3.pth
|
||||
Config: configs/convmixer/convmixer-1536-20_10xb64_in1k.py
|
||||
Converted From:
|
||||
Weights: https://github.com/tmp-iclr/convmixer/releases/download/v1.0/convmixer_1536_20_ks9_p7.pth.tar
|
||||
Code: https://github.com/locuslab/convmixer
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .alexnet import AlexNet
|
||||
from .conformer import Conformer
|
||||
from .convmixer import ConvMixer
|
||||
from .convnext import ConvNeXt
|
||||
from .deit import DistilledVisionTransformer
|
||||
from .efficientnet import EfficientNet
|
||||
|
@ -34,5 +35,5 @@ __all__ = [
|
|||
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
|
||||
'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG',
|
||||
'Conformer', 'MlpMixer', 'DistilledVisionTransformer', 'PCPVT', 'SVT',
|
||||
'EfficientNet', 'ConvNeXt', 'HRNet', 'ResNetV1c'
|
||||
'EfficientNet', 'ConvNeXt', 'HRNet', 'ResNetV1c', 'ConvMixer'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,176 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn.bricks import (Conv2dAdaptivePadding, build_activation_layer,
|
||||
build_norm_layer)
|
||||
from mmcv.utils import digit_version
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
class Residual(nn.Module):
|
||||
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(x) + x
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ConvMixer(BaseBackbone):
|
||||
"""ConvMixer. .
|
||||
|
||||
A PyTorch implementation of : `Patches Are All You Need?
|
||||
<https://arxiv.org/pdf/2201.09792.pdf>`_
|
||||
|
||||
Modified from the `official repo
|
||||
<https://github.com/locuslab/convmixer/blob/main/convmixer.py>`_
|
||||
and `timm
|
||||
<https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/convmixer.py>`_.
|
||||
|
||||
Args:
|
||||
arch (str | dict): The model's architecture. If string, it should be
|
||||
one of architecture in ``ConvMixer.arch_settings``. And if dict, it
|
||||
should include the following two keys:
|
||||
|
||||
- embed_dims (int): The dimensions of patch embedding.
|
||||
- depth (int): Number of repetitions of ConvMixer Layer.
|
||||
- patch_size (int): The patch size.
|
||||
- kernel_size (int): The kernel size of depthwise conv layers.
|
||||
|
||||
Defaults to '768/32'.
|
||||
in_channels (int): Number of input image channels. Defaults to 3.
|
||||
patch_size (int): The size of one patch in the patch embed layer.
|
||||
Defaults to 7.
|
||||
norm_cfg (dict): The config dict for norm layers.
|
||||
Defaults to ``dict(type='BN')``.
|
||||
act_cfg (dict): The config dict for activation after each convolution.
|
||||
Defaults to ``dict(type='GELU')``.
|
||||
out_indices (Sequence | int): Output from which stages.
|
||||
Defaults to -1, means the last stage.
|
||||
frozen_stages (int): Stages to be frozen (all param fixed).
|
||||
Defaults to 0, which means not freezing any parameters.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
"""
|
||||
arch_settings = {
|
||||
'768/32': {
|
||||
'embed_dims': 768,
|
||||
'depth': 32,
|
||||
'patch_size': 7,
|
||||
'kernel_size': 7
|
||||
},
|
||||
'1024/20': {
|
||||
'embed_dims': 1024,
|
||||
'depth': 20,
|
||||
'patch_size': 14,
|
||||
'kernel_size': 9
|
||||
},
|
||||
'1536/20': {
|
||||
'embed_dims': 1536,
|
||||
'depth': 20,
|
||||
'patch_size': 7,
|
||||
'kernel_size': 9
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
arch='768/32',
|
||||
in_channels=3,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
out_indices=-1,
|
||||
frozen_stages=0,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
if isinstance(arch, str):
|
||||
assert arch in self.arch_settings, \
|
||||
f'Unavailable arch, please choose from ' \
|
||||
f'({set(self.arch_settings)}) or pass a dict.'
|
||||
arch = self.arch_settings[arch]
|
||||
elif isinstance(arch, dict):
|
||||
essential_keys = {
|
||||
'embed_dims', 'depth', 'patch_size', 'kernel_size'
|
||||
}
|
||||
assert isinstance(arch, dict) and essential_keys <= set(arch), \
|
||||
f'Custom arch needs a dict with keys {essential_keys}'
|
||||
|
||||
self.embed_dims = arch['embed_dims']
|
||||
self.depth = arch['depth']
|
||||
self.patch_size = arch['patch_size']
|
||||
self.kernel_size = arch['kernel_size']
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
|
||||
# check out indices and frozen stages
|
||||
if isinstance(out_indices, int):
|
||||
out_indices = [out_indices]
|
||||
assert isinstance(out_indices, Sequence), \
|
||||
f'"out_indices" must by a sequence or int, ' \
|
||||
f'get {type(out_indices)} instead.'
|
||||
for i, index in enumerate(out_indices):
|
||||
if index < 0:
|
||||
out_indices[i] = self.depth + index
|
||||
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
|
||||
# Set stem layers
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
self.embed_dims,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size), self.act,
|
||||
build_norm_layer(norm_cfg, self.embed_dims)[1])
|
||||
|
||||
# Set conv2d according to torch version
|
||||
convfunc = nn.Conv2d
|
||||
if digit_version(torch.__version__) < digit_version('1.9.0'):
|
||||
convfunc = Conv2dAdaptivePadding
|
||||
|
||||
# Repetitions of ConvMixer Layer
|
||||
self.stages = nn.Sequential(*[
|
||||
nn.Sequential(
|
||||
Residual(
|
||||
nn.Sequential(
|
||||
convfunc(
|
||||
self.embed_dims,
|
||||
self.embed_dims,
|
||||
self.kernel_size,
|
||||
groups=self.embed_dims,
|
||||
padding='same'), self.act,
|
||||
build_norm_layer(norm_cfg, self.embed_dims)[1])),
|
||||
nn.Conv2d(self.embed_dims, self.embed_dims, kernel_size=1),
|
||||
self.act,
|
||||
build_norm_layer(norm_cfg, self.embed_dims)[1])
|
||||
for _ in range(self.depth)
|
||||
])
|
||||
|
||||
self._freeze_stages()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.stem(x)
|
||||
outs = []
|
||||
for i, stage in enumerate(self.stages):
|
||||
x = stage(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
|
||||
# x = self.pooling(x).flatten(1)
|
||||
return tuple(outs)
|
||||
|
||||
def train(self, mode=True):
|
||||
super(ConvMixer, self).train(mode)
|
||||
self._freeze_stages()
|
||||
|
||||
def _freeze_stages(self):
|
||||
for i in range(self.frozen_stages):
|
||||
stage = self.stages[i]
|
||||
stage.eval()
|
||||
for param in stage.parameters():
|
||||
param.requires_grad = False
|
|
@ -21,3 +21,4 @@ Import:
|
|||
- configs/convnext/metafile.yml
|
||||
- configs/hrnet/metafile.yml
|
||||
- configs/wrn/metafile.yml
|
||||
- configs/convmixer/metafile.yml
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcls.models.backbones import ConvMixer
|
||||
|
||||
|
||||
def test_assertion():
|
||||
with pytest.raises(AssertionError):
|
||||
ConvMixer(arch='unknown')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# ConvMixer arch dict should include essential_keys,
|
||||
ConvMixer(arch=dict(channels=[2, 3, 4, 5]))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# ConvMixer out_indices should be valid depth.
|
||||
ConvMixer(out_indices=-100)
|
||||
|
||||
|
||||
def test_convmixer():
|
||||
|
||||
# Test forward
|
||||
model = ConvMixer(arch='768/32')
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 1
|
||||
assert feat[0].shape == torch.Size([1, 768, 32, 32])
|
||||
|
||||
# Test forward with multiple outputs
|
||||
model = ConvMixer(arch='768/32', out_indices=range(32))
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 32
|
||||
for f in feat:
|
||||
assert f.shape == torch.Size([1, 768, 32, 32])
|
||||
|
||||
# Test with custom arch
|
||||
model = ConvMixer(
|
||||
arch={
|
||||
'embed_dims': 99,
|
||||
'depth': 5,
|
||||
'patch_size': 5,
|
||||
'kernel_size': 9
|
||||
},
|
||||
out_indices=range(5))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 5
|
||||
for f in feat:
|
||||
assert f.shape == torch.Size([1, 99, 44, 44])
|
||||
|
||||
# Test with even kernel size arch
|
||||
model = ConvMixer(arch={
|
||||
'embed_dims': 99,
|
||||
'depth': 5,
|
||||
'patch_size': 5,
|
||||
'kernel_size': 8
|
||||
})
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 1
|
||||
assert feat[0].shape == torch.Size([1, 99, 44, 44])
|
||||
|
||||
# Test frozen_stages
|
||||
model = ConvMixer(arch='768/32', frozen_stages=10)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
for i in range(10):
|
||||
assert not model.stages[i].training
|
||||
|
||||
for i in range(10, 32):
|
||||
assert model.stages[i].training
|
Loading…
Reference in New Issue