mirror of https://github.com/open-mmlab/mmocr.git
Implement get_root_logger and train_detector (#4)
parent
47f5906f0a
commit
393ed7fc5a
|
@ -1,3 +1,4 @@
|
|||
from .inference import model_inference
|
||||
from .train import train_detector
|
||||
|
||||
__all__ = ['model_inference']
|
||||
__all__ = ['model_inference', 'train_detector']
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
import warnings
|
||||
|
||||
import torch
|
||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
|
||||
Fp16OptimizerHook, OptimizerHook, build_optimizer,
|
||||
build_runner)
|
||||
from mmcv.utils import build_from_cfg
|
||||
|
||||
from mmdet.core import DistEvalHook, EvalHook
|
||||
from mmdet.datasets import (build_dataloader, build_dataset,
|
||||
replace_ImageToTensor)
|
||||
from mmocr.utils import get_root_logger
|
||||
|
||||
|
||||
def train_detector(model,
|
||||
dataset,
|
||||
cfg,
|
||||
distributed=False,
|
||||
validate=False,
|
||||
timestamp=None,
|
||||
meta=None):
|
||||
logger = get_root_logger(cfg.log_level)
|
||||
|
||||
# prepare data loaders
|
||||
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
|
||||
if 'imgs_per_gpu' in cfg.data:
|
||||
logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. '
|
||||
'Please use "samples_per_gpu" instead')
|
||||
if 'samples_per_gpu' in cfg.data:
|
||||
logger.warning(
|
||||
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
|
||||
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
|
||||
f'={cfg.data.imgs_per_gpu} is used in this experiments')
|
||||
else:
|
||||
logger.warning(
|
||||
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
|
||||
f'{cfg.data.imgs_per_gpu} in this experiments')
|
||||
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
|
||||
|
||||
data_loaders = [
|
||||
build_dataloader(
|
||||
ds,
|
||||
cfg.data.samples_per_gpu,
|
||||
cfg.data.workers_per_gpu,
|
||||
# cfg.gpus will be ignored if distributed
|
||||
len(cfg.gpu_ids),
|
||||
dist=distributed,
|
||||
seed=cfg.seed) for ds in dataset
|
||||
]
|
||||
|
||||
# put model on gpus
|
||||
if distributed:
|
||||
find_unused_parameters = cfg.get('find_unused_parameters', False)
|
||||
# Sets the `find_unused_parameters` parameter in
|
||||
# torch.nn.parallel.DistributedDataParallel
|
||||
model = MMDistributedDataParallel(
|
||||
model.cuda(),
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
broadcast_buffers=False,
|
||||
find_unused_parameters=find_unused_parameters)
|
||||
else:
|
||||
model = MMDataParallel(
|
||||
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
|
||||
|
||||
# build runner
|
||||
optimizer = build_optimizer(model, cfg.optimizer)
|
||||
|
||||
if 'runner' not in cfg:
|
||||
cfg.runner = {
|
||||
'type': 'EpochBasedRunner',
|
||||
'max_epochs': cfg.total_epochs
|
||||
}
|
||||
warnings.warn(
|
||||
'config is now expected to have a `runner` section, '
|
||||
'please set `runner` in your config.', UserWarning)
|
||||
else:
|
||||
if 'total_epochs' in cfg:
|
||||
assert cfg.total_epochs == cfg.runner.max_epochs
|
||||
|
||||
runner = build_runner(
|
||||
cfg.runner,
|
||||
default_args=dict(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
work_dir=cfg.work_dir,
|
||||
logger=logger,
|
||||
meta=meta))
|
||||
|
||||
# an ugly workaround to make .log and .log.json filenames the same
|
||||
runner.timestamp = timestamp
|
||||
|
||||
# fp16 setting
|
||||
fp16_cfg = cfg.get('fp16', None)
|
||||
if fp16_cfg is not None:
|
||||
optimizer_config = Fp16OptimizerHook(
|
||||
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
|
||||
elif distributed and 'type' not in cfg.optimizer_config:
|
||||
optimizer_config = OptimizerHook(**cfg.optimizer_config)
|
||||
else:
|
||||
optimizer_config = cfg.optimizer_config
|
||||
|
||||
# register hooks
|
||||
runner.register_training_hooks(cfg.lr_config, optimizer_config,
|
||||
cfg.checkpoint_config, cfg.log_config,
|
||||
cfg.get('momentum_config', None))
|
||||
if distributed:
|
||||
if isinstance(runner, EpochBasedRunner):
|
||||
runner.register_hook(DistSamplerSeedHook())
|
||||
|
||||
# register eval hooks
|
||||
if validate:
|
||||
# Support batch_size > 1 in validation
|
||||
val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1)
|
||||
if val_samples_per_gpu > 1:
|
||||
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
|
||||
cfg.data.val.pipeline = replace_ImageToTensor(
|
||||
cfg.data.val.pipeline)
|
||||
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
|
||||
val_dataloader = build_dataloader(
|
||||
val_dataset,
|
||||
samples_per_gpu=val_samples_per_gpu,
|
||||
workers_per_gpu=cfg.data.workers_per_gpu,
|
||||
dist=distributed,
|
||||
shuffle=False)
|
||||
eval_cfg = cfg.get('evaluation', {})
|
||||
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
|
||||
eval_hook = DistEvalHook if distributed else EvalHook
|
||||
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
|
||||
|
||||
# user-defined hooks
|
||||
if cfg.get('custom_hooks', None):
|
||||
custom_hooks = cfg.custom_hooks
|
||||
assert isinstance(custom_hooks, list), \
|
||||
f'custom_hooks expect list type, but got {type(custom_hooks)}'
|
||||
for hook_cfg in cfg.custom_hooks:
|
||||
assert isinstance(hook_cfg, dict), \
|
||||
'Each item in custom_hooks expects dict type, but got ' \
|
||||
f'{type(hook_cfg)}'
|
||||
hook_cfg = hook_cfg.copy()
|
||||
priority = hook_cfg.pop('priority', 'NORMAL')
|
||||
hook = build_from_cfg(hook_cfg, HOOKS)
|
||||
runner.register_hook(hook, priority=priority)
|
||||
|
||||
if cfg.resume_from:
|
||||
runner.resume(cfg.resume_from)
|
||||
elif cfg.load_from:
|
||||
runner.load_checkpoint(cfg.load_from)
|
||||
runner.run(data_loaders, cfg.workflow)
|
|
@ -8,7 +8,7 @@ from mmcv.runner import load_checkpoint
|
|||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmdet.models.builder import BACKBONES
|
||||
from mmdet.utils import get_root_logger
|
||||
from mmocr.utils import get_root_logger
|
||||
|
||||
|
||||
class UpConvBlock(nn.Module):
|
||||
|
|
|
@ -9,8 +9,8 @@ import torch.nn as nn
|
|||
from mmcv.runner import auto_fp16
|
||||
from mmcv.utils import print_log
|
||||
|
||||
from mmdet.utils import get_root_logger
|
||||
from mmocr.core import imshow_text_label
|
||||
from mmocr.utils import get_root_logger
|
||||
|
||||
|
||||
class BaseRecognizer(nn.Module, metaclass=ABCMeta):
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
from mmcv.utils import Registry, build_from_cfg
|
||||
|
||||
from mmdet.utils import get_root_logger
|
||||
from .check_argument import (equal_len, is_2dlist, is_3dlist, is_ndarray_list,
|
||||
is_none_or_type, is_type_list, valid_boundary)
|
||||
from .collect_env import collect_env
|
||||
from .img_util import drop_orientation
|
||||
from .lmdb_util import lmdb_converter
|
||||
from .logger import get_root_logger
|
||||
|
||||
__all__ = [
|
||||
'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env',
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
import logging
|
||||
|
||||
from mmcv.utils import get_logger
|
||||
|
||||
|
||||
def get_root_logger(log_file=None, log_level=logging.INFO):
|
||||
"""Use `get_logger` method in mmcv to get the root logger.
|
||||
|
||||
The logger will be initialized if it has not been initialized. By default a
|
||||
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
||||
also be added. The name of the root logger is the top-level package name,
|
||||
e.g., "mmpose".
|
||||
|
||||
Args:
|
||||
log_file (str | None): The log filename. If specified, a FileHandler
|
||||
will be added to the root logger.
|
||||
log_level (int): The root logger level. Note that only the process of
|
||||
rank 0 is affected, while other processes will set the level to
|
||||
"Error" and be silent most of the time.
|
||||
|
||||
Returns:
|
||||
logging.Logger: The root logger.
|
||||
"""
|
||||
return get_logger(__name__.split('.')[0], log_file, log_level)
|
|
@ -8,11 +8,11 @@ import torch
|
|||
from mmcv.utils import ProgressBar
|
||||
|
||||
from mmdet.apis import init_detector
|
||||
from mmdet.utils import get_root_logger
|
||||
from mmocr.apis import model_inference
|
||||
from mmocr.core.evaluation.ocr_metric import eval_ocr_metric
|
||||
from mmocr.datasets import build_dataset # noqa: F401
|
||||
from mmocr.models import build_detector # noqa: F401
|
||||
from mmocr.utils import get_root_logger
|
||||
|
||||
|
||||
def save_results(img_paths, pred_labels, gt_labels, res_dir):
|
||||
|
|
|
@ -8,14 +8,14 @@ import warnings
|
|||
import mmcv
|
||||
import torch
|
||||
from mmcv import Config, DictAction
|
||||
from mmcv.runner import get_dist_info, init_dist
|
||||
from mmcv.runner import get_dist_info, init_dist, set_random_seed
|
||||
from mmcv.utils import get_git_hash
|
||||
|
||||
from mmdet import __version__
|
||||
from mmdet.apis import set_random_seed, train_detector
|
||||
from mmdet.utils import collect_env, get_root_logger
|
||||
from mmocr import __version__
|
||||
from mmocr.apis import train_detector
|
||||
from mmocr.datasets import build_dataset
|
||||
from mmocr.models import build_detector
|
||||
from mmocr.utils import collect_env, get_root_logger
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -187,7 +187,7 @@ def main():
|
|||
# save mmdet version, config file content and class names in
|
||||
# checkpoints as meta data
|
||||
cfg.checkpoint_config.meta = dict(
|
||||
mmdet_version=__version__ + get_git_hash()[:7],
|
||||
mmocr_version=__version__ + get_git_hash()[:7],
|
||||
CLASSES=datasets[0].CLASSES)
|
||||
# add an attribute for visualization convenience
|
||||
model.CLASSES = datasets[0].CLASSES
|
||||
|
|
Loading…
Reference in New Issue