[Improve] Update benchmark scripts (#1028)
* Update train benchmark scripts * Add `--cfg-options` for dev scripts and enhance `--range`. * Fix bug of regex expression. * Fix minor bugs * Update ShuffleNet configs * Update rsb-a1 configs and label smooth loss mode. * Update inference dev scripts * From `mmengine` instead of `mmcv` import fileio. * Fix lint * Update pre-commit hook * Use `use_sigmoid` option instead of "bce" mode in label smooth loss.pull/1037/head
parent
e4e8047563
commit
e9e2d48cb2
|
@ -9,8 +9,10 @@ from typing import OrderedDict
|
|||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import Config, MMLogger, Runner
|
||||
from mmengine.dataset import Compose
|
||||
from mmengine import Config, DictAction, MMLogger
|
||||
from mmengine.dataset import Compose, default_collate
|
||||
from mmengine.fileio import FileClient
|
||||
from mmengine.runner import Runner
|
||||
from modelindex.load_model_index import load
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
@ -52,6 +54,16 @@ def parse_args():
|
|||
'--flops-str',
|
||||
action='store_true',
|
||||
help='Output FLOPs and params counts in a string form.')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
@ -62,6 +74,8 @@ def inference(config_file, checkpoint, work_dir, args, exp_name):
|
|||
cfg.load_from = checkpoint
|
||||
cfg.log_level = 'WARN'
|
||||
cfg.experiment_name = exp_name
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# build the data pipeline
|
||||
test_dataset = cfg.test_dataloader.dataset
|
||||
|
@ -72,7 +86,8 @@ def inference(config_file, checkpoint, work_dir, args, exp_name):
|
|||
test_dataset.pipeline.insert(1, dict(type='Resize', scale=32))
|
||||
|
||||
data = Compose(test_dataset.pipeline)({'img_path': args.img})
|
||||
resolution = tuple(data['inputs'].shape[1:])
|
||||
data = default_collate([data])
|
||||
resolution = tuple(data['inputs'].shape[-2:])
|
||||
|
||||
runner: Runner = Runner.from_cfg(cfg)
|
||||
model = runner.model
|
||||
|
@ -83,26 +98,30 @@ def inference(config_file, checkpoint, work_dir, args, exp_name):
|
|||
if args.inference_time:
|
||||
time_record = []
|
||||
for _ in range(10):
|
||||
model.val_step(data) # warmup before profiling
|
||||
torch.cuda.synchronize()
|
||||
start = time()
|
||||
model.val_step([data])
|
||||
model.val_step(data)
|
||||
torch.cuda.synchronize()
|
||||
time_record.append((time() - start) * 1000)
|
||||
result['time_mean'] = np.mean(time_record[1:-1])
|
||||
result['time_std'] = np.std(time_record[1:-1])
|
||||
else:
|
||||
model.val_step([data])
|
||||
model.val_step(data)
|
||||
|
||||
result['model'] = config_file.stem
|
||||
|
||||
if args.flops:
|
||||
from mmcv.cnn.utils import get_model_complexity_info
|
||||
from fvcore.nn import FlopCountAnalysis, parameter_count
|
||||
from fvcore.nn.print_model_statistics import _format_size
|
||||
_format_size = _format_size if args.flops_str else lambda x: x
|
||||
with torch.no_grad():
|
||||
if hasattr(model, 'extract_feat'):
|
||||
model.forward = model.extract_feat
|
||||
flops, params = get_model_complexity_info(
|
||||
model,
|
||||
input_shape=(3, ) + resolution,
|
||||
print_per_layer_stat=False,
|
||||
as_strings=args.flops_str)
|
||||
model.to('cpu')
|
||||
inputs = (torch.randn((1, 3, *resolution)), )
|
||||
flops = _format_size(FlopCountAnalysis(model, inputs).total())
|
||||
params = _format_size(parameter_count(model)[''])
|
||||
result['flops'] = flops if args.flops_str else int(flops)
|
||||
result['params'] = params if args.flops_str else int(params)
|
||||
else:
|
||||
|
@ -184,7 +203,6 @@ def main(args):
|
|||
if args.checkpoint_root is not None:
|
||||
root = args.checkpoint_root
|
||||
if 's3://' in args.checkpoint_root:
|
||||
from mmcv.fileio import FileClient
|
||||
from petrel_client.common.exception import AccessDeniedError
|
||||
file_client = FileClient.infer_client(uri=root)
|
||||
checkpoint = file_client.join_path(
|
||||
|
|
|
@ -62,6 +62,12 @@ def parse_args():
|
|||
action='store_true',
|
||||
help='Summarize benchmark test results.')
|
||||
parser.add_argument('--save', action='store_true', help='Save the summary')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
type=str,
|
||||
default=[],
|
||||
help='Config options for all config files.')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
@ -76,7 +82,7 @@ def create_test_job_batch(commands, model_info, args, port, script_name):
|
|||
|
||||
http_prefix = 'https://download.openmmlab.com/mmclassification/'
|
||||
if 's3://' in args.checkpoint_root:
|
||||
from mmcv.fileio import FileClient
|
||||
from mmengine.fileio import FileClient
|
||||
from petrel_client.common.exception import AccessDeniedError
|
||||
file_client = FileClient.infer_client(uri=args.checkpoint_root)
|
||||
checkpoint = file_client.join_path(
|
||||
|
@ -125,6 +131,7 @@ def create_test_job_batch(commands, model_info, args, port, script_name):
|
|||
f'--work-dir={work_dir} '
|
||||
f'--out={result_file} '
|
||||
f'--cfg-option dist_params.port={port} '
|
||||
f'{" ".join(args.cfg_options)} '
|
||||
f'--launcher={launcher}\n')
|
||||
|
||||
with open(work_dir / 'job.sh', 'w') as f:
|
||||
|
|
|
@ -3,22 +3,48 @@ import json
|
|||
import os
|
||||
import os.path as osp
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from zipfile import ZipFile
|
||||
|
||||
import yaml
|
||||
from modelindex.load_model_index import load
|
||||
from rich.console import Console
|
||||
from rich.syntax import Syntax
|
||||
from rich.table import Table
|
||||
|
||||
console = Console()
|
||||
MMCLS_ROOT = Path(__file__).absolute().parents[2]
|
||||
CYCLE_LEVELS = ['month', 'quarter', 'half-year', 'no-training']
|
||||
METRICS_MAP = {
|
||||
'Top 1 Accuracy': 'accuracy/top1',
|
||||
'Top 5 Accuracy': 'accuracy/top5'
|
||||
}
|
||||
|
||||
|
||||
class RangeAction(argparse.Action):
|
||||
|
||||
def __call__(self, parser, namespace, values: str, option_string):
|
||||
matches = re.match(r'([><=]*)([-\w]+)', values)
|
||||
if matches is None:
|
||||
raise ValueError(f'Unavailable range option {values}')
|
||||
symbol, range_str = matches.groups()
|
||||
assert range_str in CYCLE_LEVELS, \
|
||||
f'{range_str} are not in {CYCLE_LEVELS}.'
|
||||
level = CYCLE_LEVELS.index(range_str)
|
||||
symbol = symbol or '<='
|
||||
ranges = set()
|
||||
if '=' in symbol:
|
||||
ranges.add(level)
|
||||
if '>' in symbol:
|
||||
ranges.update(range(level + 1, len(CYCLE_LEVELS)))
|
||||
if '<' in symbol:
|
||||
ranges.update(range(level))
|
||||
assert len(ranges) > 0, 'No range are selected.'
|
||||
setattr(namespace, self.dest, ranges)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Train models (in bench_train.yml) and compare accuracy.')
|
||||
|
@ -32,6 +58,14 @@ def parse_args():
|
|||
parser.add_argument('--port', type=int, default=29666, help='dist port')
|
||||
parser.add_argument(
|
||||
'--models', nargs='+', type=str, help='Specify model names to run.')
|
||||
parser.add_argument(
|
||||
'--range',
|
||||
type=str,
|
||||
default={0},
|
||||
action=RangeAction,
|
||||
metavar='{month,quarter,half-year,no-training}',
|
||||
help='The training benchmark range, "no-training" means all models '
|
||||
"including those we haven't trained.")
|
||||
parser.add_argument(
|
||||
'--work-dir',
|
||||
default='work_dirs/benchmark_train',
|
||||
|
@ -63,18 +97,33 @@ def parse_args():
|
|||
'--save',
|
||||
action='store_true',
|
||||
help='Save the summary and archive log files.')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
type=str,
|
||||
default=[],
|
||||
help='Config options for all config files.')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def get_gpu_number(model_info):
|
||||
config = osp.basename(model_info.config)
|
||||
matches = re.match(r'.*[-_](\d+)xb(\d+).*', config)
|
||||
if matches is None:
|
||||
raise ValueError(
|
||||
'Cannot get gpu numbers from the config name {config}')
|
||||
gpus = int(matches.groups()[0])
|
||||
return gpus
|
||||
|
||||
|
||||
def create_train_job_batch(commands, model_info, args, port, script_name):
|
||||
|
||||
fname = model_info.name
|
||||
|
||||
assert 'Gpus' in model_info.data, \
|
||||
f"Haven't specify gpu numbers for {fname}"
|
||||
gpus = model_info.data['Gpus']
|
||||
gpus = get_gpu_number(model_info)
|
||||
gpus_per_node = min(gpus, 8)
|
||||
|
||||
config = Path(model_info.config)
|
||||
assert config.exists(), f'"{fname}": {config} not found.'
|
||||
|
@ -101,15 +150,17 @@ def create_train_job_batch(commands, model_info, args, port, script_name):
|
|||
f'#SBATCH --output {work_dir}/job.%j.out\n'
|
||||
f'#SBATCH --partition={args.partition}\n'
|
||||
f'#SBATCH --job-name {job_name}\n'
|
||||
f'#SBATCH --gres=gpu:8\n'
|
||||
f'#SBATCH --gres=gpu:{gpus_per_node}\n'
|
||||
f'{mail_cfg}{quota_cfg}'
|
||||
f'#SBATCH --ntasks-per-node=8\n'
|
||||
f'#SBATCH --ntasks-per-node={gpus_per_node}\n'
|
||||
f'#SBATCH --ntasks={gpus}\n'
|
||||
f'#SBATCH --cpus-per-task=5\n\n'
|
||||
f'{runner} -u {script_name} {config} '
|
||||
f'--work-dir={work_dir} --cfg-option '
|
||||
f'dist_params.port={port} '
|
||||
f'checkpoint_config.max_keep_ckpts=10 '
|
||||
f'env_cfg.dist_cfg.port={port} '
|
||||
f'{" ".join(args.cfg_options)} '
|
||||
f'default_hooks.checkpoint.max_keep_ckpts=2 '
|
||||
f'default_hooks.checkpoint.save_best="auto" '
|
||||
f'--launcher={launcher}\n')
|
||||
|
||||
with open(work_dir / 'job.sh', 'w') as f:
|
||||
|
@ -124,33 +175,16 @@ def create_train_job_batch(commands, model_info, args, port, script_name):
|
|||
return work_dir / 'job.sh'
|
||||
|
||||
|
||||
def train(args):
|
||||
models_cfg = load(str(Path(__file__).parent / 'bench_train.yml'))
|
||||
models_cfg.build_models_with_collections()
|
||||
models = {model.name: model for model in models_cfg.models}
|
||||
|
||||
def train(models, args):
|
||||
script_name = osp.join('tools', 'train.py')
|
||||
port = args.port
|
||||
|
||||
commands = []
|
||||
if args.models:
|
||||
patterns = [re.compile(pattern) for pattern in args.models]
|
||||
filter_models = {}
|
||||
for k, v in models.items():
|
||||
if any([re.match(pattern, k) for pattern in patterns]):
|
||||
filter_models[k] = v
|
||||
if len(filter_models) == 0:
|
||||
print('No model found, please specify models in:')
|
||||
print('\n'.join(models.keys()))
|
||||
return
|
||||
models = filter_models
|
||||
|
||||
for model_info in models.values():
|
||||
months = model_info.data.get('Months', range(1, 13))
|
||||
if datetime.now().month in months:
|
||||
script_path = create_train_job_batch(commands, model_info, args,
|
||||
port, script_name)
|
||||
port += 1
|
||||
script_path = create_train_job_batch(commands, model_info, args, port,
|
||||
script_name)
|
||||
port += 1
|
||||
|
||||
command_str = '\n'.join(commands)
|
||||
|
||||
|
@ -245,12 +279,14 @@ def show_summary(summary_data):
|
|||
metric = summary[metric_key]
|
||||
expect = metric['expect']
|
||||
last = metric['last']
|
||||
last_epoch = metric['last_epoch']
|
||||
last_color = set_color(last, expect)
|
||||
best = metric['best']
|
||||
best_color = set_color(best, expect)
|
||||
best_epoch = metric['best_epoch']
|
||||
row.append(f'{expect:.2f}')
|
||||
row.append(f'[{last_color}]{last:.2f}[/{last_color}]')
|
||||
row.append(
|
||||
f'[{last_color}]{last:.2f}[/{last_color}] ({last_epoch})')
|
||||
row.append(
|
||||
f'[{best_color}]{best:.2f}[/{best_color}] ({best_epoch})')
|
||||
table.add_row(*row)
|
||||
|
@ -258,25 +294,11 @@ def show_summary(summary_data):
|
|||
console.print(table)
|
||||
|
||||
|
||||
def summary(args):
|
||||
models_cfg = load(str(Path(__file__).parent / 'bench_train.yml'))
|
||||
models = {model.name: model for model in models_cfg.models}
|
||||
def summary(models, args):
|
||||
|
||||
work_dir = Path(args.work_dir)
|
||||
dir_map = {p.name: p for p in work_dir.iterdir() if p.is_dir()}
|
||||
|
||||
if args.models:
|
||||
patterns = [re.compile(pattern) for pattern in args.models]
|
||||
filter_models = {}
|
||||
for k, v in models.items():
|
||||
if any([re.match(pattern, k) for pattern in patterns]):
|
||||
filter_models[k] = v
|
||||
if len(filter_models) == 0:
|
||||
print('No model found, please specify models in:')
|
||||
print('\n'.join(models.keys()))
|
||||
return
|
||||
models = filter_models
|
||||
|
||||
summary_data = {}
|
||||
for model_name, model_info in models.items():
|
||||
|
||||
|
@ -287,17 +309,19 @@ def summary(args):
|
|||
|
||||
# Skip if not found any vis_data folder.
|
||||
sub_dir = dir_map[model_name]
|
||||
vis_folders = [d for d in sub_dir.iterdir() if d.is_dir()]
|
||||
if len(vis_folders) == 0:
|
||||
continue
|
||||
log_file = sorted(vis_folders)[-1] / 'vis_data' / 'scalars.json'
|
||||
if not log_file.exists():
|
||||
log_files = [f for f in sub_dir.glob('*/vis_data/scalars.json')]
|
||||
if len(log_files) == 0:
|
||||
continue
|
||||
log_file = sorted(log_files)[-1]
|
||||
|
||||
# parse train log
|
||||
with open(log_file) as f:
|
||||
json_logs = [json.loads(s) for s in f.readlines()]
|
||||
val_logs = [log for log in json_logs if 'loss' not in log]
|
||||
val_logs = [
|
||||
log for log in json_logs
|
||||
# TODO: need a better method to extract validate log
|
||||
if 'loss' not in log and 'accuracy/top1' in log
|
||||
]
|
||||
|
||||
if len(val_logs) == 0:
|
||||
continue
|
||||
|
@ -320,9 +344,10 @@ def summary(args):
|
|||
summary[key_yml] = dict(
|
||||
expect=expect_result,
|
||||
last=last,
|
||||
last_epoch=len(val_logs),
|
||||
best=best,
|
||||
best_epoch=best_epoch)
|
||||
summary_data[model_name] = summary
|
||||
best_epoch=best_epoch + 1)
|
||||
summary_data[model_name].update(summary)
|
||||
|
||||
show_summary(summary_data)
|
||||
if args.save:
|
||||
|
@ -332,10 +357,39 @@ def summary(args):
|
|||
def main():
|
||||
args = parse_args()
|
||||
|
||||
model_index_file = MMCLS_ROOT / 'model-index.yml'
|
||||
model_index = load(str(model_index_file))
|
||||
model_index.build_models_with_collections()
|
||||
all_models = {model.name: model for model in model_index.models}
|
||||
|
||||
with open(Path(__file__).parent / 'bench_train.yml', 'r') as f:
|
||||
train_items = yaml.safe_load(f)
|
||||
models = OrderedDict()
|
||||
for item in train_items:
|
||||
name = item['Name']
|
||||
model_info = all_models[name]
|
||||
model_info.cycle = item.get('Cycle', None)
|
||||
cycle = getattr(model_info, 'cycle', 'month')
|
||||
cycle_level = CYCLE_LEVELS.index(cycle)
|
||||
if cycle_level in args.range:
|
||||
models[name] = model_info
|
||||
|
||||
if args.models:
|
||||
patterns = [re.compile(pattern) for pattern in args.models]
|
||||
filter_models = {}
|
||||
for k, v in models.items():
|
||||
if any([re.match(pattern, k) for pattern in patterns]):
|
||||
filter_models[k] = v
|
||||
if len(filter_models) == 0:
|
||||
print('No model found, please specify models in:')
|
||||
print('\n'.join(models.keys()))
|
||||
return
|
||||
models = filter_models
|
||||
|
||||
if args.summary:
|
||||
summary(args)
|
||||
summary(models, args)
|
||||
else:
|
||||
train(args)
|
||||
train(models, args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,88 +1,86 @@
|
|||
Models:
|
||||
- Name: resnet50
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 76.55
|
||||
Top 5 Accuracy: 93.06
|
||||
Config: configs/resnet/resnet50_8xb32_in1k.py
|
||||
Gpus: 8
|
||||
- Name: mobilenet-v2_8xb32_in1k
|
||||
Cycle: month
|
||||
|
||||
- Name: seresnet50
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 77.74
|
||||
Top 5 Accuracy: 93.84
|
||||
Config: configs/seresnet/seresnet50_8xb32_in1k.py
|
||||
Gpus: 8
|
||||
- Name: resnet50_8xb32_in1k
|
||||
Cycle: month
|
||||
|
||||
- Name: vit-base
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 82.37
|
||||
Top 5 Accuracy: 96.15
|
||||
Config: configs/vision_transformer/vit-base-p16_pt-32xb128-mae_in1k-224.py
|
||||
Gpus: 32
|
||||
- Name: seresnet50_8xb32_in1k
|
||||
Cycle: month
|
||||
|
||||
- Name: mobilenetv2
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 71.86
|
||||
Top 5 Accuracy: 90.42
|
||||
Config: configs/mobilenet_v2/mobilenet-v2_8xb32_in1k.py
|
||||
Gpus: 8
|
||||
- Name: swin-small_16xb64_in1k
|
||||
Cycle: month
|
||||
|
||||
- Name: swin_tiny
|
||||
Results:
|
||||
- Dataset: ImageNet
|
||||
Metrics:
|
||||
Top 1 Accuracy: 81.18
|
||||
Top 5 Accuracy: 95.61
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_b16x64_300e_imagenet_20210616_090925-66df6be6.pth
|
||||
Config: configs/swin_transformer/swin-tiny_16xb64_in1k.py
|
||||
Gpus: 16
|
||||
- Name: vit-base-p16_pt-32xb128-mae_in1k
|
||||
Cycle: month
|
||||
|
||||
- Name: vgg16
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 71.62
|
||||
Top 5 Accuracy: 90.49
|
||||
Config: configs/vgg/vgg16_8xb32_in1k.py
|
||||
Gpus: 8
|
||||
Months:
|
||||
- 1
|
||||
- 4
|
||||
- 7
|
||||
- 10
|
||||
- Name: resnet50_8xb256-rsb-a1-600e_in1k
|
||||
Cycle: quarter
|
||||
|
||||
- Name: shufflenet_v2
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 69.55
|
||||
Top 5 Accuracy: 88.92
|
||||
Config: configs/shufflenet_v2/shufflenet-v2-1x_16xb64_in1k.py
|
||||
Gpus: 16
|
||||
Months:
|
||||
- 2
|
||||
- 5
|
||||
- 8
|
||||
- 11
|
||||
- Name: resnext50-32x4d_8xb32_in1k
|
||||
Cycle: quarter
|
||||
|
||||
- Name: resnet-rsb
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 80.12
|
||||
Top 5 Accuracy: 94.78
|
||||
Config: configs/resnet/resnet50_8xb256-rsb-a1-600e_in1k.py
|
||||
Gpus: 8
|
||||
Months:
|
||||
- 3
|
||||
- 6
|
||||
- 9
|
||||
- 12
|
||||
- Name: shufflenet-v2-1x_16xb64_in1k
|
||||
Cycle: quarter
|
||||
|
||||
- Name: vgg16_8xb32_in1k
|
||||
Cycle: quarter
|
||||
|
||||
- Name: shufflenet-v1-1x_16xb64_in1k
|
||||
Cycle: half-year
|
||||
|
||||
- Name: t2t-vit-t-14_8xb64_in1k
|
||||
Cycle: half-year
|
||||
|
||||
- Name: regnetx-1.6gf_8xb128_in1k
|
||||
Cycle: half-year
|
||||
|
||||
- Name: van-small_8xb128_in1k
|
||||
Cycle: no-training
|
||||
|
||||
- Name: res2net50-w14-s8_3rdparty_8xb32_in1k
|
||||
Cycle: no-training
|
||||
|
||||
- Name: repvgg-A2_3rdparty_4xb64-coslr-120e_in1k
|
||||
Cycle: no-training
|
||||
|
||||
- Name: tnt-small-p16_3rdparty_in1k
|
||||
Cycle: no-training
|
||||
|
||||
- Name: mlp-mixer-base-p16_3rdparty_64xb64_in1k
|
||||
Cycle: no-training
|
||||
|
||||
- Name: conformer-small-p16_3rdparty_8xb128_in1k
|
||||
Cycle: no-training
|
||||
|
||||
- Name: twins-pcpvt-base_3rdparty_8xb128_in1k
|
||||
Cycle: no-training
|
||||
|
||||
- Name: efficientnet-b0_3rdparty_8xb32_in1k
|
||||
Cycle: no-training
|
||||
|
||||
- Name: convnext-small_3rdparty_32xb128_in1k
|
||||
Cycle: no-training
|
||||
|
||||
- Name: hrnet-w18_3rdparty_8xb32_in1k
|
||||
Cycle: no-training
|
||||
|
||||
- Name: repmlp-base_3rdparty_8xb64_in1k
|
||||
Cycle: no-training
|
||||
|
||||
- Name: wide-resnet50_3rdparty_8xb32_in1k
|
||||
Cycle: no-training
|
||||
|
||||
- Name: cspresnet50_3rdparty_8xb32_in1k
|
||||
Cycle: no-training
|
||||
|
||||
- Name: convmixer-768-32_10xb64_in1k
|
||||
Cycle: no-training
|
||||
|
||||
- Name: densenet169_4xb256_in1k
|
||||
Cycle: no-training
|
||||
|
||||
- Name: poolformer-s24_3rdparty_32xb128_in1k
|
||||
Cycle: no-training
|
||||
|
||||
- Name: inception-v3_3rdparty_8xb32_in1k
|
||||
Cycle: no-training
|
||||
|
|
|
@ -44,10 +44,10 @@ repos:
|
|||
- id: docformatter
|
||||
args: ["--in-place", "--wrap-descriptions", "79"]
|
||||
- repo: https://github.com/open-mmlab/pre-commit-hooks
|
||||
rev: v0.2.0
|
||||
rev: v0.4.0
|
||||
hooks:
|
||||
- id: check-copyright
|
||||
args: ["mmcls", "tests", "demo", "tools"]
|
||||
args: ["mmcls", "tests", "demo", "tools", "--excludes", "mmcls/.mim/", "--ignore-file-not-found-error"]
|
||||
# - repo: local
|
||||
# hooks:
|
||||
# - id: clang-format
|
||||
|
|
|
@ -6,14 +6,8 @@ optim_wrapper = dict(
|
|||
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='ConstantLR',
|
||||
factor=0.1,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=5,
|
||||
convert_to_iter_based=True),
|
||||
dict(type='PolyLR', eta_min=0, by_epoch=True, begin=5, end=300)
|
||||
dict(type='ConstantLR', factor=0.1, by_epoch=False, begin=0, end=5000),
|
||||
dict(type='PolyLR', eta_min=0, by_epoch=False, begin=5000)
|
||||
]
|
||||
|
||||
# train, val, test setting
|
||||
|
|
|
@ -16,6 +16,7 @@ model = dict(
|
|||
type='LabelSmoothLoss',
|
||||
label_smooth_val=0.1,
|
||||
mode='original',
|
||||
use_sigmoid=True,
|
||||
)),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.2, num_classes=1000),
|
||||
|
|
|
@ -24,37 +24,42 @@ class LabelSmoothLoss(nn.Module):
|
|||
label_smooth_val (float): The degree of label smoothing.
|
||||
num_classes (int, optional): Number of classes. Defaults to None.
|
||||
mode (str): Refers to notes, Options are 'original', 'classy_vision',
|
||||
'multi_label'. Defaults to 'original'
|
||||
'multi_label'. Defaults to 'original'.
|
||||
use_sigmoid (bool, optional): Whether the prediction uses sigmoid of
|
||||
softmax. Defaults to None, which means to use sigmoid in
|
||||
"multi_label" mode and not use in other modes.
|
||||
reduction (str): The method used to reduce the loss.
|
||||
Options are "none", "mean" and "sum". Defaults to 'mean'.
|
||||
loss_weight (float): Weight of the loss. Defaults to 1.0.
|
||||
|
||||
Notes:
|
||||
if the mode is "original", this will use the same label smooth method
|
||||
as the original paper as:
|
||||
- if the mode is **"original"**, this will use the same label smooth
|
||||
method as the original paper as:
|
||||
|
||||
.. math::
|
||||
(1-\epsilon)\delta_{k, y} + \frac{\epsilon}{K}
|
||||
.. math::
|
||||
(1-\epsilon)\delta_{k, y} + \frac{\epsilon}{K}
|
||||
|
||||
where epsilon is the `label_smooth_val`, K is the num_classes and
|
||||
delta(k,y) is Dirac delta, which equals 1 for k=y and 0 otherwise.
|
||||
where :math:`\epsilon` is the ``label_smooth_val``, :math:`K` is the
|
||||
``num_classes`` and :math:`\delta_{k, y}` is Dirac delta, which
|
||||
equals 1 for :math:`k=y` and 0 otherwise.
|
||||
|
||||
if the mode is "classy_vision", this will use the same label smooth
|
||||
method as the facebookresearch/ClassyVision repo as:
|
||||
- if the mode is **"classy_vision"**, this will use the same label
|
||||
smooth method as the facebookresearch/ClassyVision repo as:
|
||||
|
||||
.. math::
|
||||
\frac{\delta_{k, y} + \epsilon/K}{1+\epsilon}
|
||||
.. math::
|
||||
\frac{\delta_{k, y} + \epsilon/K}{1+\epsilon}
|
||||
|
||||
if the mode is "multi_label", this will accept labels from multi-label
|
||||
task and smoothing them as:
|
||||
- if the mode is **"multi_label"**, this will accept labels from
|
||||
multi-label task and smoothing them as:
|
||||
|
||||
.. math::
|
||||
(1-2\epsilon)\delta_{k, y} + \epsilon
|
||||
.. math::
|
||||
(1-2\epsilon)\delta_{k, y} + \epsilon
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
label_smooth_val,
|
||||
num_classes=None,
|
||||
use_sigmoid=None,
|
||||
mode='original',
|
||||
reduction='mean',
|
||||
loss_weight=1.0):
|
||||
|
@ -82,12 +87,21 @@ class LabelSmoothLoss(nn.Module):
|
|||
self._eps = label_smooth_val
|
||||
if mode == 'classy_vision':
|
||||
self._eps = label_smooth_val / (1 + label_smooth_val)
|
||||
|
||||
if mode == 'multi_label':
|
||||
self.ce = CrossEntropyLoss(use_sigmoid=True)
|
||||
if not use_sigmoid:
|
||||
from mmengine.logging import MMLogger
|
||||
MMLogger.get_current_instance().warning(
|
||||
'For multi-label tasks, please set `use_sigmoid=True` '
|
||||
'to use binary cross entropy.')
|
||||
self.smooth_label = self.multilabel_smooth_label
|
||||
use_sigmoid = True if use_sigmoid is None else use_sigmoid
|
||||
else:
|
||||
self.ce = CrossEntropyLoss(use_soft=True)
|
||||
self.smooth_label = self.original_smooth_label
|
||||
use_sigmoid = False if use_sigmoid is None else use_sigmoid
|
||||
|
||||
self.ce = CrossEntropyLoss(
|
||||
use_sigmoid=use_sigmoid, use_soft=not use_sigmoid)
|
||||
|
||||
def generate_one_hot_like_label(self, label):
|
||||
"""This function takes one-hot or index label vectors and computes one-
|
||||
|
|
|
@ -247,6 +247,17 @@ def test_label_smooth_loss():
|
|||
correct = 0.2269 # from timm
|
||||
assert loss(cls_score, label) - correct <= 0.0001
|
||||
|
||||
loss_cfg = dict(
|
||||
type='LabelSmoothLoss',
|
||||
label_smooth_val=0.1,
|
||||
mode='original',
|
||||
use_sigmoid=True,
|
||||
reduction='mean',
|
||||
loss_weight=1.0)
|
||||
loss = build_loss(loss_cfg)
|
||||
correct = 0.3633 # from timm
|
||||
assert loss(cls_score, label) - correct <= 0.0001
|
||||
|
||||
# test classy_vision mode label smooth loss
|
||||
loss_cfg = dict(
|
||||
type='LabelSmoothLoss',
|
||||
|
|
Loading…
Reference in New Issue