OWOD/detectron2/engine/defaults.py

579 lines
22 KiB
Python

# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
This file contains components with some default boilerplate logic user may need
in training / testing. They will not work for everyone, but many users may find them useful.
The behavior of functions/classes in this file is subject to change,
since they are meant to represent the "common default behavior" people need in their projects.
"""
import argparse
import logging
import os
import sys
from collections import OrderedDict
import torch
from fvcore.common.file_io import PathManager
from fvcore.nn.precise_bn import get_bn_modules
from torch.nn.parallel import DistributedDataParallel
import detectron2.data.transforms as T
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.data import (
MetadataCatalog,
build_detection_test_loader,
build_detection_train_loader,
)
from detectron2.evaluation import (
DatasetEvaluator,
inference_on_dataset,
print_csv_format,
verify_results,
)
from detectron2.modeling import build_model
from detectron2.solver import build_lr_scheduler, build_optimizer
from detectron2.utils import comm
from detectron2.utils.collect_env import collect_env_info
from detectron2.utils.env import seed_all_rng
from detectron2.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
from detectron2.utils.logger import setup_logger
from . import hooks
from .train_loop import SimpleTrainer
__all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]
def default_argument_parser(epilog=None):
"""
Create a parser with some common arguments used by detectron2 users.
Args:
epilog (str): epilog passed to ArgumentParser describing the usage.
Returns:
argparse.ArgumentParser:
"""
parser = argparse.ArgumentParser(
epilog=epilog
or f"""
Examples:
Run on single machine:
$ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth
Run on multiple machines:
(machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags]
(machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags]
""",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
parser.add_argument(
"--resume",
action="store_true",
help="whether to attempt to resume from the checkpoint directory",
)
parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
parser.add_argument(
"--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
)
# PyTorch still may leave orphan processes in multi-gpu training.
# Therefore we use a deterministic way to obtain port,
# so that users are aware of orphan processes by seeing the port occupied.
port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
parser.add_argument(
"--dist-url",
default="tcp://127.0.0.1:{}".format(port),
help="initialization URL for pytorch distributed backend. See "
"https://pytorch.org/docs/stable/distributed.html for details.",
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
return parser
def default_setup(cfg, args):
"""
Perform some basic common setups at the beginning of a job, including:
1. Set up the detectron2 logger
2. Log basic information about environment, cmdline arguments, and config
3. Backup the config to the output directory
Args:
cfg (CfgNode): the full config to be used
args (argparse.NameSpace): the command line arguments to be logged
"""
output_dir = cfg.OUTPUT_DIR
if comm.is_main_process() and output_dir:
PathManager.mkdirs(output_dir)
rank = comm.get_rank()
setup_logger(output_dir, distributed_rank=rank, name="fvcore")
logger = setup_logger(output_dir, distributed_rank=rank)
logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
logger.info("Environment info:\n" + collect_env_info())
logger.info("Command line arguments: " + str(args))
if hasattr(args, "config_file") and args.config_file != "":
logger.info(
"Contents of args.config_file={}:\n{}".format(
args.config_file, PathManager.open(args.config_file, "r").read()
)
)
logger.info("Running with full config:\n{}".format(cfg))
if comm.is_main_process() and output_dir:
# Note: some of our scripts may expect the existence of
# config.yaml in output directory
path = os.path.join(output_dir, "config.yaml")
with PathManager.open(path, "w") as f:
f.write(cfg.dump())
logger.info("Full config saved to {}".format(path))
# make sure each worker has a different, yet deterministic seed if specified
seed_all_rng(None if cfg.SEED < 0 else cfg.SEED + rank)
# cudnn benchmark has large overhead. It shouldn't be used considering the small size of
# typical validation set.
if not (hasattr(args, "eval_only") and args.eval_only):
torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK
class DefaultPredictor:
"""
Create a simple end-to-end predictor with the given config that runs on
single device for a single input image.
Compared to using the model directly, this class does the following additions:
1. Load checkpoint from `cfg.MODEL.WEIGHTS`.
2. Always take BGR image as the input and apply conversion defined by `cfg.INPUT.FORMAT`.
3. Apply resizing defined by `cfg.INPUT.{MIN,MAX}_SIZE_TEST`.
4. Take one input image and produce a single output, instead of a batch.
If you'd like to do anything more fancy, please refer to its source code
as examples to build and use the model manually.
Attributes:
metadata (Metadata): the metadata of the underlying dataset, obtained from
cfg.DATASETS.TEST.
Examples:
::
pred = DefaultPredictor(cfg)
inputs = cv2.imread("input.jpg")
outputs = pred(inputs)
"""
def __init__(self, cfg):
self.cfg = cfg.clone() # cfg can be modified by model
self.model = build_model(self.cfg)
self.model.eval()
if len(cfg.DATASETS.TEST):
self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
checkpointer = DetectionCheckpointer(self.model)
checkpointer.load(cfg.MODEL.WEIGHTS)
self.aug = T.ResizeShortestEdge(
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
)
self.input_format = cfg.INPUT.FORMAT
assert self.input_format in ["RGB", "BGR"], self.input_format
def __call__(self, original_image):
"""
Args:
original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
Returns:
predictions (dict):
the output of the model for one image only.
See :doc:`/tutorials/models` for details about the format.
"""
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
# Apply pre-processing to image.
if self.input_format == "RGB":
# whether the model expects BGR inputs or RGB
original_image = original_image[:, :, ::-1]
height, width = original_image.shape[:2]
image = self.aug.get_transform(original_image).apply_image(original_image)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
inputs = {"image": image, "height": height, "width": width}
predictions = self.model([inputs])[0]
return predictions
class DefaultTrainer(SimpleTrainer):
"""
A trainer with default training logic.
It is a subclass of :class:`SimpleTrainer` and instantiates everything needed from the
config. It does the following:
1. Create model, optimizer, scheduler, dataloader from the given config.
2. Load a checkpoint or `cfg.MODEL.WEIGHTS`, if exists, when
`resume_or_load` is called.
3. Register a few common hooks defined by the config.
It is created to simplify the **standard model training workflow** and reduce code boilerplate
for users who only need the standard training workflow, with standard features.
It means this class makes *many assumptions* about your training logic that
may easily become invalid in a new research. In fact, any assumptions beyond those made in the
:class:`SimpleTrainer` are too much for research.
The code of this class has been annotated about restrictive assumptions it makes.
When they do not work for you, you're encouraged to:
1. Overwrite methods of this class, OR:
2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
nothing else. You can then add your own hooks if needed. OR:
3. Write your own training loop similar to `tools/plain_train_net.py`.
See the :doc:`/tutorials/training` tutorials for more details.
Note that the behavior of this class, like other functions/classes in
this file, is not stable, since it is meant to represent the "common default behavior".
It is only guaranteed to work well with the standard models and training workflow in detectron2.
To obtain more stable behavior, write your own training logic with other public APIs.
Examples:
::
trainer = DefaultTrainer(cfg)
trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS
trainer.train()
Attributes:
scheduler:
checkpointer (DetectionCheckpointer):
cfg (CfgNode):
"""
def __init__(self, cfg):
"""
Args:
cfg (CfgNode):
"""
logger = logging.getLogger("detectron2")
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
setup_logger()
cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
# Assume these objects must be constructed in this order.
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
data_loader = self.build_train_loader(cfg)
# For training, wrap with DDP. But don't need this for inference.
if comm.get_world_size() > 1:
model = DistributedDataParallel(
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
)
super().__init__(model, data_loader, optimizer)
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
# Assume no other objects need to be checkpointed.
# We can later make it checkpoint the stateful hooks
self.checkpointer = DetectionCheckpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
cfg.OUTPUT_DIR,
optimizer=optimizer,
scheduler=self.scheduler,
)
self.start_iter = 0
self.max_iter = cfg.SOLVER.MAX_ITER
self.cfg = cfg
self.register_hooks(self.build_hooks())
def resume_or_load(self, resume=True):
"""
If `resume==True`, and last checkpoint exists, resume from it, load all checkpointables
(eg. optimizer and scheduler) and update iteration counter from it. ``cfg.MODEL.WEIGHTS``
will not be used.
Otherwise, load the model specified by the config (skip all checkpointables) and start from
the first iteration.
Args:
resume (bool): whether to do resume or not
"""
checkpoint = self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
if resume and self.checkpointer.has_checkpoint():
self.start_iter = checkpoint.get("iteration", -1) + 1
# The checkpoint stores the training iteration that just finished, thus we start
# at the next iteration (or iter zero if there's no checkpoint).
def build_hooks(self):
"""
Build a list of default hooks, including timing, evaluation,
checkpointing, lr scheduling, precise BN, writing events.
Returns:
list[HookBase]:
"""
cfg = self.cfg.clone()
cfg.defrost()
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
ret = [
hooks.IterationTimer(),
hooks.LRScheduler(self.optimizer, self.scheduler),
hooks.PreciseBN(
# Run at the same freq as (but before) evaluation.
cfg.TEST.EVAL_PERIOD,
self.model,
# Build a new data loader to not affect training
self.build_train_loader(cfg),
cfg.TEST.PRECISE_BN.NUM_ITER,
)
if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
else None,
]
# Do PreciseBN before checkpointer, because it updates the model and need to
# be saved by checkpointer.
# This is not always the best: if checkpointing has a different frequency,
# some checkpoints may have more precise statistics than others.
if comm.is_main_process():
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
def test_and_save_results():
self._last_eval_results = self.test(self.cfg, self.model)
return self._last_eval_results
# Do evaluation after checkpointer, because then if it fails,
# we can use the saved checkpoint to debug.
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
if comm.is_main_process():
# run writers in the end, so that evaluation metrics are written
ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
return ret
def build_writers(self):
"""
Build a list of writers to be used. By default it contains
writers that write metrics to the screen,
a json file, and a tensorboard event file respectively.
If you'd like a different list of writers, you can overwrite it in
your trainer.
Returns:
list[EventWriter]: a list of :class:`EventWriter` objects.
It is now implemented by:
::
return [
CommonMetricPrinter(self.max_iter),
JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
TensorboardXWriter(self.cfg.OUTPUT_DIR),
]
"""
# Here the default print/log frequency of each writer is used.
return [
# It may not always print what you want to see, since it prints "common" metrics only.
CommonMetricPrinter(self.max_iter),
JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
TensorboardXWriter(self.cfg.OUTPUT_DIR),
]
def train(self):
"""
Run training.
Returns:
OrderedDict of results, if evaluation is enabled. Otherwise None.
"""
super().train(self.start_iter, self.max_iter)
if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process():
assert hasattr(
self, "_last_eval_results"
), "No evaluation results obtained during training!"
verify_results(self.cfg, self._last_eval_results)
return self._last_eval_results
@classmethod
def build_model(cls, cfg):
"""
Returns:
torch.nn.Module:
It now calls :func:`detectron2.modeling.build_model`.
Overwrite it if you'd like a different model.
"""
model = build_model(cfg)
logger = logging.getLogger(__name__)
logger.info("Model:\n{}".format(model))
return model
@classmethod
def build_optimizer(cls, cfg, model):
"""
Returns:
torch.optim.Optimizer:
It now calls :func:`detectron2.solver.build_optimizer`.
Overwrite it if you'd like a different optimizer.
"""
return build_optimizer(cfg, model)
@classmethod
def build_lr_scheduler(cls, cfg, optimizer):
"""
It now calls :func:`detectron2.solver.build_lr_scheduler`.
Overwrite it if you'd like a different scheduler.
"""
return build_lr_scheduler(cfg, optimizer)
@classmethod
def build_train_loader(cls, cfg):
"""
Returns:
iterable
It now calls :func:`detectron2.data.build_detection_train_loader`.
Overwrite it if you'd like a different data loader.
"""
return build_detection_train_loader(cfg)
@classmethod
def build_test_loader(cls, cfg, dataset_name):
"""
Returns:
iterable
It now calls :func:`detectron2.data.build_detection_test_loader`.
Overwrite it if you'd like a different data loader.
"""
return build_detection_test_loader(cfg, dataset_name)
@classmethod
def build_evaluator(cls, cfg, dataset_name):
"""
Returns:
DatasetEvaluator or None
It is not implemented by default.
"""
raise NotImplementedError(
"""
If you want DefaultTrainer to automatically run evaluation,
please implement `build_evaluator()` in subclasses (see train_net.py for example).
Alternatively, you can call evaluation functions yourself (see Colab balloon tutorial for example).
"""
)
@classmethod
def test(cls, cfg, model, evaluators=None):
"""
Args:
cfg (CfgNode):
model (nn.Module):
evaluators (list[DatasetEvaluator] or None): if None, will call
:meth:`build_evaluator`. Otherwise, must have the same length as
`cfg.DATASETS.TEST`.
Returns:
dict: a dict of result metrics
"""
logger = logging.getLogger(__name__)
if isinstance(evaluators, DatasetEvaluator):
evaluators = [evaluators]
if evaluators is not None:
assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
len(cfg.DATASETS.TEST), len(evaluators)
)
results = OrderedDict()
for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
data_loader = cls.build_test_loader(cfg, dataset_name)
# When evaluators are passed in as arguments,
# implicitly assume that evaluators can be created before data_loader.
if evaluators is not None:
evaluator = evaluators[idx]
else:
try:
evaluator = cls.build_evaluator(cfg, dataset_name)
except NotImplementedError:
logger.warn(
"No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
"or implement its `build_evaluator` method."
)
results[dataset_name] = {}
continue
results_i = inference_on_dataset(model, data_loader, evaluator)
results[dataset_name] = results_i
if comm.is_main_process():
assert isinstance(
results_i, dict
), "Evaluator must return a dict on the main process. Got {} instead.".format(
results_i
)
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
print_csv_format(results_i)
if len(results) == 1:
results = list(results.values())[0]
return results
@staticmethod
def auto_scale_workers(cfg, num_workers: int):
"""
When the config is defined for certain number of workers (according to
``cfg.SOLVER.REFERENCE_WORLD_SIZE``) that's different from the number of
workers currently in use, returns a new cfg where the total batch size
is scaled so that the per-GPU batch size stays the same as the
original ``IMS_PER_BATCH // REFERENCE_WORLD_SIZE``.
Other config options are also scaled accordingly:
* training steps and warmup steps are scaled inverse proportionally.
* learning rate are scaled proportionally, following :paper:`ImageNet in 1h`.
It returns the original config if ``cfg.SOLVER.REFERENCE_WORLD_SIZE==0``.
Returns:
CfgNode: a new config
"""
old_world_size = cfg.SOLVER.REFERENCE_WORLD_SIZE
if old_world_size == 0 or old_world_size == num_workers:
return cfg
cfg = cfg.clone()
frozen = cfg.is_frozen()
cfg.defrost()
assert (
cfg.SOLVER.IMS_PER_BATCH % old_world_size == 0
), "Invalid REFERENCE_WORLD_SIZE in config!"
scale = num_workers / old_world_size
bs = cfg.SOLVER.IMS_PER_BATCH = int(round(cfg.SOLVER.IMS_PER_BATCH * scale))
lr = cfg.SOLVER.BASE_LR = cfg.SOLVER.BASE_LR * scale
max_iter = cfg.SOLVER.MAX_ITER = int(round(cfg.SOLVER.MAX_ITER / scale))
warmup_iter = cfg.SOLVER.WARMUP_ITERS = int(round(cfg.SOLVER.WARMUP_ITERS / scale))
cfg.SOLVER.STEPS = tuple(int(round(s / scale)) for s in cfg.SOLVER.STEPS)
cfg.TEST.EVAL_PERIOD = int(round(cfg.TEST.EVAL_PERIOD / scale))
cfg.SOLVER.REFERENCE_WORLD_SIZE = num_workers # maintain invariant
logger = logging.getLogger(__name__)
logger.info(
f"Auto-scaling the config to batch_size={bs}, learning_rate={lr}, "
f"max_iter={max_iter}, warmup={warmup_iter}."
)
if frozen:
cfg.freeze()
return cfg