[Feature] Add mask channel in MGD Loss (#461)
* [Feature] Add mask channel in MGD Loss * fix lint --------- Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com> Co-authored-by: pppppM <gjf_mail@126.com>pull/471/head
parent
0919f69287
commit
60469c0a58
|
@ -27,10 +27,12 @@ class MGDConnector(BaseConnector):
|
|||
student_channels: int,
|
||||
teacher_channels: int,
|
||||
lambda_mgd: float = 0.65,
|
||||
mask_on_channel: bool = False,
|
||||
init_cfg: Optional[Dict] = None,
|
||||
) -> None:
|
||||
super().__init__(init_cfg)
|
||||
self.lambda_mgd = lambda_mgd
|
||||
self.mask_on_channel = mask_on_channel
|
||||
if student_channels != teacher_channels:
|
||||
self.align = nn.Conv2d(
|
||||
student_channels,
|
||||
|
@ -55,7 +57,11 @@ class MGDConnector(BaseConnector):
|
|||
N, C, H, W = feature.shape
|
||||
|
||||
device = feature.device
|
||||
mat = torch.rand((N, 1, H, W)).to(device)
|
||||
if not self.mask_on_channel:
|
||||
mat = torch.rand((N, 1, H, W)).to(device)
|
||||
else:
|
||||
mat = torch.rand((N, C, 1, 1)).to(device)
|
||||
|
||||
mat = torch.where(mat > 1 - self.lambda_mgd,
|
||||
torch.zeros(1).to(device),
|
||||
torch.ones(1).to(device)).to(device)
|
||||
|
|
|
@ -144,6 +144,22 @@ class TestConnector(TestCase):
|
|||
assert s_output1.shape == torch.Size([1, 16, 8, 8])
|
||||
assert s_output2.shape == torch.Size([1, 32, 8, 8])
|
||||
|
||||
mgd_connector1 = MGDConnector(
|
||||
student_channels=16,
|
||||
teacher_channels=16,
|
||||
lambda_mgd=0.65,
|
||||
mask_on_channel=True)
|
||||
mgd_connector2 = MGDConnector(
|
||||
student_channels=16,
|
||||
teacher_channels=32,
|
||||
lambda_mgd=0.65,
|
||||
mask_on_channel=True)
|
||||
s_output1 = mgd_connector1.forward_train(s_feat)
|
||||
s_output2 = mgd_connector2.forward_train(s_feat)
|
||||
|
||||
assert s_output1.shape == torch.Size([1, 16, 8, 8])
|
||||
assert s_output2.shape == torch.Size([1, 32, 8, 8])
|
||||
|
||||
def test_norm_connector(self):
|
||||
s_feat = torch.randn(2, 3, 2, 2)
|
||||
norm_cfg = dict(type='BN', affine=False, track_running_stats=False)
|
||||
|
|
Loading…
Reference in New Issue