mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
52 lines
1.7 KiB
Python
52 lines
1.7 KiB
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
from os.path import dirname, join
|
||
|
from unittest import TestCase
|
||
|
|
||
|
import torch
|
||
|
from mmengine import Config
|
||
|
|
||
|
import mmseg
|
||
|
from mmseg.models.backbones import VPD
|
||
|
|
||
|
|
||
|
class TestVPD(TestCase):
|
||
|
|
||
|
def setUp(self) -> None:
|
||
|
|
||
|
repo_dpath = dirname(dirname(mmseg.__file__))
|
||
|
config_dpath = join(repo_dpath, 'configs/_base_/models/vpd_sd.py')
|
||
|
vpd_cfg = Config.fromfile(config_dpath).stable_diffusion_cfg
|
||
|
vpd_cfg.pop('checkpoint')
|
||
|
|
||
|
self.vpd_model = VPD(
|
||
|
diffusion_cfg=vpd_cfg,
|
||
|
class_embed_path='https://download.openmmlab.com/mmsegmentation/'
|
||
|
'v0.5/vpd/nyu_class_embeddings.pth',
|
||
|
class_embed_select=True,
|
||
|
pad_shape=64,
|
||
|
unet_cfg=dict(use_attn=False),
|
||
|
)
|
||
|
|
||
|
def test_forward(self):
|
||
|
# test forward without class_id
|
||
|
x = torch.randn(1, 3, 60, 60)
|
||
|
with torch.no_grad():
|
||
|
out = self.vpd_model(x)
|
||
|
|
||
|
self.assertEqual(len(out), 4)
|
||
|
self.assertListEqual(list(out[0].shape), [1, 320, 8, 8])
|
||
|
self.assertListEqual(list(out[1].shape), [1, 640, 4, 4])
|
||
|
self.assertListEqual(list(out[2].shape), [1, 1280, 2, 2])
|
||
|
self.assertListEqual(list(out[3].shape), [1, 1280, 1, 1])
|
||
|
|
||
|
# test forward with class_id
|
||
|
x = torch.randn(1, 3, 60, 60)
|
||
|
with torch.no_grad():
|
||
|
out = self.vpd_model((x, torch.tensor([2])))
|
||
|
|
||
|
self.assertEqual(len(out), 4)
|
||
|
self.assertListEqual(list(out[0].shape), [1, 320, 8, 8])
|
||
|
self.assertListEqual(list(out[1].shape), [1, 640, 4, 4])
|
||
|
self.assertListEqual(list(out[2].shape), [1, 1280, 2, 2])
|
||
|
self.assertListEqual(list(out[3].shape), [1, 1280, 1, 1])
|