mmfewshot/tools/train.py
Linyiqi 7ac7436d11
dataset refactoring v2 (#10)
* add doc string

* update configs

* update caffe config

* update caffe config

* update config

* rm ignore ann in fewshot

* add difficult option in voc dataset

* add difficult option in voc dataset

* add difficult option in voc dataset

* show datasets

* add FSCE file

* fix FSCE bug

* update config

* update config

* update config

* update config

* update config

* update config

* run script

* update config

* update config

* test aug

* test aug

* update config

* update config

* update config

* update config

* update config

* update config

* update config

* fix comment

* fix comment

* fix comment

* change voc index

* update config

* change voc iou

* change voc iou

* change voc iou

* change voc iou

* Revert "change voc iou"

This reverts commit e5f4ca8bc4701c7c69e900b746eaed14cccf87bb.

* reset voc

* Revert "Merge branch 'lyq-fsce' of https://github.com/linyq17/mmfewshot into lyq-fsce"

This reverts commit 1fde859a1625eb0cb900e4a9baa1125c2f387ae5, reversing
changes made to 75c40006a23f456a56b688a63e66f133d064e01d.

* update weight_decay

* add fsod training code

* fsod training code

fsod training code

fsod training code

fsod training code

fsod training code

fsod training code

* disable group config

* fix empty gt bug

fix empty gt bug

fix empty gt bug

* fsod test code

* add support template init into test code

* add support template init into test code

* fix coco del bug

* fix coco del bug

* update training config

* update test config

* test config

* test config

* test config

* update anchor config

* update anchor config

* update get bbox

* update get bbox

* update training config

* update loss weight

* update loss weight

* update loss weight

* update training config

* update training config

* add repeat dataset into dataset warpper

* mv base code to fsod

* update config

* update config

* update config

* update docstr of script

* update config of cl branch

* fix few shot config bug

* add docstr

* disable filp

* update weight decay config

* add docstr and update loss weight

* add docstr and update loss weight

* add docstr and update loss weight

* update loss weight

* update loss weight

* update loss weight

* fix arpn bug

* update lr config

* fix test bug

* update config

* update config

* update config

* update config

* update config

* fix support order bug

* fsdetview training and testing code

* update config name

* update config

* update config

* update config

* check data loader

* update config

* update config

* update config

* fix dataloader bug

* update config

* check rank

* check rank

* check rank

* check rank

* check rank

* check rank

* check rank

* check rank

* rm check rank

* dataset refactoring

* add save dataset function

* add doc string

* update config

* update config

* fix dataset bulider bug

* update ckpt_surgery script

* update ckpt_surgery script

* rename arpn to attention_rpn

* update tfa config

* update tfa config

* update tfa config

* update tfa config

* update tfa config

* update tfa config

* update script

* update config

* update fsdetview config

* update fsdetview config

* update fsdetview config

* update tfa config

* update tfa config

* update fsdetview config

* run script

* run script

* run script

* fix comments

* fix comments

* fix comments

* fix save dataset bug

* update config

* update config

* update config

* update config

* update config

* update config

* update config

* update config

* update dataset doc

* fix dataset loading bug

* fix dataset loading bug

* create dataset pr

* create dataset pr

* update config

* update config

* update config

* update config

* update config

* update config

* update config

* update config

* update config

* update config

* update config

* update config

* update config

* update config

* fix eval class splits bug

* fix eval class splits bug

* update attention rpn voc config

* update attention rpn voc config

* update attention rpn voc config

* update attention rpn config

* fix voc dataset bug

* fix voc dataset bug

* udpate config

* update config

* update config

* update config

* update config

* update config

* update config

* update config

* update config

* update config

* update config

* update config

* update config

* update config

* update config

* update config

* add dataset visualization

* rm unused arg

* rm unused arg

* update config

* add visual dataset

* fix dataset loading bug

* fix dataset loading bug

* update config

* update config

* add visualize dataset code

* voc base training debug

* update tfa voc base config

* update config

* update config

* add multiple training

* update tfa voc lr

* update config

* update config

* update config

* update config

* update config

* update voc metric

* fix voc metric

* add dataset generate code

* update base training

* update save dataset

* update doc string

* create pr

* create pr

* fix dataset loading bug

* fix comments

* save support set for queryawaredtaset

* fix comments

* fix comments

fix pipeline parameter

* fix comments

* refactoring ann_file

* fix commets

* add config check

* refactoring ann_cfg datasetwrapper

* add doc string

* fix bug

* fix bug

* add dataset name & fix doc str

* fix doc str

* fix doc str

* fix doc str

* fix doc str

* rm model config
2021-07-23 12:58:12 +08:00

225 lines
8.4 KiB
Python

import argparse
import copy
import os
import os.path as osp
import time
import warnings
import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import get_git_hash
from mmdet.utils import collect_env, get_root_logger
import mmfewshot # noqa: F401, F403
from mmfewshot import __version__
from mmfewshot.apis import set_random_seed, train_model
from mmfewshot.builders.dataset_builder import build_dataset
from mmfewshot.builders.model_builder import build_model
from mmfewshot.utils.check_config import check_config
def parse_args():
parser = argparse.ArgumentParser(description='Train a FewShot model')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
parser.add_argument(
'--no-validate',
action='store_true',
help='whether not to evaluate the checkpoint during training')
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--gpus',
type=int,
help='number of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='ids of gpus to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file (deprecate), '
'change to --cfg-options instead.')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
if args.options and args.cfg_options:
raise ValueError(
'--options and --cfg-options cannot be both '
'specified, --options is deprecated in favor of --cfg-options')
if args.options:
warnings.warn('--options is deprecated in favor of --cfg-options')
args.cfg_options = args.options
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
cfg = check_config(cfg)
# import modules from string list.
if cfg.get('custom_imports', None):
from mmcv.utils import import_modules_from_strings
import_modules_from_strings(**cfg['custom_imports'])
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
if args.resume_from is not None:
cfg.resume_from = args.resume_from
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids
else:
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# re-set gpu_ids with distributed training mode
rank, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# dump config
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
meta['env_info'] = env_info
meta['config'] = cfg.pretty_text
# log some basic info
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')
# set random seeds
if args.seed is not None:
seed = args.seed
elif cfg.seed is not None:
seed = cfg.seed
elif distributed:
seed = 1234567
Warning(f'When using DistributedDataParallel, each rank will '
f'initialize different random seed. It will cause different'
f'random action for each rank. In few shot setting, novel '
f'shots may be generated by random sampling. If all rank do '
f'not use same seed, each rank will sample different data.'
f'It will cause UNFAIR data usage. Therefore, seed is set '
f'to {seed} for default.')
else:
seed = None
if seed is not None:
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
meta['seed'] = seed
meta['exp_name'] = osp.basename(args.config)
# get fixed parameters
frozen_parameters = cfg.model.pop('frozen_parameters', None)
model = build_model(cfg.model, task_type=cfg.task_type)
model.init_weights()
# fix parameters by prefix
if frozen_parameters is not None:
for name, param in model.named_parameters():
for frozen_prefix in frozen_parameters:
if frozen_prefix in name:
param.requires_grad = False
# If save_dataset is set to True, dataset will be saved into json.
save_dataset = cfg.data.train.pop('save_dataset', False)
datasets = [build_dataset(cfg.data.train, task_type=cfg.task_type)]
if save_dataset:
save_dataset_path = osp.join(cfg.work_dir,
f'{timestamp}_saved_data.json')
if cfg.data.train.type == 'RepeatDataset':
datasets[0].dataset.save_data_infos(save_dataset_path)
else:
datasets[0].save_data_infos(save_dataset_path)
if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val)
val_dataset.pipeline = cfg.data.train.pipeline
datasets.append(build_dataset(val_dataset, task_type=cfg.task_type))
if cfg.checkpoint_config is not None:
# save mmfewshot version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmfewshot_version=__version__ + get_git_hash()[:7],
CLASSES=datasets[0].CLASSES)
# add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
train_model(
model,
datasets,
cfg,
task_type=cfg.task_type,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
if __name__ == '__main__':
main()