Spaces:
Running
Running
feat: allow relative position (#156)
Browse files
src/dalle_mini/model/configuration.py
CHANGED
|
@@ -64,12 +64,14 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
| 64 |
use_head_scale=False, # used in NormFormer
|
| 65 |
use_cosine_attention=False, # used in Swin v2
|
| 66 |
tau_init=0.05, # used only in cosine attention (Swin v2)
|
|
|
|
|
|
|
| 67 |
use_deepnet_scaling=False, # used in Deepnet
|
| 68 |
use_glu=False, # "GLU Variants Improve Transformer"
|
| 69 |
use_alibi=False, # Not implemented yet - from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
|
| 70 |
sinkhorn_iters=1, # used in SinkFormers
|
| 71 |
-
use_final_ln_encoder=
|
| 72 |
-
use_final_ln_decoder=
|
| 73 |
# parameters that should not be necessary but could affect results
|
| 74 |
force_ln_scale=False, # force scale in layernorm even when followed by dense layers
|
| 75 |
**kwargs,
|
|
@@ -98,6 +100,8 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
| 98 |
self.ln_positions = ln_positions
|
| 99 |
self.use_cosine_attention = use_cosine_attention
|
| 100 |
self.tau_init = tau_init
|
|
|
|
|
|
|
| 101 |
self.use_deepnet_scaling = use_deepnet_scaling
|
| 102 |
self.use_glu = use_glu
|
| 103 |
self.use_alibi = use_alibi
|
|
|
|
| 64 |
use_head_scale=False, # used in NormFormer
|
| 65 |
use_cosine_attention=False, # used in Swin v2
|
| 66 |
tau_init=0.05, # used only in cosine attention (Swin v2)
|
| 67 |
+
use_absolute_position_embeddings=True, # default
|
| 68 |
+
use_swin_position_embeddings=False, # used in Swin v1/v2
|
| 69 |
use_deepnet_scaling=False, # used in Deepnet
|
| 70 |
use_glu=False, # "GLU Variants Improve Transformer"
|
| 71 |
use_alibi=False, # Not implemented yet - from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
|
| 72 |
sinkhorn_iters=1, # used in SinkFormers
|
| 73 |
+
use_final_ln_encoder=True, # final layer normalization in encoder
|
| 74 |
+
use_final_ln_decoder=True, # final layer normalization in decoder
|
| 75 |
# parameters that should not be necessary but could affect results
|
| 76 |
force_ln_scale=False, # force scale in layernorm even when followed by dense layers
|
| 77 |
**kwargs,
|
|
|
|
| 100 |
self.ln_positions = ln_positions
|
| 101 |
self.use_cosine_attention = use_cosine_attention
|
| 102 |
self.tau_init = tau_init
|
| 103 |
+
self.use_absolute_position_embeddings = use_absolute_position_embeddings
|
| 104 |
+
self.use_swin_position_embeddings = use_swin_position_embeddings
|
| 105 |
self.use_deepnet_scaling = use_deepnet_scaling
|
| 106 |
self.use_glu = use_glu
|
| 107 |
self.use_alibi = use_alibi
|
src/dalle_mini/model/modeling.py
CHANGED
|
@@ -25,6 +25,7 @@ import flax.linen as nn
|
|
| 25 |
import jax
|
| 26 |
import jax.numpy as jnp
|
| 27 |
import msgpack.exceptions
|
|
|
|
| 28 |
from flax.core.frozen_dict import unfreeze
|
| 29 |
from flax.linen import combine_masks, make_causal_mask
|
| 30 |
from flax.linen import partitioning as nn_partitioning
|
|
@@ -52,8 +53,6 @@ from transformers.modeling_flax_outputs import (
|
|
| 52 |
from transformers.modeling_flax_utils import ACT2FN
|
| 53 |
from transformers.models.bart.modeling_flax_bart import (
|
| 54 |
FlaxBartAttention,
|
| 55 |
-
FlaxBartDecoder,
|
| 56 |
-
FlaxBartEncoder,
|
| 57 |
FlaxBartForConditionalGeneration,
|
| 58 |
FlaxBartForConditionalGenerationModule,
|
| 59 |
FlaxBartModule,
|
|
@@ -180,6 +179,7 @@ def dot_product_attention_weights(
|
|
| 180 |
key: Any,
|
| 181 |
bias: Optional[Any] = None,
|
| 182 |
mask: Optional[Any] = None,
|
|
|
|
| 183 |
broadcast_dropout: bool = True,
|
| 184 |
dropout_rng: Optional[PRNGKey] = None,
|
| 185 |
dropout_rate: float = 0.0,
|
|
@@ -210,6 +210,10 @@ def dot_product_attention_weights(
|
|
| 210 |
if bias is not None:
|
| 211 |
attn_weights = attn_weights + bias
|
| 212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
# normalize the attention weights
|
| 214 |
if causal or sinkhorn_iters == 1:
|
| 215 |
# sinkhorn does not work for causal (leaks info of future tokens into past)
|
|
@@ -251,6 +255,8 @@ class FlaxBartAttention(FlaxBartAttention):
|
|
| 251 |
"""
|
| 252 |
|
| 253 |
is_encoder: bool = False
|
|
|
|
|
|
|
| 254 |
|
| 255 |
def setup(self) -> None:
|
| 256 |
self.head_dim = self.embed_dim // self.num_heads
|
|
@@ -305,6 +311,15 @@ class FlaxBartAttention(FlaxBartAttention):
|
|
| 305 |
(1, self.num_heads, 1, 1),
|
| 306 |
)
|
| 307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
if self.causal:
|
| 309 |
# used only in decoder
|
| 310 |
self.causal_mask = make_causal_mask(
|
|
@@ -400,11 +415,21 @@ class FlaxBartAttention(FlaxBartAttention):
|
|
| 400 |
key_states = key_states / (
|
| 401 |
jnp.linalg.norm(key_states, axis=-1, keepdims=True) + 1e-8
|
| 402 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
attn_weights = dot_product_attention_weights(
|
| 404 |
query_states,
|
| 405 |
key_states,
|
| 406 |
bias=attention_bias,
|
| 407 |
mask=attention_mask,
|
|
|
|
| 408 |
dropout_rng=dropout_rng,
|
| 409 |
dropout_rate=self.dropout,
|
| 410 |
broadcast_dropout=True,
|
|
@@ -593,6 +618,8 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
| 593 |
bias=self.config.use_bias,
|
| 594 |
dtype=self.dtype,
|
| 595 |
is_encoder=True,
|
|
|
|
|
|
|
| 596 |
)(hidden_states=hidden_states, attention_mask=attention_mask)
|
| 597 |
|
| 598 |
if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
|
|
@@ -699,6 +726,8 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
| 699 |
bias=self.config.use_bias,
|
| 700 |
dtype=self.dtype,
|
| 701 |
is_encoder=False,
|
|
|
|
|
|
|
| 702 |
)(
|
| 703 |
hidden_states=hidden_states,
|
| 704 |
attention_mask=attention_mask,
|
|
@@ -737,6 +766,8 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
| 737 |
bias=self.config.use_bias,
|
| 738 |
dtype=self.dtype,
|
| 739 |
is_encoder=False,
|
|
|
|
|
|
|
| 740 |
)(
|
| 741 |
hidden_states=hidden_states,
|
| 742 |
key_value_states=encoder_hidden_states,
|
|
@@ -953,7 +984,10 @@ class FlaxBartDecoderLayerCollection(nn.Module):
|
|
| 953 |
)
|
| 954 |
|
| 955 |
|
| 956 |
-
class FlaxBartEncoder(
|
|
|
|
|
|
|
|
|
|
| 957 |
"""
|
| 958 |
Edits:
|
| 959 |
- offset set to 0 (no padding token)
|
|
@@ -972,18 +1006,62 @@ class FlaxBartEncoder(FlaxBartEncoder):
|
|
| 972 |
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
| 973 |
# and adjust num_embeddings appropriately. Other models don't have this hack
|
| 974 |
self.offset = 0
|
| 975 |
-
self.
|
| 976 |
-
self.
|
| 977 |
-
|
| 978 |
-
|
| 979 |
-
|
|
|
|
| 980 |
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
|
| 981 |
self.layernorm_embedding = norm(
|
| 982 |
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
| 983 |
)
|
| 984 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 985 |
|
| 986 |
-
class FlaxBartDecoder(
|
|
|
|
|
|
|
|
|
|
| 987 |
"""
|
| 988 |
Edits:
|
| 989 |
- offset set to 0 (no padding token)
|
|
@@ -1004,17 +1082,65 @@ class FlaxBartDecoder(FlaxBartDecoder):
|
|
| 1004 |
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
| 1005 |
# and adjust num_embeddings appropriately. Other models don't have this hack
|
| 1006 |
self.offset = 0
|
| 1007 |
-
self.
|
| 1008 |
-
self.
|
| 1009 |
-
|
| 1010 |
-
|
| 1011 |
-
|
|
|
|
| 1012 |
|
| 1013 |
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
|
| 1014 |
self.layernorm_embedding = norm(
|
| 1015 |
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
| 1016 |
)
|
| 1017 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1018 |
|
| 1019 |
class FlaxBartModule(FlaxBartModule):
|
| 1020 |
"""
|
|
|
|
| 25 |
import jax
|
| 26 |
import jax.numpy as jnp
|
| 27 |
import msgpack.exceptions
|
| 28 |
+
from einops import rearrange
|
| 29 |
from flax.core.frozen_dict import unfreeze
|
| 30 |
from flax.linen import combine_masks, make_causal_mask
|
| 31 |
from flax.linen import partitioning as nn_partitioning
|
|
|
|
| 53 |
from transformers.modeling_flax_utils import ACT2FN
|
| 54 |
from transformers.models.bart.modeling_flax_bart import (
|
| 55 |
FlaxBartAttention,
|
|
|
|
|
|
|
| 56 |
FlaxBartForConditionalGeneration,
|
| 57 |
FlaxBartForConditionalGenerationModule,
|
| 58 |
FlaxBartModule,
|
|
|
|
| 179 |
key: Any,
|
| 180 |
bias: Optional[Any] = None,
|
| 181 |
mask: Optional[Any] = None,
|
| 182 |
+
embed_pos: Optional[Any] = None,
|
| 183 |
broadcast_dropout: bool = True,
|
| 184 |
dropout_rng: Optional[PRNGKey] = None,
|
| 185 |
dropout_rate: float = 0.0,
|
|
|
|
| 210 |
if bias is not None:
|
| 211 |
attn_weights = attn_weights + bias
|
| 212 |
|
| 213 |
+
# add relative position
|
| 214 |
+
if embed_pos is not None:
|
| 215 |
+
attn_weights = attn_weights + embed_pos
|
| 216 |
+
|
| 217 |
# normalize the attention weights
|
| 218 |
if causal or sinkhorn_iters == 1:
|
| 219 |
# sinkhorn does not work for causal (leaks info of future tokens into past)
|
|
|
|
| 255 |
"""
|
| 256 |
|
| 257 |
is_encoder: bool = False
|
| 258 |
+
q_length: int = None
|
| 259 |
+
k_length: int = None
|
| 260 |
|
| 261 |
def setup(self) -> None:
|
| 262 |
self.head_dim = self.embed_dim // self.num_heads
|
|
|
|
| 311 |
(1, self.num_heads, 1, 1),
|
| 312 |
)
|
| 313 |
|
| 314 |
+
if self.config.use_swin_position_embeddings:
|
| 315 |
+
self.rel_bias = nn.Embed(
|
| 316 |
+
self.q_length,
|
| 317 |
+
self.k_length * self.num_heads,
|
| 318 |
+
embedding_init=deepnet_init()
|
| 319 |
+
if self.config.use_deepnet_scaling
|
| 320 |
+
else jax.nn.initializers.normal(self.config.init_std),
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
if self.causal:
|
| 324 |
# used only in decoder
|
| 325 |
self.causal_mask = make_causal_mask(
|
|
|
|
| 415 |
key_states = key_states / (
|
| 416 |
jnp.linalg.norm(key_states, axis=-1, keepdims=True) + 1e-8
|
| 417 |
)
|
| 418 |
+
|
| 419 |
+
# relative position embeddings
|
| 420 |
+
if self.config.use_swin_position_embeddings:
|
| 421 |
+
position_ids = jnp.arange(self.q_length)
|
| 422 |
+
embed_pos = self.rel_bias(position_ids)
|
| 423 |
+
embed_pos = rearrange(embed_pos, "q (k h) -> 1 h q k", h=self.num_heads)
|
| 424 |
+
else:
|
| 425 |
+
embed_pos = None
|
| 426 |
+
|
| 427 |
attn_weights = dot_product_attention_weights(
|
| 428 |
query_states,
|
| 429 |
key_states,
|
| 430 |
bias=attention_bias,
|
| 431 |
mask=attention_mask,
|
| 432 |
+
embed_pos=embed_pos,
|
| 433 |
dropout_rng=dropout_rng,
|
| 434 |
dropout_rate=self.dropout,
|
| 435 |
broadcast_dropout=True,
|
|
|
|
| 618 |
bias=self.config.use_bias,
|
| 619 |
dtype=self.dtype,
|
| 620 |
is_encoder=True,
|
| 621 |
+
q_length=self.config.max_text_length,
|
| 622 |
+
k_length=self.config.max_text_length,
|
| 623 |
)(hidden_states=hidden_states, attention_mask=attention_mask)
|
| 624 |
|
| 625 |
if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
|
|
|
|
| 726 |
bias=self.config.use_bias,
|
| 727 |
dtype=self.dtype,
|
| 728 |
is_encoder=False,
|
| 729 |
+
q_length=self.config.image_length,
|
| 730 |
+
k_length=self.config.image_length,
|
| 731 |
)(
|
| 732 |
hidden_states=hidden_states,
|
| 733 |
attention_mask=attention_mask,
|
|
|
|
| 766 |
bias=self.config.use_bias,
|
| 767 |
dtype=self.dtype,
|
| 768 |
is_encoder=False,
|
| 769 |
+
q_length=self.config.image_length,
|
| 770 |
+
k_length=self.config.max_text_length,
|
| 771 |
)(
|
| 772 |
hidden_states=hidden_states,
|
| 773 |
key_value_states=encoder_hidden_states,
|
|
|
|
| 984 |
)
|
| 985 |
|
| 986 |
|
| 987 |
+
class FlaxBartEncoder(nn.Module):
|
| 988 |
+
config: DalleBartConfig
|
| 989 |
+
embed_tokens: nn.Embed
|
| 990 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 991 |
"""
|
| 992 |
Edits:
|
| 993 |
- offset set to 0 (no padding token)
|
|
|
|
| 1006 |
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
| 1007 |
# and adjust num_embeddings appropriately. Other models don't have this hack
|
| 1008 |
self.offset = 0
|
| 1009 |
+
if self.config.use_absolute_position_embeddings:
|
| 1010 |
+
self.embed_positions = nn.Embed(
|
| 1011 |
+
self.config.max_text_length + self.offset, # image length for BOS
|
| 1012 |
+
embed_dim,
|
| 1013 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
| 1014 |
+
)
|
| 1015 |
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
|
| 1016 |
self.layernorm_embedding = norm(
|
| 1017 |
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
| 1018 |
)
|
| 1019 |
|
| 1020 |
+
def __call__(
|
| 1021 |
+
self,
|
| 1022 |
+
input_ids,
|
| 1023 |
+
attention_mask,
|
| 1024 |
+
position_ids,
|
| 1025 |
+
output_attentions: bool = False,
|
| 1026 |
+
output_hidden_states: bool = False,
|
| 1027 |
+
return_dict: bool = True,
|
| 1028 |
+
deterministic: bool = True,
|
| 1029 |
+
):
|
| 1030 |
+
input_shape = input_ids.shape
|
| 1031 |
+
input_ids = input_ids.reshape(-1, input_shape[-1])
|
| 1032 |
+
|
| 1033 |
+
hidden_states = self.embed_tokens(input_ids) * self.embed_scale
|
| 1034 |
+
|
| 1035 |
+
if self.config.use_absolute_position_embeddings:
|
| 1036 |
+
embed_pos = self.embed_positions(position_ids + self.offset)
|
| 1037 |
+
hidden_states = hidden_states + embed_pos
|
| 1038 |
+
|
| 1039 |
+
hidden_states = self.layernorm_embedding(hidden_states)
|
| 1040 |
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
| 1041 |
+
|
| 1042 |
+
outputs = self.layers(
|
| 1043 |
+
hidden_states,
|
| 1044 |
+
attention_mask,
|
| 1045 |
+
deterministic=deterministic,
|
| 1046 |
+
output_attentions=output_attentions,
|
| 1047 |
+
output_hidden_states=output_hidden_states,
|
| 1048 |
+
return_dict=return_dict,
|
| 1049 |
+
)
|
| 1050 |
+
|
| 1051 |
+
if not return_dict:
|
| 1052 |
+
return outputs
|
| 1053 |
+
|
| 1054 |
+
return FlaxBaseModelOutput(
|
| 1055 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 1056 |
+
hidden_states=outputs.hidden_states,
|
| 1057 |
+
attentions=outputs.attentions,
|
| 1058 |
+
)
|
| 1059 |
+
|
| 1060 |
|
| 1061 |
+
class FlaxBartDecoder(nn.Module):
|
| 1062 |
+
config: DalleBartConfig
|
| 1063 |
+
embed_tokens: nn.Embed
|
| 1064 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 1065 |
"""
|
| 1066 |
Edits:
|
| 1067 |
- offset set to 0 (no padding token)
|
|
|
|
| 1082 |
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
| 1083 |
# and adjust num_embeddings appropriately. Other models don't have this hack
|
| 1084 |
self.offset = 0
|
| 1085 |
+
if self.config.use_absolute_position_embeddings:
|
| 1086 |
+
self.embed_positions = nn.Embed(
|
| 1087 |
+
self.config.image_length + self.offset, # image length for BOS
|
| 1088 |
+
embed_dim,
|
| 1089 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
| 1090 |
+
)
|
| 1091 |
|
| 1092 |
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
|
| 1093 |
self.layernorm_embedding = norm(
|
| 1094 |
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
| 1095 |
)
|
| 1096 |
|
| 1097 |
+
def __call__(
|
| 1098 |
+
self,
|
| 1099 |
+
input_ids,
|
| 1100 |
+
attention_mask,
|
| 1101 |
+
position_ids,
|
| 1102 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
| 1103 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
| 1104 |
+
init_cache: bool = False,
|
| 1105 |
+
output_attentions: bool = False,
|
| 1106 |
+
output_hidden_states: bool = False,
|
| 1107 |
+
return_dict: bool = True,
|
| 1108 |
+
deterministic: bool = True,
|
| 1109 |
+
):
|
| 1110 |
+
input_shape = input_ids.shape
|
| 1111 |
+
input_ids = input_ids.reshape(-1, input_shape[-1])
|
| 1112 |
+
|
| 1113 |
+
hidden_states = self.embed_tokens(input_ids) * self.embed_scale
|
| 1114 |
+
|
| 1115 |
+
if self.config.use_absolute_position_embeddings:
|
| 1116 |
+
embed_pos = self.embed_positions(position_ids + self.offset)
|
| 1117 |
+
hidden_states = hidden_states + embed_pos
|
| 1118 |
+
|
| 1119 |
+
hidden_states = self.layernorm_embedding(hidden_states)
|
| 1120 |
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
| 1121 |
+
|
| 1122 |
+
outputs = self.layers(
|
| 1123 |
+
hidden_states,
|
| 1124 |
+
attention_mask,
|
| 1125 |
+
encoder_hidden_states,
|
| 1126 |
+
encoder_attention_mask,
|
| 1127 |
+
deterministic=deterministic,
|
| 1128 |
+
init_cache=init_cache,
|
| 1129 |
+
output_attentions=output_attentions,
|
| 1130 |
+
output_hidden_states=output_hidden_states,
|
| 1131 |
+
return_dict=return_dict,
|
| 1132 |
+
)
|
| 1133 |
+
|
| 1134 |
+
if not return_dict:
|
| 1135 |
+
return outputs
|
| 1136 |
+
|
| 1137 |
+
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
| 1138 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 1139 |
+
hidden_states=outputs.hidden_states,
|
| 1140 |
+
attentions=outputs.attentions,
|
| 1141 |
+
cross_attentions=outputs.cross_attentions,
|
| 1142 |
+
)
|
| 1143 |
+
|
| 1144 |
|
| 1145 |
class FlaxBartModule(FlaxBartModule):
|
| 1146 |
"""
|
src/dalle_mini/model/partitions.py
CHANGED
|
@@ -38,6 +38,7 @@ def _get_partition_rules():
|
|
| 38 |
# embeddings
|
| 39 |
(("embed_positions", "embedding"), P("mp", None)),
|
| 40 |
(("embed_tokens", "embedding"), P("mp", None)),
|
|
|
|
| 41 |
# attention
|
| 42 |
(("(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
|
| 43 |
(("out_proj", "kernel"), P("mp", None)),
|
|
|
|
| 38 |
# embeddings
|
| 39 |
(("embed_positions", "embedding"), P("mp", None)),
|
| 40 |
(("embed_tokens", "embedding"), P("mp", None)),
|
| 41 |
+
(("rel_bias", "embedding"), P(None, "mp")),
|
| 42 |
# attention
|
| 43 |
(("(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
|
| 44 |
(("out_proj", "kernel"), P("mp", None)),
|