[Refactor] Support passing arguments to loss from head. (#523)
parent
9ab9d4ff31
commit
34d5a25281
|
@ -39,11 +39,12 @@ class ClsHead(BaseHead):
|
|||
self.compute_accuracy = Accuracy(topk=self.topk)
|
||||
self.cal_acc = cal_acc
|
||||
|
||||
def loss(self, cls_score, gt_label):
|
||||
def loss(self, cls_score, gt_label, **kwargs):
|
||||
num_samples = len(cls_score)
|
||||
losses = dict()
|
||||
# compute loss
|
||||
loss = self.compute_loss(cls_score, gt_label, avg_factor=num_samples)
|
||||
loss = self.compute_loss(
|
||||
cls_score, gt_label, avg_factor=num_samples, **kwargs)
|
||||
if self.cal_acc:
|
||||
# compute accuracy
|
||||
acc = self.compute_accuracy(cls_score, gt_label)
|
||||
|
@ -55,10 +56,10 @@ class ClsHead(BaseHead):
|
|||
losses['loss'] = loss
|
||||
return losses
|
||||
|
||||
def forward_train(self, cls_score, gt_label):
|
||||
def forward_train(self, cls_score, gt_label, **kwargs):
|
||||
if isinstance(cls_score, tuple):
|
||||
cls_score = cls_score[-1]
|
||||
losses = self.loss(cls_score, gt_label)
|
||||
losses = self.loss(cls_score, gt_label, **kwargs)
|
||||
return losses
|
||||
|
||||
def simple_test(self, cls_score):
|
||||
|
|
|
@ -46,9 +46,9 @@ class LinearClsHead(ClsHead):
|
|||
|
||||
return self.post_process(pred)
|
||||
|
||||
def forward_train(self, x, gt_label):
|
||||
def forward_train(self, x, gt_label, **kwargs):
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
cls_score = self.fc(x)
|
||||
losses = self.loss(cls_score, gt_label)
|
||||
losses = self.loss(cls_score, gt_label, **kwargs)
|
||||
return losses
|
||||
|
|
|
@ -40,11 +40,11 @@ class MultiLabelClsHead(BaseHead):
|
|||
losses['loss'] = loss
|
||||
return losses
|
||||
|
||||
def forward_train(self, cls_score, gt_label):
|
||||
def forward_train(self, cls_score, gt_label, **kwargs):
|
||||
if isinstance(cls_score, tuple):
|
||||
cls_score = cls_score[-1]
|
||||
gt_label = gt_label.type_as(cls_score)
|
||||
losses = self.loss(cls_score, gt_label)
|
||||
losses = self.loss(cls_score, gt_label, **kwargs)
|
||||
return losses
|
||||
|
||||
def simple_test(self, x):
|
||||
|
|
|
@ -39,12 +39,12 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
|
|||
|
||||
self.fc = nn.Linear(self.in_channels, self.num_classes)
|
||||
|
||||
def forward_train(self, x, gt_label):
|
||||
def forward_train(self, x, gt_label, **kwargs):
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
gt_label = gt_label.type_as(x)
|
||||
cls_score = self.fc(x)
|
||||
losses = self.loss(cls_score, gt_label)
|
||||
losses = self.loss(cls_score, gt_label, **kwargs)
|
||||
return losses
|
||||
|
||||
def simple_test(self, x):
|
||||
|
|
|
@ -127,11 +127,11 @@ class StackedLinearClsHead(ClsHead):
|
|||
|
||||
return self.post_process(pred)
|
||||
|
||||
def forward_train(self, x, gt_label):
|
||||
def forward_train(self, x, gt_label, **kwargs):
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
cls_score = x
|
||||
for layer in self.layers:
|
||||
cls_score = layer(cls_score)
|
||||
losses = self.loss(cls_score, gt_label)
|
||||
losses = self.loss(cls_score, gt_label, **kwargs)
|
||||
return losses
|
||||
|
|
|
@ -79,9 +79,9 @@ class VisionTransformerClsHead(ClsHead):
|
|||
|
||||
return self.post_process(pred)
|
||||
|
||||
def forward_train(self, x, gt_label):
|
||||
def forward_train(self, x, gt_label, **kwargs):
|
||||
x = x[-1]
|
||||
_, cls_token = x
|
||||
cls_score = self.layers(cls_token)
|
||||
losses = self.loss(cls_score, gt_label)
|
||||
losses = self.loss(cls_score, gt_label, **kwargs)
|
||||
return losses
|
||||
|
|
|
@ -27,6 +27,13 @@ def test_cls_head(feat):
|
|||
losses = head.forward_train(feat, fake_gt_label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
# test ClsHead with weight
|
||||
weight = torch.tensor([0.5, 0.5, 0.5, 0.5])
|
||||
|
||||
losses_ = head.forward_train(feat, fake_gt_label)
|
||||
losses = head.forward_train(feat, fake_gt_label, weight=weight)
|
||||
assert losses['loss'].item() == losses_['loss'].item() * 0.5
|
||||
|
||||
|
||||
@pytest.mark.parametrize('feat', [torch.rand(4, 3), (torch.rand(4, 3), )])
|
||||
def test_linear_head(feat):
|
||||
|
|
Loading…
Reference in New Issue