2021-10-13 17:01:37 +08:00
|
|
|
import logging
|
|
|
|
import re
|
2022-07-12 16:10:59 +08:00
|
|
|
import tempfile
|
2021-10-13 17:01:37 +08:00
|
|
|
from argparse import ArgumentParser
|
2022-09-21 13:27:04 +08:00
|
|
|
from collections import OrderedDict
|
2021-10-13 17:01:37 +08:00
|
|
|
from pathlib import Path
|
|
|
|
from time import time
|
|
|
|
|
2022-07-12 16:10:59 +08:00
|
|
|
import mmcv
|
2021-10-13 17:01:37 +08:00
|
|
|
import numpy as np
|
|
|
|
import torch
|
2022-09-20 15:50:21 +08:00
|
|
|
from mmengine import Config, DictAction, MMLogger
|
|
|
|
from mmengine.dataset import Compose, default_collate
|
|
|
|
from mmengine.fileio import FileClient
|
2022-12-14 13:21:33 +08:00
|
|
|
from mmengine.runner import Runner, load_checkpoint
|
2021-12-14 17:19:32 +08:00
|
|
|
from modelindex.load_model_index import load
|
2021-10-13 17:01:37 +08:00
|
|
|
from rich.console import Console
|
|
|
|
from rich.table import Table
|
|
|
|
|
2023-02-17 11:31:08 +08:00
|
|
|
from mmpretrain.apis import init_model
|
|
|
|
from mmpretrain.datasets import CIFAR10, CIFAR100, ImageNet
|
|
|
|
from mmpretrain.utils import register_all_modules
|
|
|
|
from mmpretrain.visualization import ClsVisualizer
|
2021-10-13 17:01:37 +08:00
|
|
|
|
|
|
|
console = Console()
|
|
|
|
MMCLS_ROOT = Path(__file__).absolute().parents[2]
|
|
|
|
|
|
|
|
classes_map = {
|
|
|
|
'ImageNet-1k': ImageNet.CLASSES,
|
2022-07-12 16:10:59 +08:00
|
|
|
'CIFAR-10': CIFAR10.CLASSES,
|
|
|
|
'CIFAR-100': CIFAR100.CLASSES,
|
2021-10-13 17:01:37 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
|
|
parser = ArgumentParser(description='Valid all models in model-index.yml')
|
|
|
|
parser.add_argument(
|
|
|
|
'--checkpoint-root',
|
|
|
|
help='Checkpoint file root path. If set, load checkpoint before test.')
|
|
|
|
parser.add_argument('--img', default='demo/demo.JPEG', help='Image file')
|
|
|
|
parser.add_argument('--models', nargs='+', help='models name to inference')
|
|
|
|
parser.add_argument('--show', action='store_true', help='show results')
|
|
|
|
parser.add_argument(
|
|
|
|
'--wait-time',
|
|
|
|
type=float,
|
|
|
|
default=1,
|
|
|
|
help='the interval of show (s), 0 is block')
|
|
|
|
parser.add_argument(
|
|
|
|
'--inference-time',
|
|
|
|
action='store_true',
|
|
|
|
help='Test inference time by run 10 times for each model.')
|
2022-10-08 11:14:35 +08:00
|
|
|
parser.add_argument(
|
|
|
|
'--batch-size',
|
|
|
|
type=int,
|
|
|
|
default=1,
|
|
|
|
help='The batch size during the inference.')
|
2021-12-14 17:19:32 +08:00
|
|
|
parser.add_argument(
|
|
|
|
'--flops', action='store_true', help='Get Flops and Params of models')
|
|
|
|
parser.add_argument(
|
|
|
|
'--flops-str',
|
|
|
|
action='store_true',
|
|
|
|
help='Output FLOPs and params counts in a string form.')
|
2022-09-20 15:50:21 +08:00
|
|
|
parser.add_argument(
|
|
|
|
'--cfg-options',
|
|
|
|
nargs='+',
|
|
|
|
action=DictAction,
|
|
|
|
help='override some settings in the used config, the key-value pair '
|
|
|
|
'in xxx=yyy format will be merged into config file. If the value to '
|
|
|
|
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
|
|
|
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
|
|
|
'Note that the quotation marks are necessary and that no white space '
|
|
|
|
'is allowed.')
|
2021-10-13 17:01:37 +08:00
|
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
2022-07-12 16:10:59 +08:00
|
|
|
def inference(config_file, checkpoint, work_dir, args, exp_name):
|
2021-10-13 17:01:37 +08:00
|
|
|
cfg = Config.fromfile(config_file)
|
2022-07-12 16:10:59 +08:00
|
|
|
cfg.work_dir = work_dir
|
|
|
|
cfg.load_from = checkpoint
|
|
|
|
cfg.log_level = 'WARN'
|
|
|
|
cfg.experiment_name = exp_name
|
2022-09-20 15:50:21 +08:00
|
|
|
if args.cfg_options is not None:
|
|
|
|
cfg.merge_from_dict(args.cfg_options)
|
2021-10-13 17:01:37 +08:00
|
|
|
|
2022-12-14 13:21:33 +08:00
|
|
|
if 'test_dataloader' in cfg:
|
|
|
|
# build the data pipeline
|
|
|
|
test_dataset = cfg.test_dataloader.dataset
|
|
|
|
if test_dataset.pipeline[0]['type'] != 'LoadImageFromFile':
|
|
|
|
test_dataset.pipeline.insert(0, dict(type='LoadImageFromFile'))
|
|
|
|
if test_dataset.type in ['CIFAR10', 'CIFAR100']:
|
|
|
|
# The image shape of CIFAR is (32, 32, 3)
|
|
|
|
test_dataset.pipeline.insert(1, dict(type='Resize', scale=32))
|
|
|
|
|
|
|
|
data = Compose(test_dataset.pipeline)({'img_path': args.img})
|
|
|
|
data = default_collate([data] * args.batch_size)
|
|
|
|
resolution = tuple(data['inputs'].shape[-2:])
|
|
|
|
model = Runner.from_cfg(cfg).model
|
2023-01-17 16:55:56 +08:00
|
|
|
model.eval()
|
2022-12-14 13:21:33 +08:00
|
|
|
forward = model.val_step
|
|
|
|
else:
|
|
|
|
# For configs only for get model.
|
|
|
|
model = init_model(cfg)
|
2023-01-17 16:55:56 +08:00
|
|
|
model.eval()
|
2022-12-14 13:21:33 +08:00
|
|
|
data = torch.empty(1, 3, 224, 224).to(model.data_preprocessor.device)
|
|
|
|
resolution = (224, 224)
|
|
|
|
forward = model.extract_feat
|
2021-10-13 17:01:37 +08:00
|
|
|
|
2023-01-17 16:55:56 +08:00
|
|
|
if checkpoint is not None:
|
|
|
|
load_checkpoint(model, checkpoint, map_location='cpu')
|
|
|
|
|
2021-10-13 17:01:37 +08:00
|
|
|
# forward the model
|
|
|
|
result = {'resolution': resolution}
|
|
|
|
with torch.no_grad():
|
|
|
|
if args.inference_time:
|
|
|
|
time_record = []
|
|
|
|
for _ in range(10):
|
2022-12-14 13:21:33 +08:00
|
|
|
forward(data) # warmup before profiling
|
2022-09-20 15:50:21 +08:00
|
|
|
torch.cuda.synchronize()
|
2021-10-13 17:01:37 +08:00
|
|
|
start = time()
|
2022-12-14 13:21:33 +08:00
|
|
|
forward(data)
|
2022-09-20 15:50:21 +08:00
|
|
|
torch.cuda.synchronize()
|
2022-10-08 11:14:35 +08:00
|
|
|
time_record.append((time() - start) / args.batch_size * 1000)
|
2021-10-13 17:01:37 +08:00
|
|
|
result['time_mean'] = np.mean(time_record[1:-1])
|
|
|
|
result['time_std'] = np.std(time_record[1:-1])
|
|
|
|
else:
|
2022-12-14 13:21:33 +08:00
|
|
|
forward(data)
|
2021-10-13 17:01:37 +08:00
|
|
|
|
|
|
|
result['model'] = config_file.stem
|
|
|
|
|
2021-12-14 17:19:32 +08:00
|
|
|
if args.flops:
|
2022-09-20 15:50:21 +08:00
|
|
|
from fvcore.nn import FlopCountAnalysis, parameter_count
|
|
|
|
from fvcore.nn.print_model_statistics import _format_size
|
|
|
|
_format_size = _format_size if args.flops_str else lambda x: x
|
2022-01-28 10:36:45 +08:00
|
|
|
with torch.no_grad():
|
|
|
|
if hasattr(model, 'extract_feat'):
|
|
|
|
model.forward = model.extract_feat
|
2022-09-20 15:50:21 +08:00
|
|
|
model.to('cpu')
|
|
|
|
inputs = (torch.randn((1, 3, *resolution)), )
|
|
|
|
flops = _format_size(FlopCountAnalysis(model, inputs).total())
|
|
|
|
params = _format_size(parameter_count(model)[''])
|
2022-01-28 10:36:45 +08:00
|
|
|
result['flops'] = flops if args.flops_str else int(flops)
|
|
|
|
result['params'] = params if args.flops_str else int(params)
|
|
|
|
else:
|
|
|
|
result['flops'] = ''
|
|
|
|
result['params'] = ''
|
2021-12-14 17:19:32 +08:00
|
|
|
|
2021-10-13 17:01:37 +08:00
|
|
|
return result
|
|
|
|
|
|
|
|
|
2021-12-14 17:19:32 +08:00
|
|
|
def show_summary(summary_data, args):
|
2021-10-13 17:01:37 +08:00
|
|
|
table = Table(title='Validation Benchmark Regression Summary')
|
|
|
|
table.add_column('Model')
|
|
|
|
table.add_column('Validation')
|
|
|
|
table.add_column('Resolution (h, w)')
|
2021-12-14 17:19:32 +08:00
|
|
|
if args.inference_time:
|
|
|
|
table.add_column('Inference Time (std) (ms/im)')
|
|
|
|
if args.flops:
|
2022-12-14 13:21:33 +08:00
|
|
|
table.add_column('Flops', justify='right', width=13)
|
|
|
|
table.add_column('Params', justify='right', width=11)
|
2021-10-13 17:01:37 +08:00
|
|
|
|
|
|
|
for model_name, summary in summary_data.items():
|
|
|
|
row = [model_name]
|
|
|
|
valid = summary['valid']
|
|
|
|
color = 'green' if valid == 'PASS' else 'red'
|
|
|
|
row.append(f'[{color}]{valid}[/{color}]')
|
|
|
|
if valid == 'PASS':
|
|
|
|
row.append(str(summary['resolution']))
|
2021-12-14 17:19:32 +08:00
|
|
|
if args.inference_time:
|
2021-10-13 17:01:37 +08:00
|
|
|
time_mean = f"{summary['time_mean']:.2f}"
|
|
|
|
time_std = f"{summary['time_std']:.2f}"
|
|
|
|
row.append(f'{time_mean}\t({time_std})'.expandtabs(8))
|
2021-12-14 17:19:32 +08:00
|
|
|
if args.flops:
|
|
|
|
row.append(str(summary['flops']))
|
|
|
|
row.append(str(summary['params']))
|
2021-10-13 17:01:37 +08:00
|
|
|
table.add_row(*row)
|
|
|
|
|
|
|
|
console.print(table)
|
|
|
|
|
|
|
|
|
|
|
|
# Sample test whether the inference code is correct
|
|
|
|
def main(args):
|
2022-07-12 16:10:59 +08:00
|
|
|
register_all_modules()
|
2021-10-13 17:01:37 +08:00
|
|
|
model_index_file = MMCLS_ROOT / 'model-index.yml'
|
2021-12-14 17:19:32 +08:00
|
|
|
model_index = load(str(model_index_file))
|
|
|
|
model_index.build_models_with_collections()
|
|
|
|
models = OrderedDict({model.name: model for model in model_index.models})
|
2021-10-13 17:01:37 +08:00
|
|
|
|
2022-07-12 16:10:59 +08:00
|
|
|
logger = MMLogger(
|
|
|
|
'validation',
|
|
|
|
logger_name='validation',
|
|
|
|
log_file='benchmark_test_image.log',
|
|
|
|
log_level=logging.INFO)
|
2021-10-13 17:01:37 +08:00
|
|
|
|
|
|
|
if args.models:
|
|
|
|
patterns = [re.compile(pattern) for pattern in args.models]
|
|
|
|
filter_models = {}
|
|
|
|
for k, v in models.items():
|
|
|
|
if any([re.match(pattern, k) for pattern in patterns]):
|
|
|
|
filter_models[k] = v
|
|
|
|
if len(filter_models) == 0:
|
|
|
|
print('No model found, please specify models in:')
|
|
|
|
print('\n'.join(models.keys()))
|
|
|
|
return
|
|
|
|
models = filter_models
|
|
|
|
|
|
|
|
summary_data = {}
|
2022-07-12 16:10:59 +08:00
|
|
|
tmpdir = tempfile.TemporaryDirectory()
|
2021-10-13 17:01:37 +08:00
|
|
|
for model_name, model_info in models.items():
|
|
|
|
|
2022-01-28 10:36:45 +08:00
|
|
|
if model_info.config is None:
|
|
|
|
continue
|
|
|
|
|
2021-12-14 17:19:32 +08:00
|
|
|
config = Path(model_info.config)
|
2021-10-13 17:01:37 +08:00
|
|
|
assert config.exists(), f'{model_name}: {config} not found.'
|
|
|
|
|
|
|
|
logger.info(f'Processing: {model_name}')
|
|
|
|
|
|
|
|
http_prefix = 'https://download.openmmlab.com/mmclassification/'
|
|
|
|
if args.checkpoint_root is not None:
|
2021-12-14 17:19:32 +08:00
|
|
|
root = args.checkpoint_root
|
|
|
|
if 's3://' in args.checkpoint_root:
|
|
|
|
from petrel_client.common.exception import AccessDeniedError
|
|
|
|
file_client = FileClient.infer_client(uri=root)
|
|
|
|
checkpoint = file_client.join_path(
|
|
|
|
root, model_info.weights[len(http_prefix):])
|
|
|
|
try:
|
|
|
|
exists = file_client.exists(checkpoint)
|
|
|
|
except AccessDeniedError:
|
|
|
|
exists = False
|
|
|
|
else:
|
|
|
|
checkpoint = Path(root) / model_info.weights[len(http_prefix):]
|
|
|
|
exists = checkpoint.exists()
|
|
|
|
if exists:
|
|
|
|
checkpoint = str(checkpoint)
|
|
|
|
else:
|
|
|
|
print(f'WARNING: {model_name}: {checkpoint} not found.')
|
|
|
|
checkpoint = None
|
2021-10-13 17:01:37 +08:00
|
|
|
else:
|
|
|
|
checkpoint = None
|
|
|
|
|
|
|
|
try:
|
|
|
|
# build the model from a config file and a checkpoint file
|
2022-07-12 16:10:59 +08:00
|
|
|
result = inference(MMCLS_ROOT / config, checkpoint, tmpdir.name,
|
|
|
|
args, model_name)
|
2021-10-13 17:01:37 +08:00
|
|
|
result['valid'] = 'PASS'
|
2022-07-12 16:10:59 +08:00
|
|
|
except Exception:
|
|
|
|
import traceback
|
|
|
|
logger.error(f'"{config}" :\n{traceback.format_exc()}')
|
2021-10-13 17:01:37 +08:00
|
|
|
result = {'valid': 'FAIL'}
|
|
|
|
|
|
|
|
summary_data[model_name] = result
|
|
|
|
# show the results
|
|
|
|
if args.show:
|
2022-07-12 16:10:59 +08:00
|
|
|
vis = ClsVisualizer.get_instance('valid')
|
|
|
|
vis.set_image(mmcv.imread(args.img))
|
|
|
|
vis.draw_texts(
|
|
|
|
texts='\n'.join([f'{k}: {v}' for k, v in result.items()]),
|
|
|
|
positions=np.array([(5, 5)]))
|
|
|
|
vis.show(wait_time=args.wait_time)
|
|
|
|
|
|
|
|
tmpdir.cleanup()
|
2021-12-14 17:19:32 +08:00
|
|
|
show_summary(summary_data, args)
|
2021-10-13 17:01:37 +08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
args = parse_args()
|
|
|
|
main(args)
|