Spaces:
Running
Running
handle dtype for embeddings
Browse files
dalle_mini/modeling_bart_flax.py
CHANGED
|
@@ -461,8 +461,10 @@ class FlaxBartEncoder(nn.Module):
|
|
| 461 |
input_ids = input_ids.reshape(-1, input_shape[-1])
|
| 462 |
|
| 463 |
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
|
|
|
| 464 |
|
| 465 |
embed_pos = self.embed_positions(position_ids + self.offset)
|
|
|
|
| 466 |
|
| 467 |
hidden_states = inputs_embeds + embed_pos
|
| 468 |
hidden_states = self.layernorm_embedding(hidden_states)
|
|
@@ -521,9 +523,11 @@ class FlaxBartDecoder(nn.Module):
|
|
| 521 |
input_ids = input_ids.reshape(-1, input_shape[-1])
|
| 522 |
|
| 523 |
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
|
|
|
| 524 |
|
| 525 |
# embed positions
|
| 526 |
positions = self.embed_positions(position_ids + self.offset)
|
|
|
|
| 527 |
|
| 528 |
hidden_states = inputs_embeds + positions
|
| 529 |
hidden_states = self.layernorm_embedding(hidden_states)
|
|
|
|
| 461 |
input_ids = input_ids.reshape(-1, input_shape[-1])
|
| 462 |
|
| 463 |
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
| 464 |
+
inputs_embeds = inputs_embeds.astype(self.dtype)
|
| 465 |
|
| 466 |
embed_pos = self.embed_positions(position_ids + self.offset)
|
| 467 |
+
embed_pos = embed_pos.astype(self.dtype)
|
| 468 |
|
| 469 |
hidden_states = inputs_embeds + embed_pos
|
| 470 |
hidden_states = self.layernorm_embedding(hidden_states)
|
|
|
|
| 523 |
input_ids = input_ids.reshape(-1, input_shape[-1])
|
| 524 |
|
| 525 |
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
| 526 |
+
inputs_embeds = inputs_embeds.astype(self.dtype)
|
| 527 |
|
| 528 |
# embed positions
|
| 529 |
positions = self.embed_positions(position_ids + self.offset)
|
| 530 |
+
positions = positions.astype(self.dtype)
|
| 531 |
|
| 532 |
hidden_states = inputs_embeds + positions
|
| 533 |
hidden_states = self.layernorm_embedding(hidden_states)
|