From 0ea07c07505a16fefcea97cfdfc24b11ed04958a Mon Sep 17 00:00:00 2001 From: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com> Date: Mon, 8 Aug 2022 15:01:47 +0800 Subject: [PATCH] [Refactor] Refactor hooks to support non-dist training (#397) * apply mmengine.dist instead of torch.dist * appy get_model to densecl_hook * fix bug to pass ut * update typehint --- mmselfsup/engine/hooks/densecl_hook.py | 11 +++++----- mmselfsup/engine/hooks/swav_hook.py | 22 +++++++++---------- mmselfsup/models/utils/data_preprocessor.py | 15 ++++++++----- mmselfsup/utils/__init__.py | 4 +++- mmselfsup/utils/misc.py | 18 +++++++++++++++ .../test_hooks/test_densecl_hook.py | 21 +++++------------- .../test_engine/test_hooks/test_swav_hook.py | 19 +++++----------- 7 files changed, 58 insertions(+), 52 deletions(-) create mode 100644 mmselfsup/utils/misc.py diff --git a/mmselfsup/engine/hooks/densecl_hook.py b/mmselfsup/engine/hooks/densecl_hook.py index 7ab1f916..c98261de 100644 --- a/mmselfsup/engine/hooks/densecl_hook.py +++ b/mmselfsup/engine/hooks/densecl_hook.py @@ -4,6 +4,7 @@ from typing import Optional, Sequence from mmengine.hooks import Hook from mmselfsup.registry import HOOKS +from mmselfsup.utils import get_model @HOOKS.register_module() @@ -23,19 +24,19 @@ class DenseCLHook(Hook): def before_train(self, runner) -> None: """Obtain ``loss_lambda`` from algorithm.""" - assert hasattr(runner.model.module, 'loss_lambda'), \ + assert hasattr(get_model(runner.model), 'loss_lambda'), \ "The runner must have attribute \"loss_lambda\" in DenseCL." - self.loss_lambda = runner.model.module.loss_lambda + self.loss_lambda = get_model(runner.model).loss_lambda def before_train_iter(self, runner, batch_idx: int, data_batch: Optional[Sequence[dict]] = None) -> None: """Adjust ``loss_lambda`` every train iter.""" - assert hasattr(runner.model.module, 'loss_lambda'), \ + assert hasattr(get_model(runner.model), 'loss_lambda'), \ "The runner must have attribute \"loss_lambda\" in DenseCL." cur_iter = runner.iter if cur_iter >= self.start_iters: - runner.model.module.loss_lambda = self.loss_lambda + get_model(runner.model).loss_lambda = self.loss_lambda else: - runner.model.module.loss_lambda = 0. + get_model(runner.model).loss_lambda = 0. diff --git a/mmselfsup/engine/hooks/swav_hook.py b/mmselfsup/engine/hooks/swav_hook.py index e3b3a923..d46bf57c 100644 --- a/mmselfsup/engine/hooks/swav_hook.py +++ b/mmselfsup/engine/hooks/swav_hook.py @@ -3,11 +3,12 @@ import os.path as osp from typing import Dict, List, Optional, Sequence import torch -import torch.distributed as dist +from mmengine.dist import get_rank, get_world_size, is_distributed from mmengine.hooks import Hook from mmengine.logging import MMLogger from mmselfsup.registry import HOOKS +from mmselfsup.utils import get_model @HOOKS.register_module() @@ -45,8 +46,7 @@ class SwAVHook(Hook): interval: Optional[int] = 1, frozen_layers_cfg: Optional[Dict] = dict() ) -> None: - self.batch_size = batch_size * dist.get_world_size()\ - if dist.is_initialized() else batch_size + self.batch_size = batch_size * get_world_size() self.epoch_queue_starts = epoch_queue_starts self.crops_for_assign = crops_for_assign self.feat_dim = feat_dim @@ -58,16 +58,16 @@ class SwAVHook(Hook): def before_run(self, runner) -> None: """Check whether the queues exist locally or not.""" - if dist.is_initialized(): + if is_distributed(): self.queue_path = osp.join(runner.work_dir, - 'queue' + str(dist.get_rank()) + '.pth') + 'queue' + str(get_rank()) + '.pth') else: self.queue_path = osp.join(runner.work_dir, 'queue.pth') # load the queues if queues exist locally if osp.isfile(self.queue_path): self.queue = torch.load(self.queue_path)['queue'] - runner.model.module.head.loss.queue = self.queue + get_model(runner.model).head.loss.queue = self.queue MMLogger.get_current_instance().info( f'Load queue from file: {self.queue_path}') @@ -82,12 +82,12 @@ class SwAVHook(Hook): for layer, frozen_iters in self.frozen_layers_cfg.items(): if runner.iter < frozen_iters and self.requires_grad: self.requires_grad = False - for name, p in runner.model.module.named_parameters(): + for name, p in get_model(runner.model).named_parameters(): if layer in name: p.requires_grad = False elif runner.iter >= frozen_iters and not self.requires_grad: self.requires_grad = True - for name, p in runner.model.module.named_parameters(): + for name, p in get_model(runner.model).named_parameters(): if layer in name: p.requires_grad = True @@ -104,12 +104,12 @@ class SwAVHook(Hook): ).cuda() # set the boolean type of use_the_queue - runner.model.module.head.loss.queue = self.queue - runner.model.module.head.loss.use_queue = False + get_model(runner.model).head.loss.queue = self.queue + get_model(runner.model).head.loss.use_queue = False def after_train_epoch(self, runner) -> None: """Save the queues locally.""" - self.queue = runner.model.module.head.loss.queue + self.queue = get_model(runner.model).head.loss.queue if self.queue is not None and self.every_n_epochs( runner, self.interval): diff --git a/mmselfsup/models/utils/data_preprocessor.py b/mmselfsup/models/utils/data_preprocessor.py index 8d4350be..9a18f2b3 100644 --- a/mmselfsup/models/utils/data_preprocessor.py +++ b/mmselfsup/models/utils/data_preprocessor.py @@ -82,8 +82,9 @@ class SelfSupDataPreprocessor(ImgDataPreprocessor): # :class:`mmengine.ImgDataPreprocessor`. Since there are multiple views # for an image for some algorithms, e.g. SimCLR, each item in inputs # is a list, containing multi-views for an image. - inputs = [[(img_ - self.mean) / self.std for img_ in _input] - for _input in inputs] + if self._enable_normalize: + inputs = [[(img_ - self.mean) / self.std for img_ in _input] + for _input in inputs] batch_inputs = [] for i in range(len(inputs[0])): @@ -125,8 +126,9 @@ class RelativeLocDataPreprocessor(SelfSupDataPreprocessor): # :class:`mmengine.ImgDataPreprocessor`. Since there are multiple views # for an image for some algorithms, e.g. SimCLR, each item in inputs # is a list, containing multi-views for an image. - inputs = [[(img_ - self.mean) / self.std for img_ in _input] - for _input in inputs] + if self._enable_normalize: + inputs = [[(img_ - self.mean) / self.std for img_ in _input] + for _input in inputs] batch_inputs = [] for i in range(len(inputs[0])): @@ -180,8 +182,9 @@ class RotationPredDataPreprocessor(SelfSupDataPreprocessor): # :class:`mmengine.ImgDataPreprocessor`. Since there are multiple views # for an image for some algorithms, e.g. SimCLR, each item in inputs # is a list, containing multi-views for an image. - inputs = [[(img_ - self.mean) / self.std for img_ in _input] - for _input in inputs] + if self._enable_normalize: + inputs = [[(img_ - self.mean) / self.std for img_ in _input] + for _input in inputs] batch_inputs = [] for i in range(len(inputs[0])): diff --git a/mmselfsup/utils/__init__.py b/mmselfsup/utils/__init__.py index a4cc0513..cc95a0d6 100644 --- a/mmselfsup/utils/__init__.py +++ b/mmselfsup/utils/__init__.py @@ -5,11 +5,13 @@ from .collect import dist_forward_collect, nondist_forward_collect from .collect_env import collect_env from .distributed_sinkhorn import distributed_sinkhorn from .gather import concat_all_gather, gather_tensors, gather_tensors_batch +from .misc import get_model from .setup_env import register_all_modules __all__ = [ 'AliasMethod', 'batch_shuffle_ddp', 'batch_unshuffle_ddp', 'dist_forward_collect', 'nondist_forward_collect', 'collect_env', 'sync_random_seed', 'distributed_sinkhorn', 'concat_all_gather', - 'gather_tensors', 'gather_tensors_batch', 'register_all_modules' + 'gather_tensors', 'gather_tensors_batch', 'register_all_modules', + 'get_model' ] diff --git a/mmselfsup/utils/misc.py b/mmselfsup/utils/misc.py new file mode 100644 index 00000000..e8c52f4f --- /dev/null +++ b/mmselfsup/utils/misc.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmengine.model import BaseModel, is_model_wrapper + + +def get_model(model: nn.Module) -> BaseModel: + """Get model if the input model is a model wrapper. + + Args: + model (nn.Module): A model may be a model wrapper. + + Returns: + BaseModel: The model without model wrapper. + """ + if is_model_wrapper(model): + return model.module + else: + return model diff --git a/tests/test_engine/test_hooks/test_densecl_hook.py b/tests/test_engine/test_hooks/test_densecl_hook.py index 40a80dcc..9dc89e97 100644 --- a/tests/test_engine/test_hooks/test_densecl_hook.py +++ b/tests/test_engine/test_hooks/test_densecl_hook.py @@ -6,7 +6,6 @@ import torch import torch.nn as nn from mmengine import Runner from mmengine.data import LabelData -from mmengine.model import BaseModel as EngineBaseModel from mmengine.model import BaseModule from mmengine.optim import OptimWrapper from torch.utils.data import Dataset @@ -15,6 +14,7 @@ from mmselfsup.engine import DenseCLHook from mmselfsup.models.algorithms import BaseModel from mmselfsup.registry import MODELS from mmselfsup.structures import SelfSupDataSample +from mmselfsup.utils import get_model class DummyDataset(Dataset): @@ -33,7 +33,7 @@ class DummyDataset(Dataset): data_sample = SelfSupDataSample() gt_label = LabelData(value=self.label[index]) setattr(data_sample, 'gt_label', gt_label) - return dict(inputs=self.data[index], data_sample=data_sample) + return dict(inputs=[self.data[index]], data_sample=data_sample) @MODELS.register_module() @@ -58,7 +58,7 @@ class ToyModel(BaseModel): for x in data_samples: labels.append(x.gt_label.value) labels = torch.stack(labels) - outputs = self.backbone(batch_inputs) + outputs = self.backbone(batch_inputs[0]) loss = (labels - outputs).sum() outputs = dict(loss=loss) return outputs @@ -78,18 +78,9 @@ class TestDenseCLHook(TestCase): toy_model = ToyModel().to(device) densecl_hook = DenseCLHook(start_iters=1) - class DummyWrapper(EngineBaseModel): - - def __init__(self, model): - super().__init__() - self.module = model - - def forward(self, *args, **kwargs): - return self.module(*args, **kwargs) - # test DenseCLHook with model wrapper runner = Runner( - model=DummyWrapper(toy_model), + model=toy_model, work_dir=self.temp_dir.name, train_dataloader=dict( dataset=dummy_dataset, @@ -108,6 +99,6 @@ class TestDenseCLHook(TestCase): runner.train() if runner.iter >= 1: - assert runner.model.module.loss_lambda == 0.5 + assert get_model(runner.model).loss_lambda == 0.5 else: - assert runner.model.module.loss_lambda == 0. + assert get_model(runner.model).loss_lambda == 0. diff --git a/tests/test_engine/test_hooks/test_swav_hook.py b/tests/test_engine/test_hooks/test_swav_hook.py index ef79758d..e7d3660a 100644 --- a/tests/test_engine/test_hooks/test_swav_hook.py +++ b/tests/test_engine/test_hooks/test_swav_hook.py @@ -6,7 +6,6 @@ import torch import torch.nn as nn from mmengine import Runner from mmengine.data import LabelData -from mmengine.model import BaseModel as EngineBaseModel from mmengine.model import BaseModule from mmengine.optim import OptimWrapper from torch.utils.data import Dataset @@ -16,6 +15,7 @@ from mmselfsup.models.algorithms import BaseModel from mmselfsup.models.heads import SwAVHead from mmselfsup.registry import MODELS from mmselfsup.structures import SelfSupDataSample +from mmselfsup.utils import get_model class DummyDataset(Dataset): @@ -34,7 +34,7 @@ class DummyDataset(Dataset): data_sample = SelfSupDataSample() gt_label = LabelData(value=self.label[index]) setattr(data_sample, 'gt_label', gt_label) - return dict(inputs=self.data[index], data_sample=data_sample) + return dict(inputs=[self.data[index]], data_sample=data_sample) @MODELS.register_module() @@ -65,7 +65,7 @@ class ToyModel(BaseModel): for x in data_samples: labels.append(x.gt_label.value) labels = torch.stack(labels) - outputs = self.backbone(batch_inputs) + outputs = self.backbone(batch_inputs[0]) loss = (labels - outputs).sum() outputs = dict(loss=loss) return outputs @@ -91,18 +91,9 @@ class TestSwAVHook(TestCase): queue_length=300, frozen_layers_cfg=dict(prototypes=2)) - class DummyWrapper(EngineBaseModel): - - def __init__(self, model): - super().__init__() - self.module = model - - def forward(self, *args, **kwargs): - return self.module(*args, **kwargs) - # test SwAVHook runner = Runner( - model=DummyWrapper(toy_model), + model=toy_model, work_dir=self.temp_dir.name, train_dataloader=dict( dataset=dummy_dataset, @@ -124,4 +115,4 @@ class TestSwAVHook(TestCase): if isinstance(hook, SwAVHook): assert hook.queue_length == 300 - assert runner.model.module.head.loss.use_queue is False + assert get_model(runner.model).head.loss.use_queue is False