mirror of https://github.com/RE-OWOD/RE-OWOD
239 lines
8.4 KiB
Python
239 lines
8.4 KiB
Python
|
#!/usr/bin/env python
|
||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||
|
"""
|
||
|
Detectron2 training script with a plain training loop.
|
||
|
|
||
|
This script reads a given config file and runs the training or evaluation.
|
||
|
It is an entry point that is able to train standard models in detectron2.
|
||
|
|
||
|
In order to let one script support training of many models,
|
||
|
this script contains logic that are specific to these built-in models and therefore
|
||
|
may not be suitable for your own project.
|
||
|
For example, your research project perhaps only needs a single "evaluator".
|
||
|
|
||
|
Therefore, we recommend you to use detectron2 as a library and take
|
||
|
this file as an example of how to use the library.
|
||
|
You may want to write your own script with your datasets and other customizations.
|
||
|
|
||
|
Compared to "train_net.py", this script supports fewer default features.
|
||
|
It also includes fewer abstraction, therefore is easier to add custom logic.
|
||
|
"""
|
||
|
|
||
|
import logging
|
||
|
import os
|
||
|
from collections import OrderedDict
|
||
|
import torch
|
||
|
from torch.nn.parallel import DistributedDataParallel
|
||
|
|
||
|
import detectron2.utils.comm as comm
|
||
|
from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer
|
||
|
from detectron2.config import get_cfg
|
||
|
from detectron2.data import (
|
||
|
MetadataCatalog,
|
||
|
build_detection_test_loader,
|
||
|
build_detection_train_loader,
|
||
|
)
|
||
|
from detectron2.engine import default_argument_parser, default_setup, launch
|
||
|
from detectron2.evaluation import (
|
||
|
CityscapesInstanceEvaluator,
|
||
|
CityscapesSemSegEvaluator,
|
||
|
COCOEvaluator,
|
||
|
COCOPanopticEvaluator,
|
||
|
DatasetEvaluators,
|
||
|
LVISEvaluator,
|
||
|
PascalVOCDetectionEvaluator,
|
||
|
SemSegEvaluator,
|
||
|
inference_on_dataset,
|
||
|
print_csv_format,
|
||
|
)
|
||
|
from detectron2.modeling import build_model
|
||
|
from detectron2.solver import build_lr_scheduler, build_optimizer
|
||
|
from detectron2.utils.events import (
|
||
|
CommonMetricPrinter,
|
||
|
EventStorage,
|
||
|
JSONWriter,
|
||
|
TensorboardXWriter,
|
||
|
)
|
||
|
|
||
|
logger = logging.getLogger("detectron2")
|
||
|
|
||
|
|
||
|
def get_evaluator(cfg, dataset_name, output_folder=None):
|
||
|
"""
|
||
|
Create evaluator(s) for a given dataset.
|
||
|
This uses the special metadata "evaluator_type" associated with each builtin dataset.
|
||
|
For your own dataset, you can simply create an evaluator manually in your
|
||
|
script and do not have to worry about the hacky if-else logic here.
|
||
|
"""
|
||
|
if output_folder is None:
|
||
|
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
|
||
|
evaluator_list = []
|
||
|
evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
|
||
|
if evaluator_type in ["sem_seg", "coco_panoptic_seg"]:
|
||
|
evaluator_list.append(
|
||
|
SemSegEvaluator(
|
||
|
dataset_name,
|
||
|
distributed=True,
|
||
|
num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
|
||
|
ignore_label=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
|
||
|
output_dir=output_folder,
|
||
|
)
|
||
|
)
|
||
|
if evaluator_type in ["coco", "coco_panoptic_seg"]:
|
||
|
evaluator_list.append(COCOEvaluator(dataset_name, cfg, True, output_folder))
|
||
|
if evaluator_type == "coco_panoptic_seg":
|
||
|
evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder))
|
||
|
if evaluator_type == "cityscapes_instance":
|
||
|
assert (
|
||
|
torch.cuda.device_count() >= comm.get_rank()
|
||
|
), "CityscapesEvaluator currently do not work with multiple machines."
|
||
|
return CityscapesInstanceEvaluator(dataset_name)
|
||
|
if evaluator_type == "cityscapes_sem_seg":
|
||
|
assert (
|
||
|
torch.cuda.device_count() >= comm.get_rank()
|
||
|
), "CityscapesEvaluator currently do not work with multiple machines."
|
||
|
return CityscapesSemSegEvaluator(dataset_name)
|
||
|
if evaluator_type == "pascal_voc":
|
||
|
return PascalVOCDetectionEvaluator(dataset_name)
|
||
|
if evaluator_type == "lvis":
|
||
|
return LVISEvaluator(dataset_name, cfg, True, output_folder)
|
||
|
if len(evaluator_list) == 0:
|
||
|
raise NotImplementedError(
|
||
|
"no Evaluator for the dataset {} with the type {}".format(dataset_name, evaluator_type)
|
||
|
)
|
||
|
if len(evaluator_list) == 1:
|
||
|
return evaluator_list[0]
|
||
|
return DatasetEvaluators(evaluator_list)
|
||
|
|
||
|
|
||
|
def do_test(cfg, model):
|
||
|
results = OrderedDict()
|
||
|
for dataset_name in cfg.DATASETS.TEST:
|
||
|
data_loader = build_detection_test_loader(cfg, dataset_name)
|
||
|
evaluator = get_evaluator(
|
||
|
cfg, dataset_name, os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name)
|
||
|
)
|
||
|
results_i = inference_on_dataset(model, data_loader, evaluator)
|
||
|
results[dataset_name] = results_i
|
||
|
if comm.is_main_process():
|
||
|
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
|
||
|
print_csv_format(results_i)
|
||
|
if len(results) == 1:
|
||
|
results = list(results.values())[0]
|
||
|
return results
|
||
|
|
||
|
|
||
|
def do_train(cfg, model, resume=False):
|
||
|
model.train()
|
||
|
optimizer = build_optimizer(cfg, model)
|
||
|
scheduler = build_lr_scheduler(cfg, optimizer)
|
||
|
|
||
|
checkpointer = DetectionCheckpointer(
|
||
|
model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler
|
||
|
)
|
||
|
start_iter = (
|
||
|
checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1
|
||
|
)
|
||
|
max_iter = cfg.SOLVER.MAX_ITER
|
||
|
|
||
|
periodic_checkpointer = PeriodicCheckpointer(
|
||
|
checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter
|
||
|
)
|
||
|
|
||
|
writers = (
|
||
|
[
|
||
|
CommonMetricPrinter(max_iter),
|
||
|
JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
|
||
|
TensorboardXWriter(cfg.OUTPUT_DIR),
|
||
|
]
|
||
|
if comm.is_main_process()
|
||
|
else []
|
||
|
)
|
||
|
|
||
|
# compared to "train_net.py", we do not support accurate timing and
|
||
|
# precise BN here, because they are not trivial to implement in a small training loop
|
||
|
data_loader = build_detection_train_loader(cfg)
|
||
|
logger.info("Starting training from iteration {}".format(start_iter))
|
||
|
with EventStorage(start_iter) as storage:
|
||
|
for data, iteration in zip(data_loader, range(start_iter, max_iter)):
|
||
|
storage.iter = iteration
|
||
|
|
||
|
loss_dict = model(data)
|
||
|
losses = sum(loss_dict.values())
|
||
|
assert torch.isfinite(losses).all(), loss_dict
|
||
|
|
||
|
loss_dict_reduced = {k: v.item() for k, v in comm.reduce_dict(loss_dict).items()}
|
||
|
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
|
||
|
if comm.is_main_process():
|
||
|
storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced)
|
||
|
|
||
|
optimizer.zero_grad()
|
||
|
losses.backward()
|
||
|
optimizer.step()
|
||
|
storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False)
|
||
|
scheduler.step()
|
||
|
|
||
|
if (
|
||
|
cfg.TEST.EVAL_PERIOD > 0
|
||
|
and (iteration + 1) % cfg.TEST.EVAL_PERIOD == 0
|
||
|
and iteration != max_iter - 1
|
||
|
):
|
||
|
do_test(cfg, model)
|
||
|
# Compared to "train_net.py", the test results are not dumped to EventStorage
|
||
|
comm.synchronize()
|
||
|
|
||
|
if iteration - start_iter > 5 and (
|
||
|
(iteration + 1) % 20 == 0 or iteration == max_iter - 1
|
||
|
):
|
||
|
for writer in writers:
|
||
|
writer.write()
|
||
|
periodic_checkpointer.step(iteration)
|
||
|
|
||
|
|
||
|
def setup(args):
|
||
|
"""
|
||
|
Create configs and perform basic setups.
|
||
|
"""
|
||
|
cfg = get_cfg()
|
||
|
cfg.merge_from_file(args.config_file)
|
||
|
cfg.merge_from_list(args.opts)
|
||
|
cfg.freeze()
|
||
|
default_setup(
|
||
|
cfg, args
|
||
|
) # if you don't like any of the default setup, write your own setup code
|
||
|
return cfg
|
||
|
|
||
|
|
||
|
def main(args):
|
||
|
cfg = setup(args)
|
||
|
|
||
|
model = build_model(cfg)
|
||
|
logger.info("Model:\n{}".format(model))
|
||
|
if args.eval_only:
|
||
|
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
|
||
|
cfg.MODEL.WEIGHTS, resume=args.resume
|
||
|
)
|
||
|
return do_test(cfg, model)
|
||
|
|
||
|
distributed = comm.get_world_size() > 1
|
||
|
if distributed:
|
||
|
model = DistributedDataParallel(
|
||
|
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
|
||
|
)
|
||
|
|
||
|
do_train(cfg, model, resume=args.resume)
|
||
|
return do_test(cfg, model)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
args = default_argument_parser().parse_args()
|
||
|
print("Command Line Args:", args)
|
||
|
launch(
|
||
|
main,
|
||
|
args.num_gpus,
|
||
|
num_machines=args.num_machines,
|
||
|
machine_rank=args.machine_rank,
|
||
|
dist_url=args.dist_url,
|
||
|
args=(args,),
|
||
|
)
|