mirror of https://github.com/FoundationVision/GLEE
132 lines
4.3 KiB
Python
132 lines
4.3 KiB
Python
#!/usr/bin/env python
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
"""
|
|
Training script using the new "LazyConfig" python config files.
|
|
|
|
This scripts reads a given python config file and runs the training or evaluation.
|
|
It can be used to train any models or dataset as long as they can be
|
|
instantiated by the recursive construction defined in the given config file.
|
|
|
|
Besides lazy construction of models, dataloader, etc., this scripts expects a
|
|
few common configuration parameters currently defined in "configs/common/train.py".
|
|
To add more complicated training logic, you can easily add other configs
|
|
in the config file and implement a new train_net.py to handle them.
|
|
"""
|
|
import logging
|
|
|
|
from detectron2.checkpoint import DetectionCheckpointer
|
|
from detectron2.config import LazyConfig, instantiate
|
|
from detectron2.engine import (
|
|
AMPTrainer,
|
|
SimpleTrainer,
|
|
default_argument_parser,
|
|
default_setup,
|
|
default_writers,
|
|
hooks,
|
|
launch,
|
|
)
|
|
from detectron2.engine.defaults import create_ddp_model
|
|
from detectron2.evaluation import inference_on_dataset, print_csv_format
|
|
from detectron2.utils import comm
|
|
|
|
logger = logging.getLogger("detectron2")
|
|
|
|
|
|
def do_test(cfg, model):
|
|
if "evaluator" in cfg.dataloader:
|
|
ret = inference_on_dataset(
|
|
model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator)
|
|
)
|
|
print_csv_format(ret)
|
|
return ret
|
|
|
|
|
|
def do_train(args, cfg):
|
|
"""
|
|
Args:
|
|
cfg: an object with the following attributes:
|
|
model: instantiate to a module
|
|
dataloader.{train,test}: instantiate to dataloaders
|
|
dataloader.evaluator: instantiate to evaluator for test set
|
|
optimizer: instantaite to an optimizer
|
|
lr_multiplier: instantiate to a fvcore scheduler
|
|
train: other misc config defined in `configs/common/train.py`, including:
|
|
output_dir (str)
|
|
init_checkpoint (str)
|
|
amp.enabled (bool)
|
|
max_iter (int)
|
|
eval_period, log_period (int)
|
|
device (str)
|
|
checkpointer (dict)
|
|
ddp (dict)
|
|
"""
|
|
model = instantiate(cfg.model)
|
|
logger = logging.getLogger("detectron2")
|
|
logger.info("Model:\n{}".format(model))
|
|
model.to(cfg.train.device)
|
|
|
|
cfg.optimizer.params.model = model
|
|
optim = instantiate(cfg.optimizer)
|
|
|
|
train_loader = instantiate(cfg.dataloader.train)
|
|
|
|
model = create_ddp_model(model, **cfg.train.ddp)
|
|
trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)(model, train_loader, optim)
|
|
checkpointer = DetectionCheckpointer(
|
|
model,
|
|
cfg.train.output_dir,
|
|
trainer=trainer,
|
|
)
|
|
trainer.register_hooks(
|
|
[
|
|
hooks.IterationTimer(),
|
|
hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)),
|
|
hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer)
|
|
if comm.is_main_process()
|
|
else None,
|
|
hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)),
|
|
hooks.PeriodicWriter(
|
|
default_writers(cfg.train.output_dir, cfg.train.max_iter),
|
|
period=cfg.train.log_period,
|
|
)
|
|
if comm.is_main_process()
|
|
else None,
|
|
]
|
|
)
|
|
|
|
checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume)
|
|
if args.resume and checkpointer.has_checkpoint():
|
|
# The checkpoint stores the training iteration that just finished, thus we start
|
|
# at the next iteration
|
|
start_iter = trainer.iter + 1
|
|
else:
|
|
start_iter = 0
|
|
trainer.train(start_iter, cfg.train.max_iter)
|
|
|
|
|
|
def main(args):
|
|
cfg = LazyConfig.load(args.config_file)
|
|
cfg = LazyConfig.apply_overrides(cfg, args.opts)
|
|
default_setup(cfg, args)
|
|
|
|
if args.eval_only:
|
|
model = instantiate(cfg.model)
|
|
model.to(cfg.train.device)
|
|
model = create_ddp_model(model)
|
|
DetectionCheckpointer(model).load(cfg.train.init_checkpoint)
|
|
print(do_test(cfg, model))
|
|
else:
|
|
do_train(args, cfg)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = default_argument_parser().parse_args()
|
|
launch(
|
|
main,
|
|
args.num_gpus,
|
|
num_machines=args.num_machines,
|
|
machine_rank=args.machine_rank,
|
|
dist_url=args.dist_url,
|
|
args=(args,),
|
|
)
|