diff --git a/mmcv/torchpack/hooks/__init__.py b/mmcv/torchpack/hooks/__init__.py index 495021c22..67fff16fc 100644 --- a/mmcv/torchpack/hooks/__init__.py +++ b/mmcv/torchpack/hooks/__init__.py @@ -6,10 +6,10 @@ from .optimizer_stepper import OptimizerHook from .iter_timer import IterTimerHook from .sampler_seed import DistSamplerSeedHook from .logger import (LoggerHook, TextLoggerHook, PaviLoggerHook, - pavi_hook_connect, TensorboardLoggerHook) + TensorboardLoggerHook) __all__ = [ 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook', 'TextLoggerHook', - 'PaviLoggerHook', 'pavi_hook_connect', 'TensorboardLoggerHook' + 'PaviLoggerHook', 'TensorboardLoggerHook' ] diff --git a/mmcv/torchpack/hooks/logger/__init__.py b/mmcv/torchpack/hooks/logger/__init__.py index 25b77c245..8cbaf12b1 100644 --- a/mmcv/torchpack/hooks/logger/__init__.py +++ b/mmcv/torchpack/hooks/logger/__init__.py @@ -1,9 +1,8 @@ from .base import LoggerHook -from .pavi import PaviLoggerHook, pavi_hook_connect +from .pavi import PaviLoggerHook from .tensorboard import TensorboardLoggerHook from .text import TextLoggerHook __all__ = [ - 'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'pavi_hook_connect', - 'TensorboardLoggerHook' + 'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook' ] diff --git a/mmcv/torchpack/hooks/logger/pavi.py b/mmcv/torchpack/hooks/logger/pavi.py index a82bb98fe..669c4555b 100644 --- a/mmcv/torchpack/hooks/logger/pavi.py +++ b/mmcv/torchpack/hooks/logger/pavi.py @@ -1,6 +1,8 @@ from __future__ import print_function +import logging import os +import os.path as osp import time from datetime import datetime from threading import Thread @@ -20,6 +22,7 @@ class PaviClient(object): self.password = self._get_env_var(password, 'PAVI_PASSWORD') self.instance_id = instance_id self.log_queue = None + self.logger = None def _get_env_var(self, var, env_var): if var is not None: @@ -32,25 +35,28 @@ class PaviClient(object): format(env_var)) return var + def _print_log(self, msg, level=logging.INFO, *args, **kwargs): + if self.logger is not None: + self.logger.log(level, msg, *args, **kwargs) + else: + print(msg, *args, **kwargs) + def connect(self, model_name, work_dir=None, info=dict(), timeout=5, logger=None): - if logger: - log_info = logger.info - log_error = logger.error - else: - log_info = log_error = print - log_info('connecting pavi service {}...'.format(self.url)) + if logger is not None: + self.logger = logger + self._print_log('connecting pavi service {}...'.format(self.url)) post_data = dict( time=str(datetime.now()), username=self.username, password=self.password, instance_id=self.instance_id, model=model_name, - work_dir=os.path.abspath(work_dir) if work_dir else '', + work_dir=osp.abspath(work_dir) if work_dir else '', session_file=info.get('session_file', ''), session_text=info.get('session_text', ''), model_text=info.get('model_text', ''), @@ -58,21 +64,26 @@ class PaviClient(object): try: response = requests.post(self.url, json=post_data, timeout=timeout) except Exception as ex: - log_error('fail to connect to pavi service: {}'.format(ex)) + self._print_log( + 'fail to connect to pavi service: {}'.format(ex), + level=logging.ERROR) else: if response.status_code == 200: self.instance_id = response.text - log_info('pavi service connected, instance_id: {}'.format( - self.instance_id)) + self._print_log( + 'pavi service connected, instance_id: {}'.format( + self.instance_id)) self.log_queue = Queue() self.log_thread = Thread(target=self.post_worker_fn) self.log_thread.daemon = True self.log_thread.start() return True else: - log_error('fail to connect to pavi service, status code: ' - '{}, err message: {}'.format(response.status_code, - response.reason)) + self._print_log( + 'fail to connect to pavi service, status code: ' + '{}, err message: {}'.format(response.status_code, + response.reason), + level=logging.ERROR) return False def post_worker_fn(self, max_retry=3, queue_timeout=1, req_timeout=3): @@ -82,7 +93,9 @@ class PaviClient(object): except Empty: time.sleep(1) except Exception as ex: - print('fail to get logs from queue: {}'.format(ex)) + self._print_log( + 'fail to get logs from queue: {}'.format(ex), + level=logging.ERROR) else: retry = 0 while retry < max_retry: @@ -91,17 +104,24 @@ class PaviClient(object): self.url, json=log, timeout=req_timeout) except Exception as ex: retry += 1 - print('error when posting logs to pavi: {}'.format(ex)) + self._print_log( + 'error when posting logs to pavi: {}'.format(ex), + level=logging.ERROR) else: status_code = response.status_code if status_code == 200: break else: - print('unexpected status code: %d, err msg: %s', - status_code, response.reason) + self._print_log( + 'unexpected status code: %d, err msg: {}'. + format(status_code, response.reason), + level=logging.ERROR) retry += 1 if retry == max_retry: - print('fail to send logs of iteration %d', log['iter_num']) + self._print_log( + 'fail to send logs of iteration {}'.format( + log['iter_num']), + level=logging.ERROR) def log(self, phase, iter, outputs): if self.log_queue is not None: @@ -123,21 +143,29 @@ class PaviLoggerHook(LoggerHook): username=None, password=None, instance_id=None, + config_file=None, interval=10, reset_meter=True, ignore_last=True): self.pavi = PaviClient(url, username, password, instance_id) + self.config_file = config_file super(PaviLoggerHook, self).__init__(interval, reset_meter, ignore_last) + def before_run(self, runner): + super(PaviLoggerHook, self).before_run(runner) + self.connect(runner) + @master_only - def connect(self, - model_name, - work_dir=None, - info=dict(), - timeout=5, - logger=None): - return self.pavi.connect(model_name, work_dir, info, timeout, logger) + def connect(self, runner, timeout=5): + cfg_info = dict() + if self.config_file is not None: + with open(self.config_file, 'r') as f: + config_text = f.read() + cfg_info.update( + session_file=self.config_file, session_text=config_text) + return self.pavi.connect(runner.model_name, runner.work_dir, cfg_info, + timeout, runner.logger) @master_only def log(self, runner): @@ -145,17 +173,3 @@ class PaviLoggerHook(LoggerHook): log_outs.pop('time', None) log_outs.pop('data_time', None) self.pavi.log(runner.mode, runner.iter, log_outs) - - -def pavi_hook_connect(runner, cfg_filename, cfg_text): - for hook in runner.hooks: - if isinstance(hook, PaviLoggerHook): - hook.connect( - runner.model_name, - runner.work_dir, - info={ - 'session_file': cfg_filename, - 'session_text': cfg_text - }, - logger=runner.logger) - break