[Feature] Add mask channel in MGD Loss (#461)

* [Feature]

Add mask channel in MGD Loss

* fix lint

---------

Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com>
Co-authored-by: pppppM <gjf_mail@126.com>
pull/471/head
Ming-Hsuan-Tu 2023-03-01 18:34:38 +08:00 committed by GitHub
parent 0919f69287
commit 60469c0a58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 1 deletions

View File

@ -27,10 +27,12 @@ class MGDConnector(BaseConnector):
student_channels: int,
teacher_channels: int,
lambda_mgd: float = 0.65,
mask_on_channel: bool = False,
init_cfg: Optional[Dict] = None,
) -> None:
super().__init__(init_cfg)
self.lambda_mgd = lambda_mgd
self.mask_on_channel = mask_on_channel
if student_channels != teacher_channels:
self.align = nn.Conv2d(
student_channels,
@ -55,7 +57,11 @@ class MGDConnector(BaseConnector):
N, C, H, W = feature.shape
device = feature.device
mat = torch.rand((N, 1, H, W)).to(device)
if not self.mask_on_channel:
mat = torch.rand((N, 1, H, W)).to(device)
else:
mat = torch.rand((N, C, 1, 1)).to(device)
mat = torch.where(mat > 1 - self.lambda_mgd,
torch.zeros(1).to(device),
torch.ones(1).to(device)).to(device)

View File

@ -144,6 +144,22 @@ class TestConnector(TestCase):
assert s_output1.shape == torch.Size([1, 16, 8, 8])
assert s_output2.shape == torch.Size([1, 32, 8, 8])
mgd_connector1 = MGDConnector(
student_channels=16,
teacher_channels=16,
lambda_mgd=0.65,
mask_on_channel=True)
mgd_connector2 = MGDConnector(
student_channels=16,
teacher_channels=32,
lambda_mgd=0.65,
mask_on_channel=True)
s_output1 = mgd_connector1.forward_train(s_feat)
s_output2 = mgd_connector2.forward_train(s_feat)
assert s_output1.shape == torch.Size([1, 16, 8, 8])
assert s_output2.shape == torch.Size([1, 32, 8, 8])
def test_norm_connector(self):
s_feat = torch.randn(2, 3, 2, 2)
norm_cfg = dict(type='BN', affine=False, track_running_stats=False)