Segment-Everything-Everywhe.../pipeline/XDecoderPipeline.py

197 lines
8.3 KiB
Python

# --------------------------------------------------------
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Modified by Xueyan Zou (xueyan@cs.wisc.edu)
# --------------------------------------------------------
import logging
import time
import datetime
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import Tuple, Dict, List, Union
from infinibatch import iterators
from trainer.default_trainer import DefaultTrainer
from detectron2.evaluation import inference_on_dataset
from detectron2.utils.logger import log_every_n_seconds
from detectron2.data import MetadataCatalog
from modeling import build_model
from modeling.utils import get_class_names
from modeling.BaseModel import BaseModel
from datasets import build_evaluator, build_eval_dataloader, build_train_dataloader
from utils.distributed import is_main_process
from utils.constants import COCO_PANOPTIC_CLASSES
from trainer.utils.misc import move_batch_to_device, cast_batch_to_half
from .utils.misc import hook_metadata, hook_switcher, hook_opt
logger = logging.getLogger(__name__)
class XDecoderPipeline:
def __init__(self, opt):
self._opt = opt
def initialize_model(self):
model_name = "default"
model = build_model(self._opt)
model.train()
if is_main_process():
logger.info(model)
raw_models = {model_name: BaseModel(self._opt, model)}
return raw_models
def get_dataloaders(
self, trainer: DefaultTrainer,
dataset_label: str,
is_evaluation: bool
) -> Union[DataLoader, iterators.CheckpointableIterator]:
distributed = self._opt['world_size'] > 1
if is_evaluation:
if not hasattr(self, 'valid_loader'):
dataloaders = build_eval_dataloader(self._opt)
self.valid_loader = dataloaders
else:
dataloaders = self.valid_loader
idx = 0 if dataset_label=='dev' else self._opt['DATASETS']['TEST'].index(dataset_label)
dataloader = dataloaders[idx]
self.evaluator = build_evaluator(self._opt, self._opt['DATASETS']['TEST'][idx], self._opt['SAVE_DIR'])
else:
if not hasattr(self, 'train_loader'):
dataloader = build_train_dataloader(self._opt)
self.train_loader = dataloader
logger.info(f'num of train samples: {len(dataloader)}')
else:
dataloader = self.train_loader
# temp solution for lr scheduler
steps_total = len(self.train_loader)
steps_acc = self._opt['GRADIENT_ACCUMULATE_STEP']
steps_update = steps_total // steps_acc
self._opt["LR_SCHEDULER_PARAMS"]["steps_update_per_epoch"] = steps_update
return dataloader
@staticmethod
def forward_func(trainer, batch):
loss = trainer.models['default'](batch)
return loss
def forward_step(
self,
trainer: DefaultTrainer,
batch,
grad_acc_batches: List,
grad_acc_index: int,
is_distributed: bool,
) -> Tuple[Dict[str, float], Dict[str, int], Dict]:
loss_info, sample_size_info, extra_info = {}, {}, {}
batch = move_batch_to_device(batch, self._opt['device'])
if self._opt['FP16']:
# in FP16 mode, DeepSpeed casts the model to FP16, so the input needs to be manually casted to FP16
batch = cast_batch_to_half(batch)
loss = trainer.compute_loss(self.forward_func, batch)
loss_info = {k: v.detach().item() for k,v in loss.items()}
sample_size_info = {'num_samples': len(batch)}
loss = sum(loss for loss in loss.values())
trainer.backward_loss(loss, model_names=['default'])
trainer.update_model(model_name='default')
return loss_info, sample_size_info, extra_info
def evaluate_model(
self,
trainer: DefaultTrainer,
save_folder,
) -> Tuple[Dict, Dict[str, float], bool]:
model = trainer.raw_models['default'].eval()
self._opt = hook_opt(self._opt)
dataset_names = self._opt['DATASETS']['TEST']
scores = {}
summary = {}
for dataset_label in dataset_names:
torch.cuda.empty_cache()
eval_batch_gen = self.get_dataloaders(trainer, dataset_label, is_evaluation=True)
self.evaluator.reset()
with torch.no_grad():
names = get_class_names(dataset_label)
model.model.metadata = MetadataCatalog.get(dataset_label)
model.model.metadata = hook_metadata(model.model.metadata, dataset_label)
eval_type = model.model.metadata.evaluator_type
if 'background' in names:
model.model.sem_seg_head.num_classes = len(names) - 1
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(names, is_eval=True)
hook_switcher(model, dataset_label)
total = len(eval_batch_gen)
num_warmup = min(5, total - 1)
start_time = time.perf_counter()
total_data_time = 0
total_compute_time = 0
total_eval_time = 0
start_data_time = time.perf_counter()
for idx, batch in enumerate(eval_batch_gen):
total_data_time += time.perf_counter() - start_data_time
if idx == num_warmup:
start_time = time.perf_counter()
total_data_time = 0
total_compute_time = 0
total_eval_time = 0
start_compute_time = time.perf_counter()
batch = move_batch_to_device(batch, self._opt['device'])
if self._opt['FP16']:
# in FP16 mode, DeepSpeed casts the model to FP16, so the input needs to be manually casted to FP16
batch = cast_batch_to_half(batch)
outputs = model(batch, mode=eval_type)
if torch.cuda.is_available():
torch.cuda.synchronize()
total_compute_time += time.perf_counter() - start_compute_time
start_eval_time = time.perf_counter()
self.evaluator.process(batch, outputs)
total_eval_time += time.perf_counter() - start_eval_time
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
data_seconds_per_iter = total_data_time / iters_after_start
compute_seconds_per_iter = total_compute_time / iters_after_start
eval_seconds_per_iter = total_eval_time / iters_after_start
total_seconds_per_iter = (time.perf_counter() - start_time) / iters_after_start
if is_main_process() and (idx >= num_warmup * 2 or compute_seconds_per_iter > 5):
eta = datetime.timedelta(seconds=int(total_seconds_per_iter * (total - idx - 1)))
log_every_n_seconds(
logging.INFO,
(
f"Task {dataset_label}. "
f"Inference done {idx + 1}/{total}. "
f"Dataloading: {data_seconds_per_iter:.4f} s/iter. "
f"Inference: {compute_seconds_per_iter:.4f} s/iter. "
f"Eval: {eval_seconds_per_iter:.4f} s/iter. "
f"Total: {total_seconds_per_iter:.4f} s/iter. "
f"ETA={eta}"
),
n=5,
)
start_data_time = time.perf_counter()
results = self.evaluator.evaluate()
model.model.sem_seg_head.predictor.lang_encoder.reset_text_embeddings()
if is_main_process():
scores["{}/{}".format(dataset_label, eval_type)] = results
# set back to training stat.
model.model.sem_seg_head.num_classes = self._opt['MODEL']['ENCODER']['NUM_CLASSES']
model.model.metadata = MetadataCatalog.get(self._opt['DATASETS']['TRAIN'][0])
return scores