add_bevformer
jiangnana.jnn 2022-09-29 17:23:03 +08:00
parent ff3c2bd2c1
commit e7b69a8f63
10 changed files with 119 additions and 61 deletions

View File

@ -196,6 +196,9 @@ test_pipeline = [
data = dict(
imgs_per_gpu=1,
workers_per_gpu=4,
# TODO: support custom sampler config
# shuffler_sampler=dict(type='DistributedGroupSampler'),
# nonshuffler_sampler=dict(type='DistributedSampler'),
train=dict(
type=dataset_type,
data_root=data_root,
@ -226,9 +229,7 @@ data = dict(
pipeline=test_pipeline,
bev_size=(bev_h_, bev_w_),
classes=class_names,
modality=input_modality),
shuffler_sampler=dict(type='DistributedGroupSampler'),
nonshuffler_sampler=dict(type='DistributedSampler'))
modality=input_modality))
paramwise_cfg = dict(custom_keys={
'img_backbone': dict(lr_mult=0.1),

View File

@ -1,6 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from . import (classification, detection, detection3d, face, ocr, pose,
segmentation, selfsup, shared)
from .builder import build_dali_dataset, build_dataset
from .loader import DistributedGroupSampler, GroupSampler, build_dataloader
from .registry import DATASETS
from .builder import (build_dali_dataset, build_dataset, build_datasource,
build_sampler)
from .loader import (DistributedGivenIterationSampler, DistributedGroupSampler,
DistributedMPSampler, DistributedSampler, GroupSampler,
RASampler, build_dataloader)
from .registry import DATASETS, DATASOURCES, PIPELINES, SAMPLERS

View File

@ -4,7 +4,7 @@ import copy
from easycv.datasets.shared.dataset_wrappers import (ConcatDataset,
RepeatDataset)
from easycv.utils.registry import build_from_cfg
from .registry import DALIDATASETS, DATASETS, DATASOURCES
from .registry import DALIDATASETS, DATASETS, DATASOURCES, SAMPLERS
def _concat_dataset(cfg, default_args=None):
@ -47,3 +47,7 @@ def build_dali_dataset(cfg, default_args=None):
def build_datasource(cfg):
return build_from_cfg(cfg, DATASOURCES)
def build_sampler(cfg, default_args=None):
return build_from_cfg(cfg, SAMPLERS, default_args)

View File

@ -1,9 +1,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .build_loader import build_dataloader
from .sampler import (DistributedGivenIterationSampler,
DistributedGroupSampler, GroupSampler)
DistributedGroupSampler, DistributedMPSampler,
DistributedSampler, GroupSampler, RASampler)
__all__ = [
'GroupSampler', 'DistributedGroupSampler', 'build_dataloader',
'DistributedGivenIterationSampler'
'DistributedGivenIterationSampler', 'DistributedMPSampler', 'RASampler',
'DistributedSampler'
]

View File

@ -8,14 +8,14 @@ import numpy as np
import torch
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data import DataLoader
from easycv.datasets.builder import build_sampler
from easycv.datasets.shared.odps_reader import set_dataloader_workid
from easycv.framework.errors import NotImplementedError
from easycv.utils.dist_utils import sync_random_seed
from easycv.utils.torchacc_util import is_torchacc_enabled
from .collate import CollateWrapper
from .sampler import DistributedMPSampler, DistributedSampler, RASampler
if platform.system() != 'Windows':
# https://github.com/pytorch/pytorch/issues/973
@ -37,6 +37,7 @@ def build_dataloader(dataset,
persistent_workers=False,
collate_hooks=None,
use_repeated_augment_sampler=False,
sampler=None,
**kwargs):
"""Build PyTorch DataLoader.
In distributed training, each GPU/process has a dataloader.
@ -68,44 +69,48 @@ def build_dataloader(dataset,
if dist:
seed = sync_random_seed(seed)
split_huge_listfile_byrank = getattr(dataset,
'split_huge_listfile_byrank',
False)
if use_repeated_augment_sampler:
sampler = RASampler(dataset, world_size, rank, shuffle=shuffle)
elif hasattr(dataset, 'm_per_class') and dataset.m_per_class > 1:
sampler = DistributedMPSampler(
dataset,
world_size,
rank,
shuffle=shuffle,
split_huge_listfile_byrank=split_huge_listfile_byrank)
else:
sampler = DistributedSampler(
dataset,
world_size,
rank,
shuffle=shuffle,
seed=seed,
split_huge_listfile_byrank=split_huge_listfile_byrank)
batch_size = imgs_per_gpu
num_workers = workers_per_gpu
else:
if replace:
raise NotImplementedError
if use_repeated_augment_sampler:
sampler = RASampler(dataset, 1, 0, shuffle=shuffle)
elif hasattr(dataset, 'm_per_class') and dataset.m_per_class > 1:
sampler = DistributedMPSampler(
dataset, 1, 0, shuffle=shuffle, replace=replace)
else:
sampler = RandomSampler(
dataset) if shuffle else None # TODO: set replace
batch_size = num_gpus * imgs_per_gpu
num_workers = num_gpus * workers_per_gpu
default_sampler_args = dict(
dataset=dataset,
num_replicas=world_size,
rank=rank,
shuffle=shuffle,
seed=seed,
replace=replace)
split_huge_listfile_byrank = getattr(dataset, 'split_huge_listfile_byrank',
False)
if sampler is not None:
sampler_cfg = sampler
sampler_cfg.update(default_sampler_args)
elif use_repeated_augment_sampler:
sampler_cfg = dict(type='RASampler', **default_sampler_args)
elif hasattr(dataset, 'm_per_class') and dataset.m_per_class > 1:
sampler_cfg = dict(
type='DistributedMPSampler',
split_huge_listfile_byrank=split_huge_listfile_byrank,
**default_sampler_args)
else:
if dist:
sampler_cfg = dict(
type='DistributedSampler',
split_huge_listfile_byrank=split_huge_listfile_byrank,
**default_sampler_args)
else:
sampler_cfg = dict(
type='RandomSampler',
data_source=dataset) if shuffle else None # TODO: set replace
sampler = build_sampler(sampler_cfg) if sampler_cfg is not None else None
init_fn = partial(
worker_init_fn,
num_workers=num_workers,

View File

@ -1,7 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from __future__ import division
import math
import os
import random
import numpy as np
@ -9,11 +8,16 @@ import torch
import torch.distributed as dist
from mmcv.runner import get_dist_info
from torch.utils.data import DistributedSampler as _DistributedSampler
from torch.utils.data import Sampler
from torch.utils.data import RandomSampler, Sampler
from easycv.datasets.registry import SAMPLERS
from easycv.framework.errors import ValueError
from easycv.utils.dist_utils import local_rank
SAMPLERS.register_module(RandomSampler)
@SAMPLERS.register_module()
class DistributedMPSampler(_DistributedSampler):
def __init__(self,
@ -21,7 +25,8 @@ class DistributedMPSampler(_DistributedSampler):
num_replicas=None,
rank=None,
shuffle=True,
split_huge_listfile_byrank=False):
split_huge_listfile_byrank=False,
**kwargs):
""" A Distribute sampler which support sample m instance from one class once for classification dataset
dataset: pytorch dataset object
num_replicas (optional): Number of processes participating in
@ -33,9 +38,7 @@ class DistributedMPSampler(_DistributedSampler):
"""
super().__init__(dataset, num_replicas=num_replicas, rank=rank)
current_env = os.environ.copy()
self.local_rank = int(current_env['LOCAL_RANK'])
self.local_rank = local_rank()
self.shuffle = shuffle
self.unif_sampling_flag = False
self.split_huge_listfile_byrank = split_huge_listfile_byrank
@ -158,6 +161,7 @@ class DistributedMPSampler(_DistributedSampler):
return self.length
@SAMPLERS.register_module()
class DistributedSampler(_DistributedSampler):
def __init__(
@ -256,6 +260,7 @@ class DistributedSampler(_DistributedSampler):
return self.num_samples if not self.split_huge_listfile_byrank else self.num_samples * self.num_replicas
@SAMPLERS.register_module()
class GroupSampler(Sampler):
def __init__(self, dataset, samples_per_gpu=1):
@ -297,6 +302,7 @@ class GroupSampler(Sampler):
return self.num_samples
@SAMPLERS.register_module()
class DistributedGroupSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
@ -389,6 +395,7 @@ class DistributedGroupSampler(Sampler):
self.epoch = epoch
@SAMPLERS.register_module()
class DistributedGivenIterationSampler(Sampler):
def __init__(self,
@ -476,6 +483,7 @@ class DistributedGivenIterationSampler(Sampler):
pass
@SAMPLERS.register_module()
class RASampler(torch.utils.data.Sampler):
"""Sampler that restricts data loading to a subset of the dataset for distributed,
with repeated augmentation.
@ -489,7 +497,8 @@ class RASampler(torch.utils.data.Sampler):
num_replicas=None,
rank=None,
shuffle=True,
num_repeats: int = 3):
num_repeats: int = 3,
**kwargs):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError(

View File

@ -5,3 +5,4 @@ DATASOURCES = Registry('datasource')
DATASETS = Registry('dataset')
DALIDATASETS = Registry('dalidataset')
PIPELINES = Registry('pipeline')
SAMPLERS = Registry('sampler')

View File

@ -3,7 +3,9 @@ import os.path as osp
from collections import OrderedDict
import torch
import torch.distributed as dist
from mmcv.runner import Hook
from torch.nn.modules.batchnorm import _BatchNorm
from torch.utils.data import DataLoader
from easycv.datasets.loader.loader_wrapper import TorchaccLoaderWrapper
@ -136,6 +138,9 @@ class DistEvalHook(EvalHook):
processes. Default: None.
gpu_collect (bool): Whether to use gpu or cpu to collect results.
Default: False.
broadcast_bn_buffer (bool): Whether to broadcast the
buffer(running_mean and running_var) of rank 0 to other rank
before evaluation. Default: True.
"""
def __init__(self,
@ -145,6 +150,7 @@ class DistEvalHook(EvalHook):
initial=False,
gpu_collect=False,
flush_buffer=True,
broadcast_bn_buffer=True,
**eval_kwargs):
super(DistEvalHook, self).__init__(
@ -155,11 +161,26 @@ class DistEvalHook(EvalHook):
flush_buffer=flush_buffer,
**eval_kwargs)
self.broadcast_bn_buffer = broadcast_bn_buffer
self.gpu_collect = self.eval_kwargs.pop('gpu_collect', gpu_collect)
def after_train_epoch(self, runner):
if not self.every_n_epochs(runner, self.interval):
return
# Synchronization of BatchNorm's buffer (running_mean
# and running_var) is not supported in the DDP of pytorch,
# which may cause the inconsistent performance of models in
# different ranks, so we broadcast BatchNorm's buffers
# of rank 0 to other ranks to avoid this.
if self.broadcast_bn_buffer:
model = runner.model
for name, module in model.named_modules():
if isinstance(module,
_BatchNorm) and module.track_running_stats:
dist.broadcast(module.running_var, 0)
dist.broadcast(module.running_mean, 0)
from easycv.apis import multi_gpu_test
results = multi_gpu_test(
runner.model,

View File

@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import logging
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from typing import Dict
@ -20,24 +21,34 @@ class BaseModel(nn.Module, metaclass=ABCMeta):
def __init__(self, init_cfg=None):
super(BaseModel, self).__init__()
self._is_init = False
self.init_cfg = copy.deepcopy(init_cfg)
@property
def is_init(self) -> bool:
return self._is_init
def init_weights(self):
module_name = self.__class__.__name__
if not self._is_init:
if self.init_cfg:
print_log(
f'initialize {module_name} with init_cfg {self.init_cfg}')
initialize(self, self.init_cfg)
if isinstance(self.init_cfg, dict):
# prevent the parameters of the pre-trained model from being overwritten by the `init_weights`
if self.init_cfg['type'] == 'Pretrained':
logging.warning(
'Skip `init_cfg` with `Pretrained` type!')
return
if self.init_cfg:
print_log(
f'initialize {module_name} with init_cfg {self.init_cfg}')
initialize(self, self.init_cfg)
if isinstance(self.init_cfg, dict):
# prevent the parameters of the pre-trained model from being overwritten by the `init_weights`
if self.init_cfg['type'] == 'Pretrained':
logging.warning('Skip `init_cfg` with `Pretrained` type!')
return
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights()
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights()
self._is_init = True
else:
warnings.warn(f'init_weights of {self.__class__.__name__} has '
f'been called more than once.')
@abstractmethod
def forward_train(self, img: Tensor, **kwargs) -> Dict[str, Tensor]:

View File

@ -104,6 +104,7 @@ class MVXTwoStageDetector(Base3DDetector):
'key, please consider using init_cfg')
self.pts_backbone.init_cfg = dict(
type='Pretrained', checkpoint=pts_pretrained)
self.init_weights()
@property
def with_img_shared_head(self):