|
|
import torch |
|
|
import random |
|
|
|
|
|
MASK_TOKEN = 0 |
|
|
PAD_TOKEN = 1 |
|
|
CLS_TOKEN = 2 |
|
|
|
|
|
def complete_masking(batch, masking_p, n_tokens): |
|
|
"""Apply masking to input batch for masked language modeling. |
|
|
|
|
|
Args: |
|
|
batch (dict): Input batch containing 'input_ids' and 'attention_mask' |
|
|
masking_p (float): Probability of masking a token |
|
|
n_tokens (int): Total number of tokens in vocabulary |
|
|
|
|
|
Returns: |
|
|
dict: Batch with masked indices and masking information |
|
|
""" |
|
|
device = batch['input_ids'].device |
|
|
input_ids = batch['input_ids'] |
|
|
attention_mask = batch['attention_mask'] |
|
|
|
|
|
|
|
|
prob = torch.rand(input_ids.shape, device=device) |
|
|
mask = (prob < masking_p) & (input_ids != PAD_TOKEN) & (input_ids != CLS_TOKEN) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
masked_indices = input_ids.clone() |
|
|
|
|
|
|
|
|
num_tokens_to_mask = mask.sum().item() |
|
|
|
|
|
|
|
|
mask_mask = torch.rand(num_tokens_to_mask, device=device) < 0.8 |
|
|
random_mask = (torch.rand(num_tokens_to_mask, device=device) < 0.5) & ~mask_mask |
|
|
|
|
|
|
|
|
masked_indices[mask] = torch.where( |
|
|
mask_mask, |
|
|
torch.tensor(MASK_TOKEN, device=device, dtype=torch.long), |
|
|
masked_indices[mask] |
|
|
) |
|
|
|
|
|
|
|
|
random_tokens = torch.randint( |
|
|
3, n_tokens, |
|
|
(random_mask.sum(),), |
|
|
device=device, |
|
|
dtype=torch.long |
|
|
) |
|
|
masked_indices[mask][random_mask] = random_tokens |
|
|
|
|
|
|
|
|
|
|
|
return { |
|
|
'masked_indices': masked_indices, |
|
|
'attention_mask': attention_mask, |
|
|
'mask': mask, |
|
|
'input_ids': input_ids |
|
|
} |