mirror of https://github.com/JDAI-CV/fast-reid.git
30 lines
787 B
Python
30 lines
787 B
Python
|
# encoding: utf-8
|
||
|
"""
|
||
|
@author: liaoxingyu
|
||
|
@contact: sherlockliao01@gmail.com
|
||
|
"""
|
||
|
|
||
|
import random
|
||
|
from torch import nn
|
||
|
|
||
|
|
||
|
class BatchDrop(nn.Module):
|
||
|
"""Copy from https://github.com/daizuozhuo/batch-dropblock-network/blob/master/models/networks.py
|
||
|
batch drop mask
|
||
|
"""
|
||
|
def __init__(self, h_ratio, w_ratio):
|
||
|
super().__init__()
|
||
|
self.h_ratio = h_ratio
|
||
|
self.w_ratio = w_ratio
|
||
|
|
||
|
def forward(self, x):
|
||
|
if self.training:
|
||
|
h, w = x.size()[-2:]
|
||
|
rh = round(self.h_ratio * h)
|
||
|
rw = round(self.w_ratio * w)
|
||
|
sx = random.randint(0, h-rh)
|
||
|
sy = random.randint(0, w-rw)
|
||
|
mask = x.new_ones(x.size())
|
||
|
mask[:, :, sx:sx+rh, sy:sy+rw] = 0
|
||
|
x = x * mask
|
||
|
return x
|