mirror of https://github.com/alibaba/EasyCV.git
32 lines
903 B
Python
32 lines
903 B
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import time
|
|
|
|
from mmcv.runner import Hook, get_dist_info
|
|
|
|
from .registry import HOOKS
|
|
|
|
|
|
@HOOKS.register_module
|
|
class TIMEHook(Hook):
|
|
'''
|
|
This hook to show time for runner running process
|
|
'''
|
|
|
|
def __init__(self, end_momentum=1., **kwargs):
|
|
self.end_infer = 0
|
|
self.rank, self.num_replicas = get_dist_info()
|
|
self.now_time = lambda: int(round(time.time() * 1000))
|
|
|
|
def before_train_iter(self, runner):
|
|
self.end_load = self.now_time()
|
|
if self.rank == 0:
|
|
print(self.rank,
|
|
' load data need : %d ms' % (self.end_load - self.end_infer))
|
|
|
|
def after_train_iter(self, runner):
|
|
self.end_infer = self.now_time()
|
|
if self.rank == 0:
|
|
print(
|
|
self.rank,
|
|
' infer model need : %d ms' % (self.end_infer - self.end_load))
|