Merge branch 'main' into pr/45
Browse files- modeling_chatglm.py +2 -0
modeling_chatglm.py
CHANGED
|
@@ -970,6 +970,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 970 |
|
| 971 |
if attention_mask is None:
|
| 972 |
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
|
|
|
|
|
|
|
| 973 |
|
| 974 |
for i, layer in enumerate(self.layers):
|
| 975 |
|
|
|
|
| 970 |
|
| 971 |
if attention_mask is None:
|
| 972 |
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
|
| 973 |
+
else:
|
| 974 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
| 975 |
|
| 976 |
for i, layer in enumerate(self.layers):
|
| 977 |
|