mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
17 lines
401 B
Python
17 lines
401 B
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import platform
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from mmselfsup.models.utils import Encoder
|
|
|
|
|
|
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
|
def test_dalle():
|
|
model = Encoder()
|
|
fake_inputs = torch.rand((2, 3, 112, 112))
|
|
fake_outputs = model(fake_inputs)
|
|
|
|
assert list(fake_outputs.shape) == [2, 8192, 14, 14]
|