22 lines
683 B
Python
22 lines
683 B
Python
|
import pytest
|
||
|
import torch
|
||
|
|
||
|
from mmseg.core import OHEMPixelSampler
|
||
|
|
||
|
|
||
|
def test_ohem_sampler():
|
||
|
|
||
|
with pytest.raises(AssertionError):
|
||
|
# seg_logit and seg_label must be of the same size
|
||
|
sampler = OHEMPixelSampler()
|
||
|
seg_logit = torch.randn(1, 19, 45, 45)
|
||
|
seg_label = torch.randint(0, 19, size=(1, 1, 89, 89))
|
||
|
sampler.sample(seg_logit, seg_label)
|
||
|
|
||
|
sampler = OHEMPixelSampler()
|
||
|
seg_logit = torch.randn(1, 19, 45, 45)
|
||
|
seg_label = torch.randint(0, 19, size=(1, 1, 45, 45))
|
||
|
seg_weight = sampler.sample(seg_logit, seg_label)
|
||
|
assert seg_weight.shape[0] == seg_logit.shape[0]
|
||
|
assert seg_weight.shape[1:] == seg_logit.shape[2:]
|