Spaces:
Running
Running
fix: accumulation vs lr
Browse files
seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -673,12 +673,12 @@ def main():
|
|
| 673 |
grads = jax.tree_map(lambda x: x / training_args.gradient_accumulation_steps, grad_accum)
|
| 674 |
grads = jax.lax.pmean(grads, "batch")
|
| 675 |
new_state = state.apply_gradients(
|
| 676 |
-
grads=grads, grad_accum=jax.tree_map(jnp.zeros_like, grads), optimizer_step=state.optimizer_step
|
| 677 |
)
|
| 678 |
return new_state
|
| 679 |
|
| 680 |
new_state = jax.lax.cond(
|
| 681 |
-
state.step % training_args.gradient_accumulation_steps == 0,
|
| 682 |
lambda _: update_fn(),
|
| 683 |
lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
|
| 684 |
None,
|
|
|
|
| 673 |
grads = jax.tree_map(lambda x: x / training_args.gradient_accumulation_steps, grad_accum)
|
| 674 |
grads = jax.lax.pmean(grads, "batch")
|
| 675 |
new_state = state.apply_gradients(
|
| 676 |
+
grads=grads, grad_accum=jax.tree_map(jnp.zeros_like, grads), optimizer_step=state.optimizer_step + 1
|
| 677 |
)
|
| 678 |
return new_state
|
| 679 |
|
| 680 |
new_state = jax.lax.cond(
|
| 681 |
+
(state.step + 1) % training_args.gradient_accumulation_steps == 0,
|
| 682 |
lambda _: update_fn(),
|
| 683 |
lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
|
| 684 |
None,
|