mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] remote sensing inference (#3131)
## Motivation Supports inference for ultra-large-scale remote sensing images. ## Modification Add RSImageInference.py in demo. ## Use cases Taking the inference of Vaihingen dataset images using PSPNet as an example, the following settings are required: **img**: Specify the path of the image. **model**: Provide the configuration file for the model. **checkpoint**: Specify the weight file for the model. **out**: Set the output path for the results. **batch_size**: Determine the batch size used during inference. **win_size**: Specify the width and height(512x512) of the sliding window. **stride**: Set the stride(400x400) for sliding the window. **thread(default: 1)**: Specify the number of threads to be used for inference. **Inference device (default: cuda:0)**: Specify the device for inference (e.g., cuda:0 for CPU). ```shell python demo/rs_image_inference.py demo/demo.png projects/pp_mobileseg/configs/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_tiny.py pp_mobileseg_mobilenetv3_2xb16_3rdparty-tiny_512x512-ade20k-a351ebf5.pth --batch-size 8 --device cpu --thread 2 ``` --------- Co-authored-by: xiexinch <xiexinch@outlook.com>
This commit is contained in:
parent
35ff78a07f
commit
72e20a8854
@ -73,7 +73,7 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Skip timm unittests and generate coverage report
|
name: Skip timm unittests and generate coverage report
|
||||||
command: |
|
command: |
|
||||||
python -m coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
|
python -m coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py --ignore tests/test_apis/test_rs_inferencer.py
|
||||||
python -m coverage xml
|
python -m coverage xml
|
||||||
python -m coverage report -m
|
python -m coverage report -m
|
||||||
build_cuda:
|
build_cuda:
|
||||||
@ -119,7 +119,7 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Run unittests but skip timm unittests
|
name: Run unittests but skip timm unittests
|
||||||
command: |
|
command: |
|
||||||
docker exec mmseg pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
|
docker exec mmseg pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py --ignore tests/test_models/test_backbones/test_timm_backbone.py --ignore tests/test_apis/test_rs_inferencer.py
|
||||||
workflows:
|
workflows:
|
||||||
pr_stage_lint:
|
pr_stage_lint:
|
||||||
when: << pipeline.parameters.lint_only >>
|
when: << pipeline.parameters.lint_only >>
|
||||||
|
50
demo/rs_image_inference.py
Normal file
50
demo/rs_image_inference.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
from mmseg.apis import RSImage, RSInferencer
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument('image', help='Image file path')
|
||||||
|
parser.add_argument('config', help='Config file')
|
||||||
|
parser.add_argument('checkpoint', help='Checkpoint file')
|
||||||
|
parser.add_argument(
|
||||||
|
'--output-path',
|
||||||
|
help='Path to save result image',
|
||||||
|
default='result.png')
|
||||||
|
parser.add_argument(
|
||||||
|
'--batch-size',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help='maximum number of windows inferred simultaneously')
|
||||||
|
parser.add_argument(
|
||||||
|
'--window-size',
|
||||||
|
help='window xsize,ysize',
|
||||||
|
default=(224, 224),
|
||||||
|
type=int,
|
||||||
|
nargs=2)
|
||||||
|
parser.add_argument(
|
||||||
|
'--stride',
|
||||||
|
help='window xstride,ystride',
|
||||||
|
default=(224, 224),
|
||||||
|
type=int,
|
||||||
|
nargs=2)
|
||||||
|
parser.add_argument(
|
||||||
|
'--thread', default=1, type=int, help='number of inference threads')
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', default='cuda:0', help='Device used for inference')
|
||||||
|
args = parser.parse_args()
|
||||||
|
inferencer = RSInferencer.from_config_path(
|
||||||
|
args.config,
|
||||||
|
args.checkpoint,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
thread=args.thread,
|
||||||
|
device=args.device)
|
||||||
|
image = RSImage(args.image)
|
||||||
|
|
||||||
|
inferencer.run(image, args.window_size, args.stride, args.output_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -1,7 +1,9 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .inference import inference_model, init_model, show_result_pyplot
|
from .inference import inference_model, init_model, show_result_pyplot
|
||||||
from .mmseg_inferencer import MMSegInferencer
|
from .mmseg_inferencer import MMSegInferencer
|
||||||
|
from .remote_sense_inferencer import RSImage, RSInferencer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer'
|
'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer',
|
||||||
|
'RSInferencer', 'RSImage'
|
||||||
]
|
]
|
||||||
|
@ -1,14 +1,12 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Sequence, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from mmengine import Config
|
from mmengine import Config
|
||||||
from mmengine.dataset import Compose
|
|
||||||
from mmengine.registry import init_default_scope
|
from mmengine.registry import init_default_scope
|
||||||
from mmengine.runner import load_checkpoint
|
from mmengine.runner import load_checkpoint
|
||||||
from mmengine.utils import mkdir_or_exist
|
from mmengine.utils import mkdir_or_exist
|
||||||
@ -18,6 +16,7 @@ from mmseg.registry import MODELS
|
|||||||
from mmseg.structures import SegDataSample
|
from mmseg.structures import SegDataSample
|
||||||
from mmseg.utils import SampleList, dataset_aliases, get_classes, get_palette
|
from mmseg.utils import SampleList, dataset_aliases, get_classes, get_palette
|
||||||
from mmseg.visualization import SegLocalVisualizer
|
from mmseg.visualization import SegLocalVisualizer
|
||||||
|
from .utils import ImageType, _preprare_data
|
||||||
|
|
||||||
|
|
||||||
def init_model(config: Union[str, Path, Config],
|
def init_model(config: Union[str, Path, Config],
|
||||||
@ -90,41 +89,6 @@ def init_model(config: Union[str, Path, Config],
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
|
|
||||||
|
|
||||||
|
|
||||||
def _preprare_data(imgs: ImageType, model: BaseSegmentor):
|
|
||||||
|
|
||||||
cfg = model.cfg
|
|
||||||
for t in cfg.test_pipeline:
|
|
||||||
if t.get('type') == 'LoadAnnotations':
|
|
||||||
cfg.test_pipeline.remove(t)
|
|
||||||
|
|
||||||
is_batch = True
|
|
||||||
if not isinstance(imgs, (list, tuple)):
|
|
||||||
imgs = [imgs]
|
|
||||||
is_batch = False
|
|
||||||
|
|
||||||
if isinstance(imgs[0], np.ndarray):
|
|
||||||
cfg.test_pipeline[0]['type'] = 'LoadImageFromNDArray'
|
|
||||||
|
|
||||||
# TODO: Consider using the singleton pattern to avoid building
|
|
||||||
# a pipeline for each inference
|
|
||||||
pipeline = Compose(cfg.test_pipeline)
|
|
||||||
|
|
||||||
data = defaultdict(list)
|
|
||||||
for img in imgs:
|
|
||||||
if isinstance(img, np.ndarray):
|
|
||||||
data_ = dict(img=img)
|
|
||||||
else:
|
|
||||||
data_ = dict(img_path=img)
|
|
||||||
data_ = pipeline(data_)
|
|
||||||
data['inputs'].append(data_['inputs'])
|
|
||||||
data['data_samples'].append(data_['data_samples'])
|
|
||||||
|
|
||||||
return data, is_batch
|
|
||||||
|
|
||||||
|
|
||||||
def inference_model(model: BaseSegmentor,
|
def inference_model(model: BaseSegmentor,
|
||||||
img: ImageType) -> Union[SegDataSample, SampleList]:
|
img: ImageType) -> Union[SegDataSample, SampleList]:
|
||||||
"""Inference image(s) with the segmentor.
|
"""Inference image(s) with the segmentor.
|
||||||
|
279
mmseg/apis/remote_sense_inferencer.py
Normal file
279
mmseg/apis/remote_sense_inferencer.py
Normal file
@ -0,0 +1,279 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import threading
|
||||||
|
from queue import Queue
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from mmengine import Config
|
||||||
|
from mmengine.model import BaseModel
|
||||||
|
from mmengine.registry import init_default_scope
|
||||||
|
from mmengine.runner import load_checkpoint
|
||||||
|
|
||||||
|
try:
|
||||||
|
from osgeo import gdal
|
||||||
|
except ImportError:
|
||||||
|
gdal = None
|
||||||
|
|
||||||
|
from mmseg.registry import MODELS
|
||||||
|
from .utils import _preprare_data
|
||||||
|
|
||||||
|
|
||||||
|
class RSImage:
|
||||||
|
"""Remote sensing image class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (str or gdal.Dataset): Image file path or gdal.Dataset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, image):
|
||||||
|
self.dataset = gdal.Open(image, gdal.GA_ReadOnly) if isinstance(
|
||||||
|
image, str) else image
|
||||||
|
assert isinstance(self.dataset, gdal.Dataset), \
|
||||||
|
f'{image} is not a image'
|
||||||
|
self.width = self.dataset.RasterXSize
|
||||||
|
self.height = self.dataset.RasterYSize
|
||||||
|
self.channel = self.dataset.RasterCount
|
||||||
|
self.trans = self.dataset.GetGeoTransform()
|
||||||
|
self.proj = self.dataset.GetProjection()
|
||||||
|
self.band_list = []
|
||||||
|
self.band_list.extend(
|
||||||
|
self.dataset.GetRasterBand(c + 1) for c in range(self.channel))
|
||||||
|
self.grids = []
|
||||||
|
|
||||||
|
def read(self, grid: Optional[List] = None) -> np.ndarray:
|
||||||
|
"""Read image data. If grid is None, read the whole image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid (Optional[List], optional): Grid to read. Defaults to None.
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Image data.
|
||||||
|
"""
|
||||||
|
if grid is None:
|
||||||
|
return np.einsum('ijk->jki', self.dataset.ReadAsArray())
|
||||||
|
assert len(
|
||||||
|
grid) >= 4, 'grid must be a list containing at least 4 elements'
|
||||||
|
data = self.dataset.ReadAsArray(*grid[:4])
|
||||||
|
if data.ndim == 2:
|
||||||
|
data = data[np.newaxis, ...]
|
||||||
|
return np.einsum('ijk->jki', data)
|
||||||
|
|
||||||
|
def write(self, data: Optional[np.ndarray], grid: Optional[List] = None):
|
||||||
|
"""Write image data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid (Optional[List], optional): Grid to write. Defaults to None.
|
||||||
|
data (Optional[np.ndarray], optional): Data to write.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Either grid or data must be provided.
|
||||||
|
"""
|
||||||
|
if grid is not None:
|
||||||
|
assert len(grid) == 8, 'grid must be a list of 8 elements'
|
||||||
|
for band in self.band_list:
|
||||||
|
band.WriteArray(
|
||||||
|
data[grid[5]:grid[5] + grid[7], grid[4]:grid[4] + grid[6]],
|
||||||
|
grid[0] + grid[4], grid[1] + grid[5])
|
||||||
|
elif data is not None:
|
||||||
|
for i in range(self.channel):
|
||||||
|
self.band_list[i].WriteArray(data[..., i])
|
||||||
|
else:
|
||||||
|
raise ValueError('Either grid or data must be provided.')
|
||||||
|
|
||||||
|
def create_seg_map(self, output_path: Optional[str] = None):
|
||||||
|
if output_path is None:
|
||||||
|
output_path = 'output_label.tif'
|
||||||
|
driver = gdal.GetDriverByName('GTiff')
|
||||||
|
seg_map = driver.Create(output_path, self.width, self.height, 1,
|
||||||
|
gdal.GDT_Byte)
|
||||||
|
seg_map.SetGeoTransform(self.trans)
|
||||||
|
seg_map.SetProjection(self.proj)
|
||||||
|
seg_map_img = RSImage(seg_map)
|
||||||
|
seg_map_img.path = output_path
|
||||||
|
return seg_map_img
|
||||||
|
|
||||||
|
def create_grids(self,
|
||||||
|
window_size: Tuple[int, int],
|
||||||
|
stride: Tuple[int, int] = (0, 0)):
|
||||||
|
"""Create grids for image inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
window_size (Tuple[int, int]): the size of the sliding window.
|
||||||
|
stride (Tuple[int, int], optional): the stride of the sliding
|
||||||
|
window. Defaults to (0, 0).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: window_size must be a tuple of 2 elements.
|
||||||
|
AssertionError: stride must be a tuple of 2 elements.
|
||||||
|
"""
|
||||||
|
assert len(
|
||||||
|
window_size) == 2, 'window_size must be a tuple of 2 elements'
|
||||||
|
assert len(stride) == 2, 'stride must be a tuple of 2 elements'
|
||||||
|
win_w, win_h = window_size
|
||||||
|
stride_x, stride_y = stride
|
||||||
|
|
||||||
|
stride_x = win_w if stride_x == 0 else stride_x
|
||||||
|
stride_y = win_h if stride_y == 0 else stride_y
|
||||||
|
|
||||||
|
x_half_overlap = (win_w - stride_x + 1) // 2
|
||||||
|
y_half_overlap = (win_h - stride_y + 1) // 2
|
||||||
|
|
||||||
|
for y in range(0, self.height, stride_y):
|
||||||
|
y_end = y + win_h >= self.height
|
||||||
|
y_offset = self.height - win_h if y_end else y
|
||||||
|
y_size = win_h
|
||||||
|
y_crop_off = 0 if y_offset == 0 else y_half_overlap
|
||||||
|
y_crop_size = y_size if y_end else win_h - y_crop_off
|
||||||
|
|
||||||
|
for x in range(0, self.width, stride_x):
|
||||||
|
x_end = x + win_w >= self.width
|
||||||
|
x_offset = self.width - win_w if x_end else x
|
||||||
|
x_size = win_w
|
||||||
|
x_crop_off = 0 if x_offset == 0 else x_half_overlap
|
||||||
|
x_crop_size = x_size if x_end else win_w - x_crop_off
|
||||||
|
|
||||||
|
self.grids.append([
|
||||||
|
x_offset, y_offset, x_size, y_size, x_crop_off, y_crop_off,
|
||||||
|
x_crop_size, y_crop_size
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
class RSInferencer:
|
||||||
|
"""Remote sensing inference class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (BaseModel): The loaded model.
|
||||||
|
batch_size (int, optional): Batch size. Defaults to 1.
|
||||||
|
thread (int, optional): Number of threads. Defaults to 1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model: BaseModel, batch_size: int = 1, thread: int = 1):
|
||||||
|
self.model = model
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.END_FLAG = object()
|
||||||
|
self.read_buffer = Queue(self.batch_size)
|
||||||
|
self.write_buffer = Queue(self.batch_size)
|
||||||
|
self.thread = thread
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config_path(cls,
|
||||||
|
config_path: str,
|
||||||
|
checkpoint_path: str,
|
||||||
|
batch_size: int = 1,
|
||||||
|
thread: int = 1,
|
||||||
|
device: Optional[str] = 'cpu'):
|
||||||
|
"""Initialize a segmentor from config file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path (str): Config file path.
|
||||||
|
checkpoint_path (str): Checkpoint path.
|
||||||
|
batch_size (int, optional): Batch size. Defaults to 1.
|
||||||
|
"""
|
||||||
|
init_default_scope('mmseg')
|
||||||
|
cfg = Config.fromfile(config_path)
|
||||||
|
model = MODELS.build(cfg.model)
|
||||||
|
model.cfg = cfg
|
||||||
|
load_checkpoint(model, checkpoint_path, map_location='cpu')
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
return cls(model, batch_size, thread)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_model(cls,
|
||||||
|
model: BaseModel,
|
||||||
|
checkpoint_path: Optional[str] = None,
|
||||||
|
batch_size: int = 1,
|
||||||
|
thread: int = 1,
|
||||||
|
device: Optional[str] = 'cpu'):
|
||||||
|
"""Initialize a segmentor from model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (BaseModel): The loaded model.
|
||||||
|
checkpoint_path (Optional[str]): Checkpoint path.
|
||||||
|
batch_size (int, optional): Batch size. Defaults to 1.
|
||||||
|
"""
|
||||||
|
if checkpoint_path is not None:
|
||||||
|
load_checkpoint(model, checkpoint_path, map_location='cpu')
|
||||||
|
model.to(device)
|
||||||
|
return cls(model, batch_size, thread)
|
||||||
|
|
||||||
|
def read(self,
|
||||||
|
image: RSImage,
|
||||||
|
window_size: Tuple[int, int],
|
||||||
|
strides: Tuple[int, int] = (0, 0)):
|
||||||
|
"""Load image data to read buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (RSImage): The image to read.
|
||||||
|
window_size (Tuple[int, int]): The size of the sliding window.
|
||||||
|
strides (Tuple[int, int], optional): The stride of the sliding
|
||||||
|
window. Defaults to (0, 0).
|
||||||
|
"""
|
||||||
|
image.create_grids(window_size, strides)
|
||||||
|
for grid in image.grids:
|
||||||
|
self.read_buffer.put([grid, image.read(grid=grid)])
|
||||||
|
self.read_buffer.put(self.END_FLAG)
|
||||||
|
|
||||||
|
def inference(self):
|
||||||
|
"""Inference image data from read buffer and put the result to write
|
||||||
|
buffer."""
|
||||||
|
while True:
|
||||||
|
item = self.read_buffer.get()
|
||||||
|
if item == self.END_FLAG:
|
||||||
|
self.read_buffer.put(self.END_FLAG)
|
||||||
|
self.write_buffer.put(item)
|
||||||
|
break
|
||||||
|
data, _ = _preprare_data(item[1], self.model)
|
||||||
|
with torch.no_grad():
|
||||||
|
result = self.model.test_step(data)
|
||||||
|
item[1] = result[0].pred_sem_seg.cpu().data.numpy()[0]
|
||||||
|
self.write_buffer.put(item)
|
||||||
|
self.read_buffer.task_done()
|
||||||
|
|
||||||
|
def write(self, image: RSImage, output_path: Optional[str] = None):
|
||||||
|
"""Write image data from write buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (RSImage): The image to write.
|
||||||
|
output_path (Optional[str], optional): The path to save the
|
||||||
|
segmentation map. Defaults to None.
|
||||||
|
"""
|
||||||
|
seg_map = image.create_seg_map(output_path)
|
||||||
|
while True:
|
||||||
|
item = self.write_buffer.get()
|
||||||
|
if item == self.END_FLAG:
|
||||||
|
break
|
||||||
|
seg_map.write(data=item[1], grid=item[0])
|
||||||
|
self.write_buffer.task_done()
|
||||||
|
|
||||||
|
def run(self,
|
||||||
|
image: RSImage,
|
||||||
|
window_size: Tuple[int, int],
|
||||||
|
strides: Tuple[int, int] = (0, 0),
|
||||||
|
output_path: Optional[str] = None):
|
||||||
|
"""Run inference with multi-threading.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (RSImage): The image to inference.
|
||||||
|
window_size (Tuple[int, int]): The size of the sliding window.
|
||||||
|
strides (Tuple[int, int], optional): The stride of the sliding
|
||||||
|
window. Defaults to (0, 0).
|
||||||
|
output_path (Optional[str], optional): The path to save the
|
||||||
|
segmentation map. Defaults to None.
|
||||||
|
"""
|
||||||
|
read_thread = threading.Thread(
|
||||||
|
target=self.read, args=(image, window_size, strides))
|
||||||
|
read_thread.start()
|
||||||
|
inference_threads = []
|
||||||
|
for _ in range(self.thread):
|
||||||
|
inference_thread = threading.Thread(target=self.inference)
|
||||||
|
inference_thread.start()
|
||||||
|
inference_threads.append(inference_thread)
|
||||||
|
write_thread = threading.Thread(
|
||||||
|
target=self.write, args=(image, output_path))
|
||||||
|
write_thread.start()
|
||||||
|
read_thread.join()
|
||||||
|
for inference_thread in inference_threads:
|
||||||
|
inference_thread.join()
|
||||||
|
write_thread.join()
|
41
mmseg/apis/utils.py
Normal file
41
mmseg/apis/utils.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from mmengine.dataset import Compose
|
||||||
|
from mmengine.model import BaseModel
|
||||||
|
|
||||||
|
ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
|
||||||
|
|
||||||
|
|
||||||
|
def _preprare_data(imgs: ImageType, model: BaseModel):
|
||||||
|
|
||||||
|
cfg = model.cfg
|
||||||
|
for t in cfg.test_pipeline:
|
||||||
|
if t.get('type') == 'LoadAnnotations':
|
||||||
|
cfg.test_pipeline.remove(t)
|
||||||
|
|
||||||
|
is_batch = True
|
||||||
|
if not isinstance(imgs, (list, tuple)):
|
||||||
|
imgs = [imgs]
|
||||||
|
is_batch = False
|
||||||
|
|
||||||
|
if isinstance(imgs[0], np.ndarray):
|
||||||
|
cfg.test_pipeline[0]['type'] = 'LoadImageFromNDArray'
|
||||||
|
|
||||||
|
# TODO: Consider using the singleton pattern to avoid building
|
||||||
|
# a pipeline for each inference
|
||||||
|
pipeline = Compose(cfg.test_pipeline)
|
||||||
|
|
||||||
|
data = defaultdict(list)
|
||||||
|
for img in imgs:
|
||||||
|
if isinstance(img, np.ndarray):
|
||||||
|
data_ = dict(img=img)
|
||||||
|
else:
|
||||||
|
data_ = dict(img_path=img)
|
||||||
|
data_ = pipeline(data_)
|
||||||
|
data['inputs'].append(data_['inputs'])
|
||||||
|
data['data_samples'].append(data_['data_samples'])
|
||||||
|
|
||||||
|
return data, is_batch
|
@ -3,48 +3,14 @@ import tempfile
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
from mmengine import ConfigDict
|
from mmengine import ConfigDict
|
||||||
|
from utils import * # noqa: F401, F403
|
||||||
|
|
||||||
from mmseg.apis import MMSegInferencer
|
from mmseg.apis import MMSegInferencer
|
||||||
from mmseg.models import EncoderDecoder
|
|
||||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
|
||||||
from mmseg.registry import MODELS
|
from mmseg.registry import MODELS
|
||||||
from mmseg.utils import register_all_modules
|
from mmseg.utils import register_all_modules
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module(name='InferExampleHead')
|
|
||||||
class ExampleDecodeHead(BaseDecodeHead):
|
|
||||||
|
|
||||||
def __init__(self, num_classes=19, out_channels=None):
|
|
||||||
super().__init__(
|
|
||||||
3, 3, num_classes=num_classes, out_channels=out_channels)
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
return self.cls_seg(inputs[0])
|
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module(name='InferExampleBackbone')
|
|
||||||
class ExampleBackbone(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = nn.Conv2d(3, 3, 3)
|
|
||||||
|
|
||||||
def init_weights(self, pretrained=None):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return [self.conv(x)]
|
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module(name='InferExampleModel')
|
|
||||||
class ExampleModel(EncoderDecoder):
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def test_inferencer():
|
def test_inferencer():
|
||||||
register_all_modules()
|
register_all_modules()
|
||||||
|
|
||||||
|
73
tests/test_apis/test_rs_inferencer.py
Normal file
73
tests/test_apis/test_rs_inferencer.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os.path as osp
|
||||||
|
from unittest import TestCase
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from mmengine import ConfigDict, init_default_scope
|
||||||
|
from utils import * # noqa: F401, F403
|
||||||
|
|
||||||
|
from mmseg.apis import RSImage, RSInferencer
|
||||||
|
from mmseg.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
|
class TestRSImage(TestCase):
|
||||||
|
|
||||||
|
def test_read_whole_image(self):
|
||||||
|
init_default_scope('mmseg')
|
||||||
|
img_path = osp.join(
|
||||||
|
osp.dirname(__file__),
|
||||||
|
'../data/pseudo_loveda_dataset/img_dir/0.png')
|
||||||
|
rs_image = RSImage(img_path)
|
||||||
|
window_size = (16, 16)
|
||||||
|
rs_image.create_grids(window_size)
|
||||||
|
image_data = rs_image.read(rs_image.grids[0])
|
||||||
|
self.assertIsNotNone(image_data)
|
||||||
|
|
||||||
|
def test_write_image_data(self):
|
||||||
|
init_default_scope('mmseg')
|
||||||
|
img_path = osp.join(
|
||||||
|
osp.dirname(__file__),
|
||||||
|
'../data/pseudo_loveda_dataset/img_dir/0.png')
|
||||||
|
rs_image = RSImage(img_path)
|
||||||
|
window_size = (16, 16)
|
||||||
|
rs_image.create_grids(window_size)
|
||||||
|
data = np.random.random((16, 16)).astype(np.int8)
|
||||||
|
rs_image.write(data, rs_image.grids[0])
|
||||||
|
|
||||||
|
|
||||||
|
class TestRSInferencer(TestCase):
|
||||||
|
|
||||||
|
def test_read_and_inference(self):
|
||||||
|
init_default_scope('mmseg')
|
||||||
|
cfg_dict = dict(
|
||||||
|
model=dict(
|
||||||
|
type='InferExampleModel',
|
||||||
|
data_preprocessor=dict(type='SegDataPreProcessor'),
|
||||||
|
backbone=dict(type='InferExampleBackbone'),
|
||||||
|
decode_head=dict(type='InferExampleHead'),
|
||||||
|
test_cfg=dict(mode='whole')),
|
||||||
|
test_dataloader=dict(
|
||||||
|
dataset=dict(
|
||||||
|
type='ExampleDataset',
|
||||||
|
pipeline=[
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations'),
|
||||||
|
dict(type='PackSegInputs')
|
||||||
|
])),
|
||||||
|
test_pipeline=[
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations'),
|
||||||
|
dict(type='PackSegInputs')
|
||||||
|
])
|
||||||
|
cfg = ConfigDict(cfg_dict)
|
||||||
|
model = MODELS.build(cfg.model)
|
||||||
|
model.cfg = cfg
|
||||||
|
inferencer = RSInferencer.from_model(model)
|
||||||
|
|
||||||
|
img_path = osp.join(
|
||||||
|
osp.dirname(__file__),
|
||||||
|
'../data/pseudo_loveda_dataset/img_dir/0.png')
|
||||||
|
rs_image = RSImage(img_path)
|
||||||
|
window_size = (16, 16)
|
||||||
|
stride = (16, 16)
|
||||||
|
inferencer.run(rs_image, window_size, stride)
|
38
tests/test_apis/utils.py
Normal file
38
tests/test_apis/utils.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from mmseg.models import EncoderDecoder
|
||||||
|
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||||
|
from mmseg.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module(name='InferExampleHead')
|
||||||
|
class ExampleDecodeHead(BaseDecodeHead):
|
||||||
|
|
||||||
|
def __init__(self, num_classes=19, out_channels=None):
|
||||||
|
super().__init__(
|
||||||
|
3, 3, num_classes=num_classes, out_channels=out_channels)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
return self.cls_seg(inputs[0])
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module(name='InferExampleBackbone')
|
||||||
|
class ExampleBackbone(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(3, 3, 3)
|
||||||
|
|
||||||
|
def init_weights(self, pretrained=None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return [self.conv(x)]
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module(name='InferExampleModel')
|
||||||
|
class ExampleModel(EncoderDecoder):
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
Loading…
x
Reference in New Issue
Block a user