[Fix] Fix ddp bugs caused by `out_type`. (#1570)

* set out_type to be 'raw'

* update test
pull/1503/merge
Yixiao Fang 2023-05-17 17:32:10 +08:00 committed by GitHub
parent 034919d032
commit 770eb8e24a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 9 additions and 9 deletions

View File

@ -182,7 +182,7 @@ class BEiTPretrainViT(BEiTViT):
drop_path_rate: float = 0,
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
out_type: str = 'avg_featmap',
out_type: str = 'raw',
frozen_stages: int = -1,
use_abs_pos_emb: bool = False,
use_rel_pos_bias: bool = False,

View File

@ -251,7 +251,7 @@ class CAEPretrainViT(BEiTViT):
bias: bool = 'qv_bias',
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
out_type: str = 'avg_featmap',
out_type: str = 'raw',
frozen_stages: int = -1,
use_abs_pos_emb: bool = True,
use_rel_pos_bias: bool = False,

View File

@ -64,7 +64,7 @@ class MAEViT(VisionTransformer):
drop_path_rate: float = 0,
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
out_type: str = 'avg_featmap',
out_type: str = 'raw',
interpolate_mode: str = 'bicubic',
patch_cfg: dict = dict(),
layer_cfgs: dict = dict(),

View File

@ -207,7 +207,7 @@ class MaskFeatViT(VisionTransformer):
drop_path_rate: float = 0,
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
out_type: str = 'avg_featmap',
out_type: str = 'raw',
interpolate_mode: str = 'bicubic',
patch_cfg: dict = dict(),
layer_cfgs: dict = dict(),

View File

@ -35,7 +35,7 @@ class TestBEiT(TestCase):
# test without mask
fake_outputs = beit_backbone(fake_inputs, None)
assert fake_outputs[0].shape == torch.Size([2, 768])
assert fake_outputs[0].shape == torch.Size([2, 197, 768])
@pytest.mark.skipif(
platform.system() == 'Windows', reason='Windows mem limit')

View File

@ -25,7 +25,7 @@ def test_cae_vit():
# test without mask
fake_outputs = cae_backbone(fake_inputs, None)
assert fake_outputs[0].shape == torch.Size([1, 192])
assert fake_outputs[0].shape == torch.Size([1, 197, 192])
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')

View File

@ -21,7 +21,7 @@ def test_mae_vit():
# test without mask
fake_outputs = mae_backbone(fake_inputs, None)
assert fake_outputs[0].shape == torch.Size([2, 768])
assert fake_outputs[0].shape == torch.Size([2, 197, 768])
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')

View File

@ -22,7 +22,7 @@ def test_maskfeat_vit():
# test without mask
fake_outputs = maskfeat_backbone(fake_inputs, None)
assert fake_outputs[0].shape == torch.Size([2, 768])
assert fake_outputs[0].shape == torch.Size([2, 197, 768])
@pytest.mark.skipif(

View File

@ -24,7 +24,7 @@ def test_milan_vit():
# test without mask
fake_outputs = milan_backbone(fake_inputs, None)
assert fake_outputs[0].shape == torch.Size([2, 768])
assert fake_outputs[0].shape == torch.Size([2, 197, 768])
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')