diff --git a/mmcv/runner/__init__.py b/mmcv/runner/__init__.py index 34d58cb8f..623583bca 100644 --- a/mmcv/runner/__init__.py +++ b/mmcv/runner/__init__.py @@ -11,10 +11,12 @@ from .epoch_based_runner import EpochBasedRunner, Runner from .fp16_utils import LossScaler, auto_fp16, force_fp32, wrap_fp16_model from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistEvalHook, DistSamplerSeedHook, DvcliveLoggerHook, EMAHook, EvalHook, - Fp16OptimizerHook, Hook, IterTimerHook, LoggerHook, - LrUpdaterHook, MlflowLoggerHook, NeptuneLoggerHook, - OptimizerHook, PaviLoggerHook, SyncBuffersHook, - TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook) + Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook, + GradientCumulativeOptimizerHook, Hook, IterTimerHook, + LoggerHook, LrUpdaterHook, MlflowLoggerHook, + NeptuneLoggerHook, OptimizerHook, PaviLoggerHook, + SyncBuffersHook, TensorboardLoggerHook, TextLoggerHook, + WandbLoggerHook) from .iter_based_runner import IterBasedRunner, IterLoader from .log_buffer import LogBuffer from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS, @@ -39,5 +41,6 @@ __all__ = [ 'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads', 'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule', '_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential', - 'ModuleList' + 'ModuleList', 'GradientCumulativeOptimizerHook', + 'GradientCumulativeFp16OptimizerHook' ] diff --git a/mmcv/runner/hooks/__init__.py b/mmcv/runner/hooks/__init__.py index cbc4810a9..915af28ce 100644 --- a/mmcv/runner/hooks/__init__.py +++ b/mmcv/runner/hooks/__init__.py @@ -11,7 +11,8 @@ from .logger import (DvcliveLoggerHook, LoggerHook, MlflowLoggerHook, from .lr_updater import LrUpdaterHook from .memory import EmptyCacheHook from .momentum_updater import MomentumUpdaterHook -from .optimizer import Fp16OptimizerHook, OptimizerHook +from .optimizer import (Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook, + GradientCumulativeOptimizerHook, OptimizerHook) from .profiler import ProfilerHook from .sampler_seed import DistSamplerSeedHook from .sync_buffer import SyncBuffersHook @@ -23,5 +24,6 @@ __all__ = [ 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', 'NeptuneLoggerHook', 'WandbLoggerHook', 'DvcliveLoggerHook', 'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook', 'EvalHook', - 'DistEvalHook', 'ProfilerHook' + 'DistEvalHook', 'ProfilerHook', 'GradientCumulativeOptimizerHook', + 'GradientCumulativeFp16OptimizerHook' ] diff --git a/mmcv/runner/hooks/optimizer.py b/mmcv/runner/hooks/optimizer.py index 9810e43d6..f575ceda0 100644 --- a/mmcv/runner/hooks/optimizer.py +++ b/mmcv/runner/hooks/optimizer.py @@ -5,7 +5,7 @@ from itertools import chain from torch.nn.utils import clip_grad -from mmcv.utils import TORCH_VERSION, digit_version +from mmcv.utils import TORCH_VERSION, _BatchNorm, digit_version from ..dist_utils import allreduce_grads from ..fp16_utils import LossScaler, wrap_fp16_model from .hook import HOOKS, Hook @@ -42,6 +42,91 @@ class OptimizerHook(Hook): runner.optimizer.step() +@HOOKS.register_module() +class GradientCumulativeOptimizerHook(OptimizerHook): + """Optimizer Hook implements multi-iters gradient cumulating. + + Args: + cumulative_iters (int, optional): Num of gradient cumulative iters. + The optimizer will step every `cumulative_iters` iters. + Defaults to 1. + + Examples: + >>> # Use cumulative_iters to simulate a large batch size + >>> # It is helpful when the hardware cannot handle a large batch size. + >>> loader = DataLoader(data, batch_size=64) + >>> optim_hook = GradientCumulativeOptimizerHook(cumulative_iters=4) + >>> # almost equals to + >>> loader = DataLoader(data, batch_size=256) + >>> optim_hook = OptimizerHook() + """ + + def __init__(self, cumulative_iters=1, **kwargs): + super(GradientCumulativeOptimizerHook, self).__init__(**kwargs) + + assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \ + f'cumulative_iters only accepts positive int, but got ' \ + f'{type(cumulative_iters)} instead.' + + self.cumulative_iters = cumulative_iters + self.divisible_iters = 0 + self.remainder_iters = 0 + self.initialized = False + + def has_batch_norm(self, module): + if isinstance(module, _BatchNorm): + return True + for m in module.children(): + if self.has_batch_norm(m): + return True + return False + + def _init(self, runner): + if runner.iter % self.cumulative_iters != 0: + runner.logger.warning( + 'Resume iter number is not divisible by cumulative_iters in ' + 'GradientCumulativeOptimizerHook, which means the gradient of ' + 'some iters is lost and the result may be influenced slightly.' + ) + + if self.has_batch_norm(runner.model) and self.cumulative_iters > 1: + runner.logger.warning( + 'GradientCumulativeOptimizerHook may slightly decrease ' + 'performance if the model has BatchNorm layers.') + + residual_iters = runner.max_iters - runner.iter + + self.divisible_iters = ( + residual_iters // self.cumulative_iters * self.cumulative_iters) + self.remainder_iters = residual_iters - self.divisible_iters + + self.initialized = True + + def after_train_iter(self, runner): + if not self.initialized: + self._init(runner) + + if runner.iter < self.divisible_iters: + loss_factor = self.cumulative_iters + else: + loss_factor = self.remainder_iters + loss = runner.outputs['loss'] + loss = loss / loss_factor + loss.backward() + + if (self.every_n_iters(runner, self.cumulative_iters) + or self.is_last_iter(runner)): + + if self.grad_clip is not None: + grad_norm = self.clip_grads(runner.model.parameters()) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update({'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + runner.optimizer.step() + runner.optimizer.zero_grad() + + if (TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) >= digit_version('1.6.0')): @@ -152,6 +237,60 @@ if (TORCH_VERSION != 'parrots' # save state_dict of loss_scaler runner.meta.setdefault( 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() + + @HOOKS.register_module() + class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook, + Fp16OptimizerHook): + """Fp16 optimizer Hook (using PyTorch's implementation) implements + multi-iters gradient cumulating. + + If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend, + to take care of the optimization procedure. + """ + + def __init__(self, *args, **kwargs): + super(GradientCumulativeFp16OptimizerHook, + self).__init__(*args, **kwargs) + + def after_train_iter(self, runner): + if not self.initialized: + self._init(runner) + + if runner.iter < self.divisible_iters: + loss_factor = self.cumulative_iters + else: + loss_factor = self.remainder_iters + loss = runner.outputs['loss'] + loss = loss / loss_factor + + self.loss_scaler.scale(loss).backward() + + if (self.every_n_iters(runner, self.cumulative_iters) + or self.is_last_iter(runner)): + + # copy fp16 grads in the model to fp32 params in the optimizer + self.loss_scaler.unscale_(runner.optimizer) + + if self.grad_clip is not None: + grad_norm = self.clip_grads(runner.model.parameters()) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update( + {'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + + # backward and update scaler + self.loss_scaler.step(runner.optimizer) + self.loss_scaler.update(self._scale_update_param) + + # save state_dict of loss_scaler + runner.meta.setdefault( + 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() + + # clear grads + runner.model.zero_grad() + runner.optimizer.zero_grad() + else: @HOOKS.register_module() @@ -295,3 +434,75 @@ else: # save state_dict of loss_scaler runner.meta.setdefault( 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() + + @HOOKS.register_module() + class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook, + Fp16OptimizerHook): + """Fp16 optimizer Hook (using mmcv implementation) implements multi- + iters gradient cumulating.""" + + def __init__(self, *args, **kwargs): + super(GradientCumulativeFp16OptimizerHook, + self).__init__(*args, **kwargs) + + def after_train_iter(self, runner): + if not self.initialized: + self._init(runner) + + if runner.iter < self.divisible_iters: + loss_factor = self.cumulative_iters + else: + loss_factor = self.remainder_iters + + loss = runner.outputs['loss'] + loss = loss / loss_factor + + # scale the loss value + scaled_loss = loss * self.loss_scaler.loss_scale + scaled_loss.backward() + + if (self.every_n_iters(runner, self.cumulative_iters) + or self.is_last_iter(runner)): + + # copy fp16 grads in the model to fp32 params in the optimizer + fp32_weights = [] + for param_group in runner.optimizer.param_groups: + fp32_weights += param_group['params'] + self.copy_grads_to_fp32(runner.model, fp32_weights) + # allreduce grads + if self.distributed: + allreduce_grads(fp32_weights, self.coalesce, + self.bucket_size_mb) + + has_overflow = self.loss_scaler.has_overflow(fp32_weights) + # if has overflow, skip this iteration + if not has_overflow: + # scale the gradients back + for param in fp32_weights: + if param.grad is not None: + param.grad.div_(self.loss_scaler.loss_scale) + if self.grad_clip is not None: + grad_norm = self.clip_grads(fp32_weights) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update( + {'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + # update fp32 params + runner.optimizer.step() + # copy fp32 params to the fp16 model + self.copy_params_to_fp16(runner.model, fp32_weights) + else: + runner.logger.warning( + 'Check overflow, downscale loss scale ' + f'to {self.loss_scaler.cur_scale}') + + self.loss_scaler.update_scale(has_overflow) + + # save state_dict of loss_scaler + runner.meta.setdefault( + 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() + + # clear grads + runner.model.zero_grad() + runner.optimizer.zero_grad() diff --git a/tests/test_runner/test_hooks.py b/tests/test_runner/test_hooks.py index a0f5a0af8..424c58678 100644 --- a/tests/test_runner/test_hooks.py +++ b/tests/test_runner/test_hooks.py @@ -20,8 +20,12 @@ from torch.nn.init import constant_ from torch.utils.data import DataLoader from mmcv.runner import (CheckpointHook, DvcliveLoggerHook, EMAHook, - IterTimerHook, MlflowLoggerHook, NeptuneLoggerHook, + Fp16OptimizerHook, + GradientCumulativeFp16OptimizerHook, + GradientCumulativeOptimizerHook, IterTimerHook, + MlflowLoggerHook, NeptuneLoggerHook, OptimizerHook, PaviLoggerHook, WandbLoggerHook, build_runner) +from mmcv.runner.fp16_utils import auto_fp16 from mmcv.runner.hooks.hook import HOOKS, Hook from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook, CyclicLrUpdaterHook, @@ -1229,3 +1233,207 @@ def test_get_triggered_stages(): # stages output have order, so here is list instead of set. expected_stages = ['before_run', 'after_train_epoch', 'after_val_epoch'] assert hook.get_triggered_stages() == expected_stages + + +def test_gradient_cumulative_optimizer_hook(): + + class ToyModel(nn.Module): + + def __init__(self, with_norm=False): + super().__init__() + self.fp16_enabled = False + self.fc = nn.Linear(3, 2) + nn.init.constant_(self.fc.weight, 1.) + nn.init.constant_(self.fc.bias, 1.) + self.with_norm = with_norm + if with_norm: + self.norm = nn.BatchNorm1d(2) + + def forward(self, x): + x = self.fc(x) + if self.with_norm: + x = self.norm(x) + return x + + def train_step(self, x, optimizer, **kwargs): + return dict(loss=self(x).mean(), num_samples=x.shape[0]) + + def val_step(self, x, optimizer, **kwargs): + return dict(loss=self(x).mean(), num_samples=x.shape[0]) + + def build_toy_runner(config=dict(type='EpochBasedRunner', max_epochs=3)): + model = ToyModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.02) + tmp_dir = tempfile.mkdtemp() + + runner = build_runner( + config, + default_args=dict( + model=model, + work_dir=tmp_dir, + optimizer=optimizer, + logger=logging.getLogger(), + meta=dict())) + return runner + + with pytest.raises(AssertionError): + # cumulative_iters only accepts int + GradientCumulativeOptimizerHook(cumulative_iters='str') + + with pytest.raises(AssertionError): + # cumulative_iters only accepts positive number + GradientCumulativeOptimizerHook(cumulative_iters=-1) + + # test epoch based runner + data = torch.rand((6, 3)) + # optimize with cumulative_iters + loader_1 = DataLoader(data, batch_size=1) + runner_1 = build_toy_runner() + optimizer_hook = GradientCumulativeOptimizerHook( + grad_clip=dict(max_norm=0.2), cumulative_iters=3) + runner_1.register_hook(optimizer_hook) + runner_1.run([loader_1], [('train', 1)]) + + # optimize without cumulative_iters + loader_2 = DataLoader(data, batch_size=3) + runner_2 = build_toy_runner() + optimizer_hook = OptimizerHook(grad_clip=dict(max_norm=0.2)) + runner_2.register_hook(optimizer_hook) + runner_2.run([loader_2], [('train', 1)]) + + # test optimizer works well + assert (runner_1.model.fc.weight < 1).all() + assert (runner_1.model.fc.bias < 1).all() + # test optimizer with cumulative_iters gets the same results + assert torch.allclose(runner_1.model.fc.weight, runner_2.model.fc.weight) + assert torch.allclose(runner_1.model.fc.bias, runner_2.model.fc.bias) + shutil.rmtree(runner_1.work_dir) + shutil.rmtree(runner_2.work_dir) + + # test iter based runner + data = torch.rand((8, 3)) + # optimize with cumulative_iters + loader_1 = DataLoader(data, batch_size=1) + runner_1 = build_toy_runner(dict(type='IterBasedRunner', max_iters=8)) + optimizer_hook = GradientCumulativeOptimizerHook( + grad_clip=dict(max_norm=0.2), cumulative_iters=3) + runner_1.register_hook(optimizer_hook) + runner_1.run([loader_1], [('train', 1)]) + + # optimize without cumulative_iters + loader_2_divisible = DataLoader(data[:6], batch_size=3) + loader_2_remainder = DataLoader(data[6:], batch_size=2) + runner_2 = build_toy_runner(dict(type='IterBasedRunner', max_iters=3)) + optimizer_hook = OptimizerHook(grad_clip=dict(max_norm=0.2)) + runner_2.register_hook(optimizer_hook) + runner_2.run([loader_2_divisible, loader_2_remainder], [('train', 2), + ('train', 1)]) + + # test optimizer works well + assert (runner_1.model.fc.weight < 1).all() + assert (runner_1.model.fc.bias < 1).all() + # test optimizer with cumulative_iters gets the same results + assert torch.allclose(runner_1.model.fc.weight, runner_2.model.fc.weight) + assert torch.allclose(runner_1.model.fc.bias, runner_2.model.fc.bias) + shutil.rmtree(runner_1.work_dir) + shutil.rmtree(runner_2.work_dir) + + # test has_batch_norm + model = ToyModel(with_norm=True) + optimizer_hook = GradientCumulativeOptimizerHook( + grad_clip=dict(max_norm=0.2), cumulative_iters=3) + assert optimizer_hook.has_batch_norm(model) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_gradient_cumulative_fp16_optimizer_hook(): + + class ToyModel(nn.Module): + + def __init__(self): + super().__init__() + self.fp16_enabled = False + self.fc = nn.Linear(3, 2) + nn.init.constant_(self.fc.weight, 1.) + nn.init.constant_(self.fc.bias, 1.) + + @auto_fp16(apply_to=('x', )) + def forward(self, x): + x = self.fc(x) + return x + + def train_step(self, x, optimizer, **kwargs): + return dict(loss=self(x).mean(), num_samples=x.shape[0]) + + def val_step(self, x, optimizer, **kwargs): + return dict(loss=self(x).mean(), num_samples=x.shape[0]) + + def build_toy_runner(config=dict(type='EpochBasedRunner', max_epochs=3)): + model = ToyModel().cuda() + optimizer = torch.optim.SGD(model.parameters(), lr=0.02) + tmp_dir = tempfile.mkdtemp() + + runner = build_runner( + config, + default_args=dict( + model=model, + work_dir=tmp_dir, + optimizer=optimizer, + logger=logging.getLogger(), + meta=dict())) + return runner + + # test epoch based runner + data = torch.rand((6, 3)).cuda() + # optimize with cumulative_iters + loader_1 = DataLoader(data, batch_size=1) + runner_1 = build_toy_runner() + optimizer_hook = GradientCumulativeFp16OptimizerHook( + grad_clip=dict(max_norm=0.2), cumulative_iters=3) + runner_1.register_hook(optimizer_hook) + runner_1.run([loader_1], [('train', 1)]) + + # optimize without cumulative_iters + loader_2 = DataLoader(data, batch_size=3) + runner_2 = build_toy_runner() + optimizer_hook = Fp16OptimizerHook(grad_clip=dict(max_norm=0.2)) + runner_2.register_hook(optimizer_hook) + runner_2.run([loader_2], [('train', 1)]) + + # test optimizer works well + assert (runner_1.model.fc.weight < 1).all() + assert (runner_1.model.fc.bias < 1).all() + # test optimizer with cumulative_iters gets the same results + assert torch.allclose(runner_1.model.fc.weight, runner_2.model.fc.weight) + assert torch.allclose(runner_1.model.fc.bias, runner_2.model.fc.bias) + shutil.rmtree(runner_1.work_dir) + shutil.rmtree(runner_2.work_dir) + + # test iter based runner + data = torch.rand((8, 3)).cuda() + # optimize with cumulative_iters + loader_1 = DataLoader(data, batch_size=1) + runner_1 = build_toy_runner(dict(type='IterBasedRunner', max_iters=8)) + optimizer_hook = GradientCumulativeFp16OptimizerHook( + grad_clip=dict(max_norm=0.2), cumulative_iters=3) + runner_1.register_hook(optimizer_hook) + runner_1.run([loader_1], [('train', 1)]) + + # optimize without cumulative_iters + loader_2_divisible = DataLoader(data[:6], batch_size=3) + loader_2_remainder = DataLoader(data[6:], batch_size=2) + runner_2 = build_toy_runner(dict(type='IterBasedRunner', max_iters=3)) + optimizer_hook = Fp16OptimizerHook(grad_clip=dict(max_norm=0.2)) + runner_2.register_hook(optimizer_hook) + runner_2.run([loader_2_divisible, loader_2_remainder], [('train', 2), + ('train', 1)]) + + # test optimizer works well + assert (runner_1.model.fc.weight < 1).all() + assert (runner_1.model.fc.bias < 1).all() + # test optimizer with cumulative_iters gets the same results + assert torch.allclose(runner_1.model.fc.weight, runner_2.model.fc.weight) + assert torch.allclose(runner_1.model.fc.bias, runner_2.model.fc.bias) + shutil.rmtree(runner_1.work_dir) + shutil.rmtree(runner_2.work_dir)