update timer (#1179)
parent
b406c1ab0a
commit
645eefae50
|
@ -36,9 +36,10 @@ def process_model_config(model_cfg: mmengine.Config,
|
|||
pipeline[i].meta_keys = tuple(j for j in pipeline[i].meta_keys
|
||||
if j != 'instances')
|
||||
# for static exporting
|
||||
if input_shape is not None and transform.type == 'Resize':
|
||||
pipeline[i].keep_ratio = False
|
||||
pipeline[i].scale = tuple(input_shape)
|
||||
if input_shape is not None:
|
||||
if transform.type in ('Resize', 'ShortScaleAspectJitter'):
|
||||
pipeline[i] = mmengine.ConfigDict(
|
||||
dict(type='Resize', scale=input_shape, keep_ratio=False))
|
||||
|
||||
pipeline = [
|
||||
transform for transform in pipeline
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
import time
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from logging import Logger
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
@ -43,6 +44,7 @@ class TimeCounter:
|
|||
log_interval=log_interval,
|
||||
warmup=warmup,
|
||||
with_sync=with_sync,
|
||||
batch_size=1,
|
||||
enable=False)
|
||||
|
||||
def fun(*args, **kwargs):
|
||||
|
@ -51,6 +53,7 @@ class TimeCounter:
|
|||
log_interval = cls.names[name]['log_interval']
|
||||
warmup = cls.names[name]['warmup']
|
||||
with_sync = cls.names[name]['with_sync']
|
||||
batch_size = cls.names[name]['batch_size']
|
||||
enable = cls.names[name]['enable']
|
||||
|
||||
count += 1
|
||||
|
@ -66,7 +69,7 @@ class TimeCounter:
|
|||
if enable:
|
||||
if with_sync and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.perf_counter() - start_time
|
||||
elapsed = (time.perf_counter() - start_time) / batch_size
|
||||
|
||||
if enable and count > warmup:
|
||||
execute_time.append(elapsed)
|
||||
|
@ -93,7 +96,9 @@ class TimeCounter:
|
|||
log_interval: int = 1,
|
||||
with_sync: bool = False,
|
||||
file: Optional[str] = None,
|
||||
logger=None):
|
||||
logger: Optional[Logger] = None,
|
||||
batch_size: int = 1,
|
||||
**kwargs):
|
||||
"""Activate the time counter.
|
||||
|
||||
Args:
|
||||
|
@ -105,6 +110,8 @@ class TimeCounter:
|
|||
default False.
|
||||
file (str | None): The file to save output messages. The default
|
||||
is `None`.
|
||||
logger (Logger): The logger for the timer. Default to None.
|
||||
batch_size (int): The batch size. Default to 1.
|
||||
"""
|
||||
assert warmup >= 1
|
||||
if logger is None:
|
||||
|
@ -118,12 +125,14 @@ class TimeCounter:
|
|||
cls.names[func_name]['warmup'] = warmup
|
||||
cls.names[func_name]['log_interval'] = log_interval
|
||||
cls.names[func_name]['with_sync'] = with_sync
|
||||
cls.names[func_name]['batch_size'] = batch_size
|
||||
cls.names[func_name]['enable'] = True
|
||||
else:
|
||||
for name in cls.names:
|
||||
cls.names[name]['warmup'] = warmup
|
||||
cls.names[name]['log_interval'] = log_interval
|
||||
cls.names[name]['with_sync'] = with_sync
|
||||
cls.names[name]['batch_size'] = batch_size
|
||||
cls.names[name]['enable'] = True
|
||||
yield
|
||||
if func_name is not None:
|
||||
|
|
|
@ -112,12 +112,12 @@ def main():
|
|||
is_pytorch = model_ext in ['.pth', '.pt']
|
||||
if is_pytorch:
|
||||
# load pytorch model
|
||||
model = task_processor.init_pytorch_model(args.model[0])
|
||||
model = task_processor.build_pytorch_model(args.model[0])
|
||||
model = TorchWrapper(model)
|
||||
backend = Backend.PYTORCH.value
|
||||
else:
|
||||
# load the model of the backend
|
||||
model = task_processor.init_backend_model(args.model)
|
||||
model = task_processor.build_backend_model(args.model)
|
||||
backend = get_backend(deploy_cfg).value
|
||||
|
||||
model = model.eval().to(args.device)
|
||||
|
@ -140,11 +140,17 @@ def main():
|
|||
]
|
||||
image_files = image_files[:total_nrof_image]
|
||||
with TimeCounter.activate(
|
||||
warmup=args.warmup, log_interval=20, with_sync=with_sync):
|
||||
warmup=args.warmup,
|
||||
log_interval=20,
|
||||
with_sync=with_sync,
|
||||
batch_size=args.batch_size):
|
||||
for i in range(0, total_nrof_image, args.batch_size):
|
||||
batch_files = image_files[i:(i + args.batch_size)]
|
||||
data, _ = task_processor.create_input(batch_files, input_shape)
|
||||
task_processor.run_inference(model, data)
|
||||
data, _ = task_processor.create_input(
|
||||
batch_files,
|
||||
input_shape,
|
||||
data_preprocessor=getattr(model, 'data_preprocessor', None))
|
||||
model.test_step(data)
|
||||
|
||||
print('----- Settings:')
|
||||
settings = PrettyTable()
|
||||
|
|
|
@ -143,7 +143,8 @@ def main():
|
|||
warmup=args.warmup,
|
||||
log_interval=args.log_interval,
|
||||
with_sync=with_sync,
|
||||
file=args.log2file):
|
||||
file=args.log2file,
|
||||
batch_size=test_dataloader['batch_size']):
|
||||
runner.test()
|
||||
else:
|
||||
runner.test()
|
||||
|
|
Loading…
Reference in New Issue