[Fix] Fix ddp bugs caused by `out_type`. (#1570)
* set out_type to be 'raw' * update testpull/1503/merge
parent
034919d032
commit
770eb8e24a
|
@ -182,7 +182,7 @@ class BEiTPretrainViT(BEiTViT):
|
|||
drop_path_rate: float = 0,
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
final_norm: bool = True,
|
||||
out_type: str = 'avg_featmap',
|
||||
out_type: str = 'raw',
|
||||
frozen_stages: int = -1,
|
||||
use_abs_pos_emb: bool = False,
|
||||
use_rel_pos_bias: bool = False,
|
||||
|
|
|
@ -251,7 +251,7 @@ class CAEPretrainViT(BEiTViT):
|
|||
bias: bool = 'qv_bias',
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
final_norm: bool = True,
|
||||
out_type: str = 'avg_featmap',
|
||||
out_type: str = 'raw',
|
||||
frozen_stages: int = -1,
|
||||
use_abs_pos_emb: bool = True,
|
||||
use_rel_pos_bias: bool = False,
|
||||
|
|
|
@ -64,7 +64,7 @@ class MAEViT(VisionTransformer):
|
|||
drop_path_rate: float = 0,
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
final_norm: bool = True,
|
||||
out_type: str = 'avg_featmap',
|
||||
out_type: str = 'raw',
|
||||
interpolate_mode: str = 'bicubic',
|
||||
patch_cfg: dict = dict(),
|
||||
layer_cfgs: dict = dict(),
|
||||
|
|
|
@ -207,7 +207,7 @@ class MaskFeatViT(VisionTransformer):
|
|||
drop_path_rate: float = 0,
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
final_norm: bool = True,
|
||||
out_type: str = 'avg_featmap',
|
||||
out_type: str = 'raw',
|
||||
interpolate_mode: str = 'bicubic',
|
||||
patch_cfg: dict = dict(),
|
||||
layer_cfgs: dict = dict(),
|
||||
|
|
|
@ -35,7 +35,7 @@ class TestBEiT(TestCase):
|
|||
|
||||
# test without mask
|
||||
fake_outputs = beit_backbone(fake_inputs, None)
|
||||
assert fake_outputs[0].shape == torch.Size([2, 768])
|
||||
assert fake_outputs[0].shape == torch.Size([2, 197, 768])
|
||||
|
||||
@pytest.mark.skipif(
|
||||
platform.system() == 'Windows', reason='Windows mem limit')
|
||||
|
|
|
@ -25,7 +25,7 @@ def test_cae_vit():
|
|||
|
||||
# test without mask
|
||||
fake_outputs = cae_backbone(fake_inputs, None)
|
||||
assert fake_outputs[0].shape == torch.Size([1, 192])
|
||||
assert fake_outputs[0].shape == torch.Size([1, 197, 192])
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
|
|
|
@ -21,7 +21,7 @@ def test_mae_vit():
|
|||
|
||||
# test without mask
|
||||
fake_outputs = mae_backbone(fake_inputs, None)
|
||||
assert fake_outputs[0].shape == torch.Size([2, 768])
|
||||
assert fake_outputs[0].shape == torch.Size([2, 197, 768])
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
|
|
|
@ -22,7 +22,7 @@ def test_maskfeat_vit():
|
|||
|
||||
# test without mask
|
||||
fake_outputs = maskfeat_backbone(fake_inputs, None)
|
||||
assert fake_outputs[0].shape == torch.Size([2, 768])
|
||||
assert fake_outputs[0].shape == torch.Size([2, 197, 768])
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
|
|
@ -24,7 +24,7 @@ def test_milan_vit():
|
|||
|
||||
# test without mask
|
||||
fake_outputs = milan_backbone(fake_inputs, None)
|
||||
assert fake_outputs[0].shape == torch.Size([2, 768])
|
||||
assert fake_outputs[0].shape == torch.Size([2, 197, 768])
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
|
|
Loading…
Reference in New Issue