[Refactory] MMSegmentation Content

This commit is contained in:
zhengmiao 2022-07-15 15:47:29 +00:00
parent fba91957c0
commit 4b76f277a6
71 changed files with 266 additions and 1719 deletions

1
.gitignore vendored
View File

@ -105,7 +105,6 @@ venv.bak/
# mypy
.mypy_cache/
data
.vscode
.idea

View File

@ -145,7 +145,7 @@
"outputs": [],
"source": [
"from mmseg.apis import inference_model, init_model, show_result_pyplot\n",
"from mmseg.core.evaluation import get_palette"
"from mmseg.utils import get_palette"
]
},
{

View File

@ -2,7 +2,7 @@
from argparse import ArgumentParser
from mmseg.apis import inference_model, init_model, show_result_pyplot
from mmseg.core.evaluation import get_palette
from mmseg.utils import get_palette
def main():

View File

@ -21,7 +21,7 @@
"outputs": [],
"source": [
"from mmseg.apis import init_model, inference_model, show_result_pyplot\n",
"from mmseg.core.evaluation import get_palette"
"from mmseg.utils import get_palette"
]
},
{

View File

@ -4,7 +4,7 @@ from argparse import ArgumentParser
import cv2
from mmseg.apis import inference_model, init_model
from mmseg.core.evaluation import get_palette
from mmseg.utils import get_palette
def main():

View File

@ -5,7 +5,7 @@ import torch
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
from mmseg.datasets.pipelines import Compose
from mmseg.datasets.transforms import Compose
from mmseg.models import build_segmentor

View File

@ -1,9 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import build_optimizer
from .data_structures import * # noqa: F401, F403
from .evaluation import * # noqa: F401, F403
from .optimizers import * # noqa: F401, F403
from .seg import * # noqa: F401, F403
from .utils import * # noqa: F401, F403
__all__ = ['build_optimizer']

View File

@ -1,18 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS
def build_optimizer(model, cfg):
optim_wrapper_cfg = copy.deepcopy(cfg)
constructor_type = optim_wrapper_cfg.pop('constructor',
'DefaultOptimWrapperConstructor')
paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None)
optim_wrapper_builder = OPTIM_WRAPPER_CONSTRUCTORS.build(
dict(
type=constructor_type,
optim_wrapper_cfg=optim_wrapper_cfg,
paramwise_cfg=paramwise_cfg))
optim_wrapper = optim_wrapper_builder(model)
return optim_wrapper

View File

@ -1,4 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .seg_data_sample import SegDataSample
__all__ = ['SegDataSample']

View File

@ -1,4 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .class_names import get_classes, get_palette
__all__ = ['get_classes', 'get_palette']

View File

@ -1,5 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_pixel_sampler import BasePixelSampler
from .ohem_pixel_sampler import OHEMPixelSampler
__all__ = ['BasePixelSampler', 'OHEMPixelSampler']

View File

@ -1,12 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dist_util import check_dist_init, sync_random_seed
from .misc import add_prefix, stack_batch
from .typing import (ConfigType, ForwardResults, MultiConfig, OptConfigType,
OptMultiConfig, OptSampleList, SampleList, TensorDict,
TensorList)
__all__ = [
'add_prefix', 'check_dist_init', 'sync_random_seed', 'stack_batch',
'ConfigType', 'OptConfigType', 'MultiConfig', 'OptMultiConfig',
'SampleList', 'OptSampleList', 'TensorDict', 'TensorList', 'ForwardResults'
]

View File

@ -1,46 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import get_dist_info
def check_dist_init():
return dist.is_available() and dist.is_initialized()
def sync_random_seed(seed=None, device='cuda'):
"""Make sure different ranks share the same seed. All workers must call
this function, otherwise it will deadlock. This method is generally used in
`DistributedSampler`, because the seed should be identical across all
processes in the distributed group.
In distributed sampling, different ranks should sample non-overlapped
data in the dataset. Therefore, this function is used to make sure that
each rank shuffles the data indices in the same order based
on the same seed. Then different ranks could use different indices
to select non-overlapped data from the same data list.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is None:
seed = np.random.randint(2**31)
assert isinstance(seed, int)
rank, world_size = get_dist_info()
if world_size == 1:
return seed
if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()

View File

@ -1,108 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from mmseg.core.utils.typing import SampleList
def add_prefix(inputs, prefix):
"""Add prefix for dict.
Args:
inputs (dict): The input dict with str keys.
prefix (str): The prefix to add.
Returns:
dict: The dict with keys updated with ``prefix``.
"""
outputs = dict()
for name, value in inputs.items():
outputs[f'{prefix}.{name}'] = value
return outputs
def stack_batch(inputs: List[torch.Tensor],
batch_data_samples: Optional[SampleList] = None,
size: Optional[tuple] = None,
size_divisor: Optional[int] = None,
pad_val: Union[int, float] = 0,
seg_pad_val: Union[int, float] = 255) -> torch.Tensor:
"""Stack multiple inputs to form a batch and pad the images and gt_sem_segs
to the max shape use the right bottom padding mode.
Args:
inputs (List[Tensor]): The input multiple tensors. each is a
CHW 3D-tensor.
batch_data_samples (list[:obj:`SegDataSample`]): The Data
Samples. It usually includes information such as `gt_sem_seg`.
size (tuple, optional): Fixed padding size.
size_divisor (int, optional): The divisor of padded size.
pad_val (int, float): The padding value. Defaults to 0
seg_pad_val (int, float): The padding value. Defaults to 255
Returns:
Tensor: The 4D-tensor.
batch_data_samples (list[:obj:`SegDataSample`]): After the padding of
the gt_seg_map.
"""
assert isinstance(inputs, list), \
f'Expected input type to be list, but got {type(inputs)}'
assert len(set([tensor.ndim for tensor in inputs])) == 1, \
f'Expected the dimensions of all inputs must be the same, ' \
f'but got {[tensor.ndim for tensor in inputs]}'
assert inputs[0].ndim == 3, f'Expected tensor dimension to be 3, ' \
f'but got {inputs[0].ndim}'
assert len(set([tensor.shape[0] for tensor in inputs])) == 1, \
f'Expected the channels of all inputs must be the same, ' \
f'but got {[tensor.shape[0] for tensor in inputs]}'
# only one of size and size_divisor should be valid
assert (size is not None) ^ (size_divisor is not None), \
'only one of size and size_divisor should be valid'
padded_inputs = []
padded_samples = []
inputs_sizes = [(img.shape[-2], img.shape[-1]) for img in inputs]
max_size = np.stack(inputs_sizes).max(0)
if size_divisor is not None and size_divisor > 1:
# the last two dims are H,W, both subject to divisibility requirement
max_size = (max_size +
(size_divisor - 1)) // size_divisor * size_divisor
for i in range(len(inputs)):
tensor = inputs[i]
if size is not None:
width = max(size[-1] - tensor.shape[-1], 0)
height = max(size[-2] - tensor.shape[-2], 0)
# (padding_left, padding_right, padding_top, padding_bottom)
padding_size = (0, width, 0, height)
elif size_divisor is not None:
width = max(max_size[-1] - tensor.shape[-1], 0)
height = max(max_size[-2] - tensor.shape[-2], 0)
padding_size = (0, width, 0, height)
else:
padding_size = [0, 0, 0, 0]
# pad img
pad_img = F.pad(tensor, padding_size, value=pad_val)
padded_inputs.append(pad_img)
# pad gt_sem_seg
if batch_data_samples is not None:
data_sample = batch_data_samples[i]
gt_sem_seg = data_sample.gt_sem_seg.data
del data_sample.gt_sem_seg.data
data_sample.gt_sem_seg.data = F.pad(
gt_sem_seg, padding_size, value=seg_pad_val)
data_sample.set_metainfo(
{'pad_shape': data_sample.gt_sem_seg.shape})
padded_samples.append(data_sample)
else:
padded_samples = None
return torch.stack(padded_inputs, dim=0), padded_samples

8
mmseg/data/__init__.py Normal file
View File

@ -0,0 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .sampler import BasePixelSampler, OHEMPixelSampler, build_pixel_sampler
from .seg_data_sample import SegDataSample
__all__ = [
'SegDataSample', 'BasePixelSampler', 'OHEMPixelSampler',
'build_pixel_sampler'
]

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_pixel_sampler import BasePixelSampler
from .builder import build_pixel_sampler
from .sampler import BasePixelSampler, OHEMPixelSampler
from .ohem_pixel_sampler import OHEMPixelSampler
__all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler']

View File

@ -3,8 +3,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import PIXEL_SAMPLERS
from .base_pixel_sampler import BasePixelSampler
from .builder import PIXEL_SAMPLERS
@PIXEL_SAMPLERS.register_module()

View File

@ -1,9 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
# flake8: noqa
import warnings
from .formatting import *
warnings.warn('DeprecationWarning: mmseg.datasets.pipelines.formating will be '
'deprecated in 2021, please replace it with '
'mmseg.datasets.pipelines.formatting.')

View File

@ -1,4 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .distributed_sampler import DistributedSampler
__all__ = ['DistributedSampler']

View File

@ -1,73 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from __future__ import division
from typing import Iterator, Optional
import torch
from torch.utils.data import Dataset
from torch.utils.data import DistributedSampler as _DistributedSampler
from mmseg.core.utils import sync_random_seed
from mmseg.registry import DATA_SAMPLERS
@DATA_SAMPLERS.register_module()
class DistributedSampler(_DistributedSampler):
"""DistributedSampler inheriting from
`torch.utils.data.DistributedSampler`.
Args:
datasets (Dataset): the dataset will be loaded.
num_replicas (int, optional): Number of processes participating in
distributed training. By default, world_size is retrieved from the
current distributed group.
rank (int, optional): Rank of the current process within num_replicas.
By default, rank is retrieved from the current distributed group.
shuffle (bool): If True (default), sampler will shuffle the indices.
seed (int): random seed used to shuffle the sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. Default: ``0``.
"""
def __init__(self,
dataset: Dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed=0) -> None:
super().__init__(
dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
# In distributed sampling, different ranks should sample
# non-overlapped data in the dataset. Therefore, this function
# is used to make sure that each rank shuffles the data indices
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed)
def __iter__(self) -> Iterator:
"""
Yields:
Iterator: iterator of indices for rank.
"""
# deterministically shuffle based on epoch
if self.shuffle:
g = torch.Generator()
# When :attr:`shuffle=True`, this ensures all replicas
# use a different random ordering for each epoch.
# Otherwise, the next iteration of this sampler will
# yield the same ordering.
g.manual_seed(self.epoch + self.seed)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = torch.arange(len(self.dataset)).tolist()
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)

View File

@ -5,7 +5,7 @@ from mmcv.transforms import to_tensor
from mmcv.transforms.base import BaseTransform
from mmengine.data import PixelData
from mmseg.core import SegDataSample
from mmseg.data import SegDataSample
from mmseg.registry import TRANSFORMS

View File

@ -3,10 +3,10 @@ import json
import warnings
from mmengine.dist import get_dist_info
from mmengine.logging import print_log
from mmengine.optim import DefaultOptimWrapperConstructor
from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS
from mmseg.utils import get_root_logger
def get_layer_id_for_convnext(var_name, max_layer_id):
@ -119,15 +119,14 @@ class LearningRateDecayOptimizerConstructor(DefaultOptimWrapperConstructor):
in place.
module (nn.Module): The module to be added.
"""
logger = get_root_logger()
parameter_groups = {}
logger.info(f'self.paramwise_cfg is {self.paramwise_cfg}')
print_log(f'self.paramwise_cfg is {self.paramwise_cfg}')
num_layers = self.paramwise_cfg.get('num_layers') + 2
decay_rate = self.paramwise_cfg.get('decay_rate')
decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise')
logger.info('Build LearningRateDecayOptimizerConstructor '
f'{decay_type} {decay_rate} - {num_layers}')
print_log('Build LearningRateDecayOptimizerConstructor '
f'{decay_type} {decay_rate} - {num_layers}')
weight_decay = self.base_wd
for name, param in module.named_parameters():
if not param.requires_grad:
@ -143,17 +142,17 @@ class LearningRateDecayOptimizerConstructor(DefaultOptimWrapperConstructor):
if 'ConvNeXt' in module.backbone.__class__.__name__:
layer_id = get_layer_id_for_convnext(
name, self.paramwise_cfg.get('num_layers'))
logger.info(f'set param {name} as id {layer_id}')
print_log(f'set param {name} as id {layer_id}')
elif 'BEiT' in module.backbone.__class__.__name__ or \
'MAE' in module.backbone.__class__.__name__:
layer_id = get_layer_id_for_vit(name, num_layers)
logger.info(f'set param {name} as id {layer_id}')
print_log(f'set param {name} as id {layer_id}')
else:
raise NotImplementedError()
elif decay_type == 'stage_wise':
if 'ConvNeXt' in module.backbone.__class__.__name__:
layer_id = get_stage_id_for_convnext(name, num_layers)
logger.info(f'set param {name} as id {layer_id}')
print_log(f'set param {name} as id {layer_id}')
else:
raise NotImplementedError()
group_name = f'layer_{layer_id}_{group_name}'
@ -182,7 +181,7 @@ class LearningRateDecayOptimizerConstructor(DefaultOptimWrapperConstructor):
'lr': parameter_groups[key]['lr'],
'weight_decay': parameter_groups[key]['weight_decay'],
}
logger.info(f'Param groups = {json.dumps(to_display, indent=2)}')
print_log(f'Param groups = {json.dumps(to_display, indent=2)}')
params.extend(parameter_groups.values())

View File

@ -14,7 +14,6 @@ from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.utils import _pair as to_2tuple
from mmseg.registry import MODELS
from mmseg.utils import get_root_logger
from ..utils import PatchEmbed
from .vit import TransformerEncoderLayer as VisionTransformerEncoderLayer
@ -500,9 +499,8 @@ class BEiT(BaseModule):
if (isinstance(self.init_cfg, dict)
and self.init_cfg.get('type') == 'Pretrained'):
logger = get_root_logger()
checkpoint = _load_checkpoint(
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
state_dict = self.resize_rel_pos_embed(checkpoint)
self.load_state_dict(state_dict, False)
elif self.init_cfg is not None:

View File

@ -9,7 +9,6 @@ from mmcv.runner import ModuleList, _load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from mmseg.registry import MODELS
from mmseg.utils import get_root_logger
from .beit import BEiT, BEiTAttention, BEiTTransformerEncoderLayer
@ -180,9 +179,8 @@ class MAE(BEiT):
if (isinstance(self.init_cfg, dict)
and self.init_cfg.get('type') == 'Pretrained'):
logger = get_root_logger()
checkpoint = _load_checkpoint(
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
state_dict = self.resize_rel_pos_embed(checkpoint)
state_dict = self.resize_abs_pos_embed(state_dict)
self.load_state_dict(state_dict, False)

View File

@ -14,9 +14,9 @@ from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_,
from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList,
load_state_dict)
from mmcv.utils import to_2tuple
from mmengine.logging import print_log
from mmseg.registry import MODELS
from ...utils import get_root_logger
from ..utils.embed import PatchEmbed, PatchMerging
@ -662,11 +662,10 @@ class SwinTransformer(BaseModule):
param.requires_grad = False
def init_weights(self):
logger = get_root_logger()
if self.init_cfg is None:
logger.warn(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
print_log(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
if self.use_abs_pos_embed:
trunc_normal_(self.absolute_pos_embed, std=0.02)
for m in self.modules():
@ -680,7 +679,7 @@ class SwinTransformer(BaseModule):
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
ckpt = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
if 'state_dict' in ckpt:
_state_dict = ckpt['state_dict']
elif 'model' in ckpt:
@ -705,7 +704,7 @@ class SwinTransformer(BaseModule):
N1, L, C1 = absolute_pos_embed.size()
N2, C2, H, W = self.absolute_pos_embed.size()
if N1 != N2 or C1 != C2 or L != H * W:
logger.warning('Error in loading absolute_pos_embed, pass')
print_log('Error in loading absolute_pos_embed, pass')
else:
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
N2, H, W, C2).permute(0, 3, 1, 2).contiguous()
@ -721,7 +720,7 @@ class SwinTransformer(BaseModule):
L1, nH1 = table_pretrained.size()
L2, nH2 = table_current.size()
if nH1 != nH2:
logger.warning(f'Error in loading {table_key}, pass')
print_log(f'Error in loading {table_key}, pass')
elif L1 != L2:
S1 = int(L1**0.5)
S2 = int(L2**0.5)
@ -733,7 +732,7 @@ class SwinTransformer(BaseModule):
nH2, L2).permute(1, 0).contiguous()
# load state_dict
load_state_dict(self, state_dict, strict=False, logger=logger)
load_state_dict(self, state_dict, strict=False, logger=None)
def forward(self, x):
x, hw_shape = self.patch_embed(x)

View File

@ -11,12 +11,12 @@ from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
trunc_normal_)
from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList,
load_state_dict)
from mmengine.logging import print_log
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.utils import _pair as to_2tuple
from mmseg.ops import resize
from mmseg.registry import MODELS
from mmseg.utils import get_root_logger
from ..utils import PatchEmbed
@ -293,9 +293,8 @@ class VisionTransformer(BaseModule):
def init_weights(self):
if (isinstance(self.init_cfg, dict)
and self.init_cfg.get('type') == 'Pretrained'):
logger = get_root_logger()
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
@ -304,9 +303,9 @@ class VisionTransformer(BaseModule):
if 'pos_embed' in state_dict.keys():
if self.pos_embed.shape != state_dict['pos_embed'].shape:
logger.info(msg=f'Resize the pos_embed shape from '
f'{state_dict["pos_embed"].shape} to '
f'{self.pos_embed.shape}')
print_log(msg=f'Resize the pos_embed shape from '
f'{state_dict["pos_embed"].shape} to '
f'{self.pos_embed.shape}')
h, w = self.img_size
pos_size = int(
math.sqrt(state_dict['pos_embed'].shape[1] - 1))
@ -315,7 +314,7 @@ class VisionTransformer(BaseModule):
(h // self.patch_size, w // self.patch_size),
(pos_size, pos_size), self.interpolate_mode)
load_state_dict(self, state_dict, strict=False, logger=logger)
load_state_dict(self, state_dict, strict=False, logger=None)
elif self.init_cfg is not None:
super(VisionTransformer, self).init_weights()
else:

View File

@ -6,9 +6,8 @@ import torch
from mmengine.model import BaseDataPreprocessor
from torch import Tensor
from mmseg.core import stack_batch
from mmseg.core.utils import OptSampleList
from mmseg.registry import MODELS
from mmseg.utils import OptSampleList, stack_batch
@MODELS.register_module()

View File

@ -4,7 +4,7 @@ from typing import List
from torch import Tensor
from mmseg.core.utils import ConfigType
from mmseg.utils import ConfigType
from .decode_head import BaseDecodeHead

View File

@ -6,9 +6,8 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule, Scale
from torch import Tensor, nn
from mmseg.core import add_prefix
from mmseg.core.utils import SampleList
from mmseg.registry import MODELS
from mmseg.utils import SampleList, add_prefix
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
from .decode_head import BaseDecodeHead

View File

@ -7,10 +7,9 @@ import torch.nn as nn
from mmcv.runner import BaseModule
from torch import Tensor
from mmseg.core import build_pixel_sampler
from mmseg.core.utils import SampleList
from mmseg.core.utils.typing import ConfigType
from mmseg.data import build_pixel_sampler
from mmseg.ops import resize
from mmseg.utils import ConfigType, SampleList
from ..builder import build_loss
from ..losses import accuracy

View File

@ -7,10 +7,9 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_norm_layer
from torch import Tensor
from mmseg.core.utils import SampleList
from mmseg.core.utils.typing import ConfigType
from mmseg.ops import Encoding, resize
from mmseg.registry import MODELS
from mmseg.utils import ConfigType, SampleList
from ..builder import build_loss
from .decode_head import BaseDecodeHead

View File

@ -8,12 +8,12 @@ from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
from mmcv.cnn.bricks.transformer import (FFN, TRANSFORMER_LAYER,
MultiheadAttention,
build_transformer_layer)
from mmengine.logging import print_log
from torch import Tensor
from mmseg.core.utils import SampleList
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.registry import MODELS
from mmseg.utils import get_root_logger
from mmseg.utils import SampleList
@TRANSFORMER_LAYER.register_module()
@ -276,8 +276,7 @@ class KernelUpdateHead(nn.Module):
# the weight and bias of the layer norm
pass
if self.kernel_init:
logger = get_root_logger()
logger.info(
print_log(
'mask kernel in mask head is normal initialized by std 0.01')
nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01)

View File

@ -12,9 +12,9 @@ except ModuleNotFoundError:
from typing import List
from mmseg.core.utils import SampleList
from mmseg.ops import resize
from mmseg.registry import MODELS
from mmseg.utils import SampleList
from ..losses import accuracy
from .cascade_decode_head import BaseCascadeDecodeHead

View File

@ -4,9 +4,9 @@ import torch.nn.functional as F
from mmengine.data import PixelData
from torch import Tensor
from mmseg.core.data_structures.seg_data_sample import SegDataSample
from mmseg.core.utils import SampleList
from mmseg.data import SegDataSample
from mmseg.registry import MODELS
from mmseg.utils import SampleList
from .fcn_head import FCNHead

View File

@ -6,10 +6,10 @@ from mmengine.data import PixelData
from mmengine.model import BaseModel
from torch import Tensor
from mmseg.core import SegDataSample
from mmseg.core.utils import (ForwardResults, OptConfigType, OptMultiConfig,
OptSampleList, SampleList)
from mmseg.data import SegDataSample
from mmseg.ops import resize
from mmseg.utils import (ForwardResults, OptConfigType, OptMultiConfig,
OptSampleList, SampleList)
class BaseSegmentor(BaseModel, metaclass=ABCMeta):

View File

@ -3,10 +3,9 @@ from typing import List, Optional
from torch import Tensor, nn
from mmseg.core import add_prefix
from mmseg.core.utils import ConfigType, OptSampleList, SampleList
from mmseg.core.utils.typing import OptConfigType, OptMultiConfig
from mmseg.registry import MODELS
from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig,
OptSampleList, SampleList, add_prefix)
from .encoder_decoder import EncoderDecoder

View File

@ -6,10 +6,9 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from mmseg.core import add_prefix
from mmseg.core.utils import (ConfigType, OptConfigType, OptMultiConfig,
OptSampleList, SampleList)
from mmseg.registry import MODELS
from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig,
OptSampleList, SampleList, add_prefix)
from .base import BaseSegmentor

View File

@ -1,10 +1,29 @@
# Copyright (c) OpenMMLab. All rights reserved.
# yapf: disable
from .class_names import (ade_classes, ade_palette, cityscapes_classes,
cityscapes_palette, cocostuff_classes,
cocostuff_palette, dataset_aliases, get_classes,
get_palette, isaid_classes, isaid_palette,
loveda_classes, loveda_palette, potsdam_classes,
potsdam_palette, stare_classes, stare_palette,
vaihingen_classes, vaihingen_palette, voc_classes,
voc_palette)
# yapf: enable
from .collect_env import collect_env
from .logger import get_root_logger
from .misc import find_latest_checkpoint
from .set_env import register_all_modules, setup_multi_processes
from .misc import add_prefix, stack_batch
from .set_env import register_all_modules
from .typing import (ConfigType, ForwardResults, MultiConfig, OptConfigType,
OptMultiConfig, OptSampleList, SampleList, TensorDict,
TensorList)
__all__ = [
'get_root_logger', 'collect_env', 'find_latest_checkpoint',
'setup_multi_processes', 'register_all_modules'
'collect_env', 'register_all_modules', 'stack_batch', 'add_prefix',
'ConfigType', 'OptConfigType', 'MultiConfig', 'OptMultiConfig',
'SampleList', 'OptSampleList', 'TensorDict', 'TensorList',
'ForwardResults', 'cityscapes_classes', 'ade_classes', 'voc_classes',
'cocostuff_classes', 'loveda_classes', 'potsdam_classes',
'vaihingen_classes', 'isaid_classes', 'stare_classes',
'cityscapes_palette', 'ade_palette', 'voc_palette', 'cocostuff_palette',
'loveda_palette', 'potsdam_palette', 'vaihingen_palette', 'isaid_palette',
'stare_palette', 'dataset_aliases', 'get_classes', 'get_palette'
]

View File

@ -1,28 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from mmcv.utils import get_logger
def get_root_logger(log_file=None, log_level=logging.INFO):
"""Get the root logger.
The logger will be initialized if it has not been initialized. By default a
StreamHandler will be added. If `log_file` is specified, a FileHandler will
also be added. The name of the root logger is the top-level package name,
e.g., "mmseg".
Args:
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the root logger.
log_level (int): The root logger level. Note that only the process of
rank 0 is affected, while other processes will set the level to
"Error" and be silent most of the time.
Returns:
logging.Logger: The root logger.
"""
logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level)
return logger

View File

@ -1,41 +1,108 @@
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os.path as osp
import warnings
from typing import List, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from .typing import SampleList
def find_latest_checkpoint(path, suffix='pth'):
"""This function is for finding the latest checkpoint.
It will be used when automatically resume, modified from
https://github.com/open-mmlab/mmdetection/blob/dev-v2.20.0/mmdet/utils/misc.py
def add_prefix(inputs, prefix):
"""Add prefix for dict.
Args:
path (str): The path to find checkpoints.
suffix (str): File extension for the checkpoint. Defaults to pth.
inputs (dict): The input dict with str keys.
prefix (str): The prefix to add.
Returns:
latest_path(str | None): File path of the latest checkpoint.
"""
if not osp.exists(path):
warnings.warn("The path of the checkpoints doesn't exist.")
return None
if osp.exists(osp.join(path, f'latest.{suffix}')):
return osp.join(path, f'latest.{suffix}')
checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
if len(checkpoints) == 0:
warnings.warn('The are no checkpoints in the path')
return None
latest = -1
latest_path = ''
for checkpoint in checkpoints:
if len(checkpoint) < len(latest_path):
continue
# `count` is iteration number, as checkpoints are saved as
# 'iter_xx.pth' or 'epoch_xx.pth' and xx is iteration number.
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
if count > latest:
latest = count
latest_path = checkpoint
return latest_path
dict: The dict with keys updated with ``prefix``.
"""
outputs = dict()
for name, value in inputs.items():
outputs[f'{prefix}.{name}'] = value
return outputs
def stack_batch(inputs: List[torch.Tensor],
batch_data_samples: Optional[SampleList] = None,
size: Optional[tuple] = None,
size_divisor: Optional[int] = None,
pad_val: Union[int, float] = 0,
seg_pad_val: Union[int, float] = 255) -> torch.Tensor:
"""Stack multiple inputs to form a batch and pad the images and gt_sem_segs
to the max shape use the right bottom padding mode.
Args:
inputs (List[Tensor]): The input multiple tensors. each is a
CHW 3D-tensor.
batch_data_samples (list[:obj:`SegDataSample`]): The Data
Samples. It usually includes information such as `gt_sem_seg`.
size (tuple, optional): Fixed padding size.
size_divisor (int, optional): The divisor of padded size.
pad_val (int, float): The padding value. Defaults to 0
seg_pad_val (int, float): The padding value. Defaults to 255
Returns:
Tensor: The 4D-tensor.
batch_data_samples (list[:obj:`SegDataSample`]): After the padding of
the gt_seg_map.
"""
assert isinstance(inputs, list), \
f'Expected input type to be list, but got {type(inputs)}'
assert len(set([tensor.ndim for tensor in inputs])) == 1, \
f'Expected the dimensions of all inputs must be the same, ' \
f'but got {[tensor.ndim for tensor in inputs]}'
assert inputs[0].ndim == 3, f'Expected tensor dimension to be 3, ' \
f'but got {inputs[0].ndim}'
assert len(set([tensor.shape[0] for tensor in inputs])) == 1, \
f'Expected the channels of all inputs must be the same, ' \
f'but got {[tensor.shape[0] for tensor in inputs]}'
# only one of size and size_divisor should be valid
assert (size is not None) ^ (size_divisor is not None), \
'only one of size and size_divisor should be valid'
padded_inputs = []
padded_samples = []
inputs_sizes = [(img.shape[-2], img.shape[-1]) for img in inputs]
max_size = np.stack(inputs_sizes).max(0)
if size_divisor is not None and size_divisor > 1:
# the last two dims are H,W, both subject to divisibility requirement
max_size = (max_size +
(size_divisor - 1)) // size_divisor * size_divisor
for i in range(len(inputs)):
tensor = inputs[i]
if size is not None:
width = max(size[-1] - tensor.shape[-1], 0)
height = max(size[-2] - tensor.shape[-2], 0)
# (padding_left, padding_right, padding_top, padding_bottom)
padding_size = (0, width, 0, height)
elif size_divisor is not None:
width = max(max_size[-1] - tensor.shape[-1], 0)
height = max(max_size[-2] - tensor.shape[-2], 0)
padding_size = (0, width, 0, height)
else:
padding_size = [0, 0, 0, 0]
# pad img
pad_img = F.pad(tensor, padding_size, value=pad_val)
padded_inputs.append(pad_img)
# pad gt_sem_seg
if batch_data_samples is not None:
data_sample = batch_data_samples[i]
gt_sem_seg = data_sample.gt_sem_seg.data
del data_sample.gt_sem_seg.data
data_sample.gt_sem_seg.data = F.pad(
gt_sem_seg, padding_size, value=seg_pad_val)
data_sample.set_metainfo(
{'pad_shape': data_sample.gt_sem_seg.shape})
padded_samples.append(data_sample)
else:
padded_samples = None
return torch.stack(padded_inputs, dim=0), padded_samples

View File

@ -1,62 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import os
import platform
import warnings
import cv2
import torch.multiprocessing as mp
from mmengine import DefaultScope
from ..utils import get_root_logger
def setup_multi_processes(cfg):
"""Setup multi-processing environment variables."""
logger = get_root_logger()
# set multi-process start method
if platform.system() != 'Windows':
mp_start_method = cfg.get('mp_start_method', None)
current_method = mp.get_start_method(allow_none=True)
if mp_start_method in ('fork', 'spawn', 'forkserver'):
logger.info(
f'Multi-processing start method `{mp_start_method}` is '
f'different from the previous setting `{current_method}`.'
f'It will be force set to `{mp_start_method}`.')
mp.set_start_method(mp_start_method, force=True)
else:
logger.info(
f'Multi-processing start method is `{mp_start_method}`')
# disable opencv multithreading to avoid system being overloaded
opencv_num_threads = cfg.get('opencv_num_threads', None)
if isinstance(opencv_num_threads, int):
logger.info(f'OpenCV num_threads is `{opencv_num_threads}`')
cv2.setNumThreads(opencv_num_threads)
else:
logger.info(f'OpenCV num_threads is `{cv2.getNumThreads}')
if cfg.data.workers_per_gpu > 1:
# setup OMP threads
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
omp_num_threads = cfg.get('omp_num_threads', None)
if 'OMP_NUM_THREADS' not in os.environ:
if isinstance(omp_num_threads, int):
logger.info(f'OMP num threads is {omp_num_threads}')
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
else:
logger.info(f'OMP num threads is {os.environ["OMP_NUM_THREADS"] }')
# setup MKL threads
if 'MKL_NUM_THREADS' not in os.environ:
mkl_num_threads = cfg.get('mkl_num_threads', None)
if isinstance(mkl_num_threads, int):
logger.info(f'MKL num threads is {mkl_num_threads}')
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
else:
logger.info(f'MKL num threads is {os.environ["MKL_NUM_THREADS"]}')
def register_all_modules(init_default_scope: bool = True) -> None:
"""Register all modules in mmseg into the registries.
@ -69,9 +16,9 @@ def register_all_modules(init_default_scope: bool = True) -> None:
to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md
Defaults to True.
""" # noqa
import mmseg.core # noqa: F401,F403
import mmseg.data # noqa: F401,F403
import mmseg.datasets # noqa: F401,F403
import mmseg.datasets.pipelines # noqa: F401,F403
import mmseg.engine # noqa: F401,F403
import mmseg.metrics # noqa: F401,F403
import mmseg.models # noqa: F401,F403

View File

@ -5,7 +5,7 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
from mmengine.config import ConfigDict
from ..data_structures import SegDataSample
from mmseg.data import SegDataSample
# Type hint of config data
ConfigType = Union[ConfigDict, dict]

View File

@ -1,7 +0,0 @@
[pytest]
addopts = --xdoctest --xdoctest-style=auto
norecursedirs = .git ignore build __pycache__ data docker docs .eggs
filterwarnings= default
ignore:.*No cfgstr given in Cacher constructor or call.*:Warning
ignore:.*Define the __nice__ method for.*:Warning

View File

@ -60,72 +60,72 @@ def test_config_build_segmentor():
_check_decode_head(head_config, segmentor.decode_head)
def test_config_data_pipeline():
"""Test whether the data pipeline is valid and can process corner cases.
# def test_config_data_pipeline():
# """Test whether the data pipeline is valid and can process corner cases.
CommandLine:
xdoctest -m tests/test_config.py test_config_build_data_pipeline
"""
import numpy as np
from mmcv import Config
# CommandLine:
# xdoctest -m tests/test_config.py test_config_build_data_pipeline
# """
# import numpy as np
# from mmcv import Config
from mmseg.datasets.pipelines import Compose
# from mmseg.datasets.transforms import Compose
config_dpath = _get_config_directory()
print('Found config_dpath = {!r}'.format(config_dpath))
# config_dpath = _get_config_directory()
# print('Found config_dpath = {!r}'.format(config_dpath))
import glob
config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
config_names = [relpath(p, config_dpath) for p in config_fpaths]
# import glob
# config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
# config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
# config_names = [relpath(p, config_dpath) for p in config_fpaths]
print('Using {} config files'.format(len(config_names)))
# print('Using {} config files'.format(len(config_names)))
for config_fname in config_names:
config_fpath = join(config_dpath, config_fname)
print(
'Building data pipeline, config_fpath = {!r}'.format(config_fpath))
config_mod = Config.fromfile(config_fpath)
# for config_fname in config_names:
# config_fpath = join(config_dpath, config_fname)
# print(
# 'Building data pipeline, config_fpath = {!r}'.format(config_fpath))
# config_mod = Config.fromfile(config_fpath)
# remove loading pipeline
load_img_pipeline = config_mod.train_pipeline.pop(0)
to_float32 = load_img_pipeline.get('to_float32', False)
config_mod.train_pipeline.pop(0)
config_mod.test_pipeline.pop(0)
# remove loading annotation in test pipeline
config_mod.test_pipeline.pop(1)
# # remove loading pipeline
# load_img_pipeline = config_mod.train_pipeline.pop(0)
# to_float32 = load_img_pipeline.get('to_float32', False)
# config_mod.train_pipeline.pop(0)
# config_mod.test_pipeline.pop(0)
# # remove loading annotation in test pipeline
# config_mod.test_pipeline.pop(1)
train_pipeline = Compose(config_mod.train_pipeline)
test_pipeline = Compose(config_mod.test_pipeline)
# train_pipeline = Compose(config_mod.train_pipeline)
# test_pipeline = Compose(config_mod.test_pipeline)
img = np.random.randint(0, 255, size=(1024, 2048, 3), dtype=np.uint8)
if to_float32:
img = img.astype(np.float32)
seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8)
# img = np.random.randint(0, 255, size=(1024, 2048, 3), dtype=np.uint8)
# if to_float32:
# img = img.astype(np.float32)
# seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8)
results = dict(
filename='test_img.png',
ori_filename='test_img.png',
img=img,
img_shape=img.shape,
ori_shape=img.shape,
gt_seg_map=seg)
results['seg_fields'] = ['gt_seg_map']
# results = dict(
# filename='test_img.png',
# ori_filename='test_img.png',
# img=img,
# img_shape=img.shape,
# ori_shape=img.shape,
# gt_seg_map=seg)
# results['seg_fields'] = ['gt_seg_map']
print('Test training data pipeline: \n{!r}'.format(train_pipeline))
output_results = train_pipeline(results)
assert output_results is not None
# print('Test training data pipeline: \n{!r}'.format(train_pipeline))
# output_results = train_pipeline(results)
# assert output_results is not None
results = dict(
filename='test_img.png',
ori_filename='test_img.png',
img=img,
img_shape=img.shape,
ori_shape=img.shape,
)
print('Test testing data pipeline: \n{!r}'.format(test_pipeline))
output_results = test_pipeline(results)
assert output_results is not None
# results = dict(
# filename='test_img.png',
# ori_filename='test_img.png',
# img=img,
# img_shape=img.shape,
# ori_shape=img.shape,
# )
# print('Test testing data pipeline: \n{!r}'.format(test_pipeline))
# output_results = test_pipeline(results)
# assert output_results is not None
def _check_decode_head(decode_head_cfg, decode_head):

View File

@ -6,7 +6,7 @@ import pytest
import torch
from mmengine.data import PixelData
from mmseg.core import SegDataSample
from mmseg.data import SegDataSample
def _equal(a, b):

View File

@ -6,11 +6,11 @@ from unittest.mock import MagicMock
import pytest
from mmseg.core.evaluation import get_classes, get_palette
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
COCOStuffDataset, CustomDataset, ISPRSDataset,
LoveDADataset, PascalVOCDataset, PotsdamDataset,
iSAIDDataset)
from mmseg.utils import get_classes, get_palette
def test_classes():

View File

@ -6,8 +6,8 @@ import unittest
import numpy as np
from mmengine.data import BaseDataElement
from mmseg.core import SegDataSample
from mmseg.datasets.pipelines import PackSegInputs
from mmseg.data import SegDataSample
from mmseg.datasets.transforms import PackSegInputs
class TestPackSegInputs(unittest.TestCase):

View File

@ -7,7 +7,7 @@ import mmcv
import numpy as np
from mmcv.transforms import LoadImageFromFile
from mmseg.datasets.pipelines import LoadAnnotations
from mmseg.datasets.transforms import LoadAnnotations
class TestLoading(object):

View File

@ -7,7 +7,7 @@ import numpy as np
import pytest
from PIL import Image
from mmseg.datasets.pipelines import PhotoMetricDistortion, RandomCrop
from mmseg.datasets.transforms import PhotoMetricDistortion, RandomCrop
from mmseg.registry import TRANSFORMS

View File

@ -4,7 +4,7 @@ import os.path as osp
import mmcv
import pytest
from mmseg.datasets.pipelines import * # noqa
from mmseg.datasets.transforms import * # noqa
from mmseg.registry import TRANSFORMS

View File

@ -4,10 +4,13 @@ import pytest
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.optim.optimizer import build_optim_wrapper
from mmseg.core.builder import build_optimizer
from mmseg.core.optimizers.layer_decay_optimizer_constructor import \
from mmseg.engine.optimizers.layer_decay_optimizer_constructor import \
LearningRateDecayOptimizerConstructor
from mmseg.utils import register_all_modules
register_all_modules()
base_lr = 1
decay_rate = 2
@ -221,7 +224,7 @@ def test_learning_rate_decay_optimizer_constructor():
optimizer=optimizer_cfg,
paramwise_cfg=stagewise_paramwise_cfg,
constructor='LearningRateDecayOptimizerConstructor')
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
check_optimizer_lr_wd(optim_wrapper.optimizer,
expected_stage_wise_lr_wd_convnext)
# layerwise decay
@ -232,7 +235,7 @@ def test_learning_rate_decay_optimizer_constructor():
optimizer=optimizer_cfg,
paramwise_cfg=layerwise_paramwise_cfg,
constructor='LearningRateDecayOptimizerConstructor')
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
check_optimizer_lr_wd(optim_wrapper.optimizer,
expected_layer_wise_lr_wd_convnext)
@ -247,7 +250,7 @@ def test_learning_rate_decay_optimizer_constructor():
optimizer=optimizer_cfg,
paramwise_cfg=layerwise_paramwise_cfg,
constructor='LearningRateDecayOptimizerConstructor')
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
check_optimizer_lr_wd(optim_wrapper.optimizer,
expected_layer_wise_wd_lr_beit)
@ -274,7 +277,7 @@ def test_learning_rate_decay_optimizer_constructor():
optimizer=optimizer_cfg,
paramwise_cfg=layerwise_paramwise_cfg,
constructor='LearningRateDecayOptimizerConstructor')
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
check_optimizer_lr_wd(optim_wrapper.optimizer,
expected_layer_wise_wd_lr_beit)
@ -291,7 +294,7 @@ def test_beit_layer_decay_optimizer_constructor():
paramwise_cfg=paramwise_cfg,
optimizer=dict(
type='AdamW', lr=1, betas=(0.9, 0.999), weight_decay=0.05))
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
# optimizer = optim_wrapper_builder(model)
check_optimizer_lr_wd(optim_wrapper.optimizer,
expected_layer_wise_wd_lr_beit)

View File

@ -1,8 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmseg.core.builder import build_optimizer
from mmengine.optim import build_optim_wrapper
class ExampleModel(nn.Module):
@ -29,6 +28,6 @@ def test_build_optimizer():
type='OptimWrapper',
optimizer=dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum))
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
# test whether optimizer is successfully built from parent.
assert isinstance(optim_wrapper.optimizer, torch.optim.SGD)

View File

@ -5,7 +5,7 @@ import numpy as np
import torch
from mmengine.data import BaseDataElement, PixelData
from mmseg.core import SegDataSample
from mmseg.data import SegDataSample
from mmseg.metrics import CitysMetric

View File

@ -5,7 +5,7 @@ import numpy as np
import torch
from mmengine.data import BaseDataElement, PixelData
from mmseg.core import SegDataSample
from mmseg.data import SegDataSample
from mmseg.metrics import IoUMetric

View File

@ -4,7 +4,7 @@ from unittest import TestCase
import torch
from mmengine.data import PixelData
from mmseg.core import SegDataSample
from mmseg.data import SegDataSample
from mmseg.models import SegDataPreProcessor

View File

@ -13,7 +13,7 @@ from mmcv.cnn.utils import revert_sync_batchnorm
from mmengine.data import PixelData
from torch import Tensor
from mmseg.core import SegDataSample
from mmseg.data import SegDataSample
from mmseg.utils import register_all_modules
register_all_modules()

View File

@ -2,7 +2,7 @@
import pytest
import torch
from mmseg.core import OHEMPixelSampler
from mmseg.data import OHEMPixelSampler
from mmseg.models.decode_heads import FCNHead

View File

@ -1,40 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from mmseg.utils import find_latest_checkpoint
def test_find_latest_checkpoint():
with tempfile.TemporaryDirectory() as tempdir:
# no checkpoints in the path
path = tempdir
latest = find_latest_checkpoint(path)
assert latest is None
# The path doesn't exist
path = osp.join(tempdir, 'none')
latest = find_latest_checkpoint(path)
assert latest is None
# test when latest.pth exists
with tempfile.TemporaryDirectory() as tempdir:
with open(osp.join(tempdir, 'latest.pth'), 'w') as f:
f.write('latest')
path = tempdir
latest = find_latest_checkpoint(path)
assert latest == osp.join(tempdir, 'latest.pth')
with tempfile.TemporaryDirectory() as tempdir:
for iter in range(1600, 160001, 1600):
with open(osp.join(tempdir, f'iter_{iter}.pth'), 'w') as f:
f.write(f'iter_{iter}.pth')
latest = find_latest_checkpoint(tempdir)
assert latest == osp.join(tempdir, 'iter_160000.pth')
with tempfile.TemporaryDirectory() as tempdir:
for epoch in range(1, 21):
with open(osp.join(tempdir, f'epoch_{epoch}.pth'), 'w') as f:
f.write(f'epoch_{epoch}.pth')
latest = find_latest_checkpoint(tempdir)
assert latest == osp.join(tempdir, 'epoch_20.pth')

View File

@ -1,92 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import multiprocessing as mp
import os
import platform
import sys
from unittest import TestCase
import cv2
import pytest
from mmcv import Config
from mmengine import DefaultScope
from mmseg.utils import register_all_modules, setup_multi_processes
@pytest.mark.parametrize('workers_per_gpu', (0, 2))
@pytest.mark.parametrize(('valid', 'env_cfg'), [(True,
dict(
mp_start_method='fork',
opencv_num_threads=0,
omp_num_threads=1,
mkl_num_threads=1)),
(False,
dict(
mp_start_method=1,
opencv_num_threads=0.1,
omp_num_threads='s',
mkl_num_threads='1'))])
def test_setup_multi_processes(workers_per_gpu, valid, env_cfg):
# temp save system setting
sys_start_mehod = mp.get_start_method(allow_none=True)
sys_cv_threads = cv2.getNumThreads()
# pop and temp save system env vars
sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None)
sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None)
config = dict(data=dict(workers_per_gpu=workers_per_gpu))
config.update(env_cfg)
cfg = Config(config)
setup_multi_processes(cfg)
# test when cfg is valid and workers_per_gpu > 0
# setup_multi_processes will work
if valid and workers_per_gpu > 0:
# test config without setting env
assert os.getenv('OMP_NUM_THREADS') == str(env_cfg['omp_num_threads'])
assert os.getenv('MKL_NUM_THREADS') == str(env_cfg['mkl_num_threads'])
# when set to 0, the num threads will be 1
assert cv2.getNumThreads() == env_cfg[
'opencv_num_threads'] if env_cfg['opencv_num_threads'] > 0 else 1
if platform.system() != 'Windows':
assert mp.get_start_method() == env_cfg['mp_start_method']
# revert setting to avoid affecting other programs
if sys_start_mehod:
mp.set_start_method(sys_start_mehod, force=True)
cv2.setNumThreads(sys_cv_threads)
if sys_omp_threads:
os.environ['OMP_NUM_THREADS'] = sys_omp_threads
else:
os.environ.pop('OMP_NUM_THREADS')
if sys_mkl_threads:
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
else:
os.environ.pop('MKL_NUM_THREADS')
elif valid and workers_per_gpu == 0:
if platform.system() != 'Windows':
assert mp.get_start_method() == env_cfg['mp_start_method']
assert cv2.getNumThreads() == env_cfg[
'opencv_num_threads'] if env_cfg['opencv_num_threads'] > 0 else 1
assert 'OMP_NUM_THREADS' not in os.environ
assert 'MKL_NUM_THREADS' not in os.environ
if sys_start_mehod:
mp.set_start_method(sys_start_mehod, force=True)
cv2.setNumThreads(sys_cv_threads)
if sys_omp_threads:
os.environ['OMP_NUM_THREADS'] = sys_omp_threads
if sys_mkl_threads:
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
else:
assert mp.get_start_method() == sys_start_mehod
assert cv2.getNumThreads() == sys_cv_threads
assert 'OMP_NUM_THREADS' not in os.environ
assert 'MKL_NUM_THREADS' not in os.environ
from mmseg.utils import register_all_modules
class TestSetupEnv(TestCase):

View File

@ -1,340 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import shutil
import warnings
from typing import Any, Iterable
import mmcv
import numpy as np
import torch
from mmcv.parallel import MMDataParallel
from mmcv.runner import get_dist_info
from mmcv.utils import DictAction
# from mmseg.apis import single_gpu_test
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.models.segmentors.base import BaseSegmentor
from mmseg.ops import resize
single_gpu_test = None
class ONNXRuntimeSegmentor(BaseSegmentor):
def __init__(self, onnx_file: str, cfg: Any, device_id: int):
super(ONNXRuntimeSegmentor, self).__init__()
import onnxruntime as ort
# get the custom op path
ort_custom_op_path = ''
try:
from mmcv.ops import get_onnxruntime_op_path
ort_custom_op_path = get_onnxruntime_op_path()
except (ImportError, ModuleNotFoundError):
warnings.warn('If input model has custom op from mmcv, \
you may have to build mmcv with ONNXRuntime from source.')
session_options = ort.SessionOptions()
# register custom op for onnxruntime
if osp.exists(ort_custom_op_path):
session_options.register_custom_ops_library(ort_custom_op_path)
sess = ort.InferenceSession(onnx_file, session_options)
providers = ['CPUExecutionProvider']
options = [{}]
is_cuda_available = ort.get_device() == 'GPU'
if is_cuda_available:
providers.insert(0, 'CUDAExecutionProvider')
options.insert(0, {'device_id': device_id})
sess.set_providers(providers, options)
self.sess = sess
self.device_id = device_id
self.io_binding = sess.io_binding()
self.output_names = [_.name for _ in sess.get_outputs()]
for name in self.output_names:
self.io_binding.bind_output(name)
self.cfg = cfg
self.test_mode = cfg.model.test_cfg.mode
self.is_cuda_available = is_cuda_available
def extract_feat(self, imgs):
raise NotImplementedError('This method is not implemented.')
def encode_decode(self, img, img_metas):
raise NotImplementedError('This method is not implemented.')
def forward_train(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def simple_test(self, img: torch.Tensor, img_meta: Iterable,
**kwargs) -> list:
if not self.is_cuda_available:
img = img.detach().cpu()
elif self.device_id >= 0:
img = img.cuda(self.device_id)
device_type = img.device.type
self.io_binding.bind_input(
name='input',
device_type=device_type,
device_id=self.device_id,
element_type=np.float32,
shape=img.shape,
buffer_ptr=img.data_ptr())
self.sess.run_with_iobinding(self.io_binding)
seg_pred = self.io_binding.copy_outputs_to_cpu()[0]
# whole might support dynamic reshape
ori_shape = img_meta[0]['ori_shape']
if not (ori_shape[0] == seg_pred.shape[-2]
and ori_shape[1] == seg_pred.shape[-1]):
seg_pred = torch.from_numpy(seg_pred).float()
seg_pred = resize(
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
seg_pred = seg_pred.long().detach().cpu().numpy()
seg_pred = seg_pred[0]
seg_pred = list(seg_pred)
return seg_pred
def aug_test(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
class TensorRTSegmentor(BaseSegmentor):
def __init__(self, trt_file: str, cfg: Any, device_id: int):
super(TensorRTSegmentor, self).__init__()
from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin
try:
load_tensorrt_plugin()
except (ImportError, ModuleNotFoundError):
warnings.warn('If input model has custom op from mmcv, \
you may have to build mmcv with TensorRT from source.')
model = TRTWraper(
trt_file, input_names=['input'], output_names=['output'])
self.model = model
self.device_id = device_id
self.cfg = cfg
self.test_mode = cfg.model.test_cfg.mode
def extract_feat(self, imgs):
raise NotImplementedError('This method is not implemented.')
def encode_decode(self, img, img_metas):
raise NotImplementedError('This method is not implemented.')
def forward_train(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def simple_test(self, img: torch.Tensor, img_meta: Iterable,
**kwargs) -> list:
with torch.cuda.device(self.device_id), torch.no_grad():
seg_pred = self.model({'input': img})['output']
seg_pred = seg_pred.detach().cpu().numpy()
# whole might support dynamic reshape
ori_shape = img_meta[0]['ori_shape']
if not (ori_shape[0] == seg_pred.shape[-2]
and ori_shape[1] == seg_pred.shape[-1]):
seg_pred = torch.from_numpy(seg_pred).float()
seg_pred = resize(
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
seg_pred = seg_pred.long().detach().cpu().numpy()
seg_pred = seg_pred[0]
seg_pred = list(seg_pred)
return seg_pred
def aug_test(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description='mmseg backend test (and eval)')
parser.add_argument('config', help='test config file path')
parser.add_argument('model', help='Input model file')
parser.add_argument(
'--backend',
help='Backend of the model.',
choices=['onnxruntime', 'tensorrt'])
parser.add_argument('--out', help='output result file in pickle format')
parser.add_argument(
'--format-only',
action='store_true',
help='Format the output results without perform evaluation. It is'
'useful when you want to format the result to a specific format and '
'submit it to the test server')
parser.add_argument(
'--eval',
type=str,
nargs='+',
help='evaluation metrics, which depends on the dataset, e.g., "mIoU"'
' for generic datasets, and "cityscapes" for Cityscapes')
parser.add_argument('--show', action='store_true', help='show results')
parser.add_argument(
'--show-dir', help='directory where painted images will be saved')
parser.add_argument(
'--options',
nargs='+',
action=DictAction,
help="--options is deprecated in favor of --cfg_options' and it will "
'not be supported in version v0.22.0. Override some settings in the '
'used config, the key-value pair in xxx=yyy format will be merged '
'into config file. If the value to be overwritten is a list, it '
'should be like key="[a,b]" or key=a,b It also allows nested '
'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation '
'marks are necessary and that no white space is allowed.')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--eval-options',
nargs='+',
action=DictAction,
help='custom options for evaluation')
parser.add_argument(
'--opacity',
type=float,
default=0.5,
help='Opacity of painted segmentation map. In (0, 1] range.')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
if args.options and args.cfg_options:
raise ValueError(
'--options and --cfg-options cannot be both '
'specified, --options is deprecated in favor of --cfg-options. '
'--options will not be supported in version v0.22.0.')
if args.options:
warnings.warn('--options is deprecated in favor of --cfg-options. '
'--options will not be supported in version v0.22.0.')
args.cfg_options = args.options
return args
def main():
args = parse_args()
assert args.out or args.eval or args.format_only or args.show \
or args.show_dir, \
('Please specify at least one operation (save/eval/format/show the '
'results / save the results) with the argument "--out", "--eval"'
', "--format-only", "--show" or "--show-dir"')
if args.eval and args.format_only:
raise ValueError('--eval and --format_only cannot be both specified')
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
raise ValueError('The output file must be a pkl file.')
cfg = mmcv.Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
cfg.model.pretrained = None
cfg.data.test.test_mode = True
# init distributed env first, since logger depends on the dist info.
distributed = False
# build the dataloader
# TODO: support multiple images per gpu (only minor changes are needed)
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
# load onnx config and meta
cfg.model.train_cfg = None
if args.backend == 'onnxruntime':
model = ONNXRuntimeSegmentor(args.model, cfg=cfg, device_id=0)
elif args.backend == 'tensorrt':
model = TensorRTSegmentor(args.model, cfg=cfg, device_id=0)
model.CLASSES = dataset.CLASSES
model.PALETTE = dataset.PALETTE
# clean gpu memory when starting a new evaluation.
torch.cuda.empty_cache()
eval_kwargs = {} if args.eval_options is None else args.eval_options
# Deprecated
efficient_test = eval_kwargs.get('efficient_test', False)
if efficient_test:
warnings.warn(
'``efficient_test=True`` does not have effect in tools/test.py, '
'the evaluation and format results are CPU memory efficient by '
'default')
eval_on_format_results = (
args.eval is not None and 'cityscapes' in args.eval)
if eval_on_format_results:
assert len(args.eval) == 1, 'eval on format results is not ' \
'applicable for metrics other than ' \
'cityscapes'
if args.format_only or eval_on_format_results:
if 'imgfile_prefix' in eval_kwargs:
tmpdir = eval_kwargs['imgfile_prefix']
else:
tmpdir = '.format_cityscapes'
eval_kwargs.setdefault('imgfile_prefix', tmpdir)
mmcv.mkdir_or_exist(tmpdir)
else:
tmpdir = None
model = MMDataParallel(model, device_ids=[0])
results = single_gpu_test(
model,
data_loader,
args.show,
args.show_dir,
False,
args.opacity,
pre_eval=args.eval is not None and not eval_on_format_results,
format_only=args.format_only or eval_on_format_results,
format_args=eval_kwargs)
rank, _ = get_dist_info()
if rank == 0:
if args.out:
warnings.warn(
'The behavior of ``args.out`` has been changed since MMSeg '
'v0.16, the pickled outputs could be seg map as type of '
'np.array, pre-eval results or file paths for '
'``dataset.format_results()``.')
print(f'\nwriting results to {args.out}')
mmcv.dump(results, args.out)
if args.eval:
dataset.evaluate(results, args.eval, **eval_kwargs)
if tmpdir is not None and eval_on_format_results:
# remove tmp dir when cityscapes evaluation
shutil.rmtree(tmpdir)
if __name__ == '__main__':
main()
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This tool will be deprecated in future. '
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)

View File

@ -1,289 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import warnings
from typing import Iterable, Optional, Union
import matplotlib.pyplot as plt
import mmcv
import numpy as np
import onnxruntime as ort
import torch
from mmcv.ops import get_onnxruntime_op_path
from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt,
save_trt_engine)
from mmseg.apis.inference import LoadImage
from mmseg.datasets import DATASETS
from mmseg.datasets.pipelines import Compose
def get_GiB(x: int):
"""return x GiB."""
return x * (1 << 30)
def _prepare_input_img(img_path: str,
test_pipeline: Iterable[dict],
shape: Optional[Iterable] = None,
rescale_shape: Optional[Iterable] = None) -> dict:
# build the data pipeline
if shape is not None:
test_pipeline[1]['img_scale'] = (shape[1], shape[0])
test_pipeline[1]['transforms'][0]['keep_ratio'] = False
test_pipeline = [LoadImage()] + test_pipeline[1:]
test_pipeline = Compose(test_pipeline)
# prepare data
data = dict(img=img_path)
data = test_pipeline(data)
imgs = data['img']
img_metas = [i.data for i in data['img_metas']]
if rescale_shape is not None:
for img_meta in img_metas:
img_meta['ori_shape'] = tuple(rescale_shape) + (3, )
mm_inputs = {'imgs': imgs, 'img_metas': img_metas}
return mm_inputs
def _update_input_img(img_list: Iterable, img_meta_list: Iterable):
# update img and its meta list
N = img_list[0].size(0)
img_meta = img_meta_list[0][0]
img_shape = img_meta['img_shape']
ori_shape = img_meta['ori_shape']
pad_shape = img_meta['pad_shape']
new_img_meta_list = [[{
'img_shape':
img_shape,
'ori_shape':
ori_shape,
'pad_shape':
pad_shape,
'filename':
img_meta['filename'],
'scale_factor':
(img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2,
'flip':
False,
} for _ in range(N)]]
return img_list, new_img_meta_list
def show_result_pyplot(img: Union[str, np.ndarray],
result: np.ndarray,
palette: Optional[Iterable] = None,
fig_size: Iterable[int] = (15, 10),
opacity: float = 0.5,
title: str = '',
block: bool = True):
img = mmcv.imread(img)
img = img.copy()
seg = result[0]
seg = mmcv.imresize(seg, img.shape[:2][::-1])
palette = np.array(palette)
assert palette.shape[1] == 3
assert len(palette.shape) == 2
assert 0 < opacity <= 1.0
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
# convert to BGR
color_seg = color_seg[..., ::-1]
img = img * (1 - opacity) + color_seg * opacity
img = img.astype(np.uint8)
plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img))
plt.title(title)
plt.tight_layout()
plt.show(block=block)
def onnx2tensorrt(onnx_file: str,
trt_file: str,
config: dict,
input_config: dict,
fp16: bool = False,
verify: bool = False,
show: bool = False,
dataset: str = 'CityscapesDataset',
workspace_size: int = 1,
verbose: bool = False):
import tensorrt as trt
min_shape = input_config['min_shape']
max_shape = input_config['max_shape']
# create trt engine and wrapper
opt_shape_dict = {'input': [min_shape, min_shape, max_shape]}
max_workspace_size = get_GiB(workspace_size)
trt_engine = onnx2trt(
onnx_file,
opt_shape_dict,
log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR,
fp16_mode=fp16,
max_workspace_size=max_workspace_size)
save_dir, _ = osp.split(trt_file)
if save_dir:
os.makedirs(save_dir, exist_ok=True)
save_trt_engine(trt_engine, trt_file)
print(f'Successfully created TensorRT engine: {trt_file}')
if verify:
inputs = _prepare_input_img(
input_config['input_path'],
config.data.test.pipeline,
shape=min_shape[2:])
imgs = inputs['imgs']
img_metas = inputs['img_metas']
img_list = [img[None, :] for img in imgs]
img_meta_list = [[img_meta] for img_meta in img_metas]
# update img_meta
img_list, img_meta_list = _update_input_img(img_list, img_meta_list)
if max_shape[0] > 1:
# concate flip image for batch test
flip_img_list = [_.flip(-1) for _ in img_list]
img_list = [
torch.cat((ori_img, flip_img), 0)
for ori_img, flip_img in zip(img_list, flip_img_list)
]
# Get results from ONNXRuntime
ort_custom_op_path = get_onnxruntime_op_path()
session_options = ort.SessionOptions()
if osp.exists(ort_custom_op_path):
session_options.register_custom_ops_library(ort_custom_op_path)
sess = ort.InferenceSession(onnx_file, session_options)
sess.set_providers(['CPUExecutionProvider'], [{}]) # use cpu mode
onnx_output = sess.run(['output'],
{'input': img_list[0].detach().numpy()})[0][0]
# Get results from TensorRT
trt_model = TRTWraper(trt_file, ['input'], ['output'])
with torch.no_grad():
trt_outputs = trt_model({'input': img_list[0].contiguous().cuda()})
trt_output = trt_outputs['output'][0].cpu().detach().numpy()
if show:
dataset = DATASETS.get(dataset)
assert dataset is not None
palette = dataset.PALETTE
show_result_pyplot(
input_config['input_path'],
(onnx_output[0].astype(np.uint8), ),
palette=palette,
title='ONNXRuntime',
block=False)
show_result_pyplot(
input_config['input_path'], (trt_output[0].astype(np.uint8), ),
palette=palette,
title='TensorRT')
np.testing.assert_allclose(
onnx_output, trt_output, rtol=1e-03, atol=1e-05)
print('TensorRT and ONNXRuntime output all close.')
def parse_args():
parser = argparse.ArgumentParser(
description='Convert MMSegmentation models from ONNX to TensorRT')
parser.add_argument('config', help='Config file of the model')
parser.add_argument('model', help='Path to the input ONNX model')
parser.add_argument(
'--trt-file', type=str, help='Path to the output TensorRT engine')
parser.add_argument(
'--max-shape',
type=int,
nargs=4,
default=[1, 3, 400, 600],
help='Maximum shape of model input.')
parser.add_argument(
'--min-shape',
type=int,
nargs=4,
default=[1, 3, 400, 600],
help='Minimum shape of model input.')
parser.add_argument('--fp16', action='store_true', help='Enable fp16 mode')
parser.add_argument(
'--workspace-size',
type=int,
default=1,
help='Max workspace size in GiB')
parser.add_argument(
'--input-img', type=str, default='', help='Image for test')
parser.add_argument(
'--show', action='store_true', help='Whether to show output results')
parser.add_argument(
'--dataset',
type=str,
default='CityscapesDataset',
help='Dataset name')
parser.add_argument(
'--verify',
action='store_true',
help='Verify the outputs of ONNXRuntime and TensorRT')
parser.add_argument(
'--verbose',
action='store_true',
help='Whether to verbose logging messages while creating \
TensorRT engine.')
args = parser.parse_args()
return args
if __name__ == '__main__':
assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.'
args = parse_args()
if not args.input_img:
args.input_img = osp.join(osp.dirname(__file__), '../demo/demo.png')
# check arguments
assert osp.exists(args.config), 'Config {} not found.'.format(args.config)
assert osp.exists(args.model), \
'ONNX model {} not found.'.format(args.model)
assert args.workspace_size >= 0, 'Workspace size less than 0.'
assert DATASETS.get(args.dataset) is not None, \
'Dataset {} does not found.'.format(args.dataset)
for max_value, min_value in zip(args.max_shape, args.min_shape):
assert max_value >= min_value, \
'max_shape should be larger than min shape'
input_config = {
'min_shape': args.min_shape,
'max_shape': args.max_shape,
'input_path': args.input_img
}
cfg = mmcv.Config.fromfile(args.config)
onnx2tensorrt(
args.model,
args.trt_file,
cfg,
input_config,
fp16=args.fp16,
verify=args.verify,
show=args.show,
dataset=args.dataset,
workspace_size=args.workspace_size,
verbose=args.verbose)
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This tool will be deprecated in future. '
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)

View File

@ -1,405 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import warnings
from functools import partial
import mmcv
import numpy as np
import onnxruntime as rt
import torch
import torch._C
import torch.serialization
from mmcv import DictAction
from mmcv.onnx import register_extra_symbolics
from mmcv.runner import load_checkpoint
from torch import nn
from mmseg.apis import show_result_pyplot
from mmseg.apis.inference import LoadImage
from mmseg.datasets.pipelines import Compose
from mmseg.models import build_segmentor
from mmseg.ops import resize
torch.manual_seed(3)
def _convert_batchnorm(module):
module_output = module
if isinstance(module, torch.nn.SyncBatchNorm):
module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
module_output.weight.data = module.weight.data.clone().detach()
module_output.bias.data = module.bias.data.clone().detach()
# keep requires_grad unchanged
module_output.weight.requires_grad = module.weight.requires_grad
module_output.bias.requires_grad = module.bias.requires_grad
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, _convert_batchnorm(child))
del module
return module_output
def _demo_mm_inputs(input_shape, num_classes):
"""Create a superset of inputs needed to run test or train batches.
Args:
input_shape (tuple):
input batch dimensions
num_classes (int):
number of semantic classes
"""
(N, C, H, W) = input_shape
rng = np.random.RandomState(0)
imgs = rng.rand(*input_shape)
segs = rng.randint(
low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
img_metas = [{
'img_shape': (H, W, C),
'ori_shape': (H, W, C),
'pad_shape': (H, W, C),
'filename': '<demo>.png',
'scale_factor': 1.0,
'flip': False,
} for _ in range(N)]
mm_inputs = {
'imgs': torch.FloatTensor(imgs).requires_grad_(True),
'img_metas': img_metas,
'gt_semantic_seg': torch.LongTensor(segs)
}
return mm_inputs
def _prepare_input_img(img_path,
test_pipeline,
shape=None,
rescale_shape=None):
# build the data pipeline
if shape is not None:
test_pipeline[1]['img_scale'] = (shape[1], shape[0])
test_pipeline[1]['transforms'][0]['keep_ratio'] = False
test_pipeline = [LoadImage()] + test_pipeline[1:]
test_pipeline = Compose(test_pipeline)
# prepare data
data = dict(img=img_path)
data = test_pipeline(data)
imgs = data['img']
img_metas = [i.data for i in data['img_metas']]
if rescale_shape is not None:
for img_meta in img_metas:
img_meta['ori_shape'] = tuple(rescale_shape) + (3, )
mm_inputs = {'imgs': imgs, 'img_metas': img_metas}
return mm_inputs
def _update_input_img(img_list, img_meta_list, update_ori_shape=False):
# update img and its meta list
N, C, H, W = img_list[0].shape
img_meta = img_meta_list[0][0]
img_shape = (H, W, C)
if update_ori_shape:
ori_shape = img_shape
else:
ori_shape = img_meta['ori_shape']
pad_shape = img_shape
new_img_meta_list = [[{
'img_shape':
img_shape,
'ori_shape':
ori_shape,
'pad_shape':
pad_shape,
'filename':
img_meta['filename'],
'scale_factor':
(img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2,
'flip':
False,
} for _ in range(N)]]
return img_list, new_img_meta_list
def pytorch2onnx(model,
mm_inputs,
opset_version=11,
show=False,
output_file='tmp.onnx',
verify=False,
dynamic_export=False):
"""Export Pytorch model to ONNX model and verify the outputs are same
between Pytorch and ONNX.
Args:
model (nn.Module): Pytorch model we want to export.
mm_inputs (dict): Contain the input tensors and img_metas information.
opset_version (int): The onnx op version. Default: 11.
show (bool): Whether print the computation graph. Default: False.
output_file (string): The path to where we store the output ONNX model.
Default: `tmp.onnx`.
verify (bool): Whether compare the outputs between Pytorch and ONNX.
Default: False.
dynamic_export (bool): Whether to export ONNX with dynamic axis.
Default: False.
"""
model.cpu().eval()
test_mode = model.test_cfg.mode
if isinstance(model.decode_head, nn.ModuleList):
num_classes = model.decode_head[-1].num_classes
else:
num_classes = model.decode_head.num_classes
imgs = mm_inputs.pop('imgs')
img_metas = mm_inputs.pop('img_metas')
img_list = [img[None, :] for img in imgs]
img_meta_list = [[img_meta] for img_meta in img_metas]
# update img_meta
img_list, img_meta_list = _update_input_img(img_list, img_meta_list)
# replace original forward function
origin_forward = model.forward
model.forward = partial(
model.forward,
img_metas=img_meta_list,
return_loss=False,
rescale=True)
dynamic_axes = None
if dynamic_export:
if test_mode == 'slide':
dynamic_axes = {'input': {0: 'batch'}, 'output': {1: 'batch'}}
else:
dynamic_axes = {
'input': {
0: 'batch',
2: 'height',
3: 'width'
},
'output': {
1: 'batch',
2: 'height',
3: 'width'
}
}
register_extra_symbolics(opset_version)
with torch.no_grad():
torch.onnx.export(
model, (img_list, ),
output_file,
input_names=['input'],
output_names=['output'],
export_params=True,
keep_initializers_as_inputs=False,
verbose=show,
opset_version=opset_version,
dynamic_axes=dynamic_axes)
print(f'Successfully exported ONNX model: {output_file}')
model.forward = origin_forward
if verify:
# check by onnx
import onnx
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)
if dynamic_export and test_mode == 'whole':
# scale image for dynamic shape test
img_list = [resize(_, scale_factor=1.5) for _ in img_list]
# concate flip image for batch test
flip_img_list = [_.flip(-1) for _ in img_list]
img_list = [
torch.cat((ori_img, flip_img), 0)
for ori_img, flip_img in zip(img_list, flip_img_list)
]
# update img_meta
img_list, img_meta_list = _update_input_img(
img_list, img_meta_list, test_mode == 'whole')
# check the numerical value
# get pytorch output
with torch.no_grad():
pytorch_result = model(img_list, img_meta_list, return_loss=False)
pytorch_result = np.stack(pytorch_result, 0)
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [
node.name for node in onnx_model.graph.initializer
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 1)
sess = rt.InferenceSession(output_file)
onnx_result = sess.run(
None, {net_feed_input[0]: img_list[0].detach().numpy()})[0][0]
# show segmentation results
if show:
import os.path as osp
import cv2
img = img_meta_list[0][0]['filename']
if not osp.exists(img):
img = imgs[0][:3, ...].permute(1, 2, 0) * 255
img = img.detach().numpy().astype(np.uint8)
ori_shape = img.shape[:2]
else:
ori_shape = LoadImage()({'img': img})['ori_shape']
# resize onnx_result to ori_shape
onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8),
(ori_shape[1], ori_shape[0]))
show_result_pyplot(
model,
img, (onnx_result_, ),
palette=model.PALETTE,
block=False,
title='ONNXRuntime',
opacity=0.5)
# resize pytorch_result to ori_shape
pytorch_result_ = cv2.resize(pytorch_result[0].astype(np.uint8),
(ori_shape[1], ori_shape[0]))
show_result_pyplot(
model,
img, (pytorch_result_, ),
title='PyTorch',
palette=model.PALETTE,
opacity=0.5)
# compare results
np.testing.assert_allclose(
pytorch_result.astype(np.float32) / num_classes,
onnx_result.astype(np.float32) / num_classes,
rtol=1e-5,
atol=1e-5,
err_msg='The outputs are different between Pytorch and ONNX')
print('The outputs are same between Pytorch and ONNX')
def parse_args():
parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX')
parser.add_argument('config', help='test config file path')
parser.add_argument('--checkpoint', help='checkpoint file', default=None)
parser.add_argument(
'--input-img', type=str, help='Images for input', default=None)
parser.add_argument(
'--show',
action='store_true',
help='show onnx graph and segmentation results')
parser.add_argument(
'--verify', action='store_true', help='verify the onnx model')
parser.add_argument('--output-file', type=str, default='tmp.onnx')
parser.add_argument('--opset-version', type=int, default=11)
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=None,
help='input image height and width.')
parser.add_argument(
'--rescale_shape',
type=int,
nargs='+',
default=None,
help='output image rescale height and width, work for slide mode.')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='Override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--dynamic-export',
action='store_true',
help='Whether to export onnx with dynamic axis.')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
cfg = mmcv.Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
cfg.model.pretrained = None
if args.shape is None:
img_scale = cfg.test_pipeline[1]['img_scale']
input_shape = (1, 3, img_scale[1], img_scale[0])
elif len(args.shape) == 1:
input_shape = (1, 3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (
1,
3,
) + tuple(args.shape)
else:
raise ValueError('invalid input shape')
test_mode = cfg.model.test_cfg.mode
# build the model and load checkpoint
cfg.model.train_cfg = None
segmentor = build_segmentor(
cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
# convert SyncBN to BN
segmentor = _convert_batchnorm(segmentor)
if args.checkpoint:
checkpoint = load_checkpoint(
segmentor, args.checkpoint, map_location='cpu')
segmentor.CLASSES = checkpoint['meta']['CLASSES']
segmentor.PALETTE = checkpoint['meta']['PALETTE']
# read input or create dummpy input
if args.input_img is not None:
preprocess_shape = (input_shape[2], input_shape[3])
rescale_shape = None
if args.rescale_shape is not None:
rescale_shape = [args.rescale_shape[0], args.rescale_shape[1]]
mm_inputs = _prepare_input_img(
args.input_img,
cfg.data.test.pipeline,
shape=preprocess_shape,
rescale_shape=rescale_shape)
else:
if isinstance(segmentor.decode_head, nn.ModuleList):
num_classes = segmentor.decode_head[-1].num_classes
else:
num_classes = segmentor.decode_head.num_classes
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
# convert model to onnx file
pytorch2onnx(
segmentor,
mm_inputs,
opset_version=args.opset_version,
show=args.show,
output_file=args.output_file,
verify=args.verify,
dynamic_export=args.dynamic_export)
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This tool will be deprecated in future. '
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)