mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Merge branch 'zhengmiao/refactory-content' into 'refactor_dev'
[Refactory] MMSegmentation Content See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!69
This commit is contained in:
commit
d0fb6cc833
1
.gitignore
vendored
1
.gitignore
vendored
@ -105,7 +105,6 @@ venv.bak/
|
|||||||
# mypy
|
# mypy
|
||||||
.mypy_cache/
|
.mypy_cache/
|
||||||
|
|
||||||
data
|
|
||||||
.vscode
|
.vscode
|
||||||
.idea
|
.idea
|
||||||
|
|
||||||
|
@ -145,7 +145,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from mmseg.apis import inference_model, init_model, show_result_pyplot\n",
|
"from mmseg.apis import inference_model, init_model, show_result_pyplot\n",
|
||||||
"from mmseg.core.evaluation import get_palette"
|
"from mmseg.utils import get_palette"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
from mmseg.apis import inference_model, init_model, show_result_pyplot
|
from mmseg.apis import inference_model, init_model, show_result_pyplot
|
||||||
from mmseg.core.evaluation import get_palette
|
from mmseg.utils import get_palette
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -21,7 +21,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from mmseg.apis import init_model, inference_model, show_result_pyplot\n",
|
"from mmseg.apis import init_model, inference_model, show_result_pyplot\n",
|
||||||
"from mmseg.core.evaluation import get_palette"
|
"from mmseg.utils import get_palette"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -4,7 +4,7 @@ from argparse import ArgumentParser
|
|||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
from mmseg.apis import inference_model, init_model
|
from mmseg.apis import inference_model, init_model
|
||||||
from mmseg.core.evaluation import get_palette
|
from mmseg.utils import get_palette
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -5,7 +5,7 @@ import torch
|
|||||||
from mmcv.parallel import collate, scatter
|
from mmcv.parallel import collate, scatter
|
||||||
from mmcv.runner import load_checkpoint
|
from mmcv.runner import load_checkpoint
|
||||||
|
|
||||||
from mmseg.datasets.pipelines import Compose
|
from mmseg.datasets.transforms import Compose
|
||||||
from mmseg.models import build_segmentor
|
from mmseg.models import build_segmentor
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,9 +0,0 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
from .builder import build_optimizer
|
|
||||||
from .data_structures import * # noqa: F401, F403
|
|
||||||
from .evaluation import * # noqa: F401, F403
|
|
||||||
from .optimizers import * # noqa: F401, F403
|
|
||||||
from .seg import * # noqa: F401, F403
|
|
||||||
from .utils import * # noqa: F401, F403
|
|
||||||
|
|
||||||
__all__ = ['build_optimizer']
|
|
@ -1,18 +0,0 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import copy
|
|
||||||
|
|
||||||
from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS
|
|
||||||
|
|
||||||
|
|
||||||
def build_optimizer(model, cfg):
|
|
||||||
optim_wrapper_cfg = copy.deepcopy(cfg)
|
|
||||||
constructor_type = optim_wrapper_cfg.pop('constructor',
|
|
||||||
'DefaultOptimWrapperConstructor')
|
|
||||||
paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None)
|
|
||||||
optim_wrapper_builder = OPTIM_WRAPPER_CONSTRUCTORS.build(
|
|
||||||
dict(
|
|
||||||
type=constructor_type,
|
|
||||||
optim_wrapper_cfg=optim_wrapper_cfg,
|
|
||||||
paramwise_cfg=paramwise_cfg))
|
|
||||||
optim_wrapper = optim_wrapper_builder(model)
|
|
||||||
return optim_wrapper
|
|
@ -1,4 +0,0 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
from .seg_data_sample import SegDataSample
|
|
||||||
|
|
||||||
__all__ = ['SegDataSample']
|
|
@ -1,4 +0,0 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
from .class_names import get_classes, get_palette
|
|
||||||
|
|
||||||
__all__ = ['get_classes', 'get_palette']
|
|
@ -1,5 +0,0 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
from .base_pixel_sampler import BasePixelSampler
|
|
||||||
from .ohem_pixel_sampler import OHEMPixelSampler
|
|
||||||
|
|
||||||
__all__ = ['BasePixelSampler', 'OHEMPixelSampler']
|
|
@ -1,12 +0,0 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
from .dist_util import check_dist_init, sync_random_seed
|
|
||||||
from .misc import add_prefix, stack_batch
|
|
||||||
from .typing import (ConfigType, ForwardResults, MultiConfig, OptConfigType,
|
|
||||||
OptMultiConfig, OptSampleList, SampleList, TensorDict,
|
|
||||||
TensorList)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'add_prefix', 'check_dist_init', 'sync_random_seed', 'stack_batch',
|
|
||||||
'ConfigType', 'OptConfigType', 'MultiConfig', 'OptMultiConfig',
|
|
||||||
'SampleList', 'OptSampleList', 'TensorDict', 'TensorList', 'ForwardResults'
|
|
||||||
]
|
|
@ -1,46 +0,0 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from mmcv.runner import get_dist_info
|
|
||||||
|
|
||||||
|
|
||||||
def check_dist_init():
|
|
||||||
return dist.is_available() and dist.is_initialized()
|
|
||||||
|
|
||||||
|
|
||||||
def sync_random_seed(seed=None, device='cuda'):
|
|
||||||
"""Make sure different ranks share the same seed. All workers must call
|
|
||||||
this function, otherwise it will deadlock. This method is generally used in
|
|
||||||
`DistributedSampler`, because the seed should be identical across all
|
|
||||||
processes in the distributed group.
|
|
||||||
|
|
||||||
In distributed sampling, different ranks should sample non-overlapped
|
|
||||||
data in the dataset. Therefore, this function is used to make sure that
|
|
||||||
each rank shuffles the data indices in the same order based
|
|
||||||
on the same seed. Then different ranks could use different indices
|
|
||||||
to select non-overlapped data from the same data list.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seed (int, Optional): The seed. Default to None.
|
|
||||||
device (str): The device where the seed will be put on.
|
|
||||||
Default to 'cuda'.
|
|
||||||
Returns:
|
|
||||||
int: Seed to be used.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if seed is None:
|
|
||||||
seed = np.random.randint(2**31)
|
|
||||||
assert isinstance(seed, int)
|
|
||||||
|
|
||||||
rank, world_size = get_dist_info()
|
|
||||||
|
|
||||||
if world_size == 1:
|
|
||||||
return seed
|
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
|
|
||||||
else:
|
|
||||||
random_num = torch.tensor(0, dtype=torch.int32, device=device)
|
|
||||||
dist.broadcast(random_num, src=0)
|
|
||||||
return random_num.item()
|
|
@ -1,108 +0,0 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from mmseg.core.utils.typing import SampleList
|
|
||||||
|
|
||||||
|
|
||||||
def add_prefix(inputs, prefix):
|
|
||||||
"""Add prefix for dict.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs (dict): The input dict with str keys.
|
|
||||||
prefix (str): The prefix to add.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
dict: The dict with keys updated with ``prefix``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
outputs = dict()
|
|
||||||
for name, value in inputs.items():
|
|
||||||
outputs[f'{prefix}.{name}'] = value
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
def stack_batch(inputs: List[torch.Tensor],
|
|
||||||
batch_data_samples: Optional[SampleList] = None,
|
|
||||||
size: Optional[tuple] = None,
|
|
||||||
size_divisor: Optional[int] = None,
|
|
||||||
pad_val: Union[int, float] = 0,
|
|
||||||
seg_pad_val: Union[int, float] = 255) -> torch.Tensor:
|
|
||||||
"""Stack multiple inputs to form a batch and pad the images and gt_sem_segs
|
|
||||||
to the max shape use the right bottom padding mode.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs (List[Tensor]): The input multiple tensors. each is a
|
|
||||||
CHW 3D-tensor.
|
|
||||||
batch_data_samples (list[:obj:`SegDataSample`]): The Data
|
|
||||||
Samples. It usually includes information such as `gt_sem_seg`.
|
|
||||||
size (tuple, optional): Fixed padding size.
|
|
||||||
size_divisor (int, optional): The divisor of padded size.
|
|
||||||
pad_val (int, float): The padding value. Defaults to 0
|
|
||||||
seg_pad_val (int, float): The padding value. Defaults to 255
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: The 4D-tensor.
|
|
||||||
batch_data_samples (list[:obj:`SegDataSample`]): After the padding of
|
|
||||||
the gt_seg_map.
|
|
||||||
"""
|
|
||||||
assert isinstance(inputs, list), \
|
|
||||||
f'Expected input type to be list, but got {type(inputs)}'
|
|
||||||
assert len(set([tensor.ndim for tensor in inputs])) == 1, \
|
|
||||||
f'Expected the dimensions of all inputs must be the same, ' \
|
|
||||||
f'but got {[tensor.ndim for tensor in inputs]}'
|
|
||||||
assert inputs[0].ndim == 3, f'Expected tensor dimension to be 3, ' \
|
|
||||||
f'but got {inputs[0].ndim}'
|
|
||||||
assert len(set([tensor.shape[0] for tensor in inputs])) == 1, \
|
|
||||||
f'Expected the channels of all inputs must be the same, ' \
|
|
||||||
f'but got {[tensor.shape[0] for tensor in inputs]}'
|
|
||||||
|
|
||||||
# only one of size and size_divisor should be valid
|
|
||||||
assert (size is not None) ^ (size_divisor is not None), \
|
|
||||||
'only one of size and size_divisor should be valid'
|
|
||||||
|
|
||||||
padded_inputs = []
|
|
||||||
padded_samples = []
|
|
||||||
inputs_sizes = [(img.shape[-2], img.shape[-1]) for img in inputs]
|
|
||||||
max_size = np.stack(inputs_sizes).max(0)
|
|
||||||
if size_divisor is not None and size_divisor > 1:
|
|
||||||
# the last two dims are H,W, both subject to divisibility requirement
|
|
||||||
max_size = (max_size +
|
|
||||||
(size_divisor - 1)) // size_divisor * size_divisor
|
|
||||||
|
|
||||||
for i in range(len(inputs)):
|
|
||||||
tensor = inputs[i]
|
|
||||||
if size is not None:
|
|
||||||
width = max(size[-1] - tensor.shape[-1], 0)
|
|
||||||
height = max(size[-2] - tensor.shape[-2], 0)
|
|
||||||
# (padding_left, padding_right, padding_top, padding_bottom)
|
|
||||||
padding_size = (0, width, 0, height)
|
|
||||||
elif size_divisor is not None:
|
|
||||||
width = max(max_size[-1] - tensor.shape[-1], 0)
|
|
||||||
height = max(max_size[-2] - tensor.shape[-2], 0)
|
|
||||||
padding_size = (0, width, 0, height)
|
|
||||||
else:
|
|
||||||
padding_size = [0, 0, 0, 0]
|
|
||||||
|
|
||||||
# pad img
|
|
||||||
pad_img = F.pad(tensor, padding_size, value=pad_val)
|
|
||||||
padded_inputs.append(pad_img)
|
|
||||||
# pad gt_sem_seg
|
|
||||||
if batch_data_samples is not None:
|
|
||||||
data_sample = batch_data_samples[i]
|
|
||||||
gt_sem_seg = data_sample.gt_sem_seg.data
|
|
||||||
del data_sample.gt_sem_seg.data
|
|
||||||
data_sample.gt_sem_seg.data = F.pad(
|
|
||||||
gt_sem_seg, padding_size, value=seg_pad_val)
|
|
||||||
data_sample.set_metainfo(
|
|
||||||
{'pad_shape': data_sample.gt_sem_seg.shape})
|
|
||||||
padded_samples.append(data_sample)
|
|
||||||
else:
|
|
||||||
padded_samples = None
|
|
||||||
|
|
||||||
return torch.stack(padded_inputs, dim=0), padded_samples
|
|
8
mmseg/data/__init__.py
Normal file
8
mmseg/data/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from .sampler import BasePixelSampler, OHEMPixelSampler, build_pixel_sampler
|
||||||
|
from .seg_data_sample import SegDataSample
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'SegDataSample', 'BasePixelSampler', 'OHEMPixelSampler',
|
||||||
|
'build_pixel_sampler'
|
||||||
|
]
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from .base_pixel_sampler import BasePixelSampler
|
||||||
from .builder import build_pixel_sampler
|
from .builder import build_pixel_sampler
|
||||||
from .sampler import BasePixelSampler, OHEMPixelSampler
|
from .ohem_pixel_sampler import OHEMPixelSampler
|
||||||
|
|
||||||
__all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler']
|
__all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler']
|
@ -3,8 +3,8 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ..builder import PIXEL_SAMPLERS
|
|
||||||
from .base_pixel_sampler import BasePixelSampler
|
from .base_pixel_sampler import BasePixelSampler
|
||||||
|
from .builder import PIXEL_SAMPLERS
|
||||||
|
|
||||||
|
|
||||||
@PIXEL_SAMPLERS.register_module()
|
@PIXEL_SAMPLERS.register_module()
|
@ -1,9 +0,0 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
# flake8: noqa
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from .formatting import *
|
|
||||||
|
|
||||||
warnings.warn('DeprecationWarning: mmseg.datasets.pipelines.formating will be '
|
|
||||||
'deprecated in 2021, please replace it with '
|
|
||||||
'mmseg.datasets.pipelines.formatting.')
|
|
@ -1,4 +0,0 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
from .distributed_sampler import DistributedSampler
|
|
||||||
|
|
||||||
__all__ = ['DistributedSampler']
|
|
@ -1,73 +0,0 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
from __future__ import division
|
|
||||||
from typing import Iterator, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from torch.utils.data import DistributedSampler as _DistributedSampler
|
|
||||||
|
|
||||||
from mmseg.core.utils import sync_random_seed
|
|
||||||
from mmseg.registry import DATA_SAMPLERS
|
|
||||||
|
|
||||||
|
|
||||||
@DATA_SAMPLERS.register_module()
|
|
||||||
class DistributedSampler(_DistributedSampler):
|
|
||||||
"""DistributedSampler inheriting from
|
|
||||||
`torch.utils.data.DistributedSampler`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
datasets (Dataset): the dataset will be loaded.
|
|
||||||
num_replicas (int, optional): Number of processes participating in
|
|
||||||
distributed training. By default, world_size is retrieved from the
|
|
||||||
current distributed group.
|
|
||||||
rank (int, optional): Rank of the current process within num_replicas.
|
|
||||||
By default, rank is retrieved from the current distributed group.
|
|
||||||
shuffle (bool): If True (default), sampler will shuffle the indices.
|
|
||||||
seed (int): random seed used to shuffle the sampler if
|
|
||||||
:attr:`shuffle=True`. This number should be identical across all
|
|
||||||
processes in the distributed group. Default: ``0``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
dataset: Dataset,
|
|
||||||
num_replicas: Optional[int] = None,
|
|
||||||
rank: Optional[int] = None,
|
|
||||||
shuffle: bool = True,
|
|
||||||
seed=0) -> None:
|
|
||||||
super().__init__(
|
|
||||||
dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
|
||||||
|
|
||||||
# In distributed sampling, different ranks should sample
|
|
||||||
# non-overlapped data in the dataset. Therefore, this function
|
|
||||||
# is used to make sure that each rank shuffles the data indices
|
|
||||||
# in the same order based on the same seed. Then different ranks
|
|
||||||
# could use different indices to select non-overlapped data from the
|
|
||||||
# same data list.
|
|
||||||
self.seed = sync_random_seed(seed)
|
|
||||||
|
|
||||||
def __iter__(self) -> Iterator:
|
|
||||||
"""
|
|
||||||
Yields:
|
|
||||||
Iterator: iterator of indices for rank.
|
|
||||||
"""
|
|
||||||
# deterministically shuffle based on epoch
|
|
||||||
if self.shuffle:
|
|
||||||
g = torch.Generator()
|
|
||||||
# When :attr:`shuffle=True`, this ensures all replicas
|
|
||||||
# use a different random ordering for each epoch.
|
|
||||||
# Otherwise, the next iteration of this sampler will
|
|
||||||
# yield the same ordering.
|
|
||||||
g.manual_seed(self.epoch + self.seed)
|
|
||||||
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
|
||||||
else:
|
|
||||||
indices = torch.arange(len(self.dataset)).tolist()
|
|
||||||
|
|
||||||
# add extra samples to make it evenly divisible
|
|
||||||
indices += indices[:(self.total_size - len(indices))]
|
|
||||||
assert len(indices) == self.total_size
|
|
||||||
|
|
||||||
# subsample
|
|
||||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
|
||||||
assert len(indices) == self.num_samples
|
|
||||||
|
|
||||||
return iter(indices)
|
|
@ -5,7 +5,7 @@ from mmcv.transforms import to_tensor
|
|||||||
from mmcv.transforms.base import BaseTransform
|
from mmcv.transforms.base import BaseTransform
|
||||||
from mmengine.data import PixelData
|
from mmengine.data import PixelData
|
||||||
|
|
||||||
from mmseg.core import SegDataSample
|
from mmseg.data import SegDataSample
|
||||||
from mmseg.registry import TRANSFORMS
|
from mmseg.registry import TRANSFORMS
|
||||||
|
|
||||||
|
|
@ -3,10 +3,10 @@ import json
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from mmengine.dist import get_dist_info
|
from mmengine.dist import get_dist_info
|
||||||
|
from mmengine.logging import print_log
|
||||||
from mmengine.optim import DefaultOptimWrapperConstructor
|
from mmengine.optim import DefaultOptimWrapperConstructor
|
||||||
|
|
||||||
from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS
|
from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS
|
||||||
from mmseg.utils import get_root_logger
|
|
||||||
|
|
||||||
|
|
||||||
def get_layer_id_for_convnext(var_name, max_layer_id):
|
def get_layer_id_for_convnext(var_name, max_layer_id):
|
||||||
@ -119,14 +119,13 @@ class LearningRateDecayOptimizerConstructor(DefaultOptimWrapperConstructor):
|
|||||||
in place.
|
in place.
|
||||||
module (nn.Module): The module to be added.
|
module (nn.Module): The module to be added.
|
||||||
"""
|
"""
|
||||||
logger = get_root_logger()
|
|
||||||
|
|
||||||
parameter_groups = {}
|
parameter_groups = {}
|
||||||
logger.info(f'self.paramwise_cfg is {self.paramwise_cfg}')
|
print_log(f'self.paramwise_cfg is {self.paramwise_cfg}')
|
||||||
num_layers = self.paramwise_cfg.get('num_layers') + 2
|
num_layers = self.paramwise_cfg.get('num_layers') + 2
|
||||||
decay_rate = self.paramwise_cfg.get('decay_rate')
|
decay_rate = self.paramwise_cfg.get('decay_rate')
|
||||||
decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise')
|
decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise')
|
||||||
logger.info('Build LearningRateDecayOptimizerConstructor '
|
print_log('Build LearningRateDecayOptimizerConstructor '
|
||||||
f'{decay_type} {decay_rate} - {num_layers}')
|
f'{decay_type} {decay_rate} - {num_layers}')
|
||||||
weight_decay = self.base_wd
|
weight_decay = self.base_wd
|
||||||
for name, param in module.named_parameters():
|
for name, param in module.named_parameters():
|
||||||
@ -143,17 +142,17 @@ class LearningRateDecayOptimizerConstructor(DefaultOptimWrapperConstructor):
|
|||||||
if 'ConvNeXt' in module.backbone.__class__.__name__:
|
if 'ConvNeXt' in module.backbone.__class__.__name__:
|
||||||
layer_id = get_layer_id_for_convnext(
|
layer_id = get_layer_id_for_convnext(
|
||||||
name, self.paramwise_cfg.get('num_layers'))
|
name, self.paramwise_cfg.get('num_layers'))
|
||||||
logger.info(f'set param {name} as id {layer_id}')
|
print_log(f'set param {name} as id {layer_id}')
|
||||||
elif 'BEiT' in module.backbone.__class__.__name__ or \
|
elif 'BEiT' in module.backbone.__class__.__name__ or \
|
||||||
'MAE' in module.backbone.__class__.__name__:
|
'MAE' in module.backbone.__class__.__name__:
|
||||||
layer_id = get_layer_id_for_vit(name, num_layers)
|
layer_id = get_layer_id_for_vit(name, num_layers)
|
||||||
logger.info(f'set param {name} as id {layer_id}')
|
print_log(f'set param {name} as id {layer_id}')
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
elif decay_type == 'stage_wise':
|
elif decay_type == 'stage_wise':
|
||||||
if 'ConvNeXt' in module.backbone.__class__.__name__:
|
if 'ConvNeXt' in module.backbone.__class__.__name__:
|
||||||
layer_id = get_stage_id_for_convnext(name, num_layers)
|
layer_id = get_stage_id_for_convnext(name, num_layers)
|
||||||
logger.info(f'set param {name} as id {layer_id}')
|
print_log(f'set param {name} as id {layer_id}')
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
group_name = f'layer_{layer_id}_{group_name}'
|
group_name = f'layer_{layer_id}_{group_name}'
|
||||||
@ -182,7 +181,7 @@ class LearningRateDecayOptimizerConstructor(DefaultOptimWrapperConstructor):
|
|||||||
'lr': parameter_groups[key]['lr'],
|
'lr': parameter_groups[key]['lr'],
|
||||||
'weight_decay': parameter_groups[key]['weight_decay'],
|
'weight_decay': parameter_groups[key]['weight_decay'],
|
||||||
}
|
}
|
||||||
logger.info(f'Param groups = {json.dumps(to_display, indent=2)}')
|
print_log(f'Param groups = {json.dumps(to_display, indent=2)}')
|
||||||
params.extend(parameter_groups.values())
|
params.extend(parameter_groups.values())
|
||||||
|
|
||||||
|
|
@ -14,7 +14,6 @@ from torch.nn.modules.batchnorm import _BatchNorm
|
|||||||
from torch.nn.modules.utils import _pair as to_2tuple
|
from torch.nn.modules.utils import _pair as to_2tuple
|
||||||
|
|
||||||
from mmseg.registry import MODELS
|
from mmseg.registry import MODELS
|
||||||
from mmseg.utils import get_root_logger
|
|
||||||
from ..utils import PatchEmbed
|
from ..utils import PatchEmbed
|
||||||
from .vit import TransformerEncoderLayer as VisionTransformerEncoderLayer
|
from .vit import TransformerEncoderLayer as VisionTransformerEncoderLayer
|
||||||
|
|
||||||
@ -500,9 +499,8 @@ class BEiT(BaseModule):
|
|||||||
|
|
||||||
if (isinstance(self.init_cfg, dict)
|
if (isinstance(self.init_cfg, dict)
|
||||||
and self.init_cfg.get('type') == 'Pretrained'):
|
and self.init_cfg.get('type') == 'Pretrained'):
|
||||||
logger = get_root_logger()
|
|
||||||
checkpoint = _load_checkpoint(
|
checkpoint = _load_checkpoint(
|
||||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||||
state_dict = self.resize_rel_pos_embed(checkpoint)
|
state_dict = self.resize_rel_pos_embed(checkpoint)
|
||||||
self.load_state_dict(state_dict, False)
|
self.load_state_dict(state_dict, False)
|
||||||
elif self.init_cfg is not None:
|
elif self.init_cfg is not None:
|
||||||
|
@ -9,7 +9,6 @@ from mmcv.runner import ModuleList, _load_checkpoint
|
|||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
from mmseg.registry import MODELS
|
from mmseg.registry import MODELS
|
||||||
from mmseg.utils import get_root_logger
|
|
||||||
from .beit import BEiT, BEiTAttention, BEiTTransformerEncoderLayer
|
from .beit import BEiT, BEiTAttention, BEiTTransformerEncoderLayer
|
||||||
|
|
||||||
|
|
||||||
@ -180,9 +179,8 @@ class MAE(BEiT):
|
|||||||
|
|
||||||
if (isinstance(self.init_cfg, dict)
|
if (isinstance(self.init_cfg, dict)
|
||||||
and self.init_cfg.get('type') == 'Pretrained'):
|
and self.init_cfg.get('type') == 'Pretrained'):
|
||||||
logger = get_root_logger()
|
|
||||||
checkpoint = _load_checkpoint(
|
checkpoint = _load_checkpoint(
|
||||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||||
state_dict = self.resize_rel_pos_embed(checkpoint)
|
state_dict = self.resize_rel_pos_embed(checkpoint)
|
||||||
state_dict = self.resize_abs_pos_embed(state_dict)
|
state_dict = self.resize_abs_pos_embed(state_dict)
|
||||||
self.load_state_dict(state_dict, False)
|
self.load_state_dict(state_dict, False)
|
||||||
|
@ -14,9 +14,9 @@ from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_,
|
|||||||
from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList,
|
from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList,
|
||||||
load_state_dict)
|
load_state_dict)
|
||||||
from mmcv.utils import to_2tuple
|
from mmcv.utils import to_2tuple
|
||||||
|
from mmengine.logging import print_log
|
||||||
|
|
||||||
from mmseg.registry import MODELS
|
from mmseg.registry import MODELS
|
||||||
from ...utils import get_root_logger
|
|
||||||
from ..utils.embed import PatchEmbed, PatchMerging
|
from ..utils.embed import PatchEmbed, PatchMerging
|
||||||
|
|
||||||
|
|
||||||
@ -662,9 +662,8 @@ class SwinTransformer(BaseModule):
|
|||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
logger = get_root_logger()
|
|
||||||
if self.init_cfg is None:
|
if self.init_cfg is None:
|
||||||
logger.warn(f'No pre-trained weights for '
|
print_log(f'No pre-trained weights for '
|
||||||
f'{self.__class__.__name__}, '
|
f'{self.__class__.__name__}, '
|
||||||
f'training start from scratch')
|
f'training start from scratch')
|
||||||
if self.use_abs_pos_embed:
|
if self.use_abs_pos_embed:
|
||||||
@ -680,7 +679,7 @@ class SwinTransformer(BaseModule):
|
|||||||
f'`init_cfg` in ' \
|
f'`init_cfg` in ' \
|
||||||
f'{self.__class__.__name__} '
|
f'{self.__class__.__name__} '
|
||||||
ckpt = CheckpointLoader.load_checkpoint(
|
ckpt = CheckpointLoader.load_checkpoint(
|
||||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||||
if 'state_dict' in ckpt:
|
if 'state_dict' in ckpt:
|
||||||
_state_dict = ckpt['state_dict']
|
_state_dict = ckpt['state_dict']
|
||||||
elif 'model' in ckpt:
|
elif 'model' in ckpt:
|
||||||
@ -705,7 +704,7 @@ class SwinTransformer(BaseModule):
|
|||||||
N1, L, C1 = absolute_pos_embed.size()
|
N1, L, C1 = absolute_pos_embed.size()
|
||||||
N2, C2, H, W = self.absolute_pos_embed.size()
|
N2, C2, H, W = self.absolute_pos_embed.size()
|
||||||
if N1 != N2 or C1 != C2 or L != H * W:
|
if N1 != N2 or C1 != C2 or L != H * W:
|
||||||
logger.warning('Error in loading absolute_pos_embed, pass')
|
print_log('Error in loading absolute_pos_embed, pass')
|
||||||
else:
|
else:
|
||||||
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
|
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
|
||||||
N2, H, W, C2).permute(0, 3, 1, 2).contiguous()
|
N2, H, W, C2).permute(0, 3, 1, 2).contiguous()
|
||||||
@ -721,7 +720,7 @@ class SwinTransformer(BaseModule):
|
|||||||
L1, nH1 = table_pretrained.size()
|
L1, nH1 = table_pretrained.size()
|
||||||
L2, nH2 = table_current.size()
|
L2, nH2 = table_current.size()
|
||||||
if nH1 != nH2:
|
if nH1 != nH2:
|
||||||
logger.warning(f'Error in loading {table_key}, pass')
|
print_log(f'Error in loading {table_key}, pass')
|
||||||
elif L1 != L2:
|
elif L1 != L2:
|
||||||
S1 = int(L1**0.5)
|
S1 = int(L1**0.5)
|
||||||
S2 = int(L2**0.5)
|
S2 = int(L2**0.5)
|
||||||
@ -733,7 +732,7 @@ class SwinTransformer(BaseModule):
|
|||||||
nH2, L2).permute(1, 0).contiguous()
|
nH2, L2).permute(1, 0).contiguous()
|
||||||
|
|
||||||
# load state_dict
|
# load state_dict
|
||||||
load_state_dict(self, state_dict, strict=False, logger=logger)
|
load_state_dict(self, state_dict, strict=False, logger=None)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x, hw_shape = self.patch_embed(x)
|
x, hw_shape = self.patch_embed(x)
|
||||||
|
@ -11,12 +11,12 @@ from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
|
|||||||
trunc_normal_)
|
trunc_normal_)
|
||||||
from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList,
|
from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList,
|
||||||
load_state_dict)
|
load_state_dict)
|
||||||
|
from mmengine.logging import print_log
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
from torch.nn.modules.utils import _pair as to_2tuple
|
from torch.nn.modules.utils import _pair as to_2tuple
|
||||||
|
|
||||||
from mmseg.ops import resize
|
from mmseg.ops import resize
|
||||||
from mmseg.registry import MODELS
|
from mmseg.registry import MODELS
|
||||||
from mmseg.utils import get_root_logger
|
|
||||||
from ..utils import PatchEmbed
|
from ..utils import PatchEmbed
|
||||||
|
|
||||||
|
|
||||||
@ -293,9 +293,8 @@ class VisionTransformer(BaseModule):
|
|||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
if (isinstance(self.init_cfg, dict)
|
if (isinstance(self.init_cfg, dict)
|
||||||
and self.init_cfg.get('type') == 'Pretrained'):
|
and self.init_cfg.get('type') == 'Pretrained'):
|
||||||
logger = get_root_logger()
|
|
||||||
checkpoint = CheckpointLoader.load_checkpoint(
|
checkpoint = CheckpointLoader.load_checkpoint(
|
||||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||||
|
|
||||||
if 'state_dict' in checkpoint:
|
if 'state_dict' in checkpoint:
|
||||||
state_dict = checkpoint['state_dict']
|
state_dict = checkpoint['state_dict']
|
||||||
@ -304,7 +303,7 @@ class VisionTransformer(BaseModule):
|
|||||||
|
|
||||||
if 'pos_embed' in state_dict.keys():
|
if 'pos_embed' in state_dict.keys():
|
||||||
if self.pos_embed.shape != state_dict['pos_embed'].shape:
|
if self.pos_embed.shape != state_dict['pos_embed'].shape:
|
||||||
logger.info(msg=f'Resize the pos_embed shape from '
|
print_log(msg=f'Resize the pos_embed shape from '
|
||||||
f'{state_dict["pos_embed"].shape} to '
|
f'{state_dict["pos_embed"].shape} to '
|
||||||
f'{self.pos_embed.shape}')
|
f'{self.pos_embed.shape}')
|
||||||
h, w = self.img_size
|
h, w = self.img_size
|
||||||
@ -315,7 +314,7 @@ class VisionTransformer(BaseModule):
|
|||||||
(h // self.patch_size, w // self.patch_size),
|
(h // self.patch_size, w // self.patch_size),
|
||||||
(pos_size, pos_size), self.interpolate_mode)
|
(pos_size, pos_size), self.interpolate_mode)
|
||||||
|
|
||||||
load_state_dict(self, state_dict, strict=False, logger=logger)
|
load_state_dict(self, state_dict, strict=False, logger=None)
|
||||||
elif self.init_cfg is not None:
|
elif self.init_cfg is not None:
|
||||||
super(VisionTransformer, self).init_weights()
|
super(VisionTransformer, self).init_weights()
|
||||||
else:
|
else:
|
||||||
|
@ -6,9 +6,8 @@ import torch
|
|||||||
from mmengine.model import BaseDataPreprocessor
|
from mmengine.model import BaseDataPreprocessor
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from mmseg.core import stack_batch
|
|
||||||
from mmseg.core.utils import OptSampleList
|
|
||||||
from mmseg.registry import MODELS
|
from mmseg.registry import MODELS
|
||||||
|
from mmseg.utils import OptSampleList, stack_batch
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
|
@ -4,7 +4,7 @@ from typing import List
|
|||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from mmseg.core.utils import ConfigType
|
from mmseg.utils import ConfigType
|
||||||
from .decode_head import BaseDecodeHead
|
from .decode_head import BaseDecodeHead
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,9 +6,8 @@ import torch.nn.functional as F
|
|||||||
from mmcv.cnn import ConvModule, Scale
|
from mmcv.cnn import ConvModule, Scale
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from mmseg.core import add_prefix
|
|
||||||
from mmseg.core.utils import SampleList
|
|
||||||
from mmseg.registry import MODELS
|
from mmseg.registry import MODELS
|
||||||
|
from mmseg.utils import SampleList, add_prefix
|
||||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||||
from .decode_head import BaseDecodeHead
|
from .decode_head import BaseDecodeHead
|
||||||
|
|
||||||
|
@ -7,10 +7,9 @@ import torch.nn as nn
|
|||||||
from mmcv.runner import BaseModule
|
from mmcv.runner import BaseModule
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from mmseg.core import build_pixel_sampler
|
from mmseg.data import build_pixel_sampler
|
||||||
from mmseg.core.utils import SampleList
|
|
||||||
from mmseg.core.utils.typing import ConfigType
|
|
||||||
from mmseg.ops import resize
|
from mmseg.ops import resize
|
||||||
|
from mmseg.utils import ConfigType, SampleList
|
||||||
from ..builder import build_loss
|
from ..builder import build_loss
|
||||||
from ..losses import accuracy
|
from ..losses import accuracy
|
||||||
|
|
||||||
|
@ -7,10 +7,9 @@ import torch.nn.functional as F
|
|||||||
from mmcv.cnn import ConvModule, build_norm_layer
|
from mmcv.cnn import ConvModule, build_norm_layer
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from mmseg.core.utils import SampleList
|
|
||||||
from mmseg.core.utils.typing import ConfigType
|
|
||||||
from mmseg.ops import Encoding, resize
|
from mmseg.ops import Encoding, resize
|
||||||
from mmseg.registry import MODELS
|
from mmseg.registry import MODELS
|
||||||
|
from mmseg.utils import ConfigType, SampleList
|
||||||
from ..builder import build_loss
|
from ..builder import build_loss
|
||||||
from .decode_head import BaseDecodeHead
|
from .decode_head import BaseDecodeHead
|
||||||
|
|
||||||
|
@ -8,12 +8,12 @@ from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
|||||||
from mmcv.cnn.bricks.transformer import (FFN, TRANSFORMER_LAYER,
|
from mmcv.cnn.bricks.transformer import (FFN, TRANSFORMER_LAYER,
|
||||||
MultiheadAttention,
|
MultiheadAttention,
|
||||||
build_transformer_layer)
|
build_transformer_layer)
|
||||||
|
from mmengine.logging import print_log
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from mmseg.core.utils import SampleList
|
|
||||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||||
from mmseg.registry import MODELS
|
from mmseg.registry import MODELS
|
||||||
from mmseg.utils import get_root_logger
|
from mmseg.utils import SampleList
|
||||||
|
|
||||||
|
|
||||||
@TRANSFORMER_LAYER.register_module()
|
@TRANSFORMER_LAYER.register_module()
|
||||||
@ -276,8 +276,7 @@ class KernelUpdateHead(nn.Module):
|
|||||||
# the weight and bias of the layer norm
|
# the weight and bias of the layer norm
|
||||||
pass
|
pass
|
||||||
if self.kernel_init:
|
if self.kernel_init:
|
||||||
logger = get_root_logger()
|
print_log(
|
||||||
logger.info(
|
|
||||||
'mask kernel in mask head is normal initialized by std 0.01')
|
'mask kernel in mask head is normal initialized by std 0.01')
|
||||||
nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01)
|
nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01)
|
||||||
|
|
||||||
|
@ -12,9 +12,9 @@ except ModuleNotFoundError:
|
|||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from mmseg.core.utils import SampleList
|
|
||||||
from mmseg.ops import resize
|
from mmseg.ops import resize
|
||||||
from mmseg.registry import MODELS
|
from mmseg.registry import MODELS
|
||||||
|
from mmseg.utils import SampleList
|
||||||
from ..losses import accuracy
|
from ..losses import accuracy
|
||||||
from .cascade_decode_head import BaseCascadeDecodeHead
|
from .cascade_decode_head import BaseCascadeDecodeHead
|
||||||
|
|
||||||
|
@ -4,9 +4,9 @@ import torch.nn.functional as F
|
|||||||
from mmengine.data import PixelData
|
from mmengine.data import PixelData
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from mmseg.core.data_structures.seg_data_sample import SegDataSample
|
from mmseg.data import SegDataSample
|
||||||
from mmseg.core.utils import SampleList
|
|
||||||
from mmseg.registry import MODELS
|
from mmseg.registry import MODELS
|
||||||
|
from mmseg.utils import SampleList
|
||||||
from .fcn_head import FCNHead
|
from .fcn_head import FCNHead
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,10 +6,10 @@ from mmengine.data import PixelData
|
|||||||
from mmengine.model import BaseModel
|
from mmengine.model import BaseModel
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from mmseg.core import SegDataSample
|
from mmseg.data import SegDataSample
|
||||||
from mmseg.core.utils import (ForwardResults, OptConfigType, OptMultiConfig,
|
|
||||||
OptSampleList, SampleList)
|
|
||||||
from mmseg.ops import resize
|
from mmseg.ops import resize
|
||||||
|
from mmseg.utils import (ForwardResults, OptConfigType, OptMultiConfig,
|
||||||
|
OptSampleList, SampleList)
|
||||||
|
|
||||||
|
|
||||||
class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
||||||
|
@ -3,10 +3,9 @@ from typing import List, Optional
|
|||||||
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from mmseg.core import add_prefix
|
|
||||||
from mmseg.core.utils import ConfigType, OptSampleList, SampleList
|
|
||||||
from mmseg.core.utils.typing import OptConfigType, OptMultiConfig
|
|
||||||
from mmseg.registry import MODELS
|
from mmseg.registry import MODELS
|
||||||
|
from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig,
|
||||||
|
OptSampleList, SampleList, add_prefix)
|
||||||
from .encoder_decoder import EncoderDecoder
|
from .encoder_decoder import EncoderDecoder
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,10 +6,9 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from mmseg.core import add_prefix
|
|
||||||
from mmseg.core.utils import (ConfigType, OptConfigType, OptMultiConfig,
|
|
||||||
OptSampleList, SampleList)
|
|
||||||
from mmseg.registry import MODELS
|
from mmseg.registry import MODELS
|
||||||
|
from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig,
|
||||||
|
OptSampleList, SampleList, add_prefix)
|
||||||
from .base import BaseSegmentor
|
from .base import BaseSegmentor
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,10 +1,29 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
# yapf: disable
|
||||||
|
from .class_names import (ade_classes, ade_palette, cityscapes_classes,
|
||||||
|
cityscapes_palette, cocostuff_classes,
|
||||||
|
cocostuff_palette, dataset_aliases, get_classes,
|
||||||
|
get_palette, isaid_classes, isaid_palette,
|
||||||
|
loveda_classes, loveda_palette, potsdam_classes,
|
||||||
|
potsdam_palette, stare_classes, stare_palette,
|
||||||
|
vaihingen_classes, vaihingen_palette, voc_classes,
|
||||||
|
voc_palette)
|
||||||
|
# yapf: enable
|
||||||
from .collect_env import collect_env
|
from .collect_env import collect_env
|
||||||
from .logger import get_root_logger
|
from .misc import add_prefix, stack_batch
|
||||||
from .misc import find_latest_checkpoint
|
from .set_env import register_all_modules
|
||||||
from .set_env import register_all_modules, setup_multi_processes
|
from .typing import (ConfigType, ForwardResults, MultiConfig, OptConfigType,
|
||||||
|
OptMultiConfig, OptSampleList, SampleList, TensorDict,
|
||||||
|
TensorList)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'get_root_logger', 'collect_env', 'find_latest_checkpoint',
|
'collect_env', 'register_all_modules', 'stack_batch', 'add_prefix',
|
||||||
'setup_multi_processes', 'register_all_modules'
|
'ConfigType', 'OptConfigType', 'MultiConfig', 'OptMultiConfig',
|
||||||
|
'SampleList', 'OptSampleList', 'TensorDict', 'TensorList',
|
||||||
|
'ForwardResults', 'cityscapes_classes', 'ade_classes', 'voc_classes',
|
||||||
|
'cocostuff_classes', 'loveda_classes', 'potsdam_classes',
|
||||||
|
'vaihingen_classes', 'isaid_classes', 'stare_classes',
|
||||||
|
'cityscapes_palette', 'ade_palette', 'voc_palette', 'cocostuff_palette',
|
||||||
|
'loveda_palette', 'potsdam_palette', 'vaihingen_palette', 'isaid_palette',
|
||||||
|
'stare_palette', 'dataset_aliases', 'get_classes', 'get_palette'
|
||||||
]
|
]
|
||||||
|
@ -1,28 +0,0 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from mmcv.utils import get_logger
|
|
||||||
|
|
||||||
|
|
||||||
def get_root_logger(log_file=None, log_level=logging.INFO):
|
|
||||||
"""Get the root logger.
|
|
||||||
|
|
||||||
The logger will be initialized if it has not been initialized. By default a
|
|
||||||
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
|
||||||
also be added. The name of the root logger is the top-level package name,
|
|
||||||
e.g., "mmseg".
|
|
||||||
|
|
||||||
Args:
|
|
||||||
log_file (str | None): The log filename. If specified, a FileHandler
|
|
||||||
will be added to the root logger.
|
|
||||||
log_level (int): The root logger level. Note that only the process of
|
|
||||||
rank 0 is affected, while other processes will set the level to
|
|
||||||
"Error" and be silent most of the time.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
logging.Logger: The root logger.
|
|
||||||
"""
|
|
||||||
|
|
||||||
logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level)
|
|
||||||
|
|
||||||
return logger
|
|
@ -1,41 +1,108 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import glob
|
from typing import List, Optional, Union
|
||||||
import os.path as osp
|
|
||||||
import warnings
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .typing import SampleList
|
||||||
|
|
||||||
|
|
||||||
def find_latest_checkpoint(path, suffix='pth'):
|
def add_prefix(inputs, prefix):
|
||||||
"""This function is for finding the latest checkpoint.
|
"""Add prefix for dict.
|
||||||
|
|
||||||
It will be used when automatically resume, modified from
|
|
||||||
https://github.com/open-mmlab/mmdetection/blob/dev-v2.20.0/mmdet/utils/misc.py
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path (str): The path to find checkpoints.
|
inputs (dict): The input dict with str keys.
|
||||||
suffix (str): File extension for the checkpoint. Defaults to pth.
|
prefix (str): The prefix to add.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
latest_path(str | None): File path of the latest checkpoint.
|
|
||||||
"""
|
|
||||||
if not osp.exists(path):
|
|
||||||
warnings.warn("The path of the checkpoints doesn't exist.")
|
|
||||||
return None
|
|
||||||
if osp.exists(osp.join(path, f'latest.{suffix}')):
|
|
||||||
return osp.join(path, f'latest.{suffix}')
|
|
||||||
|
|
||||||
checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
|
dict: The dict with keys updated with ``prefix``.
|
||||||
if len(checkpoints) == 0:
|
"""
|
||||||
warnings.warn('The are no checkpoints in the path')
|
|
||||||
return None
|
outputs = dict()
|
||||||
latest = -1
|
for name, value in inputs.items():
|
||||||
latest_path = ''
|
outputs[f'{prefix}.{name}'] = value
|
||||||
for checkpoint in checkpoints:
|
|
||||||
if len(checkpoint) < len(latest_path):
|
return outputs
|
||||||
continue
|
|
||||||
# `count` is iteration number, as checkpoints are saved as
|
|
||||||
# 'iter_xx.pth' or 'epoch_xx.pth' and xx is iteration number.
|
def stack_batch(inputs: List[torch.Tensor],
|
||||||
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
|
batch_data_samples: Optional[SampleList] = None,
|
||||||
if count > latest:
|
size: Optional[tuple] = None,
|
||||||
latest = count
|
size_divisor: Optional[int] = None,
|
||||||
latest_path = checkpoint
|
pad_val: Union[int, float] = 0,
|
||||||
return latest_path
|
seg_pad_val: Union[int, float] = 255) -> torch.Tensor:
|
||||||
|
"""Stack multiple inputs to form a batch and pad the images and gt_sem_segs
|
||||||
|
to the max shape use the right bottom padding mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (List[Tensor]): The input multiple tensors. each is a
|
||||||
|
CHW 3D-tensor.
|
||||||
|
batch_data_samples (list[:obj:`SegDataSample`]): The Data
|
||||||
|
Samples. It usually includes information such as `gt_sem_seg`.
|
||||||
|
size (tuple, optional): Fixed padding size.
|
||||||
|
size_divisor (int, optional): The divisor of padded size.
|
||||||
|
pad_val (int, float): The padding value. Defaults to 0
|
||||||
|
seg_pad_val (int, float): The padding value. Defaults to 255
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The 4D-tensor.
|
||||||
|
batch_data_samples (list[:obj:`SegDataSample`]): After the padding of
|
||||||
|
the gt_seg_map.
|
||||||
|
"""
|
||||||
|
assert isinstance(inputs, list), \
|
||||||
|
f'Expected input type to be list, but got {type(inputs)}'
|
||||||
|
assert len(set([tensor.ndim for tensor in inputs])) == 1, \
|
||||||
|
f'Expected the dimensions of all inputs must be the same, ' \
|
||||||
|
f'but got {[tensor.ndim for tensor in inputs]}'
|
||||||
|
assert inputs[0].ndim == 3, f'Expected tensor dimension to be 3, ' \
|
||||||
|
f'but got {inputs[0].ndim}'
|
||||||
|
assert len(set([tensor.shape[0] for tensor in inputs])) == 1, \
|
||||||
|
f'Expected the channels of all inputs must be the same, ' \
|
||||||
|
f'but got {[tensor.shape[0] for tensor in inputs]}'
|
||||||
|
|
||||||
|
# only one of size and size_divisor should be valid
|
||||||
|
assert (size is not None) ^ (size_divisor is not None), \
|
||||||
|
'only one of size and size_divisor should be valid'
|
||||||
|
|
||||||
|
padded_inputs = []
|
||||||
|
padded_samples = []
|
||||||
|
inputs_sizes = [(img.shape[-2], img.shape[-1]) for img in inputs]
|
||||||
|
max_size = np.stack(inputs_sizes).max(0)
|
||||||
|
if size_divisor is not None and size_divisor > 1:
|
||||||
|
# the last two dims are H,W, both subject to divisibility requirement
|
||||||
|
max_size = (max_size +
|
||||||
|
(size_divisor - 1)) // size_divisor * size_divisor
|
||||||
|
|
||||||
|
for i in range(len(inputs)):
|
||||||
|
tensor = inputs[i]
|
||||||
|
if size is not None:
|
||||||
|
width = max(size[-1] - tensor.shape[-1], 0)
|
||||||
|
height = max(size[-2] - tensor.shape[-2], 0)
|
||||||
|
# (padding_left, padding_right, padding_top, padding_bottom)
|
||||||
|
padding_size = (0, width, 0, height)
|
||||||
|
elif size_divisor is not None:
|
||||||
|
width = max(max_size[-1] - tensor.shape[-1], 0)
|
||||||
|
height = max(max_size[-2] - tensor.shape[-2], 0)
|
||||||
|
padding_size = (0, width, 0, height)
|
||||||
|
else:
|
||||||
|
padding_size = [0, 0, 0, 0]
|
||||||
|
|
||||||
|
# pad img
|
||||||
|
pad_img = F.pad(tensor, padding_size, value=pad_val)
|
||||||
|
padded_inputs.append(pad_img)
|
||||||
|
# pad gt_sem_seg
|
||||||
|
if batch_data_samples is not None:
|
||||||
|
data_sample = batch_data_samples[i]
|
||||||
|
gt_sem_seg = data_sample.gt_sem_seg.data
|
||||||
|
del data_sample.gt_sem_seg.data
|
||||||
|
data_sample.gt_sem_seg.data = F.pad(
|
||||||
|
gt_sem_seg, padding_size, value=seg_pad_val)
|
||||||
|
data_sample.set_metainfo(
|
||||||
|
{'pad_shape': data_sample.gt_sem_seg.shape})
|
||||||
|
padded_samples.append(data_sample)
|
||||||
|
else:
|
||||||
|
padded_samples = None
|
||||||
|
|
||||||
|
return torch.stack(padded_inputs, dim=0), padded_samples
|
||||||
|
@ -1,62 +1,9 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import datetime
|
import datetime
|
||||||
import os
|
|
||||||
import platform
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import cv2
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
from mmengine import DefaultScope
|
from mmengine import DefaultScope
|
||||||
|
|
||||||
from ..utils import get_root_logger
|
|
||||||
|
|
||||||
|
|
||||||
def setup_multi_processes(cfg):
|
|
||||||
"""Setup multi-processing environment variables."""
|
|
||||||
logger = get_root_logger()
|
|
||||||
|
|
||||||
# set multi-process start method
|
|
||||||
if platform.system() != 'Windows':
|
|
||||||
mp_start_method = cfg.get('mp_start_method', None)
|
|
||||||
current_method = mp.get_start_method(allow_none=True)
|
|
||||||
if mp_start_method in ('fork', 'spawn', 'forkserver'):
|
|
||||||
logger.info(
|
|
||||||
f'Multi-processing start method `{mp_start_method}` is '
|
|
||||||
f'different from the previous setting `{current_method}`.'
|
|
||||||
f'It will be force set to `{mp_start_method}`.')
|
|
||||||
mp.set_start_method(mp_start_method, force=True)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
f'Multi-processing start method is `{mp_start_method}`')
|
|
||||||
|
|
||||||
# disable opencv multithreading to avoid system being overloaded
|
|
||||||
opencv_num_threads = cfg.get('opencv_num_threads', None)
|
|
||||||
if isinstance(opencv_num_threads, int):
|
|
||||||
logger.info(f'OpenCV num_threads is `{opencv_num_threads}`')
|
|
||||||
cv2.setNumThreads(opencv_num_threads)
|
|
||||||
else:
|
|
||||||
logger.info(f'OpenCV num_threads is `{cv2.getNumThreads}')
|
|
||||||
|
|
||||||
if cfg.data.workers_per_gpu > 1:
|
|
||||||
# setup OMP threads
|
|
||||||
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
|
|
||||||
omp_num_threads = cfg.get('omp_num_threads', None)
|
|
||||||
if 'OMP_NUM_THREADS' not in os.environ:
|
|
||||||
if isinstance(omp_num_threads, int):
|
|
||||||
logger.info(f'OMP num threads is {omp_num_threads}')
|
|
||||||
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
|
|
||||||
else:
|
|
||||||
logger.info(f'OMP num threads is {os.environ["OMP_NUM_THREADS"] }')
|
|
||||||
|
|
||||||
# setup MKL threads
|
|
||||||
if 'MKL_NUM_THREADS' not in os.environ:
|
|
||||||
mkl_num_threads = cfg.get('mkl_num_threads', None)
|
|
||||||
if isinstance(mkl_num_threads, int):
|
|
||||||
logger.info(f'MKL num threads is {mkl_num_threads}')
|
|
||||||
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
|
|
||||||
else:
|
|
||||||
logger.info(f'MKL num threads is {os.environ["MKL_NUM_THREADS"]}')
|
|
||||||
|
|
||||||
|
|
||||||
def register_all_modules(init_default_scope: bool = True) -> None:
|
def register_all_modules(init_default_scope: bool = True) -> None:
|
||||||
"""Register all modules in mmseg into the registries.
|
"""Register all modules in mmseg into the registries.
|
||||||
@ -69,9 +16,9 @@ def register_all_modules(init_default_scope: bool = True) -> None:
|
|||||||
to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md
|
to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md
|
||||||
Defaults to True.
|
Defaults to True.
|
||||||
""" # noqa
|
""" # noqa
|
||||||
import mmseg.core # noqa: F401,F403
|
import mmseg.data # noqa: F401,F403
|
||||||
import mmseg.datasets # noqa: F401,F403
|
import mmseg.datasets # noqa: F401,F403
|
||||||
import mmseg.datasets.pipelines # noqa: F401,F403
|
import mmseg.engine # noqa: F401,F403
|
||||||
import mmseg.metrics # noqa: F401,F403
|
import mmseg.metrics # noqa: F401,F403
|
||||||
import mmseg.models # noqa: F401,F403
|
import mmseg.models # noqa: F401,F403
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from mmengine.config import ConfigDict
|
from mmengine.config import ConfigDict
|
||||||
|
|
||||||
from ..data_structures import SegDataSample
|
from mmseg.data import SegDataSample
|
||||||
|
|
||||||
# Type hint of config data
|
# Type hint of config data
|
||||||
ConfigType = Union[ConfigDict, dict]
|
ConfigType = Union[ConfigDict, dict]
|
@ -1,7 +0,0 @@
|
|||||||
[pytest]
|
|
||||||
addopts = --xdoctest --xdoctest-style=auto
|
|
||||||
norecursedirs = .git ignore build __pycache__ data docker docs .eggs
|
|
||||||
|
|
||||||
filterwarnings= default
|
|
||||||
ignore:.*No cfgstr given in Cacher constructor or call.*:Warning
|
|
||||||
ignore:.*Define the __nice__ method for.*:Warning
|
|
@ -60,72 +60,72 @@ def test_config_build_segmentor():
|
|||||||
_check_decode_head(head_config, segmentor.decode_head)
|
_check_decode_head(head_config, segmentor.decode_head)
|
||||||
|
|
||||||
|
|
||||||
def test_config_data_pipeline():
|
# def test_config_data_pipeline():
|
||||||
"""Test whether the data pipeline is valid and can process corner cases.
|
# """Test whether the data pipeline is valid and can process corner cases.
|
||||||
|
|
||||||
CommandLine:
|
# CommandLine:
|
||||||
xdoctest -m tests/test_config.py test_config_build_data_pipeline
|
# xdoctest -m tests/test_config.py test_config_build_data_pipeline
|
||||||
"""
|
# """
|
||||||
import numpy as np
|
# import numpy as np
|
||||||
from mmcv import Config
|
# from mmcv import Config
|
||||||
|
|
||||||
from mmseg.datasets.pipelines import Compose
|
# from mmseg.datasets.transforms import Compose
|
||||||
|
|
||||||
config_dpath = _get_config_directory()
|
# config_dpath = _get_config_directory()
|
||||||
print('Found config_dpath = {!r}'.format(config_dpath))
|
# print('Found config_dpath = {!r}'.format(config_dpath))
|
||||||
|
|
||||||
import glob
|
# import glob
|
||||||
config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
|
# config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
|
||||||
config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
|
# config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
|
||||||
config_names = [relpath(p, config_dpath) for p in config_fpaths]
|
# config_names = [relpath(p, config_dpath) for p in config_fpaths]
|
||||||
|
|
||||||
print('Using {} config files'.format(len(config_names)))
|
# print('Using {} config files'.format(len(config_names)))
|
||||||
|
|
||||||
for config_fname in config_names:
|
# for config_fname in config_names:
|
||||||
config_fpath = join(config_dpath, config_fname)
|
# config_fpath = join(config_dpath, config_fname)
|
||||||
print(
|
# print(
|
||||||
'Building data pipeline, config_fpath = {!r}'.format(config_fpath))
|
# 'Building data pipeline, config_fpath = {!r}'.format(config_fpath))
|
||||||
config_mod = Config.fromfile(config_fpath)
|
# config_mod = Config.fromfile(config_fpath)
|
||||||
|
|
||||||
# remove loading pipeline
|
# # remove loading pipeline
|
||||||
load_img_pipeline = config_mod.train_pipeline.pop(0)
|
# load_img_pipeline = config_mod.train_pipeline.pop(0)
|
||||||
to_float32 = load_img_pipeline.get('to_float32', False)
|
# to_float32 = load_img_pipeline.get('to_float32', False)
|
||||||
config_mod.train_pipeline.pop(0)
|
# config_mod.train_pipeline.pop(0)
|
||||||
config_mod.test_pipeline.pop(0)
|
# config_mod.test_pipeline.pop(0)
|
||||||
# remove loading annotation in test pipeline
|
# # remove loading annotation in test pipeline
|
||||||
config_mod.test_pipeline.pop(1)
|
# config_mod.test_pipeline.pop(1)
|
||||||
|
|
||||||
train_pipeline = Compose(config_mod.train_pipeline)
|
# train_pipeline = Compose(config_mod.train_pipeline)
|
||||||
test_pipeline = Compose(config_mod.test_pipeline)
|
# test_pipeline = Compose(config_mod.test_pipeline)
|
||||||
|
|
||||||
img = np.random.randint(0, 255, size=(1024, 2048, 3), dtype=np.uint8)
|
# img = np.random.randint(0, 255, size=(1024, 2048, 3), dtype=np.uint8)
|
||||||
if to_float32:
|
# if to_float32:
|
||||||
img = img.astype(np.float32)
|
# img = img.astype(np.float32)
|
||||||
seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8)
|
# seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8)
|
||||||
|
|
||||||
results = dict(
|
# results = dict(
|
||||||
filename='test_img.png',
|
# filename='test_img.png',
|
||||||
ori_filename='test_img.png',
|
# ori_filename='test_img.png',
|
||||||
img=img,
|
# img=img,
|
||||||
img_shape=img.shape,
|
# img_shape=img.shape,
|
||||||
ori_shape=img.shape,
|
# ori_shape=img.shape,
|
||||||
gt_seg_map=seg)
|
# gt_seg_map=seg)
|
||||||
results['seg_fields'] = ['gt_seg_map']
|
# results['seg_fields'] = ['gt_seg_map']
|
||||||
|
|
||||||
print('Test training data pipeline: \n{!r}'.format(train_pipeline))
|
# print('Test training data pipeline: \n{!r}'.format(train_pipeline))
|
||||||
output_results = train_pipeline(results)
|
# output_results = train_pipeline(results)
|
||||||
assert output_results is not None
|
# assert output_results is not None
|
||||||
|
|
||||||
results = dict(
|
# results = dict(
|
||||||
filename='test_img.png',
|
# filename='test_img.png',
|
||||||
ori_filename='test_img.png',
|
# ori_filename='test_img.png',
|
||||||
img=img,
|
# img=img,
|
||||||
img_shape=img.shape,
|
# img_shape=img.shape,
|
||||||
ori_shape=img.shape,
|
# ori_shape=img.shape,
|
||||||
)
|
# )
|
||||||
print('Test testing data pipeline: \n{!r}'.format(test_pipeline))
|
# print('Test testing data pipeline: \n{!r}'.format(test_pipeline))
|
||||||
output_results = test_pipeline(results)
|
# output_results = test_pipeline(results)
|
||||||
assert output_results is not None
|
# assert output_results is not None
|
||||||
|
|
||||||
|
|
||||||
def _check_decode_head(decode_head_cfg, decode_head):
|
def _check_decode_head(decode_head_cfg, decode_head):
|
||||||
|
@ -6,7 +6,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from mmengine.data import PixelData
|
from mmengine.data import PixelData
|
||||||
|
|
||||||
from mmseg.core import SegDataSample
|
from mmseg.data import SegDataSample
|
||||||
|
|
||||||
|
|
||||||
def _equal(a, b):
|
def _equal(a, b):
|
@ -6,11 +6,11 @@ from unittest.mock import MagicMock
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from mmseg.core.evaluation import get_classes, get_palette
|
|
||||||
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
|
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
|
||||||
COCOStuffDataset, CustomDataset, ISPRSDataset,
|
COCOStuffDataset, CustomDataset, ISPRSDataset,
|
||||||
LoveDADataset, PascalVOCDataset, PotsdamDataset,
|
LoveDADataset, PascalVOCDataset, PotsdamDataset,
|
||||||
iSAIDDataset)
|
iSAIDDataset)
|
||||||
|
from mmseg.utils import get_classes, get_palette
|
||||||
|
|
||||||
|
|
||||||
def test_classes():
|
def test_classes():
|
||||||
|
@ -6,8 +6,8 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from mmengine.data import BaseDataElement
|
from mmengine.data import BaseDataElement
|
||||||
|
|
||||||
from mmseg.core import SegDataSample
|
from mmseg.data import SegDataSample
|
||||||
from mmseg.datasets.pipelines import PackSegInputs
|
from mmseg.datasets.transforms import PackSegInputs
|
||||||
|
|
||||||
|
|
||||||
class TestPackSegInputs(unittest.TestCase):
|
class TestPackSegInputs(unittest.TestCase):
|
||||||
|
@ -7,7 +7,7 @@ import mmcv
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from mmcv.transforms import LoadImageFromFile
|
from mmcv.transforms import LoadImageFromFile
|
||||||
|
|
||||||
from mmseg.datasets.pipelines import LoadAnnotations
|
from mmseg.datasets.transforms import LoadAnnotations
|
||||||
|
|
||||||
|
|
||||||
class TestLoading(object):
|
class TestLoading(object):
|
||||||
|
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from mmseg.datasets.pipelines import PhotoMetricDistortion, RandomCrop
|
from mmseg.datasets.transforms import PhotoMetricDistortion, RandomCrop
|
||||||
from mmseg.registry import TRANSFORMS
|
from mmseg.registry import TRANSFORMS
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import os.path as osp
|
|||||||
import mmcv
|
import mmcv
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from mmseg.datasets.pipelines import * # noqa
|
from mmseg.datasets.transforms import * # noqa
|
||||||
from mmseg.registry import TRANSFORMS
|
from mmseg.registry import TRANSFORMS
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,10 +4,13 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.cnn import ConvModule
|
from mmcv.cnn import ConvModule
|
||||||
|
from mmengine.optim.optimizer import build_optim_wrapper
|
||||||
|
|
||||||
from mmseg.core.builder import build_optimizer
|
from mmseg.engine.optimizers.layer_decay_optimizer_constructor import \
|
||||||
from mmseg.core.optimizers.layer_decay_optimizer_constructor import \
|
|
||||||
LearningRateDecayOptimizerConstructor
|
LearningRateDecayOptimizerConstructor
|
||||||
|
from mmseg.utils import register_all_modules
|
||||||
|
|
||||||
|
register_all_modules()
|
||||||
|
|
||||||
base_lr = 1
|
base_lr = 1
|
||||||
decay_rate = 2
|
decay_rate = 2
|
||||||
@ -221,7 +224,7 @@ def test_learning_rate_decay_optimizer_constructor():
|
|||||||
optimizer=optimizer_cfg,
|
optimizer=optimizer_cfg,
|
||||||
paramwise_cfg=stagewise_paramwise_cfg,
|
paramwise_cfg=stagewise_paramwise_cfg,
|
||||||
constructor='LearningRateDecayOptimizerConstructor')
|
constructor='LearningRateDecayOptimizerConstructor')
|
||||||
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
|
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
|
||||||
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
||||||
expected_stage_wise_lr_wd_convnext)
|
expected_stage_wise_lr_wd_convnext)
|
||||||
# layerwise decay
|
# layerwise decay
|
||||||
@ -232,7 +235,7 @@ def test_learning_rate_decay_optimizer_constructor():
|
|||||||
optimizer=optimizer_cfg,
|
optimizer=optimizer_cfg,
|
||||||
paramwise_cfg=layerwise_paramwise_cfg,
|
paramwise_cfg=layerwise_paramwise_cfg,
|
||||||
constructor='LearningRateDecayOptimizerConstructor')
|
constructor='LearningRateDecayOptimizerConstructor')
|
||||||
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
|
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
|
||||||
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
||||||
expected_layer_wise_lr_wd_convnext)
|
expected_layer_wise_lr_wd_convnext)
|
||||||
|
|
||||||
@ -247,7 +250,7 @@ def test_learning_rate_decay_optimizer_constructor():
|
|||||||
optimizer=optimizer_cfg,
|
optimizer=optimizer_cfg,
|
||||||
paramwise_cfg=layerwise_paramwise_cfg,
|
paramwise_cfg=layerwise_paramwise_cfg,
|
||||||
constructor='LearningRateDecayOptimizerConstructor')
|
constructor='LearningRateDecayOptimizerConstructor')
|
||||||
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
|
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
|
||||||
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
||||||
expected_layer_wise_wd_lr_beit)
|
expected_layer_wise_wd_lr_beit)
|
||||||
|
|
||||||
@ -274,7 +277,7 @@ def test_learning_rate_decay_optimizer_constructor():
|
|||||||
optimizer=optimizer_cfg,
|
optimizer=optimizer_cfg,
|
||||||
paramwise_cfg=layerwise_paramwise_cfg,
|
paramwise_cfg=layerwise_paramwise_cfg,
|
||||||
constructor='LearningRateDecayOptimizerConstructor')
|
constructor='LearningRateDecayOptimizerConstructor')
|
||||||
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
|
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
|
||||||
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
||||||
expected_layer_wise_wd_lr_beit)
|
expected_layer_wise_wd_lr_beit)
|
||||||
|
|
||||||
@ -291,7 +294,7 @@ def test_beit_layer_decay_optimizer_constructor():
|
|||||||
paramwise_cfg=paramwise_cfg,
|
paramwise_cfg=paramwise_cfg,
|
||||||
optimizer=dict(
|
optimizer=dict(
|
||||||
type='AdamW', lr=1, betas=(0.9, 0.999), weight_decay=0.05))
|
type='AdamW', lr=1, betas=(0.9, 0.999), weight_decay=0.05))
|
||||||
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
|
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
|
||||||
# optimizer = optim_wrapper_builder(model)
|
# optimizer = optim_wrapper_builder(model)
|
||||||
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
||||||
expected_layer_wise_wd_lr_beit)
|
expected_layer_wise_wd_lr_beit)
|
@ -1,8 +1,7 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from mmengine.optim import build_optim_wrapper
|
||||||
from mmseg.core.builder import build_optimizer
|
|
||||||
|
|
||||||
|
|
||||||
class ExampleModel(nn.Module):
|
class ExampleModel(nn.Module):
|
||||||
@ -29,6 +28,6 @@ def test_build_optimizer():
|
|||||||
type='OptimWrapper',
|
type='OptimWrapper',
|
||||||
optimizer=dict(
|
optimizer=dict(
|
||||||
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum))
|
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum))
|
||||||
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
|
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
|
||||||
# test whether optimizer is successfully built from parent.
|
# test whether optimizer is successfully built from parent.
|
||||||
assert isinstance(optim_wrapper.optimizer, torch.optim.SGD)
|
assert isinstance(optim_wrapper.optimizer, torch.optim.SGD)
|
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from mmengine.data import BaseDataElement, PixelData
|
from mmengine.data import BaseDataElement, PixelData
|
||||||
|
|
||||||
from mmseg.core import SegDataSample
|
from mmseg.data import SegDataSample
|
||||||
from mmseg.metrics import CitysMetric
|
from mmseg.metrics import CitysMetric
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from mmengine.data import BaseDataElement, PixelData
|
from mmengine.data import BaseDataElement, PixelData
|
||||||
|
|
||||||
from mmseg.core import SegDataSample
|
from mmseg.data import SegDataSample
|
||||||
from mmseg.metrics import IoUMetric
|
from mmseg.metrics import IoUMetric
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from unittest import TestCase
|
|||||||
import torch
|
import torch
|
||||||
from mmengine.data import PixelData
|
from mmengine.data import PixelData
|
||||||
|
|
||||||
from mmseg.core import SegDataSample
|
from mmseg.data import SegDataSample
|
||||||
from mmseg.models import SegDataPreProcessor
|
from mmseg.models import SegDataPreProcessor
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ from mmcv.cnn.utils import revert_sync_batchnorm
|
|||||||
from mmengine.data import PixelData
|
from mmengine.data import PixelData
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from mmseg.core import SegDataSample
|
from mmseg.data import SegDataSample
|
||||||
from mmseg.utils import register_all_modules
|
from mmseg.utils import register_all_modules
|
||||||
|
|
||||||
register_all_modules()
|
register_all_modules()
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmseg.core import OHEMPixelSampler
|
from mmseg.data import OHEMPixelSampler
|
||||||
from mmseg.models.decode_heads import FCNHead
|
from mmseg.models.decode_heads import FCNHead
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,40 +0,0 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import os.path as osp
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
from mmseg.utils import find_latest_checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
def test_find_latest_checkpoint():
|
|
||||||
with tempfile.TemporaryDirectory() as tempdir:
|
|
||||||
# no checkpoints in the path
|
|
||||||
path = tempdir
|
|
||||||
latest = find_latest_checkpoint(path)
|
|
||||||
assert latest is None
|
|
||||||
|
|
||||||
# The path doesn't exist
|
|
||||||
path = osp.join(tempdir, 'none')
|
|
||||||
latest = find_latest_checkpoint(path)
|
|
||||||
assert latest is None
|
|
||||||
|
|
||||||
# test when latest.pth exists
|
|
||||||
with tempfile.TemporaryDirectory() as tempdir:
|
|
||||||
with open(osp.join(tempdir, 'latest.pth'), 'w') as f:
|
|
||||||
f.write('latest')
|
|
||||||
path = tempdir
|
|
||||||
latest = find_latest_checkpoint(path)
|
|
||||||
assert latest == osp.join(tempdir, 'latest.pth')
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tempdir:
|
|
||||||
for iter in range(1600, 160001, 1600):
|
|
||||||
with open(osp.join(tempdir, f'iter_{iter}.pth'), 'w') as f:
|
|
||||||
f.write(f'iter_{iter}.pth')
|
|
||||||
latest = find_latest_checkpoint(tempdir)
|
|
||||||
assert latest == osp.join(tempdir, 'iter_160000.pth')
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tempdir:
|
|
||||||
for epoch in range(1, 21):
|
|
||||||
with open(osp.join(tempdir, f'epoch_{epoch}.pth'), 'w') as f:
|
|
||||||
f.write(f'epoch_{epoch}.pth')
|
|
||||||
latest = find_latest_checkpoint(tempdir)
|
|
||||||
assert latest == osp.join(tempdir, 'epoch_20.pth')
|
|
@ -1,92 +1,11 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import datetime
|
import datetime
|
||||||
import multiprocessing as mp
|
|
||||||
import os
|
|
||||||
import platform
|
|
||||||
import sys
|
import sys
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
import cv2
|
|
||||||
import pytest
|
|
||||||
from mmcv import Config
|
|
||||||
from mmengine import DefaultScope
|
from mmengine import DefaultScope
|
||||||
|
|
||||||
from mmseg.utils import register_all_modules, setup_multi_processes
|
from mmseg.utils import register_all_modules
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('workers_per_gpu', (0, 2))
|
|
||||||
@pytest.mark.parametrize(('valid', 'env_cfg'), [(True,
|
|
||||||
dict(
|
|
||||||
mp_start_method='fork',
|
|
||||||
opencv_num_threads=0,
|
|
||||||
omp_num_threads=1,
|
|
||||||
mkl_num_threads=1)),
|
|
||||||
(False,
|
|
||||||
dict(
|
|
||||||
mp_start_method=1,
|
|
||||||
opencv_num_threads=0.1,
|
|
||||||
omp_num_threads='s',
|
|
||||||
mkl_num_threads='1'))])
|
|
||||||
def test_setup_multi_processes(workers_per_gpu, valid, env_cfg):
|
|
||||||
# temp save system setting
|
|
||||||
sys_start_mehod = mp.get_start_method(allow_none=True)
|
|
||||||
sys_cv_threads = cv2.getNumThreads()
|
|
||||||
# pop and temp save system env vars
|
|
||||||
sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None)
|
|
||||||
sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None)
|
|
||||||
|
|
||||||
config = dict(data=dict(workers_per_gpu=workers_per_gpu))
|
|
||||||
config.update(env_cfg)
|
|
||||||
cfg = Config(config)
|
|
||||||
setup_multi_processes(cfg)
|
|
||||||
|
|
||||||
# test when cfg is valid and workers_per_gpu > 0
|
|
||||||
# setup_multi_processes will work
|
|
||||||
if valid and workers_per_gpu > 0:
|
|
||||||
# test config without setting env
|
|
||||||
|
|
||||||
assert os.getenv('OMP_NUM_THREADS') == str(env_cfg['omp_num_threads'])
|
|
||||||
assert os.getenv('MKL_NUM_THREADS') == str(env_cfg['mkl_num_threads'])
|
|
||||||
# when set to 0, the num threads will be 1
|
|
||||||
assert cv2.getNumThreads() == env_cfg[
|
|
||||||
'opencv_num_threads'] if env_cfg['opencv_num_threads'] > 0 else 1
|
|
||||||
if platform.system() != 'Windows':
|
|
||||||
assert mp.get_start_method() == env_cfg['mp_start_method']
|
|
||||||
|
|
||||||
# revert setting to avoid affecting other programs
|
|
||||||
if sys_start_mehod:
|
|
||||||
mp.set_start_method(sys_start_mehod, force=True)
|
|
||||||
cv2.setNumThreads(sys_cv_threads)
|
|
||||||
if sys_omp_threads:
|
|
||||||
os.environ['OMP_NUM_THREADS'] = sys_omp_threads
|
|
||||||
else:
|
|
||||||
os.environ.pop('OMP_NUM_THREADS')
|
|
||||||
if sys_mkl_threads:
|
|
||||||
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
|
|
||||||
else:
|
|
||||||
os.environ.pop('MKL_NUM_THREADS')
|
|
||||||
|
|
||||||
elif valid and workers_per_gpu == 0:
|
|
||||||
|
|
||||||
if platform.system() != 'Windows':
|
|
||||||
assert mp.get_start_method() == env_cfg['mp_start_method']
|
|
||||||
assert cv2.getNumThreads() == env_cfg[
|
|
||||||
'opencv_num_threads'] if env_cfg['opencv_num_threads'] > 0 else 1
|
|
||||||
assert 'OMP_NUM_THREADS' not in os.environ
|
|
||||||
assert 'MKL_NUM_THREADS' not in os.environ
|
|
||||||
if sys_start_mehod:
|
|
||||||
mp.set_start_method(sys_start_mehod, force=True)
|
|
||||||
cv2.setNumThreads(sys_cv_threads)
|
|
||||||
if sys_omp_threads:
|
|
||||||
os.environ['OMP_NUM_THREADS'] = sys_omp_threads
|
|
||||||
if sys_mkl_threads:
|
|
||||||
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
|
|
||||||
|
|
||||||
else:
|
|
||||||
assert mp.get_start_method() == sys_start_mehod
|
|
||||||
assert cv2.getNumThreads() == sys_cv_threads
|
|
||||||
assert 'OMP_NUM_THREADS' not in os.environ
|
|
||||||
assert 'MKL_NUM_THREADS' not in os.environ
|
|
||||||
|
|
||||||
|
|
||||||
class TestSetupEnv(TestCase):
|
class TestSetupEnv(TestCase):
|
||||||
|
@ -1,340 +0,0 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import os.path as osp
|
|
||||||
import shutil
|
|
||||||
import warnings
|
|
||||||
from typing import Any, Iterable
|
|
||||||
|
|
||||||
import mmcv
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from mmcv.parallel import MMDataParallel
|
|
||||||
from mmcv.runner import get_dist_info
|
|
||||||
from mmcv.utils import DictAction
|
|
||||||
|
|
||||||
# from mmseg.apis import single_gpu_test
|
|
||||||
from mmseg.datasets import build_dataloader, build_dataset
|
|
||||||
from mmseg.models.segmentors.base import BaseSegmentor
|
|
||||||
from mmseg.ops import resize
|
|
||||||
|
|
||||||
single_gpu_test = None
|
|
||||||
|
|
||||||
|
|
||||||
class ONNXRuntimeSegmentor(BaseSegmentor):
|
|
||||||
|
|
||||||
def __init__(self, onnx_file: str, cfg: Any, device_id: int):
|
|
||||||
super(ONNXRuntimeSegmentor, self).__init__()
|
|
||||||
import onnxruntime as ort
|
|
||||||
|
|
||||||
# get the custom op path
|
|
||||||
ort_custom_op_path = ''
|
|
||||||
try:
|
|
||||||
from mmcv.ops import get_onnxruntime_op_path
|
|
||||||
ort_custom_op_path = get_onnxruntime_op_path()
|
|
||||||
except (ImportError, ModuleNotFoundError):
|
|
||||||
warnings.warn('If input model has custom op from mmcv, \
|
|
||||||
you may have to build mmcv with ONNXRuntime from source.')
|
|
||||||
session_options = ort.SessionOptions()
|
|
||||||
# register custom op for onnxruntime
|
|
||||||
if osp.exists(ort_custom_op_path):
|
|
||||||
session_options.register_custom_ops_library(ort_custom_op_path)
|
|
||||||
sess = ort.InferenceSession(onnx_file, session_options)
|
|
||||||
providers = ['CPUExecutionProvider']
|
|
||||||
options = [{}]
|
|
||||||
is_cuda_available = ort.get_device() == 'GPU'
|
|
||||||
if is_cuda_available:
|
|
||||||
providers.insert(0, 'CUDAExecutionProvider')
|
|
||||||
options.insert(0, {'device_id': device_id})
|
|
||||||
|
|
||||||
sess.set_providers(providers, options)
|
|
||||||
|
|
||||||
self.sess = sess
|
|
||||||
self.device_id = device_id
|
|
||||||
self.io_binding = sess.io_binding()
|
|
||||||
self.output_names = [_.name for _ in sess.get_outputs()]
|
|
||||||
for name in self.output_names:
|
|
||||||
self.io_binding.bind_output(name)
|
|
||||||
self.cfg = cfg
|
|
||||||
self.test_mode = cfg.model.test_cfg.mode
|
|
||||||
self.is_cuda_available = is_cuda_available
|
|
||||||
|
|
||||||
def extract_feat(self, imgs):
|
|
||||||
raise NotImplementedError('This method is not implemented.')
|
|
||||||
|
|
||||||
def encode_decode(self, img, img_metas):
|
|
||||||
raise NotImplementedError('This method is not implemented.')
|
|
||||||
|
|
||||||
def forward_train(self, imgs, img_metas, **kwargs):
|
|
||||||
raise NotImplementedError('This method is not implemented.')
|
|
||||||
|
|
||||||
def simple_test(self, img: torch.Tensor, img_meta: Iterable,
|
|
||||||
**kwargs) -> list:
|
|
||||||
if not self.is_cuda_available:
|
|
||||||
img = img.detach().cpu()
|
|
||||||
elif self.device_id >= 0:
|
|
||||||
img = img.cuda(self.device_id)
|
|
||||||
device_type = img.device.type
|
|
||||||
self.io_binding.bind_input(
|
|
||||||
name='input',
|
|
||||||
device_type=device_type,
|
|
||||||
device_id=self.device_id,
|
|
||||||
element_type=np.float32,
|
|
||||||
shape=img.shape,
|
|
||||||
buffer_ptr=img.data_ptr())
|
|
||||||
self.sess.run_with_iobinding(self.io_binding)
|
|
||||||
seg_pred = self.io_binding.copy_outputs_to_cpu()[0]
|
|
||||||
# whole might support dynamic reshape
|
|
||||||
ori_shape = img_meta[0]['ori_shape']
|
|
||||||
if not (ori_shape[0] == seg_pred.shape[-2]
|
|
||||||
and ori_shape[1] == seg_pred.shape[-1]):
|
|
||||||
seg_pred = torch.from_numpy(seg_pred).float()
|
|
||||||
seg_pred = resize(
|
|
||||||
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
|
|
||||||
seg_pred = seg_pred.long().detach().cpu().numpy()
|
|
||||||
seg_pred = seg_pred[0]
|
|
||||||
seg_pred = list(seg_pred)
|
|
||||||
return seg_pred
|
|
||||||
|
|
||||||
def aug_test(self, imgs, img_metas, **kwargs):
|
|
||||||
raise NotImplementedError('This method is not implemented.')
|
|
||||||
|
|
||||||
|
|
||||||
class TensorRTSegmentor(BaseSegmentor):
|
|
||||||
|
|
||||||
def __init__(self, trt_file: str, cfg: Any, device_id: int):
|
|
||||||
super(TensorRTSegmentor, self).__init__()
|
|
||||||
from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin
|
|
||||||
try:
|
|
||||||
load_tensorrt_plugin()
|
|
||||||
except (ImportError, ModuleNotFoundError):
|
|
||||||
warnings.warn('If input model has custom op from mmcv, \
|
|
||||||
you may have to build mmcv with TensorRT from source.')
|
|
||||||
model = TRTWraper(
|
|
||||||
trt_file, input_names=['input'], output_names=['output'])
|
|
||||||
|
|
||||||
self.model = model
|
|
||||||
self.device_id = device_id
|
|
||||||
self.cfg = cfg
|
|
||||||
self.test_mode = cfg.model.test_cfg.mode
|
|
||||||
|
|
||||||
def extract_feat(self, imgs):
|
|
||||||
raise NotImplementedError('This method is not implemented.')
|
|
||||||
|
|
||||||
def encode_decode(self, img, img_metas):
|
|
||||||
raise NotImplementedError('This method is not implemented.')
|
|
||||||
|
|
||||||
def forward_train(self, imgs, img_metas, **kwargs):
|
|
||||||
raise NotImplementedError('This method is not implemented.')
|
|
||||||
|
|
||||||
def simple_test(self, img: torch.Tensor, img_meta: Iterable,
|
|
||||||
**kwargs) -> list:
|
|
||||||
with torch.cuda.device(self.device_id), torch.no_grad():
|
|
||||||
seg_pred = self.model({'input': img})['output']
|
|
||||||
seg_pred = seg_pred.detach().cpu().numpy()
|
|
||||||
# whole might support dynamic reshape
|
|
||||||
ori_shape = img_meta[0]['ori_shape']
|
|
||||||
if not (ori_shape[0] == seg_pred.shape[-2]
|
|
||||||
and ori_shape[1] == seg_pred.shape[-1]):
|
|
||||||
seg_pred = torch.from_numpy(seg_pred).float()
|
|
||||||
seg_pred = resize(
|
|
||||||
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
|
|
||||||
seg_pred = seg_pred.long().detach().cpu().numpy()
|
|
||||||
seg_pred = seg_pred[0]
|
|
||||||
seg_pred = list(seg_pred)
|
|
||||||
return seg_pred
|
|
||||||
|
|
||||||
def aug_test(self, imgs, img_metas, **kwargs):
|
|
||||||
raise NotImplementedError('This method is not implemented.')
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description='mmseg backend test (and eval)')
|
|
||||||
parser.add_argument('config', help='test config file path')
|
|
||||||
parser.add_argument('model', help='Input model file')
|
|
||||||
parser.add_argument(
|
|
||||||
'--backend',
|
|
||||||
help='Backend of the model.',
|
|
||||||
choices=['onnxruntime', 'tensorrt'])
|
|
||||||
parser.add_argument('--out', help='output result file in pickle format')
|
|
||||||
parser.add_argument(
|
|
||||||
'--format-only',
|
|
||||||
action='store_true',
|
|
||||||
help='Format the output results without perform evaluation. It is'
|
|
||||||
'useful when you want to format the result to a specific format and '
|
|
||||||
'submit it to the test server')
|
|
||||||
parser.add_argument(
|
|
||||||
'--eval',
|
|
||||||
type=str,
|
|
||||||
nargs='+',
|
|
||||||
help='evaluation metrics, which depends on the dataset, e.g., "mIoU"'
|
|
||||||
' for generic datasets, and "cityscapes" for Cityscapes')
|
|
||||||
parser.add_argument('--show', action='store_true', help='show results')
|
|
||||||
parser.add_argument(
|
|
||||||
'--show-dir', help='directory where painted images will be saved')
|
|
||||||
parser.add_argument(
|
|
||||||
'--options',
|
|
||||||
nargs='+',
|
|
||||||
action=DictAction,
|
|
||||||
help="--options is deprecated in favor of --cfg_options' and it will "
|
|
||||||
'not be supported in version v0.22.0. Override some settings in the '
|
|
||||||
'used config, the key-value pair in xxx=yyy format will be merged '
|
|
||||||
'into config file. If the value to be overwritten is a list, it '
|
|
||||||
'should be like key="[a,b]" or key=a,b It also allows nested '
|
|
||||||
'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation '
|
|
||||||
'marks are necessary and that no white space is allowed.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--cfg-options',
|
|
||||||
nargs='+',
|
|
||||||
action=DictAction,
|
|
||||||
help='override some settings in the used config, the key-value pair '
|
|
||||||
'in xxx=yyy format will be merged into config file. If the value to '
|
|
||||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
|
||||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
|
||||||
'Note that the quotation marks are necessary and that no white space '
|
|
||||||
'is allowed.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--eval-options',
|
|
||||||
nargs='+',
|
|
||||||
action=DictAction,
|
|
||||||
help='custom options for evaluation')
|
|
||||||
parser.add_argument(
|
|
||||||
'--opacity',
|
|
||||||
type=float,
|
|
||||||
default=0.5,
|
|
||||||
help='Opacity of painted segmentation map. In (0, 1] range.')
|
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
|
||||||
args = parser.parse_args()
|
|
||||||
if 'LOCAL_RANK' not in os.environ:
|
|
||||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
|
||||||
|
|
||||||
if args.options and args.cfg_options:
|
|
||||||
raise ValueError(
|
|
||||||
'--options and --cfg-options cannot be both '
|
|
||||||
'specified, --options is deprecated in favor of --cfg-options. '
|
|
||||||
'--options will not be supported in version v0.22.0.')
|
|
||||||
if args.options:
|
|
||||||
warnings.warn('--options is deprecated in favor of --cfg-options. '
|
|
||||||
'--options will not be supported in version v0.22.0.')
|
|
||||||
args.cfg_options = args.options
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = parse_args()
|
|
||||||
|
|
||||||
assert args.out or args.eval or args.format_only or args.show \
|
|
||||||
or args.show_dir, \
|
|
||||||
('Please specify at least one operation (save/eval/format/show the '
|
|
||||||
'results / save the results) with the argument "--out", "--eval"'
|
|
||||||
', "--format-only", "--show" or "--show-dir"')
|
|
||||||
|
|
||||||
if args.eval and args.format_only:
|
|
||||||
raise ValueError('--eval and --format_only cannot be both specified')
|
|
||||||
|
|
||||||
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
|
|
||||||
raise ValueError('The output file must be a pkl file.')
|
|
||||||
|
|
||||||
cfg = mmcv.Config.fromfile(args.config)
|
|
||||||
if args.cfg_options is not None:
|
|
||||||
cfg.merge_from_dict(args.cfg_options)
|
|
||||||
cfg.model.pretrained = None
|
|
||||||
cfg.data.test.test_mode = True
|
|
||||||
|
|
||||||
# init distributed env first, since logger depends on the dist info.
|
|
||||||
distributed = False
|
|
||||||
|
|
||||||
# build the dataloader
|
|
||||||
# TODO: support multiple images per gpu (only minor changes are needed)
|
|
||||||
dataset = build_dataset(cfg.data.test)
|
|
||||||
data_loader = build_dataloader(
|
|
||||||
dataset,
|
|
||||||
samples_per_gpu=1,
|
|
||||||
workers_per_gpu=cfg.data.workers_per_gpu,
|
|
||||||
dist=distributed,
|
|
||||||
shuffle=False)
|
|
||||||
|
|
||||||
# load onnx config and meta
|
|
||||||
cfg.model.train_cfg = None
|
|
||||||
|
|
||||||
if args.backend == 'onnxruntime':
|
|
||||||
model = ONNXRuntimeSegmentor(args.model, cfg=cfg, device_id=0)
|
|
||||||
elif args.backend == 'tensorrt':
|
|
||||||
model = TensorRTSegmentor(args.model, cfg=cfg, device_id=0)
|
|
||||||
|
|
||||||
model.CLASSES = dataset.CLASSES
|
|
||||||
model.PALETTE = dataset.PALETTE
|
|
||||||
|
|
||||||
# clean gpu memory when starting a new evaluation.
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
eval_kwargs = {} if args.eval_options is None else args.eval_options
|
|
||||||
|
|
||||||
# Deprecated
|
|
||||||
efficient_test = eval_kwargs.get('efficient_test', False)
|
|
||||||
if efficient_test:
|
|
||||||
warnings.warn(
|
|
||||||
'``efficient_test=True`` does not have effect in tools/test.py, '
|
|
||||||
'the evaluation and format results are CPU memory efficient by '
|
|
||||||
'default')
|
|
||||||
|
|
||||||
eval_on_format_results = (
|
|
||||||
args.eval is not None and 'cityscapes' in args.eval)
|
|
||||||
if eval_on_format_results:
|
|
||||||
assert len(args.eval) == 1, 'eval on format results is not ' \
|
|
||||||
'applicable for metrics other than ' \
|
|
||||||
'cityscapes'
|
|
||||||
if args.format_only or eval_on_format_results:
|
|
||||||
if 'imgfile_prefix' in eval_kwargs:
|
|
||||||
tmpdir = eval_kwargs['imgfile_prefix']
|
|
||||||
else:
|
|
||||||
tmpdir = '.format_cityscapes'
|
|
||||||
eval_kwargs.setdefault('imgfile_prefix', tmpdir)
|
|
||||||
mmcv.mkdir_or_exist(tmpdir)
|
|
||||||
else:
|
|
||||||
tmpdir = None
|
|
||||||
|
|
||||||
model = MMDataParallel(model, device_ids=[0])
|
|
||||||
results = single_gpu_test(
|
|
||||||
model,
|
|
||||||
data_loader,
|
|
||||||
args.show,
|
|
||||||
args.show_dir,
|
|
||||||
False,
|
|
||||||
args.opacity,
|
|
||||||
pre_eval=args.eval is not None and not eval_on_format_results,
|
|
||||||
format_only=args.format_only or eval_on_format_results,
|
|
||||||
format_args=eval_kwargs)
|
|
||||||
|
|
||||||
rank, _ = get_dist_info()
|
|
||||||
if rank == 0:
|
|
||||||
if args.out:
|
|
||||||
warnings.warn(
|
|
||||||
'The behavior of ``args.out`` has been changed since MMSeg '
|
|
||||||
'v0.16, the pickled outputs could be seg map as type of '
|
|
||||||
'np.array, pre-eval results or file paths for '
|
|
||||||
'``dataset.format_results()``.')
|
|
||||||
print(f'\nwriting results to {args.out}')
|
|
||||||
mmcv.dump(results, args.out)
|
|
||||||
if args.eval:
|
|
||||||
dataset.evaluate(results, args.eval, **eval_kwargs)
|
|
||||||
if tmpdir is not None and eval_on_format_results:
|
|
||||||
# remove tmp dir when cityscapes evaluation
|
|
||||||
shutil.rmtree(tmpdir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
|
|
||||||
# Following strings of text style are from colorama package
|
|
||||||
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
|
|
||||||
red_text, blue_text = '\x1b[31m', '\x1b[34m'
|
|
||||||
white_background = '\x1b[107m'
|
|
||||||
|
|
||||||
msg = white_background + bright_style + red_text
|
|
||||||
msg += 'DeprecationWarning: This tool will be deprecated in future. '
|
|
||||||
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
|
|
||||||
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
|
|
||||||
msg += reset_style
|
|
||||||
warnings.warn(msg)
|
|
@ -1,289 +0,0 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import os.path as osp
|
|
||||||
import warnings
|
|
||||||
from typing import Iterable, Optional, Union
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import mmcv
|
|
||||||
import numpy as np
|
|
||||||
import onnxruntime as ort
|
|
||||||
import torch
|
|
||||||
from mmcv.ops import get_onnxruntime_op_path
|
|
||||||
from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt,
|
|
||||||
save_trt_engine)
|
|
||||||
|
|
||||||
from mmseg.apis.inference import LoadImage
|
|
||||||
from mmseg.datasets import DATASETS
|
|
||||||
from mmseg.datasets.pipelines import Compose
|
|
||||||
|
|
||||||
|
|
||||||
def get_GiB(x: int):
|
|
||||||
"""return x GiB."""
|
|
||||||
return x * (1 << 30)
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_input_img(img_path: str,
|
|
||||||
test_pipeline: Iterable[dict],
|
|
||||||
shape: Optional[Iterable] = None,
|
|
||||||
rescale_shape: Optional[Iterable] = None) -> dict:
|
|
||||||
# build the data pipeline
|
|
||||||
if shape is not None:
|
|
||||||
test_pipeline[1]['img_scale'] = (shape[1], shape[0])
|
|
||||||
test_pipeline[1]['transforms'][0]['keep_ratio'] = False
|
|
||||||
test_pipeline = [LoadImage()] + test_pipeline[1:]
|
|
||||||
test_pipeline = Compose(test_pipeline)
|
|
||||||
# prepare data
|
|
||||||
data = dict(img=img_path)
|
|
||||||
data = test_pipeline(data)
|
|
||||||
imgs = data['img']
|
|
||||||
img_metas = [i.data for i in data['img_metas']]
|
|
||||||
|
|
||||||
if rescale_shape is not None:
|
|
||||||
for img_meta in img_metas:
|
|
||||||
img_meta['ori_shape'] = tuple(rescale_shape) + (3, )
|
|
||||||
|
|
||||||
mm_inputs = {'imgs': imgs, 'img_metas': img_metas}
|
|
||||||
|
|
||||||
return mm_inputs
|
|
||||||
|
|
||||||
|
|
||||||
def _update_input_img(img_list: Iterable, img_meta_list: Iterable):
|
|
||||||
# update img and its meta list
|
|
||||||
N = img_list[0].size(0)
|
|
||||||
img_meta = img_meta_list[0][0]
|
|
||||||
img_shape = img_meta['img_shape']
|
|
||||||
ori_shape = img_meta['ori_shape']
|
|
||||||
pad_shape = img_meta['pad_shape']
|
|
||||||
new_img_meta_list = [[{
|
|
||||||
'img_shape':
|
|
||||||
img_shape,
|
|
||||||
'ori_shape':
|
|
||||||
ori_shape,
|
|
||||||
'pad_shape':
|
|
||||||
pad_shape,
|
|
||||||
'filename':
|
|
||||||
img_meta['filename'],
|
|
||||||
'scale_factor':
|
|
||||||
(img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2,
|
|
||||||
'flip':
|
|
||||||
False,
|
|
||||||
} for _ in range(N)]]
|
|
||||||
|
|
||||||
return img_list, new_img_meta_list
|
|
||||||
|
|
||||||
|
|
||||||
def show_result_pyplot(img: Union[str, np.ndarray],
|
|
||||||
result: np.ndarray,
|
|
||||||
palette: Optional[Iterable] = None,
|
|
||||||
fig_size: Iterable[int] = (15, 10),
|
|
||||||
opacity: float = 0.5,
|
|
||||||
title: str = '',
|
|
||||||
block: bool = True):
|
|
||||||
img = mmcv.imread(img)
|
|
||||||
img = img.copy()
|
|
||||||
seg = result[0]
|
|
||||||
seg = mmcv.imresize(seg, img.shape[:2][::-1])
|
|
||||||
palette = np.array(palette)
|
|
||||||
assert palette.shape[1] == 3
|
|
||||||
assert len(palette.shape) == 2
|
|
||||||
assert 0 < opacity <= 1.0
|
|
||||||
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
|
||||||
for label, color in enumerate(palette):
|
|
||||||
color_seg[seg == label, :] = color
|
|
||||||
# convert to BGR
|
|
||||||
color_seg = color_seg[..., ::-1]
|
|
||||||
|
|
||||||
img = img * (1 - opacity) + color_seg * opacity
|
|
||||||
img = img.astype(np.uint8)
|
|
||||||
|
|
||||||
plt.figure(figsize=fig_size)
|
|
||||||
plt.imshow(mmcv.bgr2rgb(img))
|
|
||||||
plt.title(title)
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.show(block=block)
|
|
||||||
|
|
||||||
|
|
||||||
def onnx2tensorrt(onnx_file: str,
|
|
||||||
trt_file: str,
|
|
||||||
config: dict,
|
|
||||||
input_config: dict,
|
|
||||||
fp16: bool = False,
|
|
||||||
verify: bool = False,
|
|
||||||
show: bool = False,
|
|
||||||
dataset: str = 'CityscapesDataset',
|
|
||||||
workspace_size: int = 1,
|
|
||||||
verbose: bool = False):
|
|
||||||
import tensorrt as trt
|
|
||||||
min_shape = input_config['min_shape']
|
|
||||||
max_shape = input_config['max_shape']
|
|
||||||
# create trt engine and wrapper
|
|
||||||
opt_shape_dict = {'input': [min_shape, min_shape, max_shape]}
|
|
||||||
max_workspace_size = get_GiB(workspace_size)
|
|
||||||
trt_engine = onnx2trt(
|
|
||||||
onnx_file,
|
|
||||||
opt_shape_dict,
|
|
||||||
log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR,
|
|
||||||
fp16_mode=fp16,
|
|
||||||
max_workspace_size=max_workspace_size)
|
|
||||||
save_dir, _ = osp.split(trt_file)
|
|
||||||
if save_dir:
|
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
|
||||||
save_trt_engine(trt_engine, trt_file)
|
|
||||||
print(f'Successfully created TensorRT engine: {trt_file}')
|
|
||||||
|
|
||||||
if verify:
|
|
||||||
inputs = _prepare_input_img(
|
|
||||||
input_config['input_path'],
|
|
||||||
config.data.test.pipeline,
|
|
||||||
shape=min_shape[2:])
|
|
||||||
|
|
||||||
imgs = inputs['imgs']
|
|
||||||
img_metas = inputs['img_metas']
|
|
||||||
img_list = [img[None, :] for img in imgs]
|
|
||||||
img_meta_list = [[img_meta] for img_meta in img_metas]
|
|
||||||
# update img_meta
|
|
||||||
img_list, img_meta_list = _update_input_img(img_list, img_meta_list)
|
|
||||||
|
|
||||||
if max_shape[0] > 1:
|
|
||||||
# concate flip image for batch test
|
|
||||||
flip_img_list = [_.flip(-1) for _ in img_list]
|
|
||||||
img_list = [
|
|
||||||
torch.cat((ori_img, flip_img), 0)
|
|
||||||
for ori_img, flip_img in zip(img_list, flip_img_list)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Get results from ONNXRuntime
|
|
||||||
ort_custom_op_path = get_onnxruntime_op_path()
|
|
||||||
session_options = ort.SessionOptions()
|
|
||||||
if osp.exists(ort_custom_op_path):
|
|
||||||
session_options.register_custom_ops_library(ort_custom_op_path)
|
|
||||||
sess = ort.InferenceSession(onnx_file, session_options)
|
|
||||||
sess.set_providers(['CPUExecutionProvider'], [{}]) # use cpu mode
|
|
||||||
onnx_output = sess.run(['output'],
|
|
||||||
{'input': img_list[0].detach().numpy()})[0][0]
|
|
||||||
|
|
||||||
# Get results from TensorRT
|
|
||||||
trt_model = TRTWraper(trt_file, ['input'], ['output'])
|
|
||||||
with torch.no_grad():
|
|
||||||
trt_outputs = trt_model({'input': img_list[0].contiguous().cuda()})
|
|
||||||
trt_output = trt_outputs['output'][0].cpu().detach().numpy()
|
|
||||||
|
|
||||||
if show:
|
|
||||||
dataset = DATASETS.get(dataset)
|
|
||||||
assert dataset is not None
|
|
||||||
palette = dataset.PALETTE
|
|
||||||
|
|
||||||
show_result_pyplot(
|
|
||||||
input_config['input_path'],
|
|
||||||
(onnx_output[0].astype(np.uint8), ),
|
|
||||||
palette=palette,
|
|
||||||
title='ONNXRuntime',
|
|
||||||
block=False)
|
|
||||||
show_result_pyplot(
|
|
||||||
input_config['input_path'], (trt_output[0].astype(np.uint8), ),
|
|
||||||
palette=palette,
|
|
||||||
title='TensorRT')
|
|
||||||
|
|
||||||
np.testing.assert_allclose(
|
|
||||||
onnx_output, trt_output, rtol=1e-03, atol=1e-05)
|
|
||||||
print('TensorRT and ONNXRuntime output all close.')
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description='Convert MMSegmentation models from ONNX to TensorRT')
|
|
||||||
parser.add_argument('config', help='Config file of the model')
|
|
||||||
parser.add_argument('model', help='Path to the input ONNX model')
|
|
||||||
parser.add_argument(
|
|
||||||
'--trt-file', type=str, help='Path to the output TensorRT engine')
|
|
||||||
parser.add_argument(
|
|
||||||
'--max-shape',
|
|
||||||
type=int,
|
|
||||||
nargs=4,
|
|
||||||
default=[1, 3, 400, 600],
|
|
||||||
help='Maximum shape of model input.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--min-shape',
|
|
||||||
type=int,
|
|
||||||
nargs=4,
|
|
||||||
default=[1, 3, 400, 600],
|
|
||||||
help='Minimum shape of model input.')
|
|
||||||
parser.add_argument('--fp16', action='store_true', help='Enable fp16 mode')
|
|
||||||
parser.add_argument(
|
|
||||||
'--workspace-size',
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help='Max workspace size in GiB')
|
|
||||||
parser.add_argument(
|
|
||||||
'--input-img', type=str, default='', help='Image for test')
|
|
||||||
parser.add_argument(
|
|
||||||
'--show', action='store_true', help='Whether to show output results')
|
|
||||||
parser.add_argument(
|
|
||||||
'--dataset',
|
|
||||||
type=str,
|
|
||||||
default='CityscapesDataset',
|
|
||||||
help='Dataset name')
|
|
||||||
parser.add_argument(
|
|
||||||
'--verify',
|
|
||||||
action='store_true',
|
|
||||||
help='Verify the outputs of ONNXRuntime and TensorRT')
|
|
||||||
parser.add_argument(
|
|
||||||
'--verbose',
|
|
||||||
action='store_true',
|
|
||||||
help='Whether to verbose logging messages while creating \
|
|
||||||
TensorRT engine.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
|
||||||
assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.'
|
|
||||||
args = parse_args()
|
|
||||||
|
|
||||||
if not args.input_img:
|
|
||||||
args.input_img = osp.join(osp.dirname(__file__), '../demo/demo.png')
|
|
||||||
|
|
||||||
# check arguments
|
|
||||||
assert osp.exists(args.config), 'Config {} not found.'.format(args.config)
|
|
||||||
assert osp.exists(args.model), \
|
|
||||||
'ONNX model {} not found.'.format(args.model)
|
|
||||||
assert args.workspace_size >= 0, 'Workspace size less than 0.'
|
|
||||||
assert DATASETS.get(args.dataset) is not None, \
|
|
||||||
'Dataset {} does not found.'.format(args.dataset)
|
|
||||||
for max_value, min_value in zip(args.max_shape, args.min_shape):
|
|
||||||
assert max_value >= min_value, \
|
|
||||||
'max_shape should be larger than min shape'
|
|
||||||
|
|
||||||
input_config = {
|
|
||||||
'min_shape': args.min_shape,
|
|
||||||
'max_shape': args.max_shape,
|
|
||||||
'input_path': args.input_img
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg = mmcv.Config.fromfile(args.config)
|
|
||||||
onnx2tensorrt(
|
|
||||||
args.model,
|
|
||||||
args.trt_file,
|
|
||||||
cfg,
|
|
||||||
input_config,
|
|
||||||
fp16=args.fp16,
|
|
||||||
verify=args.verify,
|
|
||||||
show=args.show,
|
|
||||||
dataset=args.dataset,
|
|
||||||
workspace_size=args.workspace_size,
|
|
||||||
verbose=args.verbose)
|
|
||||||
|
|
||||||
# Following strings of text style are from colorama package
|
|
||||||
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
|
|
||||||
red_text, blue_text = '\x1b[31m', '\x1b[34m'
|
|
||||||
white_background = '\x1b[107m'
|
|
||||||
|
|
||||||
msg = white_background + bright_style + red_text
|
|
||||||
msg += 'DeprecationWarning: This tool will be deprecated in future. '
|
|
||||||
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
|
|
||||||
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
|
|
||||||
msg += reset_style
|
|
||||||
warnings.warn(msg)
|
|
@ -1,405 +0,0 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import argparse
|
|
||||||
import warnings
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import mmcv
|
|
||||||
import numpy as np
|
|
||||||
import onnxruntime as rt
|
|
||||||
import torch
|
|
||||||
import torch._C
|
|
||||||
import torch.serialization
|
|
||||||
from mmcv import DictAction
|
|
||||||
from mmcv.onnx import register_extra_symbolics
|
|
||||||
from mmcv.runner import load_checkpoint
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from mmseg.apis import show_result_pyplot
|
|
||||||
from mmseg.apis.inference import LoadImage
|
|
||||||
from mmseg.datasets.pipelines import Compose
|
|
||||||
from mmseg.models import build_segmentor
|
|
||||||
from mmseg.ops import resize
|
|
||||||
|
|
||||||
torch.manual_seed(3)
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_batchnorm(module):
|
|
||||||
module_output = module
|
|
||||||
if isinstance(module, torch.nn.SyncBatchNorm):
|
|
||||||
module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
|
|
||||||
module.momentum, module.affine,
|
|
||||||
module.track_running_stats)
|
|
||||||
if module.affine:
|
|
||||||
module_output.weight.data = module.weight.data.clone().detach()
|
|
||||||
module_output.bias.data = module.bias.data.clone().detach()
|
|
||||||
# keep requires_grad unchanged
|
|
||||||
module_output.weight.requires_grad = module.weight.requires_grad
|
|
||||||
module_output.bias.requires_grad = module.bias.requires_grad
|
|
||||||
module_output.running_mean = module.running_mean
|
|
||||||
module_output.running_var = module.running_var
|
|
||||||
module_output.num_batches_tracked = module.num_batches_tracked
|
|
||||||
for name, child in module.named_children():
|
|
||||||
module_output.add_module(name, _convert_batchnorm(child))
|
|
||||||
del module
|
|
||||||
return module_output
|
|
||||||
|
|
||||||
|
|
||||||
def _demo_mm_inputs(input_shape, num_classes):
|
|
||||||
"""Create a superset of inputs needed to run test or train batches.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_shape (tuple):
|
|
||||||
input batch dimensions
|
|
||||||
num_classes (int):
|
|
||||||
number of semantic classes
|
|
||||||
"""
|
|
||||||
(N, C, H, W) = input_shape
|
|
||||||
rng = np.random.RandomState(0)
|
|
||||||
imgs = rng.rand(*input_shape)
|
|
||||||
segs = rng.randint(
|
|
||||||
low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
|
|
||||||
img_metas = [{
|
|
||||||
'img_shape': (H, W, C),
|
|
||||||
'ori_shape': (H, W, C),
|
|
||||||
'pad_shape': (H, W, C),
|
|
||||||
'filename': '<demo>.png',
|
|
||||||
'scale_factor': 1.0,
|
|
||||||
'flip': False,
|
|
||||||
} for _ in range(N)]
|
|
||||||
mm_inputs = {
|
|
||||||
'imgs': torch.FloatTensor(imgs).requires_grad_(True),
|
|
||||||
'img_metas': img_metas,
|
|
||||||
'gt_semantic_seg': torch.LongTensor(segs)
|
|
||||||
}
|
|
||||||
return mm_inputs
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_input_img(img_path,
|
|
||||||
test_pipeline,
|
|
||||||
shape=None,
|
|
||||||
rescale_shape=None):
|
|
||||||
# build the data pipeline
|
|
||||||
if shape is not None:
|
|
||||||
test_pipeline[1]['img_scale'] = (shape[1], shape[0])
|
|
||||||
test_pipeline[1]['transforms'][0]['keep_ratio'] = False
|
|
||||||
test_pipeline = [LoadImage()] + test_pipeline[1:]
|
|
||||||
test_pipeline = Compose(test_pipeline)
|
|
||||||
# prepare data
|
|
||||||
data = dict(img=img_path)
|
|
||||||
data = test_pipeline(data)
|
|
||||||
imgs = data['img']
|
|
||||||
img_metas = [i.data for i in data['img_metas']]
|
|
||||||
|
|
||||||
if rescale_shape is not None:
|
|
||||||
for img_meta in img_metas:
|
|
||||||
img_meta['ori_shape'] = tuple(rescale_shape) + (3, )
|
|
||||||
|
|
||||||
mm_inputs = {'imgs': imgs, 'img_metas': img_metas}
|
|
||||||
|
|
||||||
return mm_inputs
|
|
||||||
|
|
||||||
|
|
||||||
def _update_input_img(img_list, img_meta_list, update_ori_shape=False):
|
|
||||||
# update img and its meta list
|
|
||||||
N, C, H, W = img_list[0].shape
|
|
||||||
img_meta = img_meta_list[0][0]
|
|
||||||
img_shape = (H, W, C)
|
|
||||||
if update_ori_shape:
|
|
||||||
ori_shape = img_shape
|
|
||||||
else:
|
|
||||||
ori_shape = img_meta['ori_shape']
|
|
||||||
pad_shape = img_shape
|
|
||||||
new_img_meta_list = [[{
|
|
||||||
'img_shape':
|
|
||||||
img_shape,
|
|
||||||
'ori_shape':
|
|
||||||
ori_shape,
|
|
||||||
'pad_shape':
|
|
||||||
pad_shape,
|
|
||||||
'filename':
|
|
||||||
img_meta['filename'],
|
|
||||||
'scale_factor':
|
|
||||||
(img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2,
|
|
||||||
'flip':
|
|
||||||
False,
|
|
||||||
} for _ in range(N)]]
|
|
||||||
|
|
||||||
return img_list, new_img_meta_list
|
|
||||||
|
|
||||||
|
|
||||||
def pytorch2onnx(model,
|
|
||||||
mm_inputs,
|
|
||||||
opset_version=11,
|
|
||||||
show=False,
|
|
||||||
output_file='tmp.onnx',
|
|
||||||
verify=False,
|
|
||||||
dynamic_export=False):
|
|
||||||
"""Export Pytorch model to ONNX model and verify the outputs are same
|
|
||||||
between Pytorch and ONNX.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (nn.Module): Pytorch model we want to export.
|
|
||||||
mm_inputs (dict): Contain the input tensors and img_metas information.
|
|
||||||
opset_version (int): The onnx op version. Default: 11.
|
|
||||||
show (bool): Whether print the computation graph. Default: False.
|
|
||||||
output_file (string): The path to where we store the output ONNX model.
|
|
||||||
Default: `tmp.onnx`.
|
|
||||||
verify (bool): Whether compare the outputs between Pytorch and ONNX.
|
|
||||||
Default: False.
|
|
||||||
dynamic_export (bool): Whether to export ONNX with dynamic axis.
|
|
||||||
Default: False.
|
|
||||||
"""
|
|
||||||
model.cpu().eval()
|
|
||||||
test_mode = model.test_cfg.mode
|
|
||||||
|
|
||||||
if isinstance(model.decode_head, nn.ModuleList):
|
|
||||||
num_classes = model.decode_head[-1].num_classes
|
|
||||||
else:
|
|
||||||
num_classes = model.decode_head.num_classes
|
|
||||||
|
|
||||||
imgs = mm_inputs.pop('imgs')
|
|
||||||
img_metas = mm_inputs.pop('img_metas')
|
|
||||||
|
|
||||||
img_list = [img[None, :] for img in imgs]
|
|
||||||
img_meta_list = [[img_meta] for img_meta in img_metas]
|
|
||||||
# update img_meta
|
|
||||||
img_list, img_meta_list = _update_input_img(img_list, img_meta_list)
|
|
||||||
|
|
||||||
# replace original forward function
|
|
||||||
origin_forward = model.forward
|
|
||||||
model.forward = partial(
|
|
||||||
model.forward,
|
|
||||||
img_metas=img_meta_list,
|
|
||||||
return_loss=False,
|
|
||||||
rescale=True)
|
|
||||||
dynamic_axes = None
|
|
||||||
if dynamic_export:
|
|
||||||
if test_mode == 'slide':
|
|
||||||
dynamic_axes = {'input': {0: 'batch'}, 'output': {1: 'batch'}}
|
|
||||||
else:
|
|
||||||
dynamic_axes = {
|
|
||||||
'input': {
|
|
||||||
0: 'batch',
|
|
||||||
2: 'height',
|
|
||||||
3: 'width'
|
|
||||||
},
|
|
||||||
'output': {
|
|
||||||
1: 'batch',
|
|
||||||
2: 'height',
|
|
||||||
3: 'width'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
register_extra_symbolics(opset_version)
|
|
||||||
with torch.no_grad():
|
|
||||||
torch.onnx.export(
|
|
||||||
model, (img_list, ),
|
|
||||||
output_file,
|
|
||||||
input_names=['input'],
|
|
||||||
output_names=['output'],
|
|
||||||
export_params=True,
|
|
||||||
keep_initializers_as_inputs=False,
|
|
||||||
verbose=show,
|
|
||||||
opset_version=opset_version,
|
|
||||||
dynamic_axes=dynamic_axes)
|
|
||||||
print(f'Successfully exported ONNX model: {output_file}')
|
|
||||||
model.forward = origin_forward
|
|
||||||
|
|
||||||
if verify:
|
|
||||||
# check by onnx
|
|
||||||
import onnx
|
|
||||||
onnx_model = onnx.load(output_file)
|
|
||||||
onnx.checker.check_model(onnx_model)
|
|
||||||
|
|
||||||
if dynamic_export and test_mode == 'whole':
|
|
||||||
# scale image for dynamic shape test
|
|
||||||
img_list = [resize(_, scale_factor=1.5) for _ in img_list]
|
|
||||||
# concate flip image for batch test
|
|
||||||
flip_img_list = [_.flip(-1) for _ in img_list]
|
|
||||||
img_list = [
|
|
||||||
torch.cat((ori_img, flip_img), 0)
|
|
||||||
for ori_img, flip_img in zip(img_list, flip_img_list)
|
|
||||||
]
|
|
||||||
|
|
||||||
# update img_meta
|
|
||||||
img_list, img_meta_list = _update_input_img(
|
|
||||||
img_list, img_meta_list, test_mode == 'whole')
|
|
||||||
|
|
||||||
# check the numerical value
|
|
||||||
# get pytorch output
|
|
||||||
with torch.no_grad():
|
|
||||||
pytorch_result = model(img_list, img_meta_list, return_loss=False)
|
|
||||||
pytorch_result = np.stack(pytorch_result, 0)
|
|
||||||
|
|
||||||
# get onnx output
|
|
||||||
input_all = [node.name for node in onnx_model.graph.input]
|
|
||||||
input_initializer = [
|
|
||||||
node.name for node in onnx_model.graph.initializer
|
|
||||||
]
|
|
||||||
net_feed_input = list(set(input_all) - set(input_initializer))
|
|
||||||
assert (len(net_feed_input) == 1)
|
|
||||||
sess = rt.InferenceSession(output_file)
|
|
||||||
onnx_result = sess.run(
|
|
||||||
None, {net_feed_input[0]: img_list[0].detach().numpy()})[0][0]
|
|
||||||
# show segmentation results
|
|
||||||
if show:
|
|
||||||
import os.path as osp
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
img = img_meta_list[0][0]['filename']
|
|
||||||
if not osp.exists(img):
|
|
||||||
img = imgs[0][:3, ...].permute(1, 2, 0) * 255
|
|
||||||
img = img.detach().numpy().astype(np.uint8)
|
|
||||||
ori_shape = img.shape[:2]
|
|
||||||
else:
|
|
||||||
ori_shape = LoadImage()({'img': img})['ori_shape']
|
|
||||||
|
|
||||||
# resize onnx_result to ori_shape
|
|
||||||
onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8),
|
|
||||||
(ori_shape[1], ori_shape[0]))
|
|
||||||
show_result_pyplot(
|
|
||||||
model,
|
|
||||||
img, (onnx_result_, ),
|
|
||||||
palette=model.PALETTE,
|
|
||||||
block=False,
|
|
||||||
title='ONNXRuntime',
|
|
||||||
opacity=0.5)
|
|
||||||
|
|
||||||
# resize pytorch_result to ori_shape
|
|
||||||
pytorch_result_ = cv2.resize(pytorch_result[0].astype(np.uint8),
|
|
||||||
(ori_shape[1], ori_shape[0]))
|
|
||||||
show_result_pyplot(
|
|
||||||
model,
|
|
||||||
img, (pytorch_result_, ),
|
|
||||||
title='PyTorch',
|
|
||||||
palette=model.PALETTE,
|
|
||||||
opacity=0.5)
|
|
||||||
# compare results
|
|
||||||
np.testing.assert_allclose(
|
|
||||||
pytorch_result.astype(np.float32) / num_classes,
|
|
||||||
onnx_result.astype(np.float32) / num_classes,
|
|
||||||
rtol=1e-5,
|
|
||||||
atol=1e-5,
|
|
||||||
err_msg='The outputs are different between Pytorch and ONNX')
|
|
||||||
print('The outputs are same between Pytorch and ONNX')
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX')
|
|
||||||
parser.add_argument('config', help='test config file path')
|
|
||||||
parser.add_argument('--checkpoint', help='checkpoint file', default=None)
|
|
||||||
parser.add_argument(
|
|
||||||
'--input-img', type=str, help='Images for input', default=None)
|
|
||||||
parser.add_argument(
|
|
||||||
'--show',
|
|
||||||
action='store_true',
|
|
||||||
help='show onnx graph and segmentation results')
|
|
||||||
parser.add_argument(
|
|
||||||
'--verify', action='store_true', help='verify the onnx model')
|
|
||||||
parser.add_argument('--output-file', type=str, default='tmp.onnx')
|
|
||||||
parser.add_argument('--opset-version', type=int, default=11)
|
|
||||||
parser.add_argument(
|
|
||||||
'--shape',
|
|
||||||
type=int,
|
|
||||||
nargs='+',
|
|
||||||
default=None,
|
|
||||||
help='input image height and width.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--rescale_shape',
|
|
||||||
type=int,
|
|
||||||
nargs='+',
|
|
||||||
default=None,
|
|
||||||
help='output image rescale height and width, work for slide mode.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--cfg-options',
|
|
||||||
nargs='+',
|
|
||||||
action=DictAction,
|
|
||||||
help='Override some settings in the used config, the key-value pair '
|
|
||||||
'in xxx=yyy format will be merged into config file. If the value to '
|
|
||||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
|
||||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
|
||||||
'Note that the quotation marks are necessary and that no white space '
|
|
||||||
'is allowed.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--dynamic-export',
|
|
||||||
action='store_true',
|
|
||||||
help='Whether to export onnx with dynamic axis.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
args = parse_args()
|
|
||||||
|
|
||||||
cfg = mmcv.Config.fromfile(args.config)
|
|
||||||
if args.cfg_options is not None:
|
|
||||||
cfg.merge_from_dict(args.cfg_options)
|
|
||||||
cfg.model.pretrained = None
|
|
||||||
|
|
||||||
if args.shape is None:
|
|
||||||
img_scale = cfg.test_pipeline[1]['img_scale']
|
|
||||||
input_shape = (1, 3, img_scale[1], img_scale[0])
|
|
||||||
elif len(args.shape) == 1:
|
|
||||||
input_shape = (1, 3, args.shape[0], args.shape[0])
|
|
||||||
elif len(args.shape) == 2:
|
|
||||||
input_shape = (
|
|
||||||
1,
|
|
||||||
3,
|
|
||||||
) + tuple(args.shape)
|
|
||||||
else:
|
|
||||||
raise ValueError('invalid input shape')
|
|
||||||
|
|
||||||
test_mode = cfg.model.test_cfg.mode
|
|
||||||
|
|
||||||
# build the model and load checkpoint
|
|
||||||
cfg.model.train_cfg = None
|
|
||||||
segmentor = build_segmentor(
|
|
||||||
cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
|
|
||||||
# convert SyncBN to BN
|
|
||||||
segmentor = _convert_batchnorm(segmentor)
|
|
||||||
|
|
||||||
if args.checkpoint:
|
|
||||||
checkpoint = load_checkpoint(
|
|
||||||
segmentor, args.checkpoint, map_location='cpu')
|
|
||||||
segmentor.CLASSES = checkpoint['meta']['CLASSES']
|
|
||||||
segmentor.PALETTE = checkpoint['meta']['PALETTE']
|
|
||||||
|
|
||||||
# read input or create dummpy input
|
|
||||||
if args.input_img is not None:
|
|
||||||
preprocess_shape = (input_shape[2], input_shape[3])
|
|
||||||
rescale_shape = None
|
|
||||||
if args.rescale_shape is not None:
|
|
||||||
rescale_shape = [args.rescale_shape[0], args.rescale_shape[1]]
|
|
||||||
mm_inputs = _prepare_input_img(
|
|
||||||
args.input_img,
|
|
||||||
cfg.data.test.pipeline,
|
|
||||||
shape=preprocess_shape,
|
|
||||||
rescale_shape=rescale_shape)
|
|
||||||
else:
|
|
||||||
if isinstance(segmentor.decode_head, nn.ModuleList):
|
|
||||||
num_classes = segmentor.decode_head[-1].num_classes
|
|
||||||
else:
|
|
||||||
num_classes = segmentor.decode_head.num_classes
|
|
||||||
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
|
|
||||||
|
|
||||||
# convert model to onnx file
|
|
||||||
pytorch2onnx(
|
|
||||||
segmentor,
|
|
||||||
mm_inputs,
|
|
||||||
opset_version=args.opset_version,
|
|
||||||
show=args.show,
|
|
||||||
output_file=args.output_file,
|
|
||||||
verify=args.verify,
|
|
||||||
dynamic_export=args.dynamic_export)
|
|
||||||
|
|
||||||
# Following strings of text style are from colorama package
|
|
||||||
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
|
|
||||||
red_text, blue_text = '\x1b[31m', '\x1b[34m'
|
|
||||||
white_background = '\x1b[107m'
|
|
||||||
|
|
||||||
msg = white_background + bright_style + red_text
|
|
||||||
msg += 'DeprecationWarning: This tool will be deprecated in future. '
|
|
||||||
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
|
|
||||||
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
|
|
||||||
msg += reset_style
|
|
||||||
warnings.warn(msg)
|
|
Loading…
x
Reference in New Issue
Block a user