Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import copy | |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn, Tensor | |
| from torchmultimodal.models.albef.image_encoder import ALBEFVisionEncoder | |
| from torchmultimodal.models.albef.model import ALBEFModel, ALBEFModelWithSimilarity | |
| from torchmultimodal.models.albef.multimodal_encoder import ALBEFMultimodalEncoder | |
| from torchmultimodal.modules.encoders.bert_text_encoder import bert_text_encoder | |
| from torchmultimodal.modules.layers.text_embedding import BERTTextEmbeddings | |
| from torchmultimodal.modules.losses.albef import ( | |
| CausalLanguageModelingLoss, | |
| ImageTextContrastiveLoss, | |
| ) | |
| from torchmultimodal.utils.attention import get_causal_attention_mask | |
| from torchmultimodal.utils.common import momentum_update, remove_grad | |
| _ALBEF_PRETRAINED_URLS = { | |
| "vqa": "https://download.pytorch.org/models/multimodal/albef/pretrained_vqa_checkpoint.pt", | |
| "retrieval": "https://download.pytorch.org/models/multimodal/albef/pretrained_retrieval_checkpoint.pt", | |
| } | |
| class PredictionHead(nn.Module): | |
| """ | |
| Predict the following token autoregressively. | |
| Args: | |
| vocab_size (int): The number of different tokens the prediction_head can predict. | |
| hidden_size (int): The hidden size of the prediction_head. | |
| layer_norm_eps (float): The epsilon used by the prediction_head normalization layer. | |
| transform_act_fn (Callable[[Tensor], Tensor]): The activation function in the prediction_head. | |
| Inputs: | |
| hidden_states (Tensor): The hidden states of preceding tokens. | |
| Returns: | |
| Tensor: Prediction scores for the following token. | |
| """ | |
| def __init__( | |
| self, | |
| vocab_size: int = 30522, | |
| hidden_size: int = 768, | |
| layer_norm_eps: float = 1e-12, | |
| transform_act_fn: Callable[[Tensor], Tensor] = nn.functional.gelu, | |
| ) -> None: | |
| super().__init__() | |
| self.dense = nn.Linear(hidden_size, hidden_size) | |
| self.transform_act_fn = transform_act_fn | |
| self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) | |
| self.decoder = nn.Linear(hidden_size, vocab_size) | |
| def forward(self, hidden_states: Tensor) -> Tensor: | |
| hidden_states = self.dense(hidden_states) | |
| hidden_states = self.transform_act_fn(hidden_states) | |
| hidden_states = self.layer_norm(hidden_states) | |
| hidden_states = self.decoder(hidden_states) | |
| return hidden_states | |
| class ALBEFDecoder(nn.Module): | |
| """ | |
| Generate the prediction scores for answers from image and question hidden states. | |
| Args: | |
| text_embeddings (ALBEFTextEmbeddings): Instantiated ALBEFTextEmbeddings. | |
| multimodal_encoder (ALBEFMultimodalEncoder): Instantiated ALBEFMultimodalEncoder. | |
| prediction_head (PredictionHead): Instantiated PredictionHead. | |
| Inputs: | |
| input_ids (Tensor of shape (batch_size, seq_len)): | |
| Input ids for input text tokens. | |
| attention_mask (Tensor of shape (batch_size, seq_len)): | |
| Input attention mask to avoid performing attention on padding token indices. | |
| encoder_hidden_states (Tensor of shape (batch_size, encoder_seq_len, hidden_size)): | |
| The encoder hidden states. | |
| encoder_attention_mask (Tensor of shape (batch_size, encoder_seq_len)): | |
| The attention mask for encoder hidden states. | |
| Returns: | |
| Tensor: Prediction scores for answers. | |
| """ | |
| def __init__( | |
| self, | |
| text_embeddings: BERTTextEmbeddings, | |
| multimodal_encoder: ALBEFMultimodalEncoder, | |
| prediction_head: PredictionHead, | |
| ) -> None: | |
| super().__init__() | |
| self.text_embeddings = text_embeddings | |
| self.multimodal_encoder = multimodal_encoder | |
| self.prediction_head = prediction_head | |
| def get_extended_attention_mask_for_decoder(self, attention_mask: Tensor) -> Tensor: | |
| """ | |
| Apply a causal mask in addition to the padding mask and make the mask broadcastable, | |
| such that future and masked tokens are ignored. | |
| Args: | |
| attention_mask (Tensor): | |
| Padding mask with ones indicating tokens to attend to, zeros for tokens to ignore. | |
| Returns: | |
| extended_attention_mask (Tensor): | |
| The broadcastable attention mask, with the same dtype as ``attention_mask.dtype``. | |
| """ | |
| device = attention_mask.device | |
| batch_size, seq_length = attention_mask.shape | |
| causal_mask = get_causal_attention_mask(seq_length).to(device) | |
| causal_mask = causal_mask.repeat(batch_size, 1).view( | |
| batch_size, seq_length, seq_length | |
| ) | |
| extended_attention_mask = ( | |
| causal_mask[:, None, :, :] * attention_mask[:, None, None, :] | |
| ) | |
| extended_attention_mask = extended_attention_mask.to(dtype=attention_mask.dtype) | |
| return extended_attention_mask | |
| def forward( | |
| self, | |
| input_ids: Tensor, | |
| attention_mask: Tensor, | |
| encoder_hidden_states: Tensor, | |
| encoder_attention_mask: Tensor, | |
| ) -> Tensor: | |
| hidden_states = self.text_embeddings(input_ids) | |
| attention_mask = self.get_extended_attention_mask_for_decoder(attention_mask) | |
| decoder_output = self.multimodal_encoder( | |
| hidden_states=hidden_states, | |
| attention_mask=attention_mask, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| ) | |
| prediction_scores = self.prediction_head(decoder_output) | |
| return prediction_scores | |
| class ALBEFModelForVQA(nn.Module): | |
| """ | |
| ALBEF Model for VQA finetuning and inference. | |
| Args: | |
| model (ALBEFModel): Instantiated ALBEFModel. | |
| answer_decoder (ALBEFDecoder): Instantiated ALBEFDecoder. | |
| loss (CausalLanguageModelingLoss): Instantiated CausalLanguageModelingLoss. | |
| Inputs: | |
| image (Tensor of shape (B, C, H, W)): Image features. | |
| question (Tensor of shape (B, L)): Question text features. | |
| question_atts (Tensor of shape (B, L)): Question attention mask. | |
| answers (Tensor of shape (N, M)): Answer text features. | |
| answers_atts (Tensor of shape (N, M)): Answer attention mask. | |
| ans_weights (Optional[Tensor] of shape (N)): Weights for each answer. | |
| Required if is_train is True. | |
| ans_lengths (Optional[List[int]] of length B): Number of answers for each question. | |
| ans_lengths should sum to N. | |
| Required if is_train is True. | |
| alpha (Optional[float]): The interpolation value between clm_loss and loss_distill. | |
| Required if is_train is True. | |
| k (Optional[int]): The number of answers to return for inference. | |
| Required if is_train is False. | |
| is_train (Optional[bool]): Whether the model is in training. | |
| Returns: | |
| is_train is True: | |
| Tensor: The masked language modeling loss for input. | |
| is_train is False: | |
| Tuple[Tensor, Tensor]: The ids and probabilities for the top k predicted answers. | |
| """ | |
| def __init__( | |
| self, | |
| model: ALBEFModel, | |
| answer_decoder: ALBEFDecoder, | |
| loss: CausalLanguageModelingLoss, | |
| ) -> None: | |
| super().__init__() | |
| self.model = model | |
| self.answer_decoder = answer_decoder | |
| self.loss = loss | |
| self.answer_decoder_m = copy.deepcopy(self.answer_decoder) | |
| remove_grad( | |
| self.answer_decoder_m | |
| ) # remove gradient for the momentum decoder model | |
| def _train_forward( | |
| self, | |
| image: Tensor, | |
| question: Tensor, | |
| question_atts: Tensor, | |
| answers: Tensor, | |
| answers_atts: Tensor, | |
| ans_weights: Tensor, | |
| ans_lengths: List[int], | |
| alpha: float, | |
| ) -> Tensor: | |
| """ | |
| Forward step for training. Encode the inputs with the ALBEFModel. | |
| Generate pseudo-targets using answer_decoder_m (momentum decoder model). | |
| Generate answer predictions using answer_decoder. | |
| Compute masked language modeling loss of the predictions using answers as labels, | |
| pseudo-targets as soft-labels, and alpha as their interpolation value. | |
| Inputs: | |
| image (Tensor of shape (B, C, H, W)): Image features. | |
| question (Tensor of shape (B, L)): Question text features. | |
| question_atts (Tensor of shape (B, L)): Question attention mask. | |
| answers (Tensor of shape (N, M)): Answer text features. | |
| answers_atts (Tensor of shape (N, M)): Answer attention mask. | |
| ans_weights (Tensor of shape (N)): Weights for each answer. | |
| ans_lengths (List[int] of length B): Number of answers for each question. | |
| ans_lengths should sum to N. | |
| alpha (float): The interpolation value between clm_loss and loss_distill. | |
| Returns: | |
| Tensor: The masked language modeling loss for input. | |
| """ | |
| # get image-question embeddings from the ALBEFModel and format it to match the ans_lengths | |
| encoder_outputs = self.model(image, question, question_atts) | |
| ( | |
| encoder_hidden_states, | |
| encoder_hidden_states_m, | |
| encoder_attention_mask, | |
| ) = self._encoder_hidden_states( | |
| encoder_outputs.multimodal_embeddings, | |
| encoder_outputs.multimodal_embeddings_m, | |
| question_atts, | |
| ans_lengths, | |
| ) | |
| # use the momentum model to generate pseudo-targets | |
| with torch.no_grad(): | |
| momentum_update( | |
| self.answer_decoder, self.answer_decoder_m, self.model.momentum | |
| ) | |
| prediction_scores_m = self.answer_decoder_m( | |
| input_ids=answers, | |
| attention_mask=answers_atts, | |
| encoder_hidden_states=encoder_hidden_states_m, | |
| encoder_attention_mask=encoder_attention_mask, | |
| ) | |
| # generate answer predictions | |
| prediction_scores = self.answer_decoder( | |
| input_ids=answers, | |
| attention_mask=answers_atts, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| ) | |
| # compute masked language modeling loss from the prediction scores | |
| labels = answers.masked_fill(answers == 0, self.loss.mask_token_id) | |
| loss = self.loss(labels, prediction_scores, prediction_scores_m, alpha) | |
| loss = ans_weights * loss | |
| loss = loss.sum() / image.size(0) | |
| return loss | |
| def _eval_forward( | |
| self, | |
| image: Tensor, | |
| question: Tensor, | |
| question_atts: Tensor, | |
| answers: Tensor, | |
| answer_atts: Tensor, | |
| k: int = 128, | |
| ) -> Tuple[Tensor, Tensor]: | |
| """ | |
| Forward step for evaluation. Encode the inputs with the ALBEFModel. | |
| Generate answer autoregressively using the decoder, starting with the [CLS] token. | |
| Compute the answer ids and their perspective probabilities of the top k predictions. | |
| Inputs: | |
| image (Tensor of shape (B, C, H, W)): Image features. | |
| question (Tensor of shape (B, L)): Question text features. | |
| question_atts (Tensor of shape (B, L)): Question attention mask. | |
| answers (Tensor of shape (N, M)): Answer text features. | |
| answer_atts (Tensor of shape (N, M)): Answer attention mask. | |
| k (int): The number of answers to return for inference. | |
| Returns: | |
| Tuple[Tensor, Tensor]: The ids and probabilities for the top k predicted answers. | |
| """ | |
| # get multimodal embeddings from the ALBEFModel and | |
| # feed it to the decoder as cross attention | |
| encoder_outputs = self.model(image, question, question_atts) | |
| # use cls token as the decoder's initial input token | |
| num_ques = question.size(0) | |
| start_ids = answers[0, 0].repeat(num_ques, 1) | |
| atts = torch.ones(start_ids.shape).to(image.device) | |
| # auto-regressively generates the answer | |
| prediction_scores = self.answer_decoder( | |
| input_ids=start_ids, | |
| attention_mask=atts, | |
| encoder_hidden_states=encoder_outputs.multimodal_embeddings, | |
| encoder_attention_mask=question_atts, | |
| ) | |
| logits = prediction_scores[:, 0, :] | |
| answer_first_token = answers[:, 1] | |
| prob_first_token = F.softmax(logits, dim=1).index_select( | |
| dim=1, index=answer_first_token | |
| ) | |
| topk_probs, topk_ids = prob_first_token.topk(k, dim=1) | |
| input_ids = [] | |
| input_atts = [] | |
| for topk_id in topk_ids: | |
| input_ids.append(answers.index_select(dim=0, index=topk_id)) | |
| input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) | |
| input_ids = torch.cat(input_ids) | |
| input_atts = torch.cat(input_atts) | |
| targets_ids = input_ids.masked_fill(input_ids == 0, self.loss.mask_token_id) | |
| question_states = encoder_outputs.multimodal_embeddings.repeat_interleave( | |
| k, dim=0 | |
| ) | |
| question_atts = question_atts.repeat_interleave(k, dim=0) | |
| prediction_scores = self.answer_decoder( | |
| input_ids=input_ids, | |
| attention_mask=input_atts, | |
| encoder_hidden_states=question_states, | |
| encoder_attention_mask=question_atts, | |
| ) | |
| answer_loss = self.loss(targets_ids, prediction_scores) | |
| answer_loss = answer_loss.view(input_ids.size(0), -1) | |
| # topk_prob: first token probability | |
| topk_probs = topk_probs.view(-1, 1) | |
| log_probs = torch.cat([topk_probs.log(), -answer_loss], dim=1) | |
| # re-calculate log probabilities for the answer sequences using chain rule | |
| log_probs_sum = log_probs.sum(1) | |
| log_probs_sum = log_probs_sum.view(num_ques, k) | |
| topk_probs = F.softmax(log_probs_sum, dim=-1) | |
| # get top-k after re-ranking | |
| topk_probs, rerank_id = topk_probs.topk(k, dim=1) | |
| topk_ids = torch.gather(topk_ids, 1, rerank_id) | |
| return topk_ids, topk_probs | |
| def _encoder_hidden_states( | |
| self, | |
| multimodal_embeds: Tensor, | |
| multimodal_embeds_m: Tensor, | |
| question_atts: Tensor, | |
| ans_lengths: List[int], | |
| ) -> Tuple[Tensor, Tensor, Tensor]: | |
| """ | |
| Repeat each image-question input, repeat its embedding and mask to match the number of answers it has. | |
| Args: | |
| multimodal_embeds (Tensor): Image-question embeddings. | |
| multimodal_embeds_m (Tensor): Image-question embeddings from the momentum model. | |
| question_atts (Tensor): Question attention mask. | |
| ans_lengths (List[int]): The number of answers each image-question input has. | |
| Returns: | |
| encoder_hidden_states (Tensor): Image-question embeddings after the repetition. | |
| encoder_hidden_states_m (Tensor): Image-question embeddings from the momentum model after the repetition. | |
| encoder_attention_mask (Tensor): Question attention mask after the repetition. | |
| """ | |
| encoder_hidden_states = [] | |
| encoder_attention_mask = [] | |
| for b, n in enumerate(ans_lengths): | |
| encoder_hidden_states += [multimodal_embeds[b]] * n | |
| encoder_attention_mask += [question_atts[b]] * n | |
| encoder_hidden_states = torch.stack(encoder_hidden_states) | |
| encoder_attention_mask = torch.stack(encoder_attention_mask) | |
| with torch.no_grad(): | |
| encoder_hidden_states_m = [] | |
| for b, n in enumerate(ans_lengths): | |
| encoder_hidden_states_m += [multimodal_embeds_m[b]] * n | |
| encoder_hidden_states_m = torch.stack(encoder_hidden_states_m) | |
| return encoder_hidden_states, encoder_hidden_states_m, encoder_attention_mask | |
| def forward( | |
| self, | |
| image: Tensor, | |
| question: Tensor, | |
| question_atts: Tensor, | |
| answers: Tensor, | |
| answers_atts: Tensor, | |
| ans_weights: Optional[Tensor] = None, | |
| ans_lengths: Optional[List[int]] = None, | |
| alpha: Optional[float] = 0.0, | |
| k: Optional[int] = 128, | |
| is_train: Optional[bool] = True, | |
| ) -> Union[Tensor, Tuple[Tensor, Tensor]]: | |
| if is_train: | |
| return self._train_forward( | |
| image, | |
| question, | |
| question_atts, | |
| answers, | |
| answers_atts, | |
| ans_weights, | |
| ans_lengths, | |
| alpha, | |
| ) | |
| else: | |
| return self._eval_forward( | |
| image, | |
| question, | |
| question_atts, | |
| answers, | |
| answers_atts, | |
| k, | |
| ) | |
| class ALBEFModelForRetrieval(nn.Module): | |
| """ | |
| ALBEF Model for Retrieval finetuning and inference. | |
| In training mode, the forward step computes image-text contrastive loss and | |
| image-text matching loss. | |
| In evaluation mode, the forward step takes 3 types of input: | |
| image: encode image input, project and normalize the embeddings. | |
| text: encode text input, project and normalize the embeddings. | |
| multimodal: create multimodal embeddings from image and text | |
| embeddings, and compute image-text matching scores. | |
| Args: | |
| model_with_similarity (ALBEFModelWithSimilarity): Instantiated ALBEFModelWithSimilarity. | |
| itc_loss (ImageTextContrastiveLoss): Instantiated ImageTextContrastiveLoss. | |
| hidden_size (int): Dimensionality of encoder outputs. | |
| Inputs: | |
| image (Optional[Tensor] of shape (B, C, H, W)): Image features. | |
| Required if is_train is True. | |
| Required if input_type is "image" or "multimodal". | |
| text (Optional[Tensor] of shape (B, L)): Text features. | |
| Required if is_train is True. | |
| Required if input_type is "text" or "multimodal". | |
| text_atts (Tensor of shape (B, L)): Text attention mask. | |
| Required if is_train is True. | |
| Required if input_type is "text" or "multimodal". | |
| idx (Tensor of shape (B)): Identifier for each image sample. | |
| Required if is_train is True. | |
| alpha (Optional[float]): The interpolation value between clm_loss and loss_distill. | |
| Default is 0. | |
| input_type (Optional[str]): "image", "text", or "multimodal" indicating the encoding type. | |
| Required if is_train is False. | |
| is_train (Optional[bool]): Whether the model is in training. | |
| Default is True. | |
| Returns: | |
| is_train is True: | |
| Tensor: The sum of itc loss and itm loss. | |
| is_train is False: | |
| input_type is "image": | |
| Tuple[Tensor, Tensor]: Image embeddings and projected image features. | |
| input_type is "text": | |
| Tuple[Tensor, Tensor]: Text embeddings and projected text features. | |
| input_type is "multimodal" | |
| Tensor: Scores for the retrieval task. | |
| """ | |
| def __init__( | |
| self, | |
| model_with_similarity: ALBEFModelWithSimilarity, | |
| itc_loss: ImageTextContrastiveLoss, | |
| hidden_size: int, | |
| ) -> None: | |
| super().__init__() | |
| self.model_with_similarity = model_with_similarity | |
| self.itc_loss = itc_loss | |
| self.itm_head = nn.Linear(hidden_size, 2) | |
| def _train_forward( | |
| self, | |
| image: Tensor, | |
| text: Tensor, | |
| text_atts: Tensor, | |
| idx: Tensor, | |
| alpha: float, | |
| ) -> Tensor: | |
| encoder_output = self.model_with_similarity(image, text, text_atts, idx) | |
| # compute image-text contrastive loss | |
| similarity_outputs = encoder_output.similarity | |
| similarity_targets = encoder_output.sim_targets | |
| itc_loss = self.itc_loss( | |
| similarity_outputs.sim_i2t, | |
| similarity_outputs.sim_t2i, | |
| similarity_outputs.sim_i2t_m, | |
| similarity_outputs.sim_t2i_m, | |
| similarity_targets, | |
| alpha, | |
| ) | |
| # compute image-text matching loss | |
| pos_embeddings = encoder_output.multimodal_embeddings[:, 0, :] | |
| neg_embeddings = encoder_output.multimodal_embeddings_neg[:, 0, :] | |
| vl_embeddings = torch.cat([pos_embeddings, neg_embeddings], dim=0) | |
| vl_output = self.itm_head(vl_embeddings) | |
| itm_labels = torch.cat( | |
| [ | |
| torch.ones(pos_embeddings.size(0), dtype=torch.long), | |
| torch.zeros(neg_embeddings.size(0), dtype=torch.long), | |
| ], | |
| dim=0, | |
| ).to(vl_embeddings.device) | |
| itm_loss = F.cross_entropy(vl_output, itm_labels) | |
| loss = itc_loss + itm_loss | |
| return loss | |
| def _encode_image( | |
| self, | |
| image: Tensor, | |
| ) -> Tuple[Tensor, Tensor]: | |
| image_embed = self.model_with_similarity.albef_model.vision_encoder(image) | |
| image_feat = F.normalize( | |
| self.model_with_similarity.vision_proj(image_embed[:, 0, :]), dim=-1 | |
| ) | |
| return image_embed, image_feat | |
| def _encode_text( | |
| self, | |
| text: Tensor, | |
| text_atts: Tensor, | |
| ) -> Tuple[Tensor, Tensor]: | |
| text_embed = self.model_with_similarity.albef_model.text_encoder( | |
| text, text_atts | |
| ).last_hidden_state | |
| text_feat = F.normalize( | |
| self.model_with_similarity.text_proj(text_embed[:, 0, :]), dim=-1 | |
| ) | |
| return text_embed, text_feat | |
| def _image_text_matching_score( | |
| self, | |
| image: Tensor, | |
| text: Tensor, | |
| text_atts: Tensor, | |
| ) -> Tensor: | |
| multimodal_embeds = self.model_with_similarity.albef_model.multimodal_encoder( | |
| text, | |
| text_atts, | |
| image, | |
| ) | |
| score = self.itm_head(multimodal_embeds[:, 0, :])[:, 1] | |
| return score | |
| def _eval_forward( | |
| self, | |
| input_type: str, | |
| image: Optional[Tensor], | |
| text: Optional[Tensor], | |
| text_atts: Optional[Tensor], | |
| ) -> Union[Tensor, Tuple[Tensor, Tensor]]: | |
| if input_type == "image": | |
| assert image is not None, "image input tensor cannot be None" | |
| return self._encode_image(image) | |
| elif input_type == "text": | |
| assert ( | |
| text is not None and text_atts is not None | |
| ), "text and text attention mask cannot be None" | |
| return self._encode_text(text, text_atts) | |
| elif input_type == "multimodal": | |
| assert ( | |
| image is not None and text is not None and text_atts is not None | |
| ), "image embeddings, text embeddings, and text attention mask cannot be None" | |
| return self._image_text_matching_score(image, text, text_atts) | |
| else: | |
| raise ValueError("input_type must be image, text, or multimodal") | |
| def forward( | |
| self, | |
| image: Optional[Tensor] = None, | |
| text: Optional[Tensor] = None, | |
| text_atts: Optional[Tensor] = None, | |
| idx: Optional[Tensor] = None, | |
| alpha: Optional[Tensor] = 0.0, | |
| input_type: Optional[str] = None, | |
| is_train: Optional[bool] = True, | |
| ) -> Union[Tensor, Tuple[Tensor, Tensor]]: | |
| if is_train: | |
| return self._train_forward( | |
| image, | |
| text, | |
| text_atts, | |
| idx, | |
| alpha, | |
| ) | |
| else: | |
| return self._eval_forward( | |
| input_type, | |
| image, | |
| text, | |
| text_atts, | |
| ) | |
| def albef_model_for_vqa( | |
| config: Dict[str, Any], pretrained: bool = False | |
| ) -> ALBEFModelForVQA: | |
| vision_encoder = ALBEFVisionEncoder(**config["vision_encoder_args"]) | |
| text_encoder = bert_text_encoder(**config["text_encoder_args"]) | |
| question_multimodal_encoder = ALBEFMultimodalEncoder( | |
| **config["multimodal_encoder_args"] | |
| ) | |
| text_embeddings = BERTTextEmbeddings(**config["text_embeddings_args"]) | |
| answer_multimodal_encoder = ALBEFMultimodalEncoder( | |
| **config["multimodal_encoder_args"] | |
| ) | |
| prediction_head = PredictionHead(**config["prediction_head_args"]) | |
| albef_model = ALBEFModel(vision_encoder, text_encoder, question_multimodal_encoder) | |
| decoder = ALBEFDecoder(text_embeddings, answer_multimodal_encoder, prediction_head) | |
| loss = CausalLanguageModelingLoss() | |
| model = ALBEFModelForVQA(albef_model, decoder, loss) | |
| if pretrained: | |
| checkpoint = torch.hub.load_state_dict_from_url( | |
| _ALBEF_PRETRAINED_URLS["vqa"], map_location="cpu" | |
| ) | |
| model.load_state_dict(checkpoint) | |
| return model | |
| def albef_model_for_retrieval( | |
| config: Dict[str, Any], pretrained: bool = False | |
| ) -> ALBEFModelForRetrieval: | |
| vision_encoder = ALBEFVisionEncoder(**config["vision_encoder_args"]) | |
| text_encoder = bert_text_encoder(**config["text_encoder_args"]) | |
| multimodal_encoder = ALBEFMultimodalEncoder(**config["multimodal_encoder_args"]) | |
| vision_proj = nn.Linear(**config["projection_args"]) | |
| text_proj = nn.Linear(**config["projection_args"]) | |
| albef_model = ALBEFModel(vision_encoder, text_encoder, multimodal_encoder) | |
| albef_model_with_sim = ALBEFModelWithSimilarity( | |
| albef_model, vision_proj, text_proj, **config["similarity_args"] | |
| ) | |
| itc_loss = ImageTextContrastiveLoss() | |
| model = ALBEFModelForRetrieval( | |
| albef_model_with_sim, itc_loss, config["hidden_size"] | |
| ) | |
| if pretrained: | |
| checkpoint = torch.hub.load_state_dict_from_url( | |
| _ALBEF_PRETRAINED_URLS["retrieval"], map_location="cpu" | |
| ) | |
| model.load_state_dict(checkpoint) | |
| return model | |