| | import torch.nn as nn |
| | import torch |
| | from transformers import AutoModel |
| |
|
| | class BERT_FFNN(nn.Module): |
| | """ |
| | BERT_FFNN: BERT + feed-forward network for text classification tasks. |
| | """ |
| | def __init__( |
| | self, |
| | bert_model_name= "microsoft/deberta-v3-base", |
| | hidden_dims=[192, 96], |
| | output_dim=5, |
| | dropout=0.2, |
| | pooling='attention', |
| | freeze_bert=False, |
| | freeze_layers=0, |
| | use_layer_norm=True |
| | ): |
| | super().__init__() |
| | |
| | |
| | self.bert = AutoModel.from_pretrained(bert_model_name) |
| | self.use_layer_norm = use_layer_norm |
| | self.pooling = pooling |
| | |
| | if pooling == 'attention': |
| | self.attention_pool = AttentionPooling(self.bert.config.hidden_size) |
| | |
| | if freeze_bert: |
| | for param in self.bert.parameters(): |
| | param.requires_grad = False |
| | elif freeze_layers > 0: |
| | for layer in self.bert.encoder.layer[:freeze_layers]: |
| | for param in layer.parameters(): |
| | param.requires_grad = False |
| |
|
| | |
| | fc_input_dim = self.bert.config.hidden_size |
| | layers = [] |
| | in_dim = fc_input_dim |
| | for h_dim in hidden_dims: |
| | layers.append(nn.Linear(in_dim, h_dim)) |
| | layers.append(nn.ReLU()) |
| | if use_layer_norm: |
| | layers.append(nn.LayerNorm(h_dim)) |
| | layers.append(nn.Dropout(dropout)) |
| | in_dim = h_dim |
| | layers.append(nn.Linear(in_dim, output_dim)) |
| | self.classifier = nn.Sequential(*layers) |
| | |
| | def forward(self, input_ids, attention_mask): |
| | |
| | outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
| | |
| | if self.pooling == 'mean': |
| | mask = attention_mask.unsqueeze(-1).float() |
| | sum_emb = (outputs.last_hidden_state * mask).sum(1) |
| | features = sum_emb / mask.sum(1).clamp(min=1e-9) |
| | elif self.pooling == 'max': |
| | mask = attention_mask.unsqueeze(-1).float() |
| | masked_emb = outputs.last_hidden_state.masked_fill(mask == 0, float('-inf')) |
| | features, _ = masked_emb.max(dim=1) |
| | elif self.pooling == 'attention': |
| | features = self.attention_pool(outputs.last_hidden_state, attention_mask) |
| | else: |
| | |
| | features = outputs.pooler_output if getattr(outputs, 'pooler_output', None) is not None else outputs.last_hidden_state[:, 0] |
| | |
| | logits = self.classifier(features) |
| | return logits |
| |
|
| | class AttentionPooling(nn.Module): |
| | def __init__(self, hidden_size): |
| | super().__init__() |
| | self.attention = nn.Linear(hidden_size, 1) |
| |
|
| | def forward(self, hidden_states, attention_mask): |
| | |
| | |
| |
|
| | scores = self.attention(hidden_states).squeeze(-1) |
| | scores = scores.masked_fill(attention_mask == 0, -1e9) |
| | weights = torch.softmax(scores, dim=-1) |
| |
|
| | weighted_sum = torch.sum(hidden_states * weights.unsqueeze(-1), dim=1) |
| | return weighted_sum |
| |
|
| |
|