adeebaldkheel commited on
Commit
932c5f6
·
verified ·
1 Parent(s): 2c968b4

Fix: Add Transformers 4.47+ compatibility - Wraps LlamaFlashAttention2 import in try-except for backward compatibility - Falls back to LlamaAttention when FlashAttention2 unavailable - Tested on Transformers 4.46.3 with macOS MPS - Minimal change: 1 file, 13 insertions, 5 deletions

Browse files
Files changed (1) hide show
  1. modeling_deepseekv2.py +13 -5
modeling_deepseekv2.py CHANGED
@@ -34,10 +34,18 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
34
  from transformers.activations import ACT2FN
35
  from transformers.cache_utils import Cache, DynamicCache
36
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
37
- from transformers.models.llama.modeling_llama import (
38
- LlamaAttention,
39
- LlamaFlashAttention2
40
- )
 
 
 
 
 
 
 
 
41
  from transformers.modeling_outputs import (
42
  BaseModelOutputWithPast,
43
  CausalLMOutputWithPast,
@@ -1235,7 +1243,7 @@ ATTENTION_CLASSES = {
1235
  "mla_flash_attention_2": DeepseekV2FlashAttention2,
1236
 
1237
  "mha_eager": LlamaAttention,
1238
- "mha_flash_attention_2": LlamaFlashAttention2
1239
  }
1240
 
1241
 
 
34
  from transformers.activations import ACT2FN
35
  from transformers.cache_utils import Cache, DynamicCache
36
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
37
+
38
+ # Handle different transformers versions
39
+ try:
40
+ from transformers.models.llama.modeling_llama import (
41
+ LlamaAttention,
42
+ LlamaFlashAttention2
43
+ )
44
+ except ImportError:
45
+ # Newer transformers versions (4.47+) don't have LlamaFlashAttention2
46
+ from transformers.models.llama.modeling_llama import LlamaAttention
47
+ LlamaFlashAttention2 = None # Will use fallback
48
+
49
  from transformers.modeling_outputs import (
50
  BaseModelOutputWithPast,
51
  CausalLMOutputWithPast,
 
1243
  "mla_flash_attention_2": DeepseekV2FlashAttention2,
1244
 
1245
  "mha_eager": LlamaAttention,
1246
+ "mha_flash_attention_2": LlamaFlashAttention2 if LlamaFlashAttention2 is not None else LlamaAttention
1247
  }
1248
 
1249