Spaces:
Running
Running
feat: no need for default values
Browse files
dev/seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -280,9 +280,9 @@ class DataTrainingArguments:
|
|
| 280 |
|
| 281 |
|
| 282 |
class TrainState(train_state.TrainState):
|
| 283 |
-
dropout_rng: jnp.ndarray
|
| 284 |
-
grad_accum: jnp.ndarray
|
| 285 |
-
optimizer_step: int
|
| 286 |
|
| 287 |
def replicate(self):
|
| 288 |
return jax_utils.replicate(self).replace(
|
|
|
|
| 280 |
|
| 281 |
|
| 282 |
class TrainState(train_state.TrainState):
|
| 283 |
+
dropout_rng: jnp.ndarray
|
| 284 |
+
grad_accum: jnp.ndarray
|
| 285 |
+
optimizer_step: int
|
| 286 |
|
| 287 |
def replicate(self):
|
| 288 |
return jax_utils.replicate(self).replace(
|