mirror of https://github.com/RE-OWOD/RE-OWOD
119 lines
4.3 KiB
Python
119 lines
4.3 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
|
|
import logging
|
|
import os
|
|
from collections import OrderedDict
|
|
|
|
from detectron2.checkpoint import DetectionCheckpointer
|
|
from detectron2.config import CfgNode
|
|
from detectron2.engine import DefaultTrainer
|
|
from detectron2.evaluation import COCOEvaluator, DatasetEvaluators
|
|
from detectron2.utils.events import EventWriter, get_event_storage
|
|
|
|
from densepose import (
|
|
DensePoseCOCOEvaluator,
|
|
DensePoseDatasetMapperTTA,
|
|
DensePoseGeneralizedRCNNWithTTA,
|
|
load_from_cfg,
|
|
)
|
|
from densepose.data import (
|
|
DatasetMapper,
|
|
build_combined_loader,
|
|
build_detection_test_loader,
|
|
build_detection_train_loader,
|
|
build_inference_based_loaders,
|
|
has_inference_based_loaders,
|
|
)
|
|
|
|
|
|
class SampleCountingLoader:
|
|
def __init__(self, loader):
|
|
self.loader = loader
|
|
|
|
def __iter__(self):
|
|
it = iter(self.loader)
|
|
storage = get_event_storage()
|
|
while True:
|
|
try:
|
|
batch = next(it)
|
|
num_inst_per_dataset = {}
|
|
for data in batch:
|
|
dataset_name = data["dataset"]
|
|
if dataset_name not in num_inst_per_dataset:
|
|
num_inst_per_dataset[dataset_name] = 0
|
|
num_inst = len(data["instances"])
|
|
num_inst_per_dataset[dataset_name] += num_inst
|
|
for dataset_name in num_inst_per_dataset:
|
|
storage.put_scalar(f"batch/{dataset_name}", num_inst_per_dataset[dataset_name])
|
|
yield batch
|
|
except StopIteration:
|
|
break
|
|
|
|
|
|
class SampleCountMetricPrinter(EventWriter):
|
|
def __init__(self):
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
def write(self):
|
|
storage = get_event_storage()
|
|
batch_stats_strs = []
|
|
for key, buf in storage.histories().items():
|
|
if key.startswith("batch/"):
|
|
batch_stats_strs.append(f"{key} {buf.avg(20)}")
|
|
self.logger.info(", ".join(batch_stats_strs))
|
|
|
|
|
|
class Trainer(DefaultTrainer):
|
|
@classmethod
|
|
def build_evaluator(cls, cfg: CfgNode, dataset_name, output_folder=None):
|
|
if output_folder is None:
|
|
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
|
|
evaluators = [COCOEvaluator(dataset_name, cfg, True, output_folder)]
|
|
if cfg.MODEL.DENSEPOSE_ON:
|
|
evaluators.append(DensePoseCOCOEvaluator(dataset_name, True, output_folder))
|
|
return DatasetEvaluators(evaluators)
|
|
|
|
@classmethod
|
|
def build_test_loader(cls, cfg: CfgNode, dataset_name):
|
|
return build_detection_test_loader(cfg, dataset_name, mapper=DatasetMapper(cfg, False))
|
|
|
|
@classmethod
|
|
def build_train_loader(cls, cfg: CfgNode):
|
|
data_loader = build_detection_train_loader(cfg, mapper=DatasetMapper(cfg, True))
|
|
if not has_inference_based_loaders(cfg):
|
|
return data_loader
|
|
model = cls.build_model(cfg)
|
|
model.to(cfg.BOOTSTRAP_MODEL.DEVICE)
|
|
DetectionCheckpointer(model).resume_or_load(cfg.BOOTSTRAP_MODEL.WEIGHTS, resume=False)
|
|
inference_based_loaders, ratios = build_inference_based_loaders(cfg, model)
|
|
loaders = [data_loader] + inference_based_loaders
|
|
ratios = [1.0] + ratios
|
|
combined_data_loader = build_combined_loader(cfg, loaders, ratios)
|
|
sample_counting_loader = SampleCountingLoader(combined_data_loader)
|
|
return sample_counting_loader
|
|
|
|
def build_writers(self):
|
|
writers = super().build_writers()
|
|
writers.append(SampleCountMetricPrinter())
|
|
return writers
|
|
|
|
@classmethod
|
|
def test_with_TTA(cls, cfg: CfgNode, model):
|
|
logger = logging.getLogger("detectron2.trainer")
|
|
# In the end of training, run an evaluation with TTA
|
|
# Only support some R-CNN models.
|
|
logger.info("Running inference with test-time augmentation ...")
|
|
transform_data = load_from_cfg(cfg)
|
|
model = DensePoseGeneralizedRCNNWithTTA(
|
|
cfg, model, transform_data, DensePoseDatasetMapperTTA(cfg)
|
|
)
|
|
evaluators = [
|
|
cls.build_evaluator(
|
|
cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
|
|
)
|
|
for name in cfg.DATASETS.TEST
|
|
]
|
|
res = cls.test(cfg, model, evaluators)
|
|
res = OrderedDict({k + "_TTA": v for k, v in res.items()})
|
|
return res
|