mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Refactory] MMSegmentation Content
This commit is contained in:
parent
fba91957c0
commit
4b76f277a6
1
.gitignore
vendored
1
.gitignore
vendored
@ -105,7 +105,6 @@ venv.bak/
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
|
||||
data
|
||||
.vscode
|
||||
.idea
|
||||
|
||||
|
@ -145,7 +145,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from mmseg.apis import inference_model, init_model, show_result_pyplot\n",
|
||||
"from mmseg.core.evaluation import get_palette"
|
||||
"from mmseg.utils import get_palette"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -2,7 +2,7 @@
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from mmseg.apis import inference_model, init_model, show_result_pyplot
|
||||
from mmseg.core.evaluation import get_palette
|
||||
from mmseg.utils import get_palette
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -21,7 +21,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from mmseg.apis import init_model, inference_model, show_result_pyplot\n",
|
||||
"from mmseg.core.evaluation import get_palette"
|
||||
"from mmseg.utils import get_palette"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -4,7 +4,7 @@ from argparse import ArgumentParser
|
||||
import cv2
|
||||
|
||||
from mmseg.apis import inference_model, init_model
|
||||
from mmseg.core.evaluation import get_palette
|
||||
from mmseg.utils import get_palette
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -5,7 +5,7 @@ import torch
|
||||
from mmcv.parallel import collate, scatter
|
||||
from mmcv.runner import load_checkpoint
|
||||
|
||||
from mmseg.datasets.pipelines import Compose
|
||||
from mmseg.datasets.transforms import Compose
|
||||
from mmseg.models import build_segmentor
|
||||
|
||||
|
||||
|
@ -1,9 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .builder import build_optimizer
|
||||
from .data_structures import * # noqa: F401, F403
|
||||
from .evaluation import * # noqa: F401, F403
|
||||
from .optimizers import * # noqa: F401, F403
|
||||
from .seg import * # noqa: F401, F403
|
||||
from .utils import * # noqa: F401, F403
|
||||
|
||||
__all__ = ['build_optimizer']
|
@ -1,18 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
|
||||
from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS
|
||||
|
||||
|
||||
def build_optimizer(model, cfg):
|
||||
optim_wrapper_cfg = copy.deepcopy(cfg)
|
||||
constructor_type = optim_wrapper_cfg.pop('constructor',
|
||||
'DefaultOptimWrapperConstructor')
|
||||
paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None)
|
||||
optim_wrapper_builder = OPTIM_WRAPPER_CONSTRUCTORS.build(
|
||||
dict(
|
||||
type=constructor_type,
|
||||
optim_wrapper_cfg=optim_wrapper_cfg,
|
||||
paramwise_cfg=paramwise_cfg))
|
||||
optim_wrapper = optim_wrapper_builder(model)
|
||||
return optim_wrapper
|
@ -1,4 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .seg_data_sample import SegDataSample
|
||||
|
||||
__all__ = ['SegDataSample']
|
@ -1,4 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .class_names import get_classes, get_palette
|
||||
|
||||
__all__ = ['get_classes', 'get_palette']
|
@ -1,5 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base_pixel_sampler import BasePixelSampler
|
||||
from .ohem_pixel_sampler import OHEMPixelSampler
|
||||
|
||||
__all__ = ['BasePixelSampler', 'OHEMPixelSampler']
|
@ -1,12 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .dist_util import check_dist_init, sync_random_seed
|
||||
from .misc import add_prefix, stack_batch
|
||||
from .typing import (ConfigType, ForwardResults, MultiConfig, OptConfigType,
|
||||
OptMultiConfig, OptSampleList, SampleList, TensorDict,
|
||||
TensorList)
|
||||
|
||||
__all__ = [
|
||||
'add_prefix', 'check_dist_init', 'sync_random_seed', 'stack_batch',
|
||||
'ConfigType', 'OptConfigType', 'MultiConfig', 'OptMultiConfig',
|
||||
'SampleList', 'OptSampleList', 'TensorDict', 'TensorList', 'ForwardResults'
|
||||
]
|
@ -1,46 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv.runner import get_dist_info
|
||||
|
||||
|
||||
def check_dist_init():
|
||||
return dist.is_available() and dist.is_initialized()
|
||||
|
||||
|
||||
def sync_random_seed(seed=None, device='cuda'):
|
||||
"""Make sure different ranks share the same seed. All workers must call
|
||||
this function, otherwise it will deadlock. This method is generally used in
|
||||
`DistributedSampler`, because the seed should be identical across all
|
||||
processes in the distributed group.
|
||||
|
||||
In distributed sampling, different ranks should sample non-overlapped
|
||||
data in the dataset. Therefore, this function is used to make sure that
|
||||
each rank shuffles the data indices in the same order based
|
||||
on the same seed. Then different ranks could use different indices
|
||||
to select non-overlapped data from the same data list.
|
||||
|
||||
Args:
|
||||
seed (int, Optional): The seed. Default to None.
|
||||
device (str): The device where the seed will be put on.
|
||||
Default to 'cuda'.
|
||||
Returns:
|
||||
int: Seed to be used.
|
||||
"""
|
||||
|
||||
if seed is None:
|
||||
seed = np.random.randint(2**31)
|
||||
assert isinstance(seed, int)
|
||||
|
||||
rank, world_size = get_dist_info()
|
||||
|
||||
if world_size == 1:
|
||||
return seed
|
||||
|
||||
if rank == 0:
|
||||
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
|
||||
else:
|
||||
random_num = torch.tensor(0, dtype=torch.int32, device=device)
|
||||
dist.broadcast(random_num, src=0)
|
||||
return random_num.item()
|
@ -1,108 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmseg.core.utils.typing import SampleList
|
||||
|
||||
|
||||
def add_prefix(inputs, prefix):
|
||||
"""Add prefix for dict.
|
||||
|
||||
Args:
|
||||
inputs (dict): The input dict with str keys.
|
||||
prefix (str): The prefix to add.
|
||||
|
||||
Returns:
|
||||
|
||||
dict: The dict with keys updated with ``prefix``.
|
||||
"""
|
||||
|
||||
outputs = dict()
|
||||
for name, value in inputs.items():
|
||||
outputs[f'{prefix}.{name}'] = value
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def stack_batch(inputs: List[torch.Tensor],
|
||||
batch_data_samples: Optional[SampleList] = None,
|
||||
size: Optional[tuple] = None,
|
||||
size_divisor: Optional[int] = None,
|
||||
pad_val: Union[int, float] = 0,
|
||||
seg_pad_val: Union[int, float] = 255) -> torch.Tensor:
|
||||
"""Stack multiple inputs to form a batch and pad the images and gt_sem_segs
|
||||
to the max shape use the right bottom padding mode.
|
||||
|
||||
Args:
|
||||
inputs (List[Tensor]): The input multiple tensors. each is a
|
||||
CHW 3D-tensor.
|
||||
batch_data_samples (list[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as `gt_sem_seg`.
|
||||
size (tuple, optional): Fixed padding size.
|
||||
size_divisor (int, optional): The divisor of padded size.
|
||||
pad_val (int, float): The padding value. Defaults to 0
|
||||
seg_pad_val (int, float): The padding value. Defaults to 255
|
||||
|
||||
Returns:
|
||||
Tensor: The 4D-tensor.
|
||||
batch_data_samples (list[:obj:`SegDataSample`]): After the padding of
|
||||
the gt_seg_map.
|
||||
"""
|
||||
assert isinstance(inputs, list), \
|
||||
f'Expected input type to be list, but got {type(inputs)}'
|
||||
assert len(set([tensor.ndim for tensor in inputs])) == 1, \
|
||||
f'Expected the dimensions of all inputs must be the same, ' \
|
||||
f'but got {[tensor.ndim for tensor in inputs]}'
|
||||
assert inputs[0].ndim == 3, f'Expected tensor dimension to be 3, ' \
|
||||
f'but got {inputs[0].ndim}'
|
||||
assert len(set([tensor.shape[0] for tensor in inputs])) == 1, \
|
||||
f'Expected the channels of all inputs must be the same, ' \
|
||||
f'but got {[tensor.shape[0] for tensor in inputs]}'
|
||||
|
||||
# only one of size and size_divisor should be valid
|
||||
assert (size is not None) ^ (size_divisor is not None), \
|
||||
'only one of size and size_divisor should be valid'
|
||||
|
||||
padded_inputs = []
|
||||
padded_samples = []
|
||||
inputs_sizes = [(img.shape[-2], img.shape[-1]) for img in inputs]
|
||||
max_size = np.stack(inputs_sizes).max(0)
|
||||
if size_divisor is not None and size_divisor > 1:
|
||||
# the last two dims are H,W, both subject to divisibility requirement
|
||||
max_size = (max_size +
|
||||
(size_divisor - 1)) // size_divisor * size_divisor
|
||||
|
||||
for i in range(len(inputs)):
|
||||
tensor = inputs[i]
|
||||
if size is not None:
|
||||
width = max(size[-1] - tensor.shape[-1], 0)
|
||||
height = max(size[-2] - tensor.shape[-2], 0)
|
||||
# (padding_left, padding_right, padding_top, padding_bottom)
|
||||
padding_size = (0, width, 0, height)
|
||||
elif size_divisor is not None:
|
||||
width = max(max_size[-1] - tensor.shape[-1], 0)
|
||||
height = max(max_size[-2] - tensor.shape[-2], 0)
|
||||
padding_size = (0, width, 0, height)
|
||||
else:
|
||||
padding_size = [0, 0, 0, 0]
|
||||
|
||||
# pad img
|
||||
pad_img = F.pad(tensor, padding_size, value=pad_val)
|
||||
padded_inputs.append(pad_img)
|
||||
# pad gt_sem_seg
|
||||
if batch_data_samples is not None:
|
||||
data_sample = batch_data_samples[i]
|
||||
gt_sem_seg = data_sample.gt_sem_seg.data
|
||||
del data_sample.gt_sem_seg.data
|
||||
data_sample.gt_sem_seg.data = F.pad(
|
||||
gt_sem_seg, padding_size, value=seg_pad_val)
|
||||
data_sample.set_metainfo(
|
||||
{'pad_shape': data_sample.gt_sem_seg.shape})
|
||||
padded_samples.append(data_sample)
|
||||
else:
|
||||
padded_samples = None
|
||||
|
||||
return torch.stack(padded_inputs, dim=0), padded_samples
|
8
mmseg/data/__init__.py
Normal file
8
mmseg/data/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .sampler import BasePixelSampler, OHEMPixelSampler, build_pixel_sampler
|
||||
from .seg_data_sample import SegDataSample
|
||||
|
||||
__all__ = [
|
||||
'SegDataSample', 'BasePixelSampler', 'OHEMPixelSampler',
|
||||
'build_pixel_sampler'
|
||||
]
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base_pixel_sampler import BasePixelSampler
|
||||
from .builder import build_pixel_sampler
|
||||
from .sampler import BasePixelSampler, OHEMPixelSampler
|
||||
from .ohem_pixel_sampler import OHEMPixelSampler
|
||||
|
||||
__all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler']
|
@ -3,8 +3,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..builder import PIXEL_SAMPLERS
|
||||
from .base_pixel_sampler import BasePixelSampler
|
||||
from .builder import PIXEL_SAMPLERS
|
||||
|
||||
|
||||
@PIXEL_SAMPLERS.register_module()
|
@ -1,9 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# flake8: noqa
|
||||
import warnings
|
||||
|
||||
from .formatting import *
|
||||
|
||||
warnings.warn('DeprecationWarning: mmseg.datasets.pipelines.formating will be '
|
||||
'deprecated in 2021, please replace it with '
|
||||
'mmseg.datasets.pipelines.formatting.')
|
@ -1,4 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .distributed_sampler import DistributedSampler
|
||||
|
||||
__all__ = ['DistributedSampler']
|
@ -1,73 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from __future__ import division
|
||||
from typing import Iterator, Optional
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import DistributedSampler as _DistributedSampler
|
||||
|
||||
from mmseg.core.utils import sync_random_seed
|
||||
from mmseg.registry import DATA_SAMPLERS
|
||||
|
||||
|
||||
@DATA_SAMPLERS.register_module()
|
||||
class DistributedSampler(_DistributedSampler):
|
||||
"""DistributedSampler inheriting from
|
||||
`torch.utils.data.DistributedSampler`.
|
||||
|
||||
Args:
|
||||
datasets (Dataset): the dataset will be loaded.
|
||||
num_replicas (int, optional): Number of processes participating in
|
||||
distributed training. By default, world_size is retrieved from the
|
||||
current distributed group.
|
||||
rank (int, optional): Rank of the current process within num_replicas.
|
||||
By default, rank is retrieved from the current distributed group.
|
||||
shuffle (bool): If True (default), sampler will shuffle the indices.
|
||||
seed (int): random seed used to shuffle the sampler if
|
||||
:attr:`shuffle=True`. This number should be identical across all
|
||||
processes in the distributed group. Default: ``0``.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset: Dataset,
|
||||
num_replicas: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
shuffle: bool = True,
|
||||
seed=0) -> None:
|
||||
super().__init__(
|
||||
dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
||||
|
||||
# In distributed sampling, different ranks should sample
|
||||
# non-overlapped data in the dataset. Therefore, this function
|
||||
# is used to make sure that each rank shuffles the data indices
|
||||
# in the same order based on the same seed. Then different ranks
|
||||
# could use different indices to select non-overlapped data from the
|
||||
# same data list.
|
||||
self.seed = sync_random_seed(seed)
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
"""
|
||||
Yields:
|
||||
Iterator: iterator of indices for rank.
|
||||
"""
|
||||
# deterministically shuffle based on epoch
|
||||
if self.shuffle:
|
||||
g = torch.Generator()
|
||||
# When :attr:`shuffle=True`, this ensures all replicas
|
||||
# use a different random ordering for each epoch.
|
||||
# Otherwise, the next iteration of this sampler will
|
||||
# yield the same ordering.
|
||||
g.manual_seed(self.epoch + self.seed)
|
||||
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
||||
else:
|
||||
indices = torch.arange(len(self.dataset)).tolist()
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
@ -5,7 +5,7 @@ from mmcv.transforms import to_tensor
|
||||
from mmcv.transforms.base import BaseTransform
|
||||
from mmengine.data import PixelData
|
||||
|
||||
from mmseg.core import SegDataSample
|
||||
from mmseg.data import SegDataSample
|
||||
from mmseg.registry import TRANSFORMS
|
||||
|
||||
|
@ -3,10 +3,10 @@ import json
|
||||
import warnings
|
||||
|
||||
from mmengine.dist import get_dist_info
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.optim import DefaultOptimWrapperConstructor
|
||||
|
||||
from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS
|
||||
from mmseg.utils import get_root_logger
|
||||
|
||||
|
||||
def get_layer_id_for_convnext(var_name, max_layer_id):
|
||||
@ -119,15 +119,14 @@ class LearningRateDecayOptimizerConstructor(DefaultOptimWrapperConstructor):
|
||||
in place.
|
||||
module (nn.Module): The module to be added.
|
||||
"""
|
||||
logger = get_root_logger()
|
||||
|
||||
parameter_groups = {}
|
||||
logger.info(f'self.paramwise_cfg is {self.paramwise_cfg}')
|
||||
print_log(f'self.paramwise_cfg is {self.paramwise_cfg}')
|
||||
num_layers = self.paramwise_cfg.get('num_layers') + 2
|
||||
decay_rate = self.paramwise_cfg.get('decay_rate')
|
||||
decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise')
|
||||
logger.info('Build LearningRateDecayOptimizerConstructor '
|
||||
f'{decay_type} {decay_rate} - {num_layers}')
|
||||
print_log('Build LearningRateDecayOptimizerConstructor '
|
||||
f'{decay_type} {decay_rate} - {num_layers}')
|
||||
weight_decay = self.base_wd
|
||||
for name, param in module.named_parameters():
|
||||
if not param.requires_grad:
|
||||
@ -143,17 +142,17 @@ class LearningRateDecayOptimizerConstructor(DefaultOptimWrapperConstructor):
|
||||
if 'ConvNeXt' in module.backbone.__class__.__name__:
|
||||
layer_id = get_layer_id_for_convnext(
|
||||
name, self.paramwise_cfg.get('num_layers'))
|
||||
logger.info(f'set param {name} as id {layer_id}')
|
||||
print_log(f'set param {name} as id {layer_id}')
|
||||
elif 'BEiT' in module.backbone.__class__.__name__ or \
|
||||
'MAE' in module.backbone.__class__.__name__:
|
||||
layer_id = get_layer_id_for_vit(name, num_layers)
|
||||
logger.info(f'set param {name} as id {layer_id}')
|
||||
print_log(f'set param {name} as id {layer_id}')
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
elif decay_type == 'stage_wise':
|
||||
if 'ConvNeXt' in module.backbone.__class__.__name__:
|
||||
layer_id = get_stage_id_for_convnext(name, num_layers)
|
||||
logger.info(f'set param {name} as id {layer_id}')
|
||||
print_log(f'set param {name} as id {layer_id}')
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
group_name = f'layer_{layer_id}_{group_name}'
|
||||
@ -182,7 +181,7 @@ class LearningRateDecayOptimizerConstructor(DefaultOptimWrapperConstructor):
|
||||
'lr': parameter_groups[key]['lr'],
|
||||
'weight_decay': parameter_groups[key]['weight_decay'],
|
||||
}
|
||||
logger.info(f'Param groups = {json.dumps(to_display, indent=2)}')
|
||||
print_log(f'Param groups = {json.dumps(to_display, indent=2)}')
|
||||
params.extend(parameter_groups.values())
|
||||
|
||||
|
@ -14,7 +14,6 @@ from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.modules.utils import _pair as to_2tuple
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import get_root_logger
|
||||
from ..utils import PatchEmbed
|
||||
from .vit import TransformerEncoderLayer as VisionTransformerEncoderLayer
|
||||
|
||||
@ -500,9 +499,8 @@ class BEiT(BaseModule):
|
||||
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg.get('type') == 'Pretrained'):
|
||||
logger = get_root_logger()
|
||||
checkpoint = _load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
state_dict = self.resize_rel_pos_embed(checkpoint)
|
||||
self.load_state_dict(state_dict, False)
|
||||
elif self.init_cfg is not None:
|
||||
|
@ -9,7 +9,6 @@ from mmcv.runner import ModuleList, _load_checkpoint
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import get_root_logger
|
||||
from .beit import BEiT, BEiTAttention, BEiTTransformerEncoderLayer
|
||||
|
||||
|
||||
@ -180,9 +179,8 @@ class MAE(BEiT):
|
||||
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg.get('type') == 'Pretrained'):
|
||||
logger = get_root_logger()
|
||||
checkpoint = _load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
state_dict = self.resize_rel_pos_embed(checkpoint)
|
||||
state_dict = self.resize_abs_pos_embed(state_dict)
|
||||
self.load_state_dict(state_dict, False)
|
||||
|
@ -14,9 +14,9 @@ from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_,
|
||||
from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList,
|
||||
load_state_dict)
|
||||
from mmcv.utils import to_2tuple
|
||||
from mmengine.logging import print_log
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ...utils import get_root_logger
|
||||
from ..utils.embed import PatchEmbed, PatchMerging
|
||||
|
||||
|
||||
@ -662,11 +662,10 @@ class SwinTransformer(BaseModule):
|
||||
param.requires_grad = False
|
||||
|
||||
def init_weights(self):
|
||||
logger = get_root_logger()
|
||||
if self.init_cfg is None:
|
||||
logger.warn(f'No pre-trained weights for '
|
||||
f'{self.__class__.__name__}, '
|
||||
f'training start from scratch')
|
||||
print_log(f'No pre-trained weights for '
|
||||
f'{self.__class__.__name__}, '
|
||||
f'training start from scratch')
|
||||
if self.use_abs_pos_embed:
|
||||
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
||||
for m in self.modules():
|
||||
@ -680,7 +679,7 @@ class SwinTransformer(BaseModule):
|
||||
f'`init_cfg` in ' \
|
||||
f'{self.__class__.__name__} '
|
||||
ckpt = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
if 'state_dict' in ckpt:
|
||||
_state_dict = ckpt['state_dict']
|
||||
elif 'model' in ckpt:
|
||||
@ -705,7 +704,7 @@ class SwinTransformer(BaseModule):
|
||||
N1, L, C1 = absolute_pos_embed.size()
|
||||
N2, C2, H, W = self.absolute_pos_embed.size()
|
||||
if N1 != N2 or C1 != C2 or L != H * W:
|
||||
logger.warning('Error in loading absolute_pos_embed, pass')
|
||||
print_log('Error in loading absolute_pos_embed, pass')
|
||||
else:
|
||||
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
|
||||
N2, H, W, C2).permute(0, 3, 1, 2).contiguous()
|
||||
@ -721,7 +720,7 @@ class SwinTransformer(BaseModule):
|
||||
L1, nH1 = table_pretrained.size()
|
||||
L2, nH2 = table_current.size()
|
||||
if nH1 != nH2:
|
||||
logger.warning(f'Error in loading {table_key}, pass')
|
||||
print_log(f'Error in loading {table_key}, pass')
|
||||
elif L1 != L2:
|
||||
S1 = int(L1**0.5)
|
||||
S2 = int(L2**0.5)
|
||||
@ -733,7 +732,7 @@ class SwinTransformer(BaseModule):
|
||||
nH2, L2).permute(1, 0).contiguous()
|
||||
|
||||
# load state_dict
|
||||
load_state_dict(self, state_dict, strict=False, logger=logger)
|
||||
load_state_dict(self, state_dict, strict=False, logger=None)
|
||||
|
||||
def forward(self, x):
|
||||
x, hw_shape = self.patch_embed(x)
|
||||
|
@ -11,12 +11,12 @@ from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
|
||||
trunc_normal_)
|
||||
from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList,
|
||||
load_state_dict)
|
||||
from mmengine.logging import print_log
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.modules.utils import _pair as to_2tuple
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import get_root_logger
|
||||
from ..utils import PatchEmbed
|
||||
|
||||
|
||||
@ -293,9 +293,8 @@ class VisionTransformer(BaseModule):
|
||||
def init_weights(self):
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg.get('type') == 'Pretrained'):
|
||||
logger = get_root_logger()
|
||||
checkpoint = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
@ -304,9 +303,9 @@ class VisionTransformer(BaseModule):
|
||||
|
||||
if 'pos_embed' in state_dict.keys():
|
||||
if self.pos_embed.shape != state_dict['pos_embed'].shape:
|
||||
logger.info(msg=f'Resize the pos_embed shape from '
|
||||
f'{state_dict["pos_embed"].shape} to '
|
||||
f'{self.pos_embed.shape}')
|
||||
print_log(msg=f'Resize the pos_embed shape from '
|
||||
f'{state_dict["pos_embed"].shape} to '
|
||||
f'{self.pos_embed.shape}')
|
||||
h, w = self.img_size
|
||||
pos_size = int(
|
||||
math.sqrt(state_dict['pos_embed'].shape[1] - 1))
|
||||
@ -315,7 +314,7 @@ class VisionTransformer(BaseModule):
|
||||
(h // self.patch_size, w // self.patch_size),
|
||||
(pos_size, pos_size), self.interpolate_mode)
|
||||
|
||||
load_state_dict(self, state_dict, strict=False, logger=logger)
|
||||
load_state_dict(self, state_dict, strict=False, logger=None)
|
||||
elif self.init_cfg is not None:
|
||||
super(VisionTransformer, self).init_weights()
|
||||
else:
|
||||
|
@ -6,9 +6,8 @@ import torch
|
||||
from mmengine.model import BaseDataPreprocessor
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.core import stack_batch
|
||||
from mmseg.core.utils import OptSampleList
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import OptSampleList, stack_batch
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
@ -4,7 +4,7 @@ from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.core.utils import ConfigType
|
||||
from mmseg.utils import ConfigType
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
|
@ -6,9 +6,8 @@ import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, Scale
|
||||
from torch import Tensor, nn
|
||||
|
||||
from mmseg.core import add_prefix
|
||||
from mmseg.core.utils import SampleList
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import SampleList, add_prefix
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
@ -7,10 +7,9 @@ import torch.nn as nn
|
||||
from mmcv.runner import BaseModule
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.core import build_pixel_sampler
|
||||
from mmseg.core.utils import SampleList
|
||||
from mmseg.core.utils.typing import ConfigType
|
||||
from mmseg.data import build_pixel_sampler
|
||||
from mmseg.ops import resize
|
||||
from mmseg.utils import ConfigType, SampleList
|
||||
from ..builder import build_loss
|
||||
from ..losses import accuracy
|
||||
|
||||
|
@ -7,10 +7,9 @@ import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.core.utils import SampleList
|
||||
from mmseg.core.utils.typing import ConfigType
|
||||
from mmseg.ops import Encoding, resize
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import ConfigType, SampleList
|
||||
from ..builder import build_loss
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
@ -8,12 +8,12 @@ from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import (FFN, TRANSFORMER_LAYER,
|
||||
MultiheadAttention,
|
||||
build_transformer_layer)
|
||||
from mmengine.logging import print_log
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.core.utils import SampleList
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import get_root_logger
|
||||
from mmseg.utils import SampleList
|
||||
|
||||
|
||||
@TRANSFORMER_LAYER.register_module()
|
||||
@ -276,8 +276,7 @@ class KernelUpdateHead(nn.Module):
|
||||
# the weight and bias of the layer norm
|
||||
pass
|
||||
if self.kernel_init:
|
||||
logger = get_root_logger()
|
||||
logger.info(
|
||||
print_log(
|
||||
'mask kernel in mask head is normal initialized by std 0.01')
|
||||
nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01)
|
||||
|
||||
|
@ -12,9 +12,9 @@ except ModuleNotFoundError:
|
||||
|
||||
from typing import List
|
||||
|
||||
from mmseg.core.utils import SampleList
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import SampleList
|
||||
from ..losses import accuracy
|
||||
from .cascade_decode_head import BaseCascadeDecodeHead
|
||||
|
||||
|
@ -4,9 +4,9 @@ import torch.nn.functional as F
|
||||
from mmengine.data import PixelData
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.core.data_structures.seg_data_sample import SegDataSample
|
||||
from mmseg.core.utils import SampleList
|
||||
from mmseg.data import SegDataSample
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import SampleList
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
|
@ -6,10 +6,10 @@ from mmengine.data import PixelData
|
||||
from mmengine.model import BaseModel
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.core import SegDataSample
|
||||
from mmseg.core.utils import (ForwardResults, OptConfigType, OptMultiConfig,
|
||||
OptSampleList, SampleList)
|
||||
from mmseg.data import SegDataSample
|
||||
from mmseg.ops import resize
|
||||
from mmseg.utils import (ForwardResults, OptConfigType, OptMultiConfig,
|
||||
OptSampleList, SampleList)
|
||||
|
||||
|
||||
class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
||||
|
@ -3,10 +3,9 @@ from typing import List, Optional
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
from mmseg.core import add_prefix
|
||||
from mmseg.core.utils import ConfigType, OptSampleList, SampleList
|
||||
from mmseg.core.utils.typing import OptConfigType, OptMultiConfig
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig,
|
||||
OptSampleList, SampleList, add_prefix)
|
||||
from .encoder_decoder import EncoderDecoder
|
||||
|
||||
|
||||
|
@ -6,10 +6,9 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.core import add_prefix
|
||||
from mmseg.core.utils import (ConfigType, OptConfigType, OptMultiConfig,
|
||||
OptSampleList, SampleList)
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig,
|
||||
OptSampleList, SampleList, add_prefix)
|
||||
from .base import BaseSegmentor
|
||||
|
||||
|
||||
|
@ -1,10 +1,29 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# yapf: disable
|
||||
from .class_names import (ade_classes, ade_palette, cityscapes_classes,
|
||||
cityscapes_palette, cocostuff_classes,
|
||||
cocostuff_palette, dataset_aliases, get_classes,
|
||||
get_palette, isaid_classes, isaid_palette,
|
||||
loveda_classes, loveda_palette, potsdam_classes,
|
||||
potsdam_palette, stare_classes, stare_palette,
|
||||
vaihingen_classes, vaihingen_palette, voc_classes,
|
||||
voc_palette)
|
||||
# yapf: enable
|
||||
from .collect_env import collect_env
|
||||
from .logger import get_root_logger
|
||||
from .misc import find_latest_checkpoint
|
||||
from .set_env import register_all_modules, setup_multi_processes
|
||||
from .misc import add_prefix, stack_batch
|
||||
from .set_env import register_all_modules
|
||||
from .typing import (ConfigType, ForwardResults, MultiConfig, OptConfigType,
|
||||
OptMultiConfig, OptSampleList, SampleList, TensorDict,
|
||||
TensorList)
|
||||
|
||||
__all__ = [
|
||||
'get_root_logger', 'collect_env', 'find_latest_checkpoint',
|
||||
'setup_multi_processes', 'register_all_modules'
|
||||
'collect_env', 'register_all_modules', 'stack_batch', 'add_prefix',
|
||||
'ConfigType', 'OptConfigType', 'MultiConfig', 'OptMultiConfig',
|
||||
'SampleList', 'OptSampleList', 'TensorDict', 'TensorList',
|
||||
'ForwardResults', 'cityscapes_classes', 'ade_classes', 'voc_classes',
|
||||
'cocostuff_classes', 'loveda_classes', 'potsdam_classes',
|
||||
'vaihingen_classes', 'isaid_classes', 'stare_classes',
|
||||
'cityscapes_palette', 'ade_palette', 'voc_palette', 'cocostuff_palette',
|
||||
'loveda_palette', 'potsdam_palette', 'vaihingen_palette', 'isaid_palette',
|
||||
'stare_palette', 'dataset_aliases', 'get_classes', 'get_palette'
|
||||
]
|
||||
|
@ -1,28 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
|
||||
from mmcv.utils import get_logger
|
||||
|
||||
|
||||
def get_root_logger(log_file=None, log_level=logging.INFO):
|
||||
"""Get the root logger.
|
||||
|
||||
The logger will be initialized if it has not been initialized. By default a
|
||||
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
||||
also be added. The name of the root logger is the top-level package name,
|
||||
e.g., "mmseg".
|
||||
|
||||
Args:
|
||||
log_file (str | None): The log filename. If specified, a FileHandler
|
||||
will be added to the root logger.
|
||||
log_level (int): The root logger level. Note that only the process of
|
||||
rank 0 is affected, while other processes will set the level to
|
||||
"Error" and be silent most of the time.
|
||||
|
||||
Returns:
|
||||
logging.Logger: The root logger.
|
||||
"""
|
||||
|
||||
logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level)
|
||||
|
||||
return logger
|
@ -1,41 +1,108 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import glob
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .typing import SampleList
|
||||
|
||||
|
||||
def find_latest_checkpoint(path, suffix='pth'):
|
||||
"""This function is for finding the latest checkpoint.
|
||||
|
||||
It will be used when automatically resume, modified from
|
||||
https://github.com/open-mmlab/mmdetection/blob/dev-v2.20.0/mmdet/utils/misc.py
|
||||
def add_prefix(inputs, prefix):
|
||||
"""Add prefix for dict.
|
||||
|
||||
Args:
|
||||
path (str): The path to find checkpoints.
|
||||
suffix (str): File extension for the checkpoint. Defaults to pth.
|
||||
inputs (dict): The input dict with str keys.
|
||||
prefix (str): The prefix to add.
|
||||
|
||||
Returns:
|
||||
latest_path(str | None): File path of the latest checkpoint.
|
||||
"""
|
||||
if not osp.exists(path):
|
||||
warnings.warn("The path of the checkpoints doesn't exist.")
|
||||
return None
|
||||
if osp.exists(osp.join(path, f'latest.{suffix}')):
|
||||
return osp.join(path, f'latest.{suffix}')
|
||||
|
||||
checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
|
||||
if len(checkpoints) == 0:
|
||||
warnings.warn('The are no checkpoints in the path')
|
||||
return None
|
||||
latest = -1
|
||||
latest_path = ''
|
||||
for checkpoint in checkpoints:
|
||||
if len(checkpoint) < len(latest_path):
|
||||
continue
|
||||
# `count` is iteration number, as checkpoints are saved as
|
||||
# 'iter_xx.pth' or 'epoch_xx.pth' and xx is iteration number.
|
||||
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
|
||||
if count > latest:
|
||||
latest = count
|
||||
latest_path = checkpoint
|
||||
return latest_path
|
||||
dict: The dict with keys updated with ``prefix``.
|
||||
"""
|
||||
|
||||
outputs = dict()
|
||||
for name, value in inputs.items():
|
||||
outputs[f'{prefix}.{name}'] = value
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def stack_batch(inputs: List[torch.Tensor],
|
||||
batch_data_samples: Optional[SampleList] = None,
|
||||
size: Optional[tuple] = None,
|
||||
size_divisor: Optional[int] = None,
|
||||
pad_val: Union[int, float] = 0,
|
||||
seg_pad_val: Union[int, float] = 255) -> torch.Tensor:
|
||||
"""Stack multiple inputs to form a batch and pad the images and gt_sem_segs
|
||||
to the max shape use the right bottom padding mode.
|
||||
|
||||
Args:
|
||||
inputs (List[Tensor]): The input multiple tensors. each is a
|
||||
CHW 3D-tensor.
|
||||
batch_data_samples (list[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as `gt_sem_seg`.
|
||||
size (tuple, optional): Fixed padding size.
|
||||
size_divisor (int, optional): The divisor of padded size.
|
||||
pad_val (int, float): The padding value. Defaults to 0
|
||||
seg_pad_val (int, float): The padding value. Defaults to 255
|
||||
|
||||
Returns:
|
||||
Tensor: The 4D-tensor.
|
||||
batch_data_samples (list[:obj:`SegDataSample`]): After the padding of
|
||||
the gt_seg_map.
|
||||
"""
|
||||
assert isinstance(inputs, list), \
|
||||
f'Expected input type to be list, but got {type(inputs)}'
|
||||
assert len(set([tensor.ndim for tensor in inputs])) == 1, \
|
||||
f'Expected the dimensions of all inputs must be the same, ' \
|
||||
f'but got {[tensor.ndim for tensor in inputs]}'
|
||||
assert inputs[0].ndim == 3, f'Expected tensor dimension to be 3, ' \
|
||||
f'but got {inputs[0].ndim}'
|
||||
assert len(set([tensor.shape[0] for tensor in inputs])) == 1, \
|
||||
f'Expected the channels of all inputs must be the same, ' \
|
||||
f'but got {[tensor.shape[0] for tensor in inputs]}'
|
||||
|
||||
# only one of size and size_divisor should be valid
|
||||
assert (size is not None) ^ (size_divisor is not None), \
|
||||
'only one of size and size_divisor should be valid'
|
||||
|
||||
padded_inputs = []
|
||||
padded_samples = []
|
||||
inputs_sizes = [(img.shape[-2], img.shape[-1]) for img in inputs]
|
||||
max_size = np.stack(inputs_sizes).max(0)
|
||||
if size_divisor is not None and size_divisor > 1:
|
||||
# the last two dims are H,W, both subject to divisibility requirement
|
||||
max_size = (max_size +
|
||||
(size_divisor - 1)) // size_divisor * size_divisor
|
||||
|
||||
for i in range(len(inputs)):
|
||||
tensor = inputs[i]
|
||||
if size is not None:
|
||||
width = max(size[-1] - tensor.shape[-1], 0)
|
||||
height = max(size[-2] - tensor.shape[-2], 0)
|
||||
# (padding_left, padding_right, padding_top, padding_bottom)
|
||||
padding_size = (0, width, 0, height)
|
||||
elif size_divisor is not None:
|
||||
width = max(max_size[-1] - tensor.shape[-1], 0)
|
||||
height = max(max_size[-2] - tensor.shape[-2], 0)
|
||||
padding_size = (0, width, 0, height)
|
||||
else:
|
||||
padding_size = [0, 0, 0, 0]
|
||||
|
||||
# pad img
|
||||
pad_img = F.pad(tensor, padding_size, value=pad_val)
|
||||
padded_inputs.append(pad_img)
|
||||
# pad gt_sem_seg
|
||||
if batch_data_samples is not None:
|
||||
data_sample = batch_data_samples[i]
|
||||
gt_sem_seg = data_sample.gt_sem_seg.data
|
||||
del data_sample.gt_sem_seg.data
|
||||
data_sample.gt_sem_seg.data = F.pad(
|
||||
gt_sem_seg, padding_size, value=seg_pad_val)
|
||||
data_sample.set_metainfo(
|
||||
{'pad_shape': data_sample.gt_sem_seg.shape})
|
||||
padded_samples.append(data_sample)
|
||||
else:
|
||||
padded_samples = None
|
||||
|
||||
return torch.stack(padded_inputs, dim=0), padded_samples
|
||||
|
@ -1,62 +1,9 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import datetime
|
||||
import os
|
||||
import platform
|
||||
import warnings
|
||||
|
||||
import cv2
|
||||
import torch.multiprocessing as mp
|
||||
from mmengine import DefaultScope
|
||||
|
||||
from ..utils import get_root_logger
|
||||
|
||||
|
||||
def setup_multi_processes(cfg):
|
||||
"""Setup multi-processing environment variables."""
|
||||
logger = get_root_logger()
|
||||
|
||||
# set multi-process start method
|
||||
if platform.system() != 'Windows':
|
||||
mp_start_method = cfg.get('mp_start_method', None)
|
||||
current_method = mp.get_start_method(allow_none=True)
|
||||
if mp_start_method in ('fork', 'spawn', 'forkserver'):
|
||||
logger.info(
|
||||
f'Multi-processing start method `{mp_start_method}` is '
|
||||
f'different from the previous setting `{current_method}`.'
|
||||
f'It will be force set to `{mp_start_method}`.')
|
||||
mp.set_start_method(mp_start_method, force=True)
|
||||
else:
|
||||
logger.info(
|
||||
f'Multi-processing start method is `{mp_start_method}`')
|
||||
|
||||
# disable opencv multithreading to avoid system being overloaded
|
||||
opencv_num_threads = cfg.get('opencv_num_threads', None)
|
||||
if isinstance(opencv_num_threads, int):
|
||||
logger.info(f'OpenCV num_threads is `{opencv_num_threads}`')
|
||||
cv2.setNumThreads(opencv_num_threads)
|
||||
else:
|
||||
logger.info(f'OpenCV num_threads is `{cv2.getNumThreads}')
|
||||
|
||||
if cfg.data.workers_per_gpu > 1:
|
||||
# setup OMP threads
|
||||
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
|
||||
omp_num_threads = cfg.get('omp_num_threads', None)
|
||||
if 'OMP_NUM_THREADS' not in os.environ:
|
||||
if isinstance(omp_num_threads, int):
|
||||
logger.info(f'OMP num threads is {omp_num_threads}')
|
||||
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
|
||||
else:
|
||||
logger.info(f'OMP num threads is {os.environ["OMP_NUM_THREADS"] }')
|
||||
|
||||
# setup MKL threads
|
||||
if 'MKL_NUM_THREADS' not in os.environ:
|
||||
mkl_num_threads = cfg.get('mkl_num_threads', None)
|
||||
if isinstance(mkl_num_threads, int):
|
||||
logger.info(f'MKL num threads is {mkl_num_threads}')
|
||||
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
|
||||
else:
|
||||
logger.info(f'MKL num threads is {os.environ["MKL_NUM_THREADS"]}')
|
||||
|
||||
|
||||
def register_all_modules(init_default_scope: bool = True) -> None:
|
||||
"""Register all modules in mmseg into the registries.
|
||||
@ -69,9 +16,9 @@ def register_all_modules(init_default_scope: bool = True) -> None:
|
||||
to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md
|
||||
Defaults to True.
|
||||
""" # noqa
|
||||
import mmseg.core # noqa: F401,F403
|
||||
import mmseg.data # noqa: F401,F403
|
||||
import mmseg.datasets # noqa: F401,F403
|
||||
import mmseg.datasets.pipelines # noqa: F401,F403
|
||||
import mmseg.engine # noqa: F401,F403
|
||||
import mmseg.metrics # noqa: F401,F403
|
||||
import mmseg.models # noqa: F401,F403
|
||||
|
||||
|
@ -5,7 +5,7 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
import torch
|
||||
from mmengine.config import ConfigDict
|
||||
|
||||
from ..data_structures import SegDataSample
|
||||
from mmseg.data import SegDataSample
|
||||
|
||||
# Type hint of config data
|
||||
ConfigType = Union[ConfigDict, dict]
|
@ -1,7 +0,0 @@
|
||||
[pytest]
|
||||
addopts = --xdoctest --xdoctest-style=auto
|
||||
norecursedirs = .git ignore build __pycache__ data docker docs .eggs
|
||||
|
||||
filterwarnings= default
|
||||
ignore:.*No cfgstr given in Cacher constructor or call.*:Warning
|
||||
ignore:.*Define the __nice__ method for.*:Warning
|
@ -60,72 +60,72 @@ def test_config_build_segmentor():
|
||||
_check_decode_head(head_config, segmentor.decode_head)
|
||||
|
||||
|
||||
def test_config_data_pipeline():
|
||||
"""Test whether the data pipeline is valid and can process corner cases.
|
||||
# def test_config_data_pipeline():
|
||||
# """Test whether the data pipeline is valid and can process corner cases.
|
||||
|
||||
CommandLine:
|
||||
xdoctest -m tests/test_config.py test_config_build_data_pipeline
|
||||
"""
|
||||
import numpy as np
|
||||
from mmcv import Config
|
||||
# CommandLine:
|
||||
# xdoctest -m tests/test_config.py test_config_build_data_pipeline
|
||||
# """
|
||||
# import numpy as np
|
||||
# from mmcv import Config
|
||||
|
||||
from mmseg.datasets.pipelines import Compose
|
||||
# from mmseg.datasets.transforms import Compose
|
||||
|
||||
config_dpath = _get_config_directory()
|
||||
print('Found config_dpath = {!r}'.format(config_dpath))
|
||||
# config_dpath = _get_config_directory()
|
||||
# print('Found config_dpath = {!r}'.format(config_dpath))
|
||||
|
||||
import glob
|
||||
config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
|
||||
config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
|
||||
config_names = [relpath(p, config_dpath) for p in config_fpaths]
|
||||
# import glob
|
||||
# config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
|
||||
# config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
|
||||
# config_names = [relpath(p, config_dpath) for p in config_fpaths]
|
||||
|
||||
print('Using {} config files'.format(len(config_names)))
|
||||
# print('Using {} config files'.format(len(config_names)))
|
||||
|
||||
for config_fname in config_names:
|
||||
config_fpath = join(config_dpath, config_fname)
|
||||
print(
|
||||
'Building data pipeline, config_fpath = {!r}'.format(config_fpath))
|
||||
config_mod = Config.fromfile(config_fpath)
|
||||
# for config_fname in config_names:
|
||||
# config_fpath = join(config_dpath, config_fname)
|
||||
# print(
|
||||
# 'Building data pipeline, config_fpath = {!r}'.format(config_fpath))
|
||||
# config_mod = Config.fromfile(config_fpath)
|
||||
|
||||
# remove loading pipeline
|
||||
load_img_pipeline = config_mod.train_pipeline.pop(0)
|
||||
to_float32 = load_img_pipeline.get('to_float32', False)
|
||||
config_mod.train_pipeline.pop(0)
|
||||
config_mod.test_pipeline.pop(0)
|
||||
# remove loading annotation in test pipeline
|
||||
config_mod.test_pipeline.pop(1)
|
||||
# # remove loading pipeline
|
||||
# load_img_pipeline = config_mod.train_pipeline.pop(0)
|
||||
# to_float32 = load_img_pipeline.get('to_float32', False)
|
||||
# config_mod.train_pipeline.pop(0)
|
||||
# config_mod.test_pipeline.pop(0)
|
||||
# # remove loading annotation in test pipeline
|
||||
# config_mod.test_pipeline.pop(1)
|
||||
|
||||
train_pipeline = Compose(config_mod.train_pipeline)
|
||||
test_pipeline = Compose(config_mod.test_pipeline)
|
||||
# train_pipeline = Compose(config_mod.train_pipeline)
|
||||
# test_pipeline = Compose(config_mod.test_pipeline)
|
||||
|
||||
img = np.random.randint(0, 255, size=(1024, 2048, 3), dtype=np.uint8)
|
||||
if to_float32:
|
||||
img = img.astype(np.float32)
|
||||
seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8)
|
||||
# img = np.random.randint(0, 255, size=(1024, 2048, 3), dtype=np.uint8)
|
||||
# if to_float32:
|
||||
# img = img.astype(np.float32)
|
||||
# seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8)
|
||||
|
||||
results = dict(
|
||||
filename='test_img.png',
|
||||
ori_filename='test_img.png',
|
||||
img=img,
|
||||
img_shape=img.shape,
|
||||
ori_shape=img.shape,
|
||||
gt_seg_map=seg)
|
||||
results['seg_fields'] = ['gt_seg_map']
|
||||
# results = dict(
|
||||
# filename='test_img.png',
|
||||
# ori_filename='test_img.png',
|
||||
# img=img,
|
||||
# img_shape=img.shape,
|
||||
# ori_shape=img.shape,
|
||||
# gt_seg_map=seg)
|
||||
# results['seg_fields'] = ['gt_seg_map']
|
||||
|
||||
print('Test training data pipeline: \n{!r}'.format(train_pipeline))
|
||||
output_results = train_pipeline(results)
|
||||
assert output_results is not None
|
||||
# print('Test training data pipeline: \n{!r}'.format(train_pipeline))
|
||||
# output_results = train_pipeline(results)
|
||||
# assert output_results is not None
|
||||
|
||||
results = dict(
|
||||
filename='test_img.png',
|
||||
ori_filename='test_img.png',
|
||||
img=img,
|
||||
img_shape=img.shape,
|
||||
ori_shape=img.shape,
|
||||
)
|
||||
print('Test testing data pipeline: \n{!r}'.format(test_pipeline))
|
||||
output_results = test_pipeline(results)
|
||||
assert output_results is not None
|
||||
# results = dict(
|
||||
# filename='test_img.png',
|
||||
# ori_filename='test_img.png',
|
||||
# img=img,
|
||||
# img_shape=img.shape,
|
||||
# ori_shape=img.shape,
|
||||
# )
|
||||
# print('Test testing data pipeline: \n{!r}'.format(test_pipeline))
|
||||
# output_results = test_pipeline(results)
|
||||
# assert output_results is not None
|
||||
|
||||
|
||||
def _check_decode_head(decode_head_cfg, decode_head):
|
||||
|
@ -6,7 +6,7 @@ import pytest
|
||||
import torch
|
||||
from mmengine.data import PixelData
|
||||
|
||||
from mmseg.core import SegDataSample
|
||||
from mmseg.data import SegDataSample
|
||||
|
||||
|
||||
def _equal(a, b):
|
@ -6,11 +6,11 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from mmseg.core.evaluation import get_classes, get_palette
|
||||
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
|
||||
COCOStuffDataset, CustomDataset, ISPRSDataset,
|
||||
LoveDADataset, PascalVOCDataset, PotsdamDataset,
|
||||
iSAIDDataset)
|
||||
from mmseg.utils import get_classes, get_palette
|
||||
|
||||
|
||||
def test_classes():
|
||||
|
@ -6,8 +6,8 @@ import unittest
|
||||
import numpy as np
|
||||
from mmengine.data import BaseDataElement
|
||||
|
||||
from mmseg.core import SegDataSample
|
||||
from mmseg.datasets.pipelines import PackSegInputs
|
||||
from mmseg.data import SegDataSample
|
||||
from mmseg.datasets.transforms import PackSegInputs
|
||||
|
||||
|
||||
class TestPackSegInputs(unittest.TestCase):
|
||||
|
@ -7,7 +7,7 @@ import mmcv
|
||||
import numpy as np
|
||||
from mmcv.transforms import LoadImageFromFile
|
||||
|
||||
from mmseg.datasets.pipelines import LoadAnnotations
|
||||
from mmseg.datasets.transforms import LoadAnnotations
|
||||
|
||||
|
||||
class TestLoading(object):
|
||||
|
@ -7,7 +7,7 @@ import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from mmseg.datasets.pipelines import PhotoMetricDistortion, RandomCrop
|
||||
from mmseg.datasets.transforms import PhotoMetricDistortion, RandomCrop
|
||||
from mmseg.registry import TRANSFORMS
|
||||
|
||||
|
||||
|
@ -4,7 +4,7 @@ import os.path as osp
|
||||
import mmcv
|
||||
import pytest
|
||||
|
||||
from mmseg.datasets.pipelines import * # noqa
|
||||
from mmseg.datasets.transforms import * # noqa
|
||||
from mmseg.registry import TRANSFORMS
|
||||
|
||||
|
||||
|
@ -4,10 +4,13 @@ import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.optim.optimizer import build_optim_wrapper
|
||||
|
||||
from mmseg.core.builder import build_optimizer
|
||||
from mmseg.core.optimizers.layer_decay_optimizer_constructor import \
|
||||
from mmseg.engine.optimizers.layer_decay_optimizer_constructor import \
|
||||
LearningRateDecayOptimizerConstructor
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
base_lr = 1
|
||||
decay_rate = 2
|
||||
@ -221,7 +224,7 @@ def test_learning_rate_decay_optimizer_constructor():
|
||||
optimizer=optimizer_cfg,
|
||||
paramwise_cfg=stagewise_paramwise_cfg,
|
||||
constructor='LearningRateDecayOptimizerConstructor')
|
||||
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
|
||||
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
|
||||
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
||||
expected_stage_wise_lr_wd_convnext)
|
||||
# layerwise decay
|
||||
@ -232,7 +235,7 @@ def test_learning_rate_decay_optimizer_constructor():
|
||||
optimizer=optimizer_cfg,
|
||||
paramwise_cfg=layerwise_paramwise_cfg,
|
||||
constructor='LearningRateDecayOptimizerConstructor')
|
||||
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
|
||||
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
|
||||
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
||||
expected_layer_wise_lr_wd_convnext)
|
||||
|
||||
@ -247,7 +250,7 @@ def test_learning_rate_decay_optimizer_constructor():
|
||||
optimizer=optimizer_cfg,
|
||||
paramwise_cfg=layerwise_paramwise_cfg,
|
||||
constructor='LearningRateDecayOptimizerConstructor')
|
||||
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
|
||||
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
|
||||
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
||||
expected_layer_wise_wd_lr_beit)
|
||||
|
||||
@ -274,7 +277,7 @@ def test_learning_rate_decay_optimizer_constructor():
|
||||
optimizer=optimizer_cfg,
|
||||
paramwise_cfg=layerwise_paramwise_cfg,
|
||||
constructor='LearningRateDecayOptimizerConstructor')
|
||||
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
|
||||
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
|
||||
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
||||
expected_layer_wise_wd_lr_beit)
|
||||
|
||||
@ -291,7 +294,7 @@ def test_beit_layer_decay_optimizer_constructor():
|
||||
paramwise_cfg=paramwise_cfg,
|
||||
optimizer=dict(
|
||||
type='AdamW', lr=1, betas=(0.9, 0.999), weight_decay=0.05))
|
||||
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
|
||||
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
|
||||
# optimizer = optim_wrapper_builder(model)
|
||||
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
||||
expected_layer_wise_wd_lr_beit)
|
@ -1,8 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmseg.core.builder import build_optimizer
|
||||
from mmengine.optim import build_optim_wrapper
|
||||
|
||||
|
||||
class ExampleModel(nn.Module):
|
||||
@ -29,6 +28,6 @@ def test_build_optimizer():
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum))
|
||||
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
|
||||
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
|
||||
# test whether optimizer is successfully built from parent.
|
||||
assert isinstance(optim_wrapper.optimizer, torch.optim.SGD)
|
@ -5,7 +5,7 @@ import numpy as np
|
||||
import torch
|
||||
from mmengine.data import BaseDataElement, PixelData
|
||||
|
||||
from mmseg.core import SegDataSample
|
||||
from mmseg.data import SegDataSample
|
||||
from mmseg.metrics import CitysMetric
|
||||
|
||||
|
||||
|
@ -5,7 +5,7 @@ import numpy as np
|
||||
import torch
|
||||
from mmengine.data import BaseDataElement, PixelData
|
||||
|
||||
from mmseg.core import SegDataSample
|
||||
from mmseg.data import SegDataSample
|
||||
from mmseg.metrics import IoUMetric
|
||||
|
||||
|
||||
|
@ -4,7 +4,7 @@ from unittest import TestCase
|
||||
import torch
|
||||
from mmengine.data import PixelData
|
||||
|
||||
from mmseg.core import SegDataSample
|
||||
from mmseg.data import SegDataSample
|
||||
from mmseg.models import SegDataPreProcessor
|
||||
|
||||
|
||||
|
@ -13,7 +13,7 @@ from mmcv.cnn.utils import revert_sync_batchnorm
|
||||
from mmengine.data import PixelData
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.core import SegDataSample
|
||||
from mmseg.data import SegDataSample
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
@ -2,7 +2,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.core import OHEMPixelSampler
|
||||
from mmseg.data import OHEMPixelSampler
|
||||
from mmseg.models.decode_heads import FCNHead
|
||||
|
||||
|
||||
|
@ -1,40 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
from mmseg.utils import find_latest_checkpoint
|
||||
|
||||
|
||||
def test_find_latest_checkpoint():
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
# no checkpoints in the path
|
||||
path = tempdir
|
||||
latest = find_latest_checkpoint(path)
|
||||
assert latest is None
|
||||
|
||||
# The path doesn't exist
|
||||
path = osp.join(tempdir, 'none')
|
||||
latest = find_latest_checkpoint(path)
|
||||
assert latest is None
|
||||
|
||||
# test when latest.pth exists
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
with open(osp.join(tempdir, 'latest.pth'), 'w') as f:
|
||||
f.write('latest')
|
||||
path = tempdir
|
||||
latest = find_latest_checkpoint(path)
|
||||
assert latest == osp.join(tempdir, 'latest.pth')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
for iter in range(1600, 160001, 1600):
|
||||
with open(osp.join(tempdir, f'iter_{iter}.pth'), 'w') as f:
|
||||
f.write(f'iter_{iter}.pth')
|
||||
latest = find_latest_checkpoint(tempdir)
|
||||
assert latest == osp.join(tempdir, 'iter_160000.pth')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
for epoch in range(1, 21):
|
||||
with open(osp.join(tempdir, f'epoch_{epoch}.pth'), 'w') as f:
|
||||
f.write(f'epoch_{epoch}.pth')
|
||||
latest = find_latest_checkpoint(tempdir)
|
||||
assert latest == osp.join(tempdir, 'epoch_20.pth')
|
@ -1,92 +1,11 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import datetime
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
from unittest import TestCase
|
||||
|
||||
import cv2
|
||||
import pytest
|
||||
from mmcv import Config
|
||||
from mmengine import DefaultScope
|
||||
|
||||
from mmseg.utils import register_all_modules, setup_multi_processes
|
||||
|
||||
|
||||
@pytest.mark.parametrize('workers_per_gpu', (0, 2))
|
||||
@pytest.mark.parametrize(('valid', 'env_cfg'), [(True,
|
||||
dict(
|
||||
mp_start_method='fork',
|
||||
opencv_num_threads=0,
|
||||
omp_num_threads=1,
|
||||
mkl_num_threads=1)),
|
||||
(False,
|
||||
dict(
|
||||
mp_start_method=1,
|
||||
opencv_num_threads=0.1,
|
||||
omp_num_threads='s',
|
||||
mkl_num_threads='1'))])
|
||||
def test_setup_multi_processes(workers_per_gpu, valid, env_cfg):
|
||||
# temp save system setting
|
||||
sys_start_mehod = mp.get_start_method(allow_none=True)
|
||||
sys_cv_threads = cv2.getNumThreads()
|
||||
# pop and temp save system env vars
|
||||
sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None)
|
||||
sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None)
|
||||
|
||||
config = dict(data=dict(workers_per_gpu=workers_per_gpu))
|
||||
config.update(env_cfg)
|
||||
cfg = Config(config)
|
||||
setup_multi_processes(cfg)
|
||||
|
||||
# test when cfg is valid and workers_per_gpu > 0
|
||||
# setup_multi_processes will work
|
||||
if valid and workers_per_gpu > 0:
|
||||
# test config without setting env
|
||||
|
||||
assert os.getenv('OMP_NUM_THREADS') == str(env_cfg['omp_num_threads'])
|
||||
assert os.getenv('MKL_NUM_THREADS') == str(env_cfg['mkl_num_threads'])
|
||||
# when set to 0, the num threads will be 1
|
||||
assert cv2.getNumThreads() == env_cfg[
|
||||
'opencv_num_threads'] if env_cfg['opencv_num_threads'] > 0 else 1
|
||||
if platform.system() != 'Windows':
|
||||
assert mp.get_start_method() == env_cfg['mp_start_method']
|
||||
|
||||
# revert setting to avoid affecting other programs
|
||||
if sys_start_mehod:
|
||||
mp.set_start_method(sys_start_mehod, force=True)
|
||||
cv2.setNumThreads(sys_cv_threads)
|
||||
if sys_omp_threads:
|
||||
os.environ['OMP_NUM_THREADS'] = sys_omp_threads
|
||||
else:
|
||||
os.environ.pop('OMP_NUM_THREADS')
|
||||
if sys_mkl_threads:
|
||||
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
|
||||
else:
|
||||
os.environ.pop('MKL_NUM_THREADS')
|
||||
|
||||
elif valid and workers_per_gpu == 0:
|
||||
|
||||
if platform.system() != 'Windows':
|
||||
assert mp.get_start_method() == env_cfg['mp_start_method']
|
||||
assert cv2.getNumThreads() == env_cfg[
|
||||
'opencv_num_threads'] if env_cfg['opencv_num_threads'] > 0 else 1
|
||||
assert 'OMP_NUM_THREADS' not in os.environ
|
||||
assert 'MKL_NUM_THREADS' not in os.environ
|
||||
if sys_start_mehod:
|
||||
mp.set_start_method(sys_start_mehod, force=True)
|
||||
cv2.setNumThreads(sys_cv_threads)
|
||||
if sys_omp_threads:
|
||||
os.environ['OMP_NUM_THREADS'] = sys_omp_threads
|
||||
if sys_mkl_threads:
|
||||
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
|
||||
|
||||
else:
|
||||
assert mp.get_start_method() == sys_start_mehod
|
||||
assert cv2.getNumThreads() == sys_cv_threads
|
||||
assert 'OMP_NUM_THREADS' not in os.environ
|
||||
assert 'MKL_NUM_THREADS' not in os.environ
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
|
||||
class TestSetupEnv(TestCase):
|
||||
|
@ -1,340 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import warnings
|
||||
from typing import Any, Iterable
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.parallel import MMDataParallel
|
||||
from mmcv.runner import get_dist_info
|
||||
from mmcv.utils import DictAction
|
||||
|
||||
# from mmseg.apis import single_gpu_test
|
||||
from mmseg.datasets import build_dataloader, build_dataset
|
||||
from mmseg.models.segmentors.base import BaseSegmentor
|
||||
from mmseg.ops import resize
|
||||
|
||||
single_gpu_test = None
|
||||
|
||||
|
||||
class ONNXRuntimeSegmentor(BaseSegmentor):
|
||||
|
||||
def __init__(self, onnx_file: str, cfg: Any, device_id: int):
|
||||
super(ONNXRuntimeSegmentor, self).__init__()
|
||||
import onnxruntime as ort
|
||||
|
||||
# get the custom op path
|
||||
ort_custom_op_path = ''
|
||||
try:
|
||||
from mmcv.ops import get_onnxruntime_op_path
|
||||
ort_custom_op_path = get_onnxruntime_op_path()
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
warnings.warn('If input model has custom op from mmcv, \
|
||||
you may have to build mmcv with ONNXRuntime from source.')
|
||||
session_options = ort.SessionOptions()
|
||||
# register custom op for onnxruntime
|
||||
if osp.exists(ort_custom_op_path):
|
||||
session_options.register_custom_ops_library(ort_custom_op_path)
|
||||
sess = ort.InferenceSession(onnx_file, session_options)
|
||||
providers = ['CPUExecutionProvider']
|
||||
options = [{}]
|
||||
is_cuda_available = ort.get_device() == 'GPU'
|
||||
if is_cuda_available:
|
||||
providers.insert(0, 'CUDAExecutionProvider')
|
||||
options.insert(0, {'device_id': device_id})
|
||||
|
||||
sess.set_providers(providers, options)
|
||||
|
||||
self.sess = sess
|
||||
self.device_id = device_id
|
||||
self.io_binding = sess.io_binding()
|
||||
self.output_names = [_.name for _ in sess.get_outputs()]
|
||||
for name in self.output_names:
|
||||
self.io_binding.bind_output(name)
|
||||
self.cfg = cfg
|
||||
self.test_mode = cfg.model.test_cfg.mode
|
||||
self.is_cuda_available = is_cuda_available
|
||||
|
||||
def extract_feat(self, imgs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def encode_decode(self, img, img_metas):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def forward_train(self, imgs, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def simple_test(self, img: torch.Tensor, img_meta: Iterable,
|
||||
**kwargs) -> list:
|
||||
if not self.is_cuda_available:
|
||||
img = img.detach().cpu()
|
||||
elif self.device_id >= 0:
|
||||
img = img.cuda(self.device_id)
|
||||
device_type = img.device.type
|
||||
self.io_binding.bind_input(
|
||||
name='input',
|
||||
device_type=device_type,
|
||||
device_id=self.device_id,
|
||||
element_type=np.float32,
|
||||
shape=img.shape,
|
||||
buffer_ptr=img.data_ptr())
|
||||
self.sess.run_with_iobinding(self.io_binding)
|
||||
seg_pred = self.io_binding.copy_outputs_to_cpu()[0]
|
||||
# whole might support dynamic reshape
|
||||
ori_shape = img_meta[0]['ori_shape']
|
||||
if not (ori_shape[0] == seg_pred.shape[-2]
|
||||
and ori_shape[1] == seg_pred.shape[-1]):
|
||||
seg_pred = torch.from_numpy(seg_pred).float()
|
||||
seg_pred = resize(
|
||||
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
|
||||
seg_pred = seg_pred.long().detach().cpu().numpy()
|
||||
seg_pred = seg_pred[0]
|
||||
seg_pred = list(seg_pred)
|
||||
return seg_pred
|
||||
|
||||
def aug_test(self, imgs, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
|
||||
class TensorRTSegmentor(BaseSegmentor):
|
||||
|
||||
def __init__(self, trt_file: str, cfg: Any, device_id: int):
|
||||
super(TensorRTSegmentor, self).__init__()
|
||||
from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin
|
||||
try:
|
||||
load_tensorrt_plugin()
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
warnings.warn('If input model has custom op from mmcv, \
|
||||
you may have to build mmcv with TensorRT from source.')
|
||||
model = TRTWraper(
|
||||
trt_file, input_names=['input'], output_names=['output'])
|
||||
|
||||
self.model = model
|
||||
self.device_id = device_id
|
||||
self.cfg = cfg
|
||||
self.test_mode = cfg.model.test_cfg.mode
|
||||
|
||||
def extract_feat(self, imgs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def encode_decode(self, img, img_metas):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def forward_train(self, imgs, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def simple_test(self, img: torch.Tensor, img_meta: Iterable,
|
||||
**kwargs) -> list:
|
||||
with torch.cuda.device(self.device_id), torch.no_grad():
|
||||
seg_pred = self.model({'input': img})['output']
|
||||
seg_pred = seg_pred.detach().cpu().numpy()
|
||||
# whole might support dynamic reshape
|
||||
ori_shape = img_meta[0]['ori_shape']
|
||||
if not (ori_shape[0] == seg_pred.shape[-2]
|
||||
and ori_shape[1] == seg_pred.shape[-1]):
|
||||
seg_pred = torch.from_numpy(seg_pred).float()
|
||||
seg_pred = resize(
|
||||
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
|
||||
seg_pred = seg_pred.long().detach().cpu().numpy()
|
||||
seg_pred = seg_pred[0]
|
||||
seg_pred = list(seg_pred)
|
||||
return seg_pred
|
||||
|
||||
def aug_test(self, imgs, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description='mmseg backend test (and eval)')
|
||||
parser.add_argument('config', help='test config file path')
|
||||
parser.add_argument('model', help='Input model file')
|
||||
parser.add_argument(
|
||||
'--backend',
|
||||
help='Backend of the model.',
|
||||
choices=['onnxruntime', 'tensorrt'])
|
||||
parser.add_argument('--out', help='output result file in pickle format')
|
||||
parser.add_argument(
|
||||
'--format-only',
|
||||
action='store_true',
|
||||
help='Format the output results without perform evaluation. It is'
|
||||
'useful when you want to format the result to a specific format and '
|
||||
'submit it to the test server')
|
||||
parser.add_argument(
|
||||
'--eval',
|
||||
type=str,
|
||||
nargs='+',
|
||||
help='evaluation metrics, which depends on the dataset, e.g., "mIoU"'
|
||||
' for generic datasets, and "cityscapes" for Cityscapes')
|
||||
parser.add_argument('--show', action='store_true', help='show results')
|
||||
parser.add_argument(
|
||||
'--show-dir', help='directory where painted images will be saved')
|
||||
parser.add_argument(
|
||||
'--options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help="--options is deprecated in favor of --cfg_options' and it will "
|
||||
'not be supported in version v0.22.0. Override some settings in the '
|
||||
'used config, the key-value pair in xxx=yyy format will be merged '
|
||||
'into config file. If the value to be overwritten is a list, it '
|
||||
'should be like key="[a,b]" or key=a,b It also allows nested '
|
||||
'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation '
|
||||
'marks are necessary and that no white space is allowed.')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
parser.add_argument(
|
||||
'--eval-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='custom options for evaluation')
|
||||
parser.add_argument(
|
||||
'--opacity',
|
||||
type=float,
|
||||
default=0.5,
|
||||
help='Opacity of painted segmentation map. In (0, 1] range.')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
if 'LOCAL_RANK' not in os.environ:
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
|
||||
if args.options and args.cfg_options:
|
||||
raise ValueError(
|
||||
'--options and --cfg-options cannot be both '
|
||||
'specified, --options is deprecated in favor of --cfg-options. '
|
||||
'--options will not be supported in version v0.22.0.')
|
||||
if args.options:
|
||||
warnings.warn('--options is deprecated in favor of --cfg-options. '
|
||||
'--options will not be supported in version v0.22.0.')
|
||||
args.cfg_options = args.options
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
assert args.out or args.eval or args.format_only or args.show \
|
||||
or args.show_dir, \
|
||||
('Please specify at least one operation (save/eval/format/show the '
|
||||
'results / save the results) with the argument "--out", "--eval"'
|
||||
', "--format-only", "--show" or "--show-dir"')
|
||||
|
||||
if args.eval and args.format_only:
|
||||
raise ValueError('--eval and --format_only cannot be both specified')
|
||||
|
||||
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
|
||||
raise ValueError('The output file must be a pkl file.')
|
||||
|
||||
cfg = mmcv.Config.fromfile(args.config)
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
cfg.model.pretrained = None
|
||||
cfg.data.test.test_mode = True
|
||||
|
||||
# init distributed env first, since logger depends on the dist info.
|
||||
distributed = False
|
||||
|
||||
# build the dataloader
|
||||
# TODO: support multiple images per gpu (only minor changes are needed)
|
||||
dataset = build_dataset(cfg.data.test)
|
||||
data_loader = build_dataloader(
|
||||
dataset,
|
||||
samples_per_gpu=1,
|
||||
workers_per_gpu=cfg.data.workers_per_gpu,
|
||||
dist=distributed,
|
||||
shuffle=False)
|
||||
|
||||
# load onnx config and meta
|
||||
cfg.model.train_cfg = None
|
||||
|
||||
if args.backend == 'onnxruntime':
|
||||
model = ONNXRuntimeSegmentor(args.model, cfg=cfg, device_id=0)
|
||||
elif args.backend == 'tensorrt':
|
||||
model = TensorRTSegmentor(args.model, cfg=cfg, device_id=0)
|
||||
|
||||
model.CLASSES = dataset.CLASSES
|
||||
model.PALETTE = dataset.PALETTE
|
||||
|
||||
# clean gpu memory when starting a new evaluation.
|
||||
torch.cuda.empty_cache()
|
||||
eval_kwargs = {} if args.eval_options is None else args.eval_options
|
||||
|
||||
# Deprecated
|
||||
efficient_test = eval_kwargs.get('efficient_test', False)
|
||||
if efficient_test:
|
||||
warnings.warn(
|
||||
'``efficient_test=True`` does not have effect in tools/test.py, '
|
||||
'the evaluation and format results are CPU memory efficient by '
|
||||
'default')
|
||||
|
||||
eval_on_format_results = (
|
||||
args.eval is not None and 'cityscapes' in args.eval)
|
||||
if eval_on_format_results:
|
||||
assert len(args.eval) == 1, 'eval on format results is not ' \
|
||||
'applicable for metrics other than ' \
|
||||
'cityscapes'
|
||||
if args.format_only or eval_on_format_results:
|
||||
if 'imgfile_prefix' in eval_kwargs:
|
||||
tmpdir = eval_kwargs['imgfile_prefix']
|
||||
else:
|
||||
tmpdir = '.format_cityscapes'
|
||||
eval_kwargs.setdefault('imgfile_prefix', tmpdir)
|
||||
mmcv.mkdir_or_exist(tmpdir)
|
||||
else:
|
||||
tmpdir = None
|
||||
|
||||
model = MMDataParallel(model, device_ids=[0])
|
||||
results = single_gpu_test(
|
||||
model,
|
||||
data_loader,
|
||||
args.show,
|
||||
args.show_dir,
|
||||
False,
|
||||
args.opacity,
|
||||
pre_eval=args.eval is not None and not eval_on_format_results,
|
||||
format_only=args.format_only or eval_on_format_results,
|
||||
format_args=eval_kwargs)
|
||||
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
if args.out:
|
||||
warnings.warn(
|
||||
'The behavior of ``args.out`` has been changed since MMSeg '
|
||||
'v0.16, the pickled outputs could be seg map as type of '
|
||||
'np.array, pre-eval results or file paths for '
|
||||
'``dataset.format_results()``.')
|
||||
print(f'\nwriting results to {args.out}')
|
||||
mmcv.dump(results, args.out)
|
||||
if args.eval:
|
||||
dataset.evaluate(results, args.eval, **eval_kwargs)
|
||||
if tmpdir is not None and eval_on_format_results:
|
||||
# remove tmp dir when cityscapes evaluation
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
# Following strings of text style are from colorama package
|
||||
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
|
||||
red_text, blue_text = '\x1b[31m', '\x1b[34m'
|
||||
white_background = '\x1b[107m'
|
||||
|
||||
msg = white_background + bright_style + red_text
|
||||
msg += 'DeprecationWarning: This tool will be deprecated in future. '
|
||||
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
|
||||
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
|
||||
msg += reset_style
|
||||
warnings.warn(msg)
|
@ -1,289 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from typing import Iterable, Optional, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
from mmcv.ops import get_onnxruntime_op_path
|
||||
from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt,
|
||||
save_trt_engine)
|
||||
|
||||
from mmseg.apis.inference import LoadImage
|
||||
from mmseg.datasets import DATASETS
|
||||
from mmseg.datasets.pipelines import Compose
|
||||
|
||||
|
||||
def get_GiB(x: int):
|
||||
"""return x GiB."""
|
||||
return x * (1 << 30)
|
||||
|
||||
|
||||
def _prepare_input_img(img_path: str,
|
||||
test_pipeline: Iterable[dict],
|
||||
shape: Optional[Iterable] = None,
|
||||
rescale_shape: Optional[Iterable] = None) -> dict:
|
||||
# build the data pipeline
|
||||
if shape is not None:
|
||||
test_pipeline[1]['img_scale'] = (shape[1], shape[0])
|
||||
test_pipeline[1]['transforms'][0]['keep_ratio'] = False
|
||||
test_pipeline = [LoadImage()] + test_pipeline[1:]
|
||||
test_pipeline = Compose(test_pipeline)
|
||||
# prepare data
|
||||
data = dict(img=img_path)
|
||||
data = test_pipeline(data)
|
||||
imgs = data['img']
|
||||
img_metas = [i.data for i in data['img_metas']]
|
||||
|
||||
if rescale_shape is not None:
|
||||
for img_meta in img_metas:
|
||||
img_meta['ori_shape'] = tuple(rescale_shape) + (3, )
|
||||
|
||||
mm_inputs = {'imgs': imgs, 'img_metas': img_metas}
|
||||
|
||||
return mm_inputs
|
||||
|
||||
|
||||
def _update_input_img(img_list: Iterable, img_meta_list: Iterable):
|
||||
# update img and its meta list
|
||||
N = img_list[0].size(0)
|
||||
img_meta = img_meta_list[0][0]
|
||||
img_shape = img_meta['img_shape']
|
||||
ori_shape = img_meta['ori_shape']
|
||||
pad_shape = img_meta['pad_shape']
|
||||
new_img_meta_list = [[{
|
||||
'img_shape':
|
||||
img_shape,
|
||||
'ori_shape':
|
||||
ori_shape,
|
||||
'pad_shape':
|
||||
pad_shape,
|
||||
'filename':
|
||||
img_meta['filename'],
|
||||
'scale_factor':
|
||||
(img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2,
|
||||
'flip':
|
||||
False,
|
||||
} for _ in range(N)]]
|
||||
|
||||
return img_list, new_img_meta_list
|
||||
|
||||
|
||||
def show_result_pyplot(img: Union[str, np.ndarray],
|
||||
result: np.ndarray,
|
||||
palette: Optional[Iterable] = None,
|
||||
fig_size: Iterable[int] = (15, 10),
|
||||
opacity: float = 0.5,
|
||||
title: str = '',
|
||||
block: bool = True):
|
||||
img = mmcv.imread(img)
|
||||
img = img.copy()
|
||||
seg = result[0]
|
||||
seg = mmcv.imresize(seg, img.shape[:2][::-1])
|
||||
palette = np.array(palette)
|
||||
assert palette.shape[1] == 3
|
||||
assert len(palette.shape) == 2
|
||||
assert 0 < opacity <= 1.0
|
||||
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
||||
for label, color in enumerate(palette):
|
||||
color_seg[seg == label, :] = color
|
||||
# convert to BGR
|
||||
color_seg = color_seg[..., ::-1]
|
||||
|
||||
img = img * (1 - opacity) + color_seg * opacity
|
||||
img = img.astype(np.uint8)
|
||||
|
||||
plt.figure(figsize=fig_size)
|
||||
plt.imshow(mmcv.bgr2rgb(img))
|
||||
plt.title(title)
|
||||
plt.tight_layout()
|
||||
plt.show(block=block)
|
||||
|
||||
|
||||
def onnx2tensorrt(onnx_file: str,
|
||||
trt_file: str,
|
||||
config: dict,
|
||||
input_config: dict,
|
||||
fp16: bool = False,
|
||||
verify: bool = False,
|
||||
show: bool = False,
|
||||
dataset: str = 'CityscapesDataset',
|
||||
workspace_size: int = 1,
|
||||
verbose: bool = False):
|
||||
import tensorrt as trt
|
||||
min_shape = input_config['min_shape']
|
||||
max_shape = input_config['max_shape']
|
||||
# create trt engine and wrapper
|
||||
opt_shape_dict = {'input': [min_shape, min_shape, max_shape]}
|
||||
max_workspace_size = get_GiB(workspace_size)
|
||||
trt_engine = onnx2trt(
|
||||
onnx_file,
|
||||
opt_shape_dict,
|
||||
log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR,
|
||||
fp16_mode=fp16,
|
||||
max_workspace_size=max_workspace_size)
|
||||
save_dir, _ = osp.split(trt_file)
|
||||
if save_dir:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_trt_engine(trt_engine, trt_file)
|
||||
print(f'Successfully created TensorRT engine: {trt_file}')
|
||||
|
||||
if verify:
|
||||
inputs = _prepare_input_img(
|
||||
input_config['input_path'],
|
||||
config.data.test.pipeline,
|
||||
shape=min_shape[2:])
|
||||
|
||||
imgs = inputs['imgs']
|
||||
img_metas = inputs['img_metas']
|
||||
img_list = [img[None, :] for img in imgs]
|
||||
img_meta_list = [[img_meta] for img_meta in img_metas]
|
||||
# update img_meta
|
||||
img_list, img_meta_list = _update_input_img(img_list, img_meta_list)
|
||||
|
||||
if max_shape[0] > 1:
|
||||
# concate flip image for batch test
|
||||
flip_img_list = [_.flip(-1) for _ in img_list]
|
||||
img_list = [
|
||||
torch.cat((ori_img, flip_img), 0)
|
||||
for ori_img, flip_img in zip(img_list, flip_img_list)
|
||||
]
|
||||
|
||||
# Get results from ONNXRuntime
|
||||
ort_custom_op_path = get_onnxruntime_op_path()
|
||||
session_options = ort.SessionOptions()
|
||||
if osp.exists(ort_custom_op_path):
|
||||
session_options.register_custom_ops_library(ort_custom_op_path)
|
||||
sess = ort.InferenceSession(onnx_file, session_options)
|
||||
sess.set_providers(['CPUExecutionProvider'], [{}]) # use cpu mode
|
||||
onnx_output = sess.run(['output'],
|
||||
{'input': img_list[0].detach().numpy()})[0][0]
|
||||
|
||||
# Get results from TensorRT
|
||||
trt_model = TRTWraper(trt_file, ['input'], ['output'])
|
||||
with torch.no_grad():
|
||||
trt_outputs = trt_model({'input': img_list[0].contiguous().cuda()})
|
||||
trt_output = trt_outputs['output'][0].cpu().detach().numpy()
|
||||
|
||||
if show:
|
||||
dataset = DATASETS.get(dataset)
|
||||
assert dataset is not None
|
||||
palette = dataset.PALETTE
|
||||
|
||||
show_result_pyplot(
|
||||
input_config['input_path'],
|
||||
(onnx_output[0].astype(np.uint8), ),
|
||||
palette=palette,
|
||||
title='ONNXRuntime',
|
||||
block=False)
|
||||
show_result_pyplot(
|
||||
input_config['input_path'], (trt_output[0].astype(np.uint8), ),
|
||||
palette=palette,
|
||||
title='TensorRT')
|
||||
|
||||
np.testing.assert_allclose(
|
||||
onnx_output, trt_output, rtol=1e-03, atol=1e-05)
|
||||
print('TensorRT and ONNXRuntime output all close.')
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert MMSegmentation models from ONNX to TensorRT')
|
||||
parser.add_argument('config', help='Config file of the model')
|
||||
parser.add_argument('model', help='Path to the input ONNX model')
|
||||
parser.add_argument(
|
||||
'--trt-file', type=str, help='Path to the output TensorRT engine')
|
||||
parser.add_argument(
|
||||
'--max-shape',
|
||||
type=int,
|
||||
nargs=4,
|
||||
default=[1, 3, 400, 600],
|
||||
help='Maximum shape of model input.')
|
||||
parser.add_argument(
|
||||
'--min-shape',
|
||||
type=int,
|
||||
nargs=4,
|
||||
default=[1, 3, 400, 600],
|
||||
help='Minimum shape of model input.')
|
||||
parser.add_argument('--fp16', action='store_true', help='Enable fp16 mode')
|
||||
parser.add_argument(
|
||||
'--workspace-size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Max workspace size in GiB')
|
||||
parser.add_argument(
|
||||
'--input-img', type=str, default='', help='Image for test')
|
||||
parser.add_argument(
|
||||
'--show', action='store_true', help='Whether to show output results')
|
||||
parser.add_argument(
|
||||
'--dataset',
|
||||
type=str,
|
||||
default='CityscapesDataset',
|
||||
help='Dataset name')
|
||||
parser.add_argument(
|
||||
'--verify',
|
||||
action='store_true',
|
||||
help='Verify the outputs of ONNXRuntime and TensorRT')
|
||||
parser.add_argument(
|
||||
'--verbose',
|
||||
action='store_true',
|
||||
help='Whether to verbose logging messages while creating \
|
||||
TensorRT engine.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.'
|
||||
args = parse_args()
|
||||
|
||||
if not args.input_img:
|
||||
args.input_img = osp.join(osp.dirname(__file__), '../demo/demo.png')
|
||||
|
||||
# check arguments
|
||||
assert osp.exists(args.config), 'Config {} not found.'.format(args.config)
|
||||
assert osp.exists(args.model), \
|
||||
'ONNX model {} not found.'.format(args.model)
|
||||
assert args.workspace_size >= 0, 'Workspace size less than 0.'
|
||||
assert DATASETS.get(args.dataset) is not None, \
|
||||
'Dataset {} does not found.'.format(args.dataset)
|
||||
for max_value, min_value in zip(args.max_shape, args.min_shape):
|
||||
assert max_value >= min_value, \
|
||||
'max_shape should be larger than min shape'
|
||||
|
||||
input_config = {
|
||||
'min_shape': args.min_shape,
|
||||
'max_shape': args.max_shape,
|
||||
'input_path': args.input_img
|
||||
}
|
||||
|
||||
cfg = mmcv.Config.fromfile(args.config)
|
||||
onnx2tensorrt(
|
||||
args.model,
|
||||
args.trt_file,
|
||||
cfg,
|
||||
input_config,
|
||||
fp16=args.fp16,
|
||||
verify=args.verify,
|
||||
show=args.show,
|
||||
dataset=args.dataset,
|
||||
workspace_size=args.workspace_size,
|
||||
verbose=args.verbose)
|
||||
|
||||
# Following strings of text style are from colorama package
|
||||
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
|
||||
red_text, blue_text = '\x1b[31m', '\x1b[34m'
|
||||
white_background = '\x1b[107m'
|
||||
|
||||
msg = white_background + bright_style + red_text
|
||||
msg += 'DeprecationWarning: This tool will be deprecated in future. '
|
||||
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
|
||||
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
|
||||
msg += reset_style
|
||||
warnings.warn(msg)
|
@ -1,405 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import warnings
|
||||
from functools import partial
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import onnxruntime as rt
|
||||
import torch
|
||||
import torch._C
|
||||
import torch.serialization
|
||||
from mmcv import DictAction
|
||||
from mmcv.onnx import register_extra_symbolics
|
||||
from mmcv.runner import load_checkpoint
|
||||
from torch import nn
|
||||
|
||||
from mmseg.apis import show_result_pyplot
|
||||
from mmseg.apis.inference import LoadImage
|
||||
from mmseg.datasets.pipelines import Compose
|
||||
from mmseg.models import build_segmentor
|
||||
from mmseg.ops import resize
|
||||
|
||||
torch.manual_seed(3)
|
||||
|
||||
|
||||
def _convert_batchnorm(module):
|
||||
module_output = module
|
||||
if isinstance(module, torch.nn.SyncBatchNorm):
|
||||
module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
|
||||
module.momentum, module.affine,
|
||||
module.track_running_stats)
|
||||
if module.affine:
|
||||
module_output.weight.data = module.weight.data.clone().detach()
|
||||
module_output.bias.data = module.bias.data.clone().detach()
|
||||
# keep requires_grad unchanged
|
||||
module_output.weight.requires_grad = module.weight.requires_grad
|
||||
module_output.bias.requires_grad = module.bias.requires_grad
|
||||
module_output.running_mean = module.running_mean
|
||||
module_output.running_var = module.running_var
|
||||
module_output.num_batches_tracked = module.num_batches_tracked
|
||||
for name, child in module.named_children():
|
||||
module_output.add_module(name, _convert_batchnorm(child))
|
||||
del module
|
||||
return module_output
|
||||
|
||||
|
||||
def _demo_mm_inputs(input_shape, num_classes):
|
||||
"""Create a superset of inputs needed to run test or train batches.
|
||||
|
||||
Args:
|
||||
input_shape (tuple):
|
||||
input batch dimensions
|
||||
num_classes (int):
|
||||
number of semantic classes
|
||||
"""
|
||||
(N, C, H, W) = input_shape
|
||||
rng = np.random.RandomState(0)
|
||||
imgs = rng.rand(*input_shape)
|
||||
segs = rng.randint(
|
||||
low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
|
||||
img_metas = [{
|
||||
'img_shape': (H, W, C),
|
||||
'ori_shape': (H, W, C),
|
||||
'pad_shape': (H, W, C),
|
||||
'filename': '<demo>.png',
|
||||
'scale_factor': 1.0,
|
||||
'flip': False,
|
||||
} for _ in range(N)]
|
||||
mm_inputs = {
|
||||
'imgs': torch.FloatTensor(imgs).requires_grad_(True),
|
||||
'img_metas': img_metas,
|
||||
'gt_semantic_seg': torch.LongTensor(segs)
|
||||
}
|
||||
return mm_inputs
|
||||
|
||||
|
||||
def _prepare_input_img(img_path,
|
||||
test_pipeline,
|
||||
shape=None,
|
||||
rescale_shape=None):
|
||||
# build the data pipeline
|
||||
if shape is not None:
|
||||
test_pipeline[1]['img_scale'] = (shape[1], shape[0])
|
||||
test_pipeline[1]['transforms'][0]['keep_ratio'] = False
|
||||
test_pipeline = [LoadImage()] + test_pipeline[1:]
|
||||
test_pipeline = Compose(test_pipeline)
|
||||
# prepare data
|
||||
data = dict(img=img_path)
|
||||
data = test_pipeline(data)
|
||||
imgs = data['img']
|
||||
img_metas = [i.data for i in data['img_metas']]
|
||||
|
||||
if rescale_shape is not None:
|
||||
for img_meta in img_metas:
|
||||
img_meta['ori_shape'] = tuple(rescale_shape) + (3, )
|
||||
|
||||
mm_inputs = {'imgs': imgs, 'img_metas': img_metas}
|
||||
|
||||
return mm_inputs
|
||||
|
||||
|
||||
def _update_input_img(img_list, img_meta_list, update_ori_shape=False):
|
||||
# update img and its meta list
|
||||
N, C, H, W = img_list[0].shape
|
||||
img_meta = img_meta_list[0][0]
|
||||
img_shape = (H, W, C)
|
||||
if update_ori_shape:
|
||||
ori_shape = img_shape
|
||||
else:
|
||||
ori_shape = img_meta['ori_shape']
|
||||
pad_shape = img_shape
|
||||
new_img_meta_list = [[{
|
||||
'img_shape':
|
||||
img_shape,
|
||||
'ori_shape':
|
||||
ori_shape,
|
||||
'pad_shape':
|
||||
pad_shape,
|
||||
'filename':
|
||||
img_meta['filename'],
|
||||
'scale_factor':
|
||||
(img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2,
|
||||
'flip':
|
||||
False,
|
||||
} for _ in range(N)]]
|
||||
|
||||
return img_list, new_img_meta_list
|
||||
|
||||
|
||||
def pytorch2onnx(model,
|
||||
mm_inputs,
|
||||
opset_version=11,
|
||||
show=False,
|
||||
output_file='tmp.onnx',
|
||||
verify=False,
|
||||
dynamic_export=False):
|
||||
"""Export Pytorch model to ONNX model and verify the outputs are same
|
||||
between Pytorch and ONNX.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Pytorch model we want to export.
|
||||
mm_inputs (dict): Contain the input tensors and img_metas information.
|
||||
opset_version (int): The onnx op version. Default: 11.
|
||||
show (bool): Whether print the computation graph. Default: False.
|
||||
output_file (string): The path to where we store the output ONNX model.
|
||||
Default: `tmp.onnx`.
|
||||
verify (bool): Whether compare the outputs between Pytorch and ONNX.
|
||||
Default: False.
|
||||
dynamic_export (bool): Whether to export ONNX with dynamic axis.
|
||||
Default: False.
|
||||
"""
|
||||
model.cpu().eval()
|
||||
test_mode = model.test_cfg.mode
|
||||
|
||||
if isinstance(model.decode_head, nn.ModuleList):
|
||||
num_classes = model.decode_head[-1].num_classes
|
||||
else:
|
||||
num_classes = model.decode_head.num_classes
|
||||
|
||||
imgs = mm_inputs.pop('imgs')
|
||||
img_metas = mm_inputs.pop('img_metas')
|
||||
|
||||
img_list = [img[None, :] for img in imgs]
|
||||
img_meta_list = [[img_meta] for img_meta in img_metas]
|
||||
# update img_meta
|
||||
img_list, img_meta_list = _update_input_img(img_list, img_meta_list)
|
||||
|
||||
# replace original forward function
|
||||
origin_forward = model.forward
|
||||
model.forward = partial(
|
||||
model.forward,
|
||||
img_metas=img_meta_list,
|
||||
return_loss=False,
|
||||
rescale=True)
|
||||
dynamic_axes = None
|
||||
if dynamic_export:
|
||||
if test_mode == 'slide':
|
||||
dynamic_axes = {'input': {0: 'batch'}, 'output': {1: 'batch'}}
|
||||
else:
|
||||
dynamic_axes = {
|
||||
'input': {
|
||||
0: 'batch',
|
||||
2: 'height',
|
||||
3: 'width'
|
||||
},
|
||||
'output': {
|
||||
1: 'batch',
|
||||
2: 'height',
|
||||
3: 'width'
|
||||
}
|
||||
}
|
||||
|
||||
register_extra_symbolics(opset_version)
|
||||
with torch.no_grad():
|
||||
torch.onnx.export(
|
||||
model, (img_list, ),
|
||||
output_file,
|
||||
input_names=['input'],
|
||||
output_names=['output'],
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=False,
|
||||
verbose=show,
|
||||
opset_version=opset_version,
|
||||
dynamic_axes=dynamic_axes)
|
||||
print(f'Successfully exported ONNX model: {output_file}')
|
||||
model.forward = origin_forward
|
||||
|
||||
if verify:
|
||||
# check by onnx
|
||||
import onnx
|
||||
onnx_model = onnx.load(output_file)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
|
||||
if dynamic_export and test_mode == 'whole':
|
||||
# scale image for dynamic shape test
|
||||
img_list = [resize(_, scale_factor=1.5) for _ in img_list]
|
||||
# concate flip image for batch test
|
||||
flip_img_list = [_.flip(-1) for _ in img_list]
|
||||
img_list = [
|
||||
torch.cat((ori_img, flip_img), 0)
|
||||
for ori_img, flip_img in zip(img_list, flip_img_list)
|
||||
]
|
||||
|
||||
# update img_meta
|
||||
img_list, img_meta_list = _update_input_img(
|
||||
img_list, img_meta_list, test_mode == 'whole')
|
||||
|
||||
# check the numerical value
|
||||
# get pytorch output
|
||||
with torch.no_grad():
|
||||
pytorch_result = model(img_list, img_meta_list, return_loss=False)
|
||||
pytorch_result = np.stack(pytorch_result, 0)
|
||||
|
||||
# get onnx output
|
||||
input_all = [node.name for node in onnx_model.graph.input]
|
||||
input_initializer = [
|
||||
node.name for node in onnx_model.graph.initializer
|
||||
]
|
||||
net_feed_input = list(set(input_all) - set(input_initializer))
|
||||
assert (len(net_feed_input) == 1)
|
||||
sess = rt.InferenceSession(output_file)
|
||||
onnx_result = sess.run(
|
||||
None, {net_feed_input[0]: img_list[0].detach().numpy()})[0][0]
|
||||
# show segmentation results
|
||||
if show:
|
||||
import os.path as osp
|
||||
|
||||
import cv2
|
||||
img = img_meta_list[0][0]['filename']
|
||||
if not osp.exists(img):
|
||||
img = imgs[0][:3, ...].permute(1, 2, 0) * 255
|
||||
img = img.detach().numpy().astype(np.uint8)
|
||||
ori_shape = img.shape[:2]
|
||||
else:
|
||||
ori_shape = LoadImage()({'img': img})['ori_shape']
|
||||
|
||||
# resize onnx_result to ori_shape
|
||||
onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8),
|
||||
(ori_shape[1], ori_shape[0]))
|
||||
show_result_pyplot(
|
||||
model,
|
||||
img, (onnx_result_, ),
|
||||
palette=model.PALETTE,
|
||||
block=False,
|
||||
title='ONNXRuntime',
|
||||
opacity=0.5)
|
||||
|
||||
# resize pytorch_result to ori_shape
|
||||
pytorch_result_ = cv2.resize(pytorch_result[0].astype(np.uint8),
|
||||
(ori_shape[1], ori_shape[0]))
|
||||
show_result_pyplot(
|
||||
model,
|
||||
img, (pytorch_result_, ),
|
||||
title='PyTorch',
|
||||
palette=model.PALETTE,
|
||||
opacity=0.5)
|
||||
# compare results
|
||||
np.testing.assert_allclose(
|
||||
pytorch_result.astype(np.float32) / num_classes,
|
||||
onnx_result.astype(np.float32) / num_classes,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
err_msg='The outputs are different between Pytorch and ONNX')
|
||||
print('The outputs are same between Pytorch and ONNX')
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX')
|
||||
parser.add_argument('config', help='test config file path')
|
||||
parser.add_argument('--checkpoint', help='checkpoint file', default=None)
|
||||
parser.add_argument(
|
||||
'--input-img', type=str, help='Images for input', default=None)
|
||||
parser.add_argument(
|
||||
'--show',
|
||||
action='store_true',
|
||||
help='show onnx graph and segmentation results')
|
||||
parser.add_argument(
|
||||
'--verify', action='store_true', help='verify the onnx model')
|
||||
parser.add_argument('--output-file', type=str, default='tmp.onnx')
|
||||
parser.add_argument('--opset-version', type=int, default=11)
|
||||
parser.add_argument(
|
||||
'--shape',
|
||||
type=int,
|
||||
nargs='+',
|
||||
default=None,
|
||||
help='input image height and width.')
|
||||
parser.add_argument(
|
||||
'--rescale_shape',
|
||||
type=int,
|
||||
nargs='+',
|
||||
default=None,
|
||||
help='output image rescale height and width, work for slide mode.')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='Override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
parser.add_argument(
|
||||
'--dynamic-export',
|
||||
action='store_true',
|
||||
help='Whether to export onnx with dynamic axis.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
cfg = mmcv.Config.fromfile(args.config)
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
cfg.model.pretrained = None
|
||||
|
||||
if args.shape is None:
|
||||
img_scale = cfg.test_pipeline[1]['img_scale']
|
||||
input_shape = (1, 3, img_scale[1], img_scale[0])
|
||||
elif len(args.shape) == 1:
|
||||
input_shape = (1, 3, args.shape[0], args.shape[0])
|
||||
elif len(args.shape) == 2:
|
||||
input_shape = (
|
||||
1,
|
||||
3,
|
||||
) + tuple(args.shape)
|
||||
else:
|
||||
raise ValueError('invalid input shape')
|
||||
|
||||
test_mode = cfg.model.test_cfg.mode
|
||||
|
||||
# build the model and load checkpoint
|
||||
cfg.model.train_cfg = None
|
||||
segmentor = build_segmentor(
|
||||
cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
|
||||
# convert SyncBN to BN
|
||||
segmentor = _convert_batchnorm(segmentor)
|
||||
|
||||
if args.checkpoint:
|
||||
checkpoint = load_checkpoint(
|
||||
segmentor, args.checkpoint, map_location='cpu')
|
||||
segmentor.CLASSES = checkpoint['meta']['CLASSES']
|
||||
segmentor.PALETTE = checkpoint['meta']['PALETTE']
|
||||
|
||||
# read input or create dummpy input
|
||||
if args.input_img is not None:
|
||||
preprocess_shape = (input_shape[2], input_shape[3])
|
||||
rescale_shape = None
|
||||
if args.rescale_shape is not None:
|
||||
rescale_shape = [args.rescale_shape[0], args.rescale_shape[1]]
|
||||
mm_inputs = _prepare_input_img(
|
||||
args.input_img,
|
||||
cfg.data.test.pipeline,
|
||||
shape=preprocess_shape,
|
||||
rescale_shape=rescale_shape)
|
||||
else:
|
||||
if isinstance(segmentor.decode_head, nn.ModuleList):
|
||||
num_classes = segmentor.decode_head[-1].num_classes
|
||||
else:
|
||||
num_classes = segmentor.decode_head.num_classes
|
||||
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
|
||||
|
||||
# convert model to onnx file
|
||||
pytorch2onnx(
|
||||
segmentor,
|
||||
mm_inputs,
|
||||
opset_version=args.opset_version,
|
||||
show=args.show,
|
||||
output_file=args.output_file,
|
||||
verify=args.verify,
|
||||
dynamic_export=args.dynamic_export)
|
||||
|
||||
# Following strings of text style are from colorama package
|
||||
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
|
||||
red_text, blue_text = '\x1b[31m', '\x1b[34m'
|
||||
white_background = '\x1b[107m'
|
||||
|
||||
msg = white_background + bright_style + red_text
|
||||
msg += 'DeprecationWarning: This tool will be deprecated in future. '
|
||||
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
|
||||
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
|
||||
msg += reset_style
|
||||
warnings.warn(msg)
|
Loading…
x
Reference in New Issue
Block a user