diff --git a/.dev_scripts/benchmark_regression/4-benchmark_speed.py b/.dev_scripts/benchmark_regression/4-benchmark_speed.py new file mode 100644 index 00000000..71ec017a --- /dev/null +++ b/.dev_scripts/benchmark_regression/4-benchmark_speed.py @@ -0,0 +1,273 @@ +import logging +import re +import time +from argparse import ArgumentParser +from copy import deepcopy +from datetime import datetime +from pathlib import Path +from typing import OrderedDict + +import numpy as np +import torch +from mmcv import Config +from mmcv.parallel.data_parallel import MMDataParallel +from mmcv.parallel.distributed import MMDistributedDataParallel +from mmcv.runner import load_checkpoint, wrap_fp16_model +from mmengine.logging.logger import MMLogger +from modelindex.load_model_index import load +from rich.console import Console +from rich.table import Table + +from mmcls.datasets.builder import build_dataloader +from mmcls.datasets.pipelines import Compose +from mmcls.models.builder import build_classifier + +console = Console() +MMCLS_ROOT = Path(__file__).absolute().parents[2] +logger = MMLogger( + name='benchmark', + logger_name='benchmark', + log_file='benchmark_speed.log', + log_level=logging.INFO) + + +def parse_args(): + parser = ArgumentParser( + description='Get FPS of all models in model-index.yml') + parser.add_argument( + '--checkpoint-root', + help='Checkpoint file root path. If set, load checkpoint before test.') + parser.add_argument( + '--models', nargs='+', help='models name to inference.') + parser.add_argument( + '--work-dir', + type=Path, + default='work_dirs/benchmark_speed', + help='the dir to save speed test results') + parser.add_argument( + '--max-iter', type=int, default=2048, help='num of max iter') + parser.add_argument( + '--batch-size', + type=int, + default=64, + help='The batch size to inference.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument( + '--device', default='cuda', help='Device used for inference') + parser.add_argument( + '--gpu-id', + type=int, + default=0, + help='id of gpu to use ' + '(only applicable to non-distributed testing)') + args = parser.parse_args() + return args + + +class ToyDataset: + """A dummy dataset used to provide images for benchmark.""" + + def __init__(self, num, hw) -> None: + data = [] + for _ in range(num): + if isinstance(hw, int): + w = h = hw + else: + w, h = hw + img = np.random.randint(0, 256, size=(h, w, 3), dtype=np.uint8) + data.append({'img': img}) + self.data = data + self.pipeline = None + + def __getitem__(self, idx): + return self.pipeline(deepcopy(self.data[idx])) + + def __len__(self): + return len(self.data) + + +def measure_fps(config_file, checkpoint, dataset, args, distributed=False): + cfg = Config.fromfile(config_file) + # set cudnn_benchmark + if cfg.get('cudnn_benchmark', False): + torch.backends.cudnn.benchmark = True + + # build the data pipeline + if cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile': + cfg.data.test.pipeline.pop(0) + + dataset.pipeline = Compose(cfg.data.test.pipeline) + resolution = tuple(dataset[0]['img'].shape[1:]) + + # build the dataloader + data_loader = build_dataloader( + dataset, + samples_per_gpu=args.batch_size, + # Because multiple processes will occupy additional CPU resources, + # FPS statistics will be more unstable when workers_per_gpu is not 0. + # It is reasonable to set workers_per_gpu to 0. + workers_per_gpu=0, + dist=False if args.launcher == 'none' else True, + shuffle=False, + drop_last=True, + persistent_workers=False) + + # build the model and load checkpoint + model = build_classifier(cfg.model) + fp16_cfg = cfg.get('fp16', None) + if fp16_cfg is not None: + wrap_fp16_model(model) + if checkpoint is not None: + load_checkpoint(model, checkpoint, map_location='cpu') + + if not distributed: + if args.device == 'cpu': + model = model.cpu() + else: + model = MMDataParallel(model, device_ids=[args.gpu_id]) + else: + model = MMDistributedDataParallel( + model.cuda(), + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False) + model.eval() + + # the first several iterations may be very slow so skip them + num_warmup = 5 + infer_time = [] + fps = 0 + + # forward the model + result = {'model': config_file.stem, 'resolution': resolution} + for i, data in enumerate(data_loader): + torch.cuda.synchronize() + start_time = time.perf_counter() + + with torch.no_grad(): + model(return_loss=False, **data) + + torch.cuda.synchronize() + elapsed = (time.perf_counter() - start_time) / args.batch_size + + if i >= num_warmup: + infer_time.append(elapsed) + if (i + 1) % 8 == 0: + fps = (i + 1 - num_warmup) / sum(infer_time) + print( + f'Done image [{(i + 1)*args.batch_size:<4}/' + f'{args.max_iter}], fps: {fps:.1f} img / s, ' + f'times per image: {1000 / fps:.1f} ms / img', + flush=True) + result['fps'] = (len(data_loader) - num_warmup) / sum(infer_time) + result['time_mean'] = np.mean(infer_time) * 1000 + result['time_std'] = np.std(infer_time) * 1000 + + return result + + +def show_summary(summary_data, args): + table = Table(title='Speed Benchmark Regression Summary') + table.add_column('Model') + table.add_column('Resolution (h, w)') + table.add_column('FPS (img/s)') + table.add_column('Inference Time (std) (ms/img)') + + for model_name, summary in summary_data.items(): + row = [model_name] + row.append(str(summary['resolution'])) + row.append(f"{summary['fps']:.2f}") + time_mean = f"{summary['time_mean']:.2f}" + time_std = f"{summary['time_std']:.2f}" + row.append(f'{time_mean}\t({time_std})'.expandtabs(8)) + table.add_row(*row) + + console.print(table) + + +# Sample test whether the inference code is correct +def main(args): + model_index_file = MMCLS_ROOT / 'model-index.yml' + model_index = load(str(model_index_file)) + model_index.build_models_with_collections() + models = OrderedDict({model.name: model for model in model_index.models}) + + if args.models: + patterns = [re.compile(pattern) for pattern in args.models] + filter_models = {} + for k, v in models.items(): + if any([re.match(pattern, k) for pattern in patterns]): + filter_models[k] = v + if len(filter_models) == 0: + print('No model found, please specify models in:') + print('\n'.join(models.keys())) + return + models = filter_models + + dataset_map = { + # come from the average size of ImageNet + 'ImageNet-1k': ToyDataset(args.max_iter, (442, 522)), + 'CIFAR-10': ToyDataset(args.max_iter, 32), + 'CIFAR-100': ToyDataset(args.max_iter, 32), + } + + summary_data = {} + for model_name, model_info in models.items(): + + if model_info.config is None: + continue + + config = Path(model_info.config) + assert config.exists(), f'{model_name}: {config} not found.' + + logger.info(f'Processing: {model_name}') + + http_prefix = 'https://download.openmmlab.com/mmclassification/' + dataset = model_info.results[0].dataset + if dataset not in dataset_map.keys(): + continue + if args.checkpoint_root is not None: + root = args.checkpoint_root + if 's3://' in args.checkpoint_root: + from mmcv.fileio import FileClient + from petrel_client.common.exception import AccessDeniedError + file_client = FileClient.infer_client(uri=root) + checkpoint = file_client.join_path( + root, model_info.weights[len(http_prefix):]) + try: + exists = file_client.exists(checkpoint) + except AccessDeniedError: + exists = False + else: + checkpoint = Path(root) / model_info.weights[len(http_prefix):] + exists = checkpoint.exists() + if exists: + checkpoint = str(checkpoint) + else: + print(f'WARNING: {model_name}: {checkpoint} not found.') + checkpoint = None + else: + checkpoint = None + + # build the model from a config file and a checkpoint file + result = measure_fps(MMCLS_ROOT / config, checkpoint, + dataset_map[dataset], args) + + summary_data[model_name] = result + + show_summary(summary_data, args) + args.work_dir.mkdir(parents=True, exist_ok=True) + out_path = args.work_dir / datetime.now().strftime('%Y-%m-%d.csv') + with open(out_path, 'w') as f: + f.write('MODEL,SHAPE,FPS\n') + for model, summary in summary_data.items(): + f.write( + f'{model},"{summary["resolution"]}",{summary["fps"]:.2f}\n') + + +if __name__ == '__main__': + args = parse_args() + main(args)