Spaces:
Running
Running
feat: model config not hardcoded
Browse filesFormer-commit-id: 8cc773f8dfaee95469a926d907c006873922e1c6
- seq2seq/run_seq2seq_flax.py +12 -5
seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -271,6 +271,10 @@ class TrainState(train_state.TrainState):
|
|
| 271 |
|
| 272 |
class CustomFlaxBartModule(FlaxBartModule):
|
| 273 |
def setup(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
# we keep shared to easily load pre-trained weights
|
| 275 |
self.shared = nn.Embed(
|
| 276 |
self.config.vocab_size,
|
|
@@ -280,7 +284,7 @@ class CustomFlaxBartModule(FlaxBartModule):
|
|
| 280 |
)
|
| 281 |
# a separate embedding is used for the decoder
|
| 282 |
self.decoder_embed = nn.Embed(
|
| 283 |
-
|
| 284 |
self.config.d_model,
|
| 285 |
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
| 286 |
dtype=self.dtype,
|
|
@@ -289,20 +293,23 @@ class CustomFlaxBartModule(FlaxBartModule):
|
|
| 289 |
|
| 290 |
# the decoder has a different config
|
| 291 |
decoder_config = BartConfig(self.config.to_dict())
|
| 292 |
-
decoder_config.max_position_embeddings =
|
| 293 |
-
decoder_config.vocab_size =
|
| 294 |
self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
|
| 295 |
|
| 296 |
class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
|
| 297 |
def setup(self):
|
|
|
|
|
|
|
|
|
|
| 298 |
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
|
| 299 |
self.lm_head = nn.Dense(
|
| 300 |
-
|
| 301 |
use_bias=False,
|
| 302 |
dtype=self.dtype,
|
| 303 |
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
| 304 |
)
|
| 305 |
-
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1,
|
| 306 |
|
| 307 |
class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
|
| 308 |
module_class = CustomFlaxBartForConditionalGenerationModule
|
|
|
|
| 271 |
|
| 272 |
class CustomFlaxBartModule(FlaxBartModule):
|
| 273 |
def setup(self):
|
| 274 |
+
# check config is valid, otherwise set default values
|
| 275 |
+
self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
|
| 276 |
+
self.config.max_position_embeddings_decoder = getattr(self.config, 'vocab_size_output', OUTPUT_LENGTH)
|
| 277 |
+
|
| 278 |
# we keep shared to easily load pre-trained weights
|
| 279 |
self.shared = nn.Embed(
|
| 280 |
self.config.vocab_size,
|
|
|
|
| 284 |
)
|
| 285 |
# a separate embedding is used for the decoder
|
| 286 |
self.decoder_embed = nn.Embed(
|
| 287 |
+
self.config.vocab_size_output,
|
| 288 |
self.config.d_model,
|
| 289 |
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
| 290 |
dtype=self.dtype,
|
|
|
|
| 293 |
|
| 294 |
# the decoder has a different config
|
| 295 |
decoder_config = BartConfig(self.config.to_dict())
|
| 296 |
+
decoder_config.max_position_embeddings = self.config.max_position_embeddings_decoder
|
| 297 |
+
decoder_config.vocab_size = self.config.vocab_size_output
|
| 298 |
self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
|
| 299 |
|
| 300 |
class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
|
| 301 |
def setup(self):
|
| 302 |
+
# check config is valid, otherwise set default values
|
| 303 |
+
self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
|
| 304 |
+
|
| 305 |
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
|
| 306 |
self.lm_head = nn.Dense(
|
| 307 |
+
self.config.vocab_size_output,
|
| 308 |
use_bias=False,
|
| 309 |
dtype=self.dtype,
|
| 310 |
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
| 311 |
)
|
| 312 |
+
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.config.vocab_size_output))
|
| 313 |
|
| 314 |
class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
|
| 315 |
module_class = CustomFlaxBartForConditionalGenerationModule
|