Spaces:
Running
Running
fix: correct clip params
Browse files
tools/inference/log_inference_samples.ipynb
CHANGED
|
@@ -24,25 +24,6 @@
|
|
| 24 |
"from dalle_mini.text import TextNormalizer"
|
| 25 |
]
|
| 26 |
},
|
| 27 |
-
{
|
| 28 |
-
"cell_type": "code",
|
| 29 |
-
"execution_count": null,
|
| 30 |
-
"id": "23e00271-941c-4e1b-b6a9-107a1b77324d",
|
| 31 |
-
"metadata": {},
|
| 32 |
-
"outputs": [],
|
| 33 |
-
"source": [
|
| 34 |
-
"run_ids = ['3kaut6e8']\n",
|
| 35 |
-
"# Alamy - 3kaut6e8\n",
|
| 36 |
-
"# YFCC - to do\n",
|
| 37 |
-
"# HF spaces - 4oh3u7ca\n",
|
| 38 |
-
"ENTITY, PROJECT = 'wandb', 'hf-flax-dalle-mini'\n",
|
| 39 |
-
"VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
|
| 40 |
-
"normalize_text = False\n",
|
| 41 |
-
"latest_only = True # log only latest or all versions\n",
|
| 42 |
-
"suffix = '' # mainly for duplicate inference runs with a deleted version\n",
|
| 43 |
-
"add_clip_32 = False"
|
| 44 |
-
]
|
| 45 |
-
},
|
| 46 |
{
|
| 47 |
"cell_type": "code",
|
| 48 |
"execution_count": null,
|
|
@@ -50,13 +31,9 @@
|
|
| 50 |
"metadata": {},
|
| 51 |
"outputs": [],
|
| 52 |
"source": [
|
| 53 |
-
"run_ids = ['
|
| 54 |
-
"# poorly shuffled 1nj161cl\n",
|
| 55 |
-
"# well shuffled he9rrc3q\n",
|
| 56 |
-
"# non normalized 1fwxpyfh ! requires changing normalize_text\n",
|
| 57 |
"ENTITY, PROJECT = 'dalle-mini', 'dalle-mini' # used only for training run\n",
|
| 58 |
-
"VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384',
|
| 59 |
-
"normalize_text = True\n",
|
| 60 |
"latest_only = True # log only latest or all versions\n",
|
| 61 |
"suffix = '' # mainly for duplicate inference runs with a deleted version\n",
|
| 62 |
"add_clip_32 = False"
|
|
@@ -85,7 +62,7 @@
|
|
| 85 |
"batch_size = 8\n",
|
| 86 |
"num_images = 128\n",
|
| 87 |
"top_k = 8\n",
|
| 88 |
-
"text_normalizer = TextNormalizer()
|
| 89 |
"padding_item = 'NONE'\n",
|
| 90 |
"seed = random.randint(0, 2**32-1)\n",
|
| 91 |
"key = jax.random.PRNGKey(seed)\n",
|
|
@@ -230,7 +207,7 @@
|
|
| 230 |
"outputs": [],
|
| 231 |
"source": [
|
| 232 |
"run_id = run_ids[0]\n",
|
| 233 |
-
"# TODO:
|
| 234 |
]
|
| 235 |
},
|
| 236 |
{
|
|
@@ -287,7 +264,7 @@
|
|
| 287 |
"\n",
|
| 288 |
" # process one batch of captions\n",
|
| 289 |
" for batch in tqdm(samples):\n",
|
| 290 |
-
" processed_prompts = [text_normalizer(x) for x in batch] if normalize_text else list(batch)\n",
|
| 291 |
"\n",
|
| 292 |
" # repeat the prompts to distribute over each device and tokenize\n",
|
| 293 |
" processed_prompts = processed_prompts * jax.device_count()\n",
|
|
@@ -296,7 +273,7 @@
|
|
| 296 |
"\n",
|
| 297 |
" # generate images\n",
|
| 298 |
" images = []\n",
|
| 299 |
-
" pbar = tqdm(range(num_images // jax.device_count()), desc='Generating Images', leave=
|
| 300 |
" for i in pbar:\n",
|
| 301 |
" key, subkey = jax.random.split(key)\n",
|
| 302 |
" encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params)\n",
|
|
@@ -312,7 +289,7 @@
|
|
| 312 |
" images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
|
| 313 |
" clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
|
| 314 |
" clip_inputs = shard(clip_inputs)\n",
|
| 315 |
-
" logits = p_clip(clip_inputs,
|
| 316 |
" logits = logits.reshape(-1, num_images)\n",
|
| 317 |
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
| 318 |
" logits = jax.device_get(logits)\n",
|
|
@@ -348,6 +325,14 @@
|
|
| 348 |
" wandb.finish()\n",
|
| 349 |
" run = None # ensure we don't log on this run"
|
| 350 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
}
|
| 352 |
],
|
| 353 |
"metadata": {
|
|
|
|
| 24 |
"from dalle_mini.text import TextNormalizer"
|
| 25 |
]
|
| 26 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
{
|
| 28 |
"cell_type": "code",
|
| 29 |
"execution_count": null,
|
|
|
|
| 31 |
"metadata": {},
|
| 32 |
"outputs": [],
|
| 33 |
"source": [
|
| 34 |
+
"run_ids = ['63otg87g']\n",
|
|
|
|
|
|
|
|
|
|
| 35 |
"ENTITY, PROJECT = 'dalle-mini', 'dalle-mini' # used only for training run\n",
|
| 36 |
+
"VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', 'e93a26e7707683d349bf5d5c41c5b0ef69b677a9'\n",
|
|
|
|
| 37 |
"latest_only = True # log only latest or all versions\n",
|
| 38 |
"suffix = '' # mainly for duplicate inference runs with a deleted version\n",
|
| 39 |
"add_clip_32 = False"
|
|
|
|
| 62 |
"batch_size = 8\n",
|
| 63 |
"num_images = 128\n",
|
| 64 |
"top_k = 8\n",
|
| 65 |
+
"text_normalizer = TextNormalizer()\n",
|
| 66 |
"padding_item = 'NONE'\n",
|
| 67 |
"seed = random.randint(0, 2**32-1)\n",
|
| 68 |
"key = jax.random.PRNGKey(seed)\n",
|
|
|
|
| 207 |
"outputs": [],
|
| 208 |
"source": [
|
| 209 |
"run_id = run_ids[0]\n",
|
| 210 |
+
"# TODO: loop over runs"
|
| 211 |
]
|
| 212 |
},
|
| 213 |
{
|
|
|
|
| 264 |
"\n",
|
| 265 |
" # process one batch of captions\n",
|
| 266 |
" for batch in tqdm(samples):\n",
|
| 267 |
+
" processed_prompts = [text_normalizer(x) for x in batch] if model.config.normalize_text else list(batch)\n",
|
| 268 |
"\n",
|
| 269 |
" # repeat the prompts to distribute over each device and tokenize\n",
|
| 270 |
" processed_prompts = processed_prompts * jax.device_count()\n",
|
|
|
|
| 273 |
"\n",
|
| 274 |
" # generate images\n",
|
| 275 |
" images = []\n",
|
| 276 |
+
" pbar = tqdm(range(num_images // jax.device_count()), desc='Generating Images', leave=True)\n",
|
| 277 |
" for i in pbar:\n",
|
| 278 |
" key, subkey = jax.random.split(key)\n",
|
| 279 |
" encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params)\n",
|
|
|
|
| 289 |
" images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
|
| 290 |
" clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
|
| 291 |
" clip_inputs = shard(clip_inputs)\n",
|
| 292 |
+
" logits = p_clip(clip_inputs, clip_params)\n",
|
| 293 |
" logits = logits.reshape(-1, num_images)\n",
|
| 294 |
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
| 295 |
" logits = jax.device_get(logits)\n",
|
|
|
|
| 325 |
" wandb.finish()\n",
|
| 326 |
" run = None # ensure we don't log on this run"
|
| 327 |
]
|
| 328 |
+
},
|
| 329 |
+
{
|
| 330 |
+
"cell_type": "code",
|
| 331 |
+
"execution_count": null,
|
| 332 |
+
"id": "415d3f54-7226-43de-9eea-4283a948dc93",
|
| 333 |
+
"metadata": {},
|
| 334 |
+
"outputs": [],
|
| 335 |
+
"source": []
|
| 336 |
}
|
| 337 |
],
|
| 338 |
"metadata": {
|