mirror of https://github.com/open-mmlab/mmcv.git
auto-connect pavi service before running
parent
95cb88535f
commit
f9d8870f45
|
@ -6,10 +6,10 @@ from .optimizer_stepper import OptimizerHook
|
|||
from .iter_timer import IterTimerHook
|
||||
from .sampler_seed import DistSamplerSeedHook
|
||||
from .logger import (LoggerHook, TextLoggerHook, PaviLoggerHook,
|
||||
pavi_hook_connect, TensorboardLoggerHook)
|
||||
TensorboardLoggerHook)
|
||||
|
||||
__all__ = [
|
||||
'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook',
|
||||
'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook', 'TextLoggerHook',
|
||||
'PaviLoggerHook', 'pavi_hook_connect', 'TensorboardLoggerHook'
|
||||
'PaviLoggerHook', 'TensorboardLoggerHook'
|
||||
]
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
from .base import LoggerHook
|
||||
from .pavi import PaviLoggerHook, pavi_hook_connect
|
||||
from .pavi import PaviLoggerHook
|
||||
from .tensorboard import TensorboardLoggerHook
|
||||
from .text import TextLoggerHook
|
||||
|
||||
__all__ = [
|
||||
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'pavi_hook_connect',
|
||||
'TensorboardLoggerHook'
|
||||
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook'
|
||||
]
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
import time
|
||||
from datetime import datetime
|
||||
from threading import Thread
|
||||
|
@ -20,6 +22,7 @@ class PaviClient(object):
|
|||
self.password = self._get_env_var(password, 'PAVI_PASSWORD')
|
||||
self.instance_id = instance_id
|
||||
self.log_queue = None
|
||||
self.logger = None
|
||||
|
||||
def _get_env_var(self, var, env_var):
|
||||
if var is not None:
|
||||
|
@ -32,25 +35,28 @@ class PaviClient(object):
|
|||
format(env_var))
|
||||
return var
|
||||
|
||||
def _print_log(self, msg, level=logging.INFO, *args, **kwargs):
|
||||
if self.logger is not None:
|
||||
self.logger.log(level, msg, *args, **kwargs)
|
||||
else:
|
||||
print(msg, *args, **kwargs)
|
||||
|
||||
def connect(self,
|
||||
model_name,
|
||||
work_dir=None,
|
||||
info=dict(),
|
||||
timeout=5,
|
||||
logger=None):
|
||||
if logger:
|
||||
log_info = logger.info
|
||||
log_error = logger.error
|
||||
else:
|
||||
log_info = log_error = print
|
||||
log_info('connecting pavi service {}...'.format(self.url))
|
||||
if logger is not None:
|
||||
self.logger = logger
|
||||
self._print_log('connecting pavi service {}...'.format(self.url))
|
||||
post_data = dict(
|
||||
time=str(datetime.now()),
|
||||
username=self.username,
|
||||
password=self.password,
|
||||
instance_id=self.instance_id,
|
||||
model=model_name,
|
||||
work_dir=os.path.abspath(work_dir) if work_dir else '',
|
||||
work_dir=osp.abspath(work_dir) if work_dir else '',
|
||||
session_file=info.get('session_file', ''),
|
||||
session_text=info.get('session_text', ''),
|
||||
model_text=info.get('model_text', ''),
|
||||
|
@ -58,21 +64,26 @@ class PaviClient(object):
|
|||
try:
|
||||
response = requests.post(self.url, json=post_data, timeout=timeout)
|
||||
except Exception as ex:
|
||||
log_error('fail to connect to pavi service: {}'.format(ex))
|
||||
self._print_log(
|
||||
'fail to connect to pavi service: {}'.format(ex),
|
||||
level=logging.ERROR)
|
||||
else:
|
||||
if response.status_code == 200:
|
||||
self.instance_id = response.text
|
||||
log_info('pavi service connected, instance_id: {}'.format(
|
||||
self.instance_id))
|
||||
self._print_log(
|
||||
'pavi service connected, instance_id: {}'.format(
|
||||
self.instance_id))
|
||||
self.log_queue = Queue()
|
||||
self.log_thread = Thread(target=self.post_worker_fn)
|
||||
self.log_thread.daemon = True
|
||||
self.log_thread.start()
|
||||
return True
|
||||
else:
|
||||
log_error('fail to connect to pavi service, status code: '
|
||||
'{}, err message: {}'.format(response.status_code,
|
||||
response.reason))
|
||||
self._print_log(
|
||||
'fail to connect to pavi service, status code: '
|
||||
'{}, err message: {}'.format(response.status_code,
|
||||
response.reason),
|
||||
level=logging.ERROR)
|
||||
return False
|
||||
|
||||
def post_worker_fn(self, max_retry=3, queue_timeout=1, req_timeout=3):
|
||||
|
@ -82,7 +93,9 @@ class PaviClient(object):
|
|||
except Empty:
|
||||
time.sleep(1)
|
||||
except Exception as ex:
|
||||
print('fail to get logs from queue: {}'.format(ex))
|
||||
self._print_log(
|
||||
'fail to get logs from queue: {}'.format(ex),
|
||||
level=logging.ERROR)
|
||||
else:
|
||||
retry = 0
|
||||
while retry < max_retry:
|
||||
|
@ -91,17 +104,24 @@ class PaviClient(object):
|
|||
self.url, json=log, timeout=req_timeout)
|
||||
except Exception as ex:
|
||||
retry += 1
|
||||
print('error when posting logs to pavi: {}'.format(ex))
|
||||
self._print_log(
|
||||
'error when posting logs to pavi: {}'.format(ex),
|
||||
level=logging.ERROR)
|
||||
else:
|
||||
status_code = response.status_code
|
||||
if status_code == 200:
|
||||
break
|
||||
else:
|
||||
print('unexpected status code: %d, err msg: %s',
|
||||
status_code, response.reason)
|
||||
self._print_log(
|
||||
'unexpected status code: %d, err msg: {}'.
|
||||
format(status_code, response.reason),
|
||||
level=logging.ERROR)
|
||||
retry += 1
|
||||
if retry == max_retry:
|
||||
print('fail to send logs of iteration %d', log['iter_num'])
|
||||
self._print_log(
|
||||
'fail to send logs of iteration {}'.format(
|
||||
log['iter_num']),
|
||||
level=logging.ERROR)
|
||||
|
||||
def log(self, phase, iter, outputs):
|
||||
if self.log_queue is not None:
|
||||
|
@ -123,21 +143,29 @@ class PaviLoggerHook(LoggerHook):
|
|||
username=None,
|
||||
password=None,
|
||||
instance_id=None,
|
||||
config_file=None,
|
||||
interval=10,
|
||||
reset_meter=True,
|
||||
ignore_last=True):
|
||||
self.pavi = PaviClient(url, username, password, instance_id)
|
||||
self.config_file = config_file
|
||||
super(PaviLoggerHook, self).__init__(interval, reset_meter,
|
||||
ignore_last)
|
||||
|
||||
def before_run(self, runner):
|
||||
super(PaviLoggerHook, self).before_run(runner)
|
||||
self.connect(runner)
|
||||
|
||||
@master_only
|
||||
def connect(self,
|
||||
model_name,
|
||||
work_dir=None,
|
||||
info=dict(),
|
||||
timeout=5,
|
||||
logger=None):
|
||||
return self.pavi.connect(model_name, work_dir, info, timeout, logger)
|
||||
def connect(self, runner, timeout=5):
|
||||
cfg_info = dict()
|
||||
if self.config_file is not None:
|
||||
with open(self.config_file, 'r') as f:
|
||||
config_text = f.read()
|
||||
cfg_info.update(
|
||||
session_file=self.config_file, session_text=config_text)
|
||||
return self.pavi.connect(runner.model_name, runner.work_dir, cfg_info,
|
||||
timeout, runner.logger)
|
||||
|
||||
@master_only
|
||||
def log(self, runner):
|
||||
|
@ -145,17 +173,3 @@ class PaviLoggerHook(LoggerHook):
|
|||
log_outs.pop('time', None)
|
||||
log_outs.pop('data_time', None)
|
||||
self.pavi.log(runner.mode, runner.iter, log_outs)
|
||||
|
||||
|
||||
def pavi_hook_connect(runner, cfg_filename, cfg_text):
|
||||
for hook in runner.hooks:
|
||||
if isinstance(hook, PaviLoggerHook):
|
||||
hook.connect(
|
||||
runner.model_name,
|
||||
runner.work_dir,
|
||||
info={
|
||||
'session_file': cfg_filename,
|
||||
'session_text': cfg_text
|
||||
},
|
||||
logger=runner.logger)
|
||||
break
|
||||
|
|
Loading…
Reference in New Issue