diff --git a/README.md b/README.md index e23a99f..40d647c 100644 --- a/README.md +++ b/README.md @@ -8,13 +8,15 @@ FastReID is a research platform that implements state-of-the-art re-identificati ## What's New +- [June 2021] [Contiguous parameters](https://github.com/PhilJd/contiguous_pytorch_params) is supported, now it can + accelerate ~20%. - [May 2021] Vision Transformer backbone supported, see `configs/Market1501/bagtricks_vit.yml`. - [Apr 2021] Partial FC supported in [FastFace](projects/FastFace)! - [Jan 2021] TRT network definition APIs in [FastRT](projects/FastRT) has been released! Thanks for [Darren](https://github.com/TCHeish)'s contribution. - [Jan 2021] NAIC20(reid track) [1-st solution](projects/NAIC20) based on fastreid has been released! -- [Jan 2021] FastReID V1.0 has been released!🎉 -Support many tasks beyond reid, such image retrieval and face recognition. See [release notes](https://github.com/JDAI-CV/fast-reid/releases/tag/v1.0.0). +- [Jan 2021] FastReID V1.0 has been released!🎉 + Support many tasks beyond reid, such image retrieval and face recognition. See [release notes](https://github.com/JDAI-CV/fast-reid/releases/tag/v1.0.0). - [Oct 2020] Added the [Hyper-Parameter Optimization](projects/FastTune) based on fastreid. See `projects/FastTune`. - [Sep 2020] Added the [person attribute recognition](projects/FastAttr) based on fastreid. See `projects/FastAttr`. - [Sep 2020] Automatic Mixed Precision training is supported with `apex`. Set `cfg.SOLVER.FP16_ENABLED=True` to switch it on. diff --git a/fastreid/config/defaults.py b/fastreid/config/defaults.py index e6bc403..2eefca2 100644 --- a/fastreid/config/defaults.py +++ b/fastreid/config/defaults.py @@ -242,7 +242,7 @@ _C.SOLVER.NESTEROV = False _C.SOLVER.WEIGHT_DECAY = 0.0005 # The weight decay that's applied to parameters of normalization layers # (typically the affine transformation) -_C.SOLVER.WEIGHT_DECAY_NORM = 0.0 +_C.SOLVER.WEIGHT_DECAY_NORM = 0.0005 # The previous detection code used a 2x higher LR and 0 WD for bias. # This is not useful (at least for recent models). You should avoid diff --git a/fastreid/engine/defaults.py b/fastreid/engine/defaults.py index ae0912f..e1ad678 100644 --- a/fastreid/engine/defaults.py +++ b/fastreid/engine/defaults.py @@ -201,7 +201,7 @@ class DefaultTrainer(TrainerBase): data_loader = self.build_train_loader(cfg) cfg = self.auto_scale_hyperparams(cfg, data_loader.dataset.num_classes) model = self.build_model(cfg) - optimizer = self.build_optimizer(cfg, model) + optimizer, param_wrapper = self.build_optimizer(cfg, model) # For training, wrap with DDP. But don't need this for inference. if comm.get_world_size() > 1: @@ -212,7 +212,7 @@ class DefaultTrainer(TrainerBase): ) self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)( - model, data_loader, optimizer + model, data_loader, optimizer, param_wrapper ) self.iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH diff --git a/fastreid/engine/train_loop.py b/fastreid/engine/train_loop.py index 4622fb1..f22ac05 100644 --- a/fastreid/engine/train_loop.py +++ b/fastreid/engine/train_loop.py @@ -15,6 +15,7 @@ from torch.nn.parallel import DataParallel, DistributedDataParallel import fastreid.utils.comm as comm from fastreid.utils.events import EventStorage, get_event_storage +from fastreid.utils.params import ContiguousParams __all__ = ["HookBase", "TrainerBase", "SimpleTrainer"] @@ -197,7 +198,7 @@ class SimpleTrainer(TrainerBase): or write your own training loop. """ - def __init__(self, model, data_loader, optimizer): + def __init__(self, model, data_loader, optimizer, param_wrapper): """ Args: model: a torch Module. Takes a data from data_loader and returns a @@ -219,6 +220,7 @@ class SimpleTrainer(TrainerBase): self.data_loader = data_loader self._data_loader_iter = iter(data_loader) self.optimizer = optimizer + self.param_wrapper = param_wrapper def run_step(self): """ @@ -254,6 +256,8 @@ class SimpleTrainer(TrainerBase): wrap the optimizer with your custom `step()` method. """ self.optimizer.step() + if isinstance(self.param_wrapper, ContiguousParams): + self.param_wrapper.assert_buffer_is_valid() def _write_metrics(self, loss_dict: Dict[str, torch.Tensor], data_time: float): """ @@ -303,7 +307,7 @@ class AMPTrainer(SimpleTrainer): in the training loop. """ - def __init__(self, model, data_loader, optimizer, grad_scaler=None): + def __init__(self, model, data_loader, optimizer, param_wrapper, grad_scaler=None): """ Args: @@ -315,7 +319,7 @@ class AMPTrainer(SimpleTrainer): assert not (model.device_ids and len(model.device_ids) > 1), unsupported assert not isinstance(model, DataParallel), unsupported - super().__init__(model, data_loader, optimizer) + super().__init__(model, data_loader, optimizer, param_wrapper) if grad_scaler is None: from torch.cuda.amp import GradScaler @@ -346,3 +350,5 @@ class AMPTrainer(SimpleTrainer): self.grad_scaler.step(self.optimizer) self.grad_scaler.update() + if isinstance(self.param_wrapper, ContiguousParams): + self.param_wrapper.assert_buffer_is_valid() diff --git a/fastreid/solver/build.py b/fastreid/solver/build.py index bd29b0e..2933758 100644 --- a/fastreid/solver/build.py +++ b/fastreid/solver/build.py @@ -9,12 +9,14 @@ import copy import itertools import math +import re from enum import Enum from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union import torch from fastreid.config import CfgNode +from fastreid.utils.params import ContiguousParams from . import lr_scheduler _GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]] @@ -60,6 +62,7 @@ def _generate_optimizer_class_with_gradient_clipping( per_param_clipper is None or global_clipper is None ), "Not allowed to use both per-parameter clipping and global clipping" + @torch.no_grad() def optimizer_wgc_step(self, closure=None): if per_param_clipper is not None: for group in self.param_groups: @@ -117,26 +120,31 @@ def maybe_add_gradient_clipping( def _generate_optimizer_class_with_freeze_layer( optimizer: Type[torch.optim.Optimizer], *, - freeze_layers: Optional[List] = None, freeze_iters: int = 0, ) -> Type[torch.optim.Optimizer]: - assert ( - freeze_layers is not None and freeze_iters > 0 - ), "No layers need to be frozen or freeze iterations is 0" + assert freeze_iters > 0, "No layers need to be frozen or freeze iterations is 0" cnt = 0 - + @torch.no_grad() def optimizer_wfl_step(self, closure=None): nonlocal cnt if cnt < freeze_iters: cnt += 1 + param_ref = [] + grad_ref = [] for group in self.param_groups: - if group["name"].split('.')[0] in freeze_layers: + if group["freeze_status"] == "freeze": for p in group["params"]: if p.grad is not None: + param_ref.append(p) + grad_ref.append(p.grad) p.grad = None - optimizer.step(self, closure) + optimizer.step(self, closure) + for p, g in zip(param_ref, grad_ref): + p.grad = g + else: + optimizer.step(self, closure) OptimizerWithFreezeLayer = type( optimizer.__name__ + "WithFreezeLayer", @@ -149,7 +157,7 @@ def _generate_optimizer_class_with_freeze_layer( def maybe_add_freeze_layer( cfg: CfgNode, optimizer: Type[torch.optim.Optimizer] ) -> Type[torch.optim.Optimizer]: - if len(cfg.MODEL.FREEZE_LAYERS) == 0 or cfg.SOLVER.FREEZE_ITERS == 0: + if len(cfg.MODEL.FREEZE_LAYERS) == 0 or cfg.SOLVER.FREEZE_ITERS <= 0: return optimizer if isinstance(optimizer, torch.optim.Optimizer): @@ -160,7 +168,6 @@ def maybe_add_freeze_layer( OptimizerWithFreezeLayer = _generate_optimizer_class_with_freeze_layer( optimizer_type, - freeze_layers=cfg.MODEL.FREEZE_LAYERS, freeze_iters=cfg.SOLVER.FREEZE_ITERS ) if isinstance(optimizer, torch.optim.Optimizer): @@ -170,37 +177,35 @@ def maybe_add_freeze_layer( return OptimizerWithFreezeLayer -def build_optimizer(cfg, model): +def build_optimizer(cfg, model, contiguous=True): params = get_default_optimizer_params( model, base_lr=cfg.SOLVER.BASE_LR, + weight_decay=cfg.SOLVER.WEIGHT_DECAY, weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM, bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR, heads_lr_factor=cfg.SOLVER.HEADS_LR_FACTOR, - weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS + weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS, + freeze_layers=cfg.MODEL.FREEZE_LAYERS if cfg.SOLVER.FREEZE_ITERS > 0 else [], ) + if contiguous: + params = ContiguousParams(params) solver_opt = cfg.SOLVER.OPT if solver_opt == "SGD": return maybe_add_freeze_layer( cfg, maybe_add_gradient_clipping(cfg, torch.optim.SGD) )( - params, - lr=cfg.SOLVER.BASE_LR, + params.contiguous() if contiguous else params, momentum=cfg.SOLVER.MOMENTUM, nesterov=cfg.SOLVER.NESTEROV, - weight_decay=cfg.SOLVER.WEIGHT_DECAY, - ) + ), params else: return maybe_add_freeze_layer( cfg, maybe_add_gradient_clipping(cfg, getattr(torch.optim, solver_opt)) - )( - params, - lr=cfg.SOLVER.BASE_LR, - weight_decay=cfg.SOLVER.WEIGHT_DECAY, - ) + )(params.contiguous() if contiguous else params), params def get_default_optimizer_params( @@ -212,6 +217,7 @@ def get_default_optimizer_params( heads_lr_factor: Optional[float] = 1.0, weight_decay_bias: Optional[float] = None, overrides: Optional[Dict[str, Dict[str, float]]] = None, + freeze_layers: Optional[list] = [], ): """ Get default param list for optimizer, with support for a few types of @@ -228,6 +234,7 @@ def get_default_optimizer_params( (LR, weight decay) for module parameters with a given name; e.g. ``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and weight decay values for all module parameters named `embedding`. + freeze_layers: layer names for freezing. For common detection models, ``weight_decay_norm`` is the only option needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings from Detectron1 that are not found useful. @@ -257,6 +264,8 @@ def get_default_optimizer_params( raise ValueError("Conflicting overrides for 'bias'") overrides["bias"] = bias_overrides + layer_names_pattern = [re.compile(name) for name in freeze_layers] + norm_module_types = ( torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, @@ -288,8 +297,15 @@ def get_default_optimizer_params( hyperparams.update(overrides.get(module_param_name, {})) if module_name.split('.')[0] == "heads" and (heads_lr_factor is not None and heads_lr_factor != 1.0): hyperparams["lr"] = hyperparams.get("lr", base_lr) * heads_lr_factor - params.append({"name": module_name + '.' + module_param_name, - "params": [value], **hyperparams}) + name = module_name + '.' + module_param_name + freeze_status = "normal" + # Search freeze layer names, it must match from beginning, so use `match` not `search` + for pattern in layer_names_pattern: + if pattern.match(name) is not None: + freeze_status = "freeze" + break + + params.append({"freeze_status": freeze_status, "params": [value], **hyperparams}) return params diff --git a/fastreid/utils/params.py b/fastreid/utils/params.py new file mode 100644 index 0000000..8b05e4d --- /dev/null +++ b/fastreid/utils/params.py @@ -0,0 +1,103 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: sherlockliao01@gmail.com +""" + +# based on: https://github.com/PhilJd/contiguous_pytorch_params/blob/master/contiguous_params/params.py + +from collections import OrderedDict + +import torch + + +class ContiguousParams: + + def __init__(self, parameters): + # Create a list of the parameters to prevent emptying an iterator. + self._parameters = parameters + self._param_buffer = [] + self._grad_buffer = [] + self._group_dict = OrderedDict() + self._name_buffer = [] + self._init_buffers() + # Store the data pointers for each parameter into the buffer. These + # can be used to check if an operation overwrites the gradient/data + # tensor (invalidating the assumption of a contiguous buffer). + self.data_pointers = [] + self.grad_pointers = [] + self.make_params_contiguous() + + def _init_buffers(self): + dtype = self._parameters[0]["params"][0].dtype + device = self._parameters[0]["params"][0].device + if not all(p["params"][0].dtype == dtype for p in self._parameters): + raise ValueError("All parameters must be of the same dtype.") + if not all(p["params"][0].device == device for p in self._parameters): + raise ValueError("All parameters must be on the same device.") + + # Group parameters by lr and weight decay + for param_dict in self._parameters: + freeze_status = param_dict["freeze_status"] + param_key = freeze_status + '_' + str(param_dict["lr"]) + '_' + str(param_dict["weight_decay"]) + if param_key not in self._group_dict: + self._group_dict[param_key] = [] + self._group_dict[param_key].append(param_dict) + + for key, params in self._group_dict.items(): + size = sum(p["params"][0].numel() for p in params) + self._param_buffer.append(torch.zeros(size, dtype=dtype, device=device)) + self._grad_buffer.append(torch.zeros(size, dtype=dtype, device=device)) + self._name_buffer.append(key) + + def make_params_contiguous(self): + """Create a buffer to hold all params and update the params to be views of the buffer. + Args: + parameters: An iterable of parameters. + """ + for i, params in enumerate(self._group_dict.values()): + index = 0 + for param_dict in params: + p = param_dict["params"][0] + size = p.numel() + self._param_buffer[i][index:index + size] = p.data.view(-1) + p.data = self._param_buffer[i][index:index + size].view(p.data.shape) + p.grad = self._grad_buffer[i][index:index + size].view(p.data.shape) + self.data_pointers.append(p.data.data_ptr) + self.grad_pointers.append(p.grad.data.data_ptr) + index += size + # Bend the param_buffer to use grad_buffer to track its gradients. + self._param_buffer[i].grad = self._grad_buffer[i] + + def contiguous(self): + """Return all parameters as one contiguous buffer.""" + return [{ + "freeze_status": self._name_buffer[i].split('_')[0], + "params": self._param_buffer[i], + "lr": float(self._name_buffer[i].split('_')[1]), + "weight_decay": float(self._name_buffer[i].split('_')[2]), + } for i in range(len(self._param_buffer))] + + def original(self): + """Return the non-flattened parameters.""" + return self._parameters + + def buffer_is_valid(self): + """Verify that all parameters and gradients still use the buffer.""" + i = 0 + for params in self._group_dict.values(): + for param_dict in params: + p = param_dict["params"][0] + data_ptr = self.data_pointers[i] + grad_ptr = self.grad_pointers[i] + if (p.data.data_ptr() != data_ptr()) or (p.grad.data.data_ptr() != grad_ptr()): + return False + i += 1 + return True + + def assert_buffer_is_valid(self): + if not self.buffer_is_valid(): + raise ValueError( + "The data or gradient buffer has been invalidated. Please make " + "sure to use inplace operations only when updating parameters " + "or gradients.") diff --git a/projects/FastClas/configs/base-clas.yaml b/projects/FastClas/configs/base-clas.yaml index bff7e9a..029df28 100644 --- a/projects/FastClas/configs/base-clas.yaml +++ b/projects/FastClas/configs/base-clas.yaml @@ -56,7 +56,7 @@ SOLVER: BIAS_LR_FACTOR: 1. WEIGHT_DECAY: 0.0005 WEIGHT_DECAY_BIAS: 0. - IMS_PER_BATCH: 4 + IMS_PER_BATCH: 16 ETA_MIN_LR: 0.00003 diff --git a/projects/FastFace/fastface/trainer.py b/projects/FastFace/fastface/trainer.py index 90db184..6561841 100644 --- a/projects/FastFace/fastface/trainer.py +++ b/projects/FastFace/fastface/trainer.py @@ -16,9 +16,11 @@ from fastreid.data.transforms import build_transforms from fastreid.engine import hooks from fastreid.engine.defaults import DefaultTrainer, TrainerBase from fastreid.engine.train_loop import SimpleTrainer, AMPTrainer +from fastreid.solver import build_optimizer from fastreid.utils import comm from fastreid.utils.checkpoint import Checkpointer from fastreid.utils.logger import setup_logger +from fastreid.utils.params import ContiguousParams from .face_data import MXFaceDataset from .face_data import TestFaceDataset from .face_evaluator import FaceEvaluator @@ -39,7 +41,7 @@ class FaceTrainer(DefaultTrainer): data_loader = self.build_train_loader(cfg) cfg = self.auto_scale_hyperparams(cfg, data_loader.dataset.num_classes) model = self.build_model(cfg) - optimizer = self.build_optimizer(cfg, model) + optimizer, param_wrapper = self.build_optimizer(cfg, model) if cfg.MODEL.HEADS.PFC.ENABLED: # fmt: off @@ -54,7 +56,7 @@ class FaceTrainer(DefaultTrainer): # Partial-FC module embedding_size = embedding_dim if embedding_dim > 0 else feat_dim self.pfc_module = PartialFC(embedding_size, num_classes, sample_rate, cls_type, scale, margin) - self.pfc_optimizer = self.build_optimizer(cfg, self.pfc_module) + self.pfc_optimizer, _ = build_optimizer(cfg, self.pfc_module, False) # For training, wrap with DDP. But don't need this for inference. if comm.get_world_size() > 1: @@ -67,11 +69,11 @@ class FaceTrainer(DefaultTrainer): if cfg.MODEL.HEADS.PFC.ENABLED: mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size() grad_scaler = MaxClipGradScaler(mini_batch_size, 128 * mini_batch_size, growth_interval=100) - self._trainer = PFCTrainer(model, data_loader, optimizer, + self._trainer = PFCTrainer(model, data_loader, optimizer, param_wrapper, self.pfc_module, self.pfc_optimizer, cfg.SOLVER.AMP.ENABLED, grad_scaler) else: self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)( - model, data_loader, optimizer + model, data_loader, optimizer, param_wrapper ) self.iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH @@ -124,7 +126,8 @@ class FaceTrainer(DefaultTrainer): # Backbone loading state_dict super().resume_or_load(resume) # Partial-FC loading state_dict - self.pfc_checkpointer.resume_or_load('', resume=resume) + if self.cfg.MODEL.HEADS.PFC.ENABLED: + self.pfc_checkpointer.resume_or_load('', resume=resume) @classmethod def build_train_loader(cls, cfg): @@ -161,8 +164,9 @@ class PFCTrainer(SimpleTrainer): https://github.com/deepinsight/insightface/blob/master/recognition/arcface_torch/partial_fc.py """ - def __init__(self, model, data_loader, optimizer, pfc_module, pfc_optimizer, amp_enabled, grad_scaler): - super().__init__(model, data_loader, optimizer) + def __init__(self, model, data_loader, optimizer, param_wrapper, pfc_module, pfc_optimizer, amp_enabled, + grad_scaler): + super().__init__(model, data_loader, optimizer, param_wrapper) self.pfc_module = pfc_module self.pfc_optimizer = pfc_optimizer @@ -200,3 +204,5 @@ class PFCTrainer(SimpleTrainer): self.pfc_module.update() self.optimizer.zero_grad() self.pfc_optimizer.zero_grad() + if isinstance(self.param_wrapper, ContiguousParams): + self.param_wrapper.assert_buffer_is_valid()