[Feature] Support XCiT Backbone. (#1305)
* update model file * Update XCiT implementation and configs. * Update metafiles * Update metafile * Fix floor divide * Imporve memory usage --------- Co-authored-by: qingtian <459291290@qq.com> Co-authored-by: mzr1996 <mzr1996@163.com>pull/1345/head
parent
bedf4e9f64
commit
8352951f3d
|
@ -1,5 +1,6 @@
|
|||
import argparse
|
||||
import copy
|
||||
import re
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -70,6 +71,7 @@ def parse_args():
|
|||
parser.add_argument('--out', '-o', type=Path, help='The output path.')
|
||||
parser.add_argument(
|
||||
'--view', action='store_true', help='Only pretty print the metafile.')
|
||||
parser.add_argument('--csv', type=str, help='Use a csv to update models.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
@ -165,14 +167,17 @@ def fill_collection(collection: dict):
|
|||
return collection
|
||||
|
||||
|
||||
def fill_model(model: dict, defaults: dict):
|
||||
def fill_model_by_prompt(model: dict, defaults: dict):
|
||||
# Name
|
||||
if model.get('Name') is None:
|
||||
name = prompt(
|
||||
'Please input the model [red]name[/]: ', allow_empty=False)
|
||||
model['Name'] = name
|
||||
|
||||
# In Collection
|
||||
model['In Collection'] = defaults.get('In Collection')
|
||||
|
||||
# Config
|
||||
config = model.get('Config')
|
||||
if config is None:
|
||||
config = prompt(
|
||||
|
@ -182,6 +187,7 @@ def fill_model(model: dict, defaults: dict):
|
|||
config = str(Path(config).absolute().relative_to(MMCLS_ROOT))
|
||||
model['Config'] = config
|
||||
|
||||
# Metadata.Flops, Metadata.Parameters
|
||||
flops = model.get('Metadata', {}).get('FLOPs')
|
||||
params = model.get('Metadata', {}).get('Parameters')
|
||||
if model.get('Config') is not None and (
|
||||
|
@ -280,6 +286,100 @@ def fill_model(model: dict, defaults: dict):
|
|||
return model
|
||||
|
||||
|
||||
def update_model_by_dict(model: dict, update_dict: dict, defaults: dict):
|
||||
# Name
|
||||
if 'name override' in update_dict:
|
||||
model['Name'] = update_dict['name override']
|
||||
|
||||
# In Collection
|
||||
model['In Collection'] = defaults.get('In Collection')
|
||||
|
||||
# Config
|
||||
if 'config' in update_dict:
|
||||
config = update_dict['config'].strip()
|
||||
config = str(Path(config).absolute().relative_to(MMCLS_ROOT))
|
||||
config_updated = (config != model.get('Config'))
|
||||
model['Config'] = config
|
||||
else:
|
||||
config_updated = False
|
||||
|
||||
# Metadata.Flops, Metadata.Parameters
|
||||
flops = model.get('Metadata', {}).get('FLOPs')
|
||||
params = model.get('Metadata', {}).get('Parameters')
|
||||
if config_updated and (flops is None or params is None):
|
||||
print(f'Automatically compute FLOPs and Parameters of {model["Name"]}')
|
||||
flops, params = get_flops(str(MMCLS_ROOT / model['Config']))
|
||||
|
||||
model.setdefault('Metadata', {})
|
||||
model['Metadata']['FLOPs'] = flops
|
||||
model['Metadata']['Parameters'] = params
|
||||
|
||||
# Metadata.Training Data
|
||||
if 'metadata.training data' in update_dict:
|
||||
train_data = update_dict['metadata.training data'].strip()
|
||||
train_data = re.split(r'\s+', train_data)
|
||||
if len(train_data) > 1:
|
||||
model['Metadata']['Training Data'] = train_data
|
||||
elif len(train_data) == 1:
|
||||
model['Metadata']['Training Data'] = train_data[0]
|
||||
|
||||
# Results.Dataset
|
||||
if 'results.dataset' in update_dict:
|
||||
test_data = update_dict['results.dataset'].strip()
|
||||
results = model.get('Results') or [{}]
|
||||
result = results[0]
|
||||
result['Dataset'] = test_data
|
||||
model['Results'] = results
|
||||
|
||||
# Results.Metrics.Top 1 Accuracy
|
||||
result = None
|
||||
if 'results.metrics.top 1 accuracy' in update_dict:
|
||||
top1 = update_dict['results.metrics.top 1 accuracy']
|
||||
results = model.get('Results') or [{}]
|
||||
result = results[0]
|
||||
result.setdefault('Metrics', {})
|
||||
result['Metrics']['Top 1 Accuracy'] = round(float(top1), 2)
|
||||
task = 'Image Classification'
|
||||
model['Results'] = results
|
||||
|
||||
# Results.Metrics.Top 5 Accuracy
|
||||
if 'results.metrics.top 5 accuracy' in update_dict:
|
||||
top5 = update_dict['results.metrics.top 5 accuracy']
|
||||
results = model.get('Results') or [{}]
|
||||
result = results[0]
|
||||
result.setdefault('Metrics', {})
|
||||
result['Metrics']['Top 5 Accuracy'] = round(float(top5), 2)
|
||||
task = 'Image Classification'
|
||||
model['Results'] = results
|
||||
|
||||
if result is not None:
|
||||
result['Metrics']['Task'] = task
|
||||
|
||||
# Weights
|
||||
if 'weights' in update_dict:
|
||||
weights = update_dict['weights'].strip()
|
||||
model['Weights'] = weights
|
||||
|
||||
# Converted From.Code
|
||||
if 'converted from.code' in update_dict:
|
||||
from_code = update_dict['converted from.code'].strip()
|
||||
model.setdefault('Converted From', {})
|
||||
model['Converted From']['Code'] = from_code
|
||||
|
||||
# Converted From.Weights
|
||||
if 'converted from.weights' in update_dict:
|
||||
from_weight = update_dict['converted from.weights'].strip()
|
||||
model.setdefault('Converted From', {})
|
||||
model['Converted From']['Weights'] = from_weight
|
||||
|
||||
order = [
|
||||
'Name', 'Metadata', 'In Collection', 'Results', 'Weights', 'Config',
|
||||
'Converted From'
|
||||
]
|
||||
model = {k: model[k] for k in sorted(model.keys(), key=order.index)}
|
||||
return model
|
||||
|
||||
|
||||
def format_collection(collection: dict):
|
||||
yaml_str = yaml_dump(collection)
|
||||
return Panel(
|
||||
|
@ -325,35 +425,44 @@ def main():
|
|||
models = content.get('Models', [])
|
||||
updated_models = []
|
||||
|
||||
try:
|
||||
if args.csv is not None:
|
||||
import pandas as pd
|
||||
df = pd.read_csv(args.csv).rename(columns=lambda x: x.strip().lower())
|
||||
assert df['name'].is_unique, 'The csv has duplicated model names.'
|
||||
models_dict = {item['Name']: item for item in models}
|
||||
for update_dict in df.to_dict('records'):
|
||||
assert 'name' in update_dict, 'The csv must have the `Name` field.'
|
||||
model_name = update_dict['name'].strip()
|
||||
model = models_dict.pop(model_name, {'Name': model_name})
|
||||
model = update_model_by_dict(model, update_dict, model_defaults)
|
||||
updated_models.append(model)
|
||||
updated_models.extend(models_dict.values())
|
||||
else:
|
||||
for model in models:
|
||||
console.print(format_model(model))
|
||||
ori_model = copy.deepcopy(model)
|
||||
model = fill_model(model, model_defaults)
|
||||
model = fill_model_by_prompt(model, model_defaults)
|
||||
if ori_model != model:
|
||||
console.print(format_model(model))
|
||||
updated_models.append(model)
|
||||
|
||||
while Confirm.ask('Add new model?'):
|
||||
model = fill_model({}, model_defaults)
|
||||
model = fill_model_by_prompt({}, model_defaults)
|
||||
updated_models.append(model)
|
||||
finally:
|
||||
# Save updated models even error happened.
|
||||
updated_models.sort(key=lambda item: (item.get('Metadata', {}).get(
|
||||
'FLOPs', 0), len(item['Name'])))
|
||||
if args.out is not None:
|
||||
with open(args.out, 'w') as f:
|
||||
yaml_dump({'Collections': [collection]}, f)
|
||||
f.write('\n')
|
||||
yaml_dump({'Models': updated_models}, f)
|
||||
else:
|
||||
modelindex = {
|
||||
'Collections': [collection],
|
||||
'Models': updated_models
|
||||
}
|
||||
yaml_str = yaml_dump(modelindex)
|
||||
console.print(Syntax(yaml_str, 'yaml', background_color='default'))
|
||||
console.print('Specify [red]`--out`[/] to dump to file.')
|
||||
|
||||
# Save updated models even error happened.
|
||||
updated_models.sort(key=lambda item: (item.get('Metadata', {}).get(
|
||||
'FLOPs', 0), len(item['Name'])))
|
||||
if args.out is not None:
|
||||
with open(args.out, 'w') as f:
|
||||
yaml_dump({'Collections': [collection]}, f)
|
||||
f.write('\n')
|
||||
yaml_dump({'Models': updated_models}, f)
|
||||
else:
|
||||
modelindex = {'Collections': [collection], 'Models': updated_models}
|
||||
yaml_str = yaml_dump(modelindex)
|
||||
console.print(Syntax(yaml_str, 'yaml', background_color='default'))
|
||||
console.print('Specify [red]`--out`[/] to dump to file.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
# XCiT
|
||||
|
||||
> [XCiT: Cross-Covariance Image Transformers](https://arxiv.org/abs/2106.09681)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
Following their success in natural language processing, transformers have recently shown much promise for computer vision. The self-attention operation underlying transformers yields global interactions between all tokens ,i.e. words or image patches, and enables flexible modelling of image data beyond the local interactions of convolutions. This flexibility, however, comes with a quadratic complexity in time and memory, hindering application to long sequences and high-resolution images. We propose a "transposed" version of self-attention that operates across feature channels rather than tokens, where the interactions are based on the cross-covariance matrix between keys and queries. The resulting cross-covariance attention (XCA) has linear complexity in the number of tokens, and allows efficient processing of high-resolution images. Our cross-covariance image transformer (XCiT) is built upon XCA. It combines the accuracy of conventional transformers with the scalability of convolutional architectures. We validate the effectiveness and generality of XCiT by reporting excellent results on multiple vision benchmarks, including image classification and self-supervised feature learning on ImageNet-1k, object detection and instance segmentation on COCO, and semantic segmentation on ADE20k.
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/26739999/218900814-64a44606-150b-4757-aec8-7015c77a9fd1.png" width="60%"/>
|
||||
</div>
|
||||
|
||||
## Results and models
|
||||
|
||||
### ImageNet-1k
|
||||
|
||||
| Model | Pretrain | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||
| :-------------------------------------------: | :----------: | :-------: | :------: | :-------: | :-------: | :-------------------------------------------------: | :-------------------------------------------------------: |
|
||||
| xcit-nano-12-p16_3rdparty_in1k\* | From scratch | 3.05 | 0.56 | 70.35 | 89.98 | [config](./xcit-nano-12-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p16_3rdparty_in1k_20230213-ed776c38.pth) |
|
||||
| xcit-nano-12-p16_3rdparty-dist_in1k\* | Distillation | 3.05 | 0.56 | 72.36 | 91.02 | [config](./xcit-nano-12-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p16_3rdparty-dist_in1k_20230213-fb247f7b.pth) |
|
||||
| xcit-nano-12-p16_3rdparty-dist_in1k-384px\* | Distillation | 3.05 | 1.64 | 74.93 | 92.42 | [config](./xcit-nano-12-p16_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p16_3rdparty-dist_in1k-384px_20230213-712db4d4.pth) |
|
||||
| xcit-nano-12-p8_3rdparty_in1k\* | From scratch | 3.05 | 2.16 | 73.80 | 92.08 | [config](./xcit-nano-12-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p8_3rdparty_in1k_20230213-3370c293.pth) |
|
||||
| xcit-nano-12-p8_3rdparty-dist_in1k\* | Distillation | 3.05 | 2.16 | 76.17 | 93.08 | [config](./xcit-nano-12-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p8_3rdparty-dist_in1k_20230213-2f87d2b3.pth) |
|
||||
| xcit-nano-12-p8_3rdparty-dist_in1k-384px\* | Distillation | 3.05 | 6.34 | 77.69 | 94.09 | [config](./xcit-nano-12-p8_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p8_3rdparty-dist_in1k-384px_20230213-09d925ef.pth) |
|
||||
| xcit-tiny-12-p16_3rdparty_in1k\* | From scratch | 6.72 | 1.24 | 77.21 | 93.62 | [config](./xcit-tiny-12-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p16_3rdparty_in1k_20230213-82c547ca.pth) |
|
||||
| xcit-tiny-12-p16_3rdparty-dist_in1k\* | Distillation | 6.72 | 1.24 | 78.70 | 94.12 | [config](./xcit-tiny-12-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p16_3rdparty-dist_in1k_20230213-d5fde0a3.pth) |
|
||||
| xcit-tiny-24-p16_3rdparty_in1k\* | From scratch | 12.12 | 2.34 | 79.47 | 94.85 | [config](./xcit-tiny-24-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p16_3rdparty_in1k_20230213-366c1cd0.pth) |
|
||||
| xcit-tiny-24-p16_3rdparty-dist_in1k\* | Distillation | 12.12 | 2.34 | 80.51 | 95.17 | [config](./xcit-tiny-24-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p16_3rdparty-dist_in1k_20230213-b472e80a.pth) |
|
||||
| xcit-tiny-12-p16_3rdparty-dist_in1k-384px\* | Distillation | 6.72 | 3.64 | 80.58 | 95.38 | [config](./xcit-tiny-12-p16_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p16_3rdparty-dist_in1k-384px_20230213-00a20023.pth) |
|
||||
| xcit-tiny-12-p8_3rdparty_in1k\* | From scratch | 6.71 | 4.81 | 79.75 | 94.88 | [config](./xcit-tiny-12-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p8_3rdparty_in1k_20230213-8b02f8f5.pth) |
|
||||
| xcit-tiny-12-p8_3rdparty-dist_in1k\* | Distillation | 6.71 | 4.81 | 81.26 | 95.46 | [config](./xcit-tiny-12-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p8_3rdparty-dist_in1k_20230213-f3f9b44f.pth) |
|
||||
| xcit-tiny-24-p8_3rdparty_in1k\* | From scratch | 12.11 | 9.21 | 81.70 | 95.90 | [config](./xcit-tiny-24-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p8_3rdparty_in1k_20230213-4b9ba392.pth) |
|
||||
| xcit-tiny-24-p8_3rdparty-dist_in1k\* | Distillation | 12.11 | 9.21 | 82.62 | 96.16 | [config](./xcit-tiny-24-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p8_3rdparty-dist_in1k_20230213-ad9c44b0.pth) |
|
||||
| xcit-tiny-12-p8_3rdparty-dist_in1k-384px\* | Distillation | 6.71 | 14.13 | 82.46 | 96.22 | [config](./xcit-tiny-12-p8_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p8_3rdparty-dist_in1k-384px_20230213-a072174a.pth) |
|
||||
| xcit-tiny-24-p16_3rdparty-dist_in1k-384px\* | Distillation | 12.12 | 6.87 | 82.43 | 96.20 | [config](./xcit-tiny-24-p16_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p16_3rdparty-dist_in1k-384px_20230213-20e13917.pth) |
|
||||
| xcit-tiny-24-p8_3rdparty-dist_in1k-384px\* | Distillation | 12.11 | 27.05 | 83.77 | 96.72 | [config](./xcit-tiny-24-p8_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p8_3rdparty-dist_in1k-384px_20230213-30d5e5ec.pth) |
|
||||
| xcit-small-12-p16_3rdparty_in1k\* | From scratch | 26.25 | 4.81 | 81.87 | 95.77 | [config](./xcit-small-12-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p16_3rdparty_in1k_20230213-d36779d2.pth) |
|
||||
| xcit-small-12-p16_3rdparty-dist_in1k\* | Distillation | 26.25 | 4.81 | 83.12 | 96.41 | [config](./xcit-small-12-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p16_3rdparty-dist_in1k_20230213-c95bbae1.pth) |
|
||||
| xcit-small-24-p16_3rdparty_in1k\* | From scratch | 47.67 | 9.10 | 82.38 | 95.93 | [config](./xcit-small-24-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p16_3rdparty_in1k_20230213-40febe38.pth) |
|
||||
| xcit-small-24-p16_3rdparty-dist_in1k\* | Distillation | 47.67 | 9.10 | 83.70 | 96.61 | [config](./xcit-small-24-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p16_3rdparty-dist_in1k_20230213-130d7262.pth) |
|
||||
| xcit-small-12-p16_3rdparty-dist_in1k-384px\* | Distillation | 26.25 | 14.14 | 84.74 | 97.19 | [config](./xcit-small-12-p16_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p16_3rdparty-dist_in1k-384px_20230213-ba36c982.pth) |
|
||||
| xcit-small-12-p8_3rdparty_in1k\* | From scratch | 26.21 | 18.69 | 83.21 | 96.41 | [config](./xcit-small-12-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p8_3rdparty_in1k_20230213-9e364ce3.pth) |
|
||||
| xcit-small-12-p8_3rdparty-dist_in1k\* | Distillation | 26.21 | 18.69 | 83.97 | 96.81 | [config](./xcit-small-12-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p8_3rdparty-dist_in1k_20230213-71886580.pth) |
|
||||
| xcit-small-24-p16_3rdparty-dist_in1k-384px\* | Distillation | 47.67 | 26.72 | 85.10 | 97.32 | [config](./xcit-small-24-p16_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p16_3rdparty-dist_in1k-384px_20230213-28fa2d0e.pth) |
|
||||
| xcit-small-24-p8_3rdparty_in1k\* | From scratch | 47.63 | 35.81 | 83.62 | 96.51 | [config](./xcit-small-24-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p8_3rdparty_in1k_20230213-280ebcc7.pth) |
|
||||
| xcit-small-24-p8_3rdparty-dist_in1k\* | Distillation | 47.63 | 35.81 | 84.68 | 97.07 | [config](./xcit-small-24-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p8_3rdparty-dist_in1k_20230213-f2773c78.pth) |
|
||||
| xcit-small-12-p8_3rdparty-dist_in1k-384px\* | Distillation | 26.21 | 54.92 | 85.12 | 97.31 | [config](./xcit-small-12-p8_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p8_3rdparty-dist_in1k-384px_20230214-9f2178bc.pth) |
|
||||
| xcit-small-24-p8_3rdparty-dist_in1k-384px\* | Distillation | 47.63 | 105.24 | 85.57 | 97.60 | [config](./xcit-small-24-p8_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p8_3rdparty-dist_in1k-384px_20230214-57298eca.pth) |
|
||||
| xcit-medium-24-p16_3rdparty_in1k\* | From scratch | 84.40 | 16.13 | 82.56 | 95.82 | [config](./xcit-medium-24-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p16_3rdparty_in1k_20230213-ad0aa92e.pth) |
|
||||
| xcit-medium-24-p16_3rdparty-dist_in1k\* | Distillation | 84.40 | 16.13 | 84.15 | 96.82 | [config](./xcit-medium-24-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p16_3rdparty-dist_in1k_20230213-aca5cd0c.pth) |
|
||||
| xcit-medium-24-p16_3rdparty-dist_in1k-384px\* | Distillation | 84.40 | 47.39 | 85.47 | 97.49 | [config](./xcit-medium-24-p16_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p16_3rdparty-dist_in1k-384px_20230214-6c23a201.pth) |
|
||||
| xcit-medium-24-p8_3rdparty_in1k\* | From scratch | 84.32 | 63.52 | 83.61 | 96.23 | [config](./xcit-medium-24-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p8_3rdparty_in1k_20230214-c362850b.pth) |
|
||||
| xcit-medium-24-p8_3rdparty-dist_in1k\* | Distillation | 84.32 | 63.52 | 85.00 | 97.16 | [config](./xcit-medium-24-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p8_3rdparty-dist_in1k_20230214-625c953b.pth) |
|
||||
| xcit-medium-24-p8_3rdparty-dist_in1k-384px\* | Distillation | 84.32 | 186.67 | 85.87 | 97.61 | [config](./xcit-medium-24-p8_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p8_3rdparty-dist_in1k-384px_20230214-5db925e0.pth) |
|
||||
| xcit-large-24-p16_3rdparty_in1k\* | From scratch | 189.10 | 35.86 | 82.97 | 95.86 | [config](./xcit-large-24-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p16_3rdparty_in1k_20230214-d29d2529.pth) |
|
||||
| xcit-large-24-p16_3rdparty-dist_in1k\* | Distillation | 189.10 | 35.86 | 84.61 | 97.07 | [config](./xcit-large-24-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p16_3rdparty-dist_in1k_20230214-4fea599c.pth) |
|
||||
| xcit-large-24-p16_3rdparty-dist_in1k-384px\* | Distillation | 189.10 | 105.35 | 85.78 | 97.60 | [config](./xcit-large-24-p16_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p16_3rdparty-dist_in1k-384px_20230214-bd515a34.pth) |
|
||||
| xcit-large-24-p8_3rdparty_in1k\* | From scratch | 188.93 | 141.23 | 84.23 | 96.58 | [config](./xcit-large-24-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p8_3rdparty_in1k_20230214-08f2f664.pth) |
|
||||
| xcit-large-24-p8_3rdparty-dist_in1k\* | Distillation | 188.93 | 141.23 | 85.14 | 97.32 | [config](./xcit-large-24-p8_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p8_3rdparty-dist_in1k_20230214-8c092b34.pth) |
|
||||
| xcit-large-24-p8_3rdparty-dist_in1k-384px\* | Distillation | 188.93 | 415.00 | 86.13 | 97.75 | [config](./xcit-large-24-p8_8xb128_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p8_3rdparty-dist_in1k-384px_20230214-9f718b1a.pth) |
|
||||
|
||||
*Models with * are converted from the [official repo](https://github.com/facebookresearch/xcit). The config files of these models are only for inference. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{el2021xcit,
|
||||
title={XCiT: Cross-Covariance Image Transformers},
|
||||
author={El-Nouby, Alaaeldin and Touvron, Hugo and Caron, Mathilde and Bojanowski, Piotr and Douze, Matthijs and Joulin, Armand and Laptev, Ivan and Neverova, Natalia and Synnaeve, Gabriel and Verbeek, Jakob and others},
|
||||
journal={arXiv preprint arXiv:2106.09681},
|
||||
year={2021}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,727 @@
|
|||
Collections:
|
||||
- Name: XCiT
|
||||
Metadata:
|
||||
Architecture:
|
||||
- Class Attention
|
||||
- Local Patch Interaction
|
||||
- Cross-Covariance Attention
|
||||
Paper:
|
||||
Title: 'XCiT: Cross-Covariance Image Transformers'
|
||||
URL: https://arxiv.org/abs/2106.09681
|
||||
README: configs/xcit/README.md
|
||||
|
||||
Models:
|
||||
- Name: xcit-nano-12-p16_3rdparty_in1k
|
||||
Metadata:
|
||||
FLOPs: 557074560
|
||||
Parameters: 3053224
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 70.35
|
||||
Top 5 Accuracy: 89.98
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p16_3rdparty_in1k_20230213-ed776c38.pth
|
||||
Config: configs/xcit/xcit-nano-12-p16_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_224.pth
|
||||
- Name: xcit-nano-12-p16_3rdparty-dist_in1k
|
||||
Metadata:
|
||||
FLOPs: 557074560
|
||||
Parameters: 3053224
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 72.36
|
||||
Top 5 Accuracy: 91.02
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p16_3rdparty-dist_in1k_20230213-fb247f7b.pth
|
||||
Config: configs/xcit/xcit-nano-12-p16_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_224_dist.pth
|
||||
- Name: xcit-tiny-12-p16_3rdparty_in1k
|
||||
Metadata:
|
||||
FLOPs: 1239698112
|
||||
Parameters: 6716272
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 77.21
|
||||
Top 5 Accuracy: 93.62
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p16_3rdparty_in1k_20230213-82c547ca.pth
|
||||
Config: configs/xcit/xcit-tiny-12-p16_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_224.pth
|
||||
- Name: xcit-tiny-12-p16_3rdparty-dist_in1k
|
||||
Metadata:
|
||||
FLOPs: 1239698112
|
||||
Parameters: 6716272
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 78.7
|
||||
Top 5 Accuracy: 94.12
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p16_3rdparty-dist_in1k_20230213-d5fde0a3.pth
|
||||
Config: configs/xcit/xcit-tiny-12-p16_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_224_dist.pth
|
||||
- Name: xcit-nano-12-p16_3rdparty-dist_in1k-384px
|
||||
Metadata:
|
||||
FLOPs: 1636347520
|
||||
Parameters: 3053224
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 74.93
|
||||
Top 5 Accuracy: 92.42
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p16_3rdparty-dist_in1k-384px_20230213-712db4d4.pth
|
||||
Config: configs/xcit/xcit-nano-12-p16_8xb128_in1k-384px.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_384_dist.pth
|
||||
- Name: xcit-nano-12-p8_3rdparty_in1k
|
||||
Metadata:
|
||||
FLOPs: 2156861056
|
||||
Parameters: 3049016
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 73.8
|
||||
Top 5 Accuracy: 92.08
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p8_3rdparty_in1k_20230213-3370c293.pth
|
||||
Config: configs/xcit/xcit-nano-12-p8_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_224.pth
|
||||
- Name: xcit-nano-12-p8_3rdparty-dist_in1k
|
||||
Metadata:
|
||||
FLOPs: 2156861056
|
||||
Parameters: 3049016
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 76.17
|
||||
Top 5 Accuracy: 93.08
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p8_3rdparty-dist_in1k_20230213-2f87d2b3.pth
|
||||
Config: configs/xcit/xcit-nano-12-p8_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_224_dist.pth
|
||||
- Name: xcit-tiny-24-p16_3rdparty_in1k
|
||||
Metadata:
|
||||
FLOPs: 2339305152
|
||||
Parameters: 12116896
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 79.47
|
||||
Top 5 Accuracy: 94.85
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p16_3rdparty_in1k_20230213-366c1cd0.pth
|
||||
Config: configs/xcit/xcit-tiny-24-p16_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_224.pth
|
||||
- Name: xcit-tiny-24-p16_3rdparty-dist_in1k
|
||||
Metadata:
|
||||
FLOPs: 2339305152
|
||||
Parameters: 12116896
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 80.51
|
||||
Top 5 Accuracy: 95.17
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p16_3rdparty-dist_in1k_20230213-b472e80a.pth
|
||||
Config: configs/xcit/xcit-tiny-24-p16_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_224_dist.pth
|
||||
- Name: xcit-tiny-12-p16_3rdparty-dist_in1k-384px
|
||||
Metadata:
|
||||
FLOPs: 3641468352
|
||||
Parameters: 6716272
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 80.58
|
||||
Top 5 Accuracy: 95.38
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p16_3rdparty-dist_in1k-384px_20230213-00a20023.pth
|
||||
Config: configs/xcit/xcit-tiny-12-p16_8xb128_in1k-384px.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_384_dist.pth
|
||||
- Name: xcit-tiny-12-p8_3rdparty_in1k
|
||||
Metadata:
|
||||
FLOPs: 4807399872
|
||||
Parameters: 6706504
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 79.75
|
||||
Top 5 Accuracy: 94.88
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p8_3rdparty_in1k_20230213-8b02f8f5.pth
|
||||
Config: configs/xcit/xcit-tiny-12-p8_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_224.pth
|
||||
- Name: xcit-tiny-12-p8_3rdparty-dist_in1k
|
||||
Metadata:
|
||||
FLOPs: 4807399872
|
||||
Parameters: 6706504
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 81.26
|
||||
Top 5 Accuracy: 95.46
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p8_3rdparty-dist_in1k_20230213-f3f9b44f.pth
|
||||
Config: configs/xcit/xcit-tiny-12-p8_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_224_dist.pth
|
||||
- Name: xcit-small-12-p16_3rdparty_in1k
|
||||
Metadata:
|
||||
FLOPs: 4814951808
|
||||
Parameters: 26253304
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 81.87
|
||||
Top 5 Accuracy: 95.77
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p16_3rdparty_in1k_20230213-d36779d2.pth
|
||||
Config: configs/xcit/xcit-small-12-p16_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224.pth
|
||||
- Name: xcit-small-12-p16_3rdparty-dist_in1k
|
||||
Metadata:
|
||||
FLOPs: 4814951808
|
||||
Parameters: 26253304
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.12
|
||||
Top 5 Accuracy: 96.41
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p16_3rdparty-dist_in1k_20230213-c95bbae1.pth
|
||||
Config: configs/xcit/xcit-small-12-p16_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224_dist.pth
|
||||
- Name: xcit-nano-12-p8_3rdparty-dist_in1k-384px
|
||||
Metadata:
|
||||
FLOPs: 6337760896
|
||||
Parameters: 3049016
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 77.69
|
||||
Top 5 Accuracy: 94.09
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-nano-12-p8_3rdparty-dist_in1k-384px_20230213-09d925ef.pth
|
||||
Config: configs/xcit/xcit-nano-12-p8_8xb128_in1k-384px.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_384_dist.pth
|
||||
- Name: xcit-tiny-24-p16_3rdparty-dist_in1k-384px
|
||||
Metadata:
|
||||
FLOPs: 6872966592
|
||||
Parameters: 12116896
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 82.43
|
||||
Top 5 Accuracy: 96.2
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p16_3rdparty-dist_in1k-384px_20230213-20e13917.pth
|
||||
Config: configs/xcit/xcit-tiny-24-p16_8xb128_in1k-384px.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_384_dist.pth
|
||||
- Name: xcit-small-24-p16_3rdparty_in1k
|
||||
Metadata:
|
||||
FLOPs: 9095064960
|
||||
Parameters: 47671384
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 82.38
|
||||
Top 5 Accuracy: 95.93
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p16_3rdparty_in1k_20230213-40febe38.pth
|
||||
Config: configs/xcit/xcit-small-24-p16_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_224.pth
|
||||
- Name: xcit-small-24-p16_3rdparty-dist_in1k
|
||||
Metadata:
|
||||
FLOPs: 9095064960
|
||||
Parameters: 47671384
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.7
|
||||
Top 5 Accuracy: 96.61
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p16_3rdparty-dist_in1k_20230213-130d7262.pth
|
||||
Config: configs/xcit/xcit-small-24-p16_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_224_dist.pth
|
||||
- Name: xcit-tiny-24-p8_3rdparty_in1k
|
||||
Metadata:
|
||||
FLOPs: 9205828032
|
||||
Parameters: 12107128
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 81.7
|
||||
Top 5 Accuracy: 95.9
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p8_3rdparty_in1k_20230213-4b9ba392.pth
|
||||
Config: configs/xcit/xcit-tiny-24-p8_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_224.pth
|
||||
- Name: xcit-tiny-24-p8_3rdparty-dist_in1k
|
||||
Metadata:
|
||||
FLOPs: 9205828032
|
||||
Parameters: 12107128
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 82.62
|
||||
Top 5 Accuracy: 96.16
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p8_3rdparty-dist_in1k_20230213-ad9c44b0.pth
|
||||
Config: configs/xcit/xcit-tiny-24-p8_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_224_dist.pth
|
||||
- Name: xcit-tiny-12-p8_3rdparty-dist_in1k-384px
|
||||
Metadata:
|
||||
FLOPs: 14126142912
|
||||
Parameters: 6706504
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 82.46
|
||||
Top 5 Accuracy: 96.22
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-12-p8_3rdparty-dist_in1k-384px_20230213-a072174a.pth
|
||||
Config: configs/xcit/xcit-tiny-12-p8_8xb128_in1k-384px.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_384_dist.pth
|
||||
- Name: xcit-small-12-p16_3rdparty-dist_in1k-384px
|
||||
Metadata:
|
||||
FLOPs: 14143179648
|
||||
Parameters: 26253304
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 84.74
|
||||
Top 5 Accuracy: 97.19
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p16_3rdparty-dist_in1k-384px_20230213-ba36c982.pth
|
||||
Config: configs/xcit/xcit-small-12-p16_8xb128_in1k-384px.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_384_dist.pth
|
||||
- Name: xcit-medium-24-p16_3rdparty_in1k
|
||||
Metadata:
|
||||
FLOPs: 16129561088
|
||||
Parameters: 84395752
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 82.56
|
||||
Top 5 Accuracy: 95.82
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p16_3rdparty_in1k_20230213-ad0aa92e.pth
|
||||
Config: configs/xcit/xcit-medium-24-p16_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_224.pth
|
||||
- Name: xcit-medium-24-p16_3rdparty-dist_in1k
|
||||
Metadata:
|
||||
FLOPs: 16129561088
|
||||
Parameters: 84395752
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 84.15
|
||||
Top 5 Accuracy: 96.82
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p16_3rdparty-dist_in1k_20230213-aca5cd0c.pth
|
||||
Config: configs/xcit/xcit-medium-24-p16_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_224_dist.pth
|
||||
- Name: xcit-small-12-p8_3rdparty_in1k
|
||||
Metadata:
|
||||
FLOPs: 18691601280
|
||||
Parameters: 26213032
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.21
|
||||
Top 5 Accuracy: 96.41
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p8_3rdparty_in1k_20230213-9e364ce3.pth
|
||||
Config: configs/xcit/xcit-small-12-p8_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_224.pth
|
||||
- Name: xcit-small-12-p8_3rdparty-dist_in1k
|
||||
Metadata:
|
||||
FLOPs: 18691601280
|
||||
Parameters: 26213032
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.97
|
||||
Top 5 Accuracy: 96.81
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p8_3rdparty-dist_in1k_20230213-71886580.pth
|
||||
Config: configs/xcit/xcit-small-12-p8_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_224_dist.pth
|
||||
- Name: xcit-small-24-p16_3rdparty-dist_in1k-384px
|
||||
Metadata:
|
||||
FLOPs: 26721471360
|
||||
Parameters: 47671384
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 85.1
|
||||
Top 5 Accuracy: 97.32
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p16_3rdparty-dist_in1k-384px_20230213-28fa2d0e.pth
|
||||
Config: configs/xcit/xcit-small-24-p16_8xb128_in1k-384px.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_384_dist.pth
|
||||
- Name: xcit-tiny-24-p8_3rdparty-dist_in1k-384px
|
||||
Metadata:
|
||||
FLOPs: 27052135872
|
||||
Parameters: 12107128
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.77
|
||||
Top 5 Accuracy: 96.72
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-tiny-24-p8_3rdparty-dist_in1k-384px_20230213-30d5e5ec.pth
|
||||
Config: configs/xcit/xcit-tiny-24-p8_8xb128_in1k-384px.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_384_dist.pth
|
||||
- Name: xcit-small-24-p8_3rdparty_in1k
|
||||
Metadata:
|
||||
FLOPs: 35812053888
|
||||
Parameters: 47631112
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.62
|
||||
Top 5 Accuracy: 96.51
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p8_3rdparty_in1k_20230213-280ebcc7.pth
|
||||
Config: configs/xcit/xcit-small-24-p8_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_224.pth
|
||||
- Name: xcit-small-24-p8_3rdparty-dist_in1k
|
||||
Metadata:
|
||||
FLOPs: 35812053888
|
||||
Parameters: 47631112
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 84.68
|
||||
Top 5 Accuracy: 97.07
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p8_3rdparty-dist_in1k_20230213-f2773c78.pth
|
||||
Config: configs/xcit/xcit-small-24-p8_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_224_dist.pth
|
||||
- Name: xcit-large-24-p16_3rdparty_in1k
|
||||
Metadata:
|
||||
FLOPs: 35855948544
|
||||
Parameters: 189096136
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 82.97
|
||||
Top 5 Accuracy: 95.86
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p16_3rdparty_in1k_20230214-d29d2529.pth
|
||||
Config: configs/xcit/xcit-large-24-p16_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_224.pth
|
||||
- Name: xcit-large-24-p16_3rdparty-dist_in1k
|
||||
Metadata:
|
||||
FLOPs: 35855948544
|
||||
Parameters: 189096136
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 84.61
|
||||
Top 5 Accuracy: 97.07
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p16_3rdparty-dist_in1k_20230214-4fea599c.pth
|
||||
Config: configs/xcit/xcit-large-24-p16_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_224_dist.pth
|
||||
- Name: xcit-medium-24-p16_3rdparty-dist_in1k-384px
|
||||
Metadata:
|
||||
FLOPs: 47388932608
|
||||
Parameters: 84395752
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 85.47
|
||||
Top 5 Accuracy: 97.49
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p16_3rdparty-dist_in1k-384px_20230214-6c23a201.pth
|
||||
Config: configs/xcit/xcit-medium-24-p16_8xb128_in1k-384px.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_384_dist.pth
|
||||
- Name: xcit-small-12-p8_3rdparty-dist_in1k-384px
|
||||
Metadata:
|
||||
FLOPs: 54923537280
|
||||
Parameters: 26213032
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 85.12
|
||||
Top 5 Accuracy: 97.31
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-12-p8_3rdparty-dist_in1k-384px_20230214-9f2178bc.pth
|
||||
Config: configs/xcit/xcit-small-12-p8_8xb128_in1k-384px.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_384_dist.pth
|
||||
- Name: xcit-medium-24-p8_3rdparty_in1k
|
||||
Metadata:
|
||||
FLOPs: 63524706816
|
||||
Parameters: 84323624
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.61
|
||||
Top 5 Accuracy: 96.23
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p8_3rdparty_in1k_20230214-c362850b.pth
|
||||
Config: configs/xcit/xcit-medium-24-p8_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_224.pth
|
||||
- Name: xcit-medium-24-p8_3rdparty-dist_in1k
|
||||
Metadata:
|
||||
FLOPs: 63524706816
|
||||
Parameters: 84323624
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 85.0
|
||||
Top 5 Accuracy: 97.16
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p8_3rdparty-dist_in1k_20230214-625c953b.pth
|
||||
Config: configs/xcit/xcit-medium-24-p8_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_224_dist.pth
|
||||
- Name: xcit-small-24-p8_3rdparty-dist_in1k-384px
|
||||
Metadata:
|
||||
FLOPs: 105236704128
|
||||
Parameters: 47631112
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 85.57
|
||||
Top 5 Accuracy: 97.6
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-small-24-p8_3rdparty-dist_in1k-384px_20230214-57298eca.pth
|
||||
Config: configs/xcit/xcit-small-24-p8_8xb128_in1k-384px.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_384_dist.pth
|
||||
- Name: xcit-large-24-p16_3rdparty-dist_in1k-384px
|
||||
Metadata:
|
||||
FLOPs: 105345095424
|
||||
Parameters: 189096136
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 85.78
|
||||
Top 5 Accuracy: 97.6
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p16_3rdparty-dist_in1k-384px_20230214-bd515a34.pth
|
||||
Config: configs/xcit/xcit-large-24-p16_8xb128_in1k-384px.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_384_dist.pth
|
||||
- Name: xcit-large-24-p8_3rdparty_in1k
|
||||
Metadata:
|
||||
FLOPs: 141225699072
|
||||
Parameters: 188932648
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 84.23
|
||||
Top 5 Accuracy: 96.58
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p8_3rdparty_in1k_20230214-08f2f664.pth
|
||||
Config: configs/xcit/xcit-large-24-p8_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_224.pth
|
||||
- Name: xcit-large-24-p8_3rdparty-dist_in1k
|
||||
Metadata:
|
||||
FLOPs: 141225699072
|
||||
Parameters: 188932648
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 85.14
|
||||
Top 5 Accuracy: 97.32
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p8_3rdparty-dist_in1k_20230214-8c092b34.pth
|
||||
Config: configs/xcit/xcit-large-24-p8_8xb128_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_224_dist.pth
|
||||
- Name: xcit-medium-24-p8_3rdparty-dist_in1k-384px
|
||||
Metadata:
|
||||
FLOPs: 186672626176
|
||||
Parameters: 84323624
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 85.87
|
||||
Top 5 Accuracy: 97.61
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-medium-24-p8_3rdparty-dist_in1k-384px_20230214-5db925e0.pth
|
||||
Config: configs/xcit/xcit-medium-24-p8_8xb128_in1k-384px.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_384_dist.pth
|
||||
- Name: xcit-large-24-p8_3rdparty-dist_in1k-384px
|
||||
Metadata:
|
||||
FLOPs: 415003137792
|
||||
Parameters: 188932648
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: XCiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 86.13
|
||||
Top 5 Accuracy: 97.75
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/xcit/xcit-large-24-p8_3rdparty-dist_in1k-384px_20230214-9f718b1a.pth
|
||||
Config: configs/xcit/xcit-large-24-p8_8xb128_in1k-384px.py
|
||||
Converted From:
|
||||
Code: https://github.com/facebookresearch/xcit
|
||||
Weights: https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_384_dist.pth
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_384.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=16,
|
||||
embed_dims=768,
|
||||
depth=24,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=16,
|
||||
embed_dims=768,
|
||||
depth=24,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_384.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=8,
|
||||
embed_dims=768,
|
||||
depth=24,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=8,
|
||||
embed_dims=768,
|
||||
depth=24,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_384.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=16,
|
||||
embed_dims=512,
|
||||
depth=24,
|
||||
num_heads=8,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=512,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=16,
|
||||
embed_dims=512,
|
||||
depth=24,
|
||||
num_heads=8,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=512,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_384.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=8,
|
||||
embed_dims=512,
|
||||
depth=24,
|
||||
num_heads=8,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=512,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=8,
|
||||
embed_dims=512,
|
||||
depth=24,
|
||||
num_heads=8,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=512,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_384.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=16,
|
||||
embed_dims=128,
|
||||
depth=12,
|
||||
num_heads=4,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1.0,
|
||||
tokens_norm=False,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=128,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=16,
|
||||
embed_dims=128,
|
||||
depth=12,
|
||||
num_heads=4,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1.0,
|
||||
tokens_norm=False,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=128,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_384.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=8,
|
||||
embed_dims=128,
|
||||
depth=12,
|
||||
num_heads=4,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1.0,
|
||||
tokens_norm=False,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=128,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=8,
|
||||
embed_dims=128,
|
||||
depth=12,
|
||||
num_heads=4,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1.0,
|
||||
tokens_norm=False,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=128,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_384.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=16,
|
||||
embed_dims=384,
|
||||
depth=12,
|
||||
num_heads=8,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1.0,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=384,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=16,
|
||||
embed_dims=384,
|
||||
depth=12,
|
||||
num_heads=8,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1.0,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=384,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_384.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=8,
|
||||
embed_dims=384,
|
||||
depth=12,
|
||||
num_heads=8,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1.0,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=384,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=8,
|
||||
embed_dims=384,
|
||||
depth=12,
|
||||
num_heads=8,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1.0,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=384,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_384.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=16,
|
||||
embed_dims=384,
|
||||
depth=24,
|
||||
num_heads=8,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=384,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=16,
|
||||
embed_dims=384,
|
||||
depth=24,
|
||||
num_heads=8,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=384,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_384.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=8,
|
||||
embed_dims=384,
|
||||
depth=24,
|
||||
num_heads=8,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=384,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=8,
|
||||
embed_dims=384,
|
||||
depth=24,
|
||||
num_heads=8,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=384,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_384.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=16,
|
||||
embed_dims=192,
|
||||
depth=12,
|
||||
num_heads=4,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1.0,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=192,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=16,
|
||||
embed_dims=192,
|
||||
depth=12,
|
||||
num_heads=4,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1.0,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=192,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_384.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=8,
|
||||
embed_dims=192,
|
||||
depth=12,
|
||||
num_heads=4,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1.0,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=192,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=8,
|
||||
embed_dims=192,
|
||||
depth=12,
|
||||
num_heads=4,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1.0,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=192,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_384.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=16,
|
||||
embed_dims=192,
|
||||
depth=24,
|
||||
num_heads=4,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=192,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=16,
|
||||
embed_dims=192,
|
||||
depth=24,
|
||||
num_heads=4,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=192,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_384.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=8,
|
||||
embed_dims=192,
|
||||
depth=24,
|
||||
num_heads=4,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=192,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='XCiT',
|
||||
patch_size=8,
|
||||
embed_dims=192,
|
||||
depth=24,
|
||||
num_heads=4,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
tokens_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=192,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
),
|
||||
train_cfg=dict(arguments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -112,6 +112,7 @@ Backbones
|
|||
VGG
|
||||
Vig
|
||||
VisionTransformer
|
||||
XCiT
|
||||
|
||||
.. module:: mmcls.models.necks
|
||||
|
||||
|
|
|
@ -51,6 +51,7 @@ from .van import VAN
|
|||
from .vgg import VGG
|
||||
from .vig import PyramidVig, Vig
|
||||
from .vision_transformer import VisionTransformer
|
||||
from .xcit import XCiT
|
||||
|
||||
__all__ = [
|
||||
'LeNet5',
|
||||
|
@ -112,4 +113,5 @@ __all__ = [
|
|||
'LeViT',
|
||||
'Vig',
|
||||
'PyramidVig',
|
||||
'XCiT',
|
||||
]
|
||||
|
|
|
@ -0,0 +1,770 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn.bricks import ConvModule, DropPath
|
||||
from mmcv.cnn.bricks.transformer import FFN
|
||||
from mmengine.model import BaseModule, Sequential
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
from mmengine.utils import digit_version
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from ..utils import build_norm_layer, to_2tuple
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
if digit_version(torch.__version__) < digit_version('1.8.0'):
|
||||
floor_div = torch.floor_divide
|
||||
else:
|
||||
floor_div = partial(torch.div, rounding_mode='floor')
|
||||
|
||||
|
||||
class ClassAttntion(BaseModule):
|
||||
"""Class Attention Module.
|
||||
|
||||
A PyTorch implementation of Class Attention Module introduced by:
|
||||
`Going deeper with Image Transformers <https://arxiv.org/abs/2103.17239>`_
|
||||
|
||||
taken from
|
||||
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
with slight modifications to do CA
|
||||
|
||||
Args:
|
||||
dim (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads. Defaults to 8.
|
||||
qkv_bias (bool): enable bias for qkv if True. Defaults to False.
|
||||
attn_drop (float): The drop out rate for attention output weights.
|
||||
Defaults to 0.
|
||||
proj_drop (float): The drop out rate for linear output weights.
|
||||
Defaults to 0.
|
||||
init_cfg (dict | list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
init_cfg=None):
|
||||
|
||||
super(ClassAttntion, self).__init__(init_cfg=init_cfg)
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.k = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.v = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
# We only need to calculate query of cls token.
|
||||
q = self.q(x[:, 0]).unsqueeze(1).reshape(B, 1, self.num_heads,
|
||||
C // self.num_heads).permute(
|
||||
0, 2, 1, 3)
|
||||
k = self.k(x).reshape(B, N, self.num_heads,
|
||||
C // self.num_heads).permute(0, 2, 1, 3)
|
||||
|
||||
q = q * self.scale
|
||||
v = self.v(x).reshape(B, N, self.num_heads,
|
||||
C // self.num_heads).permute(0, 2, 1, 3)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C)
|
||||
x_cls = self.proj(x_cls)
|
||||
x_cls = self.proj_drop(x_cls)
|
||||
|
||||
return x_cls
|
||||
|
||||
|
||||
class PositionalEncodingFourier(BaseModule):
|
||||
"""Positional Encoding using a fourier kernel.
|
||||
|
||||
A PyTorch implementation of Positional Encoding relying on
|
||||
a fourier kernel introduced by:
|
||||
`Attention is all you Need <https://arxiv.org/abs/1706.03762>`_
|
||||
|
||||
Based on the `official XCiT code
|
||||
<https://github.com/facebookresearch/xcit/blob/master/xcit.py>`_
|
||||
|
||||
Args:
|
||||
hidden_dim (int): The hidden feature dimension. Defaults to 32.
|
||||
dim (int): The output feature dimension. Defaults to 768.
|
||||
temperature (int): A control variable for position encoding.
|
||||
Defaults to 10000.
|
||||
init_cfg (dict | list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_dim: int = 32,
|
||||
dim: int = 768,
|
||||
temperature: int = 10000,
|
||||
init_cfg=None):
|
||||
super(PositionalEncodingFourier, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
self.token_projection = ConvModule(
|
||||
in_channels=hidden_dim * 2,
|
||||
out_channels=dim,
|
||||
kernel_size=1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=None)
|
||||
self.scale = 2 * math.pi
|
||||
self.temperature = temperature
|
||||
self.hidden_dim = hidden_dim
|
||||
self.dim = dim
|
||||
self.eps = 1e-6
|
||||
|
||||
def forward(self, B: int, H: int, W: int):
|
||||
device = self.token_projection.conv.weight.device
|
||||
y_embed = torch.arange(
|
||||
1, H + 1, device=device).unsqueeze(1).repeat(1, 1, W).float()
|
||||
x_embed = torch.arange(1, W + 1, device=device).repeat(1, H, 1).float()
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.hidden_dim, device=device).float()
|
||||
dim_t = floor_div(dim_t, 2)
|
||||
dim_t = self.temperature**(2 * dim_t / self.hidden_dim)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack(
|
||||
[pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()],
|
||||
dim=4).flatten(3)
|
||||
pos_y = torch.stack(
|
||||
[pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()],
|
||||
dim=4).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
pos = self.token_projection(pos)
|
||||
return pos.repeat(B, 1, 1, 1) # (B, C, H, W)
|
||||
|
||||
|
||||
class ConvPatchEmbed(BaseModule):
|
||||
"""Patch Embedding using multiple convolution layers.
|
||||
|
||||
Args:
|
||||
img_size (int, tuple): input image size.
|
||||
Defaults to 224, means the size is 224*224.
|
||||
patch_size (int): The patch size in conv patch embedding.
|
||||
Defaults to 16.
|
||||
in_channels (int): The input channels of this module.
|
||||
Defaults to 3.
|
||||
embed_dims (int): The feature dimension
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='BN')``.
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Defaults to ``dict(type='GELU')``.
|
||||
init_cfg (dict | list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size: Union[int, tuple] = 224,
|
||||
patch_size: int = 16,
|
||||
in_channels: int = 3,
|
||||
embed_dims: int = 768,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
init_cfg=None):
|
||||
super(ConvPatchEmbed, self).__init__(init_cfg=init_cfg)
|
||||
img_size = to_2tuple(img_size)
|
||||
num_patches = (img_size[1] // patch_size) * (img_size[0] // patch_size)
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
|
||||
conv = partial(
|
||||
ConvModule,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
)
|
||||
|
||||
layer = []
|
||||
if patch_size == 16:
|
||||
layer.append(
|
||||
conv(in_channels=in_channels, out_channels=embed_dims // 8))
|
||||
layer.append(
|
||||
conv(
|
||||
in_channels=embed_dims // 8, out_channels=embed_dims // 4))
|
||||
elif patch_size == 8:
|
||||
layer.append(
|
||||
conv(in_channels=in_channels, out_channels=embed_dims // 4))
|
||||
else:
|
||||
raise ValueError('For patch embedding, the patch size must be 16 '
|
||||
f'or 8, but get patch size {self.patch_size}.')
|
||||
|
||||
layer.append(
|
||||
conv(in_channels=embed_dims // 4, out_channels=embed_dims // 2))
|
||||
layer.append(
|
||||
conv(
|
||||
in_channels=embed_dims // 2,
|
||||
out_channels=embed_dims,
|
||||
act_cfg=None,
|
||||
))
|
||||
|
||||
self.proj = Sequential(*layer)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = self.proj(x)
|
||||
Hp, Wp = x.shape[2], x.shape[3]
|
||||
x = x.flatten(2).transpose(1, 2) # (B, N, C)
|
||||
return x, (Hp, Wp)
|
||||
|
||||
|
||||
class ClassAttentionBlock(BaseModule):
|
||||
"""Transformer block using Class Attention.
|
||||
|
||||
Args:
|
||||
dim (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
mlp_ratio (float): The hidden dimension ratio for FFN.
|
||||
Defaults to 4.
|
||||
qkv_bias (bool): enable bias for qkv if True. Defaults to False.
|
||||
drop (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Defaults to 0.
|
||||
attn_drop (float): The drop out rate for attention output weights.
|
||||
Defaults to 0.
|
||||
drop_path (float): Stochastic depth rate. Defaults to 0.
|
||||
layer_scale_init_value (float): The initial value for layer scale.
|
||||
Defaults to 1.
|
||||
tokens_norm (bool): Whether to normalize all tokens or just the
|
||||
cls_token in the CA. Defaults to False.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN', eps=1e-6)``.
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Defaults to ``dict(type='GELU')``.
|
||||
init_cfg (dict | list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = False,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
layer_scale_init_value=1.,
|
||||
tokens_norm=False,
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
act_cfg=dict(type='GELU'),
|
||||
init_cfg=None):
|
||||
|
||||
super(ClassAttentionBlock, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
self.norm1 = build_norm_layer(norm_cfg, dim)
|
||||
|
||||
self.attn = ClassAttntion(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
|
||||
self.drop_path = DropPath(
|
||||
drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self.norm2 = build_norm_layer(norm_cfg, dim)
|
||||
|
||||
self.ffn = FFN(
|
||||
embed_dims=dim,
|
||||
feedforward_channels=int(dim * mlp_ratio),
|
||||
act_cfg=act_cfg,
|
||||
ffn_drop=drop,
|
||||
)
|
||||
|
||||
if layer_scale_init_value > 0:
|
||||
self.gamma1 = nn.Parameter(layer_scale_init_value *
|
||||
torch.ones(dim))
|
||||
self.gamma2 = nn.Parameter(layer_scale_init_value *
|
||||
torch.ones(dim))
|
||||
else:
|
||||
self.gamma1, self.gamma2 = 1.0, 1.0
|
||||
|
||||
# See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 # noqa: E501
|
||||
self.tokens_norm = tokens_norm
|
||||
|
||||
def forward(self, x):
|
||||
x_norm1 = self.norm1(x)
|
||||
x_attn = torch.cat([self.attn(x_norm1), x_norm1[:, 1:]], dim=1)
|
||||
x = x + self.drop_path(self.gamma1 * x_attn)
|
||||
if self.tokens_norm:
|
||||
x = self.norm2(x)
|
||||
else:
|
||||
x = torch.cat([self.norm2(x[:, 0:1]), x[:, 1:]], dim=1)
|
||||
x_res = x
|
||||
cls_token = x[:, 0:1]
|
||||
cls_token = self.gamma2 * self.ffn(cls_token, identity=0)
|
||||
x = torch.cat([cls_token, x[:, 1:]], dim=1)
|
||||
x = x_res + self.drop_path(x)
|
||||
return x
|
||||
|
||||
|
||||
class LPI(BaseModule):
|
||||
"""Local Patch Interaction module.
|
||||
|
||||
A PyTorch implementation of Local Patch Interaction module
|
||||
as in XCiT introduced by `XCiT: Cross-Covariance Image Transformers
|
||||
<https://arxiv.org/abs/2106.096819>`_
|
||||
|
||||
Local Patch Interaction module that allows explicit communication between
|
||||
tokens in 3x3 windows to augment the implicit communication performed by
|
||||
the block diagonal scatter attention. Implemented using 2 layers of
|
||||
separable 3x3 convolutions with GeLU and BatchNorm2d
|
||||
|
||||
Args:
|
||||
in_features (int): The input channels.
|
||||
out_features (int, optional): The output channels. Defaults to None.
|
||||
kernel_size (int): The kernel_size in ConvModule. Defaults to 3.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='BN')``.
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Defaults to ``dict(type='GELU')``.
|
||||
init_cfg (dict | list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: Optional[int] = None,
|
||||
kernel_size: int = 3,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
init_cfg=None):
|
||||
super(LPI, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
out_features = out_features or in_features
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=in_features,
|
||||
out_channels=in_features,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding,
|
||||
groups=in_features,
|
||||
bias=True,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
order=('conv', 'act', 'norm'))
|
||||
|
||||
self.conv2 = ConvModule(
|
||||
in_channels=in_features,
|
||||
out_channels=out_features,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding,
|
||||
groups=out_features,
|
||||
norm_cfg=None,
|
||||
act_cfg=None)
|
||||
|
||||
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
x = x.permute(0, 2, 1).reshape(B, C, H, W)
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = x.reshape(B, C, N).permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
|
||||
class XCA(BaseModule):
|
||||
r"""Cross-Covariance Attention module.
|
||||
|
||||
A PyTorch implementation of Cross-Covariance Attention module
|
||||
as in XCiT introduced by `XCiT: Cross-Covariance Image Transformers
|
||||
<https://arxiv.org/abs/2106.096819>`_
|
||||
|
||||
In Cross-Covariance Attention (XCA), the channels are updated using a
|
||||
weighted sum. The weights are obtained from the (softmax normalized)
|
||||
Cross-covariance matrix :math:`(Q^T \cdot K \in d_h \times d_h)`
|
||||
|
||||
Args:
|
||||
dim (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads. Defaults to 8.
|
||||
qkv_bias (bool): enable bias for qkv if True. Defaults to False.
|
||||
attn_drop (float): The drop out rate for attention output weights.
|
||||
Defaults to 0.
|
||||
proj_drop (float): The drop out rate for linear output weights.
|
||||
Defaults to 0.
|
||||
init_cfg (dict | list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
init_cfg=None):
|
||||
super(XCA, self).__init__(init_cfg=init_cfg)
|
||||
self.num_heads = num_heads
|
||||
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
# (qkv, B, num_heads, channels per head, N)
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
|
||||
C // self.num_heads).permute(2, 0, 3, 4, 1)
|
||||
q, k, v = qkv.unbind(0)
|
||||
|
||||
# Paper section 3.2 l2-Normalization and temperature scaling
|
||||
q = F.normalize(q, dim=-1)
|
||||
k = F.normalize(k, dim=-1)
|
||||
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
# (B, num_heads, C', N) -> (B, N, num_heads, C') -> (B, N C)
|
||||
x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class XCABlock(BaseModule):
|
||||
"""Transformer block using XCA.
|
||||
|
||||
Args:
|
||||
dim (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
mlp_ratio (float): The hidden dimension ratio for FFNs.
|
||||
Defaults to 4.
|
||||
qkv_bias (bool): enable bias for qkv if True. Defaults to False.
|
||||
drop (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Defaults to 0.
|
||||
attn_drop (float): The drop out rate for attention output weights.
|
||||
Defaults to 0.
|
||||
drop_path (float): Stochastic depth rate. Defaults to 0.
|
||||
layer_scale_init_value (float): The initial value for layer scale.
|
||||
Defaults to 1.
|
||||
bn_norm_cfg (dict): Config dict for batchnorm in LPI and
|
||||
ConvPatchEmbed. Defaults to ``dict(type='BN')``.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN', eps=1e-6)``.
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Defaults to ``dict(type='GELU')``.
|
||||
init_cfg (dict | list[dict], optional): Initialization config dict.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = False,
|
||||
drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
layer_scale_init_value: float = 1.,
|
||||
bn_norm_cfg=dict(type='BN'),
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
act_cfg=dict(type='GELU'),
|
||||
init_cfg=None):
|
||||
super(XCABlock, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
self.norm1 = build_norm_layer(norm_cfg, dim)
|
||||
self.attn = XCA(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
self.drop_path = DropPath(
|
||||
drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self.norm3 = build_norm_layer(norm_cfg, dim)
|
||||
self.local_mp = LPI(
|
||||
in_features=dim,
|
||||
norm_cfg=bn_norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
)
|
||||
|
||||
self.norm2 = build_norm_layer(norm_cfg, dim)
|
||||
self.ffn = FFN(
|
||||
embed_dims=dim,
|
||||
feedforward_channels=int(dim * mlp_ratio),
|
||||
act_cfg=act_cfg,
|
||||
ffn_drop=drop,
|
||||
)
|
||||
|
||||
self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones(dim))
|
||||
self.gamma3 = nn.Parameter(layer_scale_init_value * torch.ones(dim))
|
||||
self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones(dim))
|
||||
|
||||
def forward(self, x, H: int, W: int):
|
||||
x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x)))
|
||||
# NOTE official code has 3 then 2, so keeping it the same to be
|
||||
# consistent with loaded weights See
|
||||
# https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 # noqa: E501
|
||||
x = x + self.drop_path(
|
||||
self.gamma3 * self.local_mp(self.norm3(x), H, W))
|
||||
x = x + self.drop_path(
|
||||
self.gamma2 * self.ffn(self.norm2(x), identity=0))
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class XCiT(BaseBackbone):
|
||||
"""XCiT backbone.
|
||||
|
||||
A PyTorch implementation of XCiT backbone introduced by:
|
||||
`XCiT: Cross-Covariance Image Transformers
|
||||
<https://arxiv.org/abs/2106.096819>`_
|
||||
|
||||
Args:
|
||||
img_size (int, tuple): Input image size. Defaults to 224.
|
||||
patch_size (int): Patch size. Defaults to 16.
|
||||
in_channels (int): Number of input channels. Defaults to 3.
|
||||
embed_dims (int): Embedding dimension. Defaults to 768.
|
||||
depth (int): depth of vision transformer. Defaults to 12.
|
||||
cls_attn_layers (int): Depth of Class attention layers.
|
||||
Defaults to 2.
|
||||
num_heads (int): Number of attention heads. Defaults to 12.
|
||||
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
|
||||
Defaults to 4.
|
||||
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
|
||||
drop_rate (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Defaults to 0.
|
||||
attn_drop_rate (float): The drop out rate for attention output weights.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
||||
use_pos_embed (bool): Whether to use positional encoding.
|
||||
Defaults to True.
|
||||
layer_scale_init_value (float): The initial value for layer scale.
|
||||
Defaults to 1.
|
||||
tokens_norm (bool): Whether to normalize all tokens or just the
|
||||
cls_token in the CA. Defaults to False.
|
||||
out_indices (Sequence[int]): Output from which layers.
|
||||
Defaults to (-1, ).
|
||||
frozen_stages (int): Layers to be frozen (all param fixed), and 0
|
||||
means to freeze the stem stage. Defaults to -1, which means
|
||||
not freeze any parameters.
|
||||
bn_norm_cfg (dict): Config dict for the batch norm layers in LPI and
|
||||
ConvPatchEmbed. Defaults to ``dict(type='BN')``.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN', eps=1e-6)``.
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Defaults to ``dict(type='GELU')``.
|
||||
init_cfg (dict | list[dict], optional): Initialization config dict.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size: Union[int, tuple] = 224,
|
||||
patch_size: int = 16,
|
||||
in_channels: int = 3,
|
||||
embed_dims: int = 768,
|
||||
depth: int = 12,
|
||||
cls_attn_layers: int = 2,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
use_pos_embed: bool = True,
|
||||
layer_scale_init_value: float = 1.,
|
||||
tokens_norm: bool = False,
|
||||
out_type: str = 'cls_token',
|
||||
out_indices: Sequence[int] = (-1, ),
|
||||
final_norm: bool = True,
|
||||
frozen_stages: int = -1,
|
||||
bn_norm_cfg=dict(type='BN'),
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
act_cfg=dict(type='GELU'),
|
||||
init_cfg=dict(type='TruncNormal', layer='Linear')):
|
||||
super(XCiT, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
img_size = to_2tuple(img_size)
|
||||
if (img_size[0] % patch_size != 0) or (img_size[1] % patch_size != 0):
|
||||
raise ValueError(f'`patch_size` ({patch_size}) should divide '
|
||||
f'the image shape ({img_size}) evenly.')
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
|
||||
assert out_type in ('raw', 'featmap', 'avg_featmap', 'cls_token')
|
||||
self.out_type = out_type
|
||||
|
||||
self.patch_embed = ConvPatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims,
|
||||
norm_cfg=bn_norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
self.use_pos_embed = use_pos_embed
|
||||
if use_pos_embed:
|
||||
self.pos_embed = PositionalEncodingFourier(dim=embed_dims)
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
self.xca_layers = nn.ModuleList()
|
||||
self.ca_layers = nn.ModuleList()
|
||||
self.num_layers = depth + cls_attn_layers
|
||||
|
||||
for _ in range(depth):
|
||||
self.xca_layers.append(
|
||||
XCABlock(
|
||||
dim=embed_dims,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=drop_path_rate,
|
||||
bn_norm_cfg=bn_norm_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
))
|
||||
|
||||
for _ in range(cls_attn_layers):
|
||||
self.ca_layers.append(
|
||||
ClassAttentionBlock(
|
||||
dim=embed_dims,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
tokens_norm=tokens_norm,
|
||||
))
|
||||
|
||||
if final_norm:
|
||||
self.norm = build_norm_layer(norm_cfg, embed_dims)
|
||||
|
||||
# Transform out_indices
|
||||
if isinstance(out_indices, int):
|
||||
out_indices = [out_indices]
|
||||
assert isinstance(out_indices, Sequence), \
|
||||
f'"out_indices" must by a sequence or int, ' \
|
||||
f'get {type(out_indices)} instead.'
|
||||
out_indices = list(out_indices)
|
||||
for i, index in enumerate(out_indices):
|
||||
if index < 0:
|
||||
out_indices[i] = self.num_layers + index
|
||||
assert 0 <= out_indices[i] <= self.num_layers, \
|
||||
f'Invalid out_indices {index}.'
|
||||
self.out_indices = out_indices
|
||||
|
||||
if frozen_stages > self.num_layers + 1:
|
||||
raise ValueError('frozen_stages must be less than '
|
||||
f'{self.num_layers} but get {frozen_stages}')
|
||||
self.frozen_stages = frozen_stages
|
||||
|
||||
def init_weights(self):
|
||||
super().init_weights()
|
||||
|
||||
if self.init_cfg is not None and self.init_cfg['type'] == 'Pretrained':
|
||||
return
|
||||
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
|
||||
def _freeze_stages(self):
|
||||
if self.frozen_stages < 0:
|
||||
return
|
||||
|
||||
# freeze position embedding
|
||||
if self.use_pos_embed:
|
||||
self.pos_embed.eval()
|
||||
for param in self.pos_embed.parameters():
|
||||
param.requires_grad = False
|
||||
# freeze patch embedding
|
||||
self.patch_embed.eval()
|
||||
for param in self.patch_embed.parameters():
|
||||
param.requires_grad = False
|
||||
# set dropout to eval model
|
||||
self.pos_drop.eval()
|
||||
# freeze cls_token, only use in self.Clslayers
|
||||
if self.frozen_stages > len(self.xca_layers):
|
||||
self.cls_token.requires_grad = False
|
||||
# freeze layers
|
||||
for i in range(1, self.frozen_stages):
|
||||
if i <= len(self.xca_layers):
|
||||
m = self.xca_layers[i - 1]
|
||||
else:
|
||||
m = self.ca_layers[i - len(self.xca_layers) - 1]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# freeze the last layer norm if all_stages are frozen
|
||||
if self.frozen_stages == len(self.xca_layers) + len(self.ca_layers):
|
||||
self.norm.eval()
|
||||
for param in self.norm.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
B = x.shape[0]
|
||||
# x is (B, N, C). (Hp, Hw) is the patch resolution
|
||||
x, (Hp, Wp) = self.patch_embed(x)
|
||||
|
||||
if self.use_pos_embed:
|
||||
# (B, C, Hp, Wp) -> (B, C, N) -> (B, N, C)
|
||||
pos_encoding = self.pos_embed(B, Hp, Wp)
|
||||
x = x + pos_encoding.reshape(B, -1, x.size(1)).permute(0, 2, 1)
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for i, layer in enumerate(self.xca_layers):
|
||||
x = layer(x, Hp, Wp)
|
||||
if i in self.out_indices:
|
||||
outs.append(self._format_output(x, (Hp, Wp), False))
|
||||
|
||||
x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
|
||||
|
||||
for i, layer in enumerate(self.ca_layers):
|
||||
x = layer(x)
|
||||
if i == len(self.ca_layers) - 1:
|
||||
x = self.norm(x)
|
||||
if i + len(self.xca_layers) in self.out_indices:
|
||||
outs.append(self._format_output(x, (Hp, Wp), True))
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def _format_output(self, x, hw, with_cls_token: bool):
|
||||
if self.out_type == 'raw':
|
||||
return x
|
||||
if self.out_type == 'cls_token':
|
||||
if not with_cls_token:
|
||||
raise ValueError(
|
||||
'Cannot output cls_token since there is no cls_token.')
|
||||
return x[:, 0]
|
||||
|
||||
patch_token = x[:, 1:] if with_cls_token else x
|
||||
if self.out_type == 'featmap':
|
||||
B = x.size(0)
|
||||
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
|
||||
return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
|
||||
if self.out_type == 'avg_featmap':
|
||||
return patch_token.mean(dim=1)
|
||||
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
self._freeze_stages()
|
|
@ -52,3 +52,4 @@ Import:
|
|||
- configs/levit/metafile.yml
|
||||
- configs/vig/metafile.yml
|
||||
- configs/arcface/metafile.yml
|
||||
- configs/xcit/metafile.yml
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# The basic forward/backward tests are in ../test_models.py
|
||||
import torch
|
||||
|
||||
from mmcls.apis import get_model
|
||||
|
||||
|
||||
def test_out_type():
|
||||
inputs = torch.rand(1, 3, 224, 224)
|
||||
|
||||
model = get_model(
|
||||
'xcit-nano-12-p16_3rdparty_in1k',
|
||||
backbone=dict(out_type='raw'),
|
||||
neck=None,
|
||||
head=None)
|
||||
outputs = model(inputs)[0]
|
||||
assert outputs.shape == (1, 197, 128)
|
||||
|
||||
model = get_model(
|
||||
'xcit-nano-12-p16_3rdparty_in1k',
|
||||
backbone=dict(out_type='featmap'),
|
||||
neck=None,
|
||||
head=None)
|
||||
outputs = model(inputs)[0]
|
||||
assert outputs.shape == (1, 128, 14, 14)
|
||||
|
||||
model = get_model(
|
||||
'xcit-nano-12-p16_3rdparty_in1k',
|
||||
backbone=dict(out_type='cls_token'),
|
||||
neck=None,
|
||||
head=None)
|
||||
outputs = model(inputs)[0]
|
||||
assert outputs.shape == (1, 128)
|
||||
|
||||
model = get_model(
|
||||
'xcit-nano-12-p16_3rdparty_in1k',
|
||||
backbone=dict(out_type='avg_featmap'),
|
||||
neck=None,
|
||||
head=None)
|
||||
outputs = model(inputs)[0]
|
||||
assert outputs.shape == (1, 128)
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import mmcls.models
|
||||
from mmcls.apis import ModelHub, get_model
|
||||
|
||||
|
||||
@dataclass
|
||||
class Cfg:
|
||||
name: str
|
||||
backbone: type
|
||||
num_classes: int = 1000
|
||||
build: bool = True
|
||||
forward: bool = True
|
||||
backward: bool = True
|
||||
input_shape: tuple = (1, 3, 224, 224)
|
||||
|
||||
|
||||
test_list = [
|
||||
Cfg(name='xcit-small-12-p16_3rdparty_in1k', backbone=mmcls.models.XCiT),
|
||||
Cfg(name='xcit-nano-12-p8_3rdparty-dist_in1k-384px',
|
||||
backbone=mmcls.models.XCiT,
|
||||
input_shape=(1, 3, 384, 384)),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cfg', test_list)
|
||||
def test_build(cfg: Cfg):
|
||||
if not cfg.build:
|
||||
return
|
||||
|
||||
model_name = cfg.name
|
||||
ModelHub._register_mmcls_models()
|
||||
assert ModelHub.has(model_name)
|
||||
|
||||
model = get_model(model_name)
|
||||
backbone_class = cfg.backbone
|
||||
assert isinstance(model.backbone, backbone_class)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cfg', test_list)
|
||||
def test_forward(cfg: Cfg):
|
||||
if not cfg.forward:
|
||||
return
|
||||
|
||||
model = get_model(cfg.name)
|
||||
inputs = torch.rand(*cfg.input_shape)
|
||||
outputs = model(inputs)
|
||||
assert outputs.shape == (1, cfg.num_classes)
|
||||
|
||||
feats = model.extract_feat(inputs)
|
||||
assert isinstance(feats, tuple)
|
||||
assert len(feats) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cfg', test_list)
|
||||
def test_backward(cfg: Cfg):
|
||||
if not cfg.backward:
|
||||
return
|
||||
|
||||
model = get_model(cfg.name)
|
||||
inputs = torch.rand(*cfg.input_shape)
|
||||
outputs = model(inputs)
|
||||
outputs.mean().backward()
|
||||
|
||||
for n, x in model.named_parameters():
|
||||
assert x.grad is not None, f'No gradient for {n}'
|
||||
num_grad = sum(
|
||||
[x.grad.numel() for x in model.parameters() if x.grad is not None])
|
||||
assert outputs.shape[-1] == cfg.num_classes
|
||||
num_params = sum([x.numel() for x in model.parameters()])
|
||||
assert num_params == num_grad, 'Some parameters are missing gradients'
|
||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
Loading…
Reference in New Issue