[Feature] remote sensing inference (#3131)
## Motivation Supports inference for ultra-large-scale remote sensing images. ## Modification Add RSImageInference.py in demo. ## Use cases Taking the inference of Vaihingen dataset images using PSPNet as an example, the following settings are required: **img**: Specify the path of the image. **model**: Provide the configuration file for the model. **checkpoint**: Specify the weight file for the model. **out**: Set the output path for the results. **batch_size**: Determine the batch size used during inference. **win_size**: Specify the width and height(512x512) of the sliding window. **stride**: Set the stride(400x400) for sliding the window. **thread(default: 1)**: Specify the number of threads to be used for inference. **Inference device (default: cuda:0)**: Specify the device for inference (e.g., cuda:0 for CPU). ```shell python demo/rs_image_inference.py demo/demo.png projects/pp_mobileseg/configs/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_tiny.py pp_mobileseg_mobilenetv3_2xb16_3rdparty-tiny_512x512-ade20k-a351ebf5.pth --batch-size 8 --device cpu --thread 2 ``` --------- Co-authored-by: xiexinch <xiexinch@outlook.com>pull/3306/head
parent
35ff78a07f
commit
72e20a8854
|
@ -73,7 +73,7 @@ jobs:
|
|||
- run:
|
||||
name: Skip timm unittests and generate coverage report
|
||||
command: |
|
||||
python -m coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
|
||||
python -m coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py --ignore tests/test_apis/test_rs_inferencer.py
|
||||
python -m coverage xml
|
||||
python -m coverage report -m
|
||||
build_cuda:
|
||||
|
@ -119,7 +119,7 @@ jobs:
|
|||
- run:
|
||||
name: Run unittests but skip timm unittests
|
||||
command: |
|
||||
docker exec mmseg pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
|
||||
docker exec mmseg pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py --ignore tests/test_models/test_backbones/test_timm_backbone.py --ignore tests/test_apis/test_rs_inferencer.py
|
||||
workflows:
|
||||
pr_stage_lint:
|
||||
when: << pipeline.parameters.lint_only >>
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from mmseg.apis import RSImage, RSInferencer
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('image', help='Image file path')
|
||||
parser.add_argument('config', help='Config file')
|
||||
parser.add_argument('checkpoint', help='Checkpoint file')
|
||||
parser.add_argument(
|
||||
'--output-path',
|
||||
help='Path to save result image',
|
||||
default='result.png')
|
||||
parser.add_argument(
|
||||
'--batch-size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='maximum number of windows inferred simultaneously')
|
||||
parser.add_argument(
|
||||
'--window-size',
|
||||
help='window xsize,ysize',
|
||||
default=(224, 224),
|
||||
type=int,
|
||||
nargs=2)
|
||||
parser.add_argument(
|
||||
'--stride',
|
||||
help='window xstride,ystride',
|
||||
default=(224, 224),
|
||||
type=int,
|
||||
nargs=2)
|
||||
parser.add_argument(
|
||||
'--thread', default=1, type=int, help='number of inference threads')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference')
|
||||
args = parser.parse_args()
|
||||
inferencer = RSInferencer.from_config_path(
|
||||
args.config,
|
||||
args.checkpoint,
|
||||
batch_size=args.batch_size,
|
||||
thread=args.thread,
|
||||
device=args.device)
|
||||
image = RSImage(args.image)
|
||||
|
||||
inferencer.run(image, args.window_size, args.stride, args.output_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,7 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .inference import inference_model, init_model, show_result_pyplot
|
||||
from .mmseg_inferencer import MMSegInferencer
|
||||
from .remote_sense_inferencer import RSImage, RSInferencer
|
||||
|
||||
__all__ = [
|
||||
'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer'
|
||||
'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer',
|
||||
'RSInferencer', 'RSImage'
|
||||
]
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import Config
|
||||
from mmengine.dataset import Compose
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.runner import load_checkpoint
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
|
@ -18,6 +16,7 @@ from mmseg.registry import MODELS
|
|||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import SampleList, dataset_aliases, get_classes, get_palette
|
||||
from mmseg.visualization import SegLocalVisualizer
|
||||
from .utils import ImageType, _preprare_data
|
||||
|
||||
|
||||
def init_model(config: Union[str, Path, Config],
|
||||
|
@ -90,41 +89,6 @@ def init_model(config: Union[str, Path, Config],
|
|||
return model
|
||||
|
||||
|
||||
ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
|
||||
|
||||
|
||||
def _preprare_data(imgs: ImageType, model: BaseSegmentor):
|
||||
|
||||
cfg = model.cfg
|
||||
for t in cfg.test_pipeline:
|
||||
if t.get('type') == 'LoadAnnotations':
|
||||
cfg.test_pipeline.remove(t)
|
||||
|
||||
is_batch = True
|
||||
if not isinstance(imgs, (list, tuple)):
|
||||
imgs = [imgs]
|
||||
is_batch = False
|
||||
|
||||
if isinstance(imgs[0], np.ndarray):
|
||||
cfg.test_pipeline[0]['type'] = 'LoadImageFromNDArray'
|
||||
|
||||
# TODO: Consider using the singleton pattern to avoid building
|
||||
# a pipeline for each inference
|
||||
pipeline = Compose(cfg.test_pipeline)
|
||||
|
||||
data = defaultdict(list)
|
||||
for img in imgs:
|
||||
if isinstance(img, np.ndarray):
|
||||
data_ = dict(img=img)
|
||||
else:
|
||||
data_ = dict(img_path=img)
|
||||
data_ = pipeline(data_)
|
||||
data['inputs'].append(data_['inputs'])
|
||||
data['data_samples'].append(data_['data_samples'])
|
||||
|
||||
return data, is_batch
|
||||
|
||||
|
||||
def inference_model(model: BaseSegmentor,
|
||||
img: ImageType) -> Union[SegDataSample, SampleList]:
|
||||
"""Inference image(s) with the segmentor.
|
||||
|
|
|
@ -0,0 +1,279 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import threading
|
||||
from queue import Queue
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import Config
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.runner import load_checkpoint
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
except ImportError:
|
||||
gdal = None
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .utils import _preprare_data
|
||||
|
||||
|
||||
class RSImage:
|
||||
"""Remote sensing image class.
|
||||
|
||||
Args:
|
||||
img (str or gdal.Dataset): Image file path or gdal.Dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, image):
|
||||
self.dataset = gdal.Open(image, gdal.GA_ReadOnly) if isinstance(
|
||||
image, str) else image
|
||||
assert isinstance(self.dataset, gdal.Dataset), \
|
||||
f'{image} is not a image'
|
||||
self.width = self.dataset.RasterXSize
|
||||
self.height = self.dataset.RasterYSize
|
||||
self.channel = self.dataset.RasterCount
|
||||
self.trans = self.dataset.GetGeoTransform()
|
||||
self.proj = self.dataset.GetProjection()
|
||||
self.band_list = []
|
||||
self.band_list.extend(
|
||||
self.dataset.GetRasterBand(c + 1) for c in range(self.channel))
|
||||
self.grids = []
|
||||
|
||||
def read(self, grid: Optional[List] = None) -> np.ndarray:
|
||||
"""Read image data. If grid is None, read the whole image.
|
||||
|
||||
Args:
|
||||
grid (Optional[List], optional): Grid to read. Defaults to None.
|
||||
Returns:
|
||||
np.ndarray: Image data.
|
||||
"""
|
||||
if grid is None:
|
||||
return np.einsum('ijk->jki', self.dataset.ReadAsArray())
|
||||
assert len(
|
||||
grid) >= 4, 'grid must be a list containing at least 4 elements'
|
||||
data = self.dataset.ReadAsArray(*grid[:4])
|
||||
if data.ndim == 2:
|
||||
data = data[np.newaxis, ...]
|
||||
return np.einsum('ijk->jki', data)
|
||||
|
||||
def write(self, data: Optional[np.ndarray], grid: Optional[List] = None):
|
||||
"""Write image data.
|
||||
|
||||
Args:
|
||||
grid (Optional[List], optional): Grid to write. Defaults to None.
|
||||
data (Optional[np.ndarray], optional): Data to write.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: Either grid or data must be provided.
|
||||
"""
|
||||
if grid is not None:
|
||||
assert len(grid) == 8, 'grid must be a list of 8 elements'
|
||||
for band in self.band_list:
|
||||
band.WriteArray(
|
||||
data[grid[5]:grid[5] + grid[7], grid[4]:grid[4] + grid[6]],
|
||||
grid[0] + grid[4], grid[1] + grid[5])
|
||||
elif data is not None:
|
||||
for i in range(self.channel):
|
||||
self.band_list[i].WriteArray(data[..., i])
|
||||
else:
|
||||
raise ValueError('Either grid or data must be provided.')
|
||||
|
||||
def create_seg_map(self, output_path: Optional[str] = None):
|
||||
if output_path is None:
|
||||
output_path = 'output_label.tif'
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
seg_map = driver.Create(output_path, self.width, self.height, 1,
|
||||
gdal.GDT_Byte)
|
||||
seg_map.SetGeoTransform(self.trans)
|
||||
seg_map.SetProjection(self.proj)
|
||||
seg_map_img = RSImage(seg_map)
|
||||
seg_map_img.path = output_path
|
||||
return seg_map_img
|
||||
|
||||
def create_grids(self,
|
||||
window_size: Tuple[int, int],
|
||||
stride: Tuple[int, int] = (0, 0)):
|
||||
"""Create grids for image inference.
|
||||
|
||||
Args:
|
||||
window_size (Tuple[int, int]): the size of the sliding window.
|
||||
stride (Tuple[int, int], optional): the stride of the sliding
|
||||
window. Defaults to (0, 0).
|
||||
|
||||
Raises:
|
||||
AssertionError: window_size must be a tuple of 2 elements.
|
||||
AssertionError: stride must be a tuple of 2 elements.
|
||||
"""
|
||||
assert len(
|
||||
window_size) == 2, 'window_size must be a tuple of 2 elements'
|
||||
assert len(stride) == 2, 'stride must be a tuple of 2 elements'
|
||||
win_w, win_h = window_size
|
||||
stride_x, stride_y = stride
|
||||
|
||||
stride_x = win_w if stride_x == 0 else stride_x
|
||||
stride_y = win_h if stride_y == 0 else stride_y
|
||||
|
||||
x_half_overlap = (win_w - stride_x + 1) // 2
|
||||
y_half_overlap = (win_h - stride_y + 1) // 2
|
||||
|
||||
for y in range(0, self.height, stride_y):
|
||||
y_end = y + win_h >= self.height
|
||||
y_offset = self.height - win_h if y_end else y
|
||||
y_size = win_h
|
||||
y_crop_off = 0 if y_offset == 0 else y_half_overlap
|
||||
y_crop_size = y_size if y_end else win_h - y_crop_off
|
||||
|
||||
for x in range(0, self.width, stride_x):
|
||||
x_end = x + win_w >= self.width
|
||||
x_offset = self.width - win_w if x_end else x
|
||||
x_size = win_w
|
||||
x_crop_off = 0 if x_offset == 0 else x_half_overlap
|
||||
x_crop_size = x_size if x_end else win_w - x_crop_off
|
||||
|
||||
self.grids.append([
|
||||
x_offset, y_offset, x_size, y_size, x_crop_off, y_crop_off,
|
||||
x_crop_size, y_crop_size
|
||||
])
|
||||
|
||||
|
||||
class RSInferencer:
|
||||
"""Remote sensing inference class.
|
||||
|
||||
Args:
|
||||
model (BaseModel): The loaded model.
|
||||
batch_size (int, optional): Batch size. Defaults to 1.
|
||||
thread (int, optional): Number of threads. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(self, model: BaseModel, batch_size: int = 1, thread: int = 1):
|
||||
self.model = model
|
||||
self.batch_size = batch_size
|
||||
self.END_FLAG = object()
|
||||
self.read_buffer = Queue(self.batch_size)
|
||||
self.write_buffer = Queue(self.batch_size)
|
||||
self.thread = thread
|
||||
|
||||
@classmethod
|
||||
def from_config_path(cls,
|
||||
config_path: str,
|
||||
checkpoint_path: str,
|
||||
batch_size: int = 1,
|
||||
thread: int = 1,
|
||||
device: Optional[str] = 'cpu'):
|
||||
"""Initialize a segmentor from config file.
|
||||
|
||||
Args:
|
||||
config_path (str): Config file path.
|
||||
checkpoint_path (str): Checkpoint path.
|
||||
batch_size (int, optional): Batch size. Defaults to 1.
|
||||
"""
|
||||
init_default_scope('mmseg')
|
||||
cfg = Config.fromfile(config_path)
|
||||
model = MODELS.build(cfg.model)
|
||||
model.cfg = cfg
|
||||
load_checkpoint(model, checkpoint_path, map_location='cpu')
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return cls(model, batch_size, thread)
|
||||
|
||||
@classmethod
|
||||
def from_model(cls,
|
||||
model: BaseModel,
|
||||
checkpoint_path: Optional[str] = None,
|
||||
batch_size: int = 1,
|
||||
thread: int = 1,
|
||||
device: Optional[str] = 'cpu'):
|
||||
"""Initialize a segmentor from model.
|
||||
|
||||
Args:
|
||||
model (BaseModel): The loaded model.
|
||||
checkpoint_path (Optional[str]): Checkpoint path.
|
||||
batch_size (int, optional): Batch size. Defaults to 1.
|
||||
"""
|
||||
if checkpoint_path is not None:
|
||||
load_checkpoint(model, checkpoint_path, map_location='cpu')
|
||||
model.to(device)
|
||||
return cls(model, batch_size, thread)
|
||||
|
||||
def read(self,
|
||||
image: RSImage,
|
||||
window_size: Tuple[int, int],
|
||||
strides: Tuple[int, int] = (0, 0)):
|
||||
"""Load image data to read buffer.
|
||||
|
||||
Args:
|
||||
image (RSImage): The image to read.
|
||||
window_size (Tuple[int, int]): The size of the sliding window.
|
||||
strides (Tuple[int, int], optional): The stride of the sliding
|
||||
window. Defaults to (0, 0).
|
||||
"""
|
||||
image.create_grids(window_size, strides)
|
||||
for grid in image.grids:
|
||||
self.read_buffer.put([grid, image.read(grid=grid)])
|
||||
self.read_buffer.put(self.END_FLAG)
|
||||
|
||||
def inference(self):
|
||||
"""Inference image data from read buffer and put the result to write
|
||||
buffer."""
|
||||
while True:
|
||||
item = self.read_buffer.get()
|
||||
if item == self.END_FLAG:
|
||||
self.read_buffer.put(self.END_FLAG)
|
||||
self.write_buffer.put(item)
|
||||
break
|
||||
data, _ = _preprare_data(item[1], self.model)
|
||||
with torch.no_grad():
|
||||
result = self.model.test_step(data)
|
||||
item[1] = result[0].pred_sem_seg.cpu().data.numpy()[0]
|
||||
self.write_buffer.put(item)
|
||||
self.read_buffer.task_done()
|
||||
|
||||
def write(self, image: RSImage, output_path: Optional[str] = None):
|
||||
"""Write image data from write buffer.
|
||||
|
||||
Args:
|
||||
image (RSImage): The image to write.
|
||||
output_path (Optional[str], optional): The path to save the
|
||||
segmentation map. Defaults to None.
|
||||
"""
|
||||
seg_map = image.create_seg_map(output_path)
|
||||
while True:
|
||||
item = self.write_buffer.get()
|
||||
if item == self.END_FLAG:
|
||||
break
|
||||
seg_map.write(data=item[1], grid=item[0])
|
||||
self.write_buffer.task_done()
|
||||
|
||||
def run(self,
|
||||
image: RSImage,
|
||||
window_size: Tuple[int, int],
|
||||
strides: Tuple[int, int] = (0, 0),
|
||||
output_path: Optional[str] = None):
|
||||
"""Run inference with multi-threading.
|
||||
|
||||
Args:
|
||||
image (RSImage): The image to inference.
|
||||
window_size (Tuple[int, int]): The size of the sliding window.
|
||||
strides (Tuple[int, int], optional): The stride of the sliding
|
||||
window. Defaults to (0, 0).
|
||||
output_path (Optional[str], optional): The path to save the
|
||||
segmentation map. Defaults to None.
|
||||
"""
|
||||
read_thread = threading.Thread(
|
||||
target=self.read, args=(image, window_size, strides))
|
||||
read_thread.start()
|
||||
inference_threads = []
|
||||
for _ in range(self.thread):
|
||||
inference_thread = threading.Thread(target=self.inference)
|
||||
inference_thread.start()
|
||||
inference_threads.append(inference_thread)
|
||||
write_thread = threading.Thread(
|
||||
target=self.write, args=(image, output_path))
|
||||
write_thread.start()
|
||||
read_thread.join()
|
||||
for inference_thread in inference_threads:
|
||||
inference_thread.join()
|
||||
write_thread.join()
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from collections import defaultdict
|
||||
from typing import Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
from mmengine.dataset import Compose
|
||||
from mmengine.model import BaseModel
|
||||
|
||||
ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
|
||||
|
||||
|
||||
def _preprare_data(imgs: ImageType, model: BaseModel):
|
||||
|
||||
cfg = model.cfg
|
||||
for t in cfg.test_pipeline:
|
||||
if t.get('type') == 'LoadAnnotations':
|
||||
cfg.test_pipeline.remove(t)
|
||||
|
||||
is_batch = True
|
||||
if not isinstance(imgs, (list, tuple)):
|
||||
imgs = [imgs]
|
||||
is_batch = False
|
||||
|
||||
if isinstance(imgs[0], np.ndarray):
|
||||
cfg.test_pipeline[0]['type'] = 'LoadImageFromNDArray'
|
||||
|
||||
# TODO: Consider using the singleton pattern to avoid building
|
||||
# a pipeline for each inference
|
||||
pipeline = Compose(cfg.test_pipeline)
|
||||
|
||||
data = defaultdict(list)
|
||||
for img in imgs:
|
||||
if isinstance(img, np.ndarray):
|
||||
data_ = dict(img=img)
|
||||
else:
|
||||
data_ = dict(img_path=img)
|
||||
data_ = pipeline(data_)
|
||||
data['inputs'].append(data_['inputs'])
|
||||
data['data_samples'].append(data_['data_samples'])
|
||||
|
||||
return data, is_batch
|
|
@ -3,48 +3,14 @@ import tempfile
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine import ConfigDict
|
||||
from utils import * # noqa: F401, F403
|
||||
|
||||
from mmseg.apis import MMSegInferencer
|
||||
from mmseg.models import EncoderDecoder
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
|
||||
@MODELS.register_module(name='InferExampleHead')
|
||||
class ExampleDecodeHead(BaseDecodeHead):
|
||||
|
||||
def __init__(self, num_classes=19, out_channels=None):
|
||||
super().__init__(
|
||||
3, 3, num_classes=num_classes, out_channels=out_channels)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.cls_seg(inputs[0])
|
||||
|
||||
|
||||
@MODELS.register_module(name='InferExampleBackbone')
|
||||
class ExampleBackbone(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(3, 3, 3)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
pass
|
||||
|
||||
def forward(self, x):
|
||||
return [self.conv(x)]
|
||||
|
||||
|
||||
@MODELS.register_module(name='InferExampleModel')
|
||||
class ExampleModel(EncoderDecoder):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
def test_inferencer():
|
||||
register_all_modules()
|
||||
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
from mmengine import ConfigDict, init_default_scope
|
||||
from utils import * # noqa: F401, F403
|
||||
|
||||
from mmseg.apis import RSImage, RSInferencer
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
class TestRSImage(TestCase):
|
||||
|
||||
def test_read_whole_image(self):
|
||||
init_default_scope('mmseg')
|
||||
img_path = osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_loveda_dataset/img_dir/0.png')
|
||||
rs_image = RSImage(img_path)
|
||||
window_size = (16, 16)
|
||||
rs_image.create_grids(window_size)
|
||||
image_data = rs_image.read(rs_image.grids[0])
|
||||
self.assertIsNotNone(image_data)
|
||||
|
||||
def test_write_image_data(self):
|
||||
init_default_scope('mmseg')
|
||||
img_path = osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_loveda_dataset/img_dir/0.png')
|
||||
rs_image = RSImage(img_path)
|
||||
window_size = (16, 16)
|
||||
rs_image.create_grids(window_size)
|
||||
data = np.random.random((16, 16)).astype(np.int8)
|
||||
rs_image.write(data, rs_image.grids[0])
|
||||
|
||||
|
||||
class TestRSInferencer(TestCase):
|
||||
|
||||
def test_read_and_inference(self):
|
||||
init_default_scope('mmseg')
|
||||
cfg_dict = dict(
|
||||
model=dict(
|
||||
type='InferExampleModel',
|
||||
data_preprocessor=dict(type='SegDataPreProcessor'),
|
||||
backbone=dict(type='InferExampleBackbone'),
|
||||
decode_head=dict(type='InferExampleHead'),
|
||||
test_cfg=dict(mode='whole')),
|
||||
test_dataloader=dict(
|
||||
dataset=dict(
|
||||
type='ExampleDataset',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
])),
|
||||
test_pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
])
|
||||
cfg = ConfigDict(cfg_dict)
|
||||
model = MODELS.build(cfg.model)
|
||||
model.cfg = cfg
|
||||
inferencer = RSInferencer.from_model(model)
|
||||
|
||||
img_path = osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_loveda_dataset/img_dir/0.png')
|
||||
rs_image = RSImage(img_path)
|
||||
window_size = (16, 16)
|
||||
stride = (16, 16)
|
||||
inferencer.run(rs_image, window_size, stride)
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
|
||||
from mmseg.models import EncoderDecoder
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module(name='InferExampleHead')
|
||||
class ExampleDecodeHead(BaseDecodeHead):
|
||||
|
||||
def __init__(self, num_classes=19, out_channels=None):
|
||||
super().__init__(
|
||||
3, 3, num_classes=num_classes, out_channels=out_channels)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.cls_seg(inputs[0])
|
||||
|
||||
|
||||
@MODELS.register_module(name='InferExampleBackbone')
|
||||
class ExampleBackbone(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(3, 3, 3)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
pass
|
||||
|
||||
def forward(self, x):
|
||||
return [self.conv(x)]
|
||||
|
||||
|
||||
@MODELS.register_module(name='InferExampleModel')
|
||||
class ExampleModel(EncoderDecoder):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
Loading…
Reference in New Issue