remove apex dependency ()

Summary: Use Pytorch1.6(or above) built-in amp training
pull/443/head
Xingyu Liao 2021-03-23 12:12:35 +08:00 committed by GitHub
parent 883fd4aede
commit 15c556c43a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 41 additions and 42 deletions

View File

@ -31,14 +31,8 @@ from fastreid.utils.file_io import PathManager
from fastreid.utils.logger import setup_logger
from . import hooks
from .train_loop import TrainerBase, AMPTrainer, SimpleTrainer
from torch.nn.parallel import DistributedDataParallel
try:
import apex
from apex import amp
from apex.parallel import DistributedDataParallel
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example if you want to"
"train with DDP")
__all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]
@ -214,19 +208,14 @@ class DefaultTrainer(TrainerBase):
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
optimizer_ckpt = dict(optimizer=optimizer)
if cfg.SOLVER.FP16_ENABLED:
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
optimizer_ckpt.update(dict(amp=amp))
# For training, wrap with DDP. But don't need this for inference.
if comm.get_world_size() > 1:
# ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
# for part of the parameters is not updated.
# model = DistributedDataParallel(
# model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
# )
model = DistributedDataParallel(model, delay_allreduce=True)
model = DistributedDataParallel(
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False,
find_unused_parameters=True
)
self._trainer = (AMPTrainer if cfg.SOLVER.FP16_ENABLED else SimpleTrainer)(
model, data_loader, optimizer
@ -242,7 +231,7 @@ class DefaultTrainer(TrainerBase):
model,
cfg.OUTPUT_DIR,
save_to_disk=comm.is_main_process(),
**optimizer_ckpt,
optimizer=optimizer,
**self.scheduler,
)

View File

@ -10,8 +10,7 @@ import time
from collections import Counter
import torch
from apex.parallel import DistributedDataParallel
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from fastreid.evaluation.testing import flatten_results_dict
from fastreid.solver import optim

View File

@ -11,7 +11,7 @@ from typing import Dict
import numpy as np
import torch
from apex import amp
from torch.nn.parallel import DataParallel, DistributedDataParallel
import fastreid.utils.comm as comm
from fastreid.utils.events import EventStorage, get_event_storage
@ -97,9 +97,10 @@ class TrainerBase:
We made no assumptions about the existence of dataloader, optimizer, model, etc.
Attributes:
iter(int): the current iteration.
epoch(int): the current epoch.
start_iter(int): The iteration to start with.
By convention the minimum possible value is 0.
max_iter(int): The iteration to end training.
max_epoch (int): The epoch to end training.
storage(EventStorage): An EventStorage that's opened during the course of training.
"""
@ -126,7 +127,7 @@ class TrainerBase:
def train(self, start_epoch: int, max_epoch: int, iters_per_epoch: int):
"""
Args:
start_iter, max_iter (int): See docs above
start_epoch, max_epoch (int): See docs above
"""
logger = logging.getLogger(__name__)
logger.info("Starting training from epoch {}".format(start_epoch))
@ -298,29 +299,50 @@ class SimpleTrainer(TrainerBase):
class AMPTrainer(SimpleTrainer):
"""
Like :class:`SimpleTrainer`, but uses apex automatic mixed precision
Like :class:`SimpleTrainer`, but uses automatic mixed precision
in the training loop.
"""
def __init__(self, model, data_loader, optimizer, grad_scaler=None):
"""
Args:
model, data_loader, optimizer: same as in :class:`SimpleTrainer`.
grad_scaler: torch GradScaler to automatically scale gradients.
"""
unsupported = "AMPTrainer does not support single-process multi-device training!"
if isinstance(model, DistributedDataParallel):
assert not (model.device_ids and len(model.device_ids) > 1), unsupported
assert not isinstance(model, DataParallel), unsupported
super().__init__(model, data_loader, optimizer)
if grad_scaler is None:
from torch.cuda.amp import GradScaler
grad_scaler = GradScaler()
self.grad_scaler = grad_scaler
def run_step(self):
"""
Implement the AMP training logic.
"""
assert self.model.training, "[AMPTrainer] model was changed to eval mode!"
assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!"
from torch.cuda.amp import autocast
start = time.perf_counter()
data = next(self._data_loader_iter)
data_time = time.perf_counter() - start
loss_dict = self.model(data)
losses = sum(loss_dict.values())
with autocast():
loss_dict = self.model(data)
losses = sum(loss_dict.values())
self.optimizer.zero_grad()
with amp.scale_loss(losses, self.optimizer) as scaled_loss:
scaled_loss.backward()
self.grad_scaler.scale(losses).backward()
self._write_metrics(loss_dict, data_time)
self.optimizer.step()
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()

View File

@ -10,11 +10,6 @@ import torch
import torch.nn.functional as F
from torch import nn
try:
from apex import parallel
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run model with syncBN")
__all__ = ["IBN", "get_norm"]
@ -28,7 +23,7 @@ class BatchNorm(nn.BatchNorm2d):
self.bias.requires_grad_(not bias_freeze)
class SyncBatchNorm(parallel.SyncBatchNorm):
class SyncBatchNorm(nn.SyncBatchNorm):
def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
bias_init=0.0):
super().__init__(num_features, eps=eps, momentum=momentum)

View File

@ -12,16 +12,10 @@ import numpy as np
import torch
import torch.nn as nn
from termcolor import colored
from torch.nn.parallel import DataParallel
from torch.nn.parallel import DistributedDataParallel, DataParallel
from fastreid.utils.file_io import PathManager
try:
from apex.parallel import DistributedDataParallel
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example if you want to"
"train with DDP")
class _IncompatibleKeys(
NamedTuple(