diff --git a/fastreid/engine/defaults.py b/fastreid/engine/defaults.py index c16a9b6..3a21bdb 100644 --- a/fastreid/engine/defaults.py +++ b/fastreid/engine/defaults.py @@ -31,14 +31,8 @@ from fastreid.utils.file_io import PathManager from fastreid.utils.logger import setup_logger from . import hooks from .train_loop import TrainerBase, AMPTrainer, SimpleTrainer +from torch.nn.parallel import DistributedDataParallel -try: - import apex - from apex import amp - from apex.parallel import DistributedDataParallel -except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example if you want to" - "train with DDP") __all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"] @@ -214,19 +208,14 @@ class DefaultTrainer(TrainerBase): model = self.build_model(cfg) optimizer = self.build_optimizer(cfg, model) - optimizer_ckpt = dict(optimizer=optimizer) - if cfg.SOLVER.FP16_ENABLED: - model, optimizer = amp.initialize(model, optimizer, opt_level="O1") - optimizer_ckpt.update(dict(amp=amp)) - # For training, wrap with DDP. But don't need this for inference. if comm.get_world_size() > 1: # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True` # for part of the parameters is not updated. - # model = DistributedDataParallel( - # model, device_ids=[comm.get_local_rank()], broadcast_buffers=False - # ) - model = DistributedDataParallel(model, delay_allreduce=True) + model = DistributedDataParallel( + model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, + find_unused_parameters=True + ) self._trainer = (AMPTrainer if cfg.SOLVER.FP16_ENABLED else SimpleTrainer)( model, data_loader, optimizer @@ -242,7 +231,7 @@ class DefaultTrainer(TrainerBase): model, cfg.OUTPUT_DIR, save_to_disk=comm.is_main_process(), - **optimizer_ckpt, + optimizer=optimizer, **self.scheduler, ) diff --git a/fastreid/engine/hooks.py b/fastreid/engine/hooks.py index 6a63770..bdde826 100644 --- a/fastreid/engine/hooks.py +++ b/fastreid/engine/hooks.py @@ -10,8 +10,7 @@ import time from collections import Counter import torch -from apex.parallel import DistributedDataParallel -from torch import nn +from torch.nn.parallel import DistributedDataParallel from fastreid.evaluation.testing import flatten_results_dict from fastreid.solver import optim diff --git a/fastreid/engine/train_loop.py b/fastreid/engine/train_loop.py index 179f9e3..4622fb1 100644 --- a/fastreid/engine/train_loop.py +++ b/fastreid/engine/train_loop.py @@ -11,7 +11,7 @@ from typing import Dict import numpy as np import torch -from apex import amp +from torch.nn.parallel import DataParallel, DistributedDataParallel import fastreid.utils.comm as comm from fastreid.utils.events import EventStorage, get_event_storage @@ -97,9 +97,10 @@ class TrainerBase: We made no assumptions about the existence of dataloader, optimizer, model, etc. Attributes: iter(int): the current iteration. + epoch(int): the current epoch. start_iter(int): The iteration to start with. By convention the minimum possible value is 0. - max_iter(int): The iteration to end training. + max_epoch (int): The epoch to end training. storage(EventStorage): An EventStorage that's opened during the course of training. """ @@ -126,7 +127,7 @@ class TrainerBase: def train(self, start_epoch: int, max_epoch: int, iters_per_epoch: int): """ Args: - start_iter, max_iter (int): See docs above + start_epoch, max_epoch (int): See docs above """ logger = logging.getLogger(__name__) logger.info("Starting training from epoch {}".format(start_epoch)) @@ -298,29 +299,50 @@ class SimpleTrainer(TrainerBase): class AMPTrainer(SimpleTrainer): """ - Like :class:`SimpleTrainer`, but uses apex automatic mixed precision + Like :class:`SimpleTrainer`, but uses automatic mixed precision in the training loop. """ + def __init__(self, model, data_loader, optimizer, grad_scaler=None): + """ + + Args: + model, data_loader, optimizer: same as in :class:`SimpleTrainer`. + grad_scaler: torch GradScaler to automatically scale gradients. + """ + unsupported = "AMPTrainer does not support single-process multi-device training!" + if isinstance(model, DistributedDataParallel): + assert not (model.device_ids and len(model.device_ids) > 1), unsupported + assert not isinstance(model, DataParallel), unsupported + + super().__init__(model, data_loader, optimizer) + + if grad_scaler is None: + from torch.cuda.amp import GradScaler + + grad_scaler = GradScaler() + self.grad_scaler = grad_scaler + def run_step(self): """ Implement the AMP training logic. """ assert self.model.training, "[AMPTrainer] model was changed to eval mode!" assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!" + from torch.cuda.amp import autocast start = time.perf_counter() data = next(self._data_loader_iter) data_time = time.perf_counter() - start - loss_dict = self.model(data) - losses = sum(loss_dict.values()) + with autocast(): + loss_dict = self.model(data) + losses = sum(loss_dict.values()) self.optimizer.zero_grad() - - with amp.scale_loss(losses, self.optimizer) as scaled_loss: - scaled_loss.backward() + self.grad_scaler.scale(losses).backward() self._write_metrics(loss_dict, data_time) - self.optimizer.step() + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() diff --git a/fastreid/layers/batch_norm.py b/fastreid/layers/batch_norm.py index da8a733..70a3a0d 100644 --- a/fastreid/layers/batch_norm.py +++ b/fastreid/layers/batch_norm.py @@ -10,11 +10,6 @@ import torch import torch.nn.functional as F from torch import nn -try: - from apex import parallel -except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run model with syncBN") - __all__ = ["IBN", "get_norm"] @@ -28,7 +23,7 @@ class BatchNorm(nn.BatchNorm2d): self.bias.requires_grad_(not bias_freeze) -class SyncBatchNorm(parallel.SyncBatchNorm): +class SyncBatchNorm(nn.SyncBatchNorm): def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0, bias_init=0.0): super().__init__(num_features, eps=eps, momentum=momentum) diff --git a/fastreid/utils/checkpoint.py b/fastreid/utils/checkpoint.py index d40d545..6753fa7 100644 --- a/fastreid/utils/checkpoint.py +++ b/fastreid/utils/checkpoint.py @@ -12,16 +12,10 @@ import numpy as np import torch import torch.nn as nn from termcolor import colored -from torch.nn.parallel import DataParallel +from torch.nn.parallel import DistributedDataParallel, DataParallel from fastreid.utils.file_io import PathManager -try: - from apex.parallel import DistributedDataParallel -except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example if you want to" - "train with DDP") - class _IncompatibleKeys( NamedTuple(