remove the outdated PaviLoggerHook and add requirements.txt (#185)

This commit is contained in:
Kai Chen 2020-02-12 22:46:12 +08:00 committed by GitHub
parent a500a64621
commit f4272a8875
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 8 additions and 187 deletions

View File

@ -1,11 +1,9 @@
# Copyright (c) Open-MMLab. All rights reserved.
from .base import LoggerHook
from .pavi import PaviLoggerHook
from .tensorboard import TensorboardLoggerHook
from .text import TextLoggerHook
from .wandb import WandbLoggerHook
__all__ = [
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook',
'WandbLoggerHook'
'LoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', 'WandbLoggerHook'
]

View File

@ -1,178 +0,0 @@
# Copyright (c) Open-MMLab. All rights reserved.
from __future__ import print_function
import logging
import os
import os.path as osp
import time
from datetime import datetime
from threading import Thread
import requests
from six.moves.queue import Empty, Queue
from ...dist_utils import master_only
from ...utils import get_host_info
from .base import LoggerHook
class PaviClient(object):
def __init__(self, url, username=None, password=None, instance_id=None):
self.url = url
self.username = self._get_env_var(username, 'PAVI_USERNAME')
self.password = self._get_env_var(password, 'PAVI_PASSWORD')
self.instance_id = instance_id
self.log_queue = None
self.logger = None
def _get_env_var(self, var, env_var):
if var is not None:
return str(var)
var = os.getenv(env_var)
if not var:
raise ValueError(
'"{}" is neither specified nor defined as env variables'.
format(env_var))
return var
def _print_log(self, msg, level=logging.INFO, *args, **kwargs):
if self.logger is not None:
self.logger.log(level, msg, *args, **kwargs)
else:
print(msg, *args, **kwargs)
def connect(self,
model_name,
work_dir=None,
info=dict(),
timeout=5,
logger=None):
if logger is not None:
self.logger = logger
self._print_log('connecting pavi service {}...'.format(self.url))
post_data = dict(
time=str(datetime.now()),
username=self.username,
password=self.password,
instance_id=self.instance_id,
model=model_name,
work_dir=osp.abspath(work_dir) if work_dir else '',
session_file=info.get('session_file', ''),
session_text=info.get('session_text', ''),
model_text=info.get('model_text', ''),
device=get_host_info())
try:
response = requests.post(self.url, json=post_data, timeout=timeout)
except Exception as ex:
self._print_log(
'fail to connect to pavi service: {}'.format(ex),
level=logging.ERROR)
else:
if response.status_code == 200:
self.instance_id = response.text
self._print_log(
'pavi service connected, instance_id: {}'.format(
self.instance_id))
self.log_queue = Queue()
self.log_thread = Thread(target=self.post_worker_fn)
self.log_thread.daemon = True
self.log_thread.start()
return True
else:
self._print_log(
'fail to connect to pavi service, status code: '
'{}, err message: {}'.format(response.status_code,
response.reason),
level=logging.ERROR)
return False
def post_worker_fn(self, max_retry=3, queue_timeout=1, req_timeout=3):
while True:
try:
log = self.log_queue.get(timeout=queue_timeout)
except Empty:
time.sleep(1)
except Exception as ex:
self._print_log(
'fail to get logs from queue: {}'.format(ex),
level=logging.ERROR)
else:
retry = 0
while retry < max_retry:
try:
response = requests.post(
self.url, json=log, timeout=req_timeout)
except Exception as ex:
retry += 1
self._print_log(
'error when posting logs to pavi: {}'.format(ex),
level=logging.ERROR)
else:
status_code = response.status_code
if status_code == 200:
break
else:
self._print_log(
'unexpected status code: {}, err msg: {}'.
format(status_code, response.reason),
level=logging.ERROR)
retry += 1
if retry == max_retry:
self._print_log(
'fail to send logs of iteration {}'.format(
log['iter_num']),
level=logging.ERROR)
def log(self, phase, iter, outputs):
if self.log_queue is not None:
logs = {
'time': str(datetime.now()),
'instance_id': self.instance_id,
'flow_id': phase,
'iter_num': iter,
'outputs': outputs,
'msg': ''
}
self.log_queue.put(logs)
class PaviLoggerHook(LoggerHook):
def __init__(self,
url,
username=None,
password=None,
instance_id=None,
config_file=None,
interval=10,
ignore_last=True,
reset_flag=True):
self.pavi = PaviClient(url, username, password, instance_id)
self.config_file = config_file
super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag)
def before_run(self, runner):
super(PaviLoggerHook, self).before_run(runner)
self.connect(runner)
@master_only
def connect(self, runner, timeout=5):
cfg_info = dict()
if self.config_file is not None:
with open(self.config_file, 'r') as f:
config_text = f.read()
cfg_info.update(
session_file=self.config_file, session_text=config_text)
return self.pavi.connect(runner.model_name, runner.work_dir, cfg_info,
timeout, runner.logger)
@master_only
def log(self, runner):
log_outs = runner.log_buffer.output.copy()
log_outs.pop('time', None)
log_outs.pop('data_time', None)
for k, v in log_outs.items():
if isinstance(v, str):
log_outs.pop(k)
self.pavi.log(runner.mode, runner.iter + 1, log_outs)

4
requirements.txt Normal file
View File

@ -0,0 +1,4 @@
addict
numpy
pyyaml
six

View File

@ -17,6 +17,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = mmcv
known_third_party = Cython,addict,cv2,mock,numpy,pytest,requests,resnet_cifar,six,torch,torchvision,yaml
known_third_party = Cython,addict,cv2,mock,numpy,pytest,resnet_cifar,six,torch,torchvision,yaml
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY

View File

@ -25,12 +25,9 @@ def choose_requirement(primary, secondary):
install_requires = [
'numpy>=1.11.1',
'pyyaml',
'six',
'addict',
'requests',
line.strip() for line in open('requirements.txt', 'r').readlines()
]
# If first not installed install second package
CHOOSE_INSTALL_REQUIRES = [('opencv-python-headless>=3', 'opencv-python>=3')]