update timer (#1179)

pull/1189/head
AllentDan 2022-10-11 14:59:50 +08:00 committed by GitHub
parent b406c1ab0a
commit 645eefae50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 28 additions and 11 deletions

View File

@ -36,9 +36,10 @@ def process_model_config(model_cfg: mmengine.Config,
pipeline[i].meta_keys = tuple(j for j in pipeline[i].meta_keys
if j != 'instances')
# for static exporting
if input_shape is not None and transform.type == 'Resize':
pipeline[i].keep_ratio = False
pipeline[i].scale = tuple(input_shape)
if input_shape is not None:
if transform.type in ('Resize', 'ShortScaleAspectJitter'):
pipeline[i] = mmengine.ConfigDict(
dict(type='Resize', scale=input_shape, keep_ratio=False))
pipeline = [
transform for transform in pipeline

View File

@ -2,6 +2,7 @@
import time
import warnings
from contextlib import contextmanager
from logging import Logger
from typing import Optional
import numpy as np
@ -43,6 +44,7 @@ class TimeCounter:
log_interval=log_interval,
warmup=warmup,
with_sync=with_sync,
batch_size=1,
enable=False)
def fun(*args, **kwargs):
@ -51,6 +53,7 @@ class TimeCounter:
log_interval = cls.names[name]['log_interval']
warmup = cls.names[name]['warmup']
with_sync = cls.names[name]['with_sync']
batch_size = cls.names[name]['batch_size']
enable = cls.names[name]['enable']
count += 1
@ -66,7 +69,7 @@ class TimeCounter:
if enable:
if with_sync and torch.cuda.is_available():
torch.cuda.synchronize()
elapsed = time.perf_counter() - start_time
elapsed = (time.perf_counter() - start_time) / batch_size
if enable and count > warmup:
execute_time.append(elapsed)
@ -93,7 +96,9 @@ class TimeCounter:
log_interval: int = 1,
with_sync: bool = False,
file: Optional[str] = None,
logger=None):
logger: Optional[Logger] = None,
batch_size: int = 1,
**kwargs):
"""Activate the time counter.
Args:
@ -105,6 +110,8 @@ class TimeCounter:
default False.
file (str | None): The file to save output messages. The default
is `None`.
logger (Logger): The logger for the timer. Default to None.
batch_size (int): The batch size. Default to 1.
"""
assert warmup >= 1
if logger is None:
@ -118,12 +125,14 @@ class TimeCounter:
cls.names[func_name]['warmup'] = warmup
cls.names[func_name]['log_interval'] = log_interval
cls.names[func_name]['with_sync'] = with_sync
cls.names[func_name]['batch_size'] = batch_size
cls.names[func_name]['enable'] = True
else:
for name in cls.names:
cls.names[name]['warmup'] = warmup
cls.names[name]['log_interval'] = log_interval
cls.names[name]['with_sync'] = with_sync
cls.names[name]['batch_size'] = batch_size
cls.names[name]['enable'] = True
yield
if func_name is not None:

View File

@ -112,12 +112,12 @@ def main():
is_pytorch = model_ext in ['.pth', '.pt']
if is_pytorch:
# load pytorch model
model = task_processor.init_pytorch_model(args.model[0])
model = task_processor.build_pytorch_model(args.model[0])
model = TorchWrapper(model)
backend = Backend.PYTORCH.value
else:
# load the model of the backend
model = task_processor.init_backend_model(args.model)
model = task_processor.build_backend_model(args.model)
backend = get_backend(deploy_cfg).value
model = model.eval().to(args.device)
@ -140,11 +140,17 @@ def main():
]
image_files = image_files[:total_nrof_image]
with TimeCounter.activate(
warmup=args.warmup, log_interval=20, with_sync=with_sync):
warmup=args.warmup,
log_interval=20,
with_sync=with_sync,
batch_size=args.batch_size):
for i in range(0, total_nrof_image, args.batch_size):
batch_files = image_files[i:(i + args.batch_size)]
data, _ = task_processor.create_input(batch_files, input_shape)
task_processor.run_inference(model, data)
data, _ = task_processor.create_input(
batch_files,
input_shape,
data_preprocessor=getattr(model, 'data_preprocessor', None))
model.test_step(data)
print('----- Settings:')
settings = PrettyTable()

View File

@ -143,7 +143,8 @@ def main():
warmup=args.warmup,
log_interval=args.log_interval,
with_sync=with_sync,
file=args.log2file):
file=args.log2file,
batch_size=test_dataloader['batch_size']):
runner.test()
else:
runner.test()