From a0d4f95725f077c73a2a5d235b7a53b2215aaf81 Mon Sep 17 00:00:00 2001 From: yanglei Date: Wed, 27 May 2020 11:37:16 +0800 Subject: [PATCH] Add missing mmcls.core --- mmcls/core/__init__.py | 1 + mmcls/core/evaluation/__init__.py | 4 +-- mmcls/core/fp16/hooks.py | 2 +- mmcls/core/optimizer/__init__.py | 3 +++ mmcls/core/optimizer/builder.py | 42 +++++++++++++++++++++++++++++ mmcls/core/utils/__init__.py | 4 +-- mmcls/core/utils/dist_utils.py | 40 +++++++++++++++++++++++++++ mmcls/datasets/pipelines/compose.py | 2 +- tools/test.py | 3 ++- 9 files changed, 94 insertions(+), 7 deletions(-) create mode 100644 mmcls/core/optimizer/__init__.py create mode 100644 mmcls/core/optimizer/builder.py diff --git a/mmcls/core/__init__.py b/mmcls/core/__init__.py index ee0dac43..e2504c57 100644 --- a/mmcls/core/__init__.py +++ b/mmcls/core/__init__.py @@ -1,3 +1,4 @@ from .evaluation import * # noqa: F401, F403 from .fp16 import * # noqa: F401, F403 +from .optimizer import * # noqa: F401, F403 from .utils import * # noqa: F401, F403 diff --git a/mmcls/core/evaluation/__init__.py b/mmcls/core/evaluation/__init__.py index 3dca68df..3038e4b1 100644 --- a/mmcls/core/evaluation/__init__.py +++ b/mmcls/core/evaluation/__init__.py @@ -1,3 +1,3 @@ -from .eval_hooks import EvalHook +from .eval_hooks import DistEvalHook, EvalHook -__all__ = ['EvalHook'] +__all__ = ['DistEvalHook', 'EvalHook'] diff --git a/mmcls/core/fp16/hooks.py b/mmcls/core/fp16/hooks.py index be2a921c..c3d4e098 100644 --- a/mmcls/core/fp16/hooks.py +++ b/mmcls/core/fp16/hooks.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn from mmcv.runner import OptimizerHook -from ..dist_utils import allreduce_grads +from ..utils import allreduce_grads from .utils import cast_tensor_type diff --git a/mmcls/core/optimizer/__init__.py b/mmcls/core/optimizer/__init__.py new file mode 100644 index 00000000..d16477d7 --- /dev/null +++ b/mmcls/core/optimizer/__init__.py @@ -0,0 +1,3 @@ +from .builder import build_optimizer + +__all__ = ['build_optimizer'] diff --git a/mmcls/core/optimizer/builder.py b/mmcls/core/optimizer/builder.py new file mode 100644 index 00000000..539995a3 --- /dev/null +++ b/mmcls/core/optimizer/builder.py @@ -0,0 +1,42 @@ +import copy +import inspect + +import torch +from mmcv.utils import Registry, build_from_cfg + +OPTIMIZERS = Registry('optimizer') +OPTIMIZER_BUILDERS = Registry('optimizer builder') + + +def register_torch_optimizers(): + torch_optimizers = [] + for module_name in dir(torch.optim): + if module_name.startswith('__'): + continue + _optim = getattr(torch.optim, module_name) + if inspect.isclass(_optim) and issubclass(_optim, + torch.optim.Optimizer): + OPTIMIZERS.register_module()(_optim) + torch_optimizers.append(module_name) + return torch_optimizers + + +TORCH_OPTIMIZERS = register_torch_optimizers() + + +def build_optimizer_constructor(cfg): + return build_from_cfg(cfg, OPTIMIZER_BUILDERS) + + +def build_optimizer(model, cfg): + optimizer_cfg = copy.deepcopy(cfg) + constructor_type = optimizer_cfg.pop('constructor', + 'DefaultOptimizerConstructor') + paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) + optim_constructor = build_optimizer_constructor( + dict( + type=constructor_type, + optimizer_cfg=optimizer_cfg, + paramwise_cfg=paramwise_cfg)) + optimizer = optim_constructor(model) + return optimizer diff --git a/mmcls/core/utils/__init__.py b/mmcls/core/utils/__init__.py index 537c268b..0fb949d9 100644 --- a/mmcls/core/utils/__init__.py +++ b/mmcls/core/utils/__init__.py @@ -1,3 +1,3 @@ -from .dist_utils import DistOptimizerHook +from .dist_utils import DistOptimizerHook, allreduce_grads -__all__ = ['DistOptimizerHook'] +__all__ = ['allreduce_grads', 'DistOptimizerHook'] diff --git a/mmcls/core/utils/dist_utils.py b/mmcls/core/utils/dist_utils.py index e8c1b7c9..1c914b22 100644 --- a/mmcls/core/utils/dist_utils.py +++ b/mmcls/core/utils/dist_utils.py @@ -1,4 +1,44 @@ +from collections import OrderedDict + +import torch.distributed as dist from mmcv.runner import OptimizerHook +from torch._utils import (_flatten_dense_tensors, _take_tensors, + _unflatten_dense_tensors) + + +def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): + if bucket_size_mb > 0: + bucket_size_bytes = bucket_size_mb * 1024 * 1024 + buckets = _take_tensors(tensors, bucket_size_bytes) + else: + buckets = OrderedDict() + for tensor in tensors: + tp = tensor.type() + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(tensor) + buckets = buckets.values() + + for bucket in buckets: + flat_tensors = _flatten_dense_tensors(bucket) + dist.all_reduce(flat_tensors) + flat_tensors.div_(world_size) + for tensor, synced in zip( + bucket, _unflatten_dense_tensors(flat_tensors, bucket)): + tensor.copy_(synced) + + +def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): + grads = [ + param.grad.data for param in params + if param.requires_grad and param.grad is not None + ] + world_size = dist.get_world_size() + if coalesce: + _allreduce_coalesced(grads, world_size, bucket_size_mb) + else: + for tensor in grads: + dist.all_reduce(tensor.div_(world_size)) class DistOptimizerHook(OptimizerHook): diff --git a/mmcls/datasets/pipelines/compose.py b/mmcls/datasets/pipelines/compose.py index ef35ce2a..6f0b8480 100644 --- a/mmcls/datasets/pipelines/compose.py +++ b/mmcls/datasets/pipelines/compose.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from mmcv.utils import build_from_cfg -from ..registry import PIPELINES +from ..builder import PIPELINES @PIPELINES.register_module diff --git a/tools/test.py b/tools/test.py index 74b3afe9..eeb83131 100644 --- a/tools/test.py +++ b/tools/test.py @@ -6,7 +6,8 @@ import torch from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import get_dist_info, init_dist, load_checkpoint -from mmcls.core import multi_gpu_test, single_gpu_test, wrap_fp16_model +from mmcls.apis import multi_gpu_test, single_gpu_test +from mmcls.core import wrap_fp16_model from mmcls.datasets import build_dataloader, build_dataset from mmcls.models import build_model