mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] fix inferencer ut (#3117)
This commit is contained in:
parent
c30d5060f1
commit
b2f4b4fe33
@ -5,7 +5,6 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmengine import ConfigDict
|
from mmengine import ConfigDict
|
||||||
from torch.utils.data import DataLoader, Dataset
|
|
||||||
|
|
||||||
from mmseg.apis import MMSegInferencer
|
from mmseg.apis import MMSegInferencer
|
||||||
from mmseg.models import EncoderDecoder
|
from mmseg.models import EncoderDecoder
|
||||||
@ -46,33 +45,8 @@ class ExampleModel(EncoderDecoder):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
class ExampleDataset(Dataset):
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.pipeline = [
|
|
||||||
dict(type='LoadImageFromFile'),
|
|
||||||
dict(type='LoadAnnotations'),
|
|
||||||
dict(type='PackSegInputs')
|
|
||||||
]
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return dict(img=torch.tensor([1]), img_metas=dict())
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_inferencer():
|
def test_inferencer():
|
||||||
register_all_modules()
|
register_all_modules()
|
||||||
test_dataset = ExampleDataset()
|
|
||||||
data_loader = DataLoader(
|
|
||||||
test_dataset,
|
|
||||||
batch_size=1,
|
|
||||||
sampler=None,
|
|
||||||
num_workers=0,
|
|
||||||
shuffle=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
visualizer = dict(
|
visualizer = dict(
|
||||||
type='SegLocalVisualizer',
|
type='SegLocalVisualizer',
|
||||||
@ -87,7 +61,14 @@ def test_inferencer():
|
|||||||
decode_head=dict(type='InferExampleHead'),
|
decode_head=dict(type='InferExampleHead'),
|
||||||
test_cfg=dict(mode='whole')),
|
test_cfg=dict(mode='whole')),
|
||||||
visualizer=visualizer,
|
visualizer=visualizer,
|
||||||
test_dataloader=data_loader)
|
test_dataloader=dict(
|
||||||
|
dataset=dict(
|
||||||
|
type='ExampleDataset',
|
||||||
|
pipeline=[
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations'),
|
||||||
|
dict(type='PackSegInputs')
|
||||||
|
]), ))
|
||||||
cfg = ConfigDict(cfg_dict)
|
cfg = ConfigDict(cfg_dict)
|
||||||
model = MODELS.build(cfg.model)
|
model = MODELS.build(cfg.model)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user