mirror of https://github.com/alibaba/EasyCV.git
fix code
parent
ff3c2bd2c1
commit
e7b69a8f63
|
@ -196,6 +196,9 @@ test_pipeline = [
|
|||
data = dict(
|
||||
imgs_per_gpu=1,
|
||||
workers_per_gpu=4,
|
||||
# TODO: support custom sampler config
|
||||
# shuffler_sampler=dict(type='DistributedGroupSampler'),
|
||||
# nonshuffler_sampler=dict(type='DistributedSampler'),
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
|
@ -226,9 +229,7 @@ data = dict(
|
|||
pipeline=test_pipeline,
|
||||
bev_size=(bev_h_, bev_w_),
|
||||
classes=class_names,
|
||||
modality=input_modality),
|
||||
shuffler_sampler=dict(type='DistributedGroupSampler'),
|
||||
nonshuffler_sampler=dict(type='DistributedSampler'))
|
||||
modality=input_modality))
|
||||
|
||||
paramwise_cfg = dict(custom_keys={
|
||||
'img_backbone': dict(lr_mult=0.1),
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from . import (classification, detection, detection3d, face, ocr, pose,
|
||||
segmentation, selfsup, shared)
|
||||
from .builder import build_dali_dataset, build_dataset
|
||||
from .loader import DistributedGroupSampler, GroupSampler, build_dataloader
|
||||
from .registry import DATASETS
|
||||
from .builder import (build_dali_dataset, build_dataset, build_datasource,
|
||||
build_sampler)
|
||||
from .loader import (DistributedGivenIterationSampler, DistributedGroupSampler,
|
||||
DistributedMPSampler, DistributedSampler, GroupSampler,
|
||||
RASampler, build_dataloader)
|
||||
from .registry import DATASETS, DATASOURCES, PIPELINES, SAMPLERS
|
||||
|
|
|
@ -4,7 +4,7 @@ import copy
|
|||
from easycv.datasets.shared.dataset_wrappers import (ConcatDataset,
|
||||
RepeatDataset)
|
||||
from easycv.utils.registry import build_from_cfg
|
||||
from .registry import DALIDATASETS, DATASETS, DATASOURCES
|
||||
from .registry import DALIDATASETS, DATASETS, DATASOURCES, SAMPLERS
|
||||
|
||||
|
||||
def _concat_dataset(cfg, default_args=None):
|
||||
|
@ -47,3 +47,7 @@ def build_dali_dataset(cfg, default_args=None):
|
|||
|
||||
def build_datasource(cfg):
|
||||
return build_from_cfg(cfg, DATASOURCES)
|
||||
|
||||
|
||||
def build_sampler(cfg, default_args=None):
|
||||
return build_from_cfg(cfg, SAMPLERS, default_args)
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .build_loader import build_dataloader
|
||||
from .sampler import (DistributedGivenIterationSampler,
|
||||
DistributedGroupSampler, GroupSampler)
|
||||
DistributedGroupSampler, DistributedMPSampler,
|
||||
DistributedSampler, GroupSampler, RASampler)
|
||||
|
||||
__all__ = [
|
||||
'GroupSampler', 'DistributedGroupSampler', 'build_dataloader',
|
||||
'DistributedGivenIterationSampler'
|
||||
'DistributedGivenIterationSampler', 'DistributedMPSampler', 'RASampler',
|
||||
'DistributedSampler'
|
||||
]
|
||||
|
|
|
@ -8,14 +8,14 @@ import numpy as np
|
|||
import torch
|
||||
from mmcv.parallel import collate
|
||||
from mmcv.runner import get_dist_info
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from easycv.datasets.builder import build_sampler
|
||||
from easycv.datasets.shared.odps_reader import set_dataloader_workid
|
||||
from easycv.framework.errors import NotImplementedError
|
||||
from easycv.utils.dist_utils import sync_random_seed
|
||||
from easycv.utils.torchacc_util import is_torchacc_enabled
|
||||
from .collate import CollateWrapper
|
||||
from .sampler import DistributedMPSampler, DistributedSampler, RASampler
|
||||
|
||||
if platform.system() != 'Windows':
|
||||
# https://github.com/pytorch/pytorch/issues/973
|
||||
|
@ -37,6 +37,7 @@ def build_dataloader(dataset,
|
|||
persistent_workers=False,
|
||||
collate_hooks=None,
|
||||
use_repeated_augment_sampler=False,
|
||||
sampler=None,
|
||||
**kwargs):
|
||||
"""Build PyTorch DataLoader.
|
||||
In distributed training, each GPU/process has a dataloader.
|
||||
|
@ -68,44 +69,48 @@ def build_dataloader(dataset,
|
|||
|
||||
if dist:
|
||||
seed = sync_random_seed(seed)
|
||||
split_huge_listfile_byrank = getattr(dataset,
|
||||
'split_huge_listfile_byrank',
|
||||
False)
|
||||
|
||||
if use_repeated_augment_sampler:
|
||||
sampler = RASampler(dataset, world_size, rank, shuffle=shuffle)
|
||||
elif hasattr(dataset, 'm_per_class') and dataset.m_per_class > 1:
|
||||
sampler = DistributedMPSampler(
|
||||
dataset,
|
||||
world_size,
|
||||
rank,
|
||||
shuffle=shuffle,
|
||||
split_huge_listfile_byrank=split_huge_listfile_byrank)
|
||||
else:
|
||||
sampler = DistributedSampler(
|
||||
dataset,
|
||||
world_size,
|
||||
rank,
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
split_huge_listfile_byrank=split_huge_listfile_byrank)
|
||||
batch_size = imgs_per_gpu
|
||||
num_workers = workers_per_gpu
|
||||
else:
|
||||
if replace:
|
||||
raise NotImplementedError
|
||||
|
||||
if use_repeated_augment_sampler:
|
||||
sampler = RASampler(dataset, 1, 0, shuffle=shuffle)
|
||||
elif hasattr(dataset, 'm_per_class') and dataset.m_per_class > 1:
|
||||
sampler = DistributedMPSampler(
|
||||
dataset, 1, 0, shuffle=shuffle, replace=replace)
|
||||
else:
|
||||
sampler = RandomSampler(
|
||||
dataset) if shuffle else None # TODO: set replace
|
||||
batch_size = num_gpus * imgs_per_gpu
|
||||
num_workers = num_gpus * workers_per_gpu
|
||||
|
||||
default_sampler_args = dict(
|
||||
dataset=dataset,
|
||||
num_replicas=world_size,
|
||||
rank=rank,
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
replace=replace)
|
||||
|
||||
split_huge_listfile_byrank = getattr(dataset, 'split_huge_listfile_byrank',
|
||||
False)
|
||||
|
||||
if sampler is not None:
|
||||
sampler_cfg = sampler
|
||||
sampler_cfg.update(default_sampler_args)
|
||||
elif use_repeated_augment_sampler:
|
||||
sampler_cfg = dict(type='RASampler', **default_sampler_args)
|
||||
elif hasattr(dataset, 'm_per_class') and dataset.m_per_class > 1:
|
||||
sampler_cfg = dict(
|
||||
type='DistributedMPSampler',
|
||||
split_huge_listfile_byrank=split_huge_listfile_byrank,
|
||||
**default_sampler_args)
|
||||
else:
|
||||
if dist:
|
||||
sampler_cfg = dict(
|
||||
type='DistributedSampler',
|
||||
split_huge_listfile_byrank=split_huge_listfile_byrank,
|
||||
**default_sampler_args)
|
||||
else:
|
||||
sampler_cfg = dict(
|
||||
type='RandomSampler',
|
||||
data_source=dataset) if shuffle else None # TODO: set replace
|
||||
|
||||
sampler = build_sampler(sampler_cfg) if sampler_cfg is not None else None
|
||||
|
||||
init_fn = partial(
|
||||
worker_init_fn,
|
||||
num_workers=num_workers,
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from __future__ import division
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
@ -9,11 +8,16 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from mmcv.runner import get_dist_info
|
||||
from torch.utils.data import DistributedSampler as _DistributedSampler
|
||||
from torch.utils.data import Sampler
|
||||
from torch.utils.data import RandomSampler, Sampler
|
||||
|
||||
from easycv.datasets.registry import SAMPLERS
|
||||
from easycv.framework.errors import ValueError
|
||||
from easycv.utils.dist_utils import local_rank
|
||||
|
||||
SAMPLERS.register_module(RandomSampler)
|
||||
|
||||
|
||||
@SAMPLERS.register_module()
|
||||
class DistributedMPSampler(_DistributedSampler):
|
||||
|
||||
def __init__(self,
|
||||
|
@ -21,7 +25,8 @@ class DistributedMPSampler(_DistributedSampler):
|
|||
num_replicas=None,
|
||||
rank=None,
|
||||
shuffle=True,
|
||||
split_huge_listfile_byrank=False):
|
||||
split_huge_listfile_byrank=False,
|
||||
**kwargs):
|
||||
""" A Distribute sampler which support sample m instance from one class once for classification dataset
|
||||
dataset: pytorch dataset object
|
||||
num_replicas (optional): Number of processes participating in
|
||||
|
@ -33,9 +38,7 @@ class DistributedMPSampler(_DistributedSampler):
|
|||
"""
|
||||
super().__init__(dataset, num_replicas=num_replicas, rank=rank)
|
||||
|
||||
current_env = os.environ.copy()
|
||||
self.local_rank = int(current_env['LOCAL_RANK'])
|
||||
|
||||
self.local_rank = local_rank()
|
||||
self.shuffle = shuffle
|
||||
self.unif_sampling_flag = False
|
||||
self.split_huge_listfile_byrank = split_huge_listfile_byrank
|
||||
|
@ -158,6 +161,7 @@ class DistributedMPSampler(_DistributedSampler):
|
|||
return self.length
|
||||
|
||||
|
||||
@SAMPLERS.register_module()
|
||||
class DistributedSampler(_DistributedSampler):
|
||||
|
||||
def __init__(
|
||||
|
@ -256,6 +260,7 @@ class DistributedSampler(_DistributedSampler):
|
|||
return self.num_samples if not self.split_huge_listfile_byrank else self.num_samples * self.num_replicas
|
||||
|
||||
|
||||
@SAMPLERS.register_module()
|
||||
class GroupSampler(Sampler):
|
||||
|
||||
def __init__(self, dataset, samples_per_gpu=1):
|
||||
|
@ -297,6 +302,7 @@ class GroupSampler(Sampler):
|
|||
return self.num_samples
|
||||
|
||||
|
||||
@SAMPLERS.register_module()
|
||||
class DistributedGroupSampler(Sampler):
|
||||
"""Sampler that restricts data loading to a subset of the dataset.
|
||||
It is especially useful in conjunction with
|
||||
|
@ -389,6 +395,7 @@ class DistributedGroupSampler(Sampler):
|
|||
self.epoch = epoch
|
||||
|
||||
|
||||
@SAMPLERS.register_module()
|
||||
class DistributedGivenIterationSampler(Sampler):
|
||||
|
||||
def __init__(self,
|
||||
|
@ -476,6 +483,7 @@ class DistributedGivenIterationSampler(Sampler):
|
|||
pass
|
||||
|
||||
|
||||
@SAMPLERS.register_module()
|
||||
class RASampler(torch.utils.data.Sampler):
|
||||
"""Sampler that restricts data loading to a subset of the dataset for distributed,
|
||||
with repeated augmentation.
|
||||
|
@ -489,7 +497,8 @@ class RASampler(torch.utils.data.Sampler):
|
|||
num_replicas=None,
|
||||
rank=None,
|
||||
shuffle=True,
|
||||
num_repeats: int = 3):
|
||||
num_repeats: int = 3,
|
||||
**kwargs):
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError(
|
||||
|
|
|
@ -5,3 +5,4 @@ DATASOURCES = Registry('datasource')
|
|||
DATASETS = Registry('dataset')
|
||||
DALIDATASETS = Registry('dalidataset')
|
||||
PIPELINES = Registry('pipeline')
|
||||
SAMPLERS = Registry('sampler')
|
||||
|
|
|
@ -3,7 +3,9 @@ import os.path as osp
|
|||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv.runner import Hook
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from easycv.datasets.loader.loader_wrapper import TorchaccLoaderWrapper
|
||||
|
@ -136,6 +138,9 @@ class DistEvalHook(EvalHook):
|
|||
processes. Default: None.
|
||||
gpu_collect (bool): Whether to use gpu or cpu to collect results.
|
||||
Default: False.
|
||||
broadcast_bn_buffer (bool): Whether to broadcast the
|
||||
buffer(running_mean and running_var) of rank 0 to other rank
|
||||
before evaluation. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -145,6 +150,7 @@ class DistEvalHook(EvalHook):
|
|||
initial=False,
|
||||
gpu_collect=False,
|
||||
flush_buffer=True,
|
||||
broadcast_bn_buffer=True,
|
||||
**eval_kwargs):
|
||||
|
||||
super(DistEvalHook, self).__init__(
|
||||
|
@ -155,11 +161,26 @@ class DistEvalHook(EvalHook):
|
|||
flush_buffer=flush_buffer,
|
||||
**eval_kwargs)
|
||||
|
||||
self.broadcast_bn_buffer = broadcast_bn_buffer
|
||||
self.gpu_collect = self.eval_kwargs.pop('gpu_collect', gpu_collect)
|
||||
|
||||
def after_train_epoch(self, runner):
|
||||
if not self.every_n_epochs(runner, self.interval):
|
||||
return
|
||||
|
||||
# Synchronization of BatchNorm's buffer (running_mean
|
||||
# and running_var) is not supported in the DDP of pytorch,
|
||||
# which may cause the inconsistent performance of models in
|
||||
# different ranks, so we broadcast BatchNorm's buffers
|
||||
# of rank 0 to other ranks to avoid this.
|
||||
if self.broadcast_bn_buffer:
|
||||
model = runner.model
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module,
|
||||
_BatchNorm) and module.track_running_stats:
|
||||
dist.broadcast(module.running_var, 0)
|
||||
dist.broadcast(module.running_mean, 0)
|
||||
|
||||
from easycv.apis import multi_gpu_test
|
||||
results = multi_gpu_test(
|
||||
runner.model,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import copy
|
||||
import logging
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from typing import Dict
|
||||
|
@ -20,24 +21,34 @@ class BaseModel(nn.Module, metaclass=ABCMeta):
|
|||
|
||||
def __init__(self, init_cfg=None):
|
||||
super(BaseModel, self).__init__()
|
||||
self._is_init = False
|
||||
self.init_cfg = copy.deepcopy(init_cfg)
|
||||
|
||||
@property
|
||||
def is_init(self) -> bool:
|
||||
return self._is_init
|
||||
|
||||
def init_weights(self):
|
||||
module_name = self.__class__.__name__
|
||||
if not self._is_init:
|
||||
if self.init_cfg:
|
||||
print_log(
|
||||
f'initialize {module_name} with init_cfg {self.init_cfg}')
|
||||
initialize(self, self.init_cfg)
|
||||
if isinstance(self.init_cfg, dict):
|
||||
# prevent the parameters of the pre-trained model from being overwritten by the `init_weights`
|
||||
if self.init_cfg['type'] == 'Pretrained':
|
||||
logging.warning(
|
||||
'Skip `init_cfg` with `Pretrained` type!')
|
||||
return
|
||||
|
||||
if self.init_cfg:
|
||||
print_log(
|
||||
f'initialize {module_name} with init_cfg {self.init_cfg}')
|
||||
initialize(self, self.init_cfg)
|
||||
if isinstance(self.init_cfg, dict):
|
||||
# prevent the parameters of the pre-trained model from being overwritten by the `init_weights`
|
||||
if self.init_cfg['type'] == 'Pretrained':
|
||||
logging.warning('Skip `init_cfg` with `Pretrained` type!')
|
||||
return
|
||||
|
||||
for m in self.children():
|
||||
if hasattr(m, 'init_weights'):
|
||||
m.init_weights()
|
||||
for m in self.children():
|
||||
if hasattr(m, 'init_weights'):
|
||||
m.init_weights()
|
||||
self._is_init = True
|
||||
else:
|
||||
warnings.warn(f'init_weights of {self.__class__.__name__} has '
|
||||
f'been called more than once.')
|
||||
|
||||
@abstractmethod
|
||||
def forward_train(self, img: Tensor, **kwargs) -> Dict[str, Tensor]:
|
||||
|
|
|
@ -104,6 +104,7 @@ class MVXTwoStageDetector(Base3DDetector):
|
|||
'key, please consider using init_cfg')
|
||||
self.pts_backbone.init_cfg = dict(
|
||||
type='Pretrained', checkpoint=pts_pretrained)
|
||||
self.init_weights()
|
||||
|
||||
@property
|
||||
def with_img_shared_head(self):
|
||||
|
|
Loading…
Reference in New Issue