mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Merge pull request #2362 from MengzhangLI/scipy_1.x
[Enhance] Make scipy as a default dependency in runtime in dev-1.x
This commit is contained in:
commit
383826fec9
@ -11,6 +11,7 @@ from mmengine.model import BaseModule, ModuleList
|
|||||||
from mmengine.model.weight_init import (constant_init, kaiming_init,
|
from mmengine.model.weight_init import (constant_init, kaiming_init,
|
||||||
trunc_normal_)
|
trunc_normal_)
|
||||||
from mmengine.runner.checkpoint import _load_checkpoint
|
from mmengine.runner.checkpoint import _load_checkpoint
|
||||||
|
from scipy import interpolate
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
from torch.nn.modules.utils import _pair as to_2tuple
|
from torch.nn.modules.utils import _pair as to_2tuple
|
||||||
|
|
||||||
@ -18,11 +19,6 @@ from mmseg.registry import MODELS
|
|||||||
from ..utils import PatchEmbed
|
from ..utils import PatchEmbed
|
||||||
from .vit import TransformerEncoderLayer as VisionTransformerEncoderLayer
|
from .vit import TransformerEncoderLayer as VisionTransformerEncoderLayer
|
||||||
|
|
||||||
try:
|
|
||||||
from scipy import interpolate
|
|
||||||
except ImportError:
|
|
||||||
interpolate = None
|
|
||||||
|
|
||||||
|
|
||||||
class BEiTAttention(BaseModule):
|
class BEiTAttention(BaseModule):
|
||||||
"""Window based multi-head self-attention (W-MSA) module with relative
|
"""Window based multi-head self-attention (W-MSA) module with relative
|
||||||
|
@ -3,3 +3,4 @@ mmcls>=1.0.0rc0
|
|||||||
numpy
|
numpy
|
||||||
packaging
|
packaging
|
||||||
prettytable
|
prettytable
|
||||||
|
scipy
|
||||||
|
@ -140,8 +140,11 @@ def test_beit_init():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
model = BEiT(img_size=(512, 512))
|
model = BEiT(img_size=(512, 512))
|
||||||
with pytest.raises(AttributeError):
|
# If scipy is installed, this AttributeError would not be raised.
|
||||||
model.resize_rel_pos_embed(ckpt)
|
from mmengine.utils import is_installed
|
||||||
|
if not is_installed('scipy'):
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
model.resize_rel_pos_embed(ckpt)
|
||||||
|
|
||||||
# pretrained=None
|
# pretrained=None
|
||||||
# init_cfg=123, whose type is unsupported
|
# init_cfg=123, whose type is unsupported
|
||||||
|
@ -138,8 +138,11 @@ def test_mae_init():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
model = MAE(img_size=(512, 512))
|
model = MAE(img_size=(512, 512))
|
||||||
with pytest.raises(AttributeError):
|
# If scipy is installed, this AttributeError would not be raised.
|
||||||
model.resize_rel_pos_embed(ckpt)
|
from mmengine.utils import is_installed
|
||||||
|
if not is_installed('scipy'):
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
model.resize_rel_pos_embed(ckpt)
|
||||||
|
|
||||||
# test resize abs pos embed
|
# test resize abs pos embed
|
||||||
ckpt = model.resize_abs_pos_embed(ckpt['state_dict'])
|
ckpt = model.resize_abs_pos_embed(ckpt['state_dict'])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user