From e7b69a8f631dee8c64b1fede0c64b52fdcf77bcf Mon Sep 17 00:00:00 2001 From: "jiangnana.jnn" Date: Thu, 29 Sep 2022 17:23:03 +0800 Subject: [PATCH] fix code --- .../bevformer/bevformer_base.py | 7 +- easycv/datasets/__init__.py | 9 ++- easycv/datasets/builder.py | 6 +- easycv/datasets/loader/__init__.py | 6 +- easycv/datasets/loader/build_loader.py | 69 ++++++++++--------- easycv/datasets/loader/sampler.py | 23 +++++-- easycv/datasets/registry.py | 1 + easycv/hooks/eval_hook.py | 21 ++++++ easycv/models/base.py | 37 ++++++---- .../detection3d/detectors/mvx_two_stage.py | 1 + 10 files changed, 119 insertions(+), 61 deletions(-) diff --git a/configs/autonomous_driving /bevformer/bevformer_base.py b/configs/autonomous_driving /bevformer/bevformer_base.py index ec365c69..42098c12 100644 --- a/configs/autonomous_driving /bevformer/bevformer_base.py +++ b/configs/autonomous_driving /bevformer/bevformer_base.py @@ -196,6 +196,9 @@ test_pipeline = [ data = dict( imgs_per_gpu=1, workers_per_gpu=4, + # TODO: support custom sampler config + # shuffler_sampler=dict(type='DistributedGroupSampler'), + # nonshuffler_sampler=dict(type='DistributedSampler'), train=dict( type=dataset_type, data_root=data_root, @@ -226,9 +229,7 @@ data = dict( pipeline=test_pipeline, bev_size=(bev_h_, bev_w_), classes=class_names, - modality=input_modality), - shuffler_sampler=dict(type='DistributedGroupSampler'), - nonshuffler_sampler=dict(type='DistributedSampler')) + modality=input_modality)) paramwise_cfg = dict(custom_keys={ 'img_backbone': dict(lr_mult=0.1), diff --git a/easycv/datasets/__init__.py b/easycv/datasets/__init__.py index bd2304de..935d8782 100644 --- a/easycv/datasets/__init__.py +++ b/easycv/datasets/__init__.py @@ -1,6 +1,9 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from . import (classification, detection, detection3d, face, ocr, pose, segmentation, selfsup, shared) -from .builder import build_dali_dataset, build_dataset -from .loader import DistributedGroupSampler, GroupSampler, build_dataloader -from .registry import DATASETS +from .builder import (build_dali_dataset, build_dataset, build_datasource, + build_sampler) +from .loader import (DistributedGivenIterationSampler, DistributedGroupSampler, + DistributedMPSampler, DistributedSampler, GroupSampler, + RASampler, build_dataloader) +from .registry import DATASETS, DATASOURCES, PIPELINES, SAMPLERS diff --git a/easycv/datasets/builder.py b/easycv/datasets/builder.py index 3e2911e2..2e541e92 100644 --- a/easycv/datasets/builder.py +++ b/easycv/datasets/builder.py @@ -4,7 +4,7 @@ import copy from easycv.datasets.shared.dataset_wrappers import (ConcatDataset, RepeatDataset) from easycv.utils.registry import build_from_cfg -from .registry import DALIDATASETS, DATASETS, DATASOURCES +from .registry import DALIDATASETS, DATASETS, DATASOURCES, SAMPLERS def _concat_dataset(cfg, default_args=None): @@ -47,3 +47,7 @@ def build_dali_dataset(cfg, default_args=None): def build_datasource(cfg): return build_from_cfg(cfg, DATASOURCES) + + +def build_sampler(cfg, default_args=None): + return build_from_cfg(cfg, SAMPLERS, default_args) diff --git a/easycv/datasets/loader/__init__.py b/easycv/datasets/loader/__init__.py index fe6b1b0d..1234dc3f 100644 --- a/easycv/datasets/loader/__init__.py +++ b/easycv/datasets/loader/__init__.py @@ -1,9 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .build_loader import build_dataloader from .sampler import (DistributedGivenIterationSampler, - DistributedGroupSampler, GroupSampler) + DistributedGroupSampler, DistributedMPSampler, + DistributedSampler, GroupSampler, RASampler) __all__ = [ 'GroupSampler', 'DistributedGroupSampler', 'build_dataloader', - 'DistributedGivenIterationSampler' + 'DistributedGivenIterationSampler', 'DistributedMPSampler', 'RASampler', + 'DistributedSampler' ] diff --git a/easycv/datasets/loader/build_loader.py b/easycv/datasets/loader/build_loader.py index 4977553b..3681c1a6 100644 --- a/easycv/datasets/loader/build_loader.py +++ b/easycv/datasets/loader/build_loader.py @@ -8,14 +8,14 @@ import numpy as np import torch from mmcv.parallel import collate from mmcv.runner import get_dist_info -from torch.utils.data import DataLoader, RandomSampler +from torch.utils.data import DataLoader +from easycv.datasets.builder import build_sampler from easycv.datasets.shared.odps_reader import set_dataloader_workid from easycv.framework.errors import NotImplementedError from easycv.utils.dist_utils import sync_random_seed from easycv.utils.torchacc_util import is_torchacc_enabled from .collate import CollateWrapper -from .sampler import DistributedMPSampler, DistributedSampler, RASampler if platform.system() != 'Windows': # https://github.com/pytorch/pytorch/issues/973 @@ -37,6 +37,7 @@ def build_dataloader(dataset, persistent_workers=False, collate_hooks=None, use_repeated_augment_sampler=False, + sampler=None, **kwargs): """Build PyTorch DataLoader. In distributed training, each GPU/process has a dataloader. @@ -68,44 +69,48 @@ def build_dataloader(dataset, if dist: seed = sync_random_seed(seed) - split_huge_listfile_byrank = getattr(dataset, - 'split_huge_listfile_byrank', - False) - - if use_repeated_augment_sampler: - sampler = RASampler(dataset, world_size, rank, shuffle=shuffle) - elif hasattr(dataset, 'm_per_class') and dataset.m_per_class > 1: - sampler = DistributedMPSampler( - dataset, - world_size, - rank, - shuffle=shuffle, - split_huge_listfile_byrank=split_huge_listfile_byrank) - else: - sampler = DistributedSampler( - dataset, - world_size, - rank, - shuffle=shuffle, - seed=seed, - split_huge_listfile_byrank=split_huge_listfile_byrank) batch_size = imgs_per_gpu num_workers = workers_per_gpu else: if replace: raise NotImplementedError - - if use_repeated_augment_sampler: - sampler = RASampler(dataset, 1, 0, shuffle=shuffle) - elif hasattr(dataset, 'm_per_class') and dataset.m_per_class > 1: - sampler = DistributedMPSampler( - dataset, 1, 0, shuffle=shuffle, replace=replace) - else: - sampler = RandomSampler( - dataset) if shuffle else None # TODO: set replace batch_size = num_gpus * imgs_per_gpu num_workers = num_gpus * workers_per_gpu + default_sampler_args = dict( + dataset=dataset, + num_replicas=world_size, + rank=rank, + shuffle=shuffle, + seed=seed, + replace=replace) + + split_huge_listfile_byrank = getattr(dataset, 'split_huge_listfile_byrank', + False) + + if sampler is not None: + sampler_cfg = sampler + sampler_cfg.update(default_sampler_args) + elif use_repeated_augment_sampler: + sampler_cfg = dict(type='RASampler', **default_sampler_args) + elif hasattr(dataset, 'm_per_class') and dataset.m_per_class > 1: + sampler_cfg = dict( + type='DistributedMPSampler', + split_huge_listfile_byrank=split_huge_listfile_byrank, + **default_sampler_args) + else: + if dist: + sampler_cfg = dict( + type='DistributedSampler', + split_huge_listfile_byrank=split_huge_listfile_byrank, + **default_sampler_args) + else: + sampler_cfg = dict( + type='RandomSampler', + data_source=dataset) if shuffle else None # TODO: set replace + + sampler = build_sampler(sampler_cfg) if sampler_cfg is not None else None + init_fn = partial( worker_init_fn, num_workers=num_workers, diff --git a/easycv/datasets/loader/sampler.py b/easycv/datasets/loader/sampler.py index fd39d054..61a88eff 100644 --- a/easycv/datasets/loader/sampler.py +++ b/easycv/datasets/loader/sampler.py @@ -1,7 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from __future__ import division import math -import os import random import numpy as np @@ -9,11 +8,16 @@ import torch import torch.distributed as dist from mmcv.runner import get_dist_info from torch.utils.data import DistributedSampler as _DistributedSampler -from torch.utils.data import Sampler +from torch.utils.data import RandomSampler, Sampler +from easycv.datasets.registry import SAMPLERS from easycv.framework.errors import ValueError +from easycv.utils.dist_utils import local_rank + +SAMPLERS.register_module(RandomSampler) +@SAMPLERS.register_module() class DistributedMPSampler(_DistributedSampler): def __init__(self, @@ -21,7 +25,8 @@ class DistributedMPSampler(_DistributedSampler): num_replicas=None, rank=None, shuffle=True, - split_huge_listfile_byrank=False): + split_huge_listfile_byrank=False, + **kwargs): """ A Distribute sampler which support sample m instance from one class once for classification dataset dataset: pytorch dataset object num_replicas (optional): Number of processes participating in @@ -33,9 +38,7 @@ class DistributedMPSampler(_DistributedSampler): """ super().__init__(dataset, num_replicas=num_replicas, rank=rank) - current_env = os.environ.copy() - self.local_rank = int(current_env['LOCAL_RANK']) - + self.local_rank = local_rank() self.shuffle = shuffle self.unif_sampling_flag = False self.split_huge_listfile_byrank = split_huge_listfile_byrank @@ -158,6 +161,7 @@ class DistributedMPSampler(_DistributedSampler): return self.length +@SAMPLERS.register_module() class DistributedSampler(_DistributedSampler): def __init__( @@ -256,6 +260,7 @@ class DistributedSampler(_DistributedSampler): return self.num_samples if not self.split_huge_listfile_byrank else self.num_samples * self.num_replicas +@SAMPLERS.register_module() class GroupSampler(Sampler): def __init__(self, dataset, samples_per_gpu=1): @@ -297,6 +302,7 @@ class GroupSampler(Sampler): return self.num_samples +@SAMPLERS.register_module() class DistributedGroupSampler(Sampler): """Sampler that restricts data loading to a subset of the dataset. It is especially useful in conjunction with @@ -389,6 +395,7 @@ class DistributedGroupSampler(Sampler): self.epoch = epoch +@SAMPLERS.register_module() class DistributedGivenIterationSampler(Sampler): def __init__(self, @@ -476,6 +483,7 @@ class DistributedGivenIterationSampler(Sampler): pass +@SAMPLERS.register_module() class RASampler(torch.utils.data.Sampler): """Sampler that restricts data loading to a subset of the dataset for distributed, with repeated augmentation. @@ -489,7 +497,8 @@ class RASampler(torch.utils.data.Sampler): num_replicas=None, rank=None, shuffle=True, - num_repeats: int = 3): + num_repeats: int = 3, + **kwargs): if num_replicas is None: if not dist.is_available(): raise RuntimeError( diff --git a/easycv/datasets/registry.py b/easycv/datasets/registry.py index b8f25477..649372ea 100644 --- a/easycv/datasets/registry.py +++ b/easycv/datasets/registry.py @@ -5,3 +5,4 @@ DATASOURCES = Registry('datasource') DATASETS = Registry('dataset') DALIDATASETS = Registry('dalidataset') PIPELINES = Registry('pipeline') +SAMPLERS = Registry('sampler') diff --git a/easycv/hooks/eval_hook.py b/easycv/hooks/eval_hook.py index a4065617..c0967898 100644 --- a/easycv/hooks/eval_hook.py +++ b/easycv/hooks/eval_hook.py @@ -3,7 +3,9 @@ import os.path as osp from collections import OrderedDict import torch +import torch.distributed as dist from mmcv.runner import Hook +from torch.nn.modules.batchnorm import _BatchNorm from torch.utils.data import DataLoader from easycv.datasets.loader.loader_wrapper import TorchaccLoaderWrapper @@ -136,6 +138,9 @@ class DistEvalHook(EvalHook): processes. Default: None. gpu_collect (bool): Whether to use gpu or cpu to collect results. Default: False. + broadcast_bn_buffer (bool): Whether to broadcast the + buffer(running_mean and running_var) of rank 0 to other rank + before evaluation. Default: True. """ def __init__(self, @@ -145,6 +150,7 @@ class DistEvalHook(EvalHook): initial=False, gpu_collect=False, flush_buffer=True, + broadcast_bn_buffer=True, **eval_kwargs): super(DistEvalHook, self).__init__( @@ -155,11 +161,26 @@ class DistEvalHook(EvalHook): flush_buffer=flush_buffer, **eval_kwargs) + self.broadcast_bn_buffer = broadcast_bn_buffer self.gpu_collect = self.eval_kwargs.pop('gpu_collect', gpu_collect) def after_train_epoch(self, runner): if not self.every_n_epochs(runner, self.interval): return + + # Synchronization of BatchNorm's buffer (running_mean + # and running_var) is not supported in the DDP of pytorch, + # which may cause the inconsistent performance of models in + # different ranks, so we broadcast BatchNorm's buffers + # of rank 0 to other ranks to avoid this. + if self.broadcast_bn_buffer: + model = runner.model + for name, module in model.named_modules(): + if isinstance(module, + _BatchNorm) and module.track_running_stats: + dist.broadcast(module.running_var, 0) + dist.broadcast(module.running_mean, 0) + from easycv.apis import multi_gpu_test results = multi_gpu_test( runner.model, diff --git a/easycv/models/base.py b/easycv/models/base.py index 9eceec4c..ad97c9ac 100644 --- a/easycv/models/base.py +++ b/easycv/models/base.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import copy import logging +import warnings from abc import ABCMeta, abstractmethod from collections import OrderedDict from typing import Dict @@ -20,24 +21,34 @@ class BaseModel(nn.Module, metaclass=ABCMeta): def __init__(self, init_cfg=None): super(BaseModel, self).__init__() + self._is_init = False self.init_cfg = copy.deepcopy(init_cfg) + @property + def is_init(self) -> bool: + return self._is_init + def init_weights(self): module_name = self.__class__.__name__ + if not self._is_init: + if self.init_cfg: + print_log( + f'initialize {module_name} with init_cfg {self.init_cfg}') + initialize(self, self.init_cfg) + if isinstance(self.init_cfg, dict): + # prevent the parameters of the pre-trained model from being overwritten by the `init_weights` + if self.init_cfg['type'] == 'Pretrained': + logging.warning( + 'Skip `init_cfg` with `Pretrained` type!') + return - if self.init_cfg: - print_log( - f'initialize {module_name} with init_cfg {self.init_cfg}') - initialize(self, self.init_cfg) - if isinstance(self.init_cfg, dict): - # prevent the parameters of the pre-trained model from being overwritten by the `init_weights` - if self.init_cfg['type'] == 'Pretrained': - logging.warning('Skip `init_cfg` with `Pretrained` type!') - return - - for m in self.children(): - if hasattr(m, 'init_weights'): - m.init_weights() + for m in self.children(): + if hasattr(m, 'init_weights'): + m.init_weights() + self._is_init = True + else: + warnings.warn(f'init_weights of {self.__class__.__name__} has ' + f'been called more than once.') @abstractmethod def forward_train(self, img: Tensor, **kwargs) -> Dict[str, Tensor]: diff --git a/easycv/models/detection3d/detectors/mvx_two_stage.py b/easycv/models/detection3d/detectors/mvx_two_stage.py index f58871f5..58048d44 100644 --- a/easycv/models/detection3d/detectors/mvx_two_stage.py +++ b/easycv/models/detection3d/detectors/mvx_two_stage.py @@ -104,6 +104,7 @@ class MVXTwoStageDetector(Base3DDetector): 'key, please consider using init_cfg') self.pts_backbone.init_cfg = dict( type='Pretrained', checkpoint=pts_pretrained) + self.init_weights() @property def with_img_shared_head(self):