Spaces:
Running
Running
style: lint
Browse files
tools/inference/inference_pipeline.ipynb
CHANGED
|
@@ -240,7 +240,7 @@
|
|
| 240 |
"import random\n",
|
| 241 |
"\n",
|
| 242 |
"# create a random key\n",
|
| 243 |
-
"seed = random.randint(0, 2
|
| 244 |
"key = jax.random.PRNGKey(seed)"
|
| 245 |
]
|
| 246 |
},
|
|
|
|
| 240 |
"import random\n",
|
| 241 |
"\n",
|
| 242 |
"# create a random key\n",
|
| 243 |
+
"seed = random.randint(0, 2**32 - 1)\n",
|
| 244 |
"key = jax.random.PRNGKey(seed)"
|
| 245 |
]
|
| 246 |
},
|
tools/train/train.py
CHANGED
|
@@ -58,7 +58,7 @@ from dalle_mini.model import (
|
|
| 58 |
)
|
| 59 |
|
| 60 |
cc.initialize_cache(
|
| 61 |
-
"/home/boris/dalle-mini/jax_cache", max_cache_size_bytes=5 * 2
|
| 62 |
)
|
| 63 |
|
| 64 |
logger = logging.getLogger(__name__)
|
|
|
|
| 58 |
)
|
| 59 |
|
| 60 |
cc.initialize_cache(
|
| 61 |
+
"/home/boris/dalle-mini/jax_cache", max_cache_size_bytes=5 * 2**30
|
| 62 |
)
|
| 63 |
|
| 64 |
logger = logging.getLogger(__name__)
|