| | import torch |
| | from mmcls.models import VisionTransformer |
| | from torch import nn |
| | from torch.utils.checkpoint import checkpoint |
| | import copy |
| |
|
| | def build_2d_sincos_position_embedding(patches_resolution, |
| | embed_dims, |
| | temperature=10000., |
| | cls_token=False): |
| | """The function is to build position embedding for model to obtain the |
| | position information of the image patches.""" |
| |
|
| | if isinstance(patches_resolution, int): |
| | patches_resolution = (patches_resolution, patches_resolution) |
| |
|
| | h, w = patches_resolution |
| | grid_w = torch.arange(w, dtype=torch.float32) |
| | grid_h = torch.arange(h, dtype=torch.float32) |
| | grid_w, grid_h = torch.meshgrid(grid_w, grid_h) |
| | assert embed_dims % 4 == 0, \ |
| | 'Embed dimension must be divisible by 4.' |
| | pos_dim = embed_dims // 4 |
| |
|
| | omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim |
| | omega = 1. / (temperature**omega) |
| | out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) |
| | out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) |
| |
|
| | pos_emb = torch.cat( |
| | [ |
| | torch.sin(out_w), |
| | torch.cos(out_w), |
| | torch.sin(out_h), |
| | torch.cos(out_h) |
| | ], |
| | dim=1, |
| | )[None, :, :] |
| |
|
| | if cls_token: |
| | cls_token_pe = torch.zeros([1, 1, embed_dims], dtype=torch.float32) |
| | pos_emb = torch.cat([cls_token_pe, pos_emb], dim=1) |
| |
|
| | return pos_emb |
| |
|
| |
|
| |
|
| | class MAEViT(VisionTransformer): |
| | """Vision Transformer for MAE pre-training. |
| | |
| | A PyTorch implement of: `An Image is Worth 16x16 Words: Transformers |
| | for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_ |
| | |
| | Args: |
| | arch (str | dict): Vision Transformer architecture |
| | Default: 'b' |
| | img_size (int | tuple): Input image size |
| | patch_size (int | tuple): The patch size |
| | out_indices (Sequence | int): Output from which stages. |
| | Defaults to -1, means the last stage. |
| | drop_rate (float): Probability of an element to be zeroed. |
| | Defaults to 0. |
| | drop_path_rate (float): stochastic depth rate. Defaults to 0. |
| | norm_cfg (dict): Config dict for normalization layer. |
| | Defaults to ``dict(type='LN')``. |
| | final_norm (bool): Whether to add a additional layer to normalize |
| | final feature map. Defaults to True. |
| | output_cls_token (bool): Whether output the cls_token. If set True, |
| | `with_cls_token` must be True. Defaults to True. |
| | interpolate_mode (str): Select the interpolate mode for position |
| | embeding vector resize. Defaults to "bicubic". |
| | patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. |
| | layer_cfgs (Sequence | dict): Configs of each transformer layer in |
| | encoder. Defaults to an empty dict. |
| | mask_ratio (bool): The ratio of total number of patches to be masked. |
| | Defaults to 0.75. |
| | init_cfg (dict, optional): Initialization config dict. |
| | Defaults to None. |
| | """ |
| |
|
| | arch_zoo = { |
| | **dict.fromkeys( |
| | ['mocov3-s', 'mocov3-small'], { |
| | 'embed_dims': 384, |
| | 'num_layers': 12, |
| | 'num_heads': 12, |
| | 'feedforward_channels': 1536, |
| | }), |
| | **dict.fromkeys( |
| | ['b', 'base'], { |
| | 'embed_dims': 768, |
| | 'num_layers': 12, |
| | 'num_heads': 12, |
| | 'feedforward_channels': 3072 |
| | }), |
| | } |
| |
|
| |
|
| |
|
| | def __init__(self, |
| | arch='b', |
| | img_size=224, |
| | patch_size=16, |
| | out_indices=-1, |
| | drop_rate=0, |
| | drop_path_rate=0, |
| | norm_cfg=dict(type='LN', eps=1e-6), |
| | final_norm=True, |
| | output_cls_token=False, |
| | interpolate_mode='bicubic', |
| | patch_cfg=dict(), |
| | layer_cfgs=dict(), |
| | gradientCKPT=False, |
| | mask_ratio=0.75, |
| | init_cfg=None): |
| | super().__init__( |
| | arch=arch, |
| | img_size=img_size, |
| | patch_size=patch_size, |
| | out_indices=out_indices, |
| | drop_rate=drop_rate, |
| | drop_path_rate=drop_path_rate, |
| | norm_cfg=norm_cfg, |
| | final_norm=final_norm, |
| | output_cls_token=output_cls_token, |
| | interpolate_mode=interpolate_mode, |
| | patch_cfg=patch_cfg, |
| | layer_cfgs=layer_cfgs, |
| | init_cfg=init_cfg) |
| | self.gradientCKPT = gradientCKPT |
| | self.pos_embed.requires_grad = False |
| | self.mask_ratio = mask_ratio |
| | self.num_patches = self.patch_resolution[0] * self.patch_resolution[1] |
| | |
| | |
| |
|
| | def init_weights(self): |
| | super(MAEViT, self).init_weights() |
| | if not (isinstance(self.init_cfg, dict) |
| | and self.init_cfg['type'] == 'Pretrained'): |
| | |
| | pos_embed = build_2d_sincos_position_embedding( |
| | self.patch_resolution, |
| | self.pos_embed.shape[-1], |
| | cls_token=True) |
| | self.pos_embed.data.copy_(pos_embed.float()) |
| |
|
| | w = self.patch_embed.projection.weight.data |
| | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| |
|
| | torch.nn.init.normal_(self.cls_token, std=.02) |
| |
|
| | self.apply(self._init_weights) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | def _init_mask_embedding(self,m): |
| | if hasattr(m,'weight'): |
| | nn.init.constant_(m.weight,1.0) |
| | if hasattr(m, 'bias'): |
| | nn.init.constant_(m.bias,0) |
| |
|
| | def _init_weights(self, m): |
| |
|
| | if isinstance(m, nn.Linear): |
| | torch.nn.init.xavier_uniform_(m.weight) |
| | if isinstance(m, nn.Linear) and m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| | elif isinstance(m, nn.LayerNorm): |
| | nn.init.constant_(m.bias, 0) |
| | nn.init.constant_(m.weight, 1.0) |
| |
|
| | def random_masking(self, x, mask_ratio=0.75, attn_mask=None): |
| | """Generate the mask for MAE Pre-training. |
| | |
| | Args: |
| | x (torch.tensor): Image with data augmentation applied. |
| | mask_ratio (float): The mask ratio of total patches. |
| | Defaults to 0.75. |
| | |
| | Returns: |
| | tuple[Tensor, Tensor, Tensor]: masked image, mask and the ids |
| | to restore original image. |
| | |
| | - x_masked (Tensor): masked image. |
| | - mask (Tensor): mask used to mask image. |
| | - ids_restore (Tensor): ids to restore original image. |
| | """ |
| | N, L, D = x.shape |
| | len_keep = int(L * (1 - mask_ratio)) |
| |
|
| | noise = torch.rand(N, L, device=x.device) |
| |
|
| | |
| | ids_shuffle = torch.argsort( |
| | noise, dim=1) |
| | ids_restore = torch.argsort(ids_shuffle, dim=1) |
| |
|
| | |
| | ids_keep = ids_shuffle[:, :len_keep] |
| | x_masked = torch.gather( |
| | x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
| | |
| |
|
| | |
| | mask = torch.ones([N, L], device=x.device) |
| | mask[:, :len_keep] = 0 |
| | |
| | mask = torch.gather(mask, dim=1, index=ids_restore) |
| |
|
| | return x_masked, mask, ids_restore |
| |
|
| | def generate_mask(self, pixel_level_attn_mask): |
| | ''' |
| | pixel_level_attn_mask: (0,1) attn mask with the same shape as img |
| | ''' |
| | if pixel_level_attn_mask is None: return None |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| |
|
| | def extract_feat(self, img ,attn_mask=None): |
| | x, *_ = self.forward(img,attn_mask) |
| | if self.output_cls_token: |
| | return x[:,0,:] |
| | else: |
| | return torch.mean(x,dim=1) |
| |
|
| | def forward(self, x, attn_mask=None): |
| | if attn_mask is not None: assert self.output_cls_token |
| | |
| | B = x.shape[0] |
| | x = self.patch_embed(x)[0] |
| | |
| | x = x + self.pos_embed[:, 1:1+x.shape[1], :] |
| | |
| | if True: |
| | assert self.mask_ratio == 0. |
| | else: |
| | x, mask, ids_restore = self.random_masking(x, self.mask_ratio) |
| |
|
| | |
| | cls_token = self.cls_token + self.pos_embed[:, :1, :] |
| | cls_tokens = cls_token.expand(B, -1, -1) |
| | x = torch.cat((cls_tokens, x), dim=1) |
| | x = self.drop_after_pos(x) |
| | |
| | |
| |
|
| | for i, layer in enumerate(self.layers): |
| | if self.gradientCKPT: |
| | x = checkpoint(layer,x) |
| | else: |
| | x = layer(x) |
| | if i == len(self.layers) - 1 and self.final_norm: |
| | x = self.norm1(x) |
| | if True: |
| | return x |
| | else: |
| | return (x, mask, ids_restore) |
| |
|
| | def forward_generator(self, x, attn_mask=None): |
| | if attn_mask is not None: assert self.output_cls_token |
| | |
| | B = x.shape[0] |
| | x = self.patch_embed(x)[0] |
| | |
| | x = x + self.pos_embed[:, 1:1+x.shape[1], :] |
| |
|
| | |
| | cls_token = self.cls_token + self.pos_embed[:, :1, :] |
| | cls_tokens = cls_token.expand(B, -1, -1) |
| | x = torch.cat((cls_tokens, x), dim=1) |
| | x = self.drop_after_pos(x) |
| |
|
| | for i, layer in enumerate(self.layers): |
| | if self.gradientCKPT: |
| | x = checkpoint(layer,x) |
| | else: |
| | x = layer(x) |
| |
|
| | if i == len(self.layers) - 1 and self.final_norm: |
| | x = self.norm1(x) |
| | |
| | x = x if (new_x:=(yield x)) is None else new_x |
| |
|
| | debug = False |
| | if debug: |
| | print(f'layer {i}-th forwarded') |
| | |
| |
|