mirror of
https://github.com/open-mmlab/mmrazor.git
synced 2025-06-03 15:02:54 +08:00
[fix] fix mmcv mmengine (#242)
* align_with_mmcv_and_mmengine * fix_mmcv.fileio
This commit is contained in:
parent
afb95a40e7
commit
ba71abf357
@ -5,7 +5,7 @@ import re
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import mmcv
|
import mmengine
|
||||||
import wget
|
import wget
|
||||||
from modelindex.load_model_index import load
|
from modelindex.load_model_index import load
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
@ -75,7 +75,7 @@ def create_test_job_batch(commands, model_info, args, port):
|
|||||||
|
|
||||||
http_prefix = 'https://download.openmmlab.com/mmrazor/'
|
http_prefix = 'https://download.openmmlab.com/mmrazor/'
|
||||||
if 's3://' in args.checkpoint_root:
|
if 's3://' in args.checkpoint_root:
|
||||||
from mmcv.fileio import FileClient
|
from mmengine.fileio import FileClient
|
||||||
from petrel_client.common.exception import AccessDeniedError
|
from petrel_client.common.exception import AccessDeniedError
|
||||||
file_client = FileClient.infer_client(uri=args.checkpoint_root)
|
file_client = FileClient.infer_client(uri=args.checkpoint_root)
|
||||||
checkpoint = file_client.join_path(
|
checkpoint = file_client.join_path(
|
||||||
@ -171,7 +171,7 @@ def summary(args):
|
|||||||
if not latest_json.exists():
|
if not latest_json.exists():
|
||||||
print(f'{model_name} has no results.')
|
print(f'{model_name} has no results.')
|
||||||
continue
|
continue
|
||||||
latest_result = mmcv.load(latest_json, 'json')
|
latest_result = mmengine.load(latest_json, 'json')
|
||||||
|
|
||||||
expect_result = model_info.results[0].metrics
|
expect_result = model_info.results[0].metrics
|
||||||
summary_result = {
|
summary_result = {
|
||||||
@ -182,8 +182,8 @@ def summary(args):
|
|||||||
}
|
}
|
||||||
model_results[model_name] = summary_result
|
model_results[model_name] = summary_result
|
||||||
|
|
||||||
mmcv.fileio.dump(model_results,
|
mmengine.fileio.dump(model_results,
|
||||||
Path(args.work_dir) / 'summary.yml', 'yaml')
|
Path(args.work_dir) / 'summary.yml', 'yaml')
|
||||||
print(f'Summary results saved in {Path(args.work_dir)}/summary.yml')
|
print(f'Summary results saved in {Path(args.work_dir)}/summary.yml')
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ import re
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import mmcv
|
import mmengine
|
||||||
from modelindex.load_model_index import load
|
from modelindex.load_model_index import load
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.syntax import Syntax
|
from rich.syntax import Syntax
|
||||||
@ -145,7 +145,7 @@ def summary(args):
|
|||||||
if not latest_json.exists():
|
if not latest_json.exists():
|
||||||
print(f'{model_name} has no results.')
|
print(f'{model_name} has no results.')
|
||||||
continue
|
continue
|
||||||
latest_result = mmcv.load(latest_json, 'json')
|
latest_result = mmengine.load(latest_json, 'json')
|
||||||
|
|
||||||
expect_result = model_info.results[0].metrics
|
expect_result = model_info.results[0].metrics
|
||||||
summary_result = {
|
summary_result = {
|
||||||
@ -156,8 +156,8 @@ def summary(args):
|
|||||||
}
|
}
|
||||||
model_results[model_name] = summary_result
|
model_results[model_name] = summary_result
|
||||||
|
|
||||||
mmcv.fileio.dump(model_results,
|
mmengine.fileio.dump(model_results,
|
||||||
Path(args.work_dir) / 'summary.yml', 'yaml')
|
Path(args.work_dir) / 'summary.yml', 'yaml')
|
||||||
print(f'Summary results saved in {Path(args.work_dir)}/summary.yml')
|
print(f'Summary results saved in {Path(args.work_dir)}/summary.yml')
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,8 +3,8 @@ from typing import Dict, List, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.runner import load_checkpoint
|
|
||||||
from mmengine.optim import OPTIMIZERS, OptimWrapper
|
from mmengine.optim import OPTIMIZERS, OptimWrapper
|
||||||
|
from mmengine.runner import load_checkpoint
|
||||||
|
|
||||||
from mmrazor.models.utils import add_prefix, set_requires_grad
|
from mmrazor.models.utils import add_prefix, set_requires_grad
|
||||||
from mmrazor.registry import MODELS
|
from mmrazor.registry import MODELS
|
||||||
|
@ -2,9 +2,9 @@
|
|||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from mmcv.runner import load_checkpoint
|
|
||||||
from mmengine import BaseDataElement
|
from mmengine import BaseDataElement
|
||||||
from mmengine.model import BaseModel
|
from mmengine.model import BaseModel
|
||||||
|
from mmengine.runner import load_checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union
|
|||||||
from mmcls.models.backbones.base_backbone import BaseBackbone
|
from mmcls.models.backbones.base_backbone import BaseBackbone
|
||||||
from mmcls.models.utils import make_divisible
|
from mmcls.models.utils import make_divisible
|
||||||
from mmcv.cnn import ConvModule
|
from mmcv.cnn import ConvModule
|
||||||
from mmcv.runner import Sequential
|
from mmengine.model import Sequential
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
@ -166,7 +166,7 @@ class SearchableMobileNet(BaseBackbone):
|
|||||||
mutable_cfg (dict): Config of mutable.
|
mutable_cfg (dict): Config of mutable.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
mmcv.runner.Sequential: The layer made.
|
mmengine.model.Sequential: The layer made.
|
||||||
"""
|
"""
|
||||||
layers = []
|
layers = []
|
||||||
for i in range(num_blocks):
|
for i in range(num_blocks):
|
||||||
|
@ -4,8 +4,9 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union
|
|||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcls.models.backbones.base_backbone import BaseBackbone
|
from mmcls.models.backbones.base_backbone import BaseBackbone
|
||||||
from mmcv.cnn import ConvModule, constant_init, normal_init
|
from mmcv.cnn import ConvModule
|
||||||
from mmcv.runner import ModuleList, Sequential
|
from mmengine.model import ModuleList, Sequential
|
||||||
|
from mmengine.model.utils import constant_init, normal_init
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
@ -143,7 +144,7 @@ class SearchableShuffleNetV2(BaseBackbone):
|
|||||||
mutable_cfg (dict): Config of mutable.
|
mutable_cfg (dict): Config of mutable.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
mmcv.runner.Sequential: The layer made.
|
mmengine.model.Sequential: The layer made.
|
||||||
"""
|
"""
|
||||||
layers = []
|
layers = []
|
||||||
for i in range(num_blocks):
|
for i in range(num_blocks):
|
||||||
|
@ -3,7 +3,7 @@ from abc import ABCMeta, abstractmethod
|
|||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from mmcv.runner import BaseModule
|
from mmengine.model import BaseModule
|
||||||
|
|
||||||
|
|
||||||
class BaseConnector(BaseModule, metaclass=ABCMeta):
|
class BaseConnector(BaseModule, metaclass=ABCMeta):
|
||||||
|
@ -3,7 +3,6 @@ from typing import Callable, Dict
|
|||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from mmcv.cnn.bricks.registry import CONV_LAYERS
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from mmrazor.models.mutables.base_mutable import BaseMutable
|
from mmrazor.models.mutables.base_mutable import BaseMutable
|
||||||
@ -12,7 +11,6 @@ from .dynamic_conv_mixins import (BigNasConvMixin, DynamicConvMixin,
|
|||||||
OFAConvMixin)
|
OFAConvMixin)
|
||||||
|
|
||||||
|
|
||||||
@CONV_LAYERS.register_module()
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class DynamicConv2d(nn.Conv2d, DynamicConvMixin):
|
class DynamicConv2d(nn.Conv2d, DynamicConvMixin):
|
||||||
"""Dynamic Conv2d OP.
|
"""Dynamic Conv2d OP.
|
||||||
@ -65,7 +63,6 @@ class DynamicConv2d(nn.Conv2d, DynamicConvMixin):
|
|||||||
return self.forward_mixin(x)
|
return self.forward_mixin(x)
|
||||||
|
|
||||||
|
|
||||||
@CONV_LAYERS.register_module()
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class BigNasConv2d(nn.Conv2d, BigNasConvMixin):
|
class BigNasConv2d(nn.Conv2d, BigNasConvMixin):
|
||||||
"""Conv2d used in BigNas.
|
"""Conv2d used in BigNas.
|
||||||
@ -118,7 +115,6 @@ class BigNasConv2d(nn.Conv2d, BigNasConvMixin):
|
|||||||
return self.forward_mixin(x)
|
return self.forward_mixin(x)
|
||||||
|
|
||||||
|
|
||||||
@CONV_LAYERS.register_module()
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class OFAConv2d(nn.Conv2d, OFAConvMixin):
|
class OFAConv2d(nn.Conv2d, OFAConvMixin):
|
||||||
"""Conv2d used in `Once-for-All`.
|
"""Conv2d used in `Once-for-All`.
|
||||||
|
@ -3,11 +3,11 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from mmcv.cnn.bricks.registry import NORM_LAYERS
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
from mmrazor.models.mutables.base_mutable import BaseMutable
|
from mmrazor.models.mutables.base_mutable import BaseMutable
|
||||||
|
from mmrazor.registry import MODELS
|
||||||
from .dynamic_mixins import DynamicBatchNormMixin
|
from .dynamic_mixins import DynamicBatchNormMixin
|
||||||
|
|
||||||
|
|
||||||
@ -85,7 +85,7 @@ class _DynamicBatchNorm(_BatchNorm, DynamicBatchNormMixin):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@NORM_LAYERS.register_module()
|
@MODELS.register_module()
|
||||||
class DynamicBatchNorm1d(_DynamicBatchNorm):
|
class DynamicBatchNorm1d(_DynamicBatchNorm):
|
||||||
"""Dynamic BatchNorm1d OP."""
|
"""Dynamic BatchNorm1d OP."""
|
||||||
|
|
||||||
@ -100,7 +100,7 @@ class DynamicBatchNorm1d(_DynamicBatchNorm):
|
|||||||
input.dim()))
|
input.dim()))
|
||||||
|
|
||||||
|
|
||||||
@NORM_LAYERS.register_module()
|
@MODELS.register_module()
|
||||||
class DynamicBatchNorm2d(_DynamicBatchNorm):
|
class DynamicBatchNorm2d(_DynamicBatchNorm):
|
||||||
"""Dynamic BatchNorm2d OP."""
|
"""Dynamic BatchNorm2d OP."""
|
||||||
|
|
||||||
@ -115,7 +115,7 @@ class DynamicBatchNorm2d(_DynamicBatchNorm):
|
|||||||
input.dim()))
|
input.dim()))
|
||||||
|
|
||||||
|
|
||||||
@NORM_LAYERS.register_module()
|
@MODELS.register_module()
|
||||||
class DynamicBatchNorm3d(_DynamicBatchNorm):
|
class DynamicBatchNorm3d(_DynamicBatchNorm):
|
||||||
"""Dynamic BatchNorm3d OP."""
|
"""Dynamic BatchNorm3d OP."""
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from mmcv.runner import BaseModule
|
from mmengine.model import BaseModule
|
||||||
|
|
||||||
from mmrazor.models.utils import get_module_device
|
from mmrazor.models.utils import get_module_device
|
||||||
|
|
||||||
|
@ -2,9 +2,9 @@
|
|||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from mmcls.data import ClsDataSample
|
||||||
from mmcls.evaluation import Accuracy
|
from mmcls.evaluation import Accuracy
|
||||||
from mmcls.models.heads import LinearClsHead
|
from mmcls.models.heads import LinearClsHead
|
||||||
from mmcls.structures import ClsDataSample
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from mmrazor.models.utils import add_prefix
|
from mmrazor.models.utils import add_prefix
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from mmcv.runner import get_dist_info
|
from mmengine.dist import get_dist_info
|
||||||
|
|
||||||
from mmrazor.registry import MODELS
|
from mmrazor.registry import MODELS
|
||||||
from ..ops import GatherTensors
|
from ..ops import GatherTensors
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Generic, Optional, TypeVar
|
from typing import Dict, Generic, Optional, TypeVar
|
||||||
|
|
||||||
from mmcv.runner import BaseModule
|
from mmengine.model import BaseModule
|
||||||
|
|
||||||
CHOICE_TYPE = TypeVar('CHOICE_TYPE')
|
CHOICE_TYPE = TypeVar('CHOICE_TYPE')
|
||||||
CHOSEN_TYPE = TypeVar('CHOSEN_TYPE')
|
CHOSEN_TYPE = TypeVar('CHOSEN_TYPE')
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Generic, Optional, Type, TypeVar
|
from typing import Dict, Generic, Optional, Type, TypeVar
|
||||||
|
|
||||||
from mmcv.runner import BaseModule
|
from mmengine.model import BaseModule
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
|
||||||
from ..mutables.base_mutable import BaseMutable
|
from ..mutables.base_mutable import BaseMutable
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmcv.runner import BaseModule
|
from mmengine.model import BaseModule
|
||||||
|
|
||||||
|
|
||||||
class BaseOP(BaseModule):
|
class BaseOP(BaseModule):
|
||||||
|
@ -3,7 +3,7 @@ import functools
|
|||||||
from types import FunctionType
|
from types import FunctionType
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from mmcv.utils import import_modules_from_strings
|
from mmengine.utils import import_modules_from_strings
|
||||||
|
|
||||||
from mmrazor.registry import TASK_UTILS
|
from mmrazor.registry import TASK_UTILS
|
||||||
from .distill_delivery import DistillDelivery
|
from .distill_delivery import DistillDelivery
|
||||||
|
@ -3,7 +3,7 @@ import functools
|
|||||||
from types import FunctionType, ModuleType
|
from types import FunctionType, ModuleType
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from mmcv.utils import import_modules_from_strings
|
from mmengine.utils import import_modules_from_strings
|
||||||
|
|
||||||
from mmrazor.registry import TASK_UTILS
|
from mmrazor.registry import TASK_UTILS
|
||||||
from .distill_delivery import DistillDelivery
|
from .distill_delivery import DistillDelivery
|
||||||
|
@ -3,7 +3,7 @@ import functools
|
|||||||
from types import FunctionType, ModuleType
|
from types import FunctionType, ModuleType
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
from mmcv.utils import import_modules_from_strings
|
from mmengine.utils import import_modules_from_strings
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from mmrazor.registry import TASK_UTILS
|
from mmrazor.registry import TASK_UTILS
|
||||||
|
@ -3,7 +3,7 @@ import functools
|
|||||||
from types import FunctionType, ModuleType
|
from types import FunctionType, ModuleType
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
from mmcv.utils import import_modules_from_strings
|
from mmengine.utils import import_modules_from_strings
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from mmrazor.registry import TASK_UTILS
|
from mmrazor.registry import TASK_UTILS
|
||||||
|
@ -3,7 +3,7 @@ import copy
|
|||||||
import re
|
import re
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from mmcv import ConfigDict
|
from mmengine import ConfigDict
|
||||||
from torch.nn import Conv2d, Linear
|
from torch.nn import Conv2d, Linear
|
||||||
from torch.nn.modules import GroupNorm
|
from torch.nn.modules import GroupNorm
|
||||||
from torch.nn.modules.batchnorm import _NormBase
|
from torch.nn.modules.batchnorm import _NormBase
|
||||||
|
@ -1,2 +1,2 @@
|
|||||||
mmcls
|
mmcls
|
||||||
mmcv-full>=1.3.13
|
mmcv-full>=2.0.0rc0
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
from mmcv import ConfigDict
|
from mmengine import ConfigDict
|
||||||
|
|
||||||
from mmrazor.structures import DistillDeliveryManager
|
from mmrazor.structures import DistillDeliveryManager
|
||||||
|
|
||||||
|
@ -5,8 +5,8 @@ from unittest import TestCase
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmrazor.structures.graph import ModuleGraph
|
from mmrazor.structures.graph import ModuleGraph
|
||||||
from ...data.models import (AddCatModel, ConcatModel, LineModel,
|
from tests.data.models import (AddCatModel, ConcatModel, LineModel,
|
||||||
MultiConcatModel, MultiConcatModel2, ResBlock)
|
MultiConcatModel, MultiConcatModel2, ResBlock)
|
||||||
|
|
||||||
sys.setrecursionlimit(int(1e8))
|
sys.setrecursionlimit(int(1e8))
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from mmcv import ConfigDict
|
from mmengine import ConfigDict
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from toy_mod import Toy
|
from toy_mod import Toy
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ from unittest.mock import Mock
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from mmcls.structures import ClsDataSample
|
from mmcls.data import ClsDataSample
|
||||||
from mmengine.optim import build_optim_wrapper
|
from mmengine.optim import build_optim_wrapper
|
||||||
|
|
||||||
from mmrazor import digit_version
|
from mmrazor import digit_version
|
||||||
|
@ -3,7 +3,7 @@ import copy
|
|||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from mmcv import ConfigDict
|
from mmengine import ConfigDict
|
||||||
from mmengine.optim import build_optim_wrapper
|
from mmengine.optim import build_optim_wrapper
|
||||||
|
|
||||||
from mmrazor.models import DAFLDataFreeDistillation, DataFreeDistillation
|
from mmrazor.models import DAFLDataFreeDistillation, DataFreeDistillation
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from mmcv import ConfigDict
|
from mmengine import ConfigDict
|
||||||
|
|
||||||
from mmrazor.models import SelfDistill
|
from mmrazor.models import SelfDistill
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import copy
|
|||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from mmcv import ConfigDict
|
from mmengine import ConfigDict
|
||||||
from toy_models import ToyStudent
|
from toy_models import ToyStudent
|
||||||
|
|
||||||
from mmrazor.models import SingleTeacherDistill
|
from mmrazor.models import SingleTeacherDistill
|
||||||
|
@ -8,7 +8,7 @@ from unittest.mock import Mock
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from mmcls.structures import ClsDataSample
|
from mmcls.data import ClsDataSample
|
||||||
from mmengine import fileio
|
from mmengine import fileio
|
||||||
from mmengine.optim import build_optim_wrapper
|
from mmengine.optim import build_optim_wrapper
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
import copy
|
import copy
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
from mmcv import ConfigDict
|
from mmengine import ConfigDict
|
||||||
|
|
||||||
from mmrazor.models import BYOTDistiller
|
from mmrazor.models import BYOTDistiller
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
import copy
|
import copy
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
from mmcv import ConfigDict
|
from mmengine import ConfigDict
|
||||||
|
|
||||||
from mmrazor.models import ConfigurableDistiller
|
from mmrazor.models import ConfigurableDistiller
|
||||||
|
|
||||||
|
@ -4,8 +4,8 @@ import unittest
|
|||||||
from os.path import dirname
|
from os.path import dirname
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from mmcls.data import ClsDataSample
|
||||||
from mmcls.models import * # noqa: F401,F403
|
from mmcls.models import * # noqa: F401,F403
|
||||||
from mmcls.structures import ClsDataSample
|
|
||||||
|
|
||||||
from mmrazor import digit_version
|
from mmrazor import digit_version
|
||||||
from mmrazor.models.mutables import SlimmableMutableChannel
|
from mmrazor.models.mutables import SlimmableMutableChannel
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from mmcv import Config, DictAction
|
from mmengine import Config, DictAction
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
@ -4,9 +4,9 @@ import datetime
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import mmcv
|
import mmengine
|
||||||
import torch
|
import torch
|
||||||
from mmcv import digit_version
|
from mmengine import digit_version
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@ -70,9 +70,9 @@ def process_checkpoint(in_file: str,
|
|||||||
print(f'Successfully generated the publish-ckpt as {final_ckpt_file}.')
|
print(f'Successfully generated the publish-ckpt as {final_ckpt_file}.')
|
||||||
|
|
||||||
if subnet_cfg_file is not None:
|
if subnet_cfg_file is not None:
|
||||||
subnet_cfg = mmcv.fileio.load(subnet_cfg_file)
|
subnet_cfg = mmengine.fileio.load(subnet_cfg_file)
|
||||||
final_subnet_cfg_file = f'{final_file_prefix}_subnet_cfg.yaml'
|
final_subnet_cfg_file = f'{final_file_prefix}_subnet_cfg.yaml'
|
||||||
mmcv.fileio.dump(subnet_cfg, final_subnet_cfg_file)
|
mmengine.fileio.dump(subnet_cfg, final_subnet_cfg_file)
|
||||||
print(f'Successfully generated the publish-subnet-cfg as \
|
print(f'Successfully generated the publish-subnet-cfg as \
|
||||||
{final_subnet_cfg_file}.')
|
{final_subnet_cfg_file}.')
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user