From e9e2d48cb2265c03a7ed4912a5fca1925ab9955c Mon Sep 17 00:00:00 2001 From: Ma Zerun Date: Tue, 20 Sep 2022 15:50:21 +0800 Subject: [PATCH] [Improve] Update benchmark scripts (#1028) * Update train benchmark scripts * Add `--cfg-options` for dev scripts and enhance `--range`. * Fix bug of regex expression. * Fix minor bugs * Update ShuffleNet configs * Update rsb-a1 configs and label smooth loss mode. * Update inference dev scripts * From `mmengine` instead of `mmcv` import fileio. * Fix lint * Update pre-commit hook * Use `use_sigmoid` option instead of "bce" mode in label smooth loss. --- .../benchmark_regression/1-benchmark_valid.py | 42 +++-- .../benchmark_regression/2-benchmark_test.py | 9 +- .../benchmark_regression/3-benchmark_train.py | 162 ++++++++++++------ .../benchmark_regression/bench_train.yml | 160 +++++++++-------- .pre-commit-config.yaml | 4 +- .../imagenet_bs1024_linearlr_bn_nowd.py | 10 +- .../resnet50_8xb256-rsb-a1-600e_in1k.py | 1 + mmcls/models/losses/label_smooth_loss.py | 48 ++++-- tests/test_models/test_losses.py | 11 ++ 9 files changed, 272 insertions(+), 175 deletions(-) diff --git a/.dev_scripts/benchmark_regression/1-benchmark_valid.py b/.dev_scripts/benchmark_regression/1-benchmark_valid.py index a9e33c64..fed85242 100644 --- a/.dev_scripts/benchmark_regression/1-benchmark_valid.py +++ b/.dev_scripts/benchmark_regression/1-benchmark_valid.py @@ -9,8 +9,10 @@ from typing import OrderedDict import mmcv import numpy as np import torch -from mmengine import Config, MMLogger, Runner -from mmengine.dataset import Compose +from mmengine import Config, DictAction, MMLogger +from mmengine.dataset import Compose, default_collate +from mmengine.fileio import FileClient +from mmengine.runner import Runner from modelindex.load_model_index import load from rich.console import Console from rich.table import Table @@ -52,6 +54,16 @@ def parse_args(): '--flops-str', action='store_true', help='Output FLOPs and params counts in a string form.') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') args = parser.parse_args() return args @@ -62,6 +74,8 @@ def inference(config_file, checkpoint, work_dir, args, exp_name): cfg.load_from = checkpoint cfg.log_level = 'WARN' cfg.experiment_name = exp_name + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) # build the data pipeline test_dataset = cfg.test_dataloader.dataset @@ -72,7 +86,8 @@ def inference(config_file, checkpoint, work_dir, args, exp_name): test_dataset.pipeline.insert(1, dict(type='Resize', scale=32)) data = Compose(test_dataset.pipeline)({'img_path': args.img}) - resolution = tuple(data['inputs'].shape[1:]) + data = default_collate([data]) + resolution = tuple(data['inputs'].shape[-2:]) runner: Runner = Runner.from_cfg(cfg) model = runner.model @@ -83,26 +98,30 @@ def inference(config_file, checkpoint, work_dir, args, exp_name): if args.inference_time: time_record = [] for _ in range(10): + model.val_step(data) # warmup before profiling + torch.cuda.synchronize() start = time() - model.val_step([data]) + model.val_step(data) + torch.cuda.synchronize() time_record.append((time() - start) * 1000) result['time_mean'] = np.mean(time_record[1:-1]) result['time_std'] = np.std(time_record[1:-1]) else: - model.val_step([data]) + model.val_step(data) result['model'] = config_file.stem if args.flops: - from mmcv.cnn.utils import get_model_complexity_info + from fvcore.nn import FlopCountAnalysis, parameter_count + from fvcore.nn.print_model_statistics import _format_size + _format_size = _format_size if args.flops_str else lambda x: x with torch.no_grad(): if hasattr(model, 'extract_feat'): model.forward = model.extract_feat - flops, params = get_model_complexity_info( - model, - input_shape=(3, ) + resolution, - print_per_layer_stat=False, - as_strings=args.flops_str) + model.to('cpu') + inputs = (torch.randn((1, 3, *resolution)), ) + flops = _format_size(FlopCountAnalysis(model, inputs).total()) + params = _format_size(parameter_count(model)['']) result['flops'] = flops if args.flops_str else int(flops) result['params'] = params if args.flops_str else int(params) else: @@ -184,7 +203,6 @@ def main(args): if args.checkpoint_root is not None: root = args.checkpoint_root if 's3://' in args.checkpoint_root: - from mmcv.fileio import FileClient from petrel_client.common.exception import AccessDeniedError file_client = FileClient.infer_client(uri=root) checkpoint = file_client.join_path( diff --git a/.dev_scripts/benchmark_regression/2-benchmark_test.py b/.dev_scripts/benchmark_regression/2-benchmark_test.py index 380e2519..cd49bd09 100644 --- a/.dev_scripts/benchmark_regression/2-benchmark_test.py +++ b/.dev_scripts/benchmark_regression/2-benchmark_test.py @@ -62,6 +62,12 @@ def parse_args(): action='store_true', help='Summarize benchmark test results.') parser.add_argument('--save', action='store_true', help='Save the summary') + parser.add_argument( + '--cfg-options', + nargs='+', + type=str, + default=[], + help='Config options for all config files.') args = parser.parse_args() return args @@ -76,7 +82,7 @@ def create_test_job_batch(commands, model_info, args, port, script_name): http_prefix = 'https://download.openmmlab.com/mmclassification/' if 's3://' in args.checkpoint_root: - from mmcv.fileio import FileClient + from mmengine.fileio import FileClient from petrel_client.common.exception import AccessDeniedError file_client = FileClient.infer_client(uri=args.checkpoint_root) checkpoint = file_client.join_path( @@ -125,6 +131,7 @@ def create_test_job_batch(commands, model_info, args, port, script_name): f'--work-dir={work_dir} ' f'--out={result_file} ' f'--cfg-option dist_params.port={port} ' + f'{" ".join(args.cfg_options)} ' f'--launcher={launcher}\n') with open(work_dir / 'job.sh', 'w') as f: diff --git a/.dev_scripts/benchmark_regression/3-benchmark_train.py b/.dev_scripts/benchmark_regression/3-benchmark_train.py index 6d78be50..9e240cc5 100644 --- a/.dev_scripts/benchmark_regression/3-benchmark_train.py +++ b/.dev_scripts/benchmark_regression/3-benchmark_train.py @@ -3,22 +3,48 @@ import json import os import os.path as osp import re +from collections import OrderedDict from datetime import datetime from pathlib import Path from zipfile import ZipFile +import yaml from modelindex.load_model_index import load from rich.console import Console from rich.syntax import Syntax from rich.table import Table console = Console() +MMCLS_ROOT = Path(__file__).absolute().parents[2] +CYCLE_LEVELS = ['month', 'quarter', 'half-year', 'no-training'] METRICS_MAP = { 'Top 1 Accuracy': 'accuracy/top1', 'Top 5 Accuracy': 'accuracy/top5' } +class RangeAction(argparse.Action): + + def __call__(self, parser, namespace, values: str, option_string): + matches = re.match(r'([><=]*)([-\w]+)', values) + if matches is None: + raise ValueError(f'Unavailable range option {values}') + symbol, range_str = matches.groups() + assert range_str in CYCLE_LEVELS, \ + f'{range_str} are not in {CYCLE_LEVELS}.' + level = CYCLE_LEVELS.index(range_str) + symbol = symbol or '<=' + ranges = set() + if '=' in symbol: + ranges.add(level) + if '>' in symbol: + ranges.update(range(level + 1, len(CYCLE_LEVELS))) + if '<' in symbol: + ranges.update(range(level)) + assert len(ranges) > 0, 'No range are selected.' + setattr(namespace, self.dest, ranges) + + def parse_args(): parser = argparse.ArgumentParser( description='Train models (in bench_train.yml) and compare accuracy.') @@ -32,6 +58,14 @@ def parse_args(): parser.add_argument('--port', type=int, default=29666, help='dist port') parser.add_argument( '--models', nargs='+', type=str, help='Specify model names to run.') + parser.add_argument( + '--range', + type=str, + default={0}, + action=RangeAction, + metavar='{month,quarter,half-year,no-training}', + help='The training benchmark range, "no-training" means all models ' + "including those we haven't trained.") parser.add_argument( '--work-dir', default='work_dirs/benchmark_train', @@ -63,18 +97,33 @@ def parse_args(): '--save', action='store_true', help='Save the summary and archive log files.') + parser.add_argument( + '--cfg-options', + nargs='+', + type=str, + default=[], + help='Config options for all config files.') args = parser.parse_args() return args +def get_gpu_number(model_info): + config = osp.basename(model_info.config) + matches = re.match(r'.*[-_](\d+)xb(\d+).*', config) + if matches is None: + raise ValueError( + 'Cannot get gpu numbers from the config name {config}') + gpus = int(matches.groups()[0]) + return gpus + + def create_train_job_batch(commands, model_info, args, port, script_name): fname = model_info.name - assert 'Gpus' in model_info.data, \ - f"Haven't specify gpu numbers for {fname}" - gpus = model_info.data['Gpus'] + gpus = get_gpu_number(model_info) + gpus_per_node = min(gpus, 8) config = Path(model_info.config) assert config.exists(), f'"{fname}": {config} not found.' @@ -101,15 +150,17 @@ def create_train_job_batch(commands, model_info, args, port, script_name): f'#SBATCH --output {work_dir}/job.%j.out\n' f'#SBATCH --partition={args.partition}\n' f'#SBATCH --job-name {job_name}\n' - f'#SBATCH --gres=gpu:8\n' + f'#SBATCH --gres=gpu:{gpus_per_node}\n' f'{mail_cfg}{quota_cfg}' - f'#SBATCH --ntasks-per-node=8\n' + f'#SBATCH --ntasks-per-node={gpus_per_node}\n' f'#SBATCH --ntasks={gpus}\n' f'#SBATCH --cpus-per-task=5\n\n' f'{runner} -u {script_name} {config} ' f'--work-dir={work_dir} --cfg-option ' - f'dist_params.port={port} ' - f'checkpoint_config.max_keep_ckpts=10 ' + f'env_cfg.dist_cfg.port={port} ' + f'{" ".join(args.cfg_options)} ' + f'default_hooks.checkpoint.max_keep_ckpts=2 ' + f'default_hooks.checkpoint.save_best="auto" ' f'--launcher={launcher}\n') with open(work_dir / 'job.sh', 'w') as f: @@ -124,33 +175,16 @@ def create_train_job_batch(commands, model_info, args, port, script_name): return work_dir / 'job.sh' -def train(args): - models_cfg = load(str(Path(__file__).parent / 'bench_train.yml')) - models_cfg.build_models_with_collections() - models = {model.name: model for model in models_cfg.models} - +def train(models, args): script_name = osp.join('tools', 'train.py') port = args.port commands = [] - if args.models: - patterns = [re.compile(pattern) for pattern in args.models] - filter_models = {} - for k, v in models.items(): - if any([re.match(pattern, k) for pattern in patterns]): - filter_models[k] = v - if len(filter_models) == 0: - print('No model found, please specify models in:') - print('\n'.join(models.keys())) - return - models = filter_models for model_info in models.values(): - months = model_info.data.get('Months', range(1, 13)) - if datetime.now().month in months: - script_path = create_train_job_batch(commands, model_info, args, - port, script_name) - port += 1 + script_path = create_train_job_batch(commands, model_info, args, port, + script_name) + port += 1 command_str = '\n'.join(commands) @@ -245,12 +279,14 @@ def show_summary(summary_data): metric = summary[metric_key] expect = metric['expect'] last = metric['last'] + last_epoch = metric['last_epoch'] last_color = set_color(last, expect) best = metric['best'] best_color = set_color(best, expect) best_epoch = metric['best_epoch'] row.append(f'{expect:.2f}') - row.append(f'[{last_color}]{last:.2f}[/{last_color}]') + row.append( + f'[{last_color}]{last:.2f}[/{last_color}] ({last_epoch})') row.append( f'[{best_color}]{best:.2f}[/{best_color}] ({best_epoch})') table.add_row(*row) @@ -258,25 +294,11 @@ def show_summary(summary_data): console.print(table) -def summary(args): - models_cfg = load(str(Path(__file__).parent / 'bench_train.yml')) - models = {model.name: model for model in models_cfg.models} +def summary(models, args): work_dir = Path(args.work_dir) dir_map = {p.name: p for p in work_dir.iterdir() if p.is_dir()} - if args.models: - patterns = [re.compile(pattern) for pattern in args.models] - filter_models = {} - for k, v in models.items(): - if any([re.match(pattern, k) for pattern in patterns]): - filter_models[k] = v - if len(filter_models) == 0: - print('No model found, please specify models in:') - print('\n'.join(models.keys())) - return - models = filter_models - summary_data = {} for model_name, model_info in models.items(): @@ -287,17 +309,19 @@ def summary(args): # Skip if not found any vis_data folder. sub_dir = dir_map[model_name] - vis_folders = [d for d in sub_dir.iterdir() if d.is_dir()] - if len(vis_folders) == 0: - continue - log_file = sorted(vis_folders)[-1] / 'vis_data' / 'scalars.json' - if not log_file.exists(): + log_files = [f for f in sub_dir.glob('*/vis_data/scalars.json')] + if len(log_files) == 0: continue + log_file = sorted(log_files)[-1] # parse train log with open(log_file) as f: json_logs = [json.loads(s) for s in f.readlines()] - val_logs = [log for log in json_logs if 'loss' not in log] + val_logs = [ + log for log in json_logs + # TODO: need a better method to extract validate log + if 'loss' not in log and 'accuracy/top1' in log + ] if len(val_logs) == 0: continue @@ -320,9 +344,10 @@ def summary(args): summary[key_yml] = dict( expect=expect_result, last=last, + last_epoch=len(val_logs), best=best, - best_epoch=best_epoch) - summary_data[model_name] = summary + best_epoch=best_epoch + 1) + summary_data[model_name].update(summary) show_summary(summary_data) if args.save: @@ -332,10 +357,39 @@ def summary(args): def main(): args = parse_args() + model_index_file = MMCLS_ROOT / 'model-index.yml' + model_index = load(str(model_index_file)) + model_index.build_models_with_collections() + all_models = {model.name: model for model in model_index.models} + + with open(Path(__file__).parent / 'bench_train.yml', 'r') as f: + train_items = yaml.safe_load(f) + models = OrderedDict() + for item in train_items: + name = item['Name'] + model_info = all_models[name] + model_info.cycle = item.get('Cycle', None) + cycle = getattr(model_info, 'cycle', 'month') + cycle_level = CYCLE_LEVELS.index(cycle) + if cycle_level in args.range: + models[name] = model_info + + if args.models: + patterns = [re.compile(pattern) for pattern in args.models] + filter_models = {} + for k, v in models.items(): + if any([re.match(pattern, k) for pattern in patterns]): + filter_models[k] = v + if len(filter_models) == 0: + print('No model found, please specify models in:') + print('\n'.join(models.keys())) + return + models = filter_models + if args.summary: - summary(args) + summary(models, args) else: - train(args) + train(models, args) if __name__ == '__main__': diff --git a/.dev_scripts/benchmark_regression/bench_train.yml b/.dev_scripts/benchmark_regression/bench_train.yml index c7326849..9f6e11eb 100644 --- a/.dev_scripts/benchmark_regression/bench_train.yml +++ b/.dev_scripts/benchmark_regression/bench_train.yml @@ -1,88 +1,86 @@ -Models: - - Name: resnet50 - Results: - - Dataset: ImageNet-1k - Metrics: - Top 1 Accuracy: 76.55 - Top 5 Accuracy: 93.06 - Config: configs/resnet/resnet50_8xb32_in1k.py - Gpus: 8 +- Name: mobilenet-v2_8xb32_in1k + Cycle: month - - Name: seresnet50 - Results: - - Dataset: ImageNet-1k - Metrics: - Top 1 Accuracy: 77.74 - Top 5 Accuracy: 93.84 - Config: configs/seresnet/seresnet50_8xb32_in1k.py - Gpus: 8 +- Name: resnet50_8xb32_in1k + Cycle: month - - Name: vit-base - Results: - - Dataset: ImageNet-1k - Metrics: - Top 1 Accuracy: 82.37 - Top 5 Accuracy: 96.15 - Config: configs/vision_transformer/vit-base-p16_pt-32xb128-mae_in1k-224.py - Gpus: 32 +- Name: seresnet50_8xb32_in1k + Cycle: month - - Name: mobilenetv2 - Results: - - Dataset: ImageNet-1k - Metrics: - Top 1 Accuracy: 71.86 - Top 5 Accuracy: 90.42 - Config: configs/mobilenet_v2/mobilenet-v2_8xb32_in1k.py - Gpus: 8 +- Name: swin-small_16xb64_in1k + Cycle: month - - Name: swin_tiny - Results: - - Dataset: ImageNet - Metrics: - Top 1 Accuracy: 81.18 - Top 5 Accuracy: 95.61 - Weights: https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_b16x64_300e_imagenet_20210616_090925-66df6be6.pth - Config: configs/swin_transformer/swin-tiny_16xb64_in1k.py - Gpus: 16 +- Name: vit-base-p16_pt-32xb128-mae_in1k + Cycle: month - - Name: vgg16 - Results: - - Dataset: ImageNet-1k - Metrics: - Top 1 Accuracy: 71.62 - Top 5 Accuracy: 90.49 - Config: configs/vgg/vgg16_8xb32_in1k.py - Gpus: 8 - Months: - - 1 - - 4 - - 7 - - 10 +- Name: resnet50_8xb256-rsb-a1-600e_in1k + Cycle: quarter - - Name: shufflenet_v2 - Results: - - Dataset: ImageNet-1k - Metrics: - Top 1 Accuracy: 69.55 - Top 5 Accuracy: 88.92 - Config: configs/shufflenet_v2/shufflenet-v2-1x_16xb64_in1k.py - Gpus: 16 - Months: - - 2 - - 5 - - 8 - - 11 +- Name: resnext50-32x4d_8xb32_in1k + Cycle: quarter - - Name: resnet-rsb - Results: - - Dataset: ImageNet-1k - Metrics: - Top 1 Accuracy: 80.12 - Top 5 Accuracy: 94.78 - Config: configs/resnet/resnet50_8xb256-rsb-a1-600e_in1k.py - Gpus: 8 - Months: - - 3 - - 6 - - 9 - - 12 +- Name: shufflenet-v2-1x_16xb64_in1k + Cycle: quarter + +- Name: vgg16_8xb32_in1k + Cycle: quarter + +- Name: shufflenet-v1-1x_16xb64_in1k + Cycle: half-year + +- Name: t2t-vit-t-14_8xb64_in1k + Cycle: half-year + +- Name: regnetx-1.6gf_8xb128_in1k + Cycle: half-year + +- Name: van-small_8xb128_in1k + Cycle: no-training + +- Name: res2net50-w14-s8_3rdparty_8xb32_in1k + Cycle: no-training + +- Name: repvgg-A2_3rdparty_4xb64-coslr-120e_in1k + Cycle: no-training + +- Name: tnt-small-p16_3rdparty_in1k + Cycle: no-training + +- Name: mlp-mixer-base-p16_3rdparty_64xb64_in1k + Cycle: no-training + +- Name: conformer-small-p16_3rdparty_8xb128_in1k + Cycle: no-training + +- Name: twins-pcpvt-base_3rdparty_8xb128_in1k + Cycle: no-training + +- Name: efficientnet-b0_3rdparty_8xb32_in1k + Cycle: no-training + +- Name: convnext-small_3rdparty_32xb128_in1k + Cycle: no-training + +- Name: hrnet-w18_3rdparty_8xb32_in1k + Cycle: no-training + +- Name: repmlp-base_3rdparty_8xb64_in1k + Cycle: no-training + +- Name: wide-resnet50_3rdparty_8xb32_in1k + Cycle: no-training + +- Name: cspresnet50_3rdparty_8xb32_in1k + Cycle: no-training + +- Name: convmixer-768-32_10xb64_in1k + Cycle: no-training + +- Name: densenet169_4xb256_in1k + Cycle: no-training + +- Name: poolformer-s24_3rdparty_32xb128_in1k + Cycle: no-training + +- Name: inception-v3_3rdparty_8xb32_in1k + Cycle: no-training diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0d19d5f6..c55af0f2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,10 +44,10 @@ repos: - id: docformatter args: ["--in-place", "--wrap-descriptions", "79"] - repo: https://github.com/open-mmlab/pre-commit-hooks - rev: v0.2.0 + rev: v0.4.0 hooks: - id: check-copyright - args: ["mmcls", "tests", "demo", "tools"] + args: ["mmcls", "tests", "demo", "tools", "--excludes", "mmcls/.mim/", "--ignore-file-not-found-error"] # - repo: local # hooks: # - id: clang-format diff --git a/configs/_base_/schedules/imagenet_bs1024_linearlr_bn_nowd.py b/configs/_base_/schedules/imagenet_bs1024_linearlr_bn_nowd.py index 18b8554d..cf38d473 100644 --- a/configs/_base_/schedules/imagenet_bs1024_linearlr_bn_nowd.py +++ b/configs/_base_/schedules/imagenet_bs1024_linearlr_bn_nowd.py @@ -6,14 +6,8 @@ optim_wrapper = dict( # learning policy param_scheduler = [ - dict( - type='ConstantLR', - factor=0.1, - by_epoch=True, - begin=0, - end=5, - convert_to_iter_based=True), - dict(type='PolyLR', eta_min=0, by_epoch=True, begin=5, end=300) + dict(type='ConstantLR', factor=0.1, by_epoch=False, begin=0, end=5000), + dict(type='PolyLR', eta_min=0, by_epoch=False, begin=5000) ] # train, val, test setting diff --git a/configs/resnet/resnet50_8xb256-rsb-a1-600e_in1k.py b/configs/resnet/resnet50_8xb256-rsb-a1-600e_in1k.py index 3d2d5894..1c213127 100644 --- a/configs/resnet/resnet50_8xb256-rsb-a1-600e_in1k.py +++ b/configs/resnet/resnet50_8xb256-rsb-a1-600e_in1k.py @@ -16,6 +16,7 @@ model = dict( type='LabelSmoothLoss', label_smooth_val=0.1, mode='original', + use_sigmoid=True, )), train_cfg=dict(augments=[ dict(type='Mixup', alpha=0.2, num_classes=1000), diff --git a/mmcls/models/losses/label_smooth_loss.py b/mmcls/models/losses/label_smooth_loss.py index 99e50a77..b6e48c09 100644 --- a/mmcls/models/losses/label_smooth_loss.py +++ b/mmcls/models/losses/label_smooth_loss.py @@ -24,37 +24,42 @@ class LabelSmoothLoss(nn.Module): label_smooth_val (float): The degree of label smoothing. num_classes (int, optional): Number of classes. Defaults to None. mode (str): Refers to notes, Options are 'original', 'classy_vision', - 'multi_label'. Defaults to 'original' + 'multi_label'. Defaults to 'original'. + use_sigmoid (bool, optional): Whether the prediction uses sigmoid of + softmax. Defaults to None, which means to use sigmoid in + "multi_label" mode and not use in other modes. reduction (str): The method used to reduce the loss. Options are "none", "mean" and "sum". Defaults to 'mean'. loss_weight (float): Weight of the loss. Defaults to 1.0. Notes: - if the mode is "original", this will use the same label smooth method - as the original paper as: + - if the mode is **"original"**, this will use the same label smooth + method as the original paper as: - .. math:: - (1-\epsilon)\delta_{k, y} + \frac{\epsilon}{K} + .. math:: + (1-\epsilon)\delta_{k, y} + \frac{\epsilon}{K} - where epsilon is the `label_smooth_val`, K is the num_classes and - delta(k,y) is Dirac delta, which equals 1 for k=y and 0 otherwise. + where :math:`\epsilon` is the ``label_smooth_val``, :math:`K` is the + ``num_classes`` and :math:`\delta_{k, y}` is Dirac delta, which + equals 1 for :math:`k=y` and 0 otherwise. - if the mode is "classy_vision", this will use the same label smooth - method as the facebookresearch/ClassyVision repo as: + - if the mode is **"classy_vision"**, this will use the same label + smooth method as the facebookresearch/ClassyVision repo as: - .. math:: - \frac{\delta_{k, y} + \epsilon/K}{1+\epsilon} + .. math:: + \frac{\delta_{k, y} + \epsilon/K}{1+\epsilon} - if the mode is "multi_label", this will accept labels from multi-label - task and smoothing them as: + - if the mode is **"multi_label"**, this will accept labels from + multi-label task and smoothing them as: - .. math:: - (1-2\epsilon)\delta_{k, y} + \epsilon + .. math:: + (1-2\epsilon)\delta_{k, y} + \epsilon """ def __init__(self, label_smooth_val, num_classes=None, + use_sigmoid=None, mode='original', reduction='mean', loss_weight=1.0): @@ -82,12 +87,21 @@ class LabelSmoothLoss(nn.Module): self._eps = label_smooth_val if mode == 'classy_vision': self._eps = label_smooth_val / (1 + label_smooth_val) + if mode == 'multi_label': - self.ce = CrossEntropyLoss(use_sigmoid=True) + if not use_sigmoid: + from mmengine.logging import MMLogger + MMLogger.get_current_instance().warning( + 'For multi-label tasks, please set `use_sigmoid=True` ' + 'to use binary cross entropy.') self.smooth_label = self.multilabel_smooth_label + use_sigmoid = True if use_sigmoid is None else use_sigmoid else: - self.ce = CrossEntropyLoss(use_soft=True) self.smooth_label = self.original_smooth_label + use_sigmoid = False if use_sigmoid is None else use_sigmoid + + self.ce = CrossEntropyLoss( + use_sigmoid=use_sigmoid, use_soft=not use_sigmoid) def generate_one_hot_like_label(self, label): """This function takes one-hot or index label vectors and computes one- diff --git a/tests/test_models/test_losses.py b/tests/test_models/test_losses.py index 74eec620..442da9df 100644 --- a/tests/test_models/test_losses.py +++ b/tests/test_models/test_losses.py @@ -247,6 +247,17 @@ def test_label_smooth_loss(): correct = 0.2269 # from timm assert loss(cls_score, label) - correct <= 0.0001 + loss_cfg = dict( + type='LabelSmoothLoss', + label_smooth_val=0.1, + mode='original', + use_sigmoid=True, + reduction='mean', + loss_weight=1.0) + loss = build_loss(loss_cfg) + correct = 0.3633 # from timm + assert loss(cls_score, label) - correct <= 0.0001 + # test classy_vision mode label smooth loss loss_cfg = dict( type='LabelSmoothLoss',