[Feature] Support XCiT Backbone. (#1305)

* update model file

* Update XCiT implementation and configs.

* Update metafiles

* Update metafile

* Fix floor divide

* Imporve memory usage

---------

Co-authored-by: qingtian <459291290@qq.com>
Co-authored-by: mzr1996 <mzr1996@163.com>
pull/1345/head
QINGTIAN 2023-02-15 10:32:35 +08:00 committed by GitHub
parent bedf4e9f64
commit 8352951f3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 2775 additions and 21 deletions

View File

@ -1,5 +1,6 @@
import argparse
import copy
import re
from functools import partial
from pathlib import Path
@ -70,6 +71,7 @@ def parse_args():
parser.add_argument('--out', '-o', type=Path, help='The output path.')
parser.add_argument(
'--view', action='store_true', help='Only pretty print the metafile.')
parser.add_argument('--csv', type=str, help='Use a csv to update models.')
args = parser.parse_args()
return args
@ -165,14 +167,17 @@ def fill_collection(collection: dict):
return collection
def fill_model(model: dict, defaults: dict):
def fill_model_by_prompt(model: dict, defaults: dict):
# Name
if model.get('Name') is None:
name = prompt(
'Please input the model [red]name[/]: ', allow_empty=False)
model['Name'] = name
# In Collection
model['In Collection'] = defaults.get('In Collection')
# Config
config = model.get('Config')
if config is None:
config = prompt(
@ -182,6 +187,7 @@ def fill_model(model: dict, defaults: dict):
config = str(Path(config).absolute().relative_to(MMCLS_ROOT))
model['Config'] = config
# Metadata.Flops, Metadata.Parameters
flops = model.get('Metadata', {}).get('FLOPs')
params = model.get('Metadata', {}).get('Parameters')
if model.get('Config') is not None and (
@ -280,6 +286,100 @@ def fill_model(model: dict, defaults: dict):
return model
def update_model_by_dict(model: dict, update_dict: dict, defaults: dict):
# Name
if 'name override' in update_dict:
model['Name'] = update_dict['name override']
# In Collection
model['In Collection'] = defaults.get('In Collection')
# Config
if 'config' in update_dict:
config = update_dict['config'].strip()
config = str(Path(config).absolute().relative_to(MMCLS_ROOT))
config_updated = (config != model.get('Config'))
model['Config'] = config
else:
config_updated = False
# Metadata.Flops, Metadata.Parameters
flops = model.get('Metadata', {}).get('FLOPs')
params = model.get('Metadata', {}).get('Parameters')
if config_updated and (flops is None or params is None):
print(f'Automatically compute FLOPs and Parameters of {model["Name"]}')
flops, params = get_flops(str(MMCLS_ROOT / model['Config']))
model.setdefault('Metadata', {})
model['Metadata']['FLOPs'] = flops
model['Metadata']['Parameters'] = params
# Metadata.Training Data
if 'metadata.training data' in update_dict:
train_data = update_dict['metadata.training data'].strip()
train_data = re.split(r'\s+', train_data)
if len(train_data) > 1:
model['Metadata']['Training Data'] = train_data
elif len(train_data) == 1:
model['Metadata']['Training Data'] = train_data[0]
# Results.Dataset
if 'results.dataset' in update_dict:
test_data = update_dict['results.dataset'].strip()
results = model.get('Results') or [{}]
result = results[0]
result['Dataset'] = test_data
model['Results'] = results
# Results.Metrics.Top 1 Accuracy
result = None
if 'results.metrics.top 1 accuracy' in update_dict:
top1 = update_dict['results.metrics.top 1 accuracy']
results = model.get('Results') or [{}]
result = results[0]
result.setdefault('Metrics', {})
result['Metrics']['Top 1 Accuracy'] = round(float(top1), 2)
task = 'Image Classification'
model['Results'] = results
# Results.Metrics.Top 5 Accuracy
if 'results.metrics.top 5 accuracy' in update_dict:
top5 = update_dict['results.metrics.top 5 accuracy']
results = model.get('Results') or [{}]
result = results[0]
result.setdefault('Metrics', {})
result['Metrics']['Top 5 Accuracy'] = round(float(top5), 2)
task = 'Image Classification'
model['Results'] = results
if result is not None:
result['Metrics']['Task'] = task
# Weights
if 'weights' in update_dict:
weights = update_dict['weights'].strip()
model['Weights'] = weights
# Converted From.Code
if 'converted from.code' in update_dict:
from_code = update_dict['converted from.code'].strip()
model.setdefault('Converted From', {})
model['Converted From']['Code'] = from_code
# Converted From.Weights
if 'converted from.weights' in update_dict:
from_weight = update_dict['converted from.weights'].strip()
model.setdefault('Converted From', {})
model['Converted From']['Weights'] = from_weight
order = [
'Name', 'Metadata', 'In Collection', 'Results', 'Weights', 'Config',
'Converted From'
]
model = {k: model[k] for k in sorted(model.keys(), key=order.index)}
return model
def format_collection(collection: dict):
yaml_str = yaml_dump(collection)
return Panel(
@ -325,35 +425,44 @@ def main():
models = content.get('Models', [])
updated_models = []
try:
if args.csv is not None:
import pandas as pd
df = pd.read_csv(args.csv).rename(columns=lambda x: x.strip().lower())
assert df['name'].is_unique, 'The csv has duplicated model names.'
models_dict = {item['Name']: item for item in models}
for update_dict in df.to_dict('records'):
assert 'name' in update_dict, 'The csv must have the `Name` field.'
model_name = update_dict['name'].strip()
model = models_dict.pop(model_name, {'Name': model_name})
model = update_model_by_dict(model, update_dict, model_defaults)
updated_models.append(model)
updated_models.extend(models_dict.values())
else:
for model in models:
console.print(format_model(model))
ori_model = copy.deepcopy(model)
model = fill_model(model, model_defaults)
model = fill_model_by_prompt(model, model_defaults)
if ori_model != model:
console.print(format_model(model))
updated_models.append(model)
while Confirm.ask('Add new model?'):
model = fill_model({}, model_defaults)
model = fill_model_by_prompt({}, model_defaults)
updated_models.append(model)
finally:
# Save updated models even error happened.
updated_models.sort(key=lambda item: (item.get('Metadata', {}).get(
'FLOPs', 0), len(item['Name'])))
if args.out is not None:
with open(args.out, 'w') as f:
yaml_dump({'Collections': [collection]}, f)
f.write('\n')
yaml_dump({'Models': updated_models}, f)
else:
modelindex = {
'Collections': [collection],
'Models': updated_models
}
yaml_str = yaml_dump(modelindex)
console.print(Syntax(yaml_str, 'yaml', background_color='default'))
console.print('Specify [red]`--out`[/] to dump to file.')
# Save updated models even error happened.
updated_models.sort(key=lambda item: (item.get('Metadata', {}).get(
'FLOPs', 0), len(item['Name'])))
if args.out is not None:
with open(args.out, 'w') as f:
yaml_dump({'Collections': [collection]}, f)
f.write('\n')
yaml_dump({'Models': updated_models}, f)
else:
modelindex = {'Collections': [collection], 'Models': updated_models}
yaml_str = yaml_dump(modelindex)
console.print(Syntax(yaml_str, 'yaml', background_color='default'))
console.print('Specify [red]`--out`[/] to dump to file.')
if __name__ == '__main__':

View File

@ -0,0 +1,75 @@
# XCiT
> [XCiT: Cross-Covariance Image Transformers](https://arxiv.org/abs/2106.09681)
<!-- [ALGORITHM] -->
## Abstract
Following their success in natural language processing, transformers have recently shown much promise for computer vision. The self-attention operation underlying transformers yields global interactions between all tokens ,i.e. words or image patches, and enables flexible modelling of image data beyond the local interactions of convolutions. This flexibility, however, comes with a quadratic complexity in time and memory, hindering application to long sequences and high-resolution images. We propose a "transposed" version of self-attention that operates across feature channels rather than tokens, where the interactions are based on the cross-covariance matrix between keys and queries. The resulting cross-covariance attention (XCA) has linear complexity in the number of tokens, and allows efficient processing of high-resolution images. Our cross-covariance image transformer (XCiT) is built upon XCA. It combines the accuracy of conventional transformers with the scalability of convolutional architectures. We validate the effectiveness and generality of XCiT by reporting excellent results on multiple vision benchmarks, including image classification and self-supervised feature learning on ImageNet-1k, object detection and instance segmentation on COCO, and semantic segmentation on ADE20k.
<div align=center>
<img src="https://user-images.githubusercontent.com/26739999/218900814-64a44606-150b-4757-aec8-7015c77a9fd1.png" width="60%"/>
</div>
## Results and models
### ImageNet-1k
| Model | Pretrain | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
| :-------------------------------------------: | :----------: | :-------: | :------: | :-------: | :-------: | :-------------------------------------------------: | :-------------------------------------------------------: |
| xcit-nano-12-p16_3rdparty_in1k\* | From scratch | 3.05 | 0.56 | 70.35 | 89.98 | [config](./xcit-nano-12-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p16_3rdparty_in1k_20230213-ed776c38.pth) |
| xcit-nano-12-p16_3rdparty-dist_in1k\* | Distillation | 3.05 | 0.56 | 72.36 | 91.02 | [config](./xcit-nano-12-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p16_3rdparty-dist_in1k_20230213-fb247f7b.pth) |
| xcit-nano-12-p16_3rdparty-dist_in1k-384px\* | Distillation | 3.05 | 1.64 | 74.93 | 92.42 | [config](./xcit-nano-12-p16_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p16_3rdparty-dist_in1k-384px_20230213-712db4d4.pth) |
| xcit-nano-12-p8_3rdparty_in1k\* | From scratch | 3.05 | 2.16 | 73.80 | 92.08 | [config](./xcit-nano-12-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p8_3rdparty_in1k_20230213-3370c293.pth) |
| xcit-nano-12-p8_3rdparty-dist_in1k\* | Distillation | 3.05 | 2.16 | 76.17 | 93.08 | [config](./xcit-nano-12-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p8_3rdparty-dist_in1k_20230213-2f87d2b3.pth) |
| xcit-nano-12-p8_3rdparty-dist_in1k-384px\* | Distillation | 3.05 | 6.34 | 77.69 | 94.09 | [config](./xcit-nano-12-p8_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p8_3rdparty-dist_in1k-384px_20230213-09d925ef.pth) |
| xcit-tiny-12-p16_3rdparty_in1k\* | From scratch | 6.72 | 1.24 | 77.21 | 93.62 | [config](./xcit-tiny-12-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p16_3rdparty_in1k_20230213-82c547ca.pth) |
| xcit-tiny-12-p16_3rdparty-dist_in1k\* | Distillation | 6.72 | 1.24 | 78.70 | 94.12 | [config](./xcit-tiny-12-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p16_3rdparty-dist_in1k_20230213-d5fde0a3.pth) |
| xcit-tiny-24-p16_3rdparty_in1k\* | From scratch | 12.12 | 2.34 | 79.47 | 94.85 | [config](./xcit-tiny-24-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p16_3rdparty_in1k_20230213-366c1cd0.pth) |
| xcit-tiny-24-p16_3rdparty-dist_in1k\* | Distillation | 12.12 | 2.34 | 80.51 | 95.17 | [config](./xcit-tiny-24-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p16_3rdparty-dist_in1k_20230213-b472e80a.pth) |
| xcit-tiny-12-p16_3rdparty-dist_in1k-384px\* | Distillation | 6.72 | 3.64 | 80.58 | 95.38 | [config](./xcit-tiny-12-p16_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p16_3rdparty-dist_in1k-384px_20230213-00a20023.pth) |
| xcit-tiny-12-p8_3rdparty_in1k\* | From scratch | 6.71 | 4.81 | 79.75 | 94.88 | [config](./xcit-tiny-12-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p8_3rdparty_in1k_20230213-8b02f8f5.pth) |
| xcit-tiny-12-p8_3rdparty-dist_in1k\* | Distillation | 6.71 | 4.81 | 81.26 | 95.46 | [config](./xcit-tiny-12-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p8_3rdparty-dist_in1k_20230213-f3f9b44f.pth) |
| xcit-tiny-24-p8_3rdparty_in1k\* | From scratch | 12.11 | 9.21 | 81.70 | 95.90 | [config](./xcit-tiny-24-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p8_3rdparty_in1k_20230213-4b9ba392.pth) |
| xcit-tiny-24-p8_3rdparty-dist_in1k\* | Distillation | 12.11 | 9.21 | 82.62 | 96.16 | [config](./xcit-tiny-24-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p8_3rdparty-dist_in1k_20230213-ad9c44b0.pth) |
| xcit-tiny-12-p8_3rdparty-dist_in1k-384px\* | Distillation | 6.71 | 14.13 | 82.46 | 96.22 | [config](./xcit-tiny-12-p8_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p8_3rdparty-dist_in1k-384px_20230213-a072174a.pth) |
| xcit-tiny-24-p16_3rdparty-dist_in1k-384px\* | Distillation | 12.12 | 6.87 | 82.43 | 96.20 | [config](./xcit-tiny-24-p16_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p16_3rdparty-dist_in1k-384px_20230213-20e13917.pth) |
| xcit-tiny-24-p8_3rdparty-dist_in1k-384px\* | Distillation | 12.11 | 27.05 | 83.77 | 96.72 | [config](./xcit-tiny-24-p8_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p8_3rdparty-dist_in1k-384px_20230213-30d5e5ec.pth) |
| xcit-small-12-p16_3rdparty_in1k\* | From scratch | 26.25 | 4.81 | 81.87 | 95.77 | [config](./xcit-small-12-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p16_3rdparty_in1k_20230213-d36779d2.pth) |
| xcit-small-12-p16_3rdparty-dist_in1k\* | Distillation | 26.25 | 4.81 | 83.12 | 96.41 | [config](./xcit-small-12-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p16_3rdparty-dist_in1k_20230213-c95bbae1.pth) |
| xcit-small-24-p16_3rdparty_in1k\* | From scratch | 47.67 | 9.10 | 82.38 | 95.93 | [config](./xcit-small-24-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p16_3rdparty_in1k_20230213-40febe38.pth) |
| xcit-small-24-p16_3rdparty-dist_in1k\* | Distillation | 47.67 | 9.10 | 83.70 | 96.61 | [config](./xcit-small-24-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p16_3rdparty-dist_in1k_20230213-130d7262.pth) |
| xcit-small-12-p16_3rdparty-dist_in1k-384px\* | Distillation | 26.25 | 14.14 | 84.74 | 97.19 | [config](./xcit-small-12-p16_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p16_3rdparty-dist_in1k-384px_20230213-ba36c982.pth) |
| xcit-small-12-p8_3rdparty_in1k\* | From scratch | 26.21 | 18.69 | 83.21 | 96.41 | [config](./xcit-small-12-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p8_3rdparty_in1k_20230213-9e364ce3.pth) |
| xcit-small-12-p8_3rdparty-dist_in1k\* | Distillation | 26.21 | 18.69 | 83.97 | 96.81 | [config](./xcit-small-12-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p8_3rdparty-dist_in1k_20230213-71886580.pth) |
| xcit-small-24-p16_3rdparty-dist_in1k-384px\* | Distillation | 47.67 | 26.72 | 85.10 | 97.32 | [config](./xcit-small-24-p16_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p16_3rdparty-dist_in1k-384px_20230213-28fa2d0e.pth) |
| xcit-small-24-p8_3rdparty_in1k\* | From scratch | 47.63 | 35.81 | 83.62 | 96.51 | [config](./xcit-small-24-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p8_3rdparty_in1k_20230213-280ebcc7.pth) |
| xcit-small-24-p8_3rdparty-dist_in1k\* | Distillation | 47.63 | 35.81 | 84.68 | 97.07 | [config](./xcit-small-24-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p8_3rdparty-dist_in1k_20230213-f2773c78.pth) |
| xcit-small-12-p8_3rdparty-dist_in1k-384px\* | Distillation | 26.21 | 54.92 | 85.12 | 97.31 | [config](./xcit-small-12-p8_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p8_3rdparty-dist_in1k-384px_20230214-9f2178bc.pth) |
| xcit-small-24-p8_3rdparty-dist_in1k-384px\* | Distillation | 47.63 | 105.24 | 85.57 | 97.60 | [config](./xcit-small-24-p8_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p8_3rdparty-dist_in1k-384px_20230214-57298eca.pth) |
| xcit-medium-24-p16_3rdparty_in1k\* | From scratch | 84.40 | 16.13 | 82.56 | 95.82 | [config](./xcit-medium-24-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p16_3rdparty_in1k_20230213-ad0aa92e.pth) |
| xcit-medium-24-p16_3rdparty-dist_in1k\* | Distillation | 84.40 | 16.13 | 84.15 | 96.82 | [config](./xcit-medium-24-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p16_3rdparty-dist_in1k_20230213-aca5cd0c.pth) |
| xcit-medium-24-p16_3rdparty-dist_in1k-384px\* | Distillation | 84.40 | 47.39 | 85.47 | 97.49 | [config](./xcit-medium-24-p16_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p16_3rdparty-dist_in1k-384px_20230214-6c23a201.pth) |
| xcit-medium-24-p8_3rdparty_in1k\* | From scratch | 84.32 | 63.52 | 83.61 | 96.23 | [config](./xcit-medium-24-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p8_3rdparty_in1k_20230214-c362850b.pth) |
| xcit-medium-24-p8_3rdparty-dist_in1k\* | Distillation | 84.32 | 63.52 | 85.00 | 97.16 | [config](./xcit-medium-24-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p8_3rdparty-dist_in1k_20230214-625c953b.pth) |
| xcit-medium-24-p8_3rdparty-dist_in1k-384px\* | Distillation | 84.32 | 186.67 | 85.87 | 97.61 | [config](./xcit-medium-24-p8_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p8_3rdparty-dist_in1k-384px_20230214-5db925e0.pth) |
| xcit-large-24-p16_3rdparty_in1k\* | From scratch | 189.10 | 35.86 | 82.97 | 95.86 | [config](./xcit-large-24-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p16_3rdparty_in1k_20230214-d29d2529.pth) |
| xcit-large-24-p16_3rdparty-dist_in1k\* | Distillation | 189.10 | 35.86 | 84.61 | 97.07 | [config](./xcit-large-24-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p16_3rdparty-dist_in1k_20230214-4fea599c.pth) |
| xcit-large-24-p16_3rdparty-dist_in1k-384px\* | Distillation | 189.10 | 105.35 | 85.78 | 97.60 | [config](./xcit-large-24-p16_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p16_3rdparty-dist_in1k-384px_20230214-bd515a34.pth) |
| xcit-large-24-p8_3rdparty_in1k\* | From scratch | 188.93 | 141.23 | 84.23 | 96.58 | [config](./xcit-large-24-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p8_3rdparty_in1k_20230214-08f2f664.pth) |
| xcit-large-24-p8_3rdparty-dist_in1k\* | Distillation | 188.93 | 141.23 | 85.14 | 97.32 | [config](./xcit-large-24-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p8_3rdparty-dist_in1k_20230214-8c092b34.pth) |
| xcit-large-24-p8_3rdparty-dist_in1k-384px\* | Distillation | 188.93 | 415.00 | 86.13 | 97.75 | [config](./xcit-large-24-p8_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p8_3rdparty-dist_in1k-384px_20230214-9f718b1a.pth) |
*Models with * are converted from the [official repo](https://github.com/facebookresearch/xcit). The config files of these models are only for inference. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
## Citation
```bibtex
@article{el2021xcit,
title={XCiT: Cross-Covariance Image Transformers},
author={El-Nouby, Alaaeldin and Touvron, Hugo and Caron, Mathilde and Bojanowski, Piotr and Douze, Matthijs and Joulin, Armand and Laptev, Ivan and Neverova, Natalia and Synnaeve, Gabriel and Verbeek, Jakob and others},
journal={arXiv preprint arXiv:2106.09681},
year={2021}
}
```

View File

@ -0,0 +1,727 @@
Collections:
- Name: XCiT
Metadata:
Architecture:
- Class Attention
- Local Patch Interaction
- Cross-Covariance Attention
Paper:
Title: 'XCiT: Cross-Covariance Image Transformers'
URL: https://arxiv.org/abs/2106.09681
README: configs/xcit/README.md
Models:
- Name: xcit-nano-12-p16_3rdparty_in1k
Metadata:
FLOPs: 557074560
Parameters: 3053224
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 70.35
Top 5 Accuracy: 89.98
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p16_3rdparty_in1k_20230213-ed776c38.pth
Config: configs/xcit/xcit-nano-12-p16_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_224.pth
- Name: xcit-nano-12-p16_3rdparty-dist_in1k
Metadata:
FLOPs: 557074560
Parameters: 3053224
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 72.36
Top 5 Accuracy: 91.02
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p16_3rdparty-dist_in1k_20230213-fb247f7b.pth
Config: configs/xcit/xcit-nano-12-p16_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_224_dist.pth
- Name: xcit-tiny-12-p16_3rdparty_in1k
Metadata:
FLOPs: 1239698112
Parameters: 6716272
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 77.21
Top 5 Accuracy: 93.62
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p16_3rdparty_in1k_20230213-82c547ca.pth
Config: configs/xcit/xcit-tiny-12-p16_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_224.pth
- Name: xcit-tiny-12-p16_3rdparty-dist_in1k
Metadata:
FLOPs: 1239698112
Parameters: 6716272
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 78.7
Top 5 Accuracy: 94.12
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p16_3rdparty-dist_in1k_20230213-d5fde0a3.pth
Config: configs/xcit/xcit-tiny-12-p16_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_224_dist.pth
- Name: xcit-nano-12-p16_3rdparty-dist_in1k-384px
Metadata:
FLOPs: 1636347520
Parameters: 3053224
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 74.93
Top 5 Accuracy: 92.42
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p16_3rdparty-dist_in1k-384px_20230213-712db4d4.pth
Config: configs/xcit/xcit-nano-12-p16_8xb128_in1k-384px.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_384_dist.pth
- Name: xcit-nano-12-p8_3rdparty_in1k
Metadata:
FLOPs: 2156861056
Parameters: 3049016
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 73.8
Top 5 Accuracy: 92.08
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p8_3rdparty_in1k_20230213-3370c293.pth
Config: configs/xcit/xcit-nano-12-p8_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_224.pth
- Name: xcit-nano-12-p8_3rdparty-dist_in1k
Metadata:
FLOPs: 2156861056
Parameters: 3049016
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 76.17
Top 5 Accuracy: 93.08
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p8_3rdparty-dist_in1k_20230213-2f87d2b3.pth
Config: configs/xcit/xcit-nano-12-p8_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_224_dist.pth
- Name: xcit-tiny-24-p16_3rdparty_in1k
Metadata:
FLOPs: 2339305152
Parameters: 12116896
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 79.47
Top 5 Accuracy: 94.85
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p16_3rdparty_in1k_20230213-366c1cd0.pth
Config: configs/xcit/xcit-tiny-24-p16_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_224.pth
- Name: xcit-tiny-24-p16_3rdparty-dist_in1k
Metadata:
FLOPs: 2339305152
Parameters: 12116896
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 80.51
Top 5 Accuracy: 95.17
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p16_3rdparty-dist_in1k_20230213-b472e80a.pth
Config: configs/xcit/xcit-tiny-24-p16_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_224_dist.pth
- Name: xcit-tiny-12-p16_3rdparty-dist_in1k-384px
Metadata:
FLOPs: 3641468352
Parameters: 6716272
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 80.58
Top 5 Accuracy: 95.38
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p16_3rdparty-dist_in1k-384px_20230213-00a20023.pth
Config: configs/xcit/xcit-tiny-12-p16_8xb128_in1k-384px.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_384_dist.pth
- Name: xcit-tiny-12-p8_3rdparty_in1k
Metadata:
FLOPs: 4807399872
Parameters: 6706504
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 79.75
Top 5 Accuracy: 94.88
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p8_3rdparty_in1k_20230213-8b02f8f5.pth
Config: configs/xcit/xcit-tiny-12-p8_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_224.pth
- Name: xcit-tiny-12-p8_3rdparty-dist_in1k
Metadata:
FLOPs: 4807399872
Parameters: 6706504
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 81.26
Top 5 Accuracy: 95.46
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p8_3rdparty-dist_in1k_20230213-f3f9b44f.pth
Config: configs/xcit/xcit-tiny-12-p8_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_224_dist.pth
- Name: xcit-small-12-p16_3rdparty_in1k
Metadata:
FLOPs: 4814951808
Parameters: 26253304
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 81.87
Top 5 Accuracy: 95.77
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p16_3rdparty_in1k_20230213-d36779d2.pth
Config: configs/xcit/xcit-small-12-p16_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224.pth
- Name: xcit-small-12-p16_3rdparty-dist_in1k
Metadata:
FLOPs: 4814951808
Parameters: 26253304
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.12
Top 5 Accuracy: 96.41
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p16_3rdparty-dist_in1k_20230213-c95bbae1.pth
Config: configs/xcit/xcit-small-12-p16_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224_dist.pth
- Name: xcit-nano-12-p8_3rdparty-dist_in1k-384px
Metadata:
FLOPs: 6337760896
Parameters: 3049016
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 77.69
Top 5 Accuracy: 94.09
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p8_3rdparty-dist_in1k-384px_20230213-09d925ef.pth
Config: configs/xcit/xcit-nano-12-p8_8xb128_in1k-384px.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_384_dist.pth
- Name: xcit-tiny-24-p16_3rdparty-dist_in1k-384px
Metadata:
FLOPs: 6872966592
Parameters: 12116896
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 82.43
Top 5 Accuracy: 96.2
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p16_3rdparty-dist_in1k-384px_20230213-20e13917.pth
Config: configs/xcit/xcit-tiny-24-p16_8xb128_in1k-384px.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_384_dist.pth
- Name: xcit-small-24-p16_3rdparty_in1k
Metadata:
FLOPs: 9095064960
Parameters: 47671384
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 82.38
Top 5 Accuracy: 95.93
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p16_3rdparty_in1k_20230213-40febe38.pth
Config: configs/xcit/xcit-small-24-p16_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_224.pth
- Name: xcit-small-24-p16_3rdparty-dist_in1k
Metadata:
FLOPs: 9095064960
Parameters: 47671384
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.7
Top 5 Accuracy: 96.61
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p16_3rdparty-dist_in1k_20230213-130d7262.pth
Config: configs/xcit/xcit-small-24-p16_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_224_dist.pth
- Name: xcit-tiny-24-p8_3rdparty_in1k
Metadata:
FLOPs: 9205828032
Parameters: 12107128
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 81.7
Top 5 Accuracy: 95.9
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p8_3rdparty_in1k_20230213-4b9ba392.pth
Config: configs/xcit/xcit-tiny-24-p8_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_224.pth
- Name: xcit-tiny-24-p8_3rdparty-dist_in1k
Metadata:
FLOPs: 9205828032
Parameters: 12107128
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 82.62
Top 5 Accuracy: 96.16
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p8_3rdparty-dist_in1k_20230213-ad9c44b0.pth
Config: configs/xcit/xcit-tiny-24-p8_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_224_dist.pth
- Name: xcit-tiny-12-p8_3rdparty-dist_in1k-384px
Metadata:
FLOPs: 14126142912
Parameters: 6706504
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 82.46
Top 5 Accuracy: 96.22
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p8_3rdparty-dist_in1k-384px_20230213-a072174a.pth
Config: configs/xcit/xcit-tiny-12-p8_8xb128_in1k-384px.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_384_dist.pth
- Name: xcit-small-12-p16_3rdparty-dist_in1k-384px
Metadata:
FLOPs: 14143179648
Parameters: 26253304
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 84.74
Top 5 Accuracy: 97.19
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p16_3rdparty-dist_in1k-384px_20230213-ba36c982.pth
Config: configs/xcit/xcit-small-12-p16_8xb128_in1k-384px.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_384_dist.pth
- Name: xcit-medium-24-p16_3rdparty_in1k
Metadata:
FLOPs: 16129561088
Parameters: 84395752
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 82.56
Top 5 Accuracy: 95.82
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p16_3rdparty_in1k_20230213-ad0aa92e.pth
Config: configs/xcit/xcit-medium-24-p16_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_224.pth
- Name: xcit-medium-24-p16_3rdparty-dist_in1k
Metadata:
FLOPs: 16129561088
Parameters: 84395752
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 84.15
Top 5 Accuracy: 96.82
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p16_3rdparty-dist_in1k_20230213-aca5cd0c.pth
Config: configs/xcit/xcit-medium-24-p16_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_224_dist.pth
- Name: xcit-small-12-p8_3rdparty_in1k
Metadata:
FLOPs: 18691601280
Parameters: 26213032
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.21
Top 5 Accuracy: 96.41
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p8_3rdparty_in1k_20230213-9e364ce3.pth
Config: configs/xcit/xcit-small-12-p8_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_224.pth
- Name: xcit-small-12-p8_3rdparty-dist_in1k
Metadata:
FLOPs: 18691601280
Parameters: 26213032
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.97
Top 5 Accuracy: 96.81
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p8_3rdparty-dist_in1k_20230213-71886580.pth
Config: configs/xcit/xcit-small-12-p8_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_224_dist.pth
- Name: xcit-small-24-p16_3rdparty-dist_in1k-384px
Metadata:
FLOPs: 26721471360
Parameters: 47671384
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 85.1
Top 5 Accuracy: 97.32
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p16_3rdparty-dist_in1k-384px_20230213-28fa2d0e.pth
Config: configs/xcit/xcit-small-24-p16_8xb128_in1k-384px.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_384_dist.pth
- Name: xcit-tiny-24-p8_3rdparty-dist_in1k-384px
Metadata:
FLOPs: 27052135872
Parameters: 12107128
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.77
Top 5 Accuracy: 96.72
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p8_3rdparty-dist_in1k-384px_20230213-30d5e5ec.pth
Config: configs/xcit/xcit-tiny-24-p8_8xb128_in1k-384px.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_384_dist.pth
- Name: xcit-small-24-p8_3rdparty_in1k
Metadata:
FLOPs: 35812053888
Parameters: 47631112
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.62
Top 5 Accuracy: 96.51
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p8_3rdparty_in1k_20230213-280ebcc7.pth
Config: configs/xcit/xcit-small-24-p8_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_224.pth
- Name: xcit-small-24-p8_3rdparty-dist_in1k
Metadata:
FLOPs: 35812053888
Parameters: 47631112
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 84.68
Top 5 Accuracy: 97.07
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p8_3rdparty-dist_in1k_20230213-f2773c78.pth
Config: configs/xcit/xcit-small-24-p8_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_224_dist.pth
- Name: xcit-large-24-p16_3rdparty_in1k
Metadata:
FLOPs: 35855948544
Parameters: 189096136
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 82.97
Top 5 Accuracy: 95.86
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p16_3rdparty_in1k_20230214-d29d2529.pth
Config: configs/xcit/xcit-large-24-p16_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_224.pth
- Name: xcit-large-24-p16_3rdparty-dist_in1k
Metadata:
FLOPs: 35855948544
Parameters: 189096136
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 84.61
Top 5 Accuracy: 97.07
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p16_3rdparty-dist_in1k_20230214-4fea599c.pth
Config: configs/xcit/xcit-large-24-p16_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_224_dist.pth
- Name: xcit-medium-24-p16_3rdparty-dist_in1k-384px
Metadata:
FLOPs: 47388932608
Parameters: 84395752
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 85.47
Top 5 Accuracy: 97.49
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p16_3rdparty-dist_in1k-384px_20230214-6c23a201.pth
Config: configs/xcit/xcit-medium-24-p16_8xb128_in1k-384px.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_384_dist.pth
- Name: xcit-small-12-p8_3rdparty-dist_in1k-384px
Metadata:
FLOPs: 54923537280
Parameters: 26213032
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 85.12
Top 5 Accuracy: 97.31
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p8_3rdparty-dist_in1k-384px_20230214-9f2178bc.pth
Config: configs/xcit/xcit-small-12-p8_8xb128_in1k-384px.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_384_dist.pth
- Name: xcit-medium-24-p8_3rdparty_in1k
Metadata:
FLOPs: 63524706816
Parameters: 84323624
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.61
Top 5 Accuracy: 96.23
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p8_3rdparty_in1k_20230214-c362850b.pth
Config: configs/xcit/xcit-medium-24-p8_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_224.pth
- Name: xcit-medium-24-p8_3rdparty-dist_in1k
Metadata:
FLOPs: 63524706816
Parameters: 84323624
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 85.0
Top 5 Accuracy: 97.16
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p8_3rdparty-dist_in1k_20230214-625c953b.pth
Config: configs/xcit/xcit-medium-24-p8_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_224_dist.pth
- Name: xcit-small-24-p8_3rdparty-dist_in1k-384px
Metadata:
FLOPs: 105236704128
Parameters: 47631112
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 85.57
Top 5 Accuracy: 97.6
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p8_3rdparty-dist_in1k-384px_20230214-57298eca.pth
Config: configs/xcit/xcit-small-24-p8_8xb128_in1k-384px.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_384_dist.pth
- Name: xcit-large-24-p16_3rdparty-dist_in1k-384px
Metadata:
FLOPs: 105345095424
Parameters: 189096136
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 85.78
Top 5 Accuracy: 97.6
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p16_3rdparty-dist_in1k-384px_20230214-bd515a34.pth
Config: configs/xcit/xcit-large-24-p16_8xb128_in1k-384px.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_384_dist.pth
- Name: xcit-large-24-p8_3rdparty_in1k
Metadata:
FLOPs: 141225699072
Parameters: 188932648
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 84.23
Top 5 Accuracy: 96.58
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p8_3rdparty_in1k_20230214-08f2f664.pth
Config: configs/xcit/xcit-large-24-p8_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_224.pth
- Name: xcit-large-24-p8_3rdparty-dist_in1k
Metadata:
FLOPs: 141225699072
Parameters: 188932648
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 85.14
Top 5 Accuracy: 97.32
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p8_3rdparty-dist_in1k_20230214-8c092b34.pth
Config: configs/xcit/xcit-large-24-p8_8xb128_in1k.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_224_dist.pth
- Name: xcit-medium-24-p8_3rdparty-dist_in1k-384px
Metadata:
FLOPs: 186672626176
Parameters: 84323624
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 85.87
Top 5 Accuracy: 97.61
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p8_3rdparty-dist_in1k-384px_20230214-5db925e0.pth
Config: configs/xcit/xcit-medium-24-p8_8xb128_in1k-384px.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_384_dist.pth
- Name: xcit-large-24-p8_3rdparty-dist_in1k-384px
Metadata:
FLOPs: 415003137792
Parameters: 188932648
Training Data: ImageNet-1k
In Collection: XCiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 86.13
Top 5 Accuracy: 97.75
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p8_3rdparty-dist_in1k-384px_20230214-9f718b1a.pth
Config: configs/xcit/xcit-large-24-p8_8xb128_in1k-384px.py
Converted From:
Code: https://github.com/facebookresearch/xcit
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_384_dist.pth

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_384.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=16,
embed_dims=768,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1e-5,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=16,
embed_dims=768,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1e-5,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_384.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=8,
embed_dims=768,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1e-5,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=8,
embed_dims=768,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1e-5,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_384.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=16,
embed_dims=512,
depth=24,
num_heads=8,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1e-5,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=16,
embed_dims=512,
depth=24,
num_heads=8,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1e-5,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_384.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=8,
embed_dims=512,
depth=24,
num_heads=8,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1e-5,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=8,
embed_dims=512,
depth=24,
num_heads=8,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1e-5,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_384.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=16,
embed_dims=128,
depth=12,
num_heads=4,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1.0,
tokens_norm=False,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=128,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=16,
embed_dims=128,
depth=12,
num_heads=4,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1.0,
tokens_norm=False,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=128,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_384.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=8,
embed_dims=128,
depth=12,
num_heads=4,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1.0,
tokens_norm=False,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=128,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=8,
embed_dims=128,
depth=12,
num_heads=4,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1.0,
tokens_norm=False,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=128,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_384.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=16,
embed_dims=384,
depth=12,
num_heads=8,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1.0,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=384,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=16,
embed_dims=384,
depth=12,
num_heads=8,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1.0,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=384,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_384.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=8,
embed_dims=384,
depth=12,
num_heads=8,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1.0,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=384,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=8,
embed_dims=384,
depth=12,
num_heads=8,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1.0,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=384,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_384.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=16,
embed_dims=384,
depth=24,
num_heads=8,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1e-5,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=384,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=16,
embed_dims=384,
depth=24,
num_heads=8,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1e-5,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=384,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_384.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=8,
embed_dims=384,
depth=24,
num_heads=8,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1e-5,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=384,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=8,
embed_dims=384,
depth=24,
num_heads=8,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1e-5,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=384,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_384.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=16,
embed_dims=192,
depth=12,
num_heads=4,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1.0,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=192,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=16,
embed_dims=192,
depth=12,
num_heads=4,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1.0,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=192,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_384.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=8,
embed_dims=192,
depth=12,
num_heads=4,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1.0,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=192,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=8,
embed_dims=192,
depth=12,
num_heads=4,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1.0,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=192,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_384.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=16,
embed_dims=192,
depth=24,
num_heads=4,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1e-5,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=192,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=16,
embed_dims=192,
depth=24,
num_heads=4,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1e-5,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=192,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_384.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=8,
embed_dims=192,
depth=24,
num_heads=4,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1e-5,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=192,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -0,0 +1,34 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='XCiT',
patch_size=8,
embed_dims=192,
depth=24,
num_heads=4,
mlp_ratio=4,
qkv_bias=True,
layer_scale_init_value=1e-5,
tokens_norm=True,
out_type='cls_token',
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=192,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
train_cfg=dict(arguments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
# dataset settings
train_dataloader = dict(batch_size=128)

View File

@ -112,6 +112,7 @@ Backbones
VGG
Vig
VisionTransformer
XCiT
.. module:: mmcls.models.necks

View File

@ -51,6 +51,7 @@ from .van import VAN
from .vgg import VGG
from .vig import PyramidVig, Vig
from .vision_transformer import VisionTransformer
from .xcit import XCiT
__all__ = [
'LeNet5',
@ -112,4 +113,5 @@ __all__ = [
'LeViT',
'Vig',
'PyramidVig',
'XCiT',
]

View File

@ -0,0 +1,770 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from functools import partial
from typing import Optional, Sequence, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks import ConvModule, DropPath
from mmcv.cnn.bricks.transformer import FFN
from mmengine.model import BaseModule, Sequential
from mmengine.model.weight_init import trunc_normal_
from mmengine.utils import digit_version
from mmcls.registry import MODELS
from ..utils import build_norm_layer, to_2tuple
from .base_backbone import BaseBackbone
if digit_version(torch.__version__) < digit_version('1.8.0'):
floor_div = torch.floor_divide
else:
floor_div = partial(torch.div, rounding_mode='floor')
class ClassAttntion(BaseModule):
"""Class Attention Module.
A PyTorch implementation of Class Attention Module introduced by:
`Going deeper with Image Transformers <https://arxiv.org/abs/2103.17239>`_
taken from
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
with slight modifications to do CA
Args:
dim (int): The feature dimension.
num_heads (int): Parallel attention heads. Defaults to 8.
qkv_bias (bool): enable bias for qkv if True. Defaults to False.
attn_drop (float): The drop out rate for attention output weights.
Defaults to 0.
proj_drop (float): The drop out rate for linear output weights.
Defaults to 0.
init_cfg (dict | list[dict], optional): Initialization config dict.
Defaults to None.
""" # noqa: E501
def __init__(self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
init_cfg=None):
super(ClassAttntion, self).__init__(init_cfg=init_cfg)
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.k = nn.Linear(dim, dim, bias=qkv_bias)
self.v = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
# We only need to calculate query of cls token.
q = self.q(x[:, 0]).unsqueeze(1).reshape(B, 1, self.num_heads,
C // self.num_heads).permute(
0, 2, 1, 3)
k = self.k(x).reshape(B, N, self.num_heads,
C // self.num_heads).permute(0, 2, 1, 3)
q = q * self.scale
v = self.v(x).reshape(B, N, self.num_heads,
C // self.num_heads).permute(0, 2, 1, 3)
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C)
x_cls = self.proj(x_cls)
x_cls = self.proj_drop(x_cls)
return x_cls
class PositionalEncodingFourier(BaseModule):
"""Positional Encoding using a fourier kernel.
A PyTorch implementation of Positional Encoding relying on
a fourier kernel introduced by:
`Attention is all you Need <https://arxiv.org/abs/1706.03762>`_
Based on the `official XCiT code
<https://github.com/facebookresearch/xcit/blob/master/xcit.py>`_
Args:
hidden_dim (int): The hidden feature dimension. Defaults to 32.
dim (int): The output feature dimension. Defaults to 768.
temperature (int): A control variable for position encoding.
Defaults to 10000.
init_cfg (dict | list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
hidden_dim: int = 32,
dim: int = 768,
temperature: int = 10000,
init_cfg=None):
super(PositionalEncodingFourier, self).__init__(init_cfg=init_cfg)
self.token_projection = ConvModule(
in_channels=hidden_dim * 2,
out_channels=dim,
kernel_size=1,
conv_cfg=None,
norm_cfg=None,
act_cfg=None)
self.scale = 2 * math.pi
self.temperature = temperature
self.hidden_dim = hidden_dim
self.dim = dim
self.eps = 1e-6
def forward(self, B: int, H: int, W: int):
device = self.token_projection.conv.weight.device
y_embed = torch.arange(
1, H + 1, device=device).unsqueeze(1).repeat(1, 1, W).float()
x_embed = torch.arange(1, W + 1, device=device).repeat(1, H, 1).float()
y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale
dim_t = torch.arange(self.hidden_dim, device=device).float()
dim_t = floor_div(dim_t, 2)
dim_t = self.temperature**(2 * dim_t / self.hidden_dim)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
[pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()],
dim=4).flatten(3)
pos_y = torch.stack(
[pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()],
dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
pos = self.token_projection(pos)
return pos.repeat(B, 1, 1, 1) # (B, C, H, W)
class ConvPatchEmbed(BaseModule):
"""Patch Embedding using multiple convolution layers.
Args:
img_size (int, tuple): input image size.
Defaults to 224, means the size is 224*224.
patch_size (int): The patch size in conv patch embedding.
Defaults to 16.
in_channels (int): The input channels of this module.
Defaults to 3.
embed_dims (int): The feature dimension
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='BN')``.
act_cfg (dict): Config dict for activation layer.
Defaults to ``dict(type='GELU')``.
init_cfg (dict | list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
img_size: Union[int, tuple] = 224,
patch_size: int = 16,
in_channels: int = 3,
embed_dims: int = 768,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='GELU'),
init_cfg=None):
super(ConvPatchEmbed, self).__init__(init_cfg=init_cfg)
img_size = to_2tuple(img_size)
num_patches = (img_size[1] // patch_size) * (img_size[0] // patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
conv = partial(
ConvModule,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
)
layer = []
if patch_size == 16:
layer.append(
conv(in_channels=in_channels, out_channels=embed_dims // 8))
layer.append(
conv(
in_channels=embed_dims // 8, out_channels=embed_dims // 4))
elif patch_size == 8:
layer.append(
conv(in_channels=in_channels, out_channels=embed_dims // 4))
else:
raise ValueError('For patch embedding, the patch size must be 16 '
f'or 8, but get patch size {self.patch_size}.')
layer.append(
conv(in_channels=embed_dims // 4, out_channels=embed_dims // 2))
layer.append(
conv(
in_channels=embed_dims // 2,
out_channels=embed_dims,
act_cfg=None,
))
self.proj = Sequential(*layer)
def forward(self, x: torch.Tensor):
x = self.proj(x)
Hp, Wp = x.shape[2], x.shape[3]
x = x.flatten(2).transpose(1, 2) # (B, N, C)
return x, (Hp, Wp)
class ClassAttentionBlock(BaseModule):
"""Transformer block using Class Attention.
Args:
dim (int): The feature dimension.
num_heads (int): Parallel attention heads.
mlp_ratio (float): The hidden dimension ratio for FFN.
Defaults to 4.
qkv_bias (bool): enable bias for qkv if True. Defaults to False.
drop (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path (float): Stochastic depth rate. Defaults to 0.
layer_scale_init_value (float): The initial value for layer scale.
Defaults to 1.
tokens_norm (bool): Whether to normalize all tokens or just the
cls_token in the CA. Defaults to False.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN', eps=1e-6)``.
act_cfg (dict): Config dict for activation layer.
Defaults to ``dict(type='GELU')``.
init_cfg (dict | list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.,
qkv_bias: bool = False,
drop=0.,
attn_drop=0.,
drop_path=0.,
layer_scale_init_value=1.,
tokens_norm=False,
norm_cfg=dict(type='LN', eps=1e-6),
act_cfg=dict(type='GELU'),
init_cfg=None):
super(ClassAttentionBlock, self).__init__(init_cfg=init_cfg)
self.norm1 = build_norm_layer(norm_cfg, dim)
self.attn = ClassAttntion(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = build_norm_layer(norm_cfg, dim)
self.ffn = FFN(
embed_dims=dim,
feedforward_channels=int(dim * mlp_ratio),
act_cfg=act_cfg,
ffn_drop=drop,
)
if layer_scale_init_value > 0:
self.gamma1 = nn.Parameter(layer_scale_init_value *
torch.ones(dim))
self.gamma2 = nn.Parameter(layer_scale_init_value *
torch.ones(dim))
else:
self.gamma1, self.gamma2 = 1.0, 1.0
# See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 # noqa: E501
self.tokens_norm = tokens_norm
def forward(self, x):
x_norm1 = self.norm1(x)
x_attn = torch.cat([self.attn(x_norm1), x_norm1[:, 1:]], dim=1)
x = x + self.drop_path(self.gamma1 * x_attn)
if self.tokens_norm:
x = self.norm2(x)
else:
x = torch.cat([self.norm2(x[:, 0:1]), x[:, 1:]], dim=1)
x_res = x
cls_token = x[:, 0:1]
cls_token = self.gamma2 * self.ffn(cls_token, identity=0)
x = torch.cat([cls_token, x[:, 1:]], dim=1)
x = x_res + self.drop_path(x)
return x
class LPI(BaseModule):
"""Local Patch Interaction module.
A PyTorch implementation of Local Patch Interaction module
as in XCiT introduced by `XCiT: Cross-Covariance Image Transformers
<https://arxiv.org/abs/2106.096819>`_
Local Patch Interaction module that allows explicit communication between
tokens in 3x3 windows to augment the implicit communication performed by
the block diagonal scatter attention. Implemented using 2 layers of
separable 3x3 convolutions with GeLU and BatchNorm2d
Args:
in_features (int): The input channels.
out_features (int, optional): The output channels. Defaults to None.
kernel_size (int): The kernel_size in ConvModule. Defaults to 3.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='BN')``.
act_cfg (dict): Config dict for activation layer.
Defaults to ``dict(type='GELU')``.
init_cfg (dict | list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
in_features: int,
out_features: Optional[int] = None,
kernel_size: int = 3,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='GELU'),
init_cfg=None):
super(LPI, self).__init__(init_cfg=init_cfg)
out_features = out_features or in_features
padding = kernel_size // 2
self.conv1 = ConvModule(
in_channels=in_features,
out_channels=in_features,
kernel_size=kernel_size,
padding=padding,
groups=in_features,
bias=True,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
order=('conv', 'act', 'norm'))
self.conv2 = ConvModule(
in_channels=in_features,
out_channels=out_features,
kernel_size=kernel_size,
padding=padding,
groups=out_features,
norm_cfg=None,
act_cfg=None)
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
B, N, C = x.shape
x = x.permute(0, 2, 1).reshape(B, C, H, W)
x = self.conv1(x)
x = self.conv2(x)
x = x.reshape(B, C, N).permute(0, 2, 1)
return x
class XCA(BaseModule):
r"""Cross-Covariance Attention module.
A PyTorch implementation of Cross-Covariance Attention module
as in XCiT introduced by `XCiT: Cross-Covariance Image Transformers
<https://arxiv.org/abs/2106.096819>`_
In Cross-Covariance Attention (XCA), the channels are updated using a
weighted sum. The weights are obtained from the (softmax normalized)
Cross-covariance matrix :math:`(Q^T \cdot K \in d_h \times d_h)`
Args:
dim (int): The feature dimension.
num_heads (int): Parallel attention heads. Defaults to 8.
qkv_bias (bool): enable bias for qkv if True. Defaults to False.
attn_drop (float): The drop out rate for attention output weights.
Defaults to 0.
proj_drop (float): The drop out rate for linear output weights.
Defaults to 0.
init_cfg (dict | list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
init_cfg=None):
super(XCA, self).__init__(init_cfg=init_cfg)
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
# (qkv, B, num_heads, channels per head, N)
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 4, 1)
q, k, v = qkv.unbind(0)
# Paper section 3.2 l2-Normalization and temperature scaling
q = F.normalize(q, dim=-1)
k = F.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
# (B, num_heads, C', N) -> (B, N, num_heads, C') -> (B, N C)
x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class XCABlock(BaseModule):
"""Transformer block using XCA.
Args:
dim (int): The feature dimension.
num_heads (int): Parallel attention heads.
mlp_ratio (float): The hidden dimension ratio for FFNs.
Defaults to 4.
qkv_bias (bool): enable bias for qkv if True. Defaults to False.
drop (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path (float): Stochastic depth rate. Defaults to 0.
layer_scale_init_value (float): The initial value for layer scale.
Defaults to 1.
bn_norm_cfg (dict): Config dict for batchnorm in LPI and
ConvPatchEmbed. Defaults to ``dict(type='BN')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN', eps=1e-6)``.
act_cfg (dict): Config dict for activation layer.
Defaults to ``dict(type='GELU')``.
init_cfg (dict | list[dict], optional): Initialization config dict.
"""
def __init__(self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.,
qkv_bias: bool = False,
drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
layer_scale_init_value: float = 1.,
bn_norm_cfg=dict(type='BN'),
norm_cfg=dict(type='LN', eps=1e-6),
act_cfg=dict(type='GELU'),
init_cfg=None):
super(XCABlock, self).__init__(init_cfg=init_cfg)
self.norm1 = build_norm_layer(norm_cfg, dim)
self.attn = XCA(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm3 = build_norm_layer(norm_cfg, dim)
self.local_mp = LPI(
in_features=dim,
norm_cfg=bn_norm_cfg,
act_cfg=act_cfg,
)
self.norm2 = build_norm_layer(norm_cfg, dim)
self.ffn = FFN(
embed_dims=dim,
feedforward_channels=int(dim * mlp_ratio),
act_cfg=act_cfg,
ffn_drop=drop,
)
self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones(dim))
self.gamma3 = nn.Parameter(layer_scale_init_value * torch.ones(dim))
self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones(dim))
def forward(self, x, H: int, W: int):
x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x)))
# NOTE official code has 3 then 2, so keeping it the same to be
# consistent with loaded weights See
# https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 # noqa: E501
x = x + self.drop_path(
self.gamma3 * self.local_mp(self.norm3(x), H, W))
x = x + self.drop_path(
self.gamma2 * self.ffn(self.norm2(x), identity=0))
return x
@MODELS.register_module()
class XCiT(BaseBackbone):
"""XCiT backbone.
A PyTorch implementation of XCiT backbone introduced by:
`XCiT: Cross-Covariance Image Transformers
<https://arxiv.org/abs/2106.096819>`_
Args:
img_size (int, tuple): Input image size. Defaults to 224.
patch_size (int): Patch size. Defaults to 16.
in_channels (int): Number of input channels. Defaults to 3.
embed_dims (int): Embedding dimension. Defaults to 768.
depth (int): depth of vision transformer. Defaults to 12.
cls_attn_layers (int): Depth of Class attention layers.
Defaults to 2.
num_heads (int): Number of attention heads. Defaults to 12.
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
Defaults to 4.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop_rate (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
use_pos_embed (bool): Whether to use positional encoding.
Defaults to True.
layer_scale_init_value (float): The initial value for layer scale.
Defaults to 1.
tokens_norm (bool): Whether to normalize all tokens or just the
cls_token in the CA. Defaults to False.
out_indices (Sequence[int]): Output from which layers.
Defaults to (-1, ).
frozen_stages (int): Layers to be frozen (all param fixed), and 0
means to freeze the stem stage. Defaults to -1, which means
not freeze any parameters.
bn_norm_cfg (dict): Config dict for the batch norm layers in LPI and
ConvPatchEmbed. Defaults to ``dict(type='BN')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN', eps=1e-6)``.
act_cfg (dict): Config dict for activation layer.
Defaults to ``dict(type='GELU')``.
init_cfg (dict | list[dict], optional): Initialization config dict.
"""
def __init__(self,
img_size: Union[int, tuple] = 224,
patch_size: int = 16,
in_channels: int = 3,
embed_dims: int = 768,
depth: int = 12,
cls_attn_layers: int = 2,
num_heads: int = 12,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
use_pos_embed: bool = True,
layer_scale_init_value: float = 1.,
tokens_norm: bool = False,
out_type: str = 'cls_token',
out_indices: Sequence[int] = (-1, ),
final_norm: bool = True,
frozen_stages: int = -1,
bn_norm_cfg=dict(type='BN'),
norm_cfg=dict(type='LN', eps=1e-6),
act_cfg=dict(type='GELU'),
init_cfg=dict(type='TruncNormal', layer='Linear')):
super(XCiT, self).__init__(init_cfg=init_cfg)
img_size = to_2tuple(img_size)
if (img_size[0] % patch_size != 0) or (img_size[1] % patch_size != 0):
raise ValueError(f'`patch_size` ({patch_size}) should divide '
f'the image shape ({img_size}) evenly.')
self.embed_dims = embed_dims
assert out_type in ('raw', 'featmap', 'avg_featmap', 'cls_token')
self.out_type = out_type
self.patch_embed = ConvPatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dims=embed_dims,
norm_cfg=bn_norm_cfg,
act_cfg=act_cfg,
)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
self.use_pos_embed = use_pos_embed
if use_pos_embed:
self.pos_embed = PositionalEncodingFourier(dim=embed_dims)
self.pos_drop = nn.Dropout(p=drop_rate)
self.xca_layers = nn.ModuleList()
self.ca_layers = nn.ModuleList()
self.num_layers = depth + cls_attn_layers
for _ in range(depth):
self.xca_layers.append(
XCABlock(
dim=embed_dims,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=drop_path_rate,
bn_norm_cfg=bn_norm_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
layer_scale_init_value=layer_scale_init_value,
))
for _ in range(cls_attn_layers):
self.ca_layers.append(
ClassAttentionBlock(
dim=embed_dims,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
layer_scale_init_value=layer_scale_init_value,
tokens_norm=tokens_norm,
))
if final_norm:
self.norm = build_norm_layer(norm_cfg, embed_dims)
# Transform out_indices
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
out_indices = list(out_indices)
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.num_layers + index
assert 0 <= out_indices[i] <= self.num_layers, \
f'Invalid out_indices {index}.'
self.out_indices = out_indices
if frozen_stages > self.num_layers + 1:
raise ValueError('frozen_stages must be less than '
f'{self.num_layers} but get {frozen_stages}')
self.frozen_stages = frozen_stages
def init_weights(self):
super().init_weights()
if self.init_cfg is not None and self.init_cfg['type'] == 'Pretrained':
return
trunc_normal_(self.cls_token, std=.02)
def _freeze_stages(self):
if self.frozen_stages < 0:
return
# freeze position embedding
if self.use_pos_embed:
self.pos_embed.eval()
for param in self.pos_embed.parameters():
param.requires_grad = False
# freeze patch embedding
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
# set dropout to eval model
self.pos_drop.eval()
# freeze cls_token, only use in self.Clslayers
if self.frozen_stages > len(self.xca_layers):
self.cls_token.requires_grad = False
# freeze layers
for i in range(1, self.frozen_stages):
if i <= len(self.xca_layers):
m = self.xca_layers[i - 1]
else:
m = self.ca_layers[i - len(self.xca_layers) - 1]
m.eval()
for param in m.parameters():
param.requires_grad = False
# freeze the last layer norm if all_stages are frozen
if self.frozen_stages == len(self.xca_layers) + len(self.ca_layers):
self.norm.eval()
for param in self.norm.parameters():
param.requires_grad = False
def forward(self, x):
outs = []
B = x.shape[0]
# x is (B, N, C). (Hp, Hw) is the patch resolution
x, (Hp, Wp) = self.patch_embed(x)
if self.use_pos_embed:
# (B, C, Hp, Wp) -> (B, C, N) -> (B, N, C)
pos_encoding = self.pos_embed(B, Hp, Wp)
x = x + pos_encoding.reshape(B, -1, x.size(1)).permute(0, 2, 1)
x = self.pos_drop(x)
for i, layer in enumerate(self.xca_layers):
x = layer(x, Hp, Wp)
if i in self.out_indices:
outs.append(self._format_output(x, (Hp, Wp), False))
x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
for i, layer in enumerate(self.ca_layers):
x = layer(x)
if i == len(self.ca_layers) - 1:
x = self.norm(x)
if i + len(self.xca_layers) in self.out_indices:
outs.append(self._format_output(x, (Hp, Wp), True))
return tuple(outs)
def _format_output(self, x, hw, with_cls_token: bool):
if self.out_type == 'raw':
return x
if self.out_type == 'cls_token':
if not with_cls_token:
raise ValueError(
'Cannot output cls_token since there is no cls_token.')
return x[:, 0]
patch_token = x[:, 1:] if with_cls_token else x
if self.out_type == 'featmap':
B = x.size(0)
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
if self.out_type == 'avg_featmap':
return patch_token.mean(dim=1)
def train(self, mode=True):
super().train(mode)
self._freeze_stages()

View File

@ -52,3 +52,4 @@ Import:
- configs/levit/metafile.yml
- configs/vig/metafile.yml
- configs/arcface/metafile.yml
- configs/xcit/metafile.yml

View File

@ -0,0 +1,41 @@
# Copyright (c) OpenMMLab. All rights reserved.
# The basic forward/backward tests are in ../test_models.py
import torch
from mmcls.apis import get_model
def test_out_type():
inputs = torch.rand(1, 3, 224, 224)
model = get_model(
'xcit-nano-12-p16_3rdparty_in1k',
backbone=dict(out_type='raw'),
neck=None,
head=None)
outputs = model(inputs)[0]
assert outputs.shape == (1, 197, 128)
model = get_model(
'xcit-nano-12-p16_3rdparty_in1k',
backbone=dict(out_type='featmap'),
neck=None,
head=None)
outputs = model(inputs)[0]
assert outputs.shape == (1, 128, 14, 14)
model = get_model(
'xcit-nano-12-p16_3rdparty_in1k',
backbone=dict(out_type='cls_token'),
neck=None,
head=None)
outputs = model(inputs)[0]
assert outputs.shape == (1, 128)
model = get_model(
'xcit-nano-12-p16_3rdparty_in1k',
backbone=dict(out_type='avg_featmap'),
neck=None,
head=None)
outputs = model(inputs)[0]
assert outputs.shape == (1, 128)

View File

@ -0,0 +1,76 @@
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
import pytest
import torch
import mmcls.models
from mmcls.apis import ModelHub, get_model
@dataclass
class Cfg:
name: str
backbone: type
num_classes: int = 1000
build: bool = True
forward: bool = True
backward: bool = True
input_shape: tuple = (1, 3, 224, 224)
test_list = [
Cfg(name='xcit-small-12-p16_3rdparty_in1k', backbone=mmcls.models.XCiT),
Cfg(name='xcit-nano-12-p8_3rdparty-dist_in1k-384px',
backbone=mmcls.models.XCiT,
input_shape=(1, 3, 384, 384)),
]
@pytest.mark.parametrize('cfg', test_list)
def test_build(cfg: Cfg):
if not cfg.build:
return
model_name = cfg.name
ModelHub._register_mmcls_models()
assert ModelHub.has(model_name)
model = get_model(model_name)
backbone_class = cfg.backbone
assert isinstance(model.backbone, backbone_class)
@pytest.mark.parametrize('cfg', test_list)
def test_forward(cfg: Cfg):
if not cfg.forward:
return
model = get_model(cfg.name)
inputs = torch.rand(*cfg.input_shape)
outputs = model(inputs)
assert outputs.shape == (1, cfg.num_classes)
feats = model.extract_feat(inputs)
assert isinstance(feats, tuple)
assert len(feats) == 1
@pytest.mark.parametrize('cfg', test_list)
def test_backward(cfg: Cfg):
if not cfg.backward:
return
model = get_model(cfg.name)
inputs = torch.rand(*cfg.input_shape)
outputs = model(inputs)
outputs.mean().backward()
for n, x in model.named_parameters():
assert x.grad is not None, f'No gradient for {n}'
num_grad = sum(
[x.grad.numel() for x in model.parameters() if x.grad is not None])
assert outputs.shape[-1] == cfg.num_classes
num_params = sum([x.numel() for x in model.parameters()])
assert num_params == num_grad, 'Some parameters are missing gradients'
assert not torch.isnan(outputs).any(), 'Output included NaNs'