[Feature] Support auto import modules from registry (#660)
* support auto import * update * fix lint * update * fix lint * update * refine * update * update colabpull/722/head
parent
824d62e2af
commit
447f4bb38d
|
@ -14,7 +14,7 @@ model = dict(
|
||||||
depth=50,
|
depth=50,
|
||||||
in_channels=3,
|
in_channels=3,
|
||||||
num_stages=4,
|
num_stages=4,
|
||||||
out_indices=(3),
|
out_indices=(3, ),
|
||||||
norm_cfg=dict(type='BN'),
|
norm_cfg=dict(type='BN'),
|
||||||
frozen_stages=-1),
|
frozen_stages=-1),
|
||||||
neck=dict(type='GlobalAveragePooling'),
|
neck=dict(type='GlobalAveragePooling'),
|
||||||
|
|
|
@ -62,8 +62,7 @@
|
||||||
"import torch\n",
|
"import torch\n",
|
||||||
"from mmengine.dataset import Compose, default_collate\n",
|
"from mmengine.dataset import Compose, default_collate\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from mmselfsup.apis import inference_model, init_model\n",
|
"from mmselfsup.apis import inference_model, init_model"
|
||||||
"from mmselfsup.utils import register_all_modules"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -238,7 +237,6 @@
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# make random mask reproducible (comment out to make it change)\n",
|
"# make random mask reproducible (comment out to make it change)\n",
|
||||||
"register_all_modules()\n",
|
|
||||||
"torch.manual_seed(2)"
|
"torch.manual_seed(2)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|
|
@ -157,7 +157,7 @@
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# Check PyTorch installation\n",
|
"# Check PyTorch installation\n",
|
||||||
"import torch, torchvision\n",
|
"import torch\n",
|
||||||
"print(torch.__version__)\n",
|
"print(torch.__version__)\n",
|
||||||
"print(torch.cuda.is_available())"
|
"print(torch.cuda.is_available())"
|
||||||
]
|
]
|
||||||
|
@ -313,7 +313,7 @@
|
||||||
"source": [
|
"source": [
|
||||||
"!pip3 install openmim\n",
|
"!pip3 install openmim\n",
|
||||||
"!pip install -U openmim\n",
|
"!pip install -U openmim\n",
|
||||||
"!mim install 'mmengine' 'mmcv>=2.0.0rc1'"
|
"!mim install 'mmengine' 'mmcv>=2.0.0rc4'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1418,15 +1418,8 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from mmengine.config import Config, DictAction\n",
|
|
||||||
"from mmengine.runner import Runner\n",
|
"from mmengine.runner import Runner\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from mmselfsup.utils import register_all_modules\n",
|
|
||||||
"\n",
|
|
||||||
"# register all modules in mmselfsup into the registries\n",
|
|
||||||
"# do not init the default scope here because it will be init in the runner\n",
|
|
||||||
"register_all_modules(init_default_scope=False)\n",
|
|
||||||
"\n",
|
|
||||||
"# build the runner from config\n",
|
"# build the runner from config\n",
|
||||||
"runner = Runner.from_cfg(cfg)\n",
|
"runner = Runner.from_cfg(cfg)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -2669,15 +2662,8 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from mmengine.config import Config, DictAction\n",
|
|
||||||
"from mmengine.runner import Runner\n",
|
"from mmengine.runner import Runner\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from mmselfsup.utils import register_all_modules\n",
|
|
||||||
"\n",
|
|
||||||
"# register all modules in mmselfsup into the registries\n",
|
|
||||||
"# do not init the default scope here because it will be init in the runner\n",
|
|
||||||
"register_all_modules(init_default_scope=False)\n",
|
|
||||||
"\n",
|
|
||||||
"# build the runner from config\n",
|
"# build the runner from config\n",
|
||||||
"runner = Runner.from_cfg(benchmark_cfg)\n",
|
"runner = Runner.from_cfg(benchmark_cfg)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
|
@ -10,7 +10,7 @@ mmengine_minimum_version = '0.4.0'
|
||||||
mmengine_maximum_version = '1.0.0'
|
mmengine_maximum_version = '1.0.0'
|
||||||
mmengine_version = digit_version(mmengine.__version__)
|
mmengine_version = digit_version(mmengine.__version__)
|
||||||
|
|
||||||
mmcv_minimum_version = '2.0.0rc1'
|
mmcv_minimum_version = '2.0.0rc4'
|
||||||
mmcv_maximum_version = '2.1.0'
|
mmcv_maximum_version = '2.1.0'
|
||||||
mmcv_version = digit_version(mmcv.__version__)
|
mmcv_version = digit_version(mmcv.__version__)
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from mmengine.config import Config
|
from mmengine.config import Config
|
||||||
from mmengine.dataset import Compose, default_collate
|
from mmengine.dataset import Compose, default_collate
|
||||||
|
from mmengine.registry import init_default_scope
|
||||||
from mmengine.runner import load_checkpoint
|
from mmengine.runner import load_checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
@ -36,8 +37,11 @@ def init_model(config: Union[str, Config],
|
||||||
elif not isinstance(config, Config):
|
elif not isinstance(config, Config):
|
||||||
raise TypeError('config must be a filename or Config object, '
|
raise TypeError('config must be a filename or Config object, '
|
||||||
f'but got {type(config)}')
|
f'but got {type(config)}')
|
||||||
|
|
||||||
if options is not None:
|
if options is not None:
|
||||||
config.merge_from_dict(options)
|
config.merge_from_dict(options)
|
||||||
|
init_default_scope(config.get('default_scope', 'mmselfsup'))
|
||||||
|
|
||||||
config.model.pretrained = None
|
config.model.pretrained = None
|
||||||
config.model.setdefault('data_preprocessor',
|
config.model.setdefault('data_preprocessor',
|
||||||
config.get('data_preprocessor', None))
|
config.get('data_preprocessor', None))
|
||||||
|
|
|
@ -54,55 +54,93 @@ __all__ = [
|
||||||
|
|
||||||
# Registries For Runner and the related
|
# Registries For Runner and the related
|
||||||
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
|
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
|
||||||
RUNNERS = Registry('runner', parent=MMENGINE_RUNNERS)
|
RUNNERS = Registry(
|
||||||
|
'runner', parent=MMENGINE_RUNNERS, locations=['mmselfsup.engine.runner'])
|
||||||
# manage runner constructors that define how to initialize runners
|
# manage runner constructors that define how to initialize runners
|
||||||
RUNNER_CONSTRUCTORS = Registry(
|
RUNNER_CONSTRUCTORS = Registry(
|
||||||
'runner constructor', parent=MMENGINE_RUNNER_CONSTRUCTORS)
|
'runner constructor',
|
||||||
|
parent=MMENGINE_RUNNER_CONSTRUCTORS,
|
||||||
|
locations=['mmselfsup.engine.runner'])
|
||||||
# manage all kinds of loops like `EpochBasedTrainLoop`
|
# manage all kinds of loops like `EpochBasedTrainLoop`
|
||||||
LOOPS = Registry('loop', parent=MMENGINE_LOOPS)
|
LOOPS = Registry(
|
||||||
|
'loop', parent=MMENGINE_LOOPS, locations=['mmselfsup.engine.runner'])
|
||||||
# manage all kinds of hooks like `CheckpointHook`
|
# manage all kinds of hooks like `CheckpointHook`
|
||||||
HOOKS = Registry('hook', parent=MMENGINE_HOOKS)
|
HOOKS = Registry(
|
||||||
|
'hook', parent=MMENGINE_HOOKS, locations=['mmselfsup.engine.hooks'])
|
||||||
|
|
||||||
# Registries For Data and the related
|
# Registries For Data and the related
|
||||||
# manage data-related modules
|
# manage data-related modules
|
||||||
DATASETS = Registry('dataset', parent=MMENGINE_DATASETS)
|
DATASETS = Registry(
|
||||||
DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS)
|
'dataset', parent=MMENGINE_DATASETS, locations=['mmselfsup.datasets'])
|
||||||
TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS)
|
DATA_SAMPLERS = Registry(
|
||||||
|
'data sampler',
|
||||||
|
parent=MMENGINE_DATA_SAMPLERS,
|
||||||
|
locations=['mmselfsup.datasets.samplers'])
|
||||||
|
TRANSFORMS = Registry(
|
||||||
|
'transform',
|
||||||
|
parent=MMENGINE_TRANSFORMS,
|
||||||
|
locations=['mmselfsup.datasets.transforms'])
|
||||||
|
|
||||||
# manage all kinds of modules inheriting `nn.Module`
|
# manage all kinds of modules inheriting `nn.Module`
|
||||||
MODELS = Registry('model', parent=MMENGINE_MODELS)
|
MODELS = Registry(
|
||||||
|
'model', parent=MMENGINE_MODELS, locations=['mmselfsup.models'])
|
||||||
# manage all kinds of model wrappers like 'MMDistributedDataParallel'
|
# manage all kinds of model wrappers like 'MMDistributedDataParallel'
|
||||||
MODEL_WRAPPERS = Registry('model_wrapper', parent=MMENGINE_MODEL_WRAPPERS)
|
MODEL_WRAPPERS = Registry(
|
||||||
|
'model_wrapper',
|
||||||
|
parent=MMENGINE_MODEL_WRAPPERS,
|
||||||
|
locations=['mmselfsup.models'])
|
||||||
# manage all kinds of weight initialization modules like `Uniform`
|
# manage all kinds of weight initialization modules like `Uniform`
|
||||||
WEIGHT_INITIALIZERS = Registry(
|
WEIGHT_INITIALIZERS = Registry(
|
||||||
'weight initializer', parent=MMENGINE_WEIGHT_INITIALIZERS)
|
'weight initializer',
|
||||||
|
parent=MMENGINE_WEIGHT_INITIALIZERS,
|
||||||
|
locations=['mmselfsup.models'])
|
||||||
|
|
||||||
# Registries For Optimizer and the related
|
# Registries For Optimizer and the related
|
||||||
# manage all kinds of optimizers like `SGD` and `Adam`
|
# manage all kinds of optimizers like `SGD` and `Adam`
|
||||||
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS)
|
OPTIMIZERS = Registry(
|
||||||
|
'optimizer',
|
||||||
|
parent=MMENGINE_OPTIMIZERS,
|
||||||
|
locations=['mmselfsup.engine.optimizers'])
|
||||||
# manage optimizer wrapper
|
# manage optimizer wrapper
|
||||||
OPTIM_WRAPPERS = Registry('optimizer_wrapper', parent=MMENGINE_OPTIM_WRAPPERS)
|
OPTIM_WRAPPERS = Registry(
|
||||||
|
'optimizer_wrapper',
|
||||||
|
parent=MMENGINE_OPTIM_WRAPPERS,
|
||||||
|
locations=['mmselfsup.engine.optimizers'])
|
||||||
# manage constructors that customize the optimization hyperparameters.
|
# manage constructors that customize the optimization hyperparameters.
|
||||||
OPTIM_WRAPPER_CONSTRUCTORS = Registry(
|
OPTIM_WRAPPER_CONSTRUCTORS = Registry(
|
||||||
'optimizer wrapper constructor',
|
'optimizer wrapper constructor',
|
||||||
parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS)
|
parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS,
|
||||||
|
locations=['mmselfsup.engine.optimizers'])
|
||||||
# manage all kinds of parameter schedulers like `MultiStepLR`
|
# manage all kinds of parameter schedulers like `MultiStepLR`
|
||||||
PARAM_SCHEDULERS = Registry(
|
PARAM_SCHEDULERS = Registry(
|
||||||
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS)
|
'parameter scheduler',
|
||||||
|
parent=MMENGINE_PARAM_SCHEDULERS,
|
||||||
|
locations=['mmselfsup.engine.schedulers'])
|
||||||
|
|
||||||
# manage all kinds of metrics
|
# manage all kinds of metrics
|
||||||
METRICS = Registry('metric', parent=MMENGINE_METRICS)
|
METRICS = Registry(
|
||||||
|
'metric', parent=MMENGINE_METRICS, locations=['mmselfsup.evaluation'])
|
||||||
# manage evaluator
|
# manage evaluator
|
||||||
EVALUATOR = Registry('evaluator', parent=MMENGINE_EVALUATOR)
|
EVALUATOR = Registry(
|
||||||
|
'evaluator', parent=MMENGINE_EVALUATOR, locations=['mmselfsup.evaluation'])
|
||||||
|
|
||||||
# manage task-specific modules like anchor generators and box coders
|
# manage task-specific modules like anchor generators and box coders
|
||||||
TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS)
|
TASK_UTILS = Registry(
|
||||||
|
'task util', parent=MMENGINE_TASK_UTILS, locations=['mmselfsup.models'])
|
||||||
|
|
||||||
# Registries For Visualizer and the related
|
|
||||||
# manage visualizer
|
# manage visualizer
|
||||||
VISUALIZERS = Registry('visualizer', parent=MMENGINE_VISUALIZERS)
|
VISUALIZERS = Registry(
|
||||||
|
'visualizer',
|
||||||
|
parent=MMENGINE_VISUALIZERS,
|
||||||
|
locations=['mmselfsup.visualization'])
|
||||||
# manage visualizer backend
|
# manage visualizer backend
|
||||||
VISBACKENDS = Registry('vis_backend', parent=MMENGINE_VISBACKENDS)
|
VISBACKENDS = Registry(
|
||||||
|
'vis_backend',
|
||||||
|
parent=MMENGINE_VISBACKENDS,
|
||||||
|
locations=['mmselfsup.visualization'])
|
||||||
|
|
||||||
# manage logprocessor
|
# manage logprocessor
|
||||||
LOG_PROCESSORS = Registry('log_processor', parent=MMENGINE_LOG_PROCESSORS)
|
LOG_PROCESSORS = Registry(
|
||||||
|
'log_processor',
|
||||||
|
parent=MMENGINE_LOG_PROCESSORS,
|
||||||
|
locations=['mmselfsup.visualization'])
|
||||||
|
|
|
@ -7,11 +7,11 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmengine.config import Config
|
from mmengine.config import Config
|
||||||
|
from mmengine.registry import init_default_scope
|
||||||
|
|
||||||
from mmselfsup.apis import inference_model
|
from mmselfsup.apis import inference_model
|
||||||
from mmselfsup.models import BaseModel
|
from mmselfsup.models import BaseModel
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
backbone = dict(
|
backbone = dict(
|
||||||
type='ResNet',
|
type='ResNet',
|
||||||
|
@ -37,12 +37,11 @@ class ExampleModel(BaseModel):
|
||||||
|
|
||||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='')
|
@pytest.mark.skipif(platform.system() == 'Windows', reason='')
|
||||||
def test_inference_model():
|
def test_inference_model():
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
# Specify the data settings
|
# Specify the data settings
|
||||||
cfg = Config.fromfile(
|
cfg = Config.fromfile(
|
||||||
'configs/selfsup/relative_loc/relative-loc_resnet50_8xb64-steplr-70e_in1k.py' # noqa: E501
|
'configs/selfsup/relative_loc/relative-loc_resnet50_8xb64-steplr-70e_in1k.py' # noqa: E501
|
||||||
)
|
)
|
||||||
|
init_default_scope(cfg.get('default_scope', 'mmselfsup'))
|
||||||
# Build the algorithm
|
# Build the algorithm
|
||||||
model = ExampleModel()
|
model = ExampleModel()
|
||||||
model.cfg = cfg
|
model.cfg = cfg
|
||||||
|
|
|
@ -2,9 +2,9 @@
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from mmengine.registry import init_default_scope
|
||||||
|
|
||||||
from mmselfsup.datasets import DeepClusterImageNet
|
from mmselfsup.datasets import DeepClusterImageNet
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
# dataset settings
|
# dataset settings
|
||||||
train_pipeline = [
|
train_pipeline = [
|
||||||
|
@ -14,7 +14,7 @@ train_pipeline = [
|
||||||
|
|
||||||
|
|
||||||
def test_deepcluster_dataset():
|
def test_deepcluster_dataset():
|
||||||
register_all_modules()
|
init_default_scope('mmselfsup')
|
||||||
|
|
||||||
data = dict(
|
data = dict(
|
||||||
ann_file=osp.join(
|
ann_file=osp.join(
|
||||||
|
|
|
@ -3,9 +3,9 @@ import os.path as osp
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
from mmengine.registry import init_default_scope
|
||||||
|
|
||||||
from mmselfsup.datasets import ImageList
|
from mmselfsup.datasets import ImageList
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
# dataset settings
|
# dataset settings
|
||||||
train_pipeline = [
|
train_pipeline = [
|
||||||
|
@ -15,7 +15,7 @@ train_pipeline = [
|
||||||
|
|
||||||
|
|
||||||
def test_image_list_dataset():
|
def test_image_list_dataset():
|
||||||
register_all_modules()
|
init_default_scope('mmselfsup')
|
||||||
|
|
||||||
data = dict(
|
data = dict(
|
||||||
ann_file='',
|
ann_file='',
|
||||||
|
|
|
@ -7,9 +7,6 @@ import torch
|
||||||
|
|
||||||
from mmselfsup.models import BarlowTwins
|
from mmselfsup.models import BarlowTwins
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
backbone = dict(
|
backbone = dict(
|
||||||
type='ResNet',
|
type='ResNet',
|
||||||
|
|
|
@ -7,7 +7,6 @@ from mmengine.structures import InstanceData
|
||||||
|
|
||||||
from mmselfsup.models import BEiT
|
from mmselfsup.models import BEiT
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
data_preprocessor = dict(
|
data_preprocessor = dict(
|
||||||
type='TwoNormDataPreprocessor',
|
type='TwoNormDataPreprocessor',
|
||||||
|
@ -37,8 +36,6 @@ target_generator = dict(type='DALL-E')
|
||||||
|
|
||||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||||
def test_beitv1():
|
def test_beitv1():
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
model = BEiT(
|
model = BEiT(
|
||||||
backbone=backbone,
|
backbone=backbone,
|
||||||
neck=neck,
|
neck=neck,
|
||||||
|
|
|
@ -7,7 +7,6 @@ from mmengine.structures import InstanceData
|
||||||
|
|
||||||
from mmselfsup.models import BEiT
|
from mmselfsup.models import BEiT
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
data_preprocessor = dict(
|
data_preprocessor = dict(
|
||||||
type='TwoNormDataPreprocessor',
|
type='TwoNormDataPreprocessor',
|
||||||
|
@ -70,8 +69,6 @@ target_generator = dict(type='VQKD', encoder_config=vqkd_encoder)
|
||||||
|
|
||||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||||
def test_beitv2():
|
def test_beitv2():
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
model = BEiT(
|
model = BEiT(
|
||||||
backbone=backbone,
|
backbone=backbone,
|
||||||
neck=neck,
|
neck=neck,
|
||||||
|
|
|
@ -7,9 +7,7 @@ import torch
|
||||||
|
|
||||||
from mmselfsup.models.algorithms.byol import BYOL
|
from mmselfsup.models.algorithms.byol import BYOL
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
register_all_modules()
|
|
||||||
backbone = dict(
|
backbone = dict(
|
||||||
type='ResNet',
|
type='ResNet',
|
||||||
depth=18,
|
depth=18,
|
||||||
|
|
|
@ -7,9 +7,6 @@ from mmengine.structures import InstanceData
|
||||||
|
|
||||||
from mmselfsup.models.algorithms import CAE
|
from mmselfsup.models.algorithms import CAE
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
# model settings
|
# model settings
|
||||||
backbone = dict(type='CAEViT', arch='b', patch_size=16, init_values=0.1)
|
backbone = dict(type='CAEViT', arch='b', patch_size=16, init_values=0.1)
|
||||||
|
|
|
@ -8,9 +8,6 @@ from mmengine.structures import InstanceData
|
||||||
|
|
||||||
from mmselfsup.models.algorithms import DeepCluster
|
from mmselfsup.models.algorithms import DeepCluster
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
num_classes = 5
|
num_classes = 5
|
||||||
with_sobel = True,
|
with_sobel = True,
|
||||||
|
|
|
@ -9,9 +9,6 @@ import torch
|
||||||
import mmselfsup
|
import mmselfsup
|
||||||
from mmselfsup.models.algorithms.densecl import DenseCL
|
from mmselfsup.models.algorithms.densecl import DenseCL
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
queue_len = 32
|
queue_len = 32
|
||||||
feat_dim = 2
|
feat_dim = 2
|
||||||
|
|
|
@ -8,9 +8,6 @@ import torch
|
||||||
|
|
||||||
from mmselfsup.models.algorithms import EVA
|
from mmselfsup.models.algorithms import EVA
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
backbone = dict(type='MAEViT', arch='b', patch_size=16, mask_ratio=0.75)
|
backbone = dict(type='MAEViT', arch='b', patch_size=16, mask_ratio=0.75)
|
||||||
neck = dict(
|
neck = dict(
|
||||||
|
|
|
@ -7,9 +7,6 @@ import torch
|
||||||
|
|
||||||
from mmselfsup.models.algorithms.mae import MAE
|
from mmselfsup.models.algorithms.mae import MAE
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
backbone = dict(type='MAEViT', arch='b', patch_size=16, mask_ratio=0.75)
|
backbone = dict(type='MAEViT', arch='b', patch_size=16, mask_ratio=0.75)
|
||||||
neck = dict(
|
neck = dict(
|
||||||
|
|
|
@ -9,9 +9,6 @@ from mmengine.utils import digit_version
|
||||||
|
|
||||||
from mmselfsup.models.algorithms.maskfeat import MaskFeat
|
from mmselfsup.models.algorithms.maskfeat import MaskFeat
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
backbone = dict(type='MaskFeatViT', arch='b', patch_size=16)
|
backbone = dict(type='MaskFeatViT', arch='b', patch_size=16)
|
||||||
neck = dict(
|
neck = dict(
|
||||||
|
|
|
@ -8,9 +8,6 @@ import torch
|
||||||
|
|
||||||
from mmselfsup.models.algorithms import MILAN
|
from mmselfsup.models.algorithms import MILAN
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
backbone = dict(type='MILANViT', arch='b', patch_size=16, mask_ratio=0.75)
|
backbone = dict(type='MILANViT', arch='b', patch_size=16, mask_ratio=0.75)
|
||||||
neck = dict(
|
neck = dict(
|
||||||
|
|
|
@ -7,9 +7,6 @@ import torch
|
||||||
|
|
||||||
from mmselfsup.models.algorithms import MoCo
|
from mmselfsup.models.algorithms import MoCo
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
queue_len = 32
|
queue_len = 32
|
||||||
feat_dim = 2
|
feat_dim = 2
|
||||||
|
|
|
@ -7,9 +7,6 @@ import torch
|
||||||
|
|
||||||
from mmselfsup.models import MoCoV3
|
from mmselfsup.models import MoCoV3
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
backbone = dict(
|
backbone = dict(
|
||||||
type='MoCoV3ViT',
|
type='MoCoV3ViT',
|
||||||
|
|
|
@ -8,9 +8,6 @@ from mmengine.structures import InstanceData
|
||||||
|
|
||||||
from mmselfsup.models.algorithms import NPID
|
from mmselfsup.models.algorithms import NPID
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
backbone = dict(
|
backbone = dict(
|
||||||
type='ResNet',
|
type='ResNet',
|
||||||
|
|
|
@ -8,9 +8,6 @@ from mmengine.structures import InstanceData
|
||||||
|
|
||||||
from mmselfsup.models.algorithms import ODC
|
from mmselfsup.models.algorithms import ODC
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
num_classes = 5
|
num_classes = 5
|
||||||
backbone = dict(
|
backbone = dict(
|
||||||
|
|
|
@ -7,9 +7,6 @@ from mmengine.structures import InstanceData
|
||||||
|
|
||||||
from mmselfsup.models.algorithms.relative_loc import RelativeLoc
|
from mmselfsup.models.algorithms.relative_loc import RelativeLoc
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
backbone = dict(
|
backbone = dict(
|
||||||
type='ResNet',
|
type='ResNet',
|
||||||
|
|
|
@ -8,9 +8,6 @@ from mmengine.structures import InstanceData
|
||||||
|
|
||||||
from mmselfsup.models.algorithms.rotation_pred import RotationPred
|
from mmselfsup.models.algorithms.rotation_pred import RotationPred
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
backbone = dict(
|
backbone = dict(
|
||||||
type='ResNet',
|
type='ResNet',
|
||||||
|
|
|
@ -7,9 +7,6 @@ import torch
|
||||||
|
|
||||||
from mmselfsup.models.algorithms.simclr import SimCLR
|
from mmselfsup.models.algorithms.simclr import SimCLR
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
backbone = dict(
|
backbone = dict(
|
||||||
type='ResNet',
|
type='ResNet',
|
||||||
|
|
|
@ -7,9 +7,6 @@ from mmengine.structures import InstanceData
|
||||||
|
|
||||||
from mmselfsup.models.algorithms import SimMIM
|
from mmselfsup.models.algorithms import SimMIM
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||||
|
|
|
@ -7,9 +7,6 @@ import torch
|
||||||
|
|
||||||
from mmselfsup.models.algorithms import SimSiam
|
from mmselfsup.models.algorithms import SimSiam
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
backbone = dict(
|
backbone = dict(
|
||||||
type='ResNet',
|
type='ResNet',
|
||||||
|
|
|
@ -7,9 +7,6 @@ import torch
|
||||||
|
|
||||||
from mmselfsup.models.algorithms.swav import SwAV
|
from mmselfsup.models.algorithms.swav import SwAV
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
nmb_crops = [2, 6]
|
nmb_crops = [2, 6]
|
||||||
backbone = dict(
|
backbone = dict(
|
||||||
|
|
|
@ -1,40 +0,0 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import datetime
|
|
||||||
import sys
|
|
||||||
from unittest import TestCase
|
|
||||||
|
|
||||||
from mmengine import DefaultScope
|
|
||||||
|
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
|
|
||||||
class TestSetupEnv(TestCase):
|
|
||||||
|
|
||||||
def test_register_all_modules(self):
|
|
||||||
from mmselfsup.registry import DATASETS
|
|
||||||
|
|
||||||
# not init default scope
|
|
||||||
sys.modules.pop('mmselfsup.datasets', None)
|
|
||||||
sys.modules.pop('mmselfsup.datasets.places205', None)
|
|
||||||
DATASETS._module_dict.pop('Places205', None)
|
|
||||||
self.assertFalse('Places205' in DATASETS.module_dict)
|
|
||||||
register_all_modules(init_default_scope=False)
|
|
||||||
self.assertTrue('Places205' in DATASETS.module_dict)
|
|
||||||
|
|
||||||
# init default scope
|
|
||||||
sys.modules.pop('mmselfsup.datasets')
|
|
||||||
sys.modules.pop('mmselfsup.datasets.places205')
|
|
||||||
DATASETS._module_dict.pop('Places205', None)
|
|
||||||
self.assertFalse('Places205' in DATASETS.module_dict)
|
|
||||||
register_all_modules(init_default_scope=True)
|
|
||||||
self.assertTrue('Places205' in DATASETS.module_dict)
|
|
||||||
self.assertEqual(DefaultScope.get_current_instance().scope_name,
|
|
||||||
'mmselfsup')
|
|
||||||
|
|
||||||
# init default scope when another scope is init
|
|
||||||
name = f'test-{datetime.datetime.now()}'
|
|
||||||
DefaultScope.get_instance(name, scope_name='test')
|
|
||||||
with self.assertWarnsRegex(
|
|
||||||
Warning,
|
|
||||||
'The current default scope "test" is not "mmselfsup"'):
|
|
||||||
register_all_modules(init_default_scope=True)
|
|
|
@ -5,10 +5,10 @@ import os.path as osp
|
||||||
import mmengine
|
import mmengine
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mmengine import Config, DictAction
|
from mmengine import Config, DictAction
|
||||||
|
from mmengine.registry import init_default_scope
|
||||||
|
|
||||||
from mmselfsup.datasets.builder import build_dataset
|
from mmselfsup.datasets.builder import build_dataset
|
||||||
from mmselfsup.registry import VISUALIZERS
|
from mmselfsup.registry import VISUALIZERS
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
@ -41,12 +41,12 @@ def parse_args():
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
cfg = Config.fromfile(args.config)
|
cfg = Config.fromfile(args.config)
|
||||||
if args.cfg_options is not None:
|
if args.cfg_options is not None:
|
||||||
cfg.merge_from_dict(args.cfg_options)
|
cfg.merge_from_dict(args.cfg_options)
|
||||||
|
|
||||||
# register all modules in mmselfsup into the registries
|
init_default_scope(cfg.get('default_scope', 'mmselfsup'))
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
dataset = build_dataset(cfg.train_dataloader.dataset)
|
dataset = build_dataset(cfg.train_dataloader.dataset)
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,6 @@ import torch
|
||||||
from mmengine.dataset import Compose, default_collate
|
from mmengine.dataset import Compose, default_collate
|
||||||
|
|
||||||
from mmselfsup.apis import inference_model, init_model
|
from mmselfsup.apis import inference_model, init_model
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
imagenet_mean = np.array([0.485, 0.456, 0.406])
|
imagenet_mean = np.array([0.485, 0.456, 0.406])
|
||||||
imagenet_std = np.array([0.229, 0.224, 0.225])
|
imagenet_std = np.array([0.229, 0.224, 0.225])
|
||||||
|
@ -115,8 +114,6 @@ def main():
|
||||||
help='The random seed for visualization')
|
help='The random seed for visualization')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
# build the model from a config file and a checkpoint file
|
# build the model from a config file and a checkpoint file
|
||||||
model = init_model(args.config, args.checkpoint, device=args.device)
|
model = init_model(args.config, args.checkpoint, device=args.device)
|
||||||
print('Model loaded.')
|
print('Model loaded.')
|
||||||
|
|
|
@ -14,13 +14,13 @@ from mmengine.config import Config, DictAction
|
||||||
from mmengine.dataset import default_collate, worker_init_fn
|
from mmengine.dataset import default_collate, worker_init_fn
|
||||||
from mmengine.dist import get_rank
|
from mmengine.dist import get_rank
|
||||||
from mmengine.logging import MMLogger
|
from mmengine.logging import MMLogger
|
||||||
|
from mmengine.registry import init_default_scope
|
||||||
from mmengine.utils import mkdir_or_exist
|
from mmengine.utils import mkdir_or_exist
|
||||||
from sklearn.manifold import TSNE
|
from sklearn.manifold import TSNE
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from mmselfsup.apis import init_model
|
from mmselfsup.apis import init_model
|
||||||
from mmselfsup.registry import DATA_SAMPLERS, DATASETS
|
from mmselfsup.registry import DATA_SAMPLERS, DATASETS
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
@ -99,19 +99,15 @@ def parse_args():
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def post_process():
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# register all modules in mmselfsup into the registries
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
cfg = Config.fromfile(args.config)
|
cfg = Config.fromfile(args.config)
|
||||||
if args.cfg_options is not None:
|
if args.cfg_options is not None:
|
||||||
cfg.merge_from_dict(args.cfg_options)
|
cfg.merge_from_dict(args.cfg_options)
|
||||||
|
|
||||||
|
init_default_scope(cfg.get('default_scope', 'mmselfsup'))
|
||||||
|
|
||||||
# set cudnn_benchmark
|
# set cudnn_benchmark
|
||||||
if cfg.get('cudnn_benchmark', False):
|
if cfg.get('cudnn_benchmark', False):
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
|
@ -10,13 +10,13 @@ from mmengine.config import Config, DictAction
|
||||||
from mmengine.dist import get_rank, init_dist
|
from mmengine.dist import get_rank, init_dist
|
||||||
from mmengine.logging import MMLogger
|
from mmengine.logging import MMLogger
|
||||||
from mmengine.model.wrappers import MMDistributedDataParallel, is_model_wrapper
|
from mmengine.model.wrappers import MMDistributedDataParallel, is_model_wrapper
|
||||||
|
from mmengine.registry import init_default_scope
|
||||||
from mmengine.runner import Runner, load_checkpoint
|
from mmengine.runner import Runner, load_checkpoint
|
||||||
from mmengine.utils import mkdir_or_exist
|
from mmengine.utils import mkdir_or_exist
|
||||||
|
|
||||||
from mmselfsup.evaluation.functional import knn_eval
|
from mmselfsup.evaluation.functional import knn_eval
|
||||||
from mmselfsup.models.utils import Extractor
|
from mmselfsup.models.utils import Extractor
|
||||||
from mmselfsup.registry import MODELS
|
from mmselfsup.registry import MODELS
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
@ -72,14 +72,13 @@ def parse_args():
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# register all modules in mmselfsup into the registries
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
# load config
|
# load config
|
||||||
cfg = Config.fromfile(args.config)
|
cfg = Config.fromfile(args.config)
|
||||||
if args.cfg_options is not None:
|
if args.cfg_options is not None:
|
||||||
cfg.merge_from_dict(args.cfg_options)
|
cfg.merge_from_dict(args.cfg_options)
|
||||||
|
|
||||||
|
init_default_scope(cfg.get('default_scope', 'mmselfsup'))
|
||||||
|
|
||||||
# set cudnn_benchmark
|
# set cudnn_benchmark
|
||||||
if cfg.env_cfg.get('cudnn_benchmark', False):
|
if cfg.env_cfg.get('cudnn_benchmark', False):
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
|
@ -13,13 +13,13 @@ from mmengine.dataset import pseudo_collate, worker_init_fn
|
||||||
from mmengine.dist import get_rank, init_dist
|
from mmengine.dist import get_rank, init_dist
|
||||||
from mmengine.logging import MMLogger
|
from mmengine.logging import MMLogger
|
||||||
from mmengine.model.wrappers import MMDistributedDataParallel, is_model_wrapper
|
from mmengine.model.wrappers import MMDistributedDataParallel, is_model_wrapper
|
||||||
|
from mmengine.registry import init_default_scope
|
||||||
from mmengine.runner import load_checkpoint
|
from mmengine.runner import load_checkpoint
|
||||||
from mmengine.utils import mkdir_or_exist
|
from mmengine.utils import mkdir_or_exist
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from mmselfsup.models.utils import Extractor
|
from mmselfsup.models.utils import Extractor
|
||||||
from mmselfsup.registry import DATA_SAMPLERS, DATASETS, MODELS
|
from mmselfsup.registry import DATA_SAMPLERS, DATASETS, MODELS
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
@ -66,14 +66,13 @@ def parse_args():
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# register all modules in mmselfsup into the registries
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
# load config
|
# load config
|
||||||
cfg = Config.fromfile(args.config)
|
cfg = Config.fromfile(args.config)
|
||||||
if args.cfg_options is not None:
|
if args.cfg_options is not None:
|
||||||
cfg.merge_from_dict(args.cfg_options)
|
cfg.merge_from_dict(args.cfg_options)
|
||||||
|
|
||||||
|
init_default_scope(cfg.get('default_scope', 'mmselfsup'))
|
||||||
|
|
||||||
# set cudnn_benchmark
|
# set cudnn_benchmark
|
||||||
if cfg.env_cfg.get('cudnn_benchmark', False):
|
if cfg.env_cfg.get('cudnn_benchmark', False):
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
|
@ -6,8 +6,6 @@ import os.path as osp
|
||||||
from mmengine.config import Config, DictAction
|
from mmengine.config import Config, DictAction
|
||||||
from mmengine.runner import Runner
|
from mmengine.runner import Runner
|
||||||
|
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description='Train a model')
|
parser = argparse.ArgumentParser(description='Train a model')
|
||||||
|
@ -51,10 +49,6 @@ def parse_args():
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# register all modules in mmselfsup into the registries
|
|
||||||
# do not init the default scope here because it will be init in the runner
|
|
||||||
register_all_modules(init_default_scope=False)
|
|
||||||
|
|
||||||
# load config
|
# load config
|
||||||
cfg = Config.fromfile(args.config)
|
cfg = Config.fromfile(args.config)
|
||||||
cfg.launcher = args.launcher
|
cfg.launcher = args.launcher
|
||||||
|
|
Loading…
Reference in New Issue