mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Enhance] Support --batch-size
option in the validation benchmark tool.
This commit is contained in:
parent
38bea383c1
commit
23cad6a0e1
@ -48,6 +48,11 @@ def parse_args():
|
|||||||
'--inference-time',
|
'--inference-time',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='Test inference time by run 10 times for each model.')
|
help='Test inference time by run 10 times for each model.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--batch-size',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help='The batch size during the inference.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--flops', action='store_true', help='Get Flops and Params of models')
|
'--flops', action='store_true', help='Get Flops and Params of models')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -86,7 +91,7 @@ def inference(config_file, checkpoint, work_dir, args, exp_name):
|
|||||||
test_dataset.pipeline.insert(1, dict(type='Resize', scale=32))
|
test_dataset.pipeline.insert(1, dict(type='Resize', scale=32))
|
||||||
|
|
||||||
data = Compose(test_dataset.pipeline)({'img_path': args.img})
|
data = Compose(test_dataset.pipeline)({'img_path': args.img})
|
||||||
data = default_collate([data])
|
data = default_collate([data] * args.batch_size)
|
||||||
resolution = tuple(data['inputs'].shape[-2:])
|
resolution = tuple(data['inputs'].shape[-2:])
|
||||||
|
|
||||||
runner: Runner = Runner.from_cfg(cfg)
|
runner: Runner = Runner.from_cfg(cfg)
|
||||||
@ -103,7 +108,7 @@ def inference(config_file, checkpoint, work_dir, args, exp_name):
|
|||||||
start = time()
|
start = time()
|
||||||
model.val_step(data)
|
model.val_step(data)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
time_record.append((time() - start) * 1000)
|
time_record.append((time() - start) / args.batch_size * 1000)
|
||||||
result['time_mean'] = np.mean(time_record[1:-1])
|
result['time_mean'] = np.mean(time_record[1:-1])
|
||||||
result['time_std'] = np.std(time_record[1:-1])
|
result['time_std'] = np.std(time_record[1:-1])
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user