| | from functools import partial |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from timm.models.vision_transformer import PatchEmbed, Block |
| |
|
| | from util.pos_embed import get_2d_sincos_pos_embed |
| |
|
| |
|
| | class MaskedAutoencoderViTNoCT(nn.Module): |
| | """ Masked Autoencoder with VisionTransformer backbone |
| | """ |
| | def __init__(self, img_size=384, patch_size=16, in_chans=3, |
| | embed_dim=1024, depth=24, num_heads=16, |
| | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
| | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): |
| | super().__init__() |
| |
|
| | |
| | |
| | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) |
| | num_patches = self.patch_embed.num_patches |
| |
|
| | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False) |
| |
|
| | self.blocks = nn.ModuleList([ |
| | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) |
| | for i in range(depth)]) |
| | self.norm = norm_layer(embed_dim) |
| | |
| |
|
| | |
| | |
| | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) |
| |
|
| | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) |
| |
|
| | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches, decoder_embed_dim), requires_grad=False) |
| |
|
| | self.decoder_blocks = nn.ModuleList([ |
| | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) |
| | for i in range(decoder_depth)]) |
| |
|
| | self.decoder_norm = norm_layer(decoder_embed_dim) |
| | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) |
| | |
| |
|
| | self.norm_pix_loss = norm_pix_loss |
| |
|
| | self.initialize_weights() |
| |
|
| | def initialize_weights(self): |
| | |
| | |
| | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=False) |
| | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) |
| |
|
| | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=False) |
| | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) |
| |
|
| | |
| | w = self.patch_embed.proj.weight.data |
| | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| |
|
| | |
| | torch.nn.init.normal_(self.mask_token, std=.02) |
| |
|
| | |
| | self.apply(self._init_weights) |
| |
|
| | def _init_weights(self, m): |
| | if isinstance(m, nn.Linear): |
| | |
| | torch.nn.init.xavier_uniform_(m.weight) |
| | if isinstance(m, nn.Linear) and m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| | elif isinstance(m, nn.LayerNorm): |
| | nn.init.constant_(m.bias, 0) |
| | nn.init.constant_(m.weight, 1.0) |
| |
|
| | def patchify(self, imgs): |
| | """ |
| | imgs: (N, 3, H, W) |
| | x: (N, L, patch_size**2 *3) |
| | """ |
| | p = self.patch_embed.patch_size[0] |
| | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 |
| |
|
| | h = w = imgs.shape[2] // p |
| | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) |
| | x = torch.einsum('nchpwq->nhwpqc', x) |
| | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) |
| | return x |
| |
|
| | def unpatchify(self, x): |
| | """ |
| | x: (N, L, patch_size**2 *3) |
| | imgs: (N, 3, H, W) |
| | """ |
| | p = self.patch_embed.patch_size[0] |
| | h = w = int(x.shape[1]**.5) |
| | assert h * w == x.shape[1] |
| | |
| | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) |
| | x = torch.einsum('nhwpqc->nchpwq', x) |
| | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) |
| | return imgs |
| |
|
| | def random_masking(self, x, mask_ratio): |
| | """ |
| | Perform per-sample random masking by per-sample shuffling. |
| | Per-sample shuffling is done by argsort random noise. |
| | x: [N, L, D], sequence |
| | """ |
| | N, L, D = x.shape |
| | len_keep = int(L * (1 - mask_ratio)) |
| | |
| | noise = torch.rand(N, L, device=x.device) |
| | |
| | |
| | ids_shuffle = torch.argsort(noise, dim=1) |
| | ids_restore = torch.argsort(ids_shuffle, dim=1) |
| |
|
| | |
| | ids_keep = ids_shuffle[:, :len_keep] |
| | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
| |
|
| | |
| | mask = torch.ones([N, L], device=x.device) |
| | mask[:, :len_keep] = 0 |
| | |
| | mask = torch.gather(mask, dim=1, index=ids_restore) |
| |
|
| | return x_masked, mask, ids_restore |
| |
|
| | def forward_encoder(self, x, mask_ratio): |
| | |
| | x = self.patch_embed(x) |
| |
|
| | |
| | x = x + self.pos_embed |
| |
|
| | |
| | x, mask, ids_restore = self.random_masking(x, mask_ratio) |
| |
|
| | |
| | for blk in self.blocks: |
| | x = blk(x) |
| | x = self.norm(x) |
| |
|
| | return x, mask, ids_restore |
| |
|
| | def forward_decoder(self, x, ids_restore): |
| | |
| | x = self.decoder_embed(x) |
| |
|
| | |
| | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1) |
| | x_ = torch.cat([x, mask_tokens], dim=1) |
| | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) |
| | x = x_ |
| |
|
| | |
| | x = x + self.decoder_pos_embed |
| |
|
| | |
| | for blk in self.decoder_blocks: |
| | x = blk(x) |
| | x = self.decoder_norm(x) |
| |
|
| | |
| | x = self.decoder_pred(x) |
| |
|
| | return x |
| |
|
| | def forward_loss(self, imgs, pred, mask): |
| | """ |
| | imgs: [N, 3, H, W] |
| | pred: [N, L, p*p*3] |
| | mask: [N, L], 0 is keep, 1 is remove, |
| | """ |
| | target = self.patchify(imgs) |
| | if self.norm_pix_loss: |
| | mean = target.mean(dim=-1, keepdim=True) |
| | var = target.var(dim=-1, keepdim=True) |
| | target = (target - mean) / (var + 1.e-6)**.5 |
| |
|
| | loss = (pred - target) ** 2 |
| | loss = loss.mean(dim=-1) |
| |
|
| | |
| | N, L = mask.shape |
| | mask_s = torch.ones([N, L], device=imgs.device) |
| | loss = (loss * mask_s).sum() / mask_s.sum() |
| |
|
| | |
| | return loss |
| |
|
| | def forward(self, imgs, mask_ratio=0.75): |
| | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) |
| | pred = self.forward_decoder(latent, ids_restore) |
| | loss = self.forward_loss(imgs, pred, mask) |
| | return loss, pred, mask |
| |
|
| |
|
| | def mae_vit_base_patch16_dec512d8b(**kwargs): |
| | model = MaskedAutoencoderViTNoCT( |
| | patch_size=16, embed_dim=768, depth=12, num_heads=12, |
| | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
| | return model |
| |
|
| |
|
| | def mae_vit_large_patch16_dec512d8b(**kwargs): |
| | model = MaskedAutoencoderViTNoCT( |
| | patch_size=16, embed_dim=1024, depth=24, num_heads=16, |
| | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
| | return model |
| |
|
| |
|
| | def mae_vit_huge_patch14_dec512d8b(**kwargs): |
| | model = MaskedAutoencoderViTNoCT( |
| | patch_size=14, embed_dim=1280, depth=32, num_heads=16, |
| | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
| | return model |
| |
|
| |
|
| | |
| | mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b |
| | mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b |
| | mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b |
| |
|