mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Enhance] Support --batch-size
option in the validation benchmark tool.
This commit is contained in:
parent
38bea383c1
commit
23cad6a0e1
@ -48,6 +48,11 @@ def parse_args():
|
||||
'--inference-time',
|
||||
action='store_true',
|
||||
help='Test inference time by run 10 times for each model.')
|
||||
parser.add_argument(
|
||||
'--batch-size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='The batch size during the inference.')
|
||||
parser.add_argument(
|
||||
'--flops', action='store_true', help='Get Flops and Params of models')
|
||||
parser.add_argument(
|
||||
@ -86,7 +91,7 @@ def inference(config_file, checkpoint, work_dir, args, exp_name):
|
||||
test_dataset.pipeline.insert(1, dict(type='Resize', scale=32))
|
||||
|
||||
data = Compose(test_dataset.pipeline)({'img_path': args.img})
|
||||
data = default_collate([data])
|
||||
data = default_collate([data] * args.batch_size)
|
||||
resolution = tuple(data['inputs'].shape[-2:])
|
||||
|
||||
runner: Runner = Runner.from_cfg(cfg)
|
||||
@ -103,7 +108,7 @@ def inference(config_file, checkpoint, work_dir, args, exp_name):
|
||||
start = time()
|
||||
model.val_step(data)
|
||||
torch.cuda.synchronize()
|
||||
time_record.append((time() - start) * 1000)
|
||||
time_record.append((time() - start) / args.batch_size * 1000)
|
||||
result['time_mean'] = np.mean(time_record[1:-1])
|
||||
result['time_std'] = np.std(time_record[1:-1])
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user