Update modeling_chatglm.py for inputs_embeds
#45
by
Xipotzzz
- opened
- modeling_chatglm.py +22 -11
modeling_chatglm.py
CHANGED
|
@@ -914,11 +914,11 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 914 |
use_cache = False
|
| 915 |
|
| 916 |
if input_ids is not None and inputs_embeds is not None:
|
| 917 |
-
|
| 918 |
-
|
| 919 |
batch_size, seq_length = input_ids.shape[:2]
|
| 920 |
elif inputs_embeds is not None:
|
| 921 |
-
batch_size, seq_length
|
| 922 |
else:
|
| 923 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 924 |
|
|
@@ -972,9 +972,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 972 |
|
| 973 |
if attention_mask is None:
|
| 974 |
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
|
| 975 |
-
|
| 976 |
else:
|
| 977 |
-
attention_mask = attention_mask.to(
|
| 978 |
|
| 979 |
for i, layer in enumerate(self.layers):
|
| 980 |
|
|
@@ -1105,6 +1104,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1105 |
def prepare_inputs_for_generation(
|
| 1106 |
self,
|
| 1107 |
input_ids: torch.LongTensor,
|
|
|
|
| 1108 |
past: Optional[torch.Tensor] = None,
|
| 1109 |
past_key_values: Optional[torch.Tensor] = None,
|
| 1110 |
attention_mask: Optional[torch.Tensor] = None,
|
|
@@ -1165,12 +1165,23 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1165 |
use_gmasks=use_gmasks
|
| 1166 |
)
|
| 1167 |
|
| 1168 |
-
|
| 1169 |
-
|
| 1170 |
-
|
| 1171 |
-
"
|
| 1172 |
-
|
| 1173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1174 |
|
| 1175 |
def forward(
|
| 1176 |
self,
|
|
|
|
| 914 |
use_cache = False
|
| 915 |
|
| 916 |
if input_ids is not None and inputs_embeds is not None:
|
| 917 |
+
logger.warning("You passed both `inputs_embeds` and `input_ids`. Will use `inputs_embeds`")
|
| 918 |
+
if input_ids is not None:
|
| 919 |
batch_size, seq_length = input_ids.shape[:2]
|
| 920 |
elif inputs_embeds is not None:
|
| 921 |
+
batch_size, seq_length = inputs_embeds.shape[:2]
|
| 922 |
else:
|
| 923 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 924 |
|
|
|
|
| 972 |
|
| 973 |
if attention_mask is None:
|
| 974 |
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
|
|
|
|
| 975 |
else:
|
| 976 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
| 977 |
|
| 978 |
for i, layer in enumerate(self.layers):
|
| 979 |
|
|
|
|
| 1104 |
def prepare_inputs_for_generation(
|
| 1105 |
self,
|
| 1106 |
input_ids: torch.LongTensor,
|
| 1107 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1108 |
past: Optional[torch.Tensor] = None,
|
| 1109 |
past_key_values: Optional[torch.Tensor] = None,
|
| 1110 |
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
| 1165 |
use_gmasks=use_gmasks
|
| 1166 |
)
|
| 1167 |
|
| 1168 |
+
if inputs_embeds is not None:
|
| 1169 |
+
assert input_ids.size(1) == inputs_embeds.size(
|
| 1170 |
+
1
|
| 1171 |
+
), f"Make sure that both input_ids ({input_ids.size(1)}) and inputs_embeds ({inputs_embeds.size(1)}) have the same length."
|
| 1172 |
+
return {
|
| 1173 |
+
"inputs_embeds": inputs_embeds,
|
| 1174 |
+
"past_key_values": past,
|
| 1175 |
+
"position_ids": position_ids,
|
| 1176 |
+
"attention_mask": attention_mask,
|
| 1177 |
+
}
|
| 1178 |
+
else:
|
| 1179 |
+
return {
|
| 1180 |
+
"input_ids": input_ids,
|
| 1181 |
+
"past_key_values": past,
|
| 1182 |
+
"position_ids": position_ids,
|
| 1183 |
+
"attention_mask": attention_mask,
|
| 1184 |
+
}
|
| 1185 |
|
| 1186 |
def forward(
|
| 1187 |
self,
|