fast-reid/fastreid/layers/batch_drop.py

33 lines
808 B
Python
Raw Normal View History

2020-02-10 07:38:56 +08:00
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import random
2020-03-25 10:58:26 +08:00
2020-02-10 07:38:56 +08:00
from torch import nn
class BatchDrop(nn.Module):
2020-03-25 10:58:26 +08:00
"""ref: https://github.com/daizuozhuo/batch-dropblock-network/blob/master/models/networks.py
2020-02-10 07:38:56 +08:00
batch drop mask
"""
2020-03-25 10:58:26 +08:00
2020-02-10 07:38:56 +08:00
def __init__(self, h_ratio, w_ratio):
2020-03-25 10:58:26 +08:00
super(BatchDrop, self).__init__()
2020-02-10 07:38:56 +08:00
self.h_ratio = h_ratio
self.w_ratio = w_ratio
def forward(self, x):
if self.training:
h, w = x.size()[-2:]
rh = round(self.h_ratio * h)
rw = round(self.w_ratio * w)
2020-03-25 10:58:26 +08:00
sx = random.randint(0, h - rh)
sy = random.randint(0, w - rw)
2020-02-10 07:38:56 +08:00
mask = x.new_ones(x.size())
2020-03-25 10:58:26 +08:00
mask[:, :, sx:sx + rh, sy:sy + rw] = 0
2020-02-10 07:38:56 +08:00
x = x * mask
2020-03-25 10:58:26 +08:00
return x