Spaces:
Running
Running
feat: avoid OOM
Browse files
dev/seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -475,6 +475,8 @@ def main():
|
|
| 475 |
|
| 476 |
# load model
|
| 477 |
model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
|
|
|
|
|
|
|
| 478 |
|
| 479 |
# load tokenizer
|
| 480 |
tokenizer = AutoTokenizer.from_pretrained(
|
|
@@ -529,7 +531,10 @@ def main():
|
|
| 529 |
config=config,
|
| 530 |
seed=training_args.seed_model,
|
| 531 |
dtype=getattr(jnp, model_args.dtype),
|
|
|
|
| 532 |
)
|
|
|
|
|
|
|
| 533 |
else:
|
| 534 |
model = CustomFlaxBartForConditionalGeneration(
|
| 535 |
config,
|
|
|
|
| 475 |
|
| 476 |
# load model
|
| 477 |
model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
|
| 478 |
+
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
| 479 |
+
print(model.params)
|
| 480 |
|
| 481 |
# load tokenizer
|
| 482 |
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
| 531 |
config=config,
|
| 532 |
seed=training_args.seed_model,
|
| 533 |
dtype=getattr(jnp, model_args.dtype),
|
| 534 |
+
ignore_mismatched_sizes=True,
|
| 535 |
)
|
| 536 |
+
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
| 537 |
+
print(model.params)
|
| 538 |
else:
|
| 539 |
model = CustomFlaxBartForConditionalGeneration(
|
| 540 |
config,
|