87 lines
3.3 KiB
Python
87 lines
3.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from mmseg.models.builder import LOSSES
|
|
|
|
|
|
@LOSSES.register_module()
|
|
class TreeTripletLoss(nn.Module):
|
|
"""TreeTripletLoss. Modified from https://github.com/qhanghu/HSSN_pytorch/b
|
|
lob/main/mmseg/models/losses/tree_triplet_loss.py.
|
|
|
|
Args:
|
|
num_classes (int): Number of categories.
|
|
hiera_map (List[int]): Hierarchy map of each category.
|
|
hiera_index (List[List[int]]): Hierarchy indices of each hierarchy.
|
|
ignore_index (int): Specifies a target value that is ignored and
|
|
does not contribute to the input gradients. Defaults: 255.
|
|
|
|
Examples:
|
|
>>> num_classes = 19
|
|
>>> hiera_map = [
|
|
0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 5, 5, 6, 6, 6, 6, 6, 6]
|
|
>>> hiera_index = [
|
|
0, 2], [2, 5], [5, 8], [8, 10], [10, 11], [11, 13], [13, 19]]
|
|
"""
|
|
|
|
def __init__(self, num_classes, hiera_map, hiera_index, ignore_index=255):
|
|
super().__init__()
|
|
|
|
self.ignore_label = ignore_index
|
|
self.num_classes = num_classes
|
|
self.hiera_map = hiera_map
|
|
self.hiera_index = hiera_index
|
|
|
|
def forward(self, feats: torch.Tensor, labels=None, max_triplet=200):
|
|
labels = labels.unsqueeze(1).float().clone()
|
|
labels = torch.nn.functional.interpolate(
|
|
labels, (feats.shape[2], feats.shape[3]), mode='nearest')
|
|
labels = labels.squeeze(1).long()
|
|
assert labels.shape[-1] == feats.shape[-1], '{} {}'.format(
|
|
labels.shape, feats.shape)
|
|
|
|
labels = labels.view(-1)
|
|
feats = feats.permute(0, 2, 3, 1)
|
|
feats = feats.contiguous().view(-1, feats.shape[-1])
|
|
|
|
triplet_loss = 0
|
|
exist_classes = torch.unique(labels)
|
|
exist_classes = [x for x in exist_classes if x != 255]
|
|
class_count = 0
|
|
|
|
for ii in exist_classes:
|
|
index_range = self.hiera_index[self.hiera_map[ii]]
|
|
index_anchor = labels == ii
|
|
index_pos = (labels >= index_range[0]) & (
|
|
labels < index_range[-1]) & (~index_anchor)
|
|
index_neg = (labels < index_range[0]) | (labels >= index_range[-1])
|
|
|
|
min_size = min(
|
|
torch.sum(index_anchor), torch.sum(index_pos),
|
|
torch.sum(index_neg), max_triplet)
|
|
|
|
feats_anchor = feats[index_anchor][:min_size]
|
|
feats_pos = feats[index_pos][:min_size]
|
|
feats_neg = feats[index_neg][:min_size]
|
|
|
|
distance = torch.zeros(min_size, 2).to(feats)
|
|
distance[:, 0:1] = 1 - (feats_anchor * feats_pos).sum(1, True)
|
|
distance[:, 1:2] = 1 - (feats_anchor * feats_neg).sum(1, True)
|
|
|
|
# margin always 0.1 + (4-2)/4 since the hierarchy is three level
|
|
# TODO: should include label of pos is the same as anchor
|
|
margin = 0.6 * torch.ones(min_size).to(feats)
|
|
|
|
tl = distance[:, 0] - distance[:, 1] + margin
|
|
tl = F.relu(tl)
|
|
|
|
if tl.size(0) > 0:
|
|
triplet_loss += tl.mean()
|
|
class_count += 1
|
|
if class_count == 0:
|
|
return None, torch.tensor([0]).to(feats)
|
|
triplet_loss /= class_count
|
|
return triplet_loss, torch.tensor([class_count]).to(feats)
|