mirror of https://github.com/alibaba/EasyCV.git
88 lines
2.6 KiB
Python
88 lines
2.6 KiB
Python
import torch
|
|
|
|
from .. import builder
|
|
from ..base import BaseModel
|
|
from ..registry import MODELS
|
|
|
|
|
|
@MODELS.register_module
|
|
class MAE(BaseModel):
|
|
|
|
def __init__(self,
|
|
backbone,
|
|
neck,
|
|
mask_ratio=0.75,
|
|
norm_pix_loss=True,
|
|
**kwargs):
|
|
super(MAE, self).__init__()
|
|
self.mask_ratio = mask_ratio
|
|
self.norm_pix_loss = norm_pix_loss
|
|
self.encoder = builder.build_backbone(backbone)
|
|
self.patch_size = self.encoder.patch_size
|
|
neck['num_patches'] = self.encoder.num_patches
|
|
self.decoder = builder.build_neck(neck)
|
|
self.init_weights()
|
|
|
|
def init_weights(self):
|
|
self.encoder.init_weights()
|
|
self.decoder.init_weights()
|
|
|
|
def patchify(self, imgs):
|
|
"""convert image to patch
|
|
|
|
Args:
|
|
imgs: (N, 3, H, W)
|
|
Returns:
|
|
x: (N, L, patch_size**2 *3)
|
|
"""
|
|
p = self.patch_size
|
|
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
|
|
|
|
h = w = imgs.shape[2] // p
|
|
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
|
|
x = torch.einsum('nchpwq->nhwpqc', x)
|
|
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
|
|
return x
|
|
|
|
def forward_loss(self, imgs, pred, mask):
|
|
"""compute loss
|
|
|
|
Args:
|
|
imgs: (N, 3, H, W)
|
|
pred: (N, L, p*p*3)
|
|
mask: (N, L), 0 is keep, 1 is remove,
|
|
"""
|
|
target = self.patchify(imgs)
|
|
if self.norm_pix_loss:
|
|
mean = target.mean(dim=-1, keepdim=True)
|
|
var = target.var(dim=-1, keepdim=True)
|
|
target = (target - mean) / (var + 1.e-6)**.5
|
|
|
|
# adapt to ConvMAE
|
|
assert pred.shape[0] % target.shape[0] == 0
|
|
target = torch.cat([target] * (pred.shape[0] // target.shape[0]))
|
|
|
|
loss = (pred - target)**2
|
|
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
|
|
|
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
|
return loss
|
|
|
|
def forward_train(self, img, **kwargs):
|
|
latent, mask, ids_restore = self.encoder(
|
|
img, mask_ratio=self.mask_ratio)
|
|
pred = self.decoder(latent, ids_restore)
|
|
loss = self.forward_loss(img, pred, mask)
|
|
return dict(loss=loss)
|
|
|
|
def forward_test(self, img, **kwargs):
|
|
pass
|
|
|
|
def forward(self, img, mode='train', **kwargs):
|
|
if mode == 'train':
|
|
return self.forward_train(img, **kwargs)
|
|
elif mode == 'test':
|
|
return self.forward_test(img, **kwargs)
|
|
else:
|
|
raise Exception('No such mode: {}'.format(mode))
|