mirror of https://github.com/RE-OWOD/RE-OWOD
62 lines
2.1 KiB
Python
62 lines
2.1 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
from torch import nn
|
|
from torch.autograd import Function
|
|
from torch.autograd.function import once_differentiable
|
|
|
|
from tensormask import _C
|
|
|
|
|
|
class _SwapAlign2Nat(Function):
|
|
@staticmethod
|
|
def forward(ctx, X, lambda_val, pad_val):
|
|
ctx.lambda_val = lambda_val
|
|
ctx.input_shape = X.size()
|
|
|
|
Y = _C.swap_align2nat_forward(X, lambda_val, pad_val)
|
|
return Y
|
|
|
|
@staticmethod
|
|
@once_differentiable
|
|
def backward(ctx, gY):
|
|
lambda_val = ctx.lambda_val
|
|
bs, ch, h, w = ctx.input_shape
|
|
|
|
gX = _C.swap_align2nat_backward(gY, lambda_val, bs, ch, h, w)
|
|
|
|
return gX, None, None
|
|
|
|
|
|
swap_align2nat = _SwapAlign2Nat.apply
|
|
|
|
|
|
class SwapAlign2Nat(nn.Module):
|
|
"""
|
|
The op `SwapAlign2Nat` described in https://arxiv.org/abs/1903.12174.
|
|
Given an input tensor that predicts masks of shape (N, C=VxU, H, W),
|
|
apply the op, it will return masks of shape (N, V'xU', H', W') where
|
|
the unit lengths of (V, U) and (H, W) are swapped, and the mask representation
|
|
is transformed from aligned to natural.
|
|
Args:
|
|
lambda_val (int): the relative unit length ratio between (V, U) and (H, W),
|
|
as we always have larger unit lengths for (V, U) than (H, W),
|
|
lambda_val is always >= 1.
|
|
pad_val (float): padding value for the values falling outside of the input
|
|
tensor, default set to -6 as sigmoid(-6) is ~0, indicating
|
|
that is no masks outside of the tensor.
|
|
"""
|
|
|
|
def __init__(self, lambda_val, pad_val=-6.0):
|
|
super(SwapAlign2Nat, self).__init__()
|
|
self.lambda_val = lambda_val
|
|
self.pad_val = pad_val
|
|
|
|
def forward(self, X):
|
|
return swap_align2nat(X, self.lambda_val, self.pad_val)
|
|
|
|
def __repr__(self):
|
|
tmpstr = self.__class__.__name__ + "("
|
|
tmpstr += "lambda_val=" + str(self.lambda_val)
|
|
tmpstr += ", pad_val=" + str(self.pad_val)
|
|
tmpstr += ")"
|
|
return tmpstr
|