mirror of https://github.com/JDAI-CV/fast-reid.git
feat: Add contiguous parameters support
Support contiguous parameters to train faster. It can split parameters into different contiguous groups by freeze_layer, lr and weight decay.pull/542/head
parent
44d1e04e9a
commit
7e652fea2a
|
@ -8,13 +8,15 @@ FastReID is a research platform that implements state-of-the-art re-identificati
|
|||
|
||||
## What's New
|
||||
|
||||
- [June 2021] [Contiguous parameters](https://github.com/PhilJd/contiguous_pytorch_params) is supported, now it can
|
||||
accelerate ~20%.
|
||||
- [May 2021] Vision Transformer backbone supported, see `configs/Market1501/bagtricks_vit.yml`.
|
||||
- [Apr 2021] Partial FC supported in [FastFace](projects/FastFace)!
|
||||
- [Jan 2021] TRT network definition APIs in [FastRT](projects/FastRT) has been released!
|
||||
Thanks for [Darren](https://github.com/TCHeish)'s contribution.
|
||||
- [Jan 2021] NAIC20(reid track) [1-st solution](projects/NAIC20) based on fastreid has been released!
|
||||
- [Jan 2021] FastReID V1.0 has been released!🎉
|
||||
Support many tasks beyond reid, such image retrieval and face recognition. See [release notes](https://github.com/JDAI-CV/fast-reid/releases/tag/v1.0.0).
|
||||
- [Jan 2021] FastReID V1.0 has been released!🎉
|
||||
Support many tasks beyond reid, such image retrieval and face recognition. See [release notes](https://github.com/JDAI-CV/fast-reid/releases/tag/v1.0.0).
|
||||
- [Oct 2020] Added the [Hyper-Parameter Optimization](projects/FastTune) based on fastreid. See `projects/FastTune`.
|
||||
- [Sep 2020] Added the [person attribute recognition](projects/FastAttr) based on fastreid. See `projects/FastAttr`.
|
||||
- [Sep 2020] Automatic Mixed Precision training is supported with `apex`. Set `cfg.SOLVER.FP16_ENABLED=True` to switch it on.
|
||||
|
|
|
@ -242,7 +242,7 @@ _C.SOLVER.NESTEROV = False
|
|||
_C.SOLVER.WEIGHT_DECAY = 0.0005
|
||||
# The weight decay that's applied to parameters of normalization layers
|
||||
# (typically the affine transformation)
|
||||
_C.SOLVER.WEIGHT_DECAY_NORM = 0.0
|
||||
_C.SOLVER.WEIGHT_DECAY_NORM = 0.0005
|
||||
|
||||
# The previous detection code used a 2x higher LR and 0 WD for bias.
|
||||
# This is not useful (at least for recent models). You should avoid
|
||||
|
|
|
@ -201,7 +201,7 @@ class DefaultTrainer(TrainerBase):
|
|||
data_loader = self.build_train_loader(cfg)
|
||||
cfg = self.auto_scale_hyperparams(cfg, data_loader.dataset.num_classes)
|
||||
model = self.build_model(cfg)
|
||||
optimizer = self.build_optimizer(cfg, model)
|
||||
optimizer, param_wrapper = self.build_optimizer(cfg, model)
|
||||
|
||||
# For training, wrap with DDP. But don't need this for inference.
|
||||
if comm.get_world_size() > 1:
|
||||
|
@ -212,7 +212,7 @@ class DefaultTrainer(TrainerBase):
|
|||
)
|
||||
|
||||
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
|
||||
model, data_loader, optimizer
|
||||
model, data_loader, optimizer, param_wrapper
|
||||
)
|
||||
|
||||
self.iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
|
||||
|
|
|
@ -15,6 +15,7 @@ from torch.nn.parallel import DataParallel, DistributedDataParallel
|
|||
|
||||
import fastreid.utils.comm as comm
|
||||
from fastreid.utils.events import EventStorage, get_event_storage
|
||||
from fastreid.utils.params import ContiguousParams
|
||||
|
||||
__all__ = ["HookBase", "TrainerBase", "SimpleTrainer"]
|
||||
|
||||
|
@ -197,7 +198,7 @@ class SimpleTrainer(TrainerBase):
|
|||
or write your own training loop.
|
||||
"""
|
||||
|
||||
def __init__(self, model, data_loader, optimizer):
|
||||
def __init__(self, model, data_loader, optimizer, param_wrapper):
|
||||
"""
|
||||
Args:
|
||||
model: a torch Module. Takes a data from data_loader and returns a
|
||||
|
@ -219,6 +220,7 @@ class SimpleTrainer(TrainerBase):
|
|||
self.data_loader = data_loader
|
||||
self._data_loader_iter = iter(data_loader)
|
||||
self.optimizer = optimizer
|
||||
self.param_wrapper = param_wrapper
|
||||
|
||||
def run_step(self):
|
||||
"""
|
||||
|
@ -254,6 +256,8 @@ class SimpleTrainer(TrainerBase):
|
|||
wrap the optimizer with your custom `step()` method.
|
||||
"""
|
||||
self.optimizer.step()
|
||||
if isinstance(self.param_wrapper, ContiguousParams):
|
||||
self.param_wrapper.assert_buffer_is_valid()
|
||||
|
||||
def _write_metrics(self, loss_dict: Dict[str, torch.Tensor], data_time: float):
|
||||
"""
|
||||
|
@ -303,7 +307,7 @@ class AMPTrainer(SimpleTrainer):
|
|||
in the training loop.
|
||||
"""
|
||||
|
||||
def __init__(self, model, data_loader, optimizer, grad_scaler=None):
|
||||
def __init__(self, model, data_loader, optimizer, param_wrapper, grad_scaler=None):
|
||||
"""
|
||||
|
||||
Args:
|
||||
|
@ -315,7 +319,7 @@ class AMPTrainer(SimpleTrainer):
|
|||
assert not (model.device_ids and len(model.device_ids) > 1), unsupported
|
||||
assert not isinstance(model, DataParallel), unsupported
|
||||
|
||||
super().__init__(model, data_loader, optimizer)
|
||||
super().__init__(model, data_loader, optimizer, param_wrapper)
|
||||
|
||||
if grad_scaler is None:
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
@ -346,3 +350,5 @@ class AMPTrainer(SimpleTrainer):
|
|||
|
||||
self.grad_scaler.step(self.optimizer)
|
||||
self.grad_scaler.update()
|
||||
if isinstance(self.param_wrapper, ContiguousParams):
|
||||
self.param_wrapper.assert_buffer_is_valid()
|
||||
|
|
|
@ -9,12 +9,14 @@
|
|||
import copy
|
||||
import itertools
|
||||
import math
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from fastreid.config import CfgNode
|
||||
from fastreid.utils.params import ContiguousParams
|
||||
from . import lr_scheduler
|
||||
|
||||
_GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]]
|
||||
|
@ -60,6 +62,7 @@ def _generate_optimizer_class_with_gradient_clipping(
|
|||
per_param_clipper is None or global_clipper is None
|
||||
), "Not allowed to use both per-parameter clipping and global clipping"
|
||||
|
||||
@torch.no_grad()
|
||||
def optimizer_wgc_step(self, closure=None):
|
||||
if per_param_clipper is not None:
|
||||
for group in self.param_groups:
|
||||
|
@ -117,26 +120,31 @@ def maybe_add_gradient_clipping(
|
|||
def _generate_optimizer_class_with_freeze_layer(
|
||||
optimizer: Type[torch.optim.Optimizer],
|
||||
*,
|
||||
freeze_layers: Optional[List] = None,
|
||||
freeze_iters: int = 0,
|
||||
) -> Type[torch.optim.Optimizer]:
|
||||
assert (
|
||||
freeze_layers is not None and freeze_iters > 0
|
||||
), "No layers need to be frozen or freeze iterations is 0"
|
||||
assert freeze_iters > 0, "No layers need to be frozen or freeze iterations is 0"
|
||||
|
||||
cnt = 0
|
||||
|
||||
@torch.no_grad()
|
||||
def optimizer_wfl_step(self, closure=None):
|
||||
nonlocal cnt
|
||||
if cnt < freeze_iters:
|
||||
cnt += 1
|
||||
param_ref = []
|
||||
grad_ref = []
|
||||
for group in self.param_groups:
|
||||
if group["name"].split('.')[0] in freeze_layers:
|
||||
if group["freeze_status"] == "freeze":
|
||||
for p in group["params"]:
|
||||
if p.grad is not None:
|
||||
param_ref.append(p)
|
||||
grad_ref.append(p.grad)
|
||||
p.grad = None
|
||||
|
||||
optimizer.step(self, closure)
|
||||
optimizer.step(self, closure)
|
||||
for p, g in zip(param_ref, grad_ref):
|
||||
p.grad = g
|
||||
else:
|
||||
optimizer.step(self, closure)
|
||||
|
||||
OptimizerWithFreezeLayer = type(
|
||||
optimizer.__name__ + "WithFreezeLayer",
|
||||
|
@ -149,7 +157,7 @@ def _generate_optimizer_class_with_freeze_layer(
|
|||
def maybe_add_freeze_layer(
|
||||
cfg: CfgNode, optimizer: Type[torch.optim.Optimizer]
|
||||
) -> Type[torch.optim.Optimizer]:
|
||||
if len(cfg.MODEL.FREEZE_LAYERS) == 0 or cfg.SOLVER.FREEZE_ITERS == 0:
|
||||
if len(cfg.MODEL.FREEZE_LAYERS) == 0 or cfg.SOLVER.FREEZE_ITERS <= 0:
|
||||
return optimizer
|
||||
|
||||
if isinstance(optimizer, torch.optim.Optimizer):
|
||||
|
@ -160,7 +168,6 @@ def maybe_add_freeze_layer(
|
|||
|
||||
OptimizerWithFreezeLayer = _generate_optimizer_class_with_freeze_layer(
|
||||
optimizer_type,
|
||||
freeze_layers=cfg.MODEL.FREEZE_LAYERS,
|
||||
freeze_iters=cfg.SOLVER.FREEZE_ITERS
|
||||
)
|
||||
if isinstance(optimizer, torch.optim.Optimizer):
|
||||
|
@ -170,37 +177,35 @@ def maybe_add_freeze_layer(
|
|||
return OptimizerWithFreezeLayer
|
||||
|
||||
|
||||
def build_optimizer(cfg, model):
|
||||
def build_optimizer(cfg, model, contiguous=True):
|
||||
params = get_default_optimizer_params(
|
||||
model,
|
||||
base_lr=cfg.SOLVER.BASE_LR,
|
||||
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
|
||||
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
|
||||
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
|
||||
heads_lr_factor=cfg.SOLVER.HEADS_LR_FACTOR,
|
||||
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS
|
||||
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
|
||||
freeze_layers=cfg.MODEL.FREEZE_LAYERS if cfg.SOLVER.FREEZE_ITERS > 0 else [],
|
||||
)
|
||||
|
||||
if contiguous:
|
||||
params = ContiguousParams(params)
|
||||
solver_opt = cfg.SOLVER.OPT
|
||||
if solver_opt == "SGD":
|
||||
return maybe_add_freeze_layer(
|
||||
cfg,
|
||||
maybe_add_gradient_clipping(cfg, torch.optim.SGD)
|
||||
)(
|
||||
params,
|
||||
lr=cfg.SOLVER.BASE_LR,
|
||||
params.contiguous() if contiguous else params,
|
||||
momentum=cfg.SOLVER.MOMENTUM,
|
||||
nesterov=cfg.SOLVER.NESTEROV,
|
||||
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
|
||||
)
|
||||
), params
|
||||
else:
|
||||
return maybe_add_freeze_layer(
|
||||
cfg,
|
||||
maybe_add_gradient_clipping(cfg, getattr(torch.optim, solver_opt))
|
||||
)(
|
||||
params,
|
||||
lr=cfg.SOLVER.BASE_LR,
|
||||
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
|
||||
)
|
||||
)(params.contiguous() if contiguous else params), params
|
||||
|
||||
|
||||
def get_default_optimizer_params(
|
||||
|
@ -212,6 +217,7 @@ def get_default_optimizer_params(
|
|||
heads_lr_factor: Optional[float] = 1.0,
|
||||
weight_decay_bias: Optional[float] = None,
|
||||
overrides: Optional[Dict[str, Dict[str, float]]] = None,
|
||||
freeze_layers: Optional[list] = [],
|
||||
):
|
||||
"""
|
||||
Get default param list for optimizer, with support for a few types of
|
||||
|
@ -228,6 +234,7 @@ def get_default_optimizer_params(
|
|||
(LR, weight decay) for module parameters with a given name; e.g.
|
||||
``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and
|
||||
weight decay values for all module parameters named `embedding`.
|
||||
freeze_layers: layer names for freezing.
|
||||
For common detection models, ``weight_decay_norm`` is the only option
|
||||
needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings
|
||||
from Detectron1 that are not found useful.
|
||||
|
@ -257,6 +264,8 @@ def get_default_optimizer_params(
|
|||
raise ValueError("Conflicting overrides for 'bias'")
|
||||
overrides["bias"] = bias_overrides
|
||||
|
||||
layer_names_pattern = [re.compile(name) for name in freeze_layers]
|
||||
|
||||
norm_module_types = (
|
||||
torch.nn.BatchNorm1d,
|
||||
torch.nn.BatchNorm2d,
|
||||
|
@ -288,8 +297,15 @@ def get_default_optimizer_params(
|
|||
hyperparams.update(overrides.get(module_param_name, {}))
|
||||
if module_name.split('.')[0] == "heads" and (heads_lr_factor is not None and heads_lr_factor != 1.0):
|
||||
hyperparams["lr"] = hyperparams.get("lr", base_lr) * heads_lr_factor
|
||||
params.append({"name": module_name + '.' + module_param_name,
|
||||
"params": [value], **hyperparams})
|
||||
name = module_name + '.' + module_param_name
|
||||
freeze_status = "normal"
|
||||
# Search freeze layer names, it must match from beginning, so use `match` not `search`
|
||||
for pattern in layer_names_pattern:
|
||||
if pattern.match(name) is not None:
|
||||
freeze_status = "freeze"
|
||||
break
|
||||
|
||||
params.append({"freeze_status": freeze_status, "params": [value], **hyperparams})
|
||||
return params
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
# based on: https://github.com/PhilJd/contiguous_pytorch_params/blob/master/contiguous_params/params.py
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class ContiguousParams:
|
||||
|
||||
def __init__(self, parameters):
|
||||
# Create a list of the parameters to prevent emptying an iterator.
|
||||
self._parameters = parameters
|
||||
self._param_buffer = []
|
||||
self._grad_buffer = []
|
||||
self._group_dict = OrderedDict()
|
||||
self._name_buffer = []
|
||||
self._init_buffers()
|
||||
# Store the data pointers for each parameter into the buffer. These
|
||||
# can be used to check if an operation overwrites the gradient/data
|
||||
# tensor (invalidating the assumption of a contiguous buffer).
|
||||
self.data_pointers = []
|
||||
self.grad_pointers = []
|
||||
self.make_params_contiguous()
|
||||
|
||||
def _init_buffers(self):
|
||||
dtype = self._parameters[0]["params"][0].dtype
|
||||
device = self._parameters[0]["params"][0].device
|
||||
if not all(p["params"][0].dtype == dtype for p in self._parameters):
|
||||
raise ValueError("All parameters must be of the same dtype.")
|
||||
if not all(p["params"][0].device == device for p in self._parameters):
|
||||
raise ValueError("All parameters must be on the same device.")
|
||||
|
||||
# Group parameters by lr and weight decay
|
||||
for param_dict in self._parameters:
|
||||
freeze_status = param_dict["freeze_status"]
|
||||
param_key = freeze_status + '_' + str(param_dict["lr"]) + '_' + str(param_dict["weight_decay"])
|
||||
if param_key not in self._group_dict:
|
||||
self._group_dict[param_key] = []
|
||||
self._group_dict[param_key].append(param_dict)
|
||||
|
||||
for key, params in self._group_dict.items():
|
||||
size = sum(p["params"][0].numel() for p in params)
|
||||
self._param_buffer.append(torch.zeros(size, dtype=dtype, device=device))
|
||||
self._grad_buffer.append(torch.zeros(size, dtype=dtype, device=device))
|
||||
self._name_buffer.append(key)
|
||||
|
||||
def make_params_contiguous(self):
|
||||
"""Create a buffer to hold all params and update the params to be views of the buffer.
|
||||
Args:
|
||||
parameters: An iterable of parameters.
|
||||
"""
|
||||
for i, params in enumerate(self._group_dict.values()):
|
||||
index = 0
|
||||
for param_dict in params:
|
||||
p = param_dict["params"][0]
|
||||
size = p.numel()
|
||||
self._param_buffer[i][index:index + size] = p.data.view(-1)
|
||||
p.data = self._param_buffer[i][index:index + size].view(p.data.shape)
|
||||
p.grad = self._grad_buffer[i][index:index + size].view(p.data.shape)
|
||||
self.data_pointers.append(p.data.data_ptr)
|
||||
self.grad_pointers.append(p.grad.data.data_ptr)
|
||||
index += size
|
||||
# Bend the param_buffer to use grad_buffer to track its gradients.
|
||||
self._param_buffer[i].grad = self._grad_buffer[i]
|
||||
|
||||
def contiguous(self):
|
||||
"""Return all parameters as one contiguous buffer."""
|
||||
return [{
|
||||
"freeze_status": self._name_buffer[i].split('_')[0],
|
||||
"params": self._param_buffer[i],
|
||||
"lr": float(self._name_buffer[i].split('_')[1]),
|
||||
"weight_decay": float(self._name_buffer[i].split('_')[2]),
|
||||
} for i in range(len(self._param_buffer))]
|
||||
|
||||
def original(self):
|
||||
"""Return the non-flattened parameters."""
|
||||
return self._parameters
|
||||
|
||||
def buffer_is_valid(self):
|
||||
"""Verify that all parameters and gradients still use the buffer."""
|
||||
i = 0
|
||||
for params in self._group_dict.values():
|
||||
for param_dict in params:
|
||||
p = param_dict["params"][0]
|
||||
data_ptr = self.data_pointers[i]
|
||||
grad_ptr = self.grad_pointers[i]
|
||||
if (p.data.data_ptr() != data_ptr()) or (p.grad.data.data_ptr() != grad_ptr()):
|
||||
return False
|
||||
i += 1
|
||||
return True
|
||||
|
||||
def assert_buffer_is_valid(self):
|
||||
if not self.buffer_is_valid():
|
||||
raise ValueError(
|
||||
"The data or gradient buffer has been invalidated. Please make "
|
||||
"sure to use inplace operations only when updating parameters "
|
||||
"or gradients.")
|
|
@ -56,7 +56,7 @@ SOLVER:
|
|||
BIAS_LR_FACTOR: 1.
|
||||
WEIGHT_DECAY: 0.0005
|
||||
WEIGHT_DECAY_BIAS: 0.
|
||||
IMS_PER_BATCH: 4
|
||||
IMS_PER_BATCH: 16
|
||||
|
||||
ETA_MIN_LR: 0.00003
|
||||
|
||||
|
|
|
@ -16,9 +16,11 @@ from fastreid.data.transforms import build_transforms
|
|||
from fastreid.engine import hooks
|
||||
from fastreid.engine.defaults import DefaultTrainer, TrainerBase
|
||||
from fastreid.engine.train_loop import SimpleTrainer, AMPTrainer
|
||||
from fastreid.solver import build_optimizer
|
||||
from fastreid.utils import comm
|
||||
from fastreid.utils.checkpoint import Checkpointer
|
||||
from fastreid.utils.logger import setup_logger
|
||||
from fastreid.utils.params import ContiguousParams
|
||||
from .face_data import MXFaceDataset
|
||||
from .face_data import TestFaceDataset
|
||||
from .face_evaluator import FaceEvaluator
|
||||
|
@ -39,7 +41,7 @@ class FaceTrainer(DefaultTrainer):
|
|||
data_loader = self.build_train_loader(cfg)
|
||||
cfg = self.auto_scale_hyperparams(cfg, data_loader.dataset.num_classes)
|
||||
model = self.build_model(cfg)
|
||||
optimizer = self.build_optimizer(cfg, model)
|
||||
optimizer, param_wrapper = self.build_optimizer(cfg, model)
|
||||
|
||||
if cfg.MODEL.HEADS.PFC.ENABLED:
|
||||
# fmt: off
|
||||
|
@ -54,7 +56,7 @@ class FaceTrainer(DefaultTrainer):
|
|||
# Partial-FC module
|
||||
embedding_size = embedding_dim if embedding_dim > 0 else feat_dim
|
||||
self.pfc_module = PartialFC(embedding_size, num_classes, sample_rate, cls_type, scale, margin)
|
||||
self.pfc_optimizer = self.build_optimizer(cfg, self.pfc_module)
|
||||
self.pfc_optimizer, _ = build_optimizer(cfg, self.pfc_module, False)
|
||||
|
||||
# For training, wrap with DDP. But don't need this for inference.
|
||||
if comm.get_world_size() > 1:
|
||||
|
@ -67,11 +69,11 @@ class FaceTrainer(DefaultTrainer):
|
|||
if cfg.MODEL.HEADS.PFC.ENABLED:
|
||||
mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()
|
||||
grad_scaler = MaxClipGradScaler(mini_batch_size, 128 * mini_batch_size, growth_interval=100)
|
||||
self._trainer = PFCTrainer(model, data_loader, optimizer,
|
||||
self._trainer = PFCTrainer(model, data_loader, optimizer, param_wrapper,
|
||||
self.pfc_module, self.pfc_optimizer, cfg.SOLVER.AMP.ENABLED, grad_scaler)
|
||||
else:
|
||||
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
|
||||
model, data_loader, optimizer
|
||||
model, data_loader, optimizer, param_wrapper
|
||||
)
|
||||
|
||||
self.iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
|
||||
|
@ -124,7 +126,8 @@ class FaceTrainer(DefaultTrainer):
|
|||
# Backbone loading state_dict
|
||||
super().resume_or_load(resume)
|
||||
# Partial-FC loading state_dict
|
||||
self.pfc_checkpointer.resume_or_load('', resume=resume)
|
||||
if self.cfg.MODEL.HEADS.PFC.ENABLED:
|
||||
self.pfc_checkpointer.resume_or_load('', resume=resume)
|
||||
|
||||
@classmethod
|
||||
def build_train_loader(cls, cfg):
|
||||
|
@ -161,8 +164,9 @@ class PFCTrainer(SimpleTrainer):
|
|||
https://github.com/deepinsight/insightface/blob/master/recognition/arcface_torch/partial_fc.py
|
||||
"""
|
||||
|
||||
def __init__(self, model, data_loader, optimizer, pfc_module, pfc_optimizer, amp_enabled, grad_scaler):
|
||||
super().__init__(model, data_loader, optimizer)
|
||||
def __init__(self, model, data_loader, optimizer, param_wrapper, pfc_module, pfc_optimizer, amp_enabled,
|
||||
grad_scaler):
|
||||
super().__init__(model, data_loader, optimizer, param_wrapper)
|
||||
|
||||
self.pfc_module = pfc_module
|
||||
self.pfc_optimizer = pfc_optimizer
|
||||
|
@ -200,3 +204,5 @@ class PFCTrainer(SimpleTrainer):
|
|||
self.pfc_module.update()
|
||||
self.optimizer.zero_grad()
|
||||
self.pfc_optimizer.zero_grad()
|
||||
if isinstance(self.param_wrapper, ContiguousParams):
|
||||
self.param_wrapper.assert_buffer_is_valid()
|
||||
|
|
Loading…
Reference in New Issue