pytorch-image-models/timm/layers/patch_dropout.py
Fernando Cossio 9b11801cb4
Credit earlier work with the same idea.
Hi, this earlier work has the same name and idea behind this layer. It could be useful for readers to keep both links here if they want to see the effects of introducing this layer on a very different domain. 😄
2024-05-16 22:50:34 +02:00

54 lines
1.7 KiB
Python

from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
class PatchDropout(nn.Module):
"""
https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220
"""
return_indices: torch.jit.Final[bool]
def __init__(
self,
prob: float = 0.5,
num_prefix_tokens: int = 1,
ordered: bool = False,
return_indices: bool = False,
):
super().__init__()
assert 0 <= prob < 1.
self.prob = prob
self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
self.ordered = ordered
self.return_indices = return_indices
def forward(self, x) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
if not self.training or self.prob == 0.:
if self.return_indices:
return x, None
return x
if self.num_prefix_tokens:
prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:]
else:
prefix_tokens = None
B = x.shape[0]
L = x.shape[1]
num_keep = max(1, int(L * (1. - self.prob)))
keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[:, :num_keep]
if self.ordered:
# NOTE does not need to maintain patch order in typical transformer use,
# but possibly useful for debug / visualization
keep_indices = keep_indices.sort(dim=-1)[0]
x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:]))
if prefix_tokens is not None:
x = torch.cat((prefix_tokens, x), dim=1)
if self.return_indices:
return x, keep_indices
return x