[Enhancement] .dev Python files updated to get better performance and syntax (#2020)
* logger hooks samples updated * [Docs] Details for WandBLoggerHook Added * [Docs] lint test pass * [Enhancement] .dev Python files updated to get better performance and quality * [Docs] Details for WandBLoggerHook Added * [Docs] lint test pass * [Enhancement] .dev Python files updated to get better performance and quality * [Enhancement] lint test passed * [Enhancement] Change Some Line from Previous to Support Python<3.9 * Update .dev/gather_models.py Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>pull/2073/head
parent
ecd1ecb6ba
commit
31395a83bd
|
@ -53,8 +53,7 @@ def parse_args():
|
|||
'-s', '--show', action='store_true', help='show results')
|
||||
parser.add_argument(
|
||||
'-d', '--device', default='cuda:0', help='Device used for inference')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def inference_model(config_name, checkpoint, args, logger=None):
|
||||
|
@ -66,11 +65,10 @@ def inference_model(config_name, checkpoint, args, logger=None):
|
|||
0.5, 0.75, 1.0, 1.25, 1.5, 1.75
|
||||
]
|
||||
cfg.data.test.pipeline[1].flip = True
|
||||
elif logger is None:
|
||||
print(f'{config_name}: unable to start aug test', flush=True)
|
||||
else:
|
||||
if logger is not None:
|
||||
logger.error(f'{config_name}: unable to start aug test')
|
||||
else:
|
||||
print(f'{config_name}: unable to start aug test', flush=True)
|
||||
logger.error(f'{config_name}: unable to start aug test')
|
||||
|
||||
model = init_segmentor(cfg, checkpoint, device=args.device)
|
||||
# test a single image
|
||||
|
|
|
@ -18,12 +18,9 @@ def check_url(url):
|
|||
Returns:
|
||||
int, bool: status code and check flag.
|
||||
"""
|
||||
flag = True
|
||||
r = requests.head(url)
|
||||
status_code = r.status_code
|
||||
if status_code == 403 or status_code == 404:
|
||||
flag = False
|
||||
|
||||
flag = status_code not in [403, 404]
|
||||
return status_code, flag
|
||||
|
||||
|
||||
|
@ -35,8 +32,7 @@ def parse_args():
|
|||
type=str,
|
||||
help='Select the model needed to check')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
|
|
|
@ -62,8 +62,8 @@ if __name__ == '__main__':
|
|||
continue
|
||||
|
||||
# Compare between new benchmark results and previous metrics
|
||||
differential_results = dict()
|
||||
new_metrics = dict()
|
||||
differential_results = {}
|
||||
new_metrics = {}
|
||||
for record_metric_key in previous_metrics:
|
||||
if record_metric_key not in metric['metric']:
|
||||
raise KeyError('record_metric_key not exist, please '
|
||||
|
|
|
@ -72,9 +72,9 @@ if __name__ == '__main__':
|
|||
print(f'log file error: {log_json_path}')
|
||||
continue
|
||||
|
||||
differential_results = dict()
|
||||
old_results = dict()
|
||||
new_results = dict()
|
||||
differential_results = {}
|
||||
old_results = {}
|
||||
new_results = {}
|
||||
for metric_key in model_performance:
|
||||
if metric_key in ['mIoU']:
|
||||
metric = round(model_performance[metric_key] * 100, 2)
|
||||
|
|
|
@ -33,7 +33,7 @@ def process_checkpoint(in_file, out_file):
|
|||
# The hash code calculation and rename command differ on different system
|
||||
# platform.
|
||||
sha = calculate_file_sha256(out_file)
|
||||
final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8])
|
||||
final_file = out_file.rstrip('.pth') + f'-{sha[:8]}.pth'
|
||||
os.rename(out_file, final_file)
|
||||
|
||||
# Remove prefix and suffix
|
||||
|
@ -50,25 +50,23 @@ def get_final_iter(config):
|
|||
|
||||
|
||||
def get_final_results(log_json_path, iter_num):
|
||||
result_dict = dict()
|
||||
result_dict = {}
|
||||
last_iter = 0
|
||||
with open(log_json_path, 'r') as f:
|
||||
for line in f.readlines():
|
||||
for line in f:
|
||||
log_line = json.loads(line)
|
||||
if 'mode' not in log_line.keys():
|
||||
continue
|
||||
|
||||
# When evaluation, the 'iter' of new log json is the evaluation
|
||||
# steps on single gpu.
|
||||
flag1 = ('aAcc' in log_line) or (log_line['mode'] == 'val')
|
||||
flag2 = (last_iter == iter_num - 50) or (last_iter == iter_num)
|
||||
flag1 = 'aAcc' in log_line or log_line['mode'] == 'val'
|
||||
flag2 = last_iter in [iter_num - 50, iter_num]
|
||||
if flag1 and flag2:
|
||||
result_dict.update({
|
||||
key: log_line[key]
|
||||
for key in RESULTS_LUT if key in log_line
|
||||
})
|
||||
return result_dict
|
||||
|
||||
last_iter = log_line['iter']
|
||||
|
||||
|
||||
|
@ -123,7 +121,7 @@ def main():
|
|||
exp_dir = osp.join(work_dir, config_name)
|
||||
# check whether the exps is finished
|
||||
final_iter = get_final_iter(used_config)
|
||||
final_model = 'iter_{}.pth'.format(final_iter)
|
||||
final_model = f'iter_{final_iter}.pth'
|
||||
model_path = osp.join(exp_dir, final_model)
|
||||
|
||||
# skip if the model is still training
|
||||
|
@ -135,7 +133,7 @@ def main():
|
|||
log_json_paths = glob.glob(osp.join(exp_dir, '*.log.json'))
|
||||
log_json_path = log_json_paths[0]
|
||||
model_performance = None
|
||||
for idx, _log_json_path in enumerate(log_json_paths):
|
||||
for _log_json_path in log_json_paths:
|
||||
model_performance = get_final_results(_log_json_path, final_iter)
|
||||
if model_performance is not None:
|
||||
log_json_path = _log_json_path
|
||||
|
@ -161,9 +159,10 @@ def main():
|
|||
model_publish_dir = osp.join(collect_dir, config_name)
|
||||
|
||||
publish_model_path = osp.join(model_publish_dir,
|
||||
config_name + '_' + model['model_time'])
|
||||
f'{config_name}_' + model['model_time'])
|
||||
|
||||
trained_model_path = osp.join(work_dir, config_name,
|
||||
'iter_{}.pth'.format(model['iters']))
|
||||
f'iter_{model["iters"]}.pth')
|
||||
if osp.exists(model_publish_dir):
|
||||
for file in os.listdir(model_publish_dir):
|
||||
if file.endswith('.pth'):
|
||||
|
|
|
@ -20,8 +20,7 @@ def parse_args():
|
|||
default='.dev/benchmark_evaluation.sh',
|
||||
help='path to save model benchmark script')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def process_model_info(model_info, work_dir):
|
||||
|
@ -30,10 +29,9 @@ def process_model_info(model_info, work_dir):
|
|||
job_name = fname
|
||||
checkpoint = model_info['checkpoint'].strip()
|
||||
work_dir = osp.join(work_dir, fname)
|
||||
if not isinstance(model_info['eval'], list):
|
||||
evals = [model_info['eval']]
|
||||
else:
|
||||
evals = model_info['eval']
|
||||
evals = model_info['eval'] if isinstance(model_info['eval'],
|
||||
list) else [model_info['eval']]
|
||||
|
||||
eval = ' '.join(evals)
|
||||
return dict(
|
||||
config=config,
|
||||
|
|
|
@ -69,14 +69,11 @@ def main():
|
|||
port = args.port
|
||||
partition_name = 'PARTITION=$1'
|
||||
|
||||
commands = []
|
||||
commands.append(partition_name)
|
||||
commands.append('\n')
|
||||
commands.append('\n')
|
||||
commands = [partition_name, '\n', '\n']
|
||||
|
||||
with open(args.txt_path, 'r') as f:
|
||||
model_cfgs = f.readlines()
|
||||
for i, cfg in enumerate(model_cfgs):
|
||||
for cfg in model_cfgs:
|
||||
create_train_bash_info(commands, cfg, script_name, '$PARTITION',
|
||||
port)
|
||||
port += 1
|
||||
|
|
|
@ -27,15 +27,11 @@ from utils import load_config
|
|||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='extract info from log.json')
|
||||
parser.add_argument('config_dir')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def has_keyword(name: str, keywords: list):
|
||||
for a_keyword in keywords:
|
||||
if a_keyword in name:
|
||||
return True
|
||||
return False
|
||||
return any(a_keyword in name for a_keyword in keywords)
|
||||
|
||||
|
||||
def main():
|
||||
|
|
|
@ -19,8 +19,7 @@ def parse_args():
|
|||
type=str,
|
||||
default='mmsegmentation/v0.5',
|
||||
help='destination folder')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
|
|
|
@ -221,9 +221,13 @@ log_config = dict( # config to register logger hook
|
|||
hooks=[
|
||||
dict(type='TextLoggerHook', by_epoch=False),
|
||||
dict(type='TensorboardLoggerHook', by_epoch=False),
|
||||
dict(type='MMSegWandbHook', by_epoch=False, init_kwargs={'entity': entity, 'project': project, 'config': cfg_dict}), # The Wandb logger is also supported, It requires `wandb` to be installed.
|
||||
dict(type='MMSegWandbHook', by_epoch=False, # The Wandb logger is also supported, It requires `wandb` to be installed.
|
||||
init_kwargs={'entity': "OpenMMLab", # The entity used to log on Wandb
|
||||
'project': "MMSeg", # Project name in WandB
|
||||
'config': cfg_dict}), # Check https://docs.wandb.ai/ref/python/init for more init arguments.
|
||||
# MMSegWandbHook is mmseg implementation of WandbLoggerHook. ClearMLLoggerHook, DvcliveLoggerHook, MlflowLoggerHook, NeptuneLoggerHook, PaviLoggerHook, SegmindLoggerHook are also supported based on MMCV implementation.
|
||||
])
|
||||
|
||||
dist_params = dict(backend='nccl') # Parameters to setup distributed training, the port can also be set.
|
||||
log_level = 'INFO' # The level of logging.
|
||||
load_from = None # load models as a pre-trained model from a given path. This will not resume training.
|
||||
|
|
|
@ -214,10 +214,13 @@ data = dict(
|
|||
]))
|
||||
log_config = dict( # 注册日志钩 (register logger hook) 的配置文件。
|
||||
interval=50, # 打印日志的间隔
|
||||
hooks=[
|
||||
hooks=[ # 训练期间执行的钩子
|
||||
dict(type='TextLoggerHook', by_epoch=False),
|
||||
dict(type='TensorboardLoggerHook', by_epoch=False),
|
||||
dict(type='MMSegWandbHook', by_epoch=False, init_kwargs={'entity': entity, 'project': project, 'config': cfg_dict}), # 同样支持 Wandb 日志
|
||||
dict(type='MMSegWandbHook', by_epoch=False, # 还支持 Wandb 记录器,它需要安装 `wandb`。
|
||||
init_kwargs={'entity': "OpenMMLab", # 用于登录wandb的实体
|
||||
'project': "mmseg", # WandB中的项目名称
|
||||
'config': cfg_dict}), # 检查 https://docs.wandb.ai/ref/python/init 以获取更多初始化参数
|
||||
])
|
||||
|
||||
dist_params = dict(backend='nccl') # 用于设置分布式训练的参数,端口也同样可被设置。
|
||||
|
|
71
setup.py
71
setup.py
|
@ -47,8 +47,7 @@ def parse_requirements(fname='requirements.txt', with_version=True):
|
|||
if line.startswith('-r '):
|
||||
# Allow specifying requirements in other files
|
||||
target = line.split(' ')[1]
|
||||
for info in parse_require_file(target):
|
||||
yield info
|
||||
yield from parse_require_file(target)
|
||||
else:
|
||||
info = {'line': line}
|
||||
if line.startswith('-e '):
|
||||
|
@ -58,7 +57,6 @@ def parse_requirements(fname='requirements.txt', with_version=True):
|
|||
pat = '(' + '|'.join(['>=', '==', '>']) + ')'
|
||||
parts = re.split(pat, line, maxsplit=1)
|
||||
parts = [p.strip() for p in parts]
|
||||
|
||||
info['package'] = parts[0]
|
||||
if len(parts) > 1:
|
||||
op, rest = parts[1:]
|
||||
|
@ -69,8 +67,8 @@ def parse_requirements(fname='requirements.txt', with_version=True):
|
|||
rest.split(';'))
|
||||
info['platform_deps'] = platform_deps
|
||||
else:
|
||||
version = rest # NOQA
|
||||
info['version'] = (op, version)
|
||||
version = rest
|
||||
info['version'] = op, version
|
||||
yield info
|
||||
|
||||
def parse_require_file(fpath):
|
||||
|
@ -78,22 +76,21 @@ def parse_requirements(fname='requirements.txt', with_version=True):
|
|||
for line in f.readlines():
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
for info in parse_line(line):
|
||||
yield info
|
||||
yield from parse_line(line)
|
||||
|
||||
def gen_packages_items():
|
||||
if exists(require_fpath):
|
||||
for info in parse_require_file(require_fpath):
|
||||
parts = [info['package']]
|
||||
if with_version and 'version' in info:
|
||||
parts.extend(info['version'])
|
||||
if not sys.version.startswith('3.4'):
|
||||
# apparently package_deps are broken in 3.4
|
||||
platform_deps = info.get('platform_deps')
|
||||
if platform_deps is not None:
|
||||
parts.append(';' + platform_deps)
|
||||
item = ''.join(parts)
|
||||
yield item
|
||||
if not exists(require_fpath):
|
||||
return
|
||||
for info in parse_require_file(require_fpath):
|
||||
parts = [info['package']]
|
||||
if with_version and 'version' in info:
|
||||
parts.extend(info['version'])
|
||||
if not sys.version.startswith('3.4'):
|
||||
platform_deps = info.get('platform_deps')
|
||||
if platform_deps is not None:
|
||||
parts.append(f';{platform_deps}')
|
||||
item = ''.join(parts)
|
||||
yield item
|
||||
|
||||
packages = list(gen_packages_items())
|
||||
return packages
|
||||
|
@ -110,35 +107,28 @@ def add_mim_extension():
|
|||
# parse installment mode
|
||||
if 'develop' in sys.argv:
|
||||
# installed by `pip install -e .`
|
||||
if platform.system() == 'Windows':
|
||||
# set `copy` mode here since symlink fails on Windows.
|
||||
mode = 'copy'
|
||||
else:
|
||||
mode = 'symlink'
|
||||
elif 'sdist' in sys.argv or 'bdist_wheel' in sys.argv or \
|
||||
platform.system() == 'Windows':
|
||||
# set `copy` mode here since symlink fails on Windows.
|
||||
mode = 'copy' if platform.system() == 'Windows' else 'symlink'
|
||||
elif 'sdist' in sys.argv or 'bdist_wheel' in sys.argv or platform.system(
|
||||
) == 'Windows':
|
||||
# installed by `pip install .`
|
||||
# or create source distribution by `python setup.py sdist`
|
||||
# set `copy` mode here since symlink fails with WinError on Windows.
|
||||
mode = 'copy'
|
||||
else:
|
||||
return
|
||||
|
||||
filenames = ['tools', 'configs', 'model-index.yml']
|
||||
repo_path = osp.dirname(__file__)
|
||||
mim_path = osp.join(repo_path, 'mmseg', '.mim')
|
||||
os.makedirs(mim_path, exist_ok=True)
|
||||
|
||||
for filename in filenames:
|
||||
if osp.exists(filename):
|
||||
src_path = osp.join(repo_path, filename)
|
||||
tar_path = osp.join(mim_path, filename)
|
||||
|
||||
if osp.isfile(tar_path) or osp.islink(tar_path):
|
||||
os.remove(tar_path)
|
||||
elif osp.isdir(tar_path):
|
||||
shutil.rmtree(tar_path)
|
||||
|
||||
if mode == 'symlink':
|
||||
src_relpath = osp.relpath(src_path, osp.dirname(tar_path))
|
||||
try:
|
||||
|
@ -149,20 +139,19 @@ def add_mim_extension():
|
|||
# the error happens, the src file will be copied
|
||||
mode = 'copy'
|
||||
warnings.warn(
|
||||
f'Failed to create a symbolic link for {src_relpath}, '
|
||||
f'and it will be copied to {tar_path}')
|
||||
f'Failed to create a symbolic link for {src_relpath},'
|
||||
f' and it will be copied to {tar_path}')
|
||||
|
||||
else:
|
||||
continue
|
||||
|
||||
if mode == 'copy':
|
||||
if osp.isfile(src_path):
|
||||
shutil.copyfile(src_path, tar_path)
|
||||
elif osp.isdir(src_path):
|
||||
shutil.copytree(src_path, tar_path)
|
||||
else:
|
||||
warnings.warn(f'Cannot copy file {src_path}.')
|
||||
else:
|
||||
if mode != 'copy':
|
||||
raise ValueError(f'Invalid mode {mode}')
|
||||
if osp.isfile(src_path):
|
||||
shutil.copyfile(src_path, tar_path)
|
||||
elif osp.isdir(src_path):
|
||||
shutil.copytree(src_path, tar_path)
|
||||
else:
|
||||
warnings.warn(f'Cannot copy file {src_path}.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue