mirror of https://github.com/JDAI-CV/fast-reid.git
parent
883fd4aede
commit
15c556c43a
fastreid
|
@ -31,14 +31,8 @@ from fastreid.utils.file_io import PathManager
|
|||
from fastreid.utils.logger import setup_logger
|
||||
from . import hooks
|
||||
from .train_loop import TrainerBase, AMPTrainer, SimpleTrainer
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
try:
|
||||
import apex
|
||||
from apex import amp
|
||||
from apex.parallel import DistributedDataParallel
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example if you want to"
|
||||
"train with DDP")
|
||||
|
||||
__all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]
|
||||
|
||||
|
@ -214,19 +208,14 @@ class DefaultTrainer(TrainerBase):
|
|||
model = self.build_model(cfg)
|
||||
optimizer = self.build_optimizer(cfg, model)
|
||||
|
||||
optimizer_ckpt = dict(optimizer=optimizer)
|
||||
if cfg.SOLVER.FP16_ENABLED:
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
|
||||
optimizer_ckpt.update(dict(amp=amp))
|
||||
|
||||
# For training, wrap with DDP. But don't need this for inference.
|
||||
if comm.get_world_size() > 1:
|
||||
# ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
|
||||
# for part of the parameters is not updated.
|
||||
# model = DistributedDataParallel(
|
||||
# model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
|
||||
# )
|
||||
model = DistributedDataParallel(model, delay_allreduce=True)
|
||||
model = DistributedDataParallel(
|
||||
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False,
|
||||
find_unused_parameters=True
|
||||
)
|
||||
|
||||
self._trainer = (AMPTrainer if cfg.SOLVER.FP16_ENABLED else SimpleTrainer)(
|
||||
model, data_loader, optimizer
|
||||
|
@ -242,7 +231,7 @@ class DefaultTrainer(TrainerBase):
|
|||
model,
|
||||
cfg.OUTPUT_DIR,
|
||||
save_to_disk=comm.is_main_process(),
|
||||
**optimizer_ckpt,
|
||||
optimizer=optimizer,
|
||||
**self.scheduler,
|
||||
)
|
||||
|
||||
|
|
|
@ -10,8 +10,7 @@ import time
|
|||
from collections import Counter
|
||||
|
||||
import torch
|
||||
from apex.parallel import DistributedDataParallel
|
||||
from torch import nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
from fastreid.evaluation.testing import flatten_results_dict
|
||||
from fastreid.solver import optim
|
||||
|
|
|
@ -11,7 +11,7 @@ from typing import Dict
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from apex import amp
|
||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||
|
||||
import fastreid.utils.comm as comm
|
||||
from fastreid.utils.events import EventStorage, get_event_storage
|
||||
|
@ -97,9 +97,10 @@ class TrainerBase:
|
|||
We made no assumptions about the existence of dataloader, optimizer, model, etc.
|
||||
Attributes:
|
||||
iter(int): the current iteration.
|
||||
epoch(int): the current epoch.
|
||||
start_iter(int): The iteration to start with.
|
||||
By convention the minimum possible value is 0.
|
||||
max_iter(int): The iteration to end training.
|
||||
max_epoch (int): The epoch to end training.
|
||||
storage(EventStorage): An EventStorage that's opened during the course of training.
|
||||
"""
|
||||
|
||||
|
@ -126,7 +127,7 @@ class TrainerBase:
|
|||
def train(self, start_epoch: int, max_epoch: int, iters_per_epoch: int):
|
||||
"""
|
||||
Args:
|
||||
start_iter, max_iter (int): See docs above
|
||||
start_epoch, max_epoch (int): See docs above
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Starting training from epoch {}".format(start_epoch))
|
||||
|
@ -298,29 +299,50 @@ class SimpleTrainer(TrainerBase):
|
|||
|
||||
class AMPTrainer(SimpleTrainer):
|
||||
"""
|
||||
Like :class:`SimpleTrainer`, but uses apex automatic mixed precision
|
||||
Like :class:`SimpleTrainer`, but uses automatic mixed precision
|
||||
in the training loop.
|
||||
"""
|
||||
|
||||
def __init__(self, model, data_loader, optimizer, grad_scaler=None):
|
||||
"""
|
||||
|
||||
Args:
|
||||
model, data_loader, optimizer: same as in :class:`SimpleTrainer`.
|
||||
grad_scaler: torch GradScaler to automatically scale gradients.
|
||||
"""
|
||||
unsupported = "AMPTrainer does not support single-process multi-device training!"
|
||||
if isinstance(model, DistributedDataParallel):
|
||||
assert not (model.device_ids and len(model.device_ids) > 1), unsupported
|
||||
assert not isinstance(model, DataParallel), unsupported
|
||||
|
||||
super().__init__(model, data_loader, optimizer)
|
||||
|
||||
if grad_scaler is None:
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
grad_scaler = GradScaler()
|
||||
self.grad_scaler = grad_scaler
|
||||
|
||||
def run_step(self):
|
||||
"""
|
||||
Implement the AMP training logic.
|
||||
"""
|
||||
assert self.model.training, "[AMPTrainer] model was changed to eval mode!"
|
||||
assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!"
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
start = time.perf_counter()
|
||||
data = next(self._data_loader_iter)
|
||||
data_time = time.perf_counter() - start
|
||||
|
||||
loss_dict = self.model(data)
|
||||
losses = sum(loss_dict.values())
|
||||
with autocast():
|
||||
loss_dict = self.model(data)
|
||||
losses = sum(loss_dict.values())
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
with amp.scale_loss(losses, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
self.grad_scaler.scale(losses).backward()
|
||||
|
||||
self._write_metrics(loss_dict, data_time)
|
||||
|
||||
self.optimizer.step()
|
||||
self.grad_scaler.step(self.optimizer)
|
||||
self.grad_scaler.update()
|
||||
|
|
|
@ -10,11 +10,6 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
try:
|
||||
from apex import parallel
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run model with syncBN")
|
||||
|
||||
__all__ = ["IBN", "get_norm"]
|
||||
|
||||
|
||||
|
@ -28,7 +23,7 @@ class BatchNorm(nn.BatchNorm2d):
|
|||
self.bias.requires_grad_(not bias_freeze)
|
||||
|
||||
|
||||
class SyncBatchNorm(parallel.SyncBatchNorm):
|
||||
class SyncBatchNorm(nn.SyncBatchNorm):
|
||||
def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
|
||||
bias_init=0.0):
|
||||
super().__init__(num_features, eps=eps, momentum=momentum)
|
||||
|
|
|
@ -12,16 +12,10 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from termcolor import colored
|
||||
from torch.nn.parallel import DataParallel
|
||||
from torch.nn.parallel import DistributedDataParallel, DataParallel
|
||||
|
||||
from fastreid.utils.file_io import PathManager
|
||||
|
||||
try:
|
||||
from apex.parallel import DistributedDataParallel
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example if you want to"
|
||||
"train with DDP")
|
||||
|
||||
|
||||
class _IncompatibleKeys(
|
||||
NamedTuple(
|
||||
|
|
Loading…
Reference in New Issue