[Refactor] Support passing arguments to loss from head. (#523)

pull/531/head
Ezra-Yu 2021-11-10 17:12:34 +08:00 committed by GitHub
parent 9ab9d4ff31
commit 34d5a25281
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 22 additions and 14 deletions

View File

@ -39,11 +39,12 @@ class ClsHead(BaseHead):
self.compute_accuracy = Accuracy(topk=self.topk)
self.cal_acc = cal_acc
def loss(self, cls_score, gt_label):
def loss(self, cls_score, gt_label, **kwargs):
num_samples = len(cls_score)
losses = dict()
# compute loss
loss = self.compute_loss(cls_score, gt_label, avg_factor=num_samples)
loss = self.compute_loss(
cls_score, gt_label, avg_factor=num_samples, **kwargs)
if self.cal_acc:
# compute accuracy
acc = self.compute_accuracy(cls_score, gt_label)
@ -55,10 +56,10 @@ class ClsHead(BaseHead):
losses['loss'] = loss
return losses
def forward_train(self, cls_score, gt_label):
def forward_train(self, cls_score, gt_label, **kwargs):
if isinstance(cls_score, tuple):
cls_score = cls_score[-1]
losses = self.loss(cls_score, gt_label)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses
def simple_test(self, cls_score):

View File

@ -46,9 +46,9 @@ class LinearClsHead(ClsHead):
return self.post_process(pred)
def forward_train(self, x, gt_label):
def forward_train(self, x, gt_label, **kwargs):
if isinstance(x, tuple):
x = x[-1]
cls_score = self.fc(x)
losses = self.loss(cls_score, gt_label)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses

View File

@ -40,11 +40,11 @@ class MultiLabelClsHead(BaseHead):
losses['loss'] = loss
return losses
def forward_train(self, cls_score, gt_label):
def forward_train(self, cls_score, gt_label, **kwargs):
if isinstance(cls_score, tuple):
cls_score = cls_score[-1]
gt_label = gt_label.type_as(cls_score)
losses = self.loss(cls_score, gt_label)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses
def simple_test(self, x):

View File

@ -39,12 +39,12 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
self.fc = nn.Linear(self.in_channels, self.num_classes)
def forward_train(self, x, gt_label):
def forward_train(self, x, gt_label, **kwargs):
if isinstance(x, tuple):
x = x[-1]
gt_label = gt_label.type_as(x)
cls_score = self.fc(x)
losses = self.loss(cls_score, gt_label)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses
def simple_test(self, x):

View File

@ -127,11 +127,11 @@ class StackedLinearClsHead(ClsHead):
return self.post_process(pred)
def forward_train(self, x, gt_label):
def forward_train(self, x, gt_label, **kwargs):
if isinstance(x, tuple):
x = x[-1]
cls_score = x
for layer in self.layers:
cls_score = layer(cls_score)
losses = self.loss(cls_score, gt_label)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses

View File

@ -79,9 +79,9 @@ class VisionTransformerClsHead(ClsHead):
return self.post_process(pred)
def forward_train(self, x, gt_label):
def forward_train(self, x, gt_label, **kwargs):
x = x[-1]
_, cls_token = x
cls_score = self.layers(cls_token)
losses = self.loss(cls_score, gt_label)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses

View File

@ -27,6 +27,13 @@ def test_cls_head(feat):
losses = head.forward_train(feat, fake_gt_label)
assert losses['loss'].item() > 0
# test ClsHead with weight
weight = torch.tensor([0.5, 0.5, 0.5, 0.5])
losses_ = head.forward_train(feat, fake_gt_label)
losses = head.forward_train(feat, fake_gt_label, weight=weight)
assert losses['loss'].item() == losses_['loss'].item() * 0.5
@pytest.mark.parametrize('feat', [torch.rand(4, 3), (torch.rand(4, 3), )])
def test_linear_head(feat):