[Refactor] Refactor hooks to support non-dist training (#397)

* apply mmengine.dist instead of torch.dist

* appy get_model to densecl_hook

* fix bug to pass ut

* update typehint
pull/408/head
Yixiao Fang 2022-08-08 15:01:47 +08:00 committed by GitHub
parent a703ba2fcb
commit 0ea07c0750
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 58 additions and 52 deletions

View File

@ -4,6 +4,7 @@ from typing import Optional, Sequence
from mmengine.hooks import Hook
from mmselfsup.registry import HOOKS
from mmselfsup.utils import get_model
@HOOKS.register_module()
@ -23,19 +24,19 @@ class DenseCLHook(Hook):
def before_train(self, runner) -> None:
"""Obtain ``loss_lambda`` from algorithm."""
assert hasattr(runner.model.module, 'loss_lambda'), \
assert hasattr(get_model(runner.model), 'loss_lambda'), \
"The runner must have attribute \"loss_lambda\" in DenseCL."
self.loss_lambda = runner.model.module.loss_lambda
self.loss_lambda = get_model(runner.model).loss_lambda
def before_train_iter(self,
runner,
batch_idx: int,
data_batch: Optional[Sequence[dict]] = None) -> None:
"""Adjust ``loss_lambda`` every train iter."""
assert hasattr(runner.model.module, 'loss_lambda'), \
assert hasattr(get_model(runner.model), 'loss_lambda'), \
"The runner must have attribute \"loss_lambda\" in DenseCL."
cur_iter = runner.iter
if cur_iter >= self.start_iters:
runner.model.module.loss_lambda = self.loss_lambda
get_model(runner.model).loss_lambda = self.loss_lambda
else:
runner.model.module.loss_lambda = 0.
get_model(runner.model).loss_lambda = 0.

View File

@ -3,11 +3,12 @@ import os.path as osp
from typing import Dict, List, Optional, Sequence
import torch
import torch.distributed as dist
from mmengine.dist import get_rank, get_world_size, is_distributed
from mmengine.hooks import Hook
from mmengine.logging import MMLogger
from mmselfsup.registry import HOOKS
from mmselfsup.utils import get_model
@HOOKS.register_module()
@ -45,8 +46,7 @@ class SwAVHook(Hook):
interval: Optional[int] = 1,
frozen_layers_cfg: Optional[Dict] = dict()
) -> None:
self.batch_size = batch_size * dist.get_world_size()\
if dist.is_initialized() else batch_size
self.batch_size = batch_size * get_world_size()
self.epoch_queue_starts = epoch_queue_starts
self.crops_for_assign = crops_for_assign
self.feat_dim = feat_dim
@ -58,16 +58,16 @@ class SwAVHook(Hook):
def before_run(self, runner) -> None:
"""Check whether the queues exist locally or not."""
if dist.is_initialized():
if is_distributed():
self.queue_path = osp.join(runner.work_dir,
'queue' + str(dist.get_rank()) + '.pth')
'queue' + str(get_rank()) + '.pth')
else:
self.queue_path = osp.join(runner.work_dir, 'queue.pth')
# load the queues if queues exist locally
if osp.isfile(self.queue_path):
self.queue = torch.load(self.queue_path)['queue']
runner.model.module.head.loss.queue = self.queue
get_model(runner.model).head.loss.queue = self.queue
MMLogger.get_current_instance().info(
f'Load queue from file: {self.queue_path}')
@ -82,12 +82,12 @@ class SwAVHook(Hook):
for layer, frozen_iters in self.frozen_layers_cfg.items():
if runner.iter < frozen_iters and self.requires_grad:
self.requires_grad = False
for name, p in runner.model.module.named_parameters():
for name, p in get_model(runner.model).named_parameters():
if layer in name:
p.requires_grad = False
elif runner.iter >= frozen_iters and not self.requires_grad:
self.requires_grad = True
for name, p in runner.model.module.named_parameters():
for name, p in get_model(runner.model).named_parameters():
if layer in name:
p.requires_grad = True
@ -104,12 +104,12 @@ class SwAVHook(Hook):
).cuda()
# set the boolean type of use_the_queue
runner.model.module.head.loss.queue = self.queue
runner.model.module.head.loss.use_queue = False
get_model(runner.model).head.loss.queue = self.queue
get_model(runner.model).head.loss.use_queue = False
def after_train_epoch(self, runner) -> None:
"""Save the queues locally."""
self.queue = runner.model.module.head.loss.queue
self.queue = get_model(runner.model).head.loss.queue
if self.queue is not None and self.every_n_epochs(
runner, self.interval):

View File

@ -82,8 +82,9 @@ class SelfSupDataPreprocessor(ImgDataPreprocessor):
# :class:`mmengine.ImgDataPreprocessor`. Since there are multiple views
# for an image for some algorithms, e.g. SimCLR, each item in inputs
# is a list, containing multi-views for an image.
inputs = [[(img_ - self.mean) / self.std for img_ in _input]
for _input in inputs]
if self._enable_normalize:
inputs = [[(img_ - self.mean) / self.std for img_ in _input]
for _input in inputs]
batch_inputs = []
for i in range(len(inputs[0])):
@ -125,8 +126,9 @@ class RelativeLocDataPreprocessor(SelfSupDataPreprocessor):
# :class:`mmengine.ImgDataPreprocessor`. Since there are multiple views
# for an image for some algorithms, e.g. SimCLR, each item in inputs
# is a list, containing multi-views for an image.
inputs = [[(img_ - self.mean) / self.std for img_ in _input]
for _input in inputs]
if self._enable_normalize:
inputs = [[(img_ - self.mean) / self.std for img_ in _input]
for _input in inputs]
batch_inputs = []
for i in range(len(inputs[0])):
@ -180,8 +182,9 @@ class RotationPredDataPreprocessor(SelfSupDataPreprocessor):
# :class:`mmengine.ImgDataPreprocessor`. Since there are multiple views
# for an image for some algorithms, e.g. SimCLR, each item in inputs
# is a list, containing multi-views for an image.
inputs = [[(img_ - self.mean) / self.std for img_ in _input]
for _input in inputs]
if self._enable_normalize:
inputs = [[(img_ - self.mean) / self.std for img_ in _input]
for _input in inputs]
batch_inputs = []
for i in range(len(inputs[0])):

View File

@ -5,11 +5,13 @@ from .collect import dist_forward_collect, nondist_forward_collect
from .collect_env import collect_env
from .distributed_sinkhorn import distributed_sinkhorn
from .gather import concat_all_gather, gather_tensors, gather_tensors_batch
from .misc import get_model
from .setup_env import register_all_modules
__all__ = [
'AliasMethod', 'batch_shuffle_ddp', 'batch_unshuffle_ddp',
'dist_forward_collect', 'nondist_forward_collect', 'collect_env',
'sync_random_seed', 'distributed_sinkhorn', 'concat_all_gather',
'gather_tensors', 'gather_tensors_batch', 'register_all_modules'
'gather_tensors', 'gather_tensors_batch', 'register_all_modules',
'get_model'
]

View File

@ -0,0 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmengine.model import BaseModel, is_model_wrapper
def get_model(model: nn.Module) -> BaseModel:
"""Get model if the input model is a model wrapper.
Args:
model (nn.Module): A model may be a model wrapper.
Returns:
BaseModel: The model without model wrapper.
"""
if is_model_wrapper(model):
return model.module
else:
return model

View File

@ -6,7 +6,6 @@ import torch
import torch.nn as nn
from mmengine import Runner
from mmengine.data import LabelData
from mmengine.model import BaseModel as EngineBaseModel
from mmengine.model import BaseModule
from mmengine.optim import OptimWrapper
from torch.utils.data import Dataset
@ -15,6 +14,7 @@ from mmselfsup.engine import DenseCLHook
from mmselfsup.models.algorithms import BaseModel
from mmselfsup.registry import MODELS
from mmselfsup.structures import SelfSupDataSample
from mmselfsup.utils import get_model
class DummyDataset(Dataset):
@ -33,7 +33,7 @@ class DummyDataset(Dataset):
data_sample = SelfSupDataSample()
gt_label = LabelData(value=self.label[index])
setattr(data_sample, 'gt_label', gt_label)
return dict(inputs=self.data[index], data_sample=data_sample)
return dict(inputs=[self.data[index]], data_sample=data_sample)
@MODELS.register_module()
@ -58,7 +58,7 @@ class ToyModel(BaseModel):
for x in data_samples:
labels.append(x.gt_label.value)
labels = torch.stack(labels)
outputs = self.backbone(batch_inputs)
outputs = self.backbone(batch_inputs[0])
loss = (labels - outputs).sum()
outputs = dict(loss=loss)
return outputs
@ -78,18 +78,9 @@ class TestDenseCLHook(TestCase):
toy_model = ToyModel().to(device)
densecl_hook = DenseCLHook(start_iters=1)
class DummyWrapper(EngineBaseModel):
def __init__(self, model):
super().__init__()
self.module = model
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
# test DenseCLHook with model wrapper
runner = Runner(
model=DummyWrapper(toy_model),
model=toy_model,
work_dir=self.temp_dir.name,
train_dataloader=dict(
dataset=dummy_dataset,
@ -108,6 +99,6 @@ class TestDenseCLHook(TestCase):
runner.train()
if runner.iter >= 1:
assert runner.model.module.loss_lambda == 0.5
assert get_model(runner.model).loss_lambda == 0.5
else:
assert runner.model.module.loss_lambda == 0.
assert get_model(runner.model).loss_lambda == 0.

View File

@ -6,7 +6,6 @@ import torch
import torch.nn as nn
from mmengine import Runner
from mmengine.data import LabelData
from mmengine.model import BaseModel as EngineBaseModel
from mmengine.model import BaseModule
from mmengine.optim import OptimWrapper
from torch.utils.data import Dataset
@ -16,6 +15,7 @@ from mmselfsup.models.algorithms import BaseModel
from mmselfsup.models.heads import SwAVHead
from mmselfsup.registry import MODELS
from mmselfsup.structures import SelfSupDataSample
from mmselfsup.utils import get_model
class DummyDataset(Dataset):
@ -34,7 +34,7 @@ class DummyDataset(Dataset):
data_sample = SelfSupDataSample()
gt_label = LabelData(value=self.label[index])
setattr(data_sample, 'gt_label', gt_label)
return dict(inputs=self.data[index], data_sample=data_sample)
return dict(inputs=[self.data[index]], data_sample=data_sample)
@MODELS.register_module()
@ -65,7 +65,7 @@ class ToyModel(BaseModel):
for x in data_samples:
labels.append(x.gt_label.value)
labels = torch.stack(labels)
outputs = self.backbone(batch_inputs)
outputs = self.backbone(batch_inputs[0])
loss = (labels - outputs).sum()
outputs = dict(loss=loss)
return outputs
@ -91,18 +91,9 @@ class TestSwAVHook(TestCase):
queue_length=300,
frozen_layers_cfg=dict(prototypes=2))
class DummyWrapper(EngineBaseModel):
def __init__(self, model):
super().__init__()
self.module = model
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
# test SwAVHook
runner = Runner(
model=DummyWrapper(toy_model),
model=toy_model,
work_dir=self.temp_dir.name,
train_dataloader=dict(
dataset=dummy_dataset,
@ -124,4 +115,4 @@ class TestSwAVHook(TestCase):
if isinstance(hook, SwAVHook):
assert hook.queue_length == 300
assert runner.model.module.head.loss.use_queue is False
assert get_model(runner.model).head.loss.use_queue is False