auto-connect pavi service before running

pull/14/head
Kai Chen 2018-10-01 18:46:56 +08:00
parent 95cb88535f
commit f9d8870f45
3 changed files with 57 additions and 44 deletions

View File

@ -6,10 +6,10 @@ from .optimizer_stepper import OptimizerHook
from .iter_timer import IterTimerHook
from .sampler_seed import DistSamplerSeedHook
from .logger import (LoggerHook, TextLoggerHook, PaviLoggerHook,
pavi_hook_connect, TensorboardLoggerHook)
TensorboardLoggerHook)
__all__ = [
'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook',
'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook', 'TextLoggerHook',
'PaviLoggerHook', 'pavi_hook_connect', 'TensorboardLoggerHook'
'PaviLoggerHook', 'TensorboardLoggerHook'
]

View File

@ -1,9 +1,8 @@
from .base import LoggerHook
from .pavi import PaviLoggerHook, pavi_hook_connect
from .pavi import PaviLoggerHook
from .tensorboard import TensorboardLoggerHook
from .text import TextLoggerHook
__all__ = [
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'pavi_hook_connect',
'TensorboardLoggerHook'
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook'
]

View File

@ -1,6 +1,8 @@
from __future__ import print_function
import logging
import os
import os.path as osp
import time
from datetime import datetime
from threading import Thread
@ -20,6 +22,7 @@ class PaviClient(object):
self.password = self._get_env_var(password, 'PAVI_PASSWORD')
self.instance_id = instance_id
self.log_queue = None
self.logger = None
def _get_env_var(self, var, env_var):
if var is not None:
@ -32,25 +35,28 @@ class PaviClient(object):
format(env_var))
return var
def _print_log(self, msg, level=logging.INFO, *args, **kwargs):
if self.logger is not None:
self.logger.log(level, msg, *args, **kwargs)
else:
print(msg, *args, **kwargs)
def connect(self,
model_name,
work_dir=None,
info=dict(),
timeout=5,
logger=None):
if logger:
log_info = logger.info
log_error = logger.error
else:
log_info = log_error = print
log_info('connecting pavi service {}...'.format(self.url))
if logger is not None:
self.logger = logger
self._print_log('connecting pavi service {}...'.format(self.url))
post_data = dict(
time=str(datetime.now()),
username=self.username,
password=self.password,
instance_id=self.instance_id,
model=model_name,
work_dir=os.path.abspath(work_dir) if work_dir else '',
work_dir=osp.abspath(work_dir) if work_dir else '',
session_file=info.get('session_file', ''),
session_text=info.get('session_text', ''),
model_text=info.get('model_text', ''),
@ -58,21 +64,26 @@ class PaviClient(object):
try:
response = requests.post(self.url, json=post_data, timeout=timeout)
except Exception as ex:
log_error('fail to connect to pavi service: {}'.format(ex))
self._print_log(
'fail to connect to pavi service: {}'.format(ex),
level=logging.ERROR)
else:
if response.status_code == 200:
self.instance_id = response.text
log_info('pavi service connected, instance_id: {}'.format(
self.instance_id))
self._print_log(
'pavi service connected, instance_id: {}'.format(
self.instance_id))
self.log_queue = Queue()
self.log_thread = Thread(target=self.post_worker_fn)
self.log_thread.daemon = True
self.log_thread.start()
return True
else:
log_error('fail to connect to pavi service, status code: '
'{}, err message: {}'.format(response.status_code,
response.reason))
self._print_log(
'fail to connect to pavi service, status code: '
'{}, err message: {}'.format(response.status_code,
response.reason),
level=logging.ERROR)
return False
def post_worker_fn(self, max_retry=3, queue_timeout=1, req_timeout=3):
@ -82,7 +93,9 @@ class PaviClient(object):
except Empty:
time.sleep(1)
except Exception as ex:
print('fail to get logs from queue: {}'.format(ex))
self._print_log(
'fail to get logs from queue: {}'.format(ex),
level=logging.ERROR)
else:
retry = 0
while retry < max_retry:
@ -91,17 +104,24 @@ class PaviClient(object):
self.url, json=log, timeout=req_timeout)
except Exception as ex:
retry += 1
print('error when posting logs to pavi: {}'.format(ex))
self._print_log(
'error when posting logs to pavi: {}'.format(ex),
level=logging.ERROR)
else:
status_code = response.status_code
if status_code == 200:
break
else:
print('unexpected status code: %d, err msg: %s',
status_code, response.reason)
self._print_log(
'unexpected status code: %d, err msg: {}'.
format(status_code, response.reason),
level=logging.ERROR)
retry += 1
if retry == max_retry:
print('fail to send logs of iteration %d', log['iter_num'])
self._print_log(
'fail to send logs of iteration {}'.format(
log['iter_num']),
level=logging.ERROR)
def log(self, phase, iter, outputs):
if self.log_queue is not None:
@ -123,21 +143,29 @@ class PaviLoggerHook(LoggerHook):
username=None,
password=None,
instance_id=None,
config_file=None,
interval=10,
reset_meter=True,
ignore_last=True):
self.pavi = PaviClient(url, username, password, instance_id)
self.config_file = config_file
super(PaviLoggerHook, self).__init__(interval, reset_meter,
ignore_last)
def before_run(self, runner):
super(PaviLoggerHook, self).before_run(runner)
self.connect(runner)
@master_only
def connect(self,
model_name,
work_dir=None,
info=dict(),
timeout=5,
logger=None):
return self.pavi.connect(model_name, work_dir, info, timeout, logger)
def connect(self, runner, timeout=5):
cfg_info = dict()
if self.config_file is not None:
with open(self.config_file, 'r') as f:
config_text = f.read()
cfg_info.update(
session_file=self.config_file, session_text=config_text)
return self.pavi.connect(runner.model_name, runner.work_dir, cfg_info,
timeout, runner.logger)
@master_only
def log(self, runner):
@ -145,17 +173,3 @@ class PaviLoggerHook(LoggerHook):
log_outs.pop('time', None)
log_outs.pop('data_time', None)
self.pavi.log(runner.mode, runner.iter, log_outs)
def pavi_hook_connect(runner, cfg_filename, cfg_text):
for hook in runner.hooks:
if isinstance(hook, PaviLoggerHook):
hook.connect(
runner.model_name,
runner.work_dir,
info={
'session_file': cfg_filename,
'session_text': cfg_text
},
logger=runner.logger)
break