mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Hi, this earlier work has the same name and idea behind this layer. It could be useful for readers to keep both links here if they want to see the effects of introducing this layer on a very different domain. 😄
54 lines
1.7 KiB
Python
54 lines
1.7 KiB
Python
from typing import Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class PatchDropout(nn.Module):
|
|
"""
|
|
https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220
|
|
"""
|
|
return_indices: torch.jit.Final[bool]
|
|
|
|
def __init__(
|
|
self,
|
|
prob: float = 0.5,
|
|
num_prefix_tokens: int = 1,
|
|
ordered: bool = False,
|
|
return_indices: bool = False,
|
|
):
|
|
super().__init__()
|
|
assert 0 <= prob < 1.
|
|
self.prob = prob
|
|
self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
|
|
self.ordered = ordered
|
|
self.return_indices = return_indices
|
|
|
|
def forward(self, x) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
|
if not self.training or self.prob == 0.:
|
|
if self.return_indices:
|
|
return x, None
|
|
return x
|
|
|
|
if self.num_prefix_tokens:
|
|
prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:]
|
|
else:
|
|
prefix_tokens = None
|
|
|
|
B = x.shape[0]
|
|
L = x.shape[1]
|
|
num_keep = max(1, int(L * (1. - self.prob)))
|
|
keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[:, :num_keep]
|
|
if self.ordered:
|
|
# NOTE does not need to maintain patch order in typical transformer use,
|
|
# but possibly useful for debug / visualization
|
|
keep_indices = keep_indices.sort(dim=-1)[0]
|
|
x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:]))
|
|
|
|
if prefix_tokens is not None:
|
|
x = torch.cat((prefix_tokens, x), dim=1)
|
|
|
|
if self.return_indices:
|
|
return x, keep_indices
|
|
return x
|