[Refactor] Refactor dev scripts

pull/913/head
mzr1996 2022-07-12 08:10:59 +00:00
parent c992e24617
commit 24bcf069f8
89 changed files with 958 additions and 1429 deletions

View File

@ -1,54 +1,31 @@
import logging
import re
import tempfile
from argparse import ArgumentParser
from pathlib import Path
from time import time
from typing import OrderedDict
import mmcv
import numpy as np
import torch
from mmcv import Config
from mmcv.parallel import collate, scatter
from mmengine import Config, MMLogger, Runner
from mmengine.dataset import Compose
from modelindex.load_model_index import load
from rich.console import Console
from rich.table import Table
from mmcls.apis import init_model
from mmcls.core.visualization.image import imshow_infos
from mmcls.datasets.imagenet import ImageNet
from mmcls.datasets.pipelines import Compose
from mmcls.utils import get_root_logger
from mmcls.core import ClsVisualizer
from mmcls.datasets import CIFAR10, CIFAR100, ImageNet
from mmcls.utils import register_all_modules
console = Console()
MMCLS_ROOT = Path(__file__).absolute().parents[2]
CIFAR10_CLASSES = [
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
'ship', 'truck'
]
CIFAR100_CLASSES = [
'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain',
'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree',
'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy',
'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket',
'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail',
'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper',
'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train',
'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf',
'woman', 'worm'
]
classes_map = {
'ImageNet-1k': ImageNet.CLASSES,
'CIFAR-10': CIFAR10_CLASSES,
'CIFAR-100': CIFAR100_CLASSES
'CIFAR-10': CIFAR10.CLASSES,
'CIFAR-100': CIFAR100.CLASSES,
}
@ -75,34 +52,30 @@ def parse_args():
'--flops-str',
action='store_true',
help='Output FLOPs and params counts in a string form.')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
return args
def inference(config_file, checkpoint, classes, args):
def inference(config_file, checkpoint, work_dir, args, exp_name):
cfg = Config.fromfile(config_file)
model = init_model(cfg, checkpoint, device=args.device)
model.CLASSES = classes
cfg.work_dir = work_dir
cfg.load_from = checkpoint
cfg.log_level = 'WARN'
cfg.experiment_name = exp_name
# build the data pipeline
if cfg.data.test.pipeline[0]['type'] != 'LoadImageFromFile':
cfg.data.test.pipeline.insert(0, dict(type='LoadImageFromFile'))
if cfg.data.test.type in ['CIFAR10', 'CIFAR100']:
test_dataset = cfg.test_dataloader.dataset
if test_dataset.pipeline[0]['type'] != 'LoadImageFromFile':
test_dataset.pipeline.insert(0, dict(type='LoadImageFromFile'))
if test_dataset.type in ['CIFAR10', 'CIFAR100']:
# The image shape of CIFAR is (32, 32, 3)
cfg.data.test.pipeline.insert(1, dict(type='Resize', scale=32))
test_dataset.pipeline.insert(1, dict(type='Resize', scale=32))
data = dict(img_info=dict(filename=args.img), img_prefix=None)
data = Compose(test_dataset.pipeline)({'img_path': args.img})
resolution = tuple(data['inputs'].shape[1:])
test_pipeline = Compose(cfg.data.test.pipeline)
data = test_pipeline(data)
resolution = tuple(data['img'].shape[1:])
data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [args.device])[0]
runner: Runner = Runner.from_cfg(cfg)
model = runner.model
# forward the model
result = {'resolution': resolution}
@ -111,18 +84,12 @@ def inference(config_file, checkpoint, classes, args):
time_record = []
for _ in range(10):
start = time()
scores = model(return_loss=False, **data)
model.val_step([data])
time_record.append((time() - start) * 1000)
result['time_mean'] = np.mean(time_record[1:-1])
result['time_std'] = np.std(time_record[1:-1])
else:
scores = model(return_loss=False, **data)
pred_score = np.max(scores, axis=1)[0]
pred_label = np.argmax(scores, axis=1)[0]
result['pred_label'] = pred_label
result['pred_score'] = float(pred_score)
result['pred_class'] = model.CLASSES[result['pred_label']]
model.val_step([data])
result['model'] = config_file.stem
@ -177,13 +144,17 @@ def show_summary(summary_data, args):
# Sample test whether the inference code is correct
def main(args):
register_all_modules()
model_index_file = MMCLS_ROOT / 'model-index.yml'
model_index = load(str(model_index_file))
model_index.build_models_with_collections()
models = OrderedDict({model.name: model for model in model_index.models})
logger = get_root_logger(
log_file='benchmark_test_image.log', log_level=logging.INFO)
logger = MMLogger(
'validation',
logger_name='validation',
log_file='benchmark_test_image.log',
log_level=logging.INFO)
if args.models:
patterns = [re.compile(pattern) for pattern in args.models]
@ -198,6 +169,7 @@ def main(args):
models = filter_models
summary_data = {}
tmpdir = tempfile.TemporaryDirectory()
for model_name, model_info in models.items():
if model_info.config is None:
@ -209,7 +181,6 @@ def main(args):
logger.info(f'Processing: {model_name}')
http_prefix = 'https://download.openmmlab.com/mmclassification/'
dataset = model_info.results[0].dataset
if args.checkpoint_root is not None:
root = args.checkpoint_root
if 's3://' in args.checkpoint_root:
@ -235,18 +206,25 @@ def main(args):
try:
# build the model from a config file and a checkpoint file
result = inference(MMCLS_ROOT / config, checkpoint,
classes_map[dataset], args)
result = inference(MMCLS_ROOT / config, checkpoint, tmpdir.name,
args, model_name)
result['valid'] = 'PASS'
except Exception as e:
logger.error(f'"{config}" : {repr(e)}')
except Exception:
import traceback
logger.error(f'"{config}" :\n{traceback.format_exc()}')
result = {'valid': 'FAIL'}
summary_data[model_name] = result
# show the results
if args.show:
imshow_infos(args.img, result, wait_time=args.wait_time)
vis = ClsVisualizer.get_instance('valid')
vis.set_image(mmcv.imread(args.img))
vis.draw_texts(
texts='\n'.join([f'{k}: {v}' for k, v in result.items()]),
positions=np.array([(5, 5)]))
vis.show(wait_time=args.wait_time)
tmpdir.cleanup()
show_summary(summary_data, args)

View File

@ -15,8 +15,8 @@ from rich.table import Table
console = Console()
MMCLS_ROOT = Path(__file__).absolute().parents[2]
METRICS_MAP = {
'Top 1 Accuracy': 'accuracy_top-1',
'Top 5 Accuracy': 'accuracy_top-5'
'Top 1 Accuracy': 'accuracy/top1',
'Top 5 Accuracy': 'accuracy/top5'
}
@ -96,6 +96,7 @@ def create_test_job_batch(commands, model_info, args, port, script_name):
job_name = f'{args.job_name}_{fname}'
work_dir = Path(args.work_dir) / fname
work_dir.mkdir(parents=True, exist_ok=True)
result_file = work_dir / 'result.pkl'
if args.mail is not None and 'NONE' not in args.mail_type:
mail_cfg = (f'#SBATCH --mail {args.mail}\n'
@ -121,8 +122,8 @@ def create_test_job_batch(commands, model_info, args, port, script_name):
f'#SBATCH --ntasks=8\n'
f'#SBATCH --cpus-per-task=5\n\n'
f'{runner} -u {script_name} {config} {checkpoint} '
f'--out={work_dir / "result.pkl"} --metrics accuracy '
f'--out-items=none '
f'--work-dir={work_dir} '
f'--out={result_file} '
f'--cfg-option dist_params.port={port} '
f'--launcher={launcher}\n')
@ -214,14 +215,14 @@ def save_summary(summary_data, models_map, work_dir):
row = [model_name]
if 'Top 1 Accuracy' in summary:
metric = summary['Top 1 Accuracy']
row.append(f"{metric['expect']:.2f}")
row.append(f"{metric['result']:.2f}")
row.append(str(round(metric['expect'], 2)))
row.append(str(round(metric['result'], 2)))
else:
row.extend([''] * 2)
if 'Top 5 Accuracy' in summary:
metric = summary['Top 5 Accuracy']
row.append(f"{metric['expect']:.2f}")
row.append(f"{metric['result']:.2f}")
row.append(str(round(metric['expect'], 2)))
row.append(str(round(metric['result'], 2)))
else:
row.extend([''] * 2)
@ -253,8 +254,8 @@ def show_summary(summary_data):
for metric_key in METRICS_MAP:
if metric_key in summary:
metric = summary[metric_key]
expect = metric['expect']
result = metric['result']
expect = round(metric['expect'], 2)
result = round(metric['result'], 2)
color = set_color(result, expect)
row.append(f'{expect:.2f}')
row.append(f'[{color}]{result:.2f}[/{color}]')
@ -310,9 +311,7 @@ def summary(args):
# extract metrics
summary = {'date': date.strftime('%Y-%m-%d')}
for key_yml, key_res in METRICS_MAP.items():
if key_yml in expect_metrics:
assert key_res in results, \
f'{model_name}: No metric "{key_res}"'
if key_yml in expect_metrics and key_res in results:
expect_result = float(expect_metrics[key_yml])
result = float(results[key_res])
summary[key_yml] = dict(expect=expect_result, result=result)

View File

@ -14,8 +14,8 @@ from rich.table import Table
console = Console()
METRICS_MAP = {
'Top 1 Accuracy': 'accuracy_top-1',
'Top 5 Accuracy': 'accuracy_top-5'
'Top 1 Accuracy': 'accuracy/top1',
'Top 5 Accuracy': 'accuracy/top5'
}
@ -280,23 +280,24 @@ def summary(args):
summary_data = {}
for model_name, model_info in models.items():
# Skip if not found any log file.
summary_data[model_name] = {}
if model_name not in dir_map:
summary_data[model_name] = {}
continue
# Skip if not found any vis_data folder.
sub_dir = dir_map[model_name]
log_files = list(sub_dir.glob('*.log.json'))
if len(log_files) == 0:
vis_folders = [d for d in sub_dir.iterdir() if d.is_dir()]
if len(vis_folders) == 0:
continue
log_file = sorted(vis_folders)[-1] / 'vis_data' / 'scalars.json'
if not log_file.exists():
continue
log_file = sorted(log_files)[-1]
# parse train log
with open(log_file) as f:
json_logs = [json.loads(s) for s in f.readlines()]
val_logs = [
log for log in json_logs
if 'mode' in log and log['mode'] == 'val'
]
val_logs = [log for log in json_logs if 'loss' not in log]
if len(val_logs) == 0:
continue
@ -311,9 +312,10 @@ def summary(args):
f'{model_name}: No metric "{key_res}"'
expect_result = float(expect_metrics[key_yml])
last = float(val_logs[-1][key_res])
best_log = sorted(val_logs, key=lambda x: x[key_res])[-1]
best_log, best_epoch = sorted(
zip(val_logs, range(len(val_logs))),
key=lambda x: x[0][key_res])[-1]
best = float(best_log[key_res])
best_epoch = int(best_log['epoch'])
summary[key_yml] = dict(
expect=expect_result,

View File

@ -1,22 +1,11 @@
Models:
- Name: resnet34
- Name: resnet50
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 73.85
Top 5 Accuracy: 91.53
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_batch256_imagenet_20200708-32ffb4f7.pth
Config: configs/resnet/resnet34_8xb32_in1k.py
Gpus: 8
- Name: vgg11bn
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 70.75
Top 5 Accuracy: 90.12
Weights: https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_bn_batch256_imagenet_20210207-f244902c.pth
Config: configs/vgg/vgg11bn_8xb32_in1k.py
Top 1 Accuracy: 76.55
Top 5 Accuracy: 93.06
Config: configs/resnet/resnet50_8xb32_in1k.py
Gpus: 8
- Name: seresnet50
@ -25,49 +14,26 @@ Models:
Metrics:
Top 1 Accuracy: 77.74
Top 5 Accuracy: 93.84
Weights: https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet50_batch256_imagenet_20200804-ae206104.pth
Config: configs/seresnet/seresnet50_8xb32_in1k.py
Gpus: 8
- Name: resnext50
- Name: vit-base
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 77.92
Top 5 Accuracy: 93.74
Weights: https://download.openmmlab.com/mmclassification/v0/resnext/resnext50_32x4d_batch256_imagenet_20200708-c07adbb7.pth
Config: configs/resnext/resnext50-32x4d_8xb32_in1k.py
Gpus: 8
Top 1 Accuracy: 82.37
Top 5 Accuracy: 96.15
Config: configs/vision_transformer/vit-base-p16_pt-32xb128-mae_in1k-224.py
Gpus: 32
- Name: mobilenet
- Name: mobilenetv2
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 71.86
Top 5 Accuracy: 90.42
Weights: https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth
Config: configs/mobilenet_v2/mobilenet-v2_8xb32_in1k.py
Gpus: 8
Months:
- 1
- 4
- 7
- 10
- Name: shufflenet_v1
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 68.13
Top 5 Accuracy: 87.81
Weights: https://download.openmmlab.com/mmclassification/v0/shufflenet_v1/shufflenet_v1_batch1024_imagenet_20200804-5d6cec73.pth
Config: configs/shufflenet_v1/shufflenet-v1-1x_16xb64_in1k.py
Gpus: 16
Months:
- 2
- 5
- 8
- 11
- Name: swin_tiny
Results:
@ -78,6 +44,43 @@ Models:
Weights: https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_b16x64_300e_imagenet_20210616_090925-66df6be6.pth
Config: configs/swin_transformer/swin-tiny_16xb64_in1k.py
Gpus: 16
- Name: vgg16
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 71.62
Top 5 Accuracy: 90.49
Config: configs/vgg/vgg16_8xb32_in1k.py
Gpus: 8
Months:
- 1
- 4
- 7
- 10
- Name: shufflenet_v2
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 69.55
Top 5 Accuracy: 88.92
Config: configs/shufflenet_v2/shufflenet-v2-1x_16xb64_in1k.py
Gpus: 16
Months:
- 2
- 5
- 8
- 11
- Name: resnet-rsb
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 80.12
Top 5 Accuracy: 94.78
Config: configs/resnet/resnet50_8xb256-rsb-a1-600e_in1k.py
Gpus: 8
Months:
- 3
- 6

View File

@ -0,0 +1,10 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='InceptionV3', num_classes=1000, aux_logits=False),
neck=None,
head=dict(
type='ClsHead',
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5)),
)

View File

@ -6,6 +6,6 @@ param_scheduler = dict(
type='MultiStepLR', by_epoch=True, milestones=[100, 150], gamma=0.1)
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=200)
val_cfg = dict(interval=1) # validate every epoch
train_cfg = dict(by_epoch=True, max_epochs=200, val_interval=1)
val_cfg = dict()
test_cfg = dict()

View File

@ -34,6 +34,6 @@ param_scheduler = [
]
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=300)
val_cfg = dict(interval=1) # validate every epoch
train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1)
val_cfg = dict()
test_cfg = dict()

View File

@ -38,6 +38,6 @@ param_scheduler = [
]
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=300)
val_cfg = dict(interval=1) # validate every epoch
train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1)
val_cfg = dict()
test_cfg = dict()

View File

@ -9,6 +9,6 @@ param_scheduler = [
]
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=300)
val_cfg = dict(interval=1) # validate every epoch
train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)
val_cfg = dict()
test_cfg = dict()

View File

@ -17,6 +17,6 @@ param_scheduler = [
]
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=300)
val_cfg = dict(interval=1) # validate every other epoch
train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1)
val_cfg = dict()
test_cfg = dict()

View File

@ -12,6 +12,6 @@ param_scheduler = [
]
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=100)
val_cfg = dict(interval=1) # validate every epoch
train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)
val_cfg = dict()
test_cfg = dict()

View File

@ -33,6 +33,6 @@ param_scheduler = [
]
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=300)
val_cfg = dict(interval=1) # validate every epoch
train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1)
val_cfg = dict()
test_cfg = dict()

View File

@ -26,6 +26,6 @@ param_scheduler = [
]
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=100)
val_cfg = dict(interval=1) # validate every epoch
train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)
val_cfg = dict()
test_cfg = dict()

View File

@ -23,6 +23,6 @@ param_scheduler = [
]
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=100)
val_cfg = dict(interval=1) # validate every epoch
train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)
val_cfg = dict()
test_cfg = dict()

View File

@ -7,6 +7,6 @@ param_scheduler = dict(
type='MultiStepLR', by_epoch=True, milestones=[30, 60, 90], gamma=0.1)
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=100)
val_cfg = dict(interval=1) # validate every epoch
train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)
val_cfg = dict()
test_cfg = dict()

View File

@ -7,6 +7,6 @@ param_scheduler = dict(
type='MultiStepLR', by_epoch=True, milestones=[40, 80, 120], gamma=0.1)
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=140)
val_cfg = dict(interval=1) # validate every epoch
train_cfg = dict(by_epoch=True, max_epochs=140, val_interval=1)
val_cfg = dict()
test_cfg = dict()

View File

@ -25,6 +25,6 @@ param_scheduler = [
]
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=200)
val_cfg = dict(interval=1) # validate every epoch
train_cfg = dict(by_epoch=True, max_epochs=200, val_interval=1)
val_cfg = dict()
test_cfg = dict()

View File

@ -7,6 +7,6 @@ param_scheduler = dict(
type='CosineAnnealingLR', T_max=100, by_epoch=True, begin=0, end=100)
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=100)
val_cfg = dict(interval=1) # validate every epoch
train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)
val_cfg = dict()
test_cfg = dict()

View File

@ -6,6 +6,6 @@ optim_wrapper = dict(
param_scheduler = dict(type='StepLR', by_epoch=True, step_size=1, gamma=0.98)
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=300)
val_cfg = dict(interval=1) # validate every epoch
train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1)
val_cfg = dict()
test_cfg = dict()

View File

@ -30,6 +30,6 @@ param_scheduler = [
]
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=300)
val_cfg = dict(interval=1) # validate every epoch
train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1)
val_cfg = dict()
test_cfg = dict()

View File

@ -32,7 +32,7 @@ test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=288,
scale=256,
edge='short',
backend='pillow',
interpolation='bicubic'),

View File

@ -0,0 +1,34 @@
# Inception V3
> [Rethinking the Inception Architecture for Computer Vision](http://arxiv.org/abs/1512.00567)
<!-- [ALGORITHM] -->
## Abstract
Convolutional networks are at the core of most state-of-the-art computer vision solutions for a wide variety of tasks. Since 2014 very deep convolutional networks started to become mainstream, yielding substantial gains in various benchmarks. Although increased model size and computational cost tend to translate to immediate quality gains for most tasks (as long as enough labeled data is provided for training), computational efficiency and low parameter count are still enabling factors for various use cases such as mobile vision and big-data scenarios. Here we explore ways to scale up networks in ways that aim at utilizing the added computation as efficiently as possible by suitably factorized convolutions and aggressive regularization. We benchmark our methods on the ILSVRC 2012 classification challenge validation set demonstrate substantial gains over the state of the art: 21.2% top-1 and 5.6% top-5 error for single frame evaluation using a network with a computational cost of 5 billion multiply-adds per inference and with using less than 25 million parameters. With an ensemble of 4 models and multi-crop evaluation, we report 3.5% top-5 error on the validation set (3.6% error on the test set) and 17.3% top-1 error on the validation set.
<div align=center>
<img src="https://user-images.githubusercontent.com/26739999/177241797-c103eff4-79bb-414d-aef6-eac323b65a50.png" width="40%"/>
</div>
## Results and models
### ImageNet-1k
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
| Inception V3\* | 23.83 | 5.75 | 77.57 | 93.58 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/inception_v3/inception-v3_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/inception-v3/inception-v3_3rdparty_8xb32_in1k_20220615-dcd4d910.pth) |
*Models with \* are converted from the [official repo](https://github.com/pytorch/vision/blob/main/torchvision/models/inception.py#L28). The config files of these models are only for inference. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
## Citation
```bibtex
@inproceedings{szegedy2016rethinking,
title={Rethinking the inception architecture for computer vision},
author={Szegedy, Christian and Vanhoucke, Vincent and Ioffe, Sergey and Shlens, Jon and Wojna, Zbigniew},
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
pages={2818--2826},
year={2016}
}
```

View File

@ -0,0 +1,24 @@
_base_ = [
'../_base_/models/inception_v3.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256_coslr.py',
'../_base_/default_runtime.py',
]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=299),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackClsInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=342, edge='short'),
dict(type='CenterCrop', crop_size=299),
dict(type='PackClsInputs'),
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = dict(dataset=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,37 @@
Collections:
- Name: Inception V3
Metadata:
Training Data: ImageNet-1k
Training Techniques:
- SGD with Momentum
- Weight Decay
Training Resources: 8x V100 GPUs
Epochs: 100
Batch Size: 256
Architecture:
- Inception
Paper:
URL: http://arxiv.org/abs/1512.00567
Title: "Rethinking the Inception Architecture for Computer Vision"
README: configs/inception_v3/README.md
Code:
URL: TODO
Version: TODO
Models:
- Name: inception-v3_3rdparty_8xb32_in1k
Metadata:
FLOPs: 5745177632
Parameters: 23834568
In Collection: Inception V3
Results:
- Task: Image Classification
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 77.57
Top 5 Accuracy: 93.58
Weights: https://download.openmmlab.com/mmclassification/v0/inception-v3/inception-v3_3rdparty_8xb32_in1k_20220615-dcd4d910.pth
Config: configs/inception_v3/inception-v3_8xb32_in1k.py
Converted From:
Weights: https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth
Code: https://github.com/pytorch/vision/blob/main/torchvision/models/inception.py#L28

View File

@ -48,8 +48,8 @@ param_scheduler = dict(
gamma=0.1, # decay to 0.1 times.
)
train_cfg = dict(by_epoch=True, max_epochs=5) # train 5 epochs
val_cfg = dict(interval=1) # validate every epoch
train_cfg = dict(by_epoch=True, max_epochs=5, val_interval=1) # train 5 epochs
val_cfg = dict()
test_cfg = dict()
# runtime settings

View File

@ -18,6 +18,6 @@ optim_wrapper = dict(
param_scheduler = dict(type='StepLR', by_epoch=True, step_size=2, gamma=0.973)
train_cfg = dict(by_epoch=True, max_epochs=600)
val_cfg = dict(interval=1) # validate every epoch
train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=1)
val_cfg = dict()
test_cfg = dict()

View File

@ -18,6 +18,6 @@ optim_wrapper = dict(
param_scheduler = dict(type='StepLR', by_epoch=True, step_size=2, gamma=0.973)
train_cfg = dict(by_epoch=True, max_epochs=600)
val_cfg = dict(interval=1) # validate every epoch
train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=1)
val_cfg = dict()
test_cfg = dict()

View File

@ -1,6 +1,39 @@
_base_ = [
'../_base_/models/repvgg-B3_lbs-mixup_in1k.py',
'../_base_/datasets/imagenet_bs64_pil_resize_autoaug.py',
'../_base_/datasets/imagenet_bs64_pil_resize.py',
'../_base_/schedules/imagenet_bs256_200e_coslr_warmup.py',
'../_base_/default_runtime.py'
]
preprocess_cfg = dict(
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
bgr_mean = preprocess_cfg['mean'][::-1]
bgr_std = preprocess_cfg['std'][::-1]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=224, backend='pillow'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='AutoAugment',
policies='imagenet',
hparams=dict(pad_val=[round(x) for x in bgr_mean])),
dict(type='PackClsInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=256, edge='short', backend='pillow'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs'),
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = dict(dataset=dict(pipeline=test_pipeline))

View File

@ -36,8 +36,8 @@ param_scheduler = [
dict(type='ConstantLR', factor=0.1, by_epoch=True, begin=300, end=310),
]
train_cfg = dict(by_epoch=True, max_epochs=310)
val_cfg = dict(interval=1) # validate every epoch
train_cfg = dict(by_epoch=True, max_epochs=310, val_interval=1)
val_cfg = dict()
test_cfg = dict()
# runtime settings

View File

@ -36,8 +36,8 @@ param_scheduler = [
dict(type='ConstantLR', factor=0.1, by_epoch=True, begin=300, end=310),
]
train_cfg = dict(by_epoch=True, max_epochs=310)
val_cfg = dict(interval=1) # validate every epoch
train_cfg = dict(by_epoch=True, max_epochs=310, val_interval=1)
val_cfg = dict()
test_cfg = dict()
# runtime settings

View File

@ -36,8 +36,8 @@ param_scheduler = [
dict(type='ConstantLR', factor=0.1, by_epoch=True, begin=300, end=310),
]
train_cfg = dict(by_epoch=True, max_epochs=310)
val_cfg = dict(interval=1) # validate every epoch
train_cfg = dict(by_epoch=True, max_epochs=310, val_interval=1)
val_cfg = dict()
test_cfg = dict()
# runtime settings

View File

@ -46,6 +46,6 @@ param_scheduler = [
dict(type='CosineAnnealingLR', T_max=295, by_epoch=True, begin=5, end=300)
]
train_cfg = dict(by_epoch=True, max_epochs=300)
val_cfg = dict(interval=1) # validate every epoch
train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1)
val_cfg = dict()
test_cfg = dict()

View File

@ -33,6 +33,6 @@ optim_wrapper = dict(
param_scheduler = dict(type='StepLR', by_epoch=True, step_size=20, gamma=0.1)
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=40)
val_cfg = dict(interval=1) # validate every epoch
train_cfg = dict(by_epoch=True, max_epochs=40, val_interval=1)
val_cfg = dict()
test_cfg = dict()

View File

@ -1,6 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .data_structures import * # noqa: F401, F403
from .evaluation import * # noqa: F401, F403
from .hook import * # noqa: F401, F403
from .optimizers import * # noqa: F401, F403
from .utils import * # noqa: F401, F403

View File

@ -1,12 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .eval_hooks import DistEvalHook, EvalHook
from .eval_metrics import (calculate_confusion_matrix, f1_score, precision,
precision_recall_f1, recall, support)
from .mean_ap import average_precision, mAP
from .multilabel_eval_metrics import average_performance
__all__ = [
'precision', 'recall', 'f1_score', 'support', 'average_precision', 'mAP',
'average_performance', 'calculate_confusion_matrix', 'precision_recall_f1',
'EvalHook', 'DistEvalHook'
]

View File

@ -1,259 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from numbers import Number
import numpy as np
import torch
from torch.nn.functional import one_hot
def calculate_confusion_matrix(pred, target):
"""Calculate confusion matrix according to the prediction and target.
Args:
pred (torch.Tensor | np.array): The model prediction with shape (N, C).
target (torch.Tensor | np.array): The target of each prediction with
shape (N, 1) or (N,).
Returns:
torch.Tensor: Confusion matrix
The shape is (C, C), where C is the number of classes.
"""
if isinstance(pred, np.ndarray):
pred = torch.from_numpy(pred)
if isinstance(target, np.ndarray):
target = torch.from_numpy(target)
assert (
isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor)), \
(f'pred and target should be torch.Tensor or np.ndarray, '
f'but got {type(pred)} and {type(target)}.')
# Modified from PyTorch-Ignite
num_classes = pred.size(1)
pred_label = torch.argmax(pred, dim=1).flatten()
target_label = target.flatten()
assert len(pred_label) == len(target_label)
with torch.no_grad():
indices = num_classes * target_label + pred_label
matrix = torch.bincount(indices, minlength=num_classes**2)
matrix = matrix.reshape(num_classes, num_classes)
return matrix
def precision_recall_f1(pred, target, average_mode='macro', thrs=0.):
"""Calculate precision, recall and f1 score according to the prediction and
target.
Args:
pred (torch.Tensor | np.array): The model prediction with shape (N, C).
target (torch.Tensor | np.array): The target of each prediction with
shape (N, 1) or (N,).
average_mode (str): The type of averaging performed on the result.
Options are 'macro' and 'none'. If 'none', the scores for each
class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean.
Defaults to 'macro'.
thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Defaults to 0.
Returns:
tuple: tuple containing precision, recall, f1 score.
The type of precision, recall, f1 score is one of the following:
+----------------------------+--------------------+-------------------+
| Args | ``thrs`` is number | ``thrs`` is tuple |
+============================+====================+===================+
| ``average_mode`` = "macro" | float | list[float] |
+----------------------------+--------------------+-------------------+
| ``average_mode`` = "none" | np.array | list[np.array] |
+----------------------------+--------------------+-------------------+
"""
allowed_average_mode = ['macro', 'none']
if average_mode not in allowed_average_mode:
raise ValueError(f'Unsupport type of averaging {average_mode}.')
if isinstance(pred, np.ndarray):
pred = torch.from_numpy(pred)
assert isinstance(pred, torch.Tensor), \
(f'pred should be torch.Tensor or np.ndarray, but got {type(pred)}.')
if isinstance(target, np.ndarray):
target = torch.from_numpy(target).long()
assert isinstance(target, torch.Tensor), \
f'target should be torch.Tensor or np.ndarray, ' \
f'but got {type(target)}.'
if isinstance(thrs, Number):
thrs = (thrs, )
return_single = True
elif isinstance(thrs, tuple):
return_single = False
else:
raise TypeError(
f'thrs should be a number or tuple, but got {type(thrs)}.')
num_classes = pred.size(1)
pred_score, pred_label = torch.topk(pred, k=1)
pred_score = pred_score.flatten()
pred_label = pred_label.flatten()
gt_positive = one_hot(target.flatten(), num_classes)
precisions = []
recalls = []
f1_scores = []
for thr in thrs:
# Only prediction values larger than thr are counted as positive
pred_positive = one_hot(pred_label, num_classes)
if thr is not None:
pred_positive[pred_score <= thr] = 0
class_correct = (pred_positive & gt_positive).sum(0)
precision = class_correct / np.maximum(pred_positive.sum(0), 1.) * 100
recall = class_correct / np.maximum(gt_positive.sum(0), 1.) * 100
f1_score = 2 * precision * recall / np.maximum(
precision + recall,
torch.finfo(torch.float32).eps)
if average_mode == 'macro':
precision = float(precision.mean())
recall = float(recall.mean())
f1_score = float(f1_score.mean())
elif average_mode == 'none':
precision = precision.detach().cpu().numpy()
recall = recall.detach().cpu().numpy()
f1_score = f1_score.detach().cpu().numpy()
else:
raise ValueError(f'Unsupport type of averaging {average_mode}.')
precisions.append(precision)
recalls.append(recall)
f1_scores.append(f1_score)
if return_single:
return precisions[0], recalls[0], f1_scores[0]
else:
return precisions, recalls, f1_scores
def precision(pred, target, average_mode='macro', thrs=0.):
"""Calculate precision according to the prediction and target.
Args:
pred (torch.Tensor | np.array): The model prediction with shape (N, C).
target (torch.Tensor | np.array): The target of each prediction with
shape (N, 1) or (N,).
average_mode (str): The type of averaging performed on the result.
Options are 'macro' and 'none'. If 'none', the scores for each
class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean.
Defaults to 'macro'.
thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Defaults to 0.
Returns:
float | np.array | list[float | np.array]: Precision.
+----------------------------+--------------------+-------------------+
| Args | ``thrs`` is number | ``thrs`` is tuple |
+============================+====================+===================+
| ``average_mode`` = "macro" | float | list[float] |
+----------------------------+--------------------+-------------------+
| ``average_mode`` = "none" | np.array | list[np.array] |
+----------------------------+--------------------+-------------------+
"""
precisions, _, _ = precision_recall_f1(pred, target, average_mode, thrs)
return precisions
def recall(pred, target, average_mode='macro', thrs=0.):
"""Calculate recall according to the prediction and target.
Args:
pred (torch.Tensor | np.array): The model prediction with shape (N, C).
target (torch.Tensor | np.array): The target of each prediction with
shape (N, 1) or (N,).
average_mode (str): The type of averaging performed on the result.
Options are 'macro' and 'none'. If 'none', the scores for each
class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean.
Defaults to 'macro'.
thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Defaults to 0.
Returns:
float | np.array | list[float | np.array]: Recall.
+----------------------------+--------------------+-------------------+
| Args | ``thrs`` is number | ``thrs`` is tuple |
+============================+====================+===================+
| ``average_mode`` = "macro" | float | list[float] |
+----------------------------+--------------------+-------------------+
| ``average_mode`` = "none" | np.array | list[np.array] |
+----------------------------+--------------------+-------------------+
"""
_, recalls, _ = precision_recall_f1(pred, target, average_mode, thrs)
return recalls
def f1_score(pred, target, average_mode='macro', thrs=0.):
"""Calculate F1 score according to the prediction and target.
Args:
pred (torch.Tensor | np.array): The model prediction with shape (N, C).
target (torch.Tensor | np.array): The target of each prediction with
shape (N, 1) or (N,).
average_mode (str): The type of averaging performed on the result.
Options are 'macro' and 'none'. If 'none', the scores for each
class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean.
Defaults to 'macro'.
thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Defaults to 0.
Returns:
float | np.array | list[float | np.array]: F1 score.
+----------------------------+--------------------+-------------------+
| Args | ``thrs`` is number | ``thrs`` is tuple |
+============================+====================+===================+
| ``average_mode`` = "macro" | float | list[float] |
+----------------------------+--------------------+-------------------+
| ``average_mode`` = "none" | np.array | list[np.array] |
+----------------------------+--------------------+-------------------+
"""
_, _, f1_scores = precision_recall_f1(pred, target, average_mode, thrs)
return f1_scores
def support(pred, target, average_mode='macro'):
"""Calculate the total number of occurrences of each label according to the
prediction and target.
Args:
pred (torch.Tensor | np.array): The model prediction with shape (N, C).
target (torch.Tensor | np.array): The target of each prediction with
shape (N, 1) or (N,).
average_mode (str): The type of averaging performed on the result.
Options are 'macro' and 'none'. If 'none', the scores for each
class are returned. If 'macro', calculate metrics for each class,
and find their unweighted sum.
Defaults to 'macro'.
Returns:
float | np.array: Support.
- If the ``average_mode`` is set to macro, the function returns
a single float.
- If the ``average_mode`` is set to none, the function returns
a np.array with shape C.
"""
confusion_matrix = calculate_confusion_matrix(pred, target)
with torch.no_grad():
res = confusion_matrix.sum(1)
if average_mode == 'macro':
res = float(res.sum().numpy())
elif average_mode == 'none':
res = res.numpy()
else:
raise ValueError(f'Unsupport type of averaging {average_mode}.')
return res

View File

@ -1,74 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
def average_precision(pred, target):
r"""Calculate the average precision for a single class.
AP summarizes a precision-recall curve as the weighted mean of maximum
precisions obtained for any r'>r, where r is the recall:
.. math::
\text{AP} = \sum_n (R_n - R_{n-1}) P_n
Note that no approximation is involved since the curve is piecewise
constant.
Args:
pred (np.ndarray): The model prediction with shape (N, ).
target (np.ndarray): The target of each prediction with shape (N, ).
Returns:
float: a single float as average precision value.
"""
eps = np.finfo(np.float32).eps
# sort examples
sort_inds = np.argsort(-pred)
sort_target = target[sort_inds]
# count true positive examples
pos_inds = sort_target == 1
tp = np.cumsum(pos_inds)
total_pos = tp[-1]
# count not difficult examples
pn_inds = sort_target != -1
pn = np.cumsum(pn_inds)
tp[np.logical_not(pos_inds)] = 0
precision = tp / np.maximum(pn, eps)
ap = np.sum(precision) / np.maximum(total_pos, eps)
return ap
def mAP(pred, target):
"""Calculate the mean average precision with respect of classes.
Args:
pred (torch.Tensor | np.ndarray): The model prediction with shape
(N, C), where C is the number of classes.
target (torch.Tensor | np.ndarray): The target of each prediction with
shape (N, C), where C is the number of classes. 1 stands for
positive examples, 0 stands for negative examples and -1 stands for
difficult examples.
Returns:
float: A single float as mAP value.
"""
if isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor):
pred = pred.detach().cpu().numpy()
target = target.detach().cpu().numpy()
elif not (isinstance(pred, np.ndarray) and isinstance(target, np.ndarray)):
raise TypeError('pred and target should both be torch.Tensor or'
'np.ndarray')
assert pred.shape == \
target.shape, 'pred and target should be in the same shape.'
num_classes = pred.shape[1]
ap = np.zeros(num_classes)
for k in range(num_classes):
ap[k] = average_precision(pred[:, k], target[:, k])
mean_ap = ap.mean() * 100.0
return mean_ap

View File

@ -1,72 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import numpy as np
import torch
def average_performance(pred, target, thr=None, k=None):
"""Calculate CP, CR, CF1, OP, OR, OF1, where C stands for per-class
average, O stands for overall average, P stands for precision, R stands for
recall and F1 stands for F1-score.
Args:
pred (torch.Tensor | np.ndarray): The model prediction with shape
(N, C), where C is the number of classes.
target (torch.Tensor | np.ndarray): The target of each prediction with
shape (N, C), where C is the number of classes. 1 stands for
positive examples, 0 stands for negative examples and -1 stands for
difficult examples.
thr (float): The confidence threshold. Defaults to None.
k (int): Top-k performance. Note that if thr and k are both given, k
will be ignored. Defaults to None.
Returns:
tuple: (CP, CR, CF1, OP, OR, OF1)
"""
if isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor):
pred = pred.detach().cpu().numpy()
target = target.detach().cpu().numpy()
elif not (isinstance(pred, np.ndarray) and isinstance(target, np.ndarray)):
raise TypeError('pred and target should both be torch.Tensor or'
'np.ndarray')
if thr is None and k is None:
thr = 0.5
warnings.warn('Neither thr nor k is given, set thr as 0.5 by '
'default.')
elif thr is not None and k is not None:
warnings.warn('Both thr and k are given, use threshold in favor of '
'top-k.')
assert pred.shape == \
target.shape, 'pred and target should be in the same shape.'
eps = np.finfo(np.float32).eps
target[target == -1] = 0
if thr is not None:
# a label is predicted positive if the confidence is no lower than thr
pos_inds = pred >= thr
else:
# top-k labels will be predicted positive for any example
sort_inds = np.argsort(-pred, axis=1)
sort_inds_ = sort_inds[:, :k]
inds = np.indices(sort_inds_.shape)
pos_inds = np.zeros_like(pred)
pos_inds[inds[0], sort_inds_] = 1
tp = (pos_inds * target) == 1
fp = (pos_inds * (1 - target)) == 1
fn = ((1 - pos_inds) * target) == 1
precision_class = tp.sum(axis=0) / np.maximum(
tp.sum(axis=0) + fp.sum(axis=0), eps)
recall_class = tp.sum(axis=0) / np.maximum(
tp.sum(axis=0) + fn.sum(axis=0), eps)
CP = precision_class.mean() * 100.0
CR = recall_class.mean() * 100.0
CF1 = 2 * CP * CR / np.maximum(CP + CR, eps)
OP = tp.sum() / np.maximum(tp.sum() + fp.sum(), eps) * 100.0
OR = tp.sum() / np.maximum(tp.sum() + fn.sum(), eps) * 100.0
OF1 = 2 * OP * OR / np.maximum(OP + OR, eps)
return CP, CR, CF1, OP, OR, OF1

View File

@ -4,8 +4,7 @@ from .builder import build_dataset
from .cifar import CIFAR10, CIFAR100
from .cub import CUB
from .custom import CustomDataset
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
KFoldDataset, RepeatDataset)
from .dataset_wrappers import KFoldDataset
from .imagenet import ImageNet, ImageNet21k
from .mnist import MNIST, FashionMNIST
from .multi_label import MultiLabelDataset
@ -15,7 +14,6 @@ from .voc import VOC
__all__ = [
'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST',
'VOC', 'build_dataset', 'ConcatDataset', 'RepeatDataset',
'ClassBalancedDataset', 'ImageNet21k', 'KFoldDataset', 'CUB',
'VOC', 'build_dataset', 'ImageNet21k', 'KFoldDataset', 'CUB',
'CustomDataset', 'MultiLabelDataset'
]

View File

@ -3,9 +3,9 @@ from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
import mmengine
from mmengine import FileClient
from mmengine.logging import MMLogger
from mmcls.registry import DATASETS
from mmcls.utils import get_root_logger
from .base_dataset import BaseDataset
@ -193,7 +193,7 @@ class CustomDataset(BaseDataset):
self._metainfo['classes'] = tuple(classes)
if empty_classes:
logger = get_root_logger()
logger = MMLogger.get_current_instance()
logger.warning(
'Found no valid file in the folder '
f'{", ".join(empty_classes)}. '

View File

@ -1,281 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import bisect
import math
from collections import defaultdict
import numpy as np
from mmengine.logging import print_log
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
from mmcls.registry import DATASETS
@DATASETS.register_module()
class ConcatDataset(_ConcatDataset):
"""A wrapper of concatenated dataset.
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
add `get_cat_ids` function.
Args:
datasets (list[:obj:`BaseDataset`]): A list of datasets.
separate_eval (bool): Whether to evaluate the results
separately if it is used as validation dataset.
Defaults to True.
"""
def __init__(self, datasets, separate_eval=True):
super(ConcatDataset, self).__init__(datasets)
self.separate_eval = separate_eval
self.CLASSES = datasets[0].CLASSES
if not separate_eval:
if len(set([type(ds) for ds in datasets])) != 1:
raise NotImplementedError(
'To evaluate a concat dataset non-separately, '
'all the datasets should have same types')
def get_cat_ids(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError(
'absolute value of index should not exceed dataset length')
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx].get_cat_ids(sample_idx)
def evaluate(self, results, *args, indices=None, logger=None, **kwargs):
"""Evaluate the results.
Args:
results (list[list | tuple]): Testing results of the dataset.
indices (list, optional): The indices of samples corresponding to
the results. It's unavailable on ConcatDataset.
Defaults to None.
logger (logging.Logger or str, optional): If the type of logger is
``logging.Logger``, we directly use logger to log messages.
Some special loggers are:
- "silent": No message will be printed.
- "current": Use latest created logger to log message.
- other str: Instance name of logger. The corresponding logger
will log message if it has been created, otherwise will raise a
`ValueError`.
- None: The `print()` method will be used to print log
messages.
Returns:
dict[str: float]: AP results of the total dataset or each separate
dataset if `self.separate_eval=True`.
"""
if indices is not None:
raise NotImplementedError(
'Use indices to evaluate speific samples in a ConcatDataset '
'is not supported by now.')
assert len(results) == len(self), \
('Dataset and results have different sizes: '
f'{len(self)} v.s. {len(results)}')
# Check whether all the datasets support evaluation
for dataset in self.datasets:
assert hasattr(dataset, 'evaluate'), \
f"{type(dataset)} haven't implemented the evaluate function."
if self.separate_eval:
total_eval_results = dict()
for dataset_idx, dataset in enumerate(self.datasets):
start_idx = 0 if dataset_idx == 0 else \
self.cumulative_sizes[dataset_idx-1]
end_idx = self.cumulative_sizes[dataset_idx]
results_per_dataset = results[start_idx:end_idx]
print_log(
f'Evaluateing dataset-{dataset_idx} with '
f'{len(results_per_dataset)} images now',
logger=logger)
eval_results_per_dataset = dataset.evaluate(
results_per_dataset, *args, logger=logger, **kwargs)
for k, v in eval_results_per_dataset.items():
total_eval_results.update({f'{dataset_idx}_{k}': v})
return total_eval_results
else:
original_data_infos = self.datasets[0].data_infos
self.datasets[0].data_infos = sum(
[dataset.data_infos for dataset in self.datasets], [])
eval_results = self.datasets[0].evaluate(
results, logger=logger, **kwargs)
self.datasets[0].data_infos = original_data_infos
return eval_results
@DATASETS.register_module()
class RepeatDataset(object):
"""A wrapper of repeated dataset.
The length of repeated dataset will be `times` larger than the original
dataset. This is useful when the data loading time is long but the dataset
is small. Using RepeatDataset can reduce the data loading time between
epochs.
Args:
dataset (:obj:`BaseDataset`): The dataset to be repeated.
times (int): Repeat times.
"""
def __init__(self, dataset, times):
self.dataset = dataset
self.times = times
self.CLASSES = dataset.CLASSES
self._ori_len = len(self.dataset)
def __getitem__(self, idx):
return self.dataset[idx % self._ori_len]
def get_cat_ids(self, idx):
return self.dataset.get_cat_ids(idx % self._ori_len)
def __len__(self):
return self.times * self._ori_len
def evaluate(self, *args, **kwargs):
raise NotImplementedError(
'evaluate results on a repeated dataset is weird. '
'Please inference and evaluate on the original dataset.')
def __repr__(self):
"""Print the number of instance number."""
dataset_type = 'Test' if self.test_mode else 'Train'
result = (
f'\n{self.__class__.__name__} ({self.dataset.__class__.__name__}) '
f'{dataset_type} dataset with total number of samples {len(self)}.'
)
return result
# Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa
@DATASETS.register_module()
class ClassBalancedDataset(object):
r"""A wrapper of repeated dataset with repeat factor.
Suitable for training on class imbalanced datasets like LVIS. Following the
sampling strategy in `this paper`_, in each epoch, an image may appear
multiple times based on its "repeat factor".
.. _this paper: https://arxiv.org/pdf/1908.03195.pdf
The repeat factor for an image is a function of the frequency the rarest
category labeled in that image. The "frequency of category c" in [0, 1]
is defined by the fraction of images in the training set (without repeats)
in which category c appears.
The dataset needs to implement :func:`self.get_cat_ids` to support
ClassBalancedDataset.
The repeat factor is computed as followed.
1. For each category c, compute the fraction :math:`f(c)` of images that
contain it.
2. For each category c, compute the category-level repeat factor
.. math::
r(c) = \max(1, \sqrt{\frac{t}{f(c)}})
3. For each image I and its labels :math:`L(I)`, compute the image-level
repeat factor
.. math::
r(I) = \max_{c \in L(I)} r(c)
Args:
dataset (:obj:`BaseDataset`): The dataset to be repeated.
oversample_thr (float): frequency threshold below which data is
repeated. For categories with ``f_c`` >= ``oversample_thr``, there
is no oversampling. For categories with ``f_c`` <
``oversample_thr``, the degree of oversampling following the
square-root inverse frequency heuristic above.
"""
def __init__(self, dataset, oversample_thr):
self.dataset = dataset
self.oversample_thr = oversample_thr
self.CLASSES = dataset.CLASSES
repeat_factors = self._get_repeat_factors(dataset, oversample_thr)
repeat_indices = []
for dataset_index, repeat_factor in enumerate(repeat_factors):
repeat_indices.extend([dataset_index] * math.ceil(repeat_factor))
self.repeat_indices = repeat_indices
flags = []
if hasattr(self.dataset, 'flag'):
for flag, repeat_factor in zip(self.dataset.flag, repeat_factors):
flags.extend([flag] * int(math.ceil(repeat_factor)))
assert len(flags) == len(repeat_indices)
self.flag = np.asarray(flags, dtype=np.uint8)
def _get_repeat_factors(self, dataset, repeat_thr):
# 1. For each category c, compute the fraction # of images
# that contain it: f(c)
category_freq = defaultdict(int)
num_images = len(dataset)
for idx in range(num_images):
cat_ids = set(self.dataset.get_cat_ids(idx))
for cat_id in cat_ids:
category_freq[cat_id] += 1
for k, v in category_freq.items():
assert v > 0, f'caterogy {k} does not contain any images'
category_freq[k] = v / num_images
# 2. For each category c, compute the category-level repeat factor:
# r(c) = max(1, sqrt(t/f(c)))
category_repeat = {
cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
for cat_id, cat_freq in category_freq.items()
}
# 3. For each image I and its labels L(I), compute the image-level
# repeat factor:
# r(I) = max_{c in L(I)} r(c)
repeat_factors = []
for idx in range(num_images):
cat_ids = set(self.dataset.get_cat_ids(idx))
repeat_factor = max(
{category_repeat[cat_id]
for cat_id in cat_ids})
repeat_factors.append(repeat_factor)
return repeat_factors
def __getitem__(self, idx):
ori_index = self.repeat_indices[idx]
return self.dataset[ori_index]
def __len__(self):
return len(self.repeat_indices)
def evaluate(self, *args, **kwargs):
raise NotImplementedError(
'evaluate results on a class-balanced dataset is weird. '
'Please inference and evaluate on the original dataset.')
def __repr__(self):
"""Print the number of instance number."""
dataset_type = 'Test' if self.test_mode else 'Train'
result = (
f'\n{self.__class__.__name__} ({self.dataset.__class__.__name__}) '
f'{dataset_type} dataset with total number of samples {len(self)}.'
)
return result
@DATASETS.register_module()
class KFoldDataset:
"""A wrapper of dataset for K-Fold cross-validation.

View File

@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union
from mmengine.logging import MMLogger
from mmcls.registry import DATASETS
from mmcls.utils import get_root_logger
from .categories import IMAGENET_CATEGORIES
from .custom import CustomDataset
@ -78,7 +79,7 @@ class ImageNet21k(CustomDataset):
'The `multi_label` option is not supported by now.')
self.multi_label = multi_label
logger = get_root_logger()
logger = MMLogger.get_current_instance()
if not ann_file:
logger.warning(

View File

@ -8,6 +8,7 @@ from .deit import DistilledVisionTransformer
from .densenet import DenseNet
from .efficientnet import EfficientNet
from .hrnet import HRNet
from .inception_v3 import InceptionV3
from .lenet import LeNet5
from .mlp_mixer import MlpMixer
from .mobilenet_v2 import MobileNetV2
@ -42,5 +43,5 @@ __all__ = [
'Conformer', 'MlpMixer', 'DistilledVisionTransformer', 'PCPVT', 'SVT',
'EfficientNet', 'ConvNeXt', 'HRNet', 'ResNetV1c', 'ConvMixer',
'CSPDarkNet', 'CSPResNet', 'CSPResNeXt', 'CSPNet', 'RepMLPNet',
'PoolFormer', 'DenseNet', 'VAN'
'PoolFormer', 'DenseNet', 'VAN', 'InceptionV3'
]

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
class BaseBackbone(BaseModule, metaclass=ABCMeta):

View File

@ -7,11 +7,11 @@ import torch.nn.functional as F
from mmcv.cnn import build_activation_layer, build_norm_layer
from mmcv.cnn.bricks.drop import DropPath
from mmcv.cnn.bricks.transformer import AdaptivePadding
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmengine.model import BaseModule
from mmengine.model.utils import trunc_normal_
from mmcls.registry import MODELS
from mmcls.utils import get_root_logger
from .base_backbone import BaseBackbone, BaseModule
from .base_backbone import BaseBackbone
from .vision_transformer import TransformerEncoderLayer
@ -576,17 +576,12 @@ class Conformer(BaseBackbone):
def init_weights(self):
super(Conformer, self).init_weights()
logger = get_root_logger()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return
else:
logger.info(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
self.apply(self._init_weights)
self.apply(self._init_weights)
def forward(self, x):
output = []

View File

@ -8,8 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks import (NORM_LAYERS, DropPath, build_activation_layer,
build_norm_layer)
from mmcv.runner import BaseModule
from mmcv.runner.base_module import ModuleList, Sequential
from mmengine.model import BaseModule, ModuleList, Sequential
from mmcls.registry import MODELS
from .base_backbone import BaseBackbone

View File

@ -6,7 +6,7 @@ import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmcv.cnn.bricks import DropPath
from mmcv.runner import BaseModule, Sequential
from mmengine.model import BaseModule, Sequential
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.registry import MODELS

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmengine.model.utils import trunc_normal_
from mmcls.registry import MODELS
from .vision_transformer import VisionTransformer

View File

@ -7,7 +7,7 @@ import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn.bricks import ConvModule, DropPath
from mmcv.runner import BaseModule, Sequential
from mmengine.model import BaseModule, Sequential
from mmcls.models.backbones.base_backbone import BaseBackbone
from mmcls.models.utils import InvertedResidual, SELayer, make_divisible

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule, ModuleList, Sequential
from mmengine.model import BaseModule, ModuleList, Sequential
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.registry import MODELS

View File

@ -0,0 +1,501 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple
import torch
import torch.nn as nn
from mmcv.cnn import build_conv_layer
from mmengine.model import BaseModule
from mmcls.registry import MODELS
from .base_backbone import BaseBackbone
class BasicConv2d(BaseModule):
"""A basic convolution block including convolution, batch norm and ReLU.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
conv_cfg (dict, optional): The config of convolution layer.
Defaults to None, which means to use ``nn.Conv2d``.
init_cfg (dict, optional): The config of initialization.
Defaults to None.
**kwargs: Other keyword arguments of the convolution layer.
"""
def __init__(self,
in_channels: int,
out_channels: int,
conv_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None,
**kwargs) -> None:
super().__init__(init_cfg=init_cfg)
self.conv = build_conv_layer(
conv_cfg, in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
self.relu = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function."""
x = self.conv(x)
x = self.bn(x)
return self.relu(x)
class InceptionA(BaseModule):
"""Type-A Inception block.
Args:
in_channels (int): The number of input channels.
pool_features (int): The number of channels in pooling branch.
conv_cfg (dict, optional): The convolution layer config in the
:class:`BasicConv2d` block. Defaults to None.
init_cfg (dict, optional): The config of initialization.
Defaults to None.
"""
def __init__(self,
in_channels: int,
pool_features: int,
conv_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None):
super().__init__(init_cfg=init_cfg)
self.branch1x1 = BasicConv2d(
in_channels, 64, kernel_size=1, conv_cfg=conv_cfg)
self.branch5x5_1 = BasicConv2d(
in_channels, 48, kernel_size=1, conv_cfg=conv_cfg)
self.branch5x5_2 = BasicConv2d(
48, 64, kernel_size=5, padding=2, conv_cfg=conv_cfg)
self.branch3x3dbl_1 = BasicConv2d(
in_channels, 64, kernel_size=1, conv_cfg=conv_cfg)
self.branch3x3dbl_2 = BasicConv2d(
64, 96, kernel_size=3, padding=1, conv_cfg=conv_cfg)
self.branch3x3dbl_3 = BasicConv2d(
96, 96, kernel_size=3, padding=1, conv_cfg=conv_cfg)
self.branch_pool_downsample = nn.AvgPool2d(
kernel_size=3, stride=1, padding=1)
self.branch_pool = BasicConv2d(
in_channels, pool_features, kernel_size=1, conv_cfg=conv_cfg)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function."""
branch1x1 = self.branch1x1(x)
branch5x5 = self.branch5x5_1(x)
branch5x5 = self.branch5x5_2(branch5x5)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
branch_pool = self.branch_pool_downsample(x)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
return torch.cat(outputs, 1)
class InceptionB(BaseModule):
"""Type-B Inception block.
Args:
in_channels (int): The number of input channels.
conv_cfg (dict, optional): The convolution layer config in the
:class:`BasicConv2d` block. Defaults to None.
init_cfg (dict, optional): The config of initialization.
Defaults to None.
"""
def __init__(self,
in_channels: int,
conv_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None):
super().__init__(init_cfg=init_cfg)
self.branch3x3 = BasicConv2d(
in_channels, 384, kernel_size=3, stride=2, conv_cfg=conv_cfg)
self.branch3x3dbl_1 = BasicConv2d(
in_channels, 64, kernel_size=1, conv_cfg=conv_cfg)
self.branch3x3dbl_2 = BasicConv2d(
64, 96, kernel_size=3, padding=1, conv_cfg=conv_cfg)
self.branch3x3dbl_3 = BasicConv2d(
96, 96, kernel_size=3, stride=2, conv_cfg=conv_cfg)
self.branch_pool = nn.MaxPool2d(kernel_size=3, stride=2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function."""
branch3x3 = self.branch3x3(x)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
branch_pool = self.branch_pool(x)
outputs = [branch3x3, branch3x3dbl, branch_pool]
return torch.cat(outputs, 1)
class InceptionC(BaseModule):
"""Type-C Inception block.
Args:
in_channels (int): The number of input channels.
channels_7x7 (int): The number of channels in 7x7 convolution branch.
conv_cfg (dict, optional): The convolution layer config in the
:class:`BasicConv2d` block. Defaults to None.
init_cfg (dict, optional): The config of initialization.
Defaults to None.
"""
def __init__(self,
in_channels: int,
channels_7x7: int,
conv_cfg: Optional[dict] = None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.branch1x1 = BasicConv2d(
in_channels, 192, kernel_size=1, conv_cfg=conv_cfg)
c7 = channels_7x7
self.branch7x7_1 = BasicConv2d(
in_channels, c7, kernel_size=1, conv_cfg=conv_cfg)
self.branch7x7_2 = BasicConv2d(
c7, c7, kernel_size=(1, 7), padding=(0, 3), conv_cfg=conv_cfg)
self.branch7x7_3 = BasicConv2d(
c7, 192, kernel_size=(7, 1), padding=(3, 0), conv_cfg=conv_cfg)
self.branch7x7dbl_1 = BasicConv2d(
in_channels, c7, kernel_size=1, conv_cfg=conv_cfg)
self.branch7x7dbl_2 = BasicConv2d(
c7, c7, kernel_size=(7, 1), padding=(3, 0), conv_cfg=conv_cfg)
self.branch7x7dbl_3 = BasicConv2d(
c7, c7, kernel_size=(1, 7), padding=(0, 3), conv_cfg=conv_cfg)
self.branch7x7dbl_4 = BasicConv2d(
c7, c7, kernel_size=(7, 1), padding=(3, 0), conv_cfg=conv_cfg)
self.branch7x7dbl_5 = BasicConv2d(
c7, 192, kernel_size=(1, 7), padding=(0, 3), conv_cfg=conv_cfg)
self.branch_pool_downsample = nn.AvgPool2d(
kernel_size=3, stride=1, padding=1)
self.branch_pool = BasicConv2d(
in_channels, 192, kernel_size=1, conv_cfg=conv_cfg)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function."""
branch1x1 = self.branch1x1(x)
branch7x7 = self.branch7x7_1(x)
branch7x7 = self.branch7x7_2(branch7x7)
branch7x7 = self.branch7x7_3(branch7x7)
branch7x7dbl = self.branch7x7dbl_1(x)
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
branch_pool = self.branch_pool_downsample(x)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
return torch.cat(outputs, 1)
class InceptionD(BaseModule):
"""Type-D Inception block.
Args:
in_channels (int): The number of input channels.
conv_cfg (dict, optional): The convolution layer config in the
:class:`BasicConv2d` block. Defaults to None.
init_cfg (dict, optional): The config of initialization.
Defaults to None.
"""
def __init__(self,
in_channels: int,
conv_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None):
super().__init__(init_cfg=init_cfg)
self.branch3x3_1 = BasicConv2d(
in_channels, 192, kernel_size=1, conv_cfg=conv_cfg)
self.branch3x3_2 = BasicConv2d(
192, 320, kernel_size=3, stride=2, conv_cfg=conv_cfg)
self.branch7x7x3_1 = BasicConv2d(
in_channels, 192, kernel_size=1, conv_cfg=conv_cfg)
self.branch7x7x3_2 = BasicConv2d(
192, 192, kernel_size=(1, 7), padding=(0, 3), conv_cfg=conv_cfg)
self.branch7x7x3_3 = BasicConv2d(
192, 192, kernel_size=(7, 1), padding=(3, 0), conv_cfg=conv_cfg)
self.branch7x7x3_4 = BasicConv2d(
192, 192, kernel_size=3, stride=2, conv_cfg=conv_cfg)
self.branch_pool = nn.MaxPool2d(kernel_size=3, stride=2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function."""
branch3x3 = self.branch3x3_1(x)
branch3x3 = self.branch3x3_2(branch3x3)
branch7x7x3 = self.branch7x7x3_1(x)
branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
branch_pool = self.branch_pool(x)
outputs = [branch3x3, branch7x7x3, branch_pool]
return torch.cat(outputs, 1)
class InceptionE(BaseModule):
"""Type-E Inception block.
Args:
in_channels (int): The number of input channels.
conv_cfg (dict, optional): The convolution layer config in the
:class:`BasicConv2d` block. Defaults to None.
init_cfg (dict, optional): The config of initialization.
Defaults to None.
"""
def __init__(self,
in_channels: int,
conv_cfg: Optional[dict] = None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.branch1x1 = BasicConv2d(
in_channels, 320, kernel_size=1, conv_cfg=conv_cfg)
self.branch3x3_1 = BasicConv2d(
in_channels, 384, kernel_size=1, conv_cfg=conv_cfg)
self.branch3x3_2a = BasicConv2d(
384, 384, kernel_size=(1, 3), padding=(0, 1), conv_cfg=conv_cfg)
self.branch3x3_2b = BasicConv2d(
384, 384, kernel_size=(3, 1), padding=(1, 0), conv_cfg=conv_cfg)
self.branch3x3dbl_1 = BasicConv2d(
in_channels, 448, kernel_size=1, conv_cfg=conv_cfg)
self.branch3x3dbl_2 = BasicConv2d(
448, 384, kernel_size=3, padding=1, conv_cfg=conv_cfg)
self.branch3x3dbl_3a = BasicConv2d(
384, 384, kernel_size=(1, 3), padding=(0, 1), conv_cfg=conv_cfg)
self.branch3x3dbl_3b = BasicConv2d(
384, 384, kernel_size=(3, 1), padding=(1, 0), conv_cfg=conv_cfg)
self.branch_pool_downsample = nn.AvgPool2d(
kernel_size=3, stride=1, padding=1)
self.branch_pool = BasicConv2d(
in_channels, 192, kernel_size=1, conv_cfg=conv_cfg)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function."""
branch1x1 = self.branch1x1(x)
branch3x3 = self.branch3x3_1(x)
branch3x3 = [
self.branch3x3_2a(branch3x3),
self.branch3x3_2b(branch3x3),
]
branch3x3 = torch.cat(branch3x3, 1)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = [
self.branch3x3dbl_3a(branch3x3dbl),
self.branch3x3dbl_3b(branch3x3dbl),
]
branch3x3dbl = torch.cat(branch3x3dbl, 1)
branch_pool = self.branch_pool_downsample(x)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
return torch.cat(outputs, 1)
class InceptionAux(BaseModule):
"""The Inception block for the auxiliary classification branch.
Args:
in_channels (int): The number of input channels.
num_classes (int): The number of categroies.
conv_cfg (dict, optional): The convolution layer config in the
:class:`BasicConv2d` block. Defaults to None.
init_cfg (dict, optional): The config of initialization.
Defaults to use trunc normal with ``std=0.01`` for Conv2d layers
and use trunc normal with ``std=0.001`` for Linear layers..
"""
def __init__(self,
in_channels: int,
num_classes: int,
conv_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = [
dict(type='TruncNormal', layer='Conv2d', std=0.01),
dict(type='TruncNormal', layer='Linear', std=0.001)
]):
super().__init__(init_cfg=init_cfg)
self.downsample = nn.AvgPool2d(kernel_size=5, stride=3)
self.conv0 = BasicConv2d(
in_channels, 128, kernel_size=1, conv_cfg=conv_cfg)
self.conv1 = BasicConv2d(128, 768, kernel_size=5, conv_cfg=conv_cfg)
self.gap = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(768, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function."""
# N x 768 x 17 x 17
x = self.downsample(x)
# N x 768 x 5 x 5
x = self.conv0(x)
# N x 128 x 5 x 5
x = self.conv1(x)
# N x 768 x 1 x 1
# Adaptive average pooling
x = self.gap(x)
# N x 768 x 1 x 1
x = torch.flatten(x, 1)
# N x 768
x = self.fc(x)
# N x 1000
return x
@MODELS.register_module()
class InceptionV3(BaseBackbone):
"""Inception V3 backbone.
A PyTorch implementation of `Rethinking the Inception Architecture for
Computer Vision <https://arxiv.org/abs/1512.00567>`_
This implementation is modified from
https://github.com/pytorch/vision/blob/main/torchvision/models/inception.py.
Licensed under the BSD 3-Clause License.
Args:
num_classes (int): The number of categroies. Defaults to 1000.
aux_logits (bool): Whether to enable the auxiliary branch. If False,
the auxiliary logits output will be None. Defaults to False.
dropout (float): Dropout rate. Defaults to 0.5.
init_cfg (dict, optional): The config of initialization. Defaults
to use trunc normal with ``std=0.1`` for all Conv2d and Linear
layers and constant with ``val=1`` for all BatchNorm2d layers.
Example:
>>> import torch
>>> from mmcls.models import build_backbone
>>>
>>> inputs = torch.rand(2, 3, 299, 299)
>>> cfg = dict(type='InceptionV3', num_classes=100)
>>> backbone = build_backbone(cfg)
>>> aux_out, out = backbone(inputs)
>>> # The auxiliary branch is disabled by default.
>>> assert aux_out is None
>>> print(out.shape)
torch.Size([2, 100])
>>> cfg = dict(type='InceptionV3', num_classes=100, aux_logits=True)
>>> backbone = build_backbone(cfg)
>>> aux_out, out = backbone(inputs)
>>> print(aux_out.shape, out.shape)
torch.Size([2, 100]) torch.Size([2, 100])
"""
def __init__(
self,
num_classes: int = 1000,
aux_logits: bool = False,
dropout: float = 0.5,
init_cfg: Optional[dict] = [
dict(type='TruncNormal', layer=['Conv2d', 'Linear'], std=0.1),
dict(type='Constant', layer='BatchNorm2d', val=1)
],
) -> None:
super().__init__(init_cfg=init_cfg)
self.aux_logits = aux_logits
self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)
self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
self.Mixed_5b = InceptionA(192, pool_features=32)
self.Mixed_5c = InceptionA(256, pool_features=64)
self.Mixed_5d = InceptionA(288, pool_features=64)
self.Mixed_6a = InceptionB(288)
self.Mixed_6b = InceptionC(768, channels_7x7=128)
self.Mixed_6c = InceptionC(768, channels_7x7=160)
self.Mixed_6d = InceptionC(768, channels_7x7=160)
self.Mixed_6e = InceptionC(768, channels_7x7=192)
self.AuxLogits: Optional[nn.Module] = None
if aux_logits:
self.AuxLogits = InceptionAux(768, num_classes)
self.Mixed_7a = InceptionD(768)
self.Mixed_7b = InceptionE(1280)
self.Mixed_7c = InceptionE(2048)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.dropout = nn.Dropout(p=dropout)
self.fc = nn.Linear(2048, num_classes)
def forward(
self,
x: torch.Tensor) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
"""Forward function."""
# N x 3 x 299 x 299
x = self.Conv2d_1a_3x3(x)
# N x 32 x 149 x 149
x = self.Conv2d_2a_3x3(x)
# N x 32 x 147 x 147
x = self.Conv2d_2b_3x3(x)
# N x 64 x 147 x 147
x = self.maxpool1(x)
# N x 64 x 73 x 73
x = self.Conv2d_3b_1x1(x)
# N x 80 x 73 x 73
x = self.Conv2d_4a_3x3(x)
# N x 192 x 71 x 71
x = self.maxpool2(x)
# N x 192 x 35 x 35
x = self.Mixed_5b(x)
# N x 256 x 35 x 35
x = self.Mixed_5c(x)
# N x 288 x 35 x 35
x = self.Mixed_5d(x)
# N x 288 x 35 x 35
x = self.Mixed_6a(x)
# N x 768 x 17 x 17
x = self.Mixed_6b(x)
# N x 768 x 17 x 17
x = self.Mixed_6c(x)
# N x 768 x 17 x 17
x = self.Mixed_6d(x)
# N x 768 x 17 x 17
x = self.Mixed_6e(x)
# N x 768 x 17 x 17
aux: Optional[torch.Tensor] = None
if self.aux_logits and self.training:
aux = self.AuxLogits(x)
# N x 768 x 17 x 17
x = self.Mixed_7a(x)
# N x 1280 x 8 x 8
x = self.Mixed_7b(x)
# N x 2048 x 8 x 8
x = self.Mixed_7c(x)
# N x 2048 x 8 x 8
# Adaptive average pooling
x = self.avgpool(x)
# N x 2048 x 1 x 1
x = self.dropout(x)
# N x 2048 x 1 x 1
x = torch.flatten(x, 1)
# N x 2048
x = self.fc(x)
# N x 1000 (num_classes)
return aux, x

View File

@ -4,7 +4,7 @@ from typing import Sequence
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmcv.runner.base_module import BaseModule, ModuleList
from mmengine.model import BaseModule, ModuleList
from mmcls.registry import MODELS
from ..utils import to_2tuple

View File

@ -2,7 +2,7 @@
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.utils import make_divisible

View File

@ -4,7 +4,7 @@ from typing import Sequence
import torch
import torch.nn as nn
from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
from mmcls.registry import MODELS
from .base_backbone import BaseBackbone

View File

@ -6,7 +6,7 @@ import torch.nn.functional as F
from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
build_norm_layer)
from mmcv.cnn.bricks.transformer import PatchEmbed as _PatchEmbed
from mmcv.runner import BaseModule, ModuleList, Sequential
from mmengine.model import BaseModule, ModuleList, Sequential
from mmcls.models.utils import SELayer, to_2tuple
from mmcls.registry import MODELS

View File

@ -3,8 +3,8 @@ import torch
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule, Sequential
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmengine.model import BaseModule, Sequential
from mmcls.registry import MODELS
from ..utils.se_layer import SELayer

View File

@ -5,8 +5,8 @@ import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
build_norm_layer, constant_init)
from mmcv.cnn.bricks import DropPath
from mmcv.runner import BaseModule
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmengine.model import BaseModule
from mmcls.registry import MODELS
from .base_backbone import BaseBackbone

View File

@ -4,7 +4,7 @@ import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, build_activation_layer, constant_init,
normal_init)
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.utils import channel_shuffle, make_divisible

View File

@ -2,8 +2,9 @@
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule, constant_init, normal_init
from mmcv.runner import BaseModule
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from mmengine.model.utils import constant_init, normal_init
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.utils import channel_shuffle

View File

@ -8,9 +8,9 @@ import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed, PatchMerging
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule, ModuleList
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmengine.model import BaseModule, ModuleList
from mmengine.model.utils import trunc_normal_
from mmcls.registry import MODELS
from ..utils import (ShiftWindowMSA, resize_pos_embed,
@ -488,8 +488,8 @@ class SwinTransformer(BaseBackbone):
ckpt_pos_embed_shape = state_dict[name].shape
if self.absolute_pos_embed.shape != ckpt_pos_embed_shape:
from mmcls.utils import get_root_logger
logger = get_root_logger()
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.info(
'Resize the absolute_pos_embed shape from '
f'{ckpt_pos_embed_shape} to {self.absolute_pos_embed.shape}.')
@ -523,8 +523,8 @@ class SwinTransformer(BaseBackbone):
new_rel_pos_bias = resize_relative_position_bias_table(
src_size, dst_size,
relative_position_bias_table_pretrained, nH1)
from mmcls.utils import get_root_logger
logger = get_root_logger()
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.info('Resize the relative_position_bias_table from '
f'{state_dict[ckpt_key].shape} to '
f'{new_rel_pos_bias.shape}')

View File

@ -7,8 +7,8 @@ import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule, ModuleList
from mmengine.model import BaseModule, ModuleList
from mmengine.model.utils import trunc_normal_
from mmcls.registry import MODELS
from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple
@ -381,8 +381,8 @@ class T2T_ViT(BaseBackbone):
ckpt_pos_embed_shape = state_dict[name].shape
if self.pos_embed.shape != ckpt_pos_embed_shape:
from mmcls.utils import get_root_logger
logger = get_root_logger()
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.info(
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
f'to {self.pos_embed.shape}.')

View File

@ -7,9 +7,9 @@ except ImportError:
import warnings
from mmcv.cnn.bricks.registry import NORM_LAYERS
from mmengine.logging import MMLogger
from mmcls.registry import MODELS
from ...utils import get_root_logger
from .base_backbone import BaseBackbone
@ -20,7 +20,7 @@ def print_timm_feature_info(feature_info):
feature_info (list[dict] | timm.models.features.FeatureInfo | None):
feature_info of timm backbone.
"""
logger = get_root_logger()
logger = MMLogger.get_current_instance()
if feature_info is None:
logger.warning('This backbone does not have feature_info')
elif isinstance(feature_info, list):

View File

@ -5,8 +5,8 @@ import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule, ModuleList
from mmengine.model import BaseModule, ModuleList
from mmengine.model.utils import trunc_normal_
from mmcls.registry import MODELS
from ..utils import to_2tuple

View File

@ -7,9 +7,8 @@ import torch.nn.functional as F
from mmcv.cnn import Conv2d, build_norm_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmcv.cnn.utils.weight_init import (constant_init, normal_init,
trunc_normal_init)
from mmcv.runner import BaseModule, ModuleList
from mmengine.model import BaseModule, ModuleList
from mmengine.model.utils import constant_init, normal_init, trunc_normal_init
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.utils.attention import MultiheadAttention

View File

@ -4,8 +4,8 @@ import torch.nn as nn
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
from mmcv.cnn.bricks import DropPath
from mmcv.cnn.bricks.transformer import PatchEmbed
from mmcv.runner import BaseModule, ModuleList
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmengine.model import BaseModule, ModuleList
from mmcls.registry import MODELS
from .base_backbone import BaseBackbone

View File

@ -6,11 +6,10 @@ import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule, ModuleList
from mmengine.model import BaseModule, ModuleList
from mmengine.model.utils import trunc_normal_
from mmcls.registry import MODELS
from mmcls.utils import get_root_logger
from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple
from .base_backbone import BaseBackbone
@ -316,12 +315,11 @@ class VisionTransformer(BaseBackbone):
ckpt_pos_embed_shape = state_dict[name].shape
if self.pos_embed.shape != ckpt_pos_embed_shape:
from mmengine.logging import print_log
logger = get_root_logger()
print_log(
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.info(
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
f'to {self.pos_embed.shape}.',
logger=logger)
f'to {self.pos_embed.shape}.')
ckpt_pos_embed_shape = to_2tuple(
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))

View File

@ -2,8 +2,8 @@
from abc import ABCMeta, abstractmethod
from typing import List, Optional, Tuple
from mmcv.runner import BaseModule
from mmengine import BaseDataElement
from mmengine.model import BaseModule
class BaseHead(BaseModule, metaclass=ABCMeta):

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import List, Tuple
import torch
import torch.nn as nn
from mmcls.registry import MODELS
from mmcls.utils import get_root_logger
from .vision_transformer_head import VisionTransformerClsHead
@ -57,9 +57,9 @@ class DeiTClsHead(VisionTransformerClsHead):
def forward(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor:
"""The forward process."""
logger = get_root_logger()
logger.warning("MMClassification doesn't support to train the "
'distilled version DeiT.')
if self.training:
warnings.warn('MMClassification cannot train the '
'distilled version DeiT.')
cls_token, dist_token = self.pre_logits(feats)
# The final classification head.
cls_score = (self.layers.head(cls_token) +

View File

@ -4,7 +4,7 @@ from typing import Dict, Optional, Sequence, Tuple
import torch
import torch.nn as nn
from mmcv.cnn import build_activation_layer, build_norm_layer
from mmcv.runner import BaseModule, ModuleList
from mmengine.model import BaseModule, ModuleList
from mmcls.registry import MODELS
from .cls_head import ClsHead

View File

@ -6,8 +6,8 @@ from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from mmcv.cnn import build_activation_layer
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner import Sequential
from mmengine.model import Sequential
from mmengine.model.utils import trunc_normal_
from mmcls.registry import MODELS
from .cls_head import ClsHead

View File

@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .accuracy import Accuracy, accuracy
from .asymmetric_loss import AsymmetricLoss, asymmetric_loss
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
cross_entropy)
@ -10,8 +9,8 @@ from .utils import (convert_to_one_hot, reduce_loss, weight_reduce_loss,
weighted_loss)
__all__ = [
'accuracy', 'Accuracy', 'asymmetric_loss', 'AsymmetricLoss',
'cross_entropy', 'binary_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
'asymmetric_loss', 'AsymmetricLoss', 'cross_entropy',
'binary_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
'weight_reduce_loss', 'LabelSmoothLoss', 'weighted_loss', 'FocalLoss',
'sigmoid_focal_loss', 'convert_to_one_hot', 'SeesawLoss'
]

View File

@ -1,143 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from numbers import Number
import numpy as np
import torch
import torch.nn as nn
def accuracy_numpy(pred, target, topk=(1, ), thrs=0.):
if isinstance(thrs, Number):
thrs = (thrs, )
res_single = True
elif isinstance(thrs, tuple):
res_single = False
else:
raise TypeError(
f'thrs should be a number or tuple, but got {type(thrs)}.')
res = []
maxk = max(topk)
num = pred.shape[0]
static_inds = np.indices((num, maxk))[0]
pred_label = pred.argpartition(-maxk, axis=1)[:, -maxk:]
pred_score = pred[static_inds, pred_label]
sort_inds = np.argsort(pred_score, axis=1)[:, ::-1]
pred_label = pred_label[static_inds, sort_inds]
pred_score = pred_score[static_inds, sort_inds]
for k in topk:
correct_k = pred_label[:, :k] == target.reshape(-1, 1)
res_thr = []
for thr in thrs:
# Only prediction values larger than thr are counted as correct
_correct_k = correct_k & (pred_score[:, :k] > thr)
_correct_k = np.logical_or.reduce(_correct_k, axis=1)
res_thr.append((_correct_k.sum() * 100. / num))
if res_single:
res.append(res_thr[0])
else:
res.append(res_thr)
return res
def accuracy_torch(pred, target, topk=(1, ), thrs=0.):
if isinstance(thrs, Number):
thrs = (thrs, )
res_single = True
elif isinstance(thrs, tuple):
res_single = False
else:
raise TypeError(
f'thrs should be a number or tuple, but got {type(thrs)}.')
res = []
maxk = max(topk)
num = pred.size(0)
pred = pred.float()
pred_score, pred_label = pred.topk(maxk, dim=1)
pred_label = pred_label.t()
correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))
for k in topk:
res_thr = []
for thr in thrs:
# Only prediction values larger than thr are counted as correct
_correct = correct & (pred_score.t() > thr)
correct_k = _correct[:k].reshape(-1).float().sum(0, keepdim=True)
res_thr.append((correct_k.mul_(100. / num)))
if res_single:
res.append(res_thr[0])
else:
res.append(res_thr)
return res
def accuracy(pred, target, topk=1, thrs=0.):
"""Calculate accuracy according to the prediction and target.
Args:
pred (torch.Tensor | np.array): The model prediction.
target (torch.Tensor | np.array): The target of each prediction
topk (int | tuple[int]): If the predictions in ``topk``
matches the target, the predictions will be regarded as
correct ones. Defaults to 1.
thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Defaults to 0.
Returns:
torch.Tensor | list[torch.Tensor] | list[list[torch.Tensor]]: Accuracy
- torch.Tensor: If both ``topk`` and ``thrs`` is a single value.
- list[torch.Tensor]: If one of ``topk`` or ``thrs`` is a tuple.
- list[list[torch.Tensor]]: If both ``topk`` and ``thrs`` is a \
tuple. And the first dim is ``topk``, the second dim is ``thrs``.
"""
assert isinstance(topk, (int, tuple))
if isinstance(topk, int):
topk = (topk, )
return_single = True
else:
return_single = False
assert isinstance(pred, (torch.Tensor, np.ndarray)), \
f'The pred should be torch.Tensor or np.ndarray ' \
f'instead of {type(pred)}.'
assert isinstance(target, (torch.Tensor, np.ndarray)), \
f'The target should be torch.Tensor or np.ndarray ' \
f'instead of {type(target)}.'
# torch version is faster in most situations.
to_tensor = (lambda x: torch.from_numpy(x)
if isinstance(x, np.ndarray) else x)
pred = to_tensor(pred)
target = to_tensor(target)
res = accuracy_torch(pred, target, topk, thrs)
return res[0] if return_single else res
class Accuracy(nn.Module):
def __init__(self, topk=(1, )):
"""Module to calculate the accuracy.
Args:
topk (tuple): The criterion used to calculate the
accuracy. Defaults to (1,).
"""
super().__init__()
self.topk = topk
def forward(self, pred, target):
"""Forward function to calculate accuracy.
Args:
pred (torch.Tensor): Prediction of models.
target (torch.Tensor): Target for each prediction.
Returns:
list[torch.Tensor]: The accuracies under different topk criterions.
"""
return accuracy(pred, target, self.topk)

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn.bricks import ConvModule
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
from mmcls.registry import MODELS
from ..backbones.resnet import Bottleneck, ResLayer

View File

@ -4,10 +4,9 @@ import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks.registry import DROPOUT_LAYERS
from mmcv.cnn.bricks.transformer import build_dropout
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule
from mmcv.cnn.bricks.drop import build_dropout
from mmengine.model import BaseModule
from mmengine.model.utils import trunc_normal_
from mmcls.registry import MODELS
from .helpers import to_2tuple
@ -364,7 +363,7 @@ class MultiheadAttention(BaseModule):
self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
self.out_drop = DROPOUT_LAYERS.build(dropout_layer)
self.out_drop = build_dropout(dropout_layer)
def forward(self, x):
B, N, _ = x.shape

View File

@ -6,7 +6,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner.base_module import BaseModule
from mmengine.model import BaseModule
from .helpers import to_2tuple

View File

@ -3,7 +3,7 @@ import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import DropPath
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
from .se_layer import SELayer

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.runner.base_module import BaseModule
from mmengine.model import BaseModule
class ConditionalPositionEncoding(BaseModule):

View File

@ -2,7 +2,7 @@
import mmcv
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
from .make_divisible import make_divisible

View File

@ -1,9 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .collect_env import collect_env
from .logger import get_root_logger, load_json_log
from .setup_env import register_all_modules, setup_multi_processes
from .setup_env import register_all_modules
__all__ = [
'collect_env', 'get_root_logger', 'load_json_log', 'setup_multi_processes',
'register_all_modules'
]
__all__ = ['collect_env', 'register_all_modules']

View File

@ -1,64 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import logging
from collections import defaultdict
from mmengine.logging import MMLogger
def get_root_logger(log_file=None, log_level=logging.INFO):
"""Get root logger.
Args:
log_file (str, optional): File path of log. Defaults to None.
log_level (int, optional): The level of logger.
Defaults to :obj:`logging.INFO`.
Returns:
:obj:`logging.Logger`: The obtained logger
"""
try:
return MMLogger.get_instance(
'mmcls',
logger_name='mmcls',
log_file=log_file,
log_level=log_level)
except AssertionError:
# if root logger already existed, no extra kwargs needed.
return MMLogger.get_instance('mmcls')
def load_json_log(json_log):
"""load and convert json_logs to log_dicts.
Args:
json_log (str): The path of the json log file.
Returns:
dict[int, dict[str, list]]:
Key is the epoch, value is a sub dict. The keys in each sub dict
are different metrics, e.g. memory, bbox_mAP, and the value is a
list of corresponding values in all iterations in this epoch.
.. code-block:: python
# An example output
{
1: {'iter': [100, 200, 300], 'loss': [6.94, 6.73, 6.53]},
2: {'iter': [100, 200, 300], 'loss': [6.33, 6.20, 6.07]},
...
}
"""
log_dict = dict()
with open(json_log, 'r') as log_file:
for line in log_file:
log = json.loads(line.strip())
# skip lines without `epoch` field
if 'epoch' not in log:
continue
epoch = log.pop('epoch')
if epoch not in log_dict:
log_dict[epoch] = defaultdict(list)
for k, v in log.items():
log_dict[epoch][k].append(v)
return log_dict

View File

@ -1,54 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import os
import platform
import warnings
import cv2
import torch.multiprocessing as mp
from mmengine import DefaultScope
def setup_multi_processes(cfg):
"""Setup multi-processing environment variables."""
# set multi-process start method as `fork` to speed up the training
if platform.system() != 'Windows':
mp_start_method = cfg.get('mp_start_method', 'fork')
current_method = mp.get_start_method(allow_none=True)
if current_method is not None and current_method != mp_start_method:
warnings.warn(
f'Multi-processing start method `{mp_start_method}` is '
f'different from the previous setting `{current_method}`.'
f'It will be force set to `{mp_start_method}`. You can change '
f'this behavior by changing `mp_start_method` in your config.')
mp.set_start_method(mp_start_method, force=True)
# disable opencv multithreading to avoid system being overloaded
opencv_num_threads = cfg.get('opencv_num_threads', 0)
cv2.setNumThreads(opencv_num_threads)
# setup OMP threads
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
if 'OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
omp_num_threads = 1
warnings.warn(
f'Setting OMP_NUM_THREADS environment variable for each process '
f'to be {omp_num_threads} in default, to avoid your system being '
f'overloaded, please further tune the variable for optimal '
f'performance in your application as needed.')
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
# setup MKL threads
if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
mkl_num_threads = 1
warnings.warn(
f'Setting MKL_NUM_THREADS environment variable for each process '
f'to be {mkl_num_threads} in default, to avoid your system being '
f'overloaded, please further tune the variable for optimal '
f'performance in your application as needed.')
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
def register_all_modules(init_default_scope: bool = True) -> None:
"""Register all modules in mmcls into the registries.

View File

@ -27,3 +27,4 @@ Import:
- configs/convmixer/metafile.yml
- configs/densenet/metafile.yml
- configs/poolformer/metafile.yml
- configs/inception_v3/metafile.yml

View File

@ -7,12 +7,13 @@ from unittest import TestCase
from unittest.mock import MagicMock, call, patch
import numpy as np
from mmengine.logging import MMLogger
from mmengine.registry import TRANSFORMS
from mmcls.registry import DATASETS
from mmcls.utils import get_root_logger
from mmcls.utils import register_all_modules
mmcls_logger = get_root_logger()
register_all_modules()
ASSETS_ROOT = osp.abspath(
osp.join(osp.dirname(__file__), '../../data/dataset'))
@ -218,7 +219,8 @@ class TestCustomDataset(TestBaseDataset):
'ann_file': '',
'extensions': ('.jpeg', )
}
with self.assertLogs(mmcls_logger, 'WARN') as log:
logger = MMLogger.get_current_instance()
with self.assertLogs(logger, 'WARN') as log:
dataset = dataset_class(**cfg)
self.assertIn('Supported extensions are: .jpeg', log.output[0])
self.assertEqual(len(dataset), 1)
@ -291,13 +293,14 @@ class TestImageNet21k(TestCustomDataset):
# Warn about ann_file
cfg = {**self.DEFAULT_ARGS, 'ann_file': ''}
with self.assertLogs(mmcls_logger, 'WARN') as log:
logger = MMLogger.get_current_instance()
with self.assertLogs(logger, 'WARN') as log:
dataset_class(**cfg)
self.assertIn('specify the `ann_file`', log.output[0])
# Warn about classes
cfg = {**self.DEFAULT_ARGS, 'classes': None}
with self.assertLogs(mmcls_logger, 'WARN') as log:
with self.assertLogs(logger, 'WARN') as log:
dataset_class(**cfg)
self.assertIn('specify the `classes`', log.output[0])

View File

@ -1,93 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial
import pytest
import torch
from mmcls.core import average_performance, mAP
from mmcls.models.losses.accuracy import Accuracy, accuracy_numpy
def test_mAP():
target = torch.Tensor([[1, 1, 0, -1], [1, 1, 0, -1], [0, -1, 1, -1],
[0, 1, 0, -1]])
pred = torch.Tensor([[0.9, 0.8, 0.3, 0.2], [0.1, 0.2, 0.2, 0.1],
[0.7, 0.5, 0.9, 0.3], [0.8, 0.1, 0.1, 0.2]])
# target and pred should both be np.ndarray or torch.Tensor
with pytest.raises(TypeError):
target_list = target.tolist()
_ = mAP(pred, target_list)
# target and pred should be in the same shape
with pytest.raises(AssertionError):
target_shorter = target[:-1]
_ = mAP(pred, target_shorter)
assert mAP(pred, target) == pytest.approx(68.75, rel=1e-2)
target_no_difficult = torch.Tensor([[1, 1, 0, 0], [0, 1, 0, 0],
[0, 0, 1, 0], [1, 0, 0, 0]])
assert mAP(pred, target_no_difficult) == pytest.approx(70.83, rel=1e-2)
def test_average_performance():
target = torch.Tensor([[1, 1, 0, -1], [1, 1, 0, -1], [0, -1, 1, -1],
[0, 1, 0, -1], [0, 1, 0, -1]])
pred = torch.Tensor([[0.9, 0.8, 0.3, 0.2], [0.1, 0.2, 0.2, 0.1],
[0.7, 0.5, 0.9, 0.3], [0.8, 0.1, 0.1, 0.2],
[0.8, 0.1, 0.1, 0.2]])
# target and pred should both be np.ndarray or torch.Tensor
with pytest.raises(TypeError):
target_list = target.tolist()
_ = average_performance(pred, target_list)
# target and pred should be in the same shape
with pytest.raises(AssertionError):
target_shorter = target[:-1]
_ = average_performance(pred, target_shorter)
assert average_performance(pred, target) == average_performance(
pred, target, thr=0.5)
assert average_performance(pred, target, thr=0.5, k=2) \
== average_performance(pred, target, thr=0.5)
assert average_performance(
pred, target, thr=0.3) == pytest.approx(
(31.25, 43.75, 36.46, 33.33, 42.86, 37.50), rel=1e-2)
assert average_performance(
pred, target, k=2) == pytest.approx(
(43.75, 50.00, 46.67, 40.00, 57.14, 47.06), rel=1e-2)
def test_accuracy():
pred_tensor = torch.tensor([[0.1, 0.2, 0.4], [0.2, 0.5, 0.3],
[0.4, 0.3, 0.1], [0.8, 0.9, 0.0]])
target_tensor = torch.tensor([2, 0, 0, 0])
pred_array = pred_tensor.numpy()
target_array = target_tensor.numpy()
acc_top1 = 50.
acc_top2 = 75.
compute_acc = Accuracy(topk=1)
assert compute_acc(pred_tensor, target_tensor) == acc_top1
assert compute_acc(pred_array, target_array) == acc_top1
compute_acc = Accuracy(topk=(1, ))
assert compute_acc(pred_tensor, target_tensor)[0] == acc_top1
assert compute_acc(pred_array, target_array)[0] == acc_top1
compute_acc = Accuracy(topk=(1, 2))
assert compute_acc(pred_tensor, target_array)[0] == acc_top1
assert compute_acc(pred_tensor, target_tensor)[1] == acc_top2
assert compute_acc(pred_array, target_array)[0] == acc_top1
assert compute_acc(pred_array, target_array)[1] == acc_top2
with pytest.raises(AssertionError):
compute_acc(pred_tensor, 'other_type')
# test accuracy_numpy
compute_acc = partial(accuracy_numpy, topk=(1, 2))
assert compute_acc(pred_array, target_array)[0] == acc_top1
assert compute_acc(pred_array, target_array)[1] == acc_top2

View File

@ -0,0 +1,56 @@
# Copyright (c) OpenMMLab. All rights reserved.
from types import MethodType
from unittest import TestCase
import torch
from mmcls.models import InceptionV3
from mmcls.models.backbones.inception_v3 import InceptionAux
class TestInceptionV3(TestCase):
DEFAULT_ARGS = dict(num_classes=10, aux_logits=False, dropout=0.)
def test_structure(self):
# Test without auxiliary branch.
model = InceptionV3(**self.DEFAULT_ARGS)
self.assertIsNone(model.AuxLogits)
# Test with auxiliary branch.
cfg = {**self.DEFAULT_ARGS, 'aux_logits': True}
model = InceptionV3(**cfg)
self.assertIsInstance(model.AuxLogits, InceptionAux)
def test_init_weights(self):
cfg = {**self.DEFAULT_ARGS, 'aux_logits': True}
model = InceptionV3(**cfg)
init_info = {}
def get_init_info(self, *args):
for name, param in self.named_parameters():
init_info[name] = ''.join(
self._params_init_info[param]['init_info'])
model._dump_init_info = MethodType(get_init_info, model)
model.init_weights()
self.assertIn('TruncNormalInit: a=-2, b=2, mean=0, std=0.1, bias=0',
init_info['Conv2d_1a_3x3.conv.weight'])
self.assertIn('TruncNormalInit: a=-2, b=2, mean=0, std=0.01, bias=0',
init_info['AuxLogits.conv0.conv.weight'])
self.assertIn('TruncNormalInit: a=-2, b=2, mean=0, std=0.001, bias=0',
init_info['AuxLogits.fc.weight'])
def test_forward(self):
inputs = torch.rand(2, 3, 299, 299)
model = InceptionV3(**self.DEFAULT_ARGS)
aux_out, out = model(inputs)
self.assertIsNone(aux_out)
self.assertEqual(out.shape, (2, 10))
cfg = {**self.DEFAULT_ARGS, 'aux_logits': True}
model = InceptionV3(**cfg)
aux_out, out = model(inputs)
self.assertEqual(aux_out.shape, (2, 10))
self.assertEqual(out.shape, (2, 10))

View File

@ -60,7 +60,8 @@ def test_timm_backbone():
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size((1, 192))
# Disable the test since TIMM's behavior changes between 0.5.4 and 0.5.5
# assert feat[0].shape == torch.Size((1, 197, 192))
def test_timm_backbone_features_only():

View File

@ -1,53 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from mmengine.logging import MMLogger
from mmcls.utils import get_root_logger, load_json_log
def test_get_root_logger():
# set all logger instance
MMLogger._instance_dict = {}
with tempfile.TemporaryDirectory() as tmpdirname:
log_path = osp.join(tmpdirname, 'test.log')
logger = get_root_logger(log_file=log_path)
message1 = 'adhsuadghj'
logger.info(message1)
logger2 = get_root_logger()
message2 = 'm,tkrgmkr'
logger2.info(message2)
with open(log_path, 'r') as f:
lines = f.readlines()
assert message1 in lines[0]
assert message2 in lines[1]
assert logger is logger2
handlers = list(logger.handlers)
for handler in handlers:
handler.close()
logger.removeHandler(handler)
def test_load_json_log():
log_path = 'tests/data/test.logjson'
log_dict = load_json_log(log_path)
# test log_dict
assert set(log_dict.keys()) == set([1, 2, 3])
# test epoch dict in log_dict
assert set(log_dict[1].keys()) == set(
['iter', 'lr', 'memory', 'data_time', 'time', 'mode'])
assert isinstance(log_dict[1]['lr'], list)
assert len(log_dict[1]['iter']) == 4
assert len(log_dict[1]['lr']) == 4
assert len(log_dict[2]['iter']) == 3
assert len(log_dict[2]['lr']) == 3
assert log_dict[3]['iter'] == [10, 20]
assert log_dict[3]['lr'] == [0.33305, 0.34759]

View File

@ -1,16 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import multiprocessing as mp
import os
import platform
import sys
from unittest import TestCase
import cv2
from mmcv import Config
from mmengine import DefaultScope
from mmcls.utils import register_all_modules, setup_multi_processes
from mmcls.utils import register_all_modules
class TestSetupEnv(TestCase):
@ -42,62 +37,3 @@ class TestSetupEnv(TestCase):
with self.assertWarnsRegex(
Warning, 'The current default scope "test" is not "mmcls"'):
register_all_modules(init_default_scope=True)
def test_setup_multi_processes():
# temp save system setting
sys_start_mehod = mp.get_start_method(allow_none=True)
sys_cv_threads = cv2.getNumThreads()
# pop and temp save system env vars
sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None)
sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None)
# test config without setting env
config = dict(data=dict(workers_per_gpu=2))
cfg = Config(config)
setup_multi_processes(cfg)
assert os.getenv('OMP_NUM_THREADS') == '1'
assert os.getenv('MKL_NUM_THREADS') == '1'
# when set to 0, the num threads will be 1
assert cv2.getNumThreads() == 1
if platform.system() != 'Windows':
assert mp.get_start_method() == 'fork'
# test num workers <= 1
os.environ.pop('OMP_NUM_THREADS')
os.environ.pop('MKL_NUM_THREADS')
config = dict(data=dict(workers_per_gpu=0))
cfg = Config(config)
setup_multi_processes(cfg)
assert 'OMP_NUM_THREADS' not in os.environ
assert 'MKL_NUM_THREADS' not in os.environ
# test manually set env var
os.environ['OMP_NUM_THREADS'] = '4'
config = dict(data=dict(workers_per_gpu=2))
cfg = Config(config)
setup_multi_processes(cfg)
assert os.getenv('OMP_NUM_THREADS') == '4'
# test manually set opencv threads and mp start method
config = dict(
data=dict(workers_per_gpu=2),
opencv_num_threads=4,
mp_start_method='spawn')
cfg = Config(config)
setup_multi_processes(cfg)
assert cv2.getNumThreads() == 4
assert mp.get_start_method() == 'spawn'
# revert setting to avoid affecting other programs
if sys_start_mehod:
mp.set_start_method(sys_start_mehod, force=True)
cv2.setNumThreads(sys_cv_threads)
if sys_omp_threads:
os.environ['OMP_NUM_THREADS'] = sys_omp_threads
else:
os.environ.pop('OMP_NUM_THREADS')
if sys_mkl_threads:
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
else:
os.environ.pop('MKL_NUM_THREADS')

View File

@ -3,6 +3,7 @@ import argparse
import os
import os.path as osp
import mmengine
from mmengine.config import Config, DictAction
from mmengine.runner import Runner
@ -17,6 +18,7 @@ def parse_args():
parser.add_argument(
'--work-dir',
help='the directory to save the file containing evaluation metrics')
parser.add_argument('--out', help='the file to save metric results.')
parser.add_argument(
'--dump',
type=str,
@ -118,6 +120,16 @@ def main():
# build the runner from config
runner = Runner.from_cfg(cfg)
if args.out:
class SaveMetricHook(mmengine.Hook):
def after_test_epoch(self, _, metrics=None):
if metrics is not None:
mmengine.dump(metrics, args.out)
runner.register_hook(SaveMetricHook(), 'LOWEST')
# start testing
runner.test()