Spaces:
Running
Running
fix(seq2seq): opt_state from ckpt + limit cache
Browse files
dev/seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -20,10 +20,6 @@ Script adapted from run_summarization_flax.py
|
|
| 20 |
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
| 21 |
|
| 22 |
import os
|
| 23 |
-
# set a common huggingface cache folder (used with datasets and transformers) and wandb cache folder (used with artifacts)
|
| 24 |
-
os.environ['HF_HOME'] = '/data/huggingface/' # required before importing transformers & datasets
|
| 25 |
-
os.environ['WANDB_CACHE_DIR'] = '/data/wandb/' # required before importing wandb
|
| 26 |
-
|
| 27 |
import logging as pylogging # To avoid collision with transformers.utils.logging
|
| 28 |
import sys
|
| 29 |
import time
|
|
@@ -442,6 +438,7 @@ def main():
|
|
| 442 |
if (Path(artifact_dir) / 'opt_state.msgpack').exists():
|
| 443 |
with (Path(artifact_dir) / 'opt_state.msgpack').open('rb') as f:
|
| 444 |
opt_state = from_bytes(state.opt_state, f.read())
|
|
|
|
| 445 |
|
| 446 |
# restore steps
|
| 447 |
if (Path(artifact_dir) / 'training_state.json').exists():
|
|
@@ -836,6 +833,10 @@ def main():
|
|
| 836 |
artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
|
| 837 |
wandb.run.log_artifact(artifact)
|
| 838 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 839 |
# save to the hub
|
| 840 |
if training_args.push_to_hub:
|
| 841 |
model.save_pretrained(
|
|
@@ -866,7 +867,7 @@ def main():
|
|
| 866 |
# log metrics
|
| 867 |
wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
|
| 868 |
|
| 869 |
-
if global_step % training_args.eval_steps == 0:
|
| 870 |
run_evaluation()
|
| 871 |
|
| 872 |
if global_step % data_args.save_model_steps == 0:
|
|
|
|
| 20 |
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
| 21 |
|
| 22 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
import logging as pylogging # To avoid collision with transformers.utils.logging
|
| 24 |
import sys
|
| 25 |
import time
|
|
|
|
| 438 |
if (Path(artifact_dir) / 'opt_state.msgpack').exists():
|
| 439 |
with (Path(artifact_dir) / 'opt_state.msgpack').open('rb') as f:
|
| 440 |
opt_state = from_bytes(state.opt_state, f.read())
|
| 441 |
+
state.replace(opt_state=opt_state)
|
| 442 |
|
| 443 |
# restore steps
|
| 444 |
if (Path(artifact_dir) / 'training_state.json').exists():
|
|
|
|
| 833 |
artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
|
| 834 |
wandb.run.log_artifact(artifact)
|
| 835 |
|
| 836 |
+
# save some space
|
| 837 |
+
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
| 838 |
+
c.cleanup(wandb.util.from_human_size("15GB"))
|
| 839 |
+
|
| 840 |
# save to the hub
|
| 841 |
if training_args.push_to_hub:
|
| 842 |
model.save_pretrained(
|
|
|
|
| 867 |
# log metrics
|
| 868 |
wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
|
| 869 |
|
| 870 |
+
if training_args.eval_steps and global_step % training_args.eval_steps == 0:
|
| 871 |
run_evaluation()
|
| 872 |
|
| 873 |
if global_step % data_args.save_model_steps == 0:
|