train-modle2 / model.py
fokan's picture
Upload model.py with huggingface_hub
141b176 verified
"""
Custom Student Model for Knowledge Distillation
"""
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from typing import Dict, Any, List, Optional
class StudentModelConfig(PretrainedConfig):
model_type = "distilled_student"
def __init__(
self,
hidden_size=768,
num_layers=12,
num_attention_heads=12,
intermediate_size=3072,
vocab_size=30522,
max_position_embeddings=512,
modalities=["text"],
**kwargs
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.modalities = modalities
class StudentModel(PreTrainedModel):
config_class = StudentModelConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.hidden_size = config.hidden_size
self.num_layers = config.num_layers
self.modalities = config.modalities
# Build model layers based on config
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=config.hidden_size,
nhead=config.num_attention_heads,
dim_feedforward=config.intermediate_size,
batch_first=True
) for _ in range(config.num_layers)
])
self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, input_ids=None, attention_mask=None, **kwargs):
if input_ids is not None:
embeddings = self.embeddings(input_ids)
else:
# Handle other modalities
embeddings = kwargs.get('inputs_embeds')
for layer in self.layers:
embeddings = layer(embeddings, src_key_padding_mask=attention_mask)
pooled = self.pooler(embeddings.mean(dim=1))
return {
'last_hidden_state': embeddings,
'pooler_output': pooled
}