Spaces:
Runtime error
Runtime error
Upload 288 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -65
- .gitignore +2 -0
- README.md +16 -12
- app.py +274 -232
- app_img.py +414 -0
- configs/generation/slat_flow_img_dit_L_64l8p2_fp16.json +102 -0
- configs/generation/slat_flow_txt_dit_B_64l8p2_fp16.json +101 -0
- configs/generation/slat_flow_txt_dit_L_64l8p2_fp16.json +101 -0
- configs/generation/slat_flow_txt_dit_XL_64l8p2_fp16.json +101 -0
- configs/generation/ss_flow_img_dit_L_16l8_fp16.json +70 -0
- configs/generation/ss_flow_txt_dit_B_16l8_fp16.json +69 -0
- configs/generation/ss_flow_txt_dit_L_16l8_fp16.json +69 -0
- configs/generation/ss_flow_txt_dit_XL_16l8_fp16.json +70 -0
- configs/vae/slat_vae_dec_mesh_swin8_B_64l8_fp16.json +73 -0
- configs/vae/slat_vae_dec_rf_swin8_B_64l8_fp16.json +71 -0
- configs/vae/slat_vae_enc_dec_gs_swin8_B_64l8_fp16.json +105 -0
- configs/vae/ss_vae_conv3d_16l8_fp16.json +65 -0
- dataset_toolkits/blender_script/io_scene_usdz.zip +3 -0
- dataset_toolkits/blender_script/render.py +528 -0
- dataset_toolkits/build_metadata.py +270 -0
- dataset_toolkits/datasets/3D-FUTURE.py +97 -0
- dataset_toolkits/datasets/ABO.py +96 -0
- dataset_toolkits/datasets/HSSD.py +103 -0
- dataset_toolkits/datasets/ObjaverseXL.py +92 -0
- dataset_toolkits/datasets/Toys4k.py +92 -0
- dataset_toolkits/download.py +52 -0
- dataset_toolkits/encode_latent.py +127 -0
- dataset_toolkits/encode_ss_latent.py +128 -0
- dataset_toolkits/extract_feature.py +179 -0
- dataset_toolkits/render.py +121 -0
- dataset_toolkits/render_cond.py +125 -0
- dataset_toolkits/setup.sh +1 -0
- dataset_toolkits/stat_latent.py +66 -0
- dataset_toolkits/utils.py +43 -0
- dataset_toolkits/voxelize.py +86 -0
- env.py +10 -10
- extensions/vox2seq/benchmark.py +45 -0
- extensions/vox2seq/setup.py +34 -0
- extensions/vox2seq/src/api.cu +92 -0
- extensions/vox2seq/src/api.h +76 -0
- extensions/vox2seq/src/ext.cpp +10 -0
- extensions/vox2seq/src/hilbert.cu +133 -0
- extensions/vox2seq/src/hilbert.h +35 -0
- extensions/vox2seq/src/z_order.cu +66 -0
- extensions/vox2seq/src/z_order.h +35 -0
- extensions/vox2seq/test.py +25 -0
- extensions/vox2seq/vox2seq/__init__.py +50 -0
- extensions/vox2seq/vox2seq/pytorch/__init__.py +48 -0
- extensions/vox2seq/vox2seq/pytorch/default.py +59 -0
- extensions/vox2seq/vox2seq/pytorch/hilbert.py +303 -0
.gitattributes
CHANGED
|
@@ -34,68 +34,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
assets/example_image/typical_building_maya_pyramid.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
-
assets/example_image/typical_building_mushroom.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
-
assets/example_image/typical_building_space_station.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
-
assets/example_image/typical_creature_dragon.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
-
assets/example_image/typical_creature_elephant.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
-
assets/example_image/typical_creature_furry.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
-
assets/example_image/typical_creature_quadruped.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
-
assets/example_image/typical_creature_robot_crab.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
-
assets/example_image/typical_creature_robot_dinosour.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
-
assets/example_image/typical_creature_rock_monster.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
-
assets/example_image/typical_humanoid_block_robot.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
-
assets/example_image/typical_humanoid_dragonborn.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
-
assets/example_image/typical_humanoid_dwarf.png filter=lfs diff=lfs merge=lfs -text
|
| 54 |
-
assets/example_image/typical_humanoid_goblin.png filter=lfs diff=lfs merge=lfs -text
|
| 55 |
-
assets/example_image/typical_humanoid_mech.png filter=lfs diff=lfs merge=lfs -text
|
| 56 |
-
assets/example_image/typical_misc_crate.png filter=lfs diff=lfs merge=lfs -text
|
| 57 |
-
assets/example_image/typical_misc_fireplace.png filter=lfs diff=lfs merge=lfs -text
|
| 58 |
-
assets/example_image/typical_misc_gate.png filter=lfs diff=lfs merge=lfs -text
|
| 59 |
-
assets/example_image/typical_misc_lantern.png filter=lfs diff=lfs merge=lfs -text
|
| 60 |
-
assets/example_image/typical_misc_magicbook.png filter=lfs diff=lfs merge=lfs -text
|
| 61 |
-
assets/example_image/typical_misc_mailbox.png filter=lfs diff=lfs merge=lfs -text
|
| 62 |
-
assets/example_image/typical_misc_monster_chest.png filter=lfs diff=lfs merge=lfs -text
|
| 63 |
-
assets/example_image/typical_misc_paper_machine.png filter=lfs diff=lfs merge=lfs -text
|
| 64 |
-
assets/example_image/typical_misc_phonograph.png filter=lfs diff=lfs merge=lfs -text
|
| 65 |
-
assets/example_image/typical_misc_portal2.png filter=lfs diff=lfs merge=lfs -text
|
| 66 |
-
assets/example_image/typical_misc_storage_chest.png filter=lfs diff=lfs merge=lfs -text
|
| 67 |
-
assets/example_image/typical_misc_telephone.png filter=lfs diff=lfs merge=lfs -text
|
| 68 |
-
assets/example_image/typical_misc_television.png filter=lfs diff=lfs merge=lfs -text
|
| 69 |
-
assets/example_image/typical_misc_workbench.png filter=lfs diff=lfs merge=lfs -text
|
| 70 |
-
assets/example_image/typical_vehicle_biplane.png filter=lfs diff=lfs merge=lfs -text
|
| 71 |
-
assets/example_image/typical_vehicle_bulldozer.png filter=lfs diff=lfs merge=lfs -text
|
| 72 |
-
assets/example_image/typical_vehicle_cart.png filter=lfs diff=lfs merge=lfs -text
|
| 73 |
-
assets/example_image/typical_vehicle_excavator.png filter=lfs diff=lfs merge=lfs -text
|
| 74 |
-
assets/example_image/typical_vehicle_helicopter.png filter=lfs diff=lfs merge=lfs -text
|
| 75 |
-
assets/example_image/typical_vehicle_locomotive.png filter=lfs diff=lfs merge=lfs -text
|
| 76 |
-
assets/example_image/typical_vehicle_pirate_ship.png filter=lfs diff=lfs merge=lfs -text
|
| 77 |
-
assets/example_image/weatherworn_misc_paper_machine3.png filter=lfs diff=lfs merge=lfs -text
|
| 78 |
-
assets/example_multi_image/character_1.png filter=lfs diff=lfs merge=lfs -text
|
| 79 |
-
assets/example_multi_image/character_2.png filter=lfs diff=lfs merge=lfs -text
|
| 80 |
-
assets/example_multi_image/character_3.png filter=lfs diff=lfs merge=lfs -text
|
| 81 |
-
assets/example_multi_image/mushroom_1.png filter=lfs diff=lfs merge=lfs -text
|
| 82 |
-
assets/example_multi_image/mushroom_2.png filter=lfs diff=lfs merge=lfs -text
|
| 83 |
-
assets/example_multi_image/mushroom_3.png filter=lfs diff=lfs merge=lfs -text
|
| 84 |
-
assets/example_multi_image/orangeguy_1.png filter=lfs diff=lfs merge=lfs -text
|
| 85 |
-
assets/example_multi_image/orangeguy_2.png filter=lfs diff=lfs merge=lfs -text
|
| 86 |
-
assets/example_multi_image/orangeguy_3.png filter=lfs diff=lfs merge=lfs -text
|
| 87 |
-
assets/example_multi_image/popmart_1.png filter=lfs diff=lfs merge=lfs -text
|
| 88 |
-
assets/example_multi_image/popmart_2.png filter=lfs diff=lfs merge=lfs -text
|
| 89 |
-
assets/example_multi_image/popmart_3.png filter=lfs diff=lfs merge=lfs -text
|
| 90 |
-
assets/example_multi_image/rabbit_1.png filter=lfs diff=lfs merge=lfs -text
|
| 91 |
-
assets/example_multi_image/rabbit_2.png filter=lfs diff=lfs merge=lfs -text
|
| 92 |
-
assets/example_multi_image/rabbit_3.png filter=lfs diff=lfs merge=lfs -text
|
| 93 |
-
assets/example_multi_image/tiger_1.png filter=lfs diff=lfs merge=lfs -text
|
| 94 |
-
assets/example_multi_image/tiger_2.png filter=lfs diff=lfs merge=lfs -text
|
| 95 |
-
assets/example_multi_image/tiger_3.png filter=lfs diff=lfs merge=lfs -text
|
| 96 |
-
assets/example_multi_image/yoimiya_1.png filter=lfs diff=lfs merge=lfs -text
|
| 97 |
-
assets/example_multi_image/yoimiya_2.png filter=lfs diff=lfs merge=lfs -text
|
| 98 |
-
assets/example_multi_image/yoimiya_3.png filter=lfs diff=lfs merge=lfs -text
|
| 99 |
-
assets/logo.webp filter=lfs diff=lfs merge=lfs -text
|
| 100 |
-
assets/T.ply filter=lfs diff=lfs merge=lfs -text
|
| 101 |
-
assets/teaser.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.ply filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.webp filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.idea/*
|
| 2 |
+
__pycache__
|
README.md
CHANGED
|
@@ -1,12 +1,16 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: TRELLIS
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom: indigo
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned:
|
| 10 |
-
license: mit
|
| 11 |
-
short_description: 3D Generation from text
|
| 12 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: TRELLIS Text To 3D
|
| 3 |
+
emoji: 🏢
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.25.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: true
|
| 10 |
+
license: mit
|
| 11 |
+
short_description: Scalable and Versatile 3D Generation from text prompt
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 15 |
+
|
| 16 |
+
Paper: https://huggingface.co/papers/2412.01506
|
app.py
CHANGED
|
@@ -1,232 +1,274 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
-
import spaces
|
| 3 |
-
|
| 4 |
-
import os
|
| 5 |
-
import shutil
|
| 6 |
-
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
| 7 |
-
os.environ['SPCONV_ALGO'] = 'native'
|
| 8 |
-
from typing import *
|
| 9 |
-
import torch
|
| 10 |
-
import numpy as np
|
| 11 |
-
import imageio
|
| 12 |
-
from easydict import EasyDict as edict
|
| 13 |
-
from trellis.pipelines import TrellisTextTo3DPipeline
|
| 14 |
-
from trellis.representations import Gaussian, MeshExtractResult
|
| 15 |
-
from trellis.utils import render_utils, postprocessing_utils
|
| 16 |
-
|
| 17 |
-
import traceback
|
| 18 |
-
import sys
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
os.
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
os.
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
'
|
| 41 |
-
'
|
| 42 |
-
'
|
| 43 |
-
'
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
'
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
gs.
|
| 63 |
-
gs.
|
| 64 |
-
gs.
|
| 65 |
-
gs.
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
gr.
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
)
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import spaces
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import shutil
|
| 6 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
| 7 |
+
os.environ['SPCONV_ALGO'] = 'native'
|
| 8 |
+
from typing import *
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
import imageio
|
| 12 |
+
from easydict import EasyDict as edict
|
| 13 |
+
from trellis.pipelines import TrellisTextTo3DPipeline
|
| 14 |
+
from trellis.representations import Gaussian, MeshExtractResult
|
| 15 |
+
from trellis.utils import render_utils, postprocessing_utils
|
| 16 |
+
|
| 17 |
+
import traceback
|
| 18 |
+
import sys
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 22 |
+
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
|
| 23 |
+
os.makedirs(TMP_DIR, exist_ok=True)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def start_session(req: gr.Request):
|
| 27 |
+
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 28 |
+
os.makedirs(user_dir, exist_ok=True)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def end_session(req: gr.Request):
|
| 32 |
+
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 33 |
+
shutil.rmtree(user_dir)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
|
| 37 |
+
return {
|
| 38 |
+
'gaussian': {
|
| 39 |
+
**gs.init_params,
|
| 40 |
+
'_xyz': gs._xyz.cpu().numpy(),
|
| 41 |
+
'_features_dc': gs._features_dc.cpu().numpy(),
|
| 42 |
+
'_scaling': gs._scaling.cpu().numpy(),
|
| 43 |
+
'_rotation': gs._rotation.cpu().numpy(),
|
| 44 |
+
'_opacity': gs._opacity.cpu().numpy(),
|
| 45 |
+
},
|
| 46 |
+
'mesh': {
|
| 47 |
+
'vertices': mesh.vertices.cpu().numpy(),
|
| 48 |
+
'faces': mesh.faces.cpu().numpy(),
|
| 49 |
+
},
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
|
| 54 |
+
gs = Gaussian(
|
| 55 |
+
aabb=state['gaussian']['aabb'],
|
| 56 |
+
sh_degree=state['gaussian']['sh_degree'],
|
| 57 |
+
mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
|
| 58 |
+
scaling_bias=state['gaussian']['scaling_bias'],
|
| 59 |
+
opacity_bias=state['gaussian']['opacity_bias'],
|
| 60 |
+
scaling_activation=state['gaussian']['scaling_activation'],
|
| 61 |
+
)
|
| 62 |
+
gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
|
| 63 |
+
gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
|
| 64 |
+
gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
|
| 65 |
+
gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
|
| 66 |
+
gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
|
| 67 |
+
|
| 68 |
+
mesh = edict(
|
| 69 |
+
vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
|
| 70 |
+
faces=torch.tensor(state['mesh']['faces'], device='cuda'),
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
return gs, mesh
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def get_seed(randomize_seed: bool, seed: int) -> int:
|
| 77 |
+
"""
|
| 78 |
+
Get the random seed.
|
| 79 |
+
"""
|
| 80 |
+
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@spaces.GPU
|
| 84 |
+
def text_to_3d(
|
| 85 |
+
prompt: str,
|
| 86 |
+
seed: int,
|
| 87 |
+
ss_guidance_strength: float,
|
| 88 |
+
ss_sampling_steps: int,
|
| 89 |
+
slat_guidance_strength: float,
|
| 90 |
+
slat_sampling_steps: int,
|
| 91 |
+
req: gr.Request,
|
| 92 |
+
) -> Tuple[dict, str]:
|
| 93 |
+
"""
|
| 94 |
+
Convert an text prompt to a 3D model.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
prompt (str): The text prompt.
|
| 98 |
+
seed (int): The random seed.
|
| 99 |
+
ss_guidance_strength (float): The guidance strength for sparse structure generation.
|
| 100 |
+
ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
|
| 101 |
+
slat_guidance_strength (float): The guidance strength for structured latent generation.
|
| 102 |
+
slat_sampling_steps (int): The number of sampling steps for structured latent generation.
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
dict: The information of the generated 3D model.
|
| 106 |
+
str: The path to the video of the 3D model.
|
| 107 |
+
"""
|
| 108 |
+
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 109 |
+
outputs = pipeline.run(
|
| 110 |
+
prompt,
|
| 111 |
+
seed=seed,
|
| 112 |
+
formats=["gaussian", "mesh"],
|
| 113 |
+
sparse_structure_sampler_params={
|
| 114 |
+
"steps": ss_sampling_steps,
|
| 115 |
+
"cfg_strength": ss_guidance_strength,
|
| 116 |
+
},
|
| 117 |
+
slat_sampler_params={
|
| 118 |
+
"steps": slat_sampling_steps,
|
| 119 |
+
"cfg_strength": slat_guidance_strength,
|
| 120 |
+
},
|
| 121 |
+
)
|
| 122 |
+
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
|
| 123 |
+
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
|
| 124 |
+
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
|
| 125 |
+
video_path = os.path.join(user_dir, 'sample.mp4')
|
| 126 |
+
imageio.mimsave(video_path, video, fps=15)
|
| 127 |
+
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
|
| 128 |
+
torch.cuda.empty_cache()
|
| 129 |
+
return state, video_path
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@spaces.GPU(duration=90)
|
| 133 |
+
def extract_glb(
|
| 134 |
+
state: dict,
|
| 135 |
+
mesh_simplify: float,
|
| 136 |
+
texture_size: int,
|
| 137 |
+
req: gr.Request,
|
| 138 |
+
) -> Tuple[str, str]:
|
| 139 |
+
"""
|
| 140 |
+
Extract a GLB file from the 3D model.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
state (dict): The state of the generated 3D model.
|
| 144 |
+
mesh_simplify (float): The mesh simplification factor.
|
| 145 |
+
texture_size (int): The texture resolution.
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
str: The path to the extracted GLB file.
|
| 149 |
+
"""
|
| 150 |
+
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 151 |
+
gs, mesh = unpack_state(state)
|
| 152 |
+
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
|
| 153 |
+
glb_path = os.path.join(user_dir, 'sample.glb')
|
| 154 |
+
glb.export(glb_path)
|
| 155 |
+
torch.cuda.empty_cache()
|
| 156 |
+
return glb_path, glb_path
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
@spaces.GPU
|
| 160 |
+
def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
|
| 161 |
+
"""
|
| 162 |
+
Extract a Gaussian file from the 3D model.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
state (dict): The state of the generated 3D model.
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
str: The path to the extracted Gaussian file.
|
| 169 |
+
"""
|
| 170 |
+
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 171 |
+
gs, _ = unpack_state(state)
|
| 172 |
+
gaussian_path = os.path.join(user_dir, 'sample.ply')
|
| 173 |
+
gs.save_ply(gaussian_path)
|
| 174 |
+
torch.cuda.empty_cache()
|
| 175 |
+
return gaussian_path, gaussian_path
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
with gr.Blocks(delete_cache=(600, 600)) as demo:
|
| 179 |
+
gr.Markdown("""
|
| 180 |
+
## Text to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
|
| 181 |
+
* Type a text prompt and click "Generate" to create a 3D asset.
|
| 182 |
+
* If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
|
| 183 |
+
""")
|
| 184 |
+
|
| 185 |
+
with gr.Row():
|
| 186 |
+
with gr.Column():
|
| 187 |
+
text_prompt = gr.Textbox(label="Text Prompt", lines=5)
|
| 188 |
+
|
| 189 |
+
with gr.Accordion(label="Generation Settings", open=False):
|
| 190 |
+
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
|
| 191 |
+
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
| 192 |
+
gr.Markdown("Stage 1: Sparse Structure Generation")
|
| 193 |
+
with gr.Row():
|
| 194 |
+
ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
|
| 195 |
+
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
|
| 196 |
+
gr.Markdown("Stage 2: Structured Latent Generation")
|
| 197 |
+
with gr.Row():
|
| 198 |
+
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
|
| 199 |
+
slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
|
| 200 |
+
|
| 201 |
+
generate_btn = gr.Button("Generate")
|
| 202 |
+
|
| 203 |
+
with gr.Accordion(label="GLB Extraction Settings", open=False):
|
| 204 |
+
mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
|
| 205 |
+
texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
|
| 206 |
+
|
| 207 |
+
with gr.Row():
|
| 208 |
+
extract_glb_btn = gr.Button("Extract GLB", interactive=False)
|
| 209 |
+
extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
|
| 210 |
+
gr.Markdown("""
|
| 211 |
+
*NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
|
| 212 |
+
""")
|
| 213 |
+
|
| 214 |
+
with gr.Column():
|
| 215 |
+
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
|
| 216 |
+
model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
|
| 217 |
+
|
| 218 |
+
with gr.Row():
|
| 219 |
+
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
|
| 220 |
+
download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
|
| 221 |
+
|
| 222 |
+
output_buf = gr.State()
|
| 223 |
+
|
| 224 |
+
# Handlers
|
| 225 |
+
demo.load(start_session)
|
| 226 |
+
demo.unload(end_session)
|
| 227 |
+
|
| 228 |
+
generate_btn.click(
|
| 229 |
+
get_seed,
|
| 230 |
+
inputs=[randomize_seed, seed],
|
| 231 |
+
outputs=[seed],
|
| 232 |
+
).then(
|
| 233 |
+
text_to_3d,
|
| 234 |
+
inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
|
| 235 |
+
outputs=[output_buf, video_output],
|
| 236 |
+
).then(
|
| 237 |
+
lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
|
| 238 |
+
outputs=[extract_glb_btn, extract_gs_btn],
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
video_output.clear(
|
| 242 |
+
lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
|
| 243 |
+
outputs=[extract_glb_btn, extract_gs_btn],
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
extract_glb_btn.click(
|
| 247 |
+
extract_glb,
|
| 248 |
+
inputs=[output_buf, mesh_simplify, texture_size],
|
| 249 |
+
outputs=[model_output, download_glb],
|
| 250 |
+
).then(
|
| 251 |
+
lambda: gr.Button(interactive=True),
|
| 252 |
+
outputs=[download_glb],
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
extract_gs_btn.click(
|
| 256 |
+
extract_gaussian,
|
| 257 |
+
inputs=[output_buf],
|
| 258 |
+
outputs=[model_output, download_gs],
|
| 259 |
+
).then(
|
| 260 |
+
lambda: gr.Button(interactive=True),
|
| 261 |
+
outputs=[download_gs],
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
model_output.clear(
|
| 265 |
+
lambda: gr.Button(interactive=False),
|
| 266 |
+
outputs=[download_glb],
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
# Launch the Gradio app
|
| 271 |
+
if __name__ == "__main__":
|
| 272 |
+
pipeline = TrellisTextTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-text-xlarge")
|
| 273 |
+
pipeline.cuda()
|
| 274 |
+
demo.launch()
|
app_img.py
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import spaces
|
| 3 |
+
from gradio_litmodel3d import LitModel3D
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import shutil
|
| 7 |
+
os.environ['SPCONV_ALGO'] = 'native'
|
| 8 |
+
from typing import *
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
import imageio
|
| 12 |
+
from easydict import EasyDict as edict
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from trellis.pipelines import TrellisImageTo3DPipeline
|
| 15 |
+
from trellis.representations import Gaussian, MeshExtractResult
|
| 16 |
+
from trellis.utils import render_utils, postprocessing_utils
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 20 |
+
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
|
| 21 |
+
os.makedirs(TMP_DIR, exist_ok=True)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def start_session(req: gr.Request):
|
| 25 |
+
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 26 |
+
os.makedirs(user_dir, exist_ok=True)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def end_session(req: gr.Request):
|
| 30 |
+
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 31 |
+
shutil.rmtree(user_dir)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def preprocess_image(image: Image.Image) -> Image.Image:
|
| 35 |
+
"""
|
| 36 |
+
Preprocess the input image.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
image (Image.Image): The input image.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Image.Image: The preprocessed image.
|
| 43 |
+
"""
|
| 44 |
+
processed_image = pipeline.preprocess_image(image)
|
| 45 |
+
return processed_image
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
|
| 49 |
+
"""
|
| 50 |
+
Preprocess a list of input images.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
images (List[Tuple[Image.Image, str]]): The input images.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
List[Image.Image]: The preprocessed images.
|
| 57 |
+
"""
|
| 58 |
+
images = [image[0] for image in images]
|
| 59 |
+
processed_images = [pipeline.preprocess_image(image) for image in images]
|
| 60 |
+
return processed_images
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
|
| 64 |
+
return {
|
| 65 |
+
'gaussian': {
|
| 66 |
+
**gs.init_params,
|
| 67 |
+
'_xyz': gs._xyz.cpu().numpy(),
|
| 68 |
+
'_features_dc': gs._features_dc.cpu().numpy(),
|
| 69 |
+
'_scaling': gs._scaling.cpu().numpy(),
|
| 70 |
+
'_rotation': gs._rotation.cpu().numpy(),
|
| 71 |
+
'_opacity': gs._opacity.cpu().numpy(),
|
| 72 |
+
},
|
| 73 |
+
'mesh': {
|
| 74 |
+
'vertices': mesh.vertices.cpu().numpy(),
|
| 75 |
+
'faces': mesh.faces.cpu().numpy(),
|
| 76 |
+
},
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
|
| 81 |
+
gs = Gaussian(
|
| 82 |
+
aabb=state['gaussian']['aabb'],
|
| 83 |
+
sh_degree=state['gaussian']['sh_degree'],
|
| 84 |
+
mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
|
| 85 |
+
scaling_bias=state['gaussian']['scaling_bias'],
|
| 86 |
+
opacity_bias=state['gaussian']['opacity_bias'],
|
| 87 |
+
scaling_activation=state['gaussian']['scaling_activation'],
|
| 88 |
+
)
|
| 89 |
+
gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
|
| 90 |
+
gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
|
| 91 |
+
gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
|
| 92 |
+
gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
|
| 93 |
+
gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
|
| 94 |
+
|
| 95 |
+
mesh = edict(
|
| 96 |
+
vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
|
| 97 |
+
faces=torch.tensor(state['mesh']['faces'], device='cuda'),
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
return gs, mesh
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def get_seed(randomize_seed: bool, seed: int) -> int:
|
| 104 |
+
"""
|
| 105 |
+
Get the random seed.
|
| 106 |
+
"""
|
| 107 |
+
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@spaces.GPU
|
| 111 |
+
def image_to_3d(
|
| 112 |
+
image: Image.Image,
|
| 113 |
+
multiimages: List[Tuple[Image.Image, str]],
|
| 114 |
+
is_multiimage: bool,
|
| 115 |
+
seed: int,
|
| 116 |
+
ss_guidance_strength: float,
|
| 117 |
+
ss_sampling_steps: int,
|
| 118 |
+
slat_guidance_strength: float,
|
| 119 |
+
slat_sampling_steps: int,
|
| 120 |
+
multiimage_algo: Literal["multidiffusion", "stochastic"],
|
| 121 |
+
req: gr.Request,
|
| 122 |
+
) -> Tuple[dict, str]:
|
| 123 |
+
"""
|
| 124 |
+
Convert an image to a 3D model.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
image (Image.Image): The input image.
|
| 128 |
+
multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
|
| 129 |
+
is_multiimage (bool): Whether is in multi-image mode.
|
| 130 |
+
seed (int): The random seed.
|
| 131 |
+
ss_guidance_strength (float): The guidance strength for sparse structure generation.
|
| 132 |
+
ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
|
| 133 |
+
slat_guidance_strength (float): The guidance strength for structured latent generation.
|
| 134 |
+
slat_sampling_steps (int): The number of sampling steps for structured latent generation.
|
| 135 |
+
multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
dict: The information of the generated 3D model.
|
| 139 |
+
str: The path to the video of the 3D model.
|
| 140 |
+
"""
|
| 141 |
+
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 142 |
+
if not is_multiimage:
|
| 143 |
+
outputs = pipeline.run(
|
| 144 |
+
image,
|
| 145 |
+
seed=seed,
|
| 146 |
+
formats=["gaussian", "mesh"],
|
| 147 |
+
preprocess_image=False,
|
| 148 |
+
sparse_structure_sampler_params={
|
| 149 |
+
"steps": ss_sampling_steps,
|
| 150 |
+
"cfg_strength": ss_guidance_strength,
|
| 151 |
+
},
|
| 152 |
+
slat_sampler_params={
|
| 153 |
+
"steps": slat_sampling_steps,
|
| 154 |
+
"cfg_strength": slat_guidance_strength,
|
| 155 |
+
},
|
| 156 |
+
)
|
| 157 |
+
else:
|
| 158 |
+
outputs = pipeline.run_multi_image(
|
| 159 |
+
[image[0] for image in multiimages],
|
| 160 |
+
seed=seed,
|
| 161 |
+
formats=["gaussian", "mesh"],
|
| 162 |
+
preprocess_image=False,
|
| 163 |
+
sparse_structure_sampler_params={
|
| 164 |
+
"steps": ss_sampling_steps,
|
| 165 |
+
"cfg_strength": ss_guidance_strength,
|
| 166 |
+
},
|
| 167 |
+
slat_sampler_params={
|
| 168 |
+
"steps": slat_sampling_steps,
|
| 169 |
+
"cfg_strength": slat_guidance_strength,
|
| 170 |
+
},
|
| 171 |
+
mode=multiimage_algo,
|
| 172 |
+
)
|
| 173 |
+
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
|
| 174 |
+
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
|
| 175 |
+
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
|
| 176 |
+
video_path = os.path.join(user_dir, 'sample.mp4')
|
| 177 |
+
imageio.mimsave(video_path, video, fps=15)
|
| 178 |
+
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
|
| 179 |
+
torch.cuda.empty_cache()
|
| 180 |
+
return state, video_path
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
@spaces.GPU(duration=90)
|
| 184 |
+
def extract_glb(
|
| 185 |
+
state: dict,
|
| 186 |
+
mesh_simplify: float,
|
| 187 |
+
texture_size: int,
|
| 188 |
+
req: gr.Request,
|
| 189 |
+
) -> Tuple[str, str]:
|
| 190 |
+
"""
|
| 191 |
+
Extract a GLB file from the 3D model.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
state (dict): The state of the generated 3D model.
|
| 195 |
+
mesh_simplify (float): The mesh simplification factor.
|
| 196 |
+
texture_size (int): The texture resolution.
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
str: The path to the extracted GLB file.
|
| 200 |
+
"""
|
| 201 |
+
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 202 |
+
gs, mesh = unpack_state(state)
|
| 203 |
+
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
|
| 204 |
+
glb_path = os.path.join(user_dir, 'sample.glb')
|
| 205 |
+
glb.export(glb_path)
|
| 206 |
+
torch.cuda.empty_cache()
|
| 207 |
+
return glb_path, glb_path
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@spaces.GPU
|
| 211 |
+
def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
|
| 212 |
+
"""
|
| 213 |
+
Extract a Gaussian file from the 3D model.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
state (dict): The state of the generated 3D model.
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
str: The path to the extracted Gaussian file.
|
| 220 |
+
"""
|
| 221 |
+
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 222 |
+
gs, _ = unpack_state(state)
|
| 223 |
+
gaussian_path = os.path.join(user_dir, 'sample.ply')
|
| 224 |
+
gs.save_ply(gaussian_path)
|
| 225 |
+
torch.cuda.empty_cache()
|
| 226 |
+
return gaussian_path, gaussian_path
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def prepare_multi_example() -> List[Image.Image]:
|
| 230 |
+
multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
|
| 231 |
+
images = []
|
| 232 |
+
for case in multi_case:
|
| 233 |
+
_images = []
|
| 234 |
+
for i in range(1, 4):
|
| 235 |
+
img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
|
| 236 |
+
W, H = img.size
|
| 237 |
+
img = img.resize((int(W / H * 512), 512))
|
| 238 |
+
_images.append(np.array(img))
|
| 239 |
+
images.append(Image.fromarray(np.concatenate(_images, axis=1)))
|
| 240 |
+
return images
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def split_image(image: Image.Image) -> List[Image.Image]:
|
| 244 |
+
"""
|
| 245 |
+
Split an image into multiple views.
|
| 246 |
+
"""
|
| 247 |
+
image = np.array(image)
|
| 248 |
+
alpha = image[..., 3]
|
| 249 |
+
alpha = np.any(alpha>0, axis=0)
|
| 250 |
+
start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
|
| 251 |
+
end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
|
| 252 |
+
images = []
|
| 253 |
+
for s, e in zip(start_pos, end_pos):
|
| 254 |
+
images.append(Image.fromarray(image[:, s:e+1]))
|
| 255 |
+
return [preprocess_image(image) for image in images]
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
with gr.Blocks(delete_cache=(600, 600)) as demo:
|
| 259 |
+
gr.Markdown("""
|
| 260 |
+
## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
|
| 261 |
+
* Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
|
| 262 |
+
* If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
|
| 263 |
+
|
| 264 |
+
✨New: 1) Experimental multi-image support. 2) Gaussian file extraction.
|
| 265 |
+
""")
|
| 266 |
+
|
| 267 |
+
with gr.Row():
|
| 268 |
+
with gr.Column():
|
| 269 |
+
with gr.Tabs() as input_tabs:
|
| 270 |
+
with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
|
| 271 |
+
image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
|
| 272 |
+
with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
|
| 273 |
+
multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
|
| 274 |
+
gr.Markdown("""
|
| 275 |
+
Input different views of the object in separate images.
|
| 276 |
+
|
| 277 |
+
*NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
|
| 278 |
+
""")
|
| 279 |
+
|
| 280 |
+
with gr.Accordion(label="Generation Settings", open=False):
|
| 281 |
+
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
|
| 282 |
+
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
| 283 |
+
gr.Markdown("Stage 1: Sparse Structure Generation")
|
| 284 |
+
with gr.Row():
|
| 285 |
+
ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
|
| 286 |
+
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
|
| 287 |
+
gr.Markdown("Stage 2: Structured Latent Generation")
|
| 288 |
+
with gr.Row():
|
| 289 |
+
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
|
| 290 |
+
slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
|
| 291 |
+
multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
|
| 292 |
+
|
| 293 |
+
generate_btn = gr.Button("Generate")
|
| 294 |
+
|
| 295 |
+
with gr.Accordion(label="GLB Extraction Settings", open=False):
|
| 296 |
+
mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
|
| 297 |
+
texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
|
| 298 |
+
|
| 299 |
+
with gr.Row():
|
| 300 |
+
extract_glb_btn = gr.Button("Extract GLB", interactive=False)
|
| 301 |
+
extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
|
| 302 |
+
gr.Markdown("""
|
| 303 |
+
*NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
|
| 304 |
+
""")
|
| 305 |
+
|
| 306 |
+
with gr.Column():
|
| 307 |
+
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
|
| 308 |
+
model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
|
| 309 |
+
|
| 310 |
+
with gr.Row():
|
| 311 |
+
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
|
| 312 |
+
download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
|
| 313 |
+
|
| 314 |
+
is_multiimage = gr.State(False)
|
| 315 |
+
output_buf = gr.State()
|
| 316 |
+
|
| 317 |
+
# Example images at the bottom of the page
|
| 318 |
+
with gr.Row() as single_image_example:
|
| 319 |
+
examples = gr.Examples(
|
| 320 |
+
examples=[
|
| 321 |
+
f'assets/example_image/{image}'
|
| 322 |
+
for image in os.listdir("assets/example_image")
|
| 323 |
+
],
|
| 324 |
+
inputs=[image_prompt],
|
| 325 |
+
fn=preprocess_image,
|
| 326 |
+
outputs=[image_prompt],
|
| 327 |
+
run_on_click=True,
|
| 328 |
+
examples_per_page=64,
|
| 329 |
+
)
|
| 330 |
+
with gr.Row(visible=False) as multiimage_example:
|
| 331 |
+
examples_multi = gr.Examples(
|
| 332 |
+
examples=prepare_multi_example(),
|
| 333 |
+
inputs=[image_prompt],
|
| 334 |
+
fn=split_image,
|
| 335 |
+
outputs=[multiimage_prompt],
|
| 336 |
+
run_on_click=True,
|
| 337 |
+
examples_per_page=8,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
# Handlers
|
| 341 |
+
demo.load(start_session)
|
| 342 |
+
demo.unload(end_session)
|
| 343 |
+
|
| 344 |
+
single_image_input_tab.select(
|
| 345 |
+
lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]),
|
| 346 |
+
outputs=[is_multiimage, single_image_example, multiimage_example]
|
| 347 |
+
)
|
| 348 |
+
multiimage_input_tab.select(
|
| 349 |
+
lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]),
|
| 350 |
+
outputs=[is_multiimage, single_image_example, multiimage_example]
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
image_prompt.upload(
|
| 354 |
+
preprocess_image,
|
| 355 |
+
inputs=[image_prompt],
|
| 356 |
+
outputs=[image_prompt],
|
| 357 |
+
)
|
| 358 |
+
multiimage_prompt.upload(
|
| 359 |
+
preprocess_images,
|
| 360 |
+
inputs=[multiimage_prompt],
|
| 361 |
+
outputs=[multiimage_prompt],
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
generate_btn.click(
|
| 365 |
+
get_seed,
|
| 366 |
+
inputs=[randomize_seed, seed],
|
| 367 |
+
outputs=[seed],
|
| 368 |
+
).then(
|
| 369 |
+
image_to_3d,
|
| 370 |
+
inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
|
| 371 |
+
outputs=[output_buf, video_output],
|
| 372 |
+
).then(
|
| 373 |
+
lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
|
| 374 |
+
outputs=[extract_glb_btn, extract_gs_btn],
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
video_output.clear(
|
| 378 |
+
lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
|
| 379 |
+
outputs=[extract_glb_btn, extract_gs_btn],
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
extract_glb_btn.click(
|
| 383 |
+
extract_glb,
|
| 384 |
+
inputs=[output_buf, mesh_simplify, texture_size],
|
| 385 |
+
outputs=[model_output, download_glb],
|
| 386 |
+
).then(
|
| 387 |
+
lambda: gr.Button(interactive=True),
|
| 388 |
+
outputs=[download_glb],
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
extract_gs_btn.click(
|
| 392 |
+
extract_gaussian,
|
| 393 |
+
inputs=[output_buf],
|
| 394 |
+
outputs=[model_output, download_gs],
|
| 395 |
+
).then(
|
| 396 |
+
lambda: gr.Button(interactive=True),
|
| 397 |
+
outputs=[download_gs],
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
model_output.clear(
|
| 401 |
+
lambda: gr.Button(interactive=False),
|
| 402 |
+
outputs=[download_glb],
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
# Launch the Gradio app
|
| 407 |
+
if __name__ == "__main__":
|
| 408 |
+
pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
|
| 409 |
+
pipeline.cuda()
|
| 410 |
+
try:
|
| 411 |
+
pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
|
| 412 |
+
except:
|
| 413 |
+
pass
|
| 414 |
+
demo.launch()
|
configs/generation/slat_flow_img_dit_L_64l8p2_fp16.json
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"models": {
|
| 3 |
+
"denoiser": {
|
| 4 |
+
"name": "ElasticSLatFlowModel",
|
| 5 |
+
"args": {
|
| 6 |
+
"resolution": 64,
|
| 7 |
+
"in_channels": 8,
|
| 8 |
+
"out_channels": 8,
|
| 9 |
+
"model_channels": 1024,
|
| 10 |
+
"cond_channels": 1024,
|
| 11 |
+
"num_blocks": 24,
|
| 12 |
+
"num_heads": 16,
|
| 13 |
+
"mlp_ratio": 4,
|
| 14 |
+
"patch_size": 2,
|
| 15 |
+
"num_io_res_blocks": 2,
|
| 16 |
+
"io_block_channels": [128],
|
| 17 |
+
"pe_mode": "ape",
|
| 18 |
+
"qk_rms_norm": true,
|
| 19 |
+
"use_fp16": true
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
},
|
| 23 |
+
"dataset": {
|
| 24 |
+
"name": "ImageConditionedSLat",
|
| 25 |
+
"args": {
|
| 26 |
+
"latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
|
| 27 |
+
"min_aesthetic_score": 4.5,
|
| 28 |
+
"max_num_voxels": 32768,
|
| 29 |
+
"image_size": 518,
|
| 30 |
+
"normalization": {
|
| 31 |
+
"mean": [
|
| 32 |
+
-2.1687545776367188,
|
| 33 |
+
-0.004347046371549368,
|
| 34 |
+
-0.13352349400520325,
|
| 35 |
+
-0.08418072760105133,
|
| 36 |
+
-0.5271206498146057,
|
| 37 |
+
0.7238689064979553,
|
| 38 |
+
-1.1414450407028198,
|
| 39 |
+
1.2039363384246826
|
| 40 |
+
],
|
| 41 |
+
"std": [
|
| 42 |
+
2.377650737762451,
|
| 43 |
+
2.386378288269043,
|
| 44 |
+
2.124418020248413,
|
| 45 |
+
2.1748552322387695,
|
| 46 |
+
2.663944721221924,
|
| 47 |
+
2.371192216873169,
|
| 48 |
+
2.6217446327209473,
|
| 49 |
+
2.684523105621338
|
| 50 |
+
]
|
| 51 |
+
},
|
| 52 |
+
"pretrained_slat_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16"
|
| 53 |
+
}
|
| 54 |
+
},
|
| 55 |
+
"trainer": {
|
| 56 |
+
"name": "ImageConditionedSparseFlowMatchingCFGTrainer",
|
| 57 |
+
"args": {
|
| 58 |
+
"max_steps": 1000000,
|
| 59 |
+
"batch_size_per_gpu": 8,
|
| 60 |
+
"batch_split": 4,
|
| 61 |
+
"optimizer": {
|
| 62 |
+
"name": "AdamW",
|
| 63 |
+
"args": {
|
| 64 |
+
"lr": 0.0001,
|
| 65 |
+
"weight_decay": 0.0
|
| 66 |
+
}
|
| 67 |
+
},
|
| 68 |
+
"ema_rate": [
|
| 69 |
+
0.9999
|
| 70 |
+
],
|
| 71 |
+
"fp16_mode": "inflat_all",
|
| 72 |
+
"fp16_scale_growth": 0.001,
|
| 73 |
+
"elastic": {
|
| 74 |
+
"name": "LinearMemoryController",
|
| 75 |
+
"args": {
|
| 76 |
+
"target_ratio": 0.75,
|
| 77 |
+
"max_mem_ratio_start": 0.5
|
| 78 |
+
}
|
| 79 |
+
},
|
| 80 |
+
"grad_clip": {
|
| 81 |
+
"name": "AdaptiveGradClipper",
|
| 82 |
+
"args": {
|
| 83 |
+
"max_norm": 1.0,
|
| 84 |
+
"clip_percentile": 95
|
| 85 |
+
}
|
| 86 |
+
},
|
| 87 |
+
"i_log": 500,
|
| 88 |
+
"i_sample": 10000,
|
| 89 |
+
"i_save": 10000,
|
| 90 |
+
"p_uncond": 0.1,
|
| 91 |
+
"t_schedule": {
|
| 92 |
+
"name": "logitNormal",
|
| 93 |
+
"args": {
|
| 94 |
+
"mean": 1.0,
|
| 95 |
+
"std": 1.0
|
| 96 |
+
}
|
| 97 |
+
},
|
| 98 |
+
"sigma_min": 1e-5,
|
| 99 |
+
"image_cond_model": "dinov2_vitl14_reg"
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
}
|
configs/generation/slat_flow_txt_dit_B_64l8p2_fp16.json
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"models": {
|
| 3 |
+
"denoiser": {
|
| 4 |
+
"name": "ElasticSLatFlowModel",
|
| 5 |
+
"args": {
|
| 6 |
+
"resolution": 64,
|
| 7 |
+
"in_channels": 8,
|
| 8 |
+
"out_channels": 8,
|
| 9 |
+
"model_channels": 768,
|
| 10 |
+
"cond_channels": 768,
|
| 11 |
+
"num_blocks": 12,
|
| 12 |
+
"num_heads": 12,
|
| 13 |
+
"mlp_ratio": 4,
|
| 14 |
+
"patch_size": 2,
|
| 15 |
+
"num_io_res_blocks": 2,
|
| 16 |
+
"io_block_channels": [128],
|
| 17 |
+
"pe_mode": "ape",
|
| 18 |
+
"qk_rms_norm": true,
|
| 19 |
+
"use_fp16": true
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
},
|
| 23 |
+
"dataset": {
|
| 24 |
+
"name": "TextConditionedSLat",
|
| 25 |
+
"args": {
|
| 26 |
+
"latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
|
| 27 |
+
"min_aesthetic_score": 4.5,
|
| 28 |
+
"max_num_voxels": 32768,
|
| 29 |
+
"normalization": {
|
| 30 |
+
"mean": [
|
| 31 |
+
-2.1687545776367188,
|
| 32 |
+
-0.004347046371549368,
|
| 33 |
+
-0.13352349400520325,
|
| 34 |
+
-0.08418072760105133,
|
| 35 |
+
-0.5271206498146057,
|
| 36 |
+
0.7238689064979553,
|
| 37 |
+
-1.1414450407028198,
|
| 38 |
+
1.2039363384246826
|
| 39 |
+
],
|
| 40 |
+
"std": [
|
| 41 |
+
2.377650737762451,
|
| 42 |
+
2.386378288269043,
|
| 43 |
+
2.124418020248413,
|
| 44 |
+
2.1748552322387695,
|
| 45 |
+
2.663944721221924,
|
| 46 |
+
2.371192216873169,
|
| 47 |
+
2.6217446327209473,
|
| 48 |
+
2.684523105621338
|
| 49 |
+
]
|
| 50 |
+
},
|
| 51 |
+
"pretrained_slat_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16"
|
| 52 |
+
}
|
| 53 |
+
},
|
| 54 |
+
"trainer": {
|
| 55 |
+
"name": "TextConditionedSparseFlowMatchingCFGTrainer",
|
| 56 |
+
"args": {
|
| 57 |
+
"max_steps": 1000000,
|
| 58 |
+
"batch_size_per_gpu": 16,
|
| 59 |
+
"batch_split": 4,
|
| 60 |
+
"optimizer": {
|
| 61 |
+
"name": "AdamW",
|
| 62 |
+
"args": {
|
| 63 |
+
"lr": 0.0001,
|
| 64 |
+
"weight_decay": 0.0
|
| 65 |
+
}
|
| 66 |
+
},
|
| 67 |
+
"ema_rate": [
|
| 68 |
+
0.9999
|
| 69 |
+
],
|
| 70 |
+
"fp16_mode": "inflat_all",
|
| 71 |
+
"fp16_scale_growth": 0.001,
|
| 72 |
+
"elastic": {
|
| 73 |
+
"name": "LinearMemoryController",
|
| 74 |
+
"args": {
|
| 75 |
+
"target_ratio": 0.75,
|
| 76 |
+
"max_mem_ratio_start": 0.5
|
| 77 |
+
}
|
| 78 |
+
},
|
| 79 |
+
"grad_clip": {
|
| 80 |
+
"name": "AdaptiveGradClipper",
|
| 81 |
+
"args": {
|
| 82 |
+
"max_norm": 1.0,
|
| 83 |
+
"clip_percentile": 95
|
| 84 |
+
}
|
| 85 |
+
},
|
| 86 |
+
"i_log": 500,
|
| 87 |
+
"i_sample": 10000,
|
| 88 |
+
"i_save": 10000,
|
| 89 |
+
"p_uncond": 0.1,
|
| 90 |
+
"t_schedule": {
|
| 91 |
+
"name": "logitNormal",
|
| 92 |
+
"args": {
|
| 93 |
+
"mean": 1.0,
|
| 94 |
+
"std": 1.0
|
| 95 |
+
}
|
| 96 |
+
},
|
| 97 |
+
"sigma_min": 1e-5,
|
| 98 |
+
"text_cond_model": "openai/clip-vit-large-patch14"
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
}
|
configs/generation/slat_flow_txt_dit_L_64l8p2_fp16.json
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"models": {
|
| 3 |
+
"denoiser": {
|
| 4 |
+
"name": "ElasticSLatFlowModel",
|
| 5 |
+
"args": {
|
| 6 |
+
"resolution": 64,
|
| 7 |
+
"in_channels": 8,
|
| 8 |
+
"out_channels": 8,
|
| 9 |
+
"model_channels": 1024,
|
| 10 |
+
"cond_channels": 768,
|
| 11 |
+
"num_blocks": 24,
|
| 12 |
+
"num_heads": 16,
|
| 13 |
+
"mlp_ratio": 4,
|
| 14 |
+
"patch_size": 2,
|
| 15 |
+
"num_io_res_blocks": 2,
|
| 16 |
+
"io_block_channels": [128],
|
| 17 |
+
"pe_mode": "ape",
|
| 18 |
+
"qk_rms_norm": true,
|
| 19 |
+
"use_fp16": true
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
},
|
| 23 |
+
"dataset": {
|
| 24 |
+
"name": "TextConditionedSLat",
|
| 25 |
+
"args": {
|
| 26 |
+
"latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
|
| 27 |
+
"min_aesthetic_score": 4.5,
|
| 28 |
+
"max_num_voxels": 32768,
|
| 29 |
+
"normalization": {
|
| 30 |
+
"mean": [
|
| 31 |
+
-2.1687545776367188,
|
| 32 |
+
-0.004347046371549368,
|
| 33 |
+
-0.13352349400520325,
|
| 34 |
+
-0.08418072760105133,
|
| 35 |
+
-0.5271206498146057,
|
| 36 |
+
0.7238689064979553,
|
| 37 |
+
-1.1414450407028198,
|
| 38 |
+
1.2039363384246826
|
| 39 |
+
],
|
| 40 |
+
"std": [
|
| 41 |
+
2.377650737762451,
|
| 42 |
+
2.386378288269043,
|
| 43 |
+
2.124418020248413,
|
| 44 |
+
2.1748552322387695,
|
| 45 |
+
2.663944721221924,
|
| 46 |
+
2.371192216873169,
|
| 47 |
+
2.6217446327209473,
|
| 48 |
+
2.684523105621338
|
| 49 |
+
]
|
| 50 |
+
},
|
| 51 |
+
"pretrained_slat_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16"
|
| 52 |
+
}
|
| 53 |
+
},
|
| 54 |
+
"trainer": {
|
| 55 |
+
"name": "TextConditionedSparseFlowMatchingCFGTrainer",
|
| 56 |
+
"args": {
|
| 57 |
+
"max_steps": 1000000,
|
| 58 |
+
"batch_size_per_gpu": 8,
|
| 59 |
+
"batch_split": 4,
|
| 60 |
+
"optimizer": {
|
| 61 |
+
"name": "AdamW",
|
| 62 |
+
"args": {
|
| 63 |
+
"lr": 0.0001,
|
| 64 |
+
"weight_decay": 0.0
|
| 65 |
+
}
|
| 66 |
+
},
|
| 67 |
+
"ema_rate": [
|
| 68 |
+
0.9999
|
| 69 |
+
],
|
| 70 |
+
"fp16_mode": "inflat_all",
|
| 71 |
+
"fp16_scale_growth": 0.001,
|
| 72 |
+
"elastic": {
|
| 73 |
+
"name": "LinearMemoryController",
|
| 74 |
+
"args": {
|
| 75 |
+
"target_ratio": 0.75,
|
| 76 |
+
"max_mem_ratio_start": 0.5
|
| 77 |
+
}
|
| 78 |
+
},
|
| 79 |
+
"grad_clip": {
|
| 80 |
+
"name": "AdaptiveGradClipper",
|
| 81 |
+
"args": {
|
| 82 |
+
"max_norm": 1.0,
|
| 83 |
+
"clip_percentile": 95
|
| 84 |
+
}
|
| 85 |
+
},
|
| 86 |
+
"i_log": 500,
|
| 87 |
+
"i_sample": 10000,
|
| 88 |
+
"i_save": 10000,
|
| 89 |
+
"p_uncond": 0.1,
|
| 90 |
+
"t_schedule": {
|
| 91 |
+
"name": "logitNormal",
|
| 92 |
+
"args": {
|
| 93 |
+
"mean": 1.0,
|
| 94 |
+
"std": 1.0
|
| 95 |
+
}
|
| 96 |
+
},
|
| 97 |
+
"sigma_min": 1e-5,
|
| 98 |
+
"text_cond_model": "openai/clip-vit-large-patch14"
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
}
|
configs/generation/slat_flow_txt_dit_XL_64l8p2_fp16.json
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"models": {
|
| 3 |
+
"denoiser": {
|
| 4 |
+
"name": "ElasticSLatFlowModel",
|
| 5 |
+
"args": {
|
| 6 |
+
"resolution": 64,
|
| 7 |
+
"in_channels": 8,
|
| 8 |
+
"out_channels": 8,
|
| 9 |
+
"model_channels": 1280,
|
| 10 |
+
"cond_channels": 768,
|
| 11 |
+
"num_blocks": 28,
|
| 12 |
+
"num_heads": 16,
|
| 13 |
+
"mlp_ratio": 4,
|
| 14 |
+
"patch_size": 2,
|
| 15 |
+
"num_io_res_blocks": 3,
|
| 16 |
+
"io_block_channels": [256],
|
| 17 |
+
"pe_mode": "ape",
|
| 18 |
+
"qk_rms_norm": true,
|
| 19 |
+
"use_fp16": true
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
},
|
| 23 |
+
"dataset": {
|
| 24 |
+
"name": "TextConditionedSLat",
|
| 25 |
+
"args": {
|
| 26 |
+
"latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
|
| 27 |
+
"min_aesthetic_score": 4.5,
|
| 28 |
+
"max_num_voxels": 32768,
|
| 29 |
+
"normalization": {
|
| 30 |
+
"mean": [
|
| 31 |
+
-2.1687545776367188,
|
| 32 |
+
-0.004347046371549368,
|
| 33 |
+
-0.13352349400520325,
|
| 34 |
+
-0.08418072760105133,
|
| 35 |
+
-0.5271206498146057,
|
| 36 |
+
0.7238689064979553,
|
| 37 |
+
-1.1414450407028198,
|
| 38 |
+
1.2039363384246826
|
| 39 |
+
],
|
| 40 |
+
"std": [
|
| 41 |
+
2.377650737762451,
|
| 42 |
+
2.386378288269043,
|
| 43 |
+
2.124418020248413,
|
| 44 |
+
2.1748552322387695,
|
| 45 |
+
2.663944721221924,
|
| 46 |
+
2.371192216873169,
|
| 47 |
+
2.6217446327209473,
|
| 48 |
+
2.684523105621338
|
| 49 |
+
]
|
| 50 |
+
},
|
| 51 |
+
"pretrained_slat_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16"
|
| 52 |
+
}
|
| 53 |
+
},
|
| 54 |
+
"trainer": {
|
| 55 |
+
"name": "TextConditionedSparseFlowMatchingCFGTrainer",
|
| 56 |
+
"args": {
|
| 57 |
+
"max_steps": 1000000,
|
| 58 |
+
"batch_size_per_gpu": 4,
|
| 59 |
+
"batch_split": 4,
|
| 60 |
+
"optimizer": {
|
| 61 |
+
"name": "AdamW",
|
| 62 |
+
"args": {
|
| 63 |
+
"lr": 0.0001,
|
| 64 |
+
"weight_decay": 0.0
|
| 65 |
+
}
|
| 66 |
+
},
|
| 67 |
+
"ema_rate": [
|
| 68 |
+
0.9999
|
| 69 |
+
],
|
| 70 |
+
"fp16_mode": "inflat_all",
|
| 71 |
+
"fp16_scale_growth": 0.001,
|
| 72 |
+
"elastic": {
|
| 73 |
+
"name": "LinearMemoryController",
|
| 74 |
+
"args": {
|
| 75 |
+
"target_ratio": 0.75,
|
| 76 |
+
"max_mem_ratio_start": 0.5
|
| 77 |
+
}
|
| 78 |
+
},
|
| 79 |
+
"grad_clip": {
|
| 80 |
+
"name": "AdaptiveGradClipper",
|
| 81 |
+
"args": {
|
| 82 |
+
"max_norm": 1.0,
|
| 83 |
+
"clip_percentile": 95
|
| 84 |
+
}
|
| 85 |
+
},
|
| 86 |
+
"i_log": 500,
|
| 87 |
+
"i_sample": 10000,
|
| 88 |
+
"i_save": 10000,
|
| 89 |
+
"p_uncond": 0.1,
|
| 90 |
+
"t_schedule": {
|
| 91 |
+
"name": "logitNormal",
|
| 92 |
+
"args": {
|
| 93 |
+
"mean": 1.0,
|
| 94 |
+
"std": 1.0
|
| 95 |
+
}
|
| 96 |
+
},
|
| 97 |
+
"sigma_min": 1e-5,
|
| 98 |
+
"text_cond_model": "openai/clip-vit-large-patch14"
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
}
|
configs/generation/ss_flow_img_dit_L_16l8_fp16.json
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"models": {
|
| 3 |
+
"denoiser": {
|
| 4 |
+
"name": "SparseStructureFlowModel",
|
| 5 |
+
"args": {
|
| 6 |
+
"resolution": 16,
|
| 7 |
+
"in_channels": 8,
|
| 8 |
+
"out_channels": 8,
|
| 9 |
+
"model_channels": 1024,
|
| 10 |
+
"cond_channels": 1024,
|
| 11 |
+
"num_blocks": 24,
|
| 12 |
+
"num_heads": 16,
|
| 13 |
+
"mlp_ratio": 4,
|
| 14 |
+
"patch_size": 1,
|
| 15 |
+
"pe_mode": "ape",
|
| 16 |
+
"qk_rms_norm": true,
|
| 17 |
+
"use_fp16": true
|
| 18 |
+
}
|
| 19 |
+
}
|
| 20 |
+
},
|
| 21 |
+
"dataset": {
|
| 22 |
+
"name": "ImageConditionedSparseStructureLatent",
|
| 23 |
+
"args": {
|
| 24 |
+
"latent_model": "ss_enc_conv3d_16l8_fp16",
|
| 25 |
+
"min_aesthetic_score": 4.5,
|
| 26 |
+
"image_size": 518,
|
| 27 |
+
"pretrained_ss_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16"
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
"trainer": {
|
| 31 |
+
"name": "ImageConditionedFlowMatchingCFGTrainer",
|
| 32 |
+
"args": {
|
| 33 |
+
"max_steps": 1000000,
|
| 34 |
+
"batch_size_per_gpu": 8,
|
| 35 |
+
"batch_split": 1,
|
| 36 |
+
"optimizer": {
|
| 37 |
+
"name": "AdamW",
|
| 38 |
+
"args": {
|
| 39 |
+
"lr": 0.0001,
|
| 40 |
+
"weight_decay": 0.0
|
| 41 |
+
}
|
| 42 |
+
},
|
| 43 |
+
"ema_rate": [
|
| 44 |
+
0.9999
|
| 45 |
+
],
|
| 46 |
+
"fp16_mode": "inflat_all",
|
| 47 |
+
"fp16_scale_growth": 0.001,
|
| 48 |
+
"grad_clip": {
|
| 49 |
+
"name": "AdaptiveGradClipper",
|
| 50 |
+
"args": {
|
| 51 |
+
"max_norm": 1.0,
|
| 52 |
+
"clip_percentile": 95
|
| 53 |
+
}
|
| 54 |
+
},
|
| 55 |
+
"i_log": 500,
|
| 56 |
+
"i_sample": 10000,
|
| 57 |
+
"i_save": 10000,
|
| 58 |
+
"p_uncond": 0.1,
|
| 59 |
+
"t_schedule": {
|
| 60 |
+
"name": "logitNormal",
|
| 61 |
+
"args": {
|
| 62 |
+
"mean": 1.0,
|
| 63 |
+
"std": 1.0
|
| 64 |
+
}
|
| 65 |
+
},
|
| 66 |
+
"sigma_min": 1e-5,
|
| 67 |
+
"image_cond_model": "dinov2_vitl14_reg"
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
}
|
configs/generation/ss_flow_txt_dit_B_16l8_fp16.json
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"models": {
|
| 3 |
+
"denoiser": {
|
| 4 |
+
"name": "SparseStructureFlowModel",
|
| 5 |
+
"args": {
|
| 6 |
+
"resolution": 16,
|
| 7 |
+
"in_channels": 8,
|
| 8 |
+
"out_channels": 8,
|
| 9 |
+
"model_channels": 768,
|
| 10 |
+
"cond_channels": 768,
|
| 11 |
+
"num_blocks": 12,
|
| 12 |
+
"num_heads": 12,
|
| 13 |
+
"mlp_ratio": 4,
|
| 14 |
+
"patch_size": 1,
|
| 15 |
+
"pe_mode": "ape",
|
| 16 |
+
"qk_rms_norm": true,
|
| 17 |
+
"use_fp16": true
|
| 18 |
+
}
|
| 19 |
+
}
|
| 20 |
+
},
|
| 21 |
+
"dataset": {
|
| 22 |
+
"name": "TextConditionedSparseStructureLatent",
|
| 23 |
+
"args": {
|
| 24 |
+
"latent_model": "ss_enc_conv3d_16l8_fp16",
|
| 25 |
+
"min_aesthetic_score": 4.5,
|
| 26 |
+
"pretrained_ss_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16"
|
| 27 |
+
}
|
| 28 |
+
},
|
| 29 |
+
"trainer": {
|
| 30 |
+
"name": "TextConditionedFlowMatchingCFGTrainer",
|
| 31 |
+
"args": {
|
| 32 |
+
"max_steps": 1000000,
|
| 33 |
+
"batch_size_per_gpu": 16,
|
| 34 |
+
"batch_split": 1,
|
| 35 |
+
"optimizer": {
|
| 36 |
+
"name": "AdamW",
|
| 37 |
+
"args": {
|
| 38 |
+
"lr": 0.0001,
|
| 39 |
+
"weight_decay": 0.0
|
| 40 |
+
}
|
| 41 |
+
},
|
| 42 |
+
"ema_rate": [
|
| 43 |
+
0.9999
|
| 44 |
+
],
|
| 45 |
+
"fp16_mode": "inflat_all",
|
| 46 |
+
"fp16_scale_growth": 0.001,
|
| 47 |
+
"grad_clip": {
|
| 48 |
+
"name": "AdaptiveGradClipper",
|
| 49 |
+
"args": {
|
| 50 |
+
"max_norm": 1.0,
|
| 51 |
+
"clip_percentile": 95
|
| 52 |
+
}
|
| 53 |
+
},
|
| 54 |
+
"i_log": 500,
|
| 55 |
+
"i_sample": 10000,
|
| 56 |
+
"i_save": 10000,
|
| 57 |
+
"p_uncond": 0.1,
|
| 58 |
+
"t_schedule": {
|
| 59 |
+
"name": "logitNormal",
|
| 60 |
+
"args": {
|
| 61 |
+
"mean": 1.0,
|
| 62 |
+
"std": 1.0
|
| 63 |
+
}
|
| 64 |
+
},
|
| 65 |
+
"sigma_min": 1e-5,
|
| 66 |
+
"text_cond_model": "openai/clip-vit-large-patch14"
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
}
|
configs/generation/ss_flow_txt_dit_L_16l8_fp16.json
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"models": {
|
| 3 |
+
"denoiser": {
|
| 4 |
+
"name": "SparseStructureFlowModel",
|
| 5 |
+
"args": {
|
| 6 |
+
"resolution": 16,
|
| 7 |
+
"in_channels": 8,
|
| 8 |
+
"out_channels": 8,
|
| 9 |
+
"model_channels": 1024,
|
| 10 |
+
"cond_channels": 768,
|
| 11 |
+
"num_blocks": 24,
|
| 12 |
+
"num_heads": 16,
|
| 13 |
+
"mlp_ratio": 4,
|
| 14 |
+
"patch_size": 1,
|
| 15 |
+
"pe_mode": "ape",
|
| 16 |
+
"qk_rms_norm": true,
|
| 17 |
+
"use_fp16": true
|
| 18 |
+
}
|
| 19 |
+
}
|
| 20 |
+
},
|
| 21 |
+
"dataset": {
|
| 22 |
+
"name": "TextConditionedSparseStructureLatent",
|
| 23 |
+
"args": {
|
| 24 |
+
"latent_model": "ss_enc_conv3d_16l8_fp16",
|
| 25 |
+
"min_aesthetic_score": 4.5,
|
| 26 |
+
"pretrained_ss_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16"
|
| 27 |
+
}
|
| 28 |
+
},
|
| 29 |
+
"trainer": {
|
| 30 |
+
"name": "TextConditionedFlowMatchingCFGTrainer",
|
| 31 |
+
"args": {
|
| 32 |
+
"max_steps": 1000000,
|
| 33 |
+
"batch_size_per_gpu": 8,
|
| 34 |
+
"batch_split": 1,
|
| 35 |
+
"optimizer": {
|
| 36 |
+
"name": "AdamW",
|
| 37 |
+
"args": {
|
| 38 |
+
"lr": 0.0001,
|
| 39 |
+
"weight_decay": 0.0
|
| 40 |
+
}
|
| 41 |
+
},
|
| 42 |
+
"ema_rate": [
|
| 43 |
+
0.9999
|
| 44 |
+
],
|
| 45 |
+
"fp16_mode": "inflat_all",
|
| 46 |
+
"fp16_scale_growth": 0.001,
|
| 47 |
+
"grad_clip": {
|
| 48 |
+
"name": "AdaptiveGradClipper",
|
| 49 |
+
"args": {
|
| 50 |
+
"max_norm": 1.0,
|
| 51 |
+
"clip_percentile": 95
|
| 52 |
+
}
|
| 53 |
+
},
|
| 54 |
+
"i_log": 500,
|
| 55 |
+
"i_sample": 10000,
|
| 56 |
+
"i_save": 10000,
|
| 57 |
+
"p_uncond": 0.1,
|
| 58 |
+
"t_schedule": {
|
| 59 |
+
"name": "logitNormal",
|
| 60 |
+
"args": {
|
| 61 |
+
"mean": 1.0,
|
| 62 |
+
"std": 1.0
|
| 63 |
+
}
|
| 64 |
+
},
|
| 65 |
+
"sigma_min": 1e-5,
|
| 66 |
+
"text_cond_model": "openai/clip-vit-large-patch14"
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
}
|
configs/generation/ss_flow_txt_dit_XL_16l8_fp16.json
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"models": {
|
| 3 |
+
"denoiser": {
|
| 4 |
+
"name": "SparseStructureFlowModel",
|
| 5 |
+
"args": {
|
| 6 |
+
"resolution": 16,
|
| 7 |
+
"in_channels": 8,
|
| 8 |
+
"out_channels": 8,
|
| 9 |
+
"model_channels": 1280,
|
| 10 |
+
"cond_channels": 768,
|
| 11 |
+
"num_blocks": 28,
|
| 12 |
+
"num_heads": 16,
|
| 13 |
+
"mlp_ratio": 4,
|
| 14 |
+
"patch_size": 1,
|
| 15 |
+
"pe_mode": "ape",
|
| 16 |
+
"qk_rms_norm": true,
|
| 17 |
+
"qk_rms_norm_cross": true,
|
| 18 |
+
"use_fp16": true
|
| 19 |
+
}
|
| 20 |
+
}
|
| 21 |
+
},
|
| 22 |
+
"dataset": {
|
| 23 |
+
"name": "TextConditionedSparseStructureLatent",
|
| 24 |
+
"args": {
|
| 25 |
+
"latent_model": "ss_enc_conv3d_16l8_fp16",
|
| 26 |
+
"min_aesthetic_score": 4.5,
|
| 27 |
+
"pretrained_ss_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16"
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
"trainer": {
|
| 31 |
+
"name": "TextConditionedFlowMatchingCFGTrainer",
|
| 32 |
+
"args": {
|
| 33 |
+
"max_steps": 1000000,
|
| 34 |
+
"batch_size_per_gpu": 4,
|
| 35 |
+
"batch_split": 1,
|
| 36 |
+
"optimizer": {
|
| 37 |
+
"name": "AdamW",
|
| 38 |
+
"args": {
|
| 39 |
+
"lr": 0.0001,
|
| 40 |
+
"weight_decay": 0.0
|
| 41 |
+
}
|
| 42 |
+
},
|
| 43 |
+
"ema_rate": [
|
| 44 |
+
0.9999
|
| 45 |
+
],
|
| 46 |
+
"fp16_mode": "inflat_all",
|
| 47 |
+
"fp16_scale_growth": 0.001,
|
| 48 |
+
"grad_clip": {
|
| 49 |
+
"name": "AdaptiveGradClipper",
|
| 50 |
+
"args": {
|
| 51 |
+
"max_norm": 1.0,
|
| 52 |
+
"clip_percentile": 95
|
| 53 |
+
}
|
| 54 |
+
},
|
| 55 |
+
"i_log": 500,
|
| 56 |
+
"i_sample": 10000,
|
| 57 |
+
"i_save": 10000,
|
| 58 |
+
"p_uncond": 0.1,
|
| 59 |
+
"t_schedule": {
|
| 60 |
+
"name": "logitNormal",
|
| 61 |
+
"args": {
|
| 62 |
+
"mean": 1.0,
|
| 63 |
+
"std": 1.0
|
| 64 |
+
}
|
| 65 |
+
},
|
| 66 |
+
"sigma_min": 1e-5,
|
| 67 |
+
"text_cond_model": "openai/clip-vit-large-patch14"
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
}
|
configs/vae/slat_vae_dec_mesh_swin8_B_64l8_fp16.json
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"models": {
|
| 3 |
+
"decoder": {
|
| 4 |
+
"name": "ElasticSLatMeshDecoder",
|
| 5 |
+
"args": {
|
| 6 |
+
"resolution": 64,
|
| 7 |
+
"model_channels": 768,
|
| 8 |
+
"latent_channels": 8,
|
| 9 |
+
"num_blocks": 12,
|
| 10 |
+
"num_heads": 12,
|
| 11 |
+
"mlp_ratio": 4,
|
| 12 |
+
"attn_mode": "swin",
|
| 13 |
+
"window_size": 8,
|
| 14 |
+
"use_fp16": true,
|
| 15 |
+
"representation_config": {
|
| 16 |
+
"use_color": true
|
| 17 |
+
}
|
| 18 |
+
}
|
| 19 |
+
}
|
| 20 |
+
},
|
| 21 |
+
"dataset": {
|
| 22 |
+
"name": "Slat2RenderGeo",
|
| 23 |
+
"args": {
|
| 24 |
+
"image_size": 512,
|
| 25 |
+
"latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
|
| 26 |
+
"min_aesthetic_score": 4.5,
|
| 27 |
+
"max_num_voxels": 32768
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
"trainer": {
|
| 31 |
+
"name": "SLatVaeMeshDecoderTrainer",
|
| 32 |
+
"args": {
|
| 33 |
+
"max_steps": 1000000,
|
| 34 |
+
"batch_size_per_gpu": 4,
|
| 35 |
+
"batch_split": 4,
|
| 36 |
+
"optimizer": {
|
| 37 |
+
"name": "AdamW",
|
| 38 |
+
"args": {
|
| 39 |
+
"lr": 1e-4,
|
| 40 |
+
"weight_decay": 0.0
|
| 41 |
+
}
|
| 42 |
+
},
|
| 43 |
+
"ema_rate": [
|
| 44 |
+
0.9999
|
| 45 |
+
],
|
| 46 |
+
"fp16_mode": "inflat_all",
|
| 47 |
+
"fp16_scale_growth": 0.001,
|
| 48 |
+
"elastic": {
|
| 49 |
+
"name": "LinearMemoryController",
|
| 50 |
+
"args": {
|
| 51 |
+
"target_ratio": 0.75,
|
| 52 |
+
"max_mem_ratio_start": 0.5
|
| 53 |
+
}
|
| 54 |
+
},
|
| 55 |
+
"grad_clip": {
|
| 56 |
+
"name": "AdaptiveGradClipper",
|
| 57 |
+
"args": {
|
| 58 |
+
"max_norm": 1.0,
|
| 59 |
+
"clip_percentile": 95
|
| 60 |
+
}
|
| 61 |
+
},
|
| 62 |
+
"i_log": 500,
|
| 63 |
+
"i_sample": 10000,
|
| 64 |
+
"i_save": 10000,
|
| 65 |
+
"lambda_ssim": 0.2,
|
| 66 |
+
"lambda_lpips": 0.2,
|
| 67 |
+
"lambda_tsdf": 0.01,
|
| 68 |
+
"lambda_depth": 10.0,
|
| 69 |
+
"lambda_color": 0.1,
|
| 70 |
+
"depth_loss_type": "smooth_l1"
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
}
|
configs/vae/slat_vae_dec_rf_swin8_B_64l8_fp16.json
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"models": {
|
| 3 |
+
"decoder": {
|
| 4 |
+
"name": "ElasticSLatRadianceFieldDecoder",
|
| 5 |
+
"args": {
|
| 6 |
+
"resolution": 64,
|
| 7 |
+
"model_channels": 768,
|
| 8 |
+
"latent_channels": 8,
|
| 9 |
+
"num_blocks": 12,
|
| 10 |
+
"num_heads": 12,
|
| 11 |
+
"mlp_ratio": 4,
|
| 12 |
+
"attn_mode": "swin",
|
| 13 |
+
"window_size": 8,
|
| 14 |
+
"use_fp16": true,
|
| 15 |
+
"representation_config": {
|
| 16 |
+
"rank": 16,
|
| 17 |
+
"dim": 8
|
| 18 |
+
}
|
| 19 |
+
}
|
| 20 |
+
}
|
| 21 |
+
},
|
| 22 |
+
"dataset": {
|
| 23 |
+
"name": "SLat2Render",
|
| 24 |
+
"args": {
|
| 25 |
+
"image_size": 512,
|
| 26 |
+
"latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
|
| 27 |
+
"min_aesthetic_score": 4.5,
|
| 28 |
+
"max_num_voxels": 32768
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
"trainer": {
|
| 32 |
+
"name": "SLatVaeRadianceFieldDecoderTrainer",
|
| 33 |
+
"args": {
|
| 34 |
+
"max_steps": 1000000,
|
| 35 |
+
"batch_size_per_gpu": 4,
|
| 36 |
+
"batch_split": 2,
|
| 37 |
+
"optimizer": {
|
| 38 |
+
"name": "AdamW",
|
| 39 |
+
"args": {
|
| 40 |
+
"lr": 1e-4,
|
| 41 |
+
"weight_decay": 0.0
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"ema_rate": [
|
| 45 |
+
0.9999
|
| 46 |
+
],
|
| 47 |
+
"fp16_mode": "inflat_all",
|
| 48 |
+
"fp16_scale_growth": 0.001,
|
| 49 |
+
"elastic": {
|
| 50 |
+
"name": "LinearMemoryController",
|
| 51 |
+
"args": {
|
| 52 |
+
"target_ratio": 0.75,
|
| 53 |
+
"max_mem_ratio_start": 0.5
|
| 54 |
+
}
|
| 55 |
+
},
|
| 56 |
+
"grad_clip": {
|
| 57 |
+
"name": "AdaptiveGradClipper",
|
| 58 |
+
"args": {
|
| 59 |
+
"max_norm": 1.0,
|
| 60 |
+
"clip_percentile": 95
|
| 61 |
+
}
|
| 62 |
+
},
|
| 63 |
+
"i_log": 500,
|
| 64 |
+
"i_sample": 10000,
|
| 65 |
+
"i_save": 10000,
|
| 66 |
+
"loss_type": "l1",
|
| 67 |
+
"lambda_ssim": 0.2,
|
| 68 |
+
"lambda_lpips": 0.2
|
| 69 |
+
}
|
| 70 |
+
}
|
| 71 |
+
}
|
configs/vae/slat_vae_enc_dec_gs_swin8_B_64l8_fp16.json
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"models": {
|
| 3 |
+
"encoder": {
|
| 4 |
+
"name": "ElasticSLatEncoder",
|
| 5 |
+
"args": {
|
| 6 |
+
"resolution": 64,
|
| 7 |
+
"in_channels": 1024,
|
| 8 |
+
"model_channels": 768,
|
| 9 |
+
"latent_channels": 8,
|
| 10 |
+
"num_blocks": 12,
|
| 11 |
+
"num_heads": 12,
|
| 12 |
+
"mlp_ratio": 4,
|
| 13 |
+
"attn_mode": "swin",
|
| 14 |
+
"window_size": 8,
|
| 15 |
+
"use_fp16": true
|
| 16 |
+
}
|
| 17 |
+
},
|
| 18 |
+
"decoder": {
|
| 19 |
+
"name": "ElasticSLatGaussianDecoder",
|
| 20 |
+
"args": {
|
| 21 |
+
"resolution": 64,
|
| 22 |
+
"model_channels": 768,
|
| 23 |
+
"latent_channels": 8,
|
| 24 |
+
"num_blocks": 12,
|
| 25 |
+
"num_heads": 12,
|
| 26 |
+
"mlp_ratio": 4,
|
| 27 |
+
"attn_mode": "swin",
|
| 28 |
+
"window_size": 8,
|
| 29 |
+
"use_fp16": true,
|
| 30 |
+
"representation_config": {
|
| 31 |
+
"lr": {
|
| 32 |
+
"_xyz": 1.0,
|
| 33 |
+
"_features_dc": 1.0,
|
| 34 |
+
"_opacity": 1.0,
|
| 35 |
+
"_scaling": 1.0,
|
| 36 |
+
"_rotation": 0.1
|
| 37 |
+
},
|
| 38 |
+
"perturb_offset": true,
|
| 39 |
+
"voxel_size": 1.5,
|
| 40 |
+
"num_gaussians": 32,
|
| 41 |
+
"2d_filter_kernel_size": 0.1,
|
| 42 |
+
"3d_filter_kernel_size": 9e-4,
|
| 43 |
+
"scaling_bias": 4e-3,
|
| 44 |
+
"opacity_bias": 0.1,
|
| 45 |
+
"scaling_activation": "softplus"
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
},
|
| 50 |
+
"dataset": {
|
| 51 |
+
"name": "SparseFeat2Render",
|
| 52 |
+
"args": {
|
| 53 |
+
"image_size": 512,
|
| 54 |
+
"model": "dinov2_vitl14_reg",
|
| 55 |
+
"resolution": 64,
|
| 56 |
+
"min_aesthetic_score": 4.5,
|
| 57 |
+
"max_num_voxels": 32768
|
| 58 |
+
}
|
| 59 |
+
},
|
| 60 |
+
"trainer": {
|
| 61 |
+
"name": "SLatVaeGaussianTrainer",
|
| 62 |
+
"args": {
|
| 63 |
+
"max_steps": 1000000,
|
| 64 |
+
"batch_size_per_gpu": 4,
|
| 65 |
+
"batch_split": 2,
|
| 66 |
+
"optimizer": {
|
| 67 |
+
"name": "AdamW",
|
| 68 |
+
"args": {
|
| 69 |
+
"lr": 1e-4,
|
| 70 |
+
"weight_decay": 0.0
|
| 71 |
+
}
|
| 72 |
+
},
|
| 73 |
+
"ema_rate": [
|
| 74 |
+
0.9999
|
| 75 |
+
],
|
| 76 |
+
"fp16_mode": "inflat_all",
|
| 77 |
+
"fp16_scale_growth": 0.001,
|
| 78 |
+
"elastic": {
|
| 79 |
+
"name": "LinearMemoryController",
|
| 80 |
+
"args": {
|
| 81 |
+
"target_ratio": 0.75,
|
| 82 |
+
"max_mem_ratio_start": 0.5
|
| 83 |
+
}
|
| 84 |
+
},
|
| 85 |
+
"grad_clip": {
|
| 86 |
+
"name": "AdaptiveGradClipper",
|
| 87 |
+
"args": {
|
| 88 |
+
"max_norm": 1.0,
|
| 89 |
+
"clip_percentile": 95
|
| 90 |
+
}
|
| 91 |
+
},
|
| 92 |
+
"i_log": 500,
|
| 93 |
+
"i_sample": 10000,
|
| 94 |
+
"i_save": 10000,
|
| 95 |
+
"loss_type": "l1",
|
| 96 |
+
"lambda_ssim": 0.2,
|
| 97 |
+
"lambda_lpips": 0.2,
|
| 98 |
+
"lambda_kl": 1e-06,
|
| 99 |
+
"regularizations": {
|
| 100 |
+
"lambda_vol": 10000.0,
|
| 101 |
+
"lambda_opacity": 0.001
|
| 102 |
+
}
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
}
|
configs/vae/ss_vae_conv3d_16l8_fp16.json
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"models": {
|
| 3 |
+
"encoder": {
|
| 4 |
+
"name": "SparseStructureEncoder",
|
| 5 |
+
"args": {
|
| 6 |
+
"in_channels": 1,
|
| 7 |
+
"latent_channels": 8,
|
| 8 |
+
"num_res_blocks": 2,
|
| 9 |
+
"num_res_blocks_middle": 2,
|
| 10 |
+
"channels": [32, 128, 512],
|
| 11 |
+
"use_fp16": true
|
| 12 |
+
}
|
| 13 |
+
},
|
| 14 |
+
"decoder": {
|
| 15 |
+
"name": "SparseStructureDecoder",
|
| 16 |
+
"args": {
|
| 17 |
+
"out_channels": 1,
|
| 18 |
+
"latent_channels": 8,
|
| 19 |
+
"num_res_blocks": 2,
|
| 20 |
+
"num_res_blocks_middle": 2,
|
| 21 |
+
"channels": [512, 128, 32],
|
| 22 |
+
"use_fp16": true
|
| 23 |
+
}
|
| 24 |
+
}
|
| 25 |
+
},
|
| 26 |
+
"dataset": {
|
| 27 |
+
"name": "SparseStructure",
|
| 28 |
+
"args": {
|
| 29 |
+
"resolution": 64,
|
| 30 |
+
"min_aesthetic_score": 4.5
|
| 31 |
+
}
|
| 32 |
+
},
|
| 33 |
+
"trainer": {
|
| 34 |
+
"name": "SparseStructureVaeTrainer",
|
| 35 |
+
"args": {
|
| 36 |
+
"max_steps": 1000000,
|
| 37 |
+
"batch_size_per_gpu": 4,
|
| 38 |
+
"batch_split": 1,
|
| 39 |
+
"optimizer": {
|
| 40 |
+
"name": "AdamW",
|
| 41 |
+
"args": {
|
| 42 |
+
"lr": 1e-4,
|
| 43 |
+
"weight_decay": 0.0
|
| 44 |
+
}
|
| 45 |
+
},
|
| 46 |
+
"ema_rate": [
|
| 47 |
+
0.9999
|
| 48 |
+
],
|
| 49 |
+
"fp16_mode": "inflat_all",
|
| 50 |
+
"fp16_scale_growth": 0.001,
|
| 51 |
+
"grad_clip": {
|
| 52 |
+
"name": "AdaptiveGradClipper",
|
| 53 |
+
"args": {
|
| 54 |
+
"max_norm": 1.0,
|
| 55 |
+
"clip_percentile": 95
|
| 56 |
+
}
|
| 57 |
+
},
|
| 58 |
+
"i_log": 500,
|
| 59 |
+
"i_sample": 10000,
|
| 60 |
+
"i_save": 10000,
|
| 61 |
+
"loss_type": "dice",
|
| 62 |
+
"lambda_kl": 0.001
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
}
|
dataset_toolkits/blender_script/io_scene_usdz.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ec07ab6125fe0a021ed08c64169eceda126330401aba3d494d5203d26ac4b093
|
| 3 |
+
size 34685
|
dataset_toolkits/blender_script/render.py
ADDED
|
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse, sys, os, math, re, glob
|
| 2 |
+
from typing import *
|
| 3 |
+
import bpy
|
| 4 |
+
from mathutils import Vector, Matrix
|
| 5 |
+
import numpy as np
|
| 6 |
+
import json
|
| 7 |
+
import glob
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
"""=============== BLENDER ==============="""
|
| 11 |
+
|
| 12 |
+
IMPORT_FUNCTIONS: Dict[str, Callable] = {
|
| 13 |
+
"obj": bpy.ops.import_scene.obj,
|
| 14 |
+
"glb": bpy.ops.import_scene.gltf,
|
| 15 |
+
"gltf": bpy.ops.import_scene.gltf,
|
| 16 |
+
"usd": bpy.ops.import_scene.usd,
|
| 17 |
+
"fbx": bpy.ops.import_scene.fbx,
|
| 18 |
+
"stl": bpy.ops.import_mesh.stl,
|
| 19 |
+
"usda": bpy.ops.import_scene.usda,
|
| 20 |
+
"dae": bpy.ops.wm.collada_import,
|
| 21 |
+
"ply": bpy.ops.import_mesh.ply,
|
| 22 |
+
"abc": bpy.ops.wm.alembic_import,
|
| 23 |
+
"blend": bpy.ops.wm.append,
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
EXT = {
|
| 27 |
+
'PNG': 'png',
|
| 28 |
+
'JPEG': 'jpg',
|
| 29 |
+
'OPEN_EXR': 'exr',
|
| 30 |
+
'TIFF': 'tiff',
|
| 31 |
+
'BMP': 'bmp',
|
| 32 |
+
'HDR': 'hdr',
|
| 33 |
+
'TARGA': 'tga'
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
def init_render(engine='CYCLES', resolution=512, geo_mode=False):
|
| 37 |
+
bpy.context.scene.render.engine = engine
|
| 38 |
+
bpy.context.scene.render.resolution_x = resolution
|
| 39 |
+
bpy.context.scene.render.resolution_y = resolution
|
| 40 |
+
bpy.context.scene.render.resolution_percentage = 100
|
| 41 |
+
bpy.context.scene.render.image_settings.file_format = 'PNG'
|
| 42 |
+
bpy.context.scene.render.image_settings.color_mode = 'RGBA'
|
| 43 |
+
bpy.context.scene.render.film_transparent = True
|
| 44 |
+
|
| 45 |
+
bpy.context.scene.cycles.device = 'GPU'
|
| 46 |
+
bpy.context.scene.cycles.samples = 128 if not geo_mode else 1
|
| 47 |
+
bpy.context.scene.cycles.filter_type = 'BOX'
|
| 48 |
+
bpy.context.scene.cycles.filter_width = 1
|
| 49 |
+
bpy.context.scene.cycles.diffuse_bounces = 1
|
| 50 |
+
bpy.context.scene.cycles.glossy_bounces = 1
|
| 51 |
+
bpy.context.scene.cycles.transparent_max_bounces = 3 if not geo_mode else 0
|
| 52 |
+
bpy.context.scene.cycles.transmission_bounces = 3 if not geo_mode else 1
|
| 53 |
+
bpy.context.scene.cycles.use_denoising = True
|
| 54 |
+
|
| 55 |
+
bpy.context.preferences.addons['cycles'].preferences.get_devices()
|
| 56 |
+
bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA'
|
| 57 |
+
|
| 58 |
+
def init_nodes(save_depth=False, save_normal=False, save_albedo=False, save_mist=False):
|
| 59 |
+
if not any([save_depth, save_normal, save_albedo, save_mist]):
|
| 60 |
+
return {}, {}
|
| 61 |
+
outputs = {}
|
| 62 |
+
spec_nodes = {}
|
| 63 |
+
|
| 64 |
+
bpy.context.scene.use_nodes = True
|
| 65 |
+
bpy.context.scene.view_layers['View Layer'].use_pass_z = save_depth
|
| 66 |
+
bpy.context.scene.view_layers['View Layer'].use_pass_normal = save_normal
|
| 67 |
+
bpy.context.scene.view_layers['View Layer'].use_pass_diffuse_color = save_albedo
|
| 68 |
+
bpy.context.scene.view_layers['View Layer'].use_pass_mist = save_mist
|
| 69 |
+
|
| 70 |
+
nodes = bpy.context.scene.node_tree.nodes
|
| 71 |
+
links = bpy.context.scene.node_tree.links
|
| 72 |
+
for n in nodes:
|
| 73 |
+
nodes.remove(n)
|
| 74 |
+
|
| 75 |
+
render_layers = nodes.new('CompositorNodeRLayers')
|
| 76 |
+
|
| 77 |
+
if save_depth:
|
| 78 |
+
depth_file_output = nodes.new('CompositorNodeOutputFile')
|
| 79 |
+
depth_file_output.base_path = ''
|
| 80 |
+
depth_file_output.file_slots[0].use_node_format = True
|
| 81 |
+
depth_file_output.format.file_format = 'PNG'
|
| 82 |
+
depth_file_output.format.color_depth = '16'
|
| 83 |
+
depth_file_output.format.color_mode = 'BW'
|
| 84 |
+
# Remap to 0-1
|
| 85 |
+
map = nodes.new(type="CompositorNodeMapRange")
|
| 86 |
+
map.inputs[1].default_value = 0 # (min value you will be getting)
|
| 87 |
+
map.inputs[2].default_value = 10 # (max value you will be getting)
|
| 88 |
+
map.inputs[3].default_value = 0 # (min value you will map to)
|
| 89 |
+
map.inputs[4].default_value = 1 # (max value you will map to)
|
| 90 |
+
|
| 91 |
+
links.new(render_layers.outputs['Depth'], map.inputs[0])
|
| 92 |
+
links.new(map.outputs[0], depth_file_output.inputs[0])
|
| 93 |
+
|
| 94 |
+
outputs['depth'] = depth_file_output
|
| 95 |
+
spec_nodes['depth_map'] = map
|
| 96 |
+
|
| 97 |
+
if save_normal:
|
| 98 |
+
normal_file_output = nodes.new('CompositorNodeOutputFile')
|
| 99 |
+
normal_file_output.base_path = ''
|
| 100 |
+
normal_file_output.file_slots[0].use_node_format = True
|
| 101 |
+
normal_file_output.format.file_format = 'OPEN_EXR'
|
| 102 |
+
normal_file_output.format.color_mode = 'RGB'
|
| 103 |
+
normal_file_output.format.color_depth = '16'
|
| 104 |
+
|
| 105 |
+
links.new(render_layers.outputs['Normal'], normal_file_output.inputs[0])
|
| 106 |
+
|
| 107 |
+
outputs['normal'] = normal_file_output
|
| 108 |
+
|
| 109 |
+
if save_albedo:
|
| 110 |
+
albedo_file_output = nodes.new('CompositorNodeOutputFile')
|
| 111 |
+
albedo_file_output.base_path = ''
|
| 112 |
+
albedo_file_output.file_slots[0].use_node_format = True
|
| 113 |
+
albedo_file_output.format.file_format = 'PNG'
|
| 114 |
+
albedo_file_output.format.color_mode = 'RGBA'
|
| 115 |
+
albedo_file_output.format.color_depth = '8'
|
| 116 |
+
|
| 117 |
+
alpha_albedo = nodes.new('CompositorNodeSetAlpha')
|
| 118 |
+
|
| 119 |
+
links.new(render_layers.outputs['DiffCol'], alpha_albedo.inputs['Image'])
|
| 120 |
+
links.new(render_layers.outputs['Alpha'], alpha_albedo.inputs['Alpha'])
|
| 121 |
+
links.new(alpha_albedo.outputs['Image'], albedo_file_output.inputs[0])
|
| 122 |
+
|
| 123 |
+
outputs['albedo'] = albedo_file_output
|
| 124 |
+
|
| 125 |
+
if save_mist:
|
| 126 |
+
bpy.data.worlds['World'].mist_settings.start = 0
|
| 127 |
+
bpy.data.worlds['World'].mist_settings.depth = 10
|
| 128 |
+
|
| 129 |
+
mist_file_output = nodes.new('CompositorNodeOutputFile')
|
| 130 |
+
mist_file_output.base_path = ''
|
| 131 |
+
mist_file_output.file_slots[0].use_node_format = True
|
| 132 |
+
mist_file_output.format.file_format = 'PNG'
|
| 133 |
+
mist_file_output.format.color_mode = 'BW'
|
| 134 |
+
mist_file_output.format.color_depth = '16'
|
| 135 |
+
|
| 136 |
+
links.new(render_layers.outputs['Mist'], mist_file_output.inputs[0])
|
| 137 |
+
|
| 138 |
+
outputs['mist'] = mist_file_output
|
| 139 |
+
|
| 140 |
+
return outputs, spec_nodes
|
| 141 |
+
|
| 142 |
+
def init_scene() -> None:
|
| 143 |
+
"""Resets the scene to a clean state.
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
None
|
| 147 |
+
"""
|
| 148 |
+
# delete everything
|
| 149 |
+
for obj in bpy.data.objects:
|
| 150 |
+
bpy.data.objects.remove(obj, do_unlink=True)
|
| 151 |
+
|
| 152 |
+
# delete all the materials
|
| 153 |
+
for material in bpy.data.materials:
|
| 154 |
+
bpy.data.materials.remove(material, do_unlink=True)
|
| 155 |
+
|
| 156 |
+
# delete all the textures
|
| 157 |
+
for texture in bpy.data.textures:
|
| 158 |
+
bpy.data.textures.remove(texture, do_unlink=True)
|
| 159 |
+
|
| 160 |
+
# delete all the images
|
| 161 |
+
for image in bpy.data.images:
|
| 162 |
+
bpy.data.images.remove(image, do_unlink=True)
|
| 163 |
+
|
| 164 |
+
def init_camera():
|
| 165 |
+
cam = bpy.data.objects.new('Camera', bpy.data.cameras.new('Camera'))
|
| 166 |
+
bpy.context.collection.objects.link(cam)
|
| 167 |
+
bpy.context.scene.camera = cam
|
| 168 |
+
cam.data.sensor_height = cam.data.sensor_width = 32
|
| 169 |
+
cam_constraint = cam.constraints.new(type='TRACK_TO')
|
| 170 |
+
cam_constraint.track_axis = 'TRACK_NEGATIVE_Z'
|
| 171 |
+
cam_constraint.up_axis = 'UP_Y'
|
| 172 |
+
cam_empty = bpy.data.objects.new("Empty", None)
|
| 173 |
+
cam_empty.location = (0, 0, 0)
|
| 174 |
+
bpy.context.scene.collection.objects.link(cam_empty)
|
| 175 |
+
cam_constraint.target = cam_empty
|
| 176 |
+
return cam
|
| 177 |
+
|
| 178 |
+
def init_lighting():
|
| 179 |
+
# Clear existing lights
|
| 180 |
+
bpy.ops.object.select_all(action="DESELECT")
|
| 181 |
+
bpy.ops.object.select_by_type(type="LIGHT")
|
| 182 |
+
bpy.ops.object.delete()
|
| 183 |
+
|
| 184 |
+
# Create key light
|
| 185 |
+
default_light = bpy.data.objects.new("Default_Light", bpy.data.lights.new("Default_Light", type="POINT"))
|
| 186 |
+
bpy.context.collection.objects.link(default_light)
|
| 187 |
+
default_light.data.energy = 1000
|
| 188 |
+
default_light.location = (4, 1, 6)
|
| 189 |
+
default_light.rotation_euler = (0, 0, 0)
|
| 190 |
+
|
| 191 |
+
# create top light
|
| 192 |
+
top_light = bpy.data.objects.new("Top_Light", bpy.data.lights.new("Top_Light", type="AREA"))
|
| 193 |
+
bpy.context.collection.objects.link(top_light)
|
| 194 |
+
top_light.data.energy = 10000
|
| 195 |
+
top_light.location = (0, 0, 10)
|
| 196 |
+
top_light.scale = (100, 100, 100)
|
| 197 |
+
|
| 198 |
+
# create bottom light
|
| 199 |
+
bottom_light = bpy.data.objects.new("Bottom_Light", bpy.data.lights.new("Bottom_Light", type="AREA"))
|
| 200 |
+
bpy.context.collection.objects.link(bottom_light)
|
| 201 |
+
bottom_light.data.energy = 1000
|
| 202 |
+
bottom_light.location = (0, 0, -10)
|
| 203 |
+
bottom_light.rotation_euler = (0, 0, 0)
|
| 204 |
+
|
| 205 |
+
return {
|
| 206 |
+
"default_light": default_light,
|
| 207 |
+
"top_light": top_light,
|
| 208 |
+
"bottom_light": bottom_light
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def load_object(object_path: str) -> None:
|
| 213 |
+
"""Loads a model with a supported file extension into the scene.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
object_path (str): Path to the model file.
|
| 217 |
+
|
| 218 |
+
Raises:
|
| 219 |
+
ValueError: If the file extension is not supported.
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
None
|
| 223 |
+
"""
|
| 224 |
+
file_extension = object_path.split(".")[-1].lower()
|
| 225 |
+
if file_extension is None:
|
| 226 |
+
raise ValueError(f"Unsupported file type: {object_path}")
|
| 227 |
+
|
| 228 |
+
if file_extension == "usdz":
|
| 229 |
+
# install usdz io package
|
| 230 |
+
dirname = os.path.dirname(os.path.realpath(__file__))
|
| 231 |
+
usdz_package = os.path.join(dirname, "io_scene_usdz.zip")
|
| 232 |
+
bpy.ops.preferences.addon_install(filepath=usdz_package)
|
| 233 |
+
# enable it
|
| 234 |
+
addon_name = "io_scene_usdz"
|
| 235 |
+
bpy.ops.preferences.addon_enable(module=addon_name)
|
| 236 |
+
# import the usdz
|
| 237 |
+
from io_scene_usdz.import_usdz import import_usdz
|
| 238 |
+
|
| 239 |
+
import_usdz(context, filepath=object_path, materials=True, animations=True)
|
| 240 |
+
return None
|
| 241 |
+
|
| 242 |
+
# load from existing import functions
|
| 243 |
+
import_function = IMPORT_FUNCTIONS[file_extension]
|
| 244 |
+
|
| 245 |
+
print(f"Loading object from {object_path}")
|
| 246 |
+
if file_extension == "blend":
|
| 247 |
+
import_function(directory=object_path, link=False)
|
| 248 |
+
elif file_extension in {"glb", "gltf"}:
|
| 249 |
+
import_function(filepath=object_path, merge_vertices=True, import_shading='NORMALS')
|
| 250 |
+
else:
|
| 251 |
+
import_function(filepath=object_path)
|
| 252 |
+
|
| 253 |
+
def delete_invisible_objects() -> None:
|
| 254 |
+
"""Deletes all invisible objects in the scene.
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
None
|
| 258 |
+
"""
|
| 259 |
+
# bpy.ops.object.mode_set(mode="OBJECT")
|
| 260 |
+
bpy.ops.object.select_all(action="DESELECT")
|
| 261 |
+
for obj in bpy.context.scene.objects:
|
| 262 |
+
if obj.hide_viewport or obj.hide_render:
|
| 263 |
+
obj.hide_viewport = False
|
| 264 |
+
obj.hide_render = False
|
| 265 |
+
obj.hide_select = False
|
| 266 |
+
obj.select_set(True)
|
| 267 |
+
bpy.ops.object.delete()
|
| 268 |
+
|
| 269 |
+
# Delete invisible collections
|
| 270 |
+
invisible_collections = [col for col in bpy.data.collections if col.hide_viewport]
|
| 271 |
+
for col in invisible_collections:
|
| 272 |
+
bpy.data.collections.remove(col)
|
| 273 |
+
|
| 274 |
+
def split_mesh_normal():
|
| 275 |
+
bpy.ops.object.select_all(action="DESELECT")
|
| 276 |
+
objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"]
|
| 277 |
+
bpy.context.view_layer.objects.active = objs[0]
|
| 278 |
+
for obj in objs:
|
| 279 |
+
obj.select_set(True)
|
| 280 |
+
bpy.ops.object.mode_set(mode="EDIT")
|
| 281 |
+
bpy.ops.mesh.select_all(action='SELECT')
|
| 282 |
+
bpy.ops.mesh.split_normals()
|
| 283 |
+
bpy.ops.object.mode_set(mode='OBJECT')
|
| 284 |
+
bpy.ops.object.select_all(action="DESELECT")
|
| 285 |
+
|
| 286 |
+
def delete_custom_normals():
|
| 287 |
+
for this_obj in bpy.data.objects:
|
| 288 |
+
if this_obj.type == "MESH":
|
| 289 |
+
bpy.context.view_layer.objects.active = this_obj
|
| 290 |
+
bpy.ops.mesh.customdata_custom_splitnormals_clear()
|
| 291 |
+
|
| 292 |
+
def override_material():
|
| 293 |
+
new_mat = bpy.data.materials.new(name="Override0123456789")
|
| 294 |
+
new_mat.use_nodes = True
|
| 295 |
+
new_mat.node_tree.nodes.clear()
|
| 296 |
+
bsdf = new_mat.node_tree.nodes.new('ShaderNodeBsdfDiffuse')
|
| 297 |
+
bsdf.inputs[0].default_value = (0.5, 0.5, 0.5, 1)
|
| 298 |
+
bsdf.inputs[1].default_value = 1
|
| 299 |
+
output = new_mat.node_tree.nodes.new('ShaderNodeOutputMaterial')
|
| 300 |
+
new_mat.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
|
| 301 |
+
bpy.context.scene.view_layers['View Layer'].material_override = new_mat
|
| 302 |
+
|
| 303 |
+
def unhide_all_objects() -> None:
|
| 304 |
+
"""Unhides all objects in the scene.
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
None
|
| 308 |
+
"""
|
| 309 |
+
for obj in bpy.context.scene.objects:
|
| 310 |
+
obj.hide_set(False)
|
| 311 |
+
|
| 312 |
+
def convert_to_meshes() -> None:
|
| 313 |
+
"""Converts all objects in the scene to meshes.
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
None
|
| 317 |
+
"""
|
| 318 |
+
bpy.ops.object.select_all(action="DESELECT")
|
| 319 |
+
bpy.context.view_layer.objects.active = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"][0]
|
| 320 |
+
for obj in bpy.context.scene.objects:
|
| 321 |
+
obj.select_set(True)
|
| 322 |
+
bpy.ops.object.convert(target="MESH")
|
| 323 |
+
|
| 324 |
+
def triangulate_meshes() -> None:
|
| 325 |
+
"""Triangulates all meshes in the scene.
|
| 326 |
+
|
| 327 |
+
Returns:
|
| 328 |
+
None
|
| 329 |
+
"""
|
| 330 |
+
bpy.ops.object.select_all(action="DESELECT")
|
| 331 |
+
objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"]
|
| 332 |
+
bpy.context.view_layer.objects.active = objs[0]
|
| 333 |
+
for obj in objs:
|
| 334 |
+
obj.select_set(True)
|
| 335 |
+
bpy.ops.object.mode_set(mode="EDIT")
|
| 336 |
+
bpy.ops.mesh.reveal()
|
| 337 |
+
bpy.ops.mesh.select_all(action="SELECT")
|
| 338 |
+
bpy.ops.mesh.quads_convert_to_tris(quad_method="BEAUTY", ngon_method="BEAUTY")
|
| 339 |
+
bpy.ops.object.mode_set(mode="OBJECT")
|
| 340 |
+
bpy.ops.object.select_all(action="DESELECT")
|
| 341 |
+
|
| 342 |
+
def scene_bbox() -> Tuple[Vector, Vector]:
|
| 343 |
+
"""Returns the bounding box of the scene.
|
| 344 |
+
|
| 345 |
+
Taken from Shap-E rendering script
|
| 346 |
+
(https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82)
|
| 347 |
+
|
| 348 |
+
Returns:
|
| 349 |
+
Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box.
|
| 350 |
+
"""
|
| 351 |
+
bbox_min = (math.inf,) * 3
|
| 352 |
+
bbox_max = (-math.inf,) * 3
|
| 353 |
+
found = False
|
| 354 |
+
scene_meshes = [obj for obj in bpy.context.scene.objects.values() if isinstance(obj.data, bpy.types.Mesh)]
|
| 355 |
+
for obj in scene_meshes:
|
| 356 |
+
found = True
|
| 357 |
+
for coord in obj.bound_box:
|
| 358 |
+
coord = Vector(coord)
|
| 359 |
+
coord = obj.matrix_world @ coord
|
| 360 |
+
bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
|
| 361 |
+
bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
|
| 362 |
+
if not found:
|
| 363 |
+
raise RuntimeError("no objects in scene to compute bounding box for")
|
| 364 |
+
return Vector(bbox_min), Vector(bbox_max)
|
| 365 |
+
|
| 366 |
+
def normalize_scene() -> Tuple[float, Vector]:
|
| 367 |
+
"""Normalizes the scene by scaling and translating it to fit in a unit cube centered
|
| 368 |
+
at the origin.
|
| 369 |
+
|
| 370 |
+
Mostly taken from the Point-E / Shap-E rendering script
|
| 371 |
+
(https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112),
|
| 372 |
+
but fix for multiple root objects: (see bug report here:
|
| 373 |
+
https://github.com/openai/shap-e/pull/60).
|
| 374 |
+
|
| 375 |
+
Returns:
|
| 376 |
+
Tuple[float, Vector]: The scale factor and the offset applied to the scene.
|
| 377 |
+
"""
|
| 378 |
+
scene_root_objects = [obj for obj in bpy.context.scene.objects.values() if not obj.parent]
|
| 379 |
+
if len(scene_root_objects) > 1:
|
| 380 |
+
# create an empty object to be used as a parent for all root objects
|
| 381 |
+
scene = bpy.data.objects.new("ParentEmpty", None)
|
| 382 |
+
bpy.context.scene.collection.objects.link(scene)
|
| 383 |
+
|
| 384 |
+
# parent all root objects to the empty object
|
| 385 |
+
for obj in scene_root_objects:
|
| 386 |
+
obj.parent = scene
|
| 387 |
+
else:
|
| 388 |
+
scene = scene_root_objects[0]
|
| 389 |
+
|
| 390 |
+
bbox_min, bbox_max = scene_bbox()
|
| 391 |
+
scale = 1 / max(bbox_max - bbox_min)
|
| 392 |
+
scene.scale = scene.scale * scale
|
| 393 |
+
|
| 394 |
+
# Apply scale to matrix_world.
|
| 395 |
+
bpy.context.view_layer.update()
|
| 396 |
+
bbox_min, bbox_max = scene_bbox()
|
| 397 |
+
offset = -(bbox_min + bbox_max) / 2
|
| 398 |
+
scene.matrix_world.translation += offset
|
| 399 |
+
bpy.ops.object.select_all(action="DESELECT")
|
| 400 |
+
|
| 401 |
+
return scale, offset
|
| 402 |
+
|
| 403 |
+
def get_transform_matrix(obj: bpy.types.Object) -> list:
|
| 404 |
+
pos, rt, _ = obj.matrix_world.decompose()
|
| 405 |
+
rt = rt.to_matrix()
|
| 406 |
+
matrix = []
|
| 407 |
+
for ii in range(3):
|
| 408 |
+
a = []
|
| 409 |
+
for jj in range(3):
|
| 410 |
+
a.append(rt[ii][jj])
|
| 411 |
+
a.append(pos[ii])
|
| 412 |
+
matrix.append(a)
|
| 413 |
+
matrix.append([0, 0, 0, 1])
|
| 414 |
+
return matrix
|
| 415 |
+
|
| 416 |
+
def main(arg):
|
| 417 |
+
os.makedirs(arg.output_folder, exist_ok=True)
|
| 418 |
+
|
| 419 |
+
# Initialize context
|
| 420 |
+
init_render(engine=arg.engine, resolution=arg.resolution, geo_mode=arg.geo_mode)
|
| 421 |
+
outputs, spec_nodes = init_nodes(
|
| 422 |
+
save_depth=arg.save_depth,
|
| 423 |
+
save_normal=arg.save_normal,
|
| 424 |
+
save_albedo=arg.save_albedo,
|
| 425 |
+
save_mist=arg.save_mist
|
| 426 |
+
)
|
| 427 |
+
if arg.object.endswith(".blend"):
|
| 428 |
+
delete_invisible_objects()
|
| 429 |
+
else:
|
| 430 |
+
init_scene()
|
| 431 |
+
load_object(arg.object)
|
| 432 |
+
if arg.split_normal:
|
| 433 |
+
split_mesh_normal()
|
| 434 |
+
# delete_custom_normals()
|
| 435 |
+
print('[INFO] Scene initialized.')
|
| 436 |
+
|
| 437 |
+
# normalize scene
|
| 438 |
+
scale, offset = normalize_scene()
|
| 439 |
+
print('[INFO] Scene normalized.')
|
| 440 |
+
|
| 441 |
+
# Initialize camera and lighting
|
| 442 |
+
cam = init_camera()
|
| 443 |
+
init_lighting()
|
| 444 |
+
print('[INFO] Camera and lighting initialized.')
|
| 445 |
+
|
| 446 |
+
# Override material
|
| 447 |
+
if arg.geo_mode:
|
| 448 |
+
override_material()
|
| 449 |
+
|
| 450 |
+
# Create a list of views
|
| 451 |
+
to_export = {
|
| 452 |
+
"aabb": [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
|
| 453 |
+
"scale": scale,
|
| 454 |
+
"offset": [offset.x, offset.y, offset.z],
|
| 455 |
+
"frames": []
|
| 456 |
+
}
|
| 457 |
+
views = json.loads(arg.views)
|
| 458 |
+
for i, view in enumerate(views):
|
| 459 |
+
cam.location = (
|
| 460 |
+
view['radius'] * np.cos(view['yaw']) * np.cos(view['pitch']),
|
| 461 |
+
view['radius'] * np.sin(view['yaw']) * np.cos(view['pitch']),
|
| 462 |
+
view['radius'] * np.sin(view['pitch'])
|
| 463 |
+
)
|
| 464 |
+
cam.data.lens = 16 / np.tan(view['fov'] / 2)
|
| 465 |
+
|
| 466 |
+
if arg.save_depth:
|
| 467 |
+
spec_nodes['depth_map'].inputs[1].default_value = view['radius'] - 0.5 * np.sqrt(3)
|
| 468 |
+
spec_nodes['depth_map'].inputs[2].default_value = view['radius'] + 0.5 * np.sqrt(3)
|
| 469 |
+
|
| 470 |
+
bpy.context.scene.render.filepath = os.path.join(arg.output_folder, f'{i:03d}.png')
|
| 471 |
+
for name, output in outputs.items():
|
| 472 |
+
output.file_slots[0].path = os.path.join(arg.output_folder, f'{i:03d}_{name}')
|
| 473 |
+
|
| 474 |
+
# Render the scene
|
| 475 |
+
bpy.ops.render.render(write_still=True)
|
| 476 |
+
bpy.context.view_layer.update()
|
| 477 |
+
for name, output in outputs.items():
|
| 478 |
+
ext = EXT[output.format.file_format]
|
| 479 |
+
path = glob.glob(f'{output.file_slots[0].path}*.{ext}')[0]
|
| 480 |
+
os.rename(path, f'{output.file_slots[0].path}.{ext}')
|
| 481 |
+
|
| 482 |
+
# Save camera parameters
|
| 483 |
+
metadata = {
|
| 484 |
+
"file_path": f'{i:03d}.png',
|
| 485 |
+
"camera_angle_x": view['fov'],
|
| 486 |
+
"transform_matrix": get_transform_matrix(cam)
|
| 487 |
+
}
|
| 488 |
+
if arg.save_depth:
|
| 489 |
+
metadata['depth'] = {
|
| 490 |
+
'min': view['radius'] - 0.5 * np.sqrt(3),
|
| 491 |
+
'max': view['radius'] + 0.5 * np.sqrt(3)
|
| 492 |
+
}
|
| 493 |
+
to_export["frames"].append(metadata)
|
| 494 |
+
|
| 495 |
+
# Save the camera parameters
|
| 496 |
+
with open(os.path.join(arg.output_folder, 'transforms.json'), 'w') as f:
|
| 497 |
+
json.dump(to_export, f, indent=4)
|
| 498 |
+
|
| 499 |
+
if arg.save_mesh:
|
| 500 |
+
# triangulate meshes
|
| 501 |
+
unhide_all_objects()
|
| 502 |
+
convert_to_meshes()
|
| 503 |
+
triangulate_meshes()
|
| 504 |
+
print('[INFO] Meshes triangulated.')
|
| 505 |
+
|
| 506 |
+
# export ply mesh
|
| 507 |
+
bpy.ops.export_mesh.ply(filepath=os.path.join(arg.output_folder, 'mesh.ply'))
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
if __name__ == '__main__':
|
| 511 |
+
parser = argparse.ArgumentParser(description='Renders given obj file by rotation a camera around it.')
|
| 512 |
+
parser.add_argument('--views', type=str, help='JSON string of views. Contains a list of {yaw, pitch, radius, fov} object.')
|
| 513 |
+
parser.add_argument('--object', type=str, help='Path to the 3D model file to be rendered.')
|
| 514 |
+
parser.add_argument('--output_folder', type=str, default='/tmp', help='The path the output will be dumped to.')
|
| 515 |
+
parser.add_argument('--resolution', type=int, default=512, help='Resolution of the images.')
|
| 516 |
+
parser.add_argument('--engine', type=str, default='CYCLES', help='Blender internal engine for rendering. E.g. CYCLES, BLENDER_EEVEE, ...')
|
| 517 |
+
parser.add_argument('--geo_mode', action='store_true', help='Geometry mode for rendering.')
|
| 518 |
+
parser.add_argument('--save_depth', action='store_true', help='Save the depth maps.')
|
| 519 |
+
parser.add_argument('--save_normal', action='store_true', help='Save the normal maps.')
|
| 520 |
+
parser.add_argument('--save_albedo', action='store_true', help='Save the albedo maps.')
|
| 521 |
+
parser.add_argument('--save_mist', action='store_true', help='Save the mist distance maps.')
|
| 522 |
+
parser.add_argument('--split_normal', action='store_true', help='Split the normals of the mesh.')
|
| 523 |
+
parser.add_argument('--save_mesh', action='store_true', help='Save the mesh as a .ply file.')
|
| 524 |
+
argv = sys.argv[sys.argv.index("--") + 1:]
|
| 525 |
+
args = parser.parse_args(argv)
|
| 526 |
+
|
| 527 |
+
main(args)
|
| 528 |
+
|
dataset_toolkits/build_metadata.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import sys
|
| 4 |
+
import time
|
| 5 |
+
import importlib
|
| 6 |
+
import argparse
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from easydict import EasyDict as edict
|
| 11 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 12 |
+
import utils3d
|
| 13 |
+
|
| 14 |
+
def get_first_directory(path):
|
| 15 |
+
with os.scandir(path) as it:
|
| 16 |
+
for entry in it:
|
| 17 |
+
if entry.is_dir():
|
| 18 |
+
return entry.name
|
| 19 |
+
return None
|
| 20 |
+
|
| 21 |
+
def need_process(key):
|
| 22 |
+
return key in opt.field or opt.field == ['all']
|
| 23 |
+
|
| 24 |
+
if __name__ == '__main__':
|
| 25 |
+
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
| 26 |
+
|
| 27 |
+
parser = argparse.ArgumentParser()
|
| 28 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
| 29 |
+
help='Directory to save the metadata')
|
| 30 |
+
parser.add_argument('--field', type=str, default='all',
|
| 31 |
+
help='Fields to process, separated by commas')
|
| 32 |
+
parser.add_argument('--from_file', action='store_true',
|
| 33 |
+
help='Build metadata from file instead of from records of processings.' +
|
| 34 |
+
'Useful when some processing fail to generate records but file already exists.')
|
| 35 |
+
dataset_utils.add_args(parser)
|
| 36 |
+
opt = parser.parse_args(sys.argv[2:])
|
| 37 |
+
opt = edict(vars(opt))
|
| 38 |
+
|
| 39 |
+
os.makedirs(opt.output_dir, exist_ok=True)
|
| 40 |
+
os.makedirs(os.path.join(opt.output_dir, 'merged_records'), exist_ok=True)
|
| 41 |
+
|
| 42 |
+
opt.field = opt.field.split(',')
|
| 43 |
+
|
| 44 |
+
timestamp = str(int(time.time()))
|
| 45 |
+
|
| 46 |
+
# get file list
|
| 47 |
+
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
| 48 |
+
print('Loading previous metadata...')
|
| 49 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
| 50 |
+
else:
|
| 51 |
+
metadata = dataset_utils.get_metadata(**opt)
|
| 52 |
+
metadata.set_index('sha256', inplace=True)
|
| 53 |
+
|
| 54 |
+
# merge downloaded
|
| 55 |
+
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('downloaded_') and f.endswith('.csv')]
|
| 56 |
+
df_parts = []
|
| 57 |
+
for f in df_files:
|
| 58 |
+
try:
|
| 59 |
+
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
| 60 |
+
except:
|
| 61 |
+
pass
|
| 62 |
+
if len(df_parts) > 0:
|
| 63 |
+
df = pd.concat(df_parts)
|
| 64 |
+
df.set_index('sha256', inplace=True)
|
| 65 |
+
if 'local_path' in metadata.columns:
|
| 66 |
+
metadata.update(df, overwrite=True)
|
| 67 |
+
else:
|
| 68 |
+
metadata = metadata.join(df, on='sha256', how='left')
|
| 69 |
+
for f in df_files:
|
| 70 |
+
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
| 71 |
+
|
| 72 |
+
# detect models
|
| 73 |
+
image_models = []
|
| 74 |
+
if os.path.exists(os.path.join(opt.output_dir, 'features')):
|
| 75 |
+
image_models = os.listdir(os.path.join(opt.output_dir, 'features'))
|
| 76 |
+
latent_models = []
|
| 77 |
+
if os.path.exists(os.path.join(opt.output_dir, 'latents')):
|
| 78 |
+
latent_models = os.listdir(os.path.join(opt.output_dir, 'latents'))
|
| 79 |
+
ss_latent_models = []
|
| 80 |
+
if os.path.exists(os.path.join(opt.output_dir, 'ss_latents')):
|
| 81 |
+
ss_latent_models = os.listdir(os.path.join(opt.output_dir, 'ss_latents'))
|
| 82 |
+
print(f'Image models: {image_models}')
|
| 83 |
+
print(f'Latent models: {latent_models}')
|
| 84 |
+
print(f'Sparse Structure latent models: {ss_latent_models}')
|
| 85 |
+
|
| 86 |
+
if 'rendered' not in metadata.columns:
|
| 87 |
+
metadata['rendered'] = [False] * len(metadata)
|
| 88 |
+
if 'voxelized' not in metadata.columns:
|
| 89 |
+
metadata['voxelized'] = [False] * len(metadata)
|
| 90 |
+
if 'num_voxels' not in metadata.columns:
|
| 91 |
+
metadata['num_voxels'] = [0] * len(metadata)
|
| 92 |
+
if 'cond_rendered' not in metadata.columns:
|
| 93 |
+
metadata['cond_rendered'] = [False] * len(metadata)
|
| 94 |
+
for model in image_models:
|
| 95 |
+
if f'feature_{model}' not in metadata.columns:
|
| 96 |
+
metadata[f'feature_{model}'] = [False] * len(metadata)
|
| 97 |
+
for model in latent_models:
|
| 98 |
+
if f'latent_{model}' not in metadata.columns:
|
| 99 |
+
metadata[f'latent_{model}'] = [False] * len(metadata)
|
| 100 |
+
for model in ss_latent_models:
|
| 101 |
+
if f'ss_latent_{model}' not in metadata.columns:
|
| 102 |
+
metadata[f'ss_latent_{model}'] = [False] * len(metadata)
|
| 103 |
+
|
| 104 |
+
# merge rendered
|
| 105 |
+
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('rendered_') and f.endswith('.csv')]
|
| 106 |
+
df_parts = []
|
| 107 |
+
for f in df_files:
|
| 108 |
+
try:
|
| 109 |
+
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
| 110 |
+
except:
|
| 111 |
+
pass
|
| 112 |
+
if len(df_parts) > 0:
|
| 113 |
+
df = pd.concat(df_parts)
|
| 114 |
+
df.set_index('sha256', inplace=True)
|
| 115 |
+
metadata.update(df, overwrite=True)
|
| 116 |
+
for f in df_files:
|
| 117 |
+
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
| 118 |
+
|
| 119 |
+
# merge voxelized
|
| 120 |
+
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('voxelized_') and f.endswith('.csv')]
|
| 121 |
+
df_parts = []
|
| 122 |
+
for f in df_files:
|
| 123 |
+
try:
|
| 124 |
+
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
| 125 |
+
except:
|
| 126 |
+
pass
|
| 127 |
+
if len(df_parts) > 0:
|
| 128 |
+
df = pd.concat(df_parts)
|
| 129 |
+
df.set_index('sha256', inplace=True)
|
| 130 |
+
metadata.update(df, overwrite=True)
|
| 131 |
+
for f in df_files:
|
| 132 |
+
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
| 133 |
+
|
| 134 |
+
# merge cond_rendered
|
| 135 |
+
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('cond_rendered_') and f.endswith('.csv')]
|
| 136 |
+
df_parts = []
|
| 137 |
+
for f in df_files:
|
| 138 |
+
try:
|
| 139 |
+
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
| 140 |
+
except:
|
| 141 |
+
pass
|
| 142 |
+
if len(df_parts) > 0:
|
| 143 |
+
df = pd.concat(df_parts)
|
| 144 |
+
df.set_index('sha256', inplace=True)
|
| 145 |
+
metadata.update(df, overwrite=True)
|
| 146 |
+
for f in df_files:
|
| 147 |
+
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
| 148 |
+
|
| 149 |
+
# merge features
|
| 150 |
+
for model in image_models:
|
| 151 |
+
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'feature_{model}_') and f.endswith('.csv')]
|
| 152 |
+
df_parts = []
|
| 153 |
+
for f in df_files:
|
| 154 |
+
try:
|
| 155 |
+
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
| 156 |
+
except:
|
| 157 |
+
pass
|
| 158 |
+
if len(df_parts) > 0:
|
| 159 |
+
df = pd.concat(df_parts)
|
| 160 |
+
df.set_index('sha256', inplace=True)
|
| 161 |
+
metadata.update(df, overwrite=True)
|
| 162 |
+
for f in df_files:
|
| 163 |
+
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
| 164 |
+
|
| 165 |
+
# merge latents
|
| 166 |
+
for model in latent_models:
|
| 167 |
+
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'latent_{model}_') and f.endswith('.csv')]
|
| 168 |
+
df_parts = []
|
| 169 |
+
for f in df_files:
|
| 170 |
+
try:
|
| 171 |
+
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
| 172 |
+
except:
|
| 173 |
+
pass
|
| 174 |
+
if len(df_parts) > 0:
|
| 175 |
+
df = pd.concat(df_parts)
|
| 176 |
+
df.set_index('sha256', inplace=True)
|
| 177 |
+
metadata.update(df, overwrite=True)
|
| 178 |
+
for f in df_files:
|
| 179 |
+
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
| 180 |
+
|
| 181 |
+
# merge sparse structure latents
|
| 182 |
+
for model in ss_latent_models:
|
| 183 |
+
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'ss_latent_{model}_') and f.endswith('.csv')]
|
| 184 |
+
df_parts = []
|
| 185 |
+
for f in df_files:
|
| 186 |
+
try:
|
| 187 |
+
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
| 188 |
+
except:
|
| 189 |
+
pass
|
| 190 |
+
if len(df_parts) > 0:
|
| 191 |
+
df = pd.concat(df_parts)
|
| 192 |
+
df.set_index('sha256', inplace=True)
|
| 193 |
+
metadata.update(df, overwrite=True)
|
| 194 |
+
for f in df_files:
|
| 195 |
+
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
| 196 |
+
|
| 197 |
+
# build metadata from files
|
| 198 |
+
if opt.from_file:
|
| 199 |
+
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
|
| 200 |
+
tqdm(total=len(metadata), desc="Building metadata") as pbar:
|
| 201 |
+
def worker(sha256):
|
| 202 |
+
try:
|
| 203 |
+
if need_process('rendered') and metadata.loc[sha256, 'rendered'] == False and \
|
| 204 |
+
os.path.exists(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json')):
|
| 205 |
+
metadata.loc[sha256, 'rendered'] = True
|
| 206 |
+
if need_process('voxelized') and metadata.loc[sha256, 'rendered'] == True and metadata.loc[sha256, 'voxelized'] == False and \
|
| 207 |
+
os.path.exists(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply')):
|
| 208 |
+
try:
|
| 209 |
+
pts = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0]
|
| 210 |
+
metadata.loc[sha256, 'voxelized'] = True
|
| 211 |
+
metadata.loc[sha256, 'num_voxels'] = len(pts)
|
| 212 |
+
except Exception as e:
|
| 213 |
+
pass
|
| 214 |
+
if need_process('cond_rendered') and metadata.loc[sha256, 'cond_rendered'] == False and \
|
| 215 |
+
os.path.exists(os.path.join(opt.output_dir, 'renders_cond', sha256, 'transforms.json')):
|
| 216 |
+
metadata.loc[sha256, 'cond_rendered'] = True
|
| 217 |
+
for model in image_models:
|
| 218 |
+
if need_process(f'feature_{model}') and \
|
| 219 |
+
metadata.loc[sha256, f'feature_{model}'] == False and \
|
| 220 |
+
metadata.loc[sha256, 'rendered'] == True and \
|
| 221 |
+
metadata.loc[sha256, 'voxelized'] == True and \
|
| 222 |
+
os.path.exists(os.path.join(opt.output_dir, 'features', model, f'{sha256}.npz')):
|
| 223 |
+
metadata.loc[sha256, f'feature_{model}'] = True
|
| 224 |
+
for model in latent_models:
|
| 225 |
+
if need_process(f'latent_{model}') and \
|
| 226 |
+
metadata.loc[sha256, f'latent_{model}'] == False and \
|
| 227 |
+
metadata.loc[sha256, 'rendered'] == True and \
|
| 228 |
+
metadata.loc[sha256, 'voxelized'] == True and \
|
| 229 |
+
os.path.exists(os.path.join(opt.output_dir, 'latents', model, f'{sha256}.npz')):
|
| 230 |
+
metadata.loc[sha256, f'latent_{model}'] = True
|
| 231 |
+
for model in ss_latent_models:
|
| 232 |
+
if need_process(f'ss_latent_{model}') and \
|
| 233 |
+
metadata.loc[sha256, f'ss_latent_{model}'] == False and \
|
| 234 |
+
metadata.loc[sha256, 'voxelized'] == True and \
|
| 235 |
+
os.path.exists(os.path.join(opt.output_dir, 'ss_latents', model, f'{sha256}.npz')):
|
| 236 |
+
metadata.loc[sha256, f'ss_latent_{model}'] = True
|
| 237 |
+
pbar.update()
|
| 238 |
+
except Exception as e:
|
| 239 |
+
print(f'Error processing {sha256}: {e}')
|
| 240 |
+
pbar.update()
|
| 241 |
+
|
| 242 |
+
executor.map(worker, metadata.index)
|
| 243 |
+
executor.shutdown(wait=True)
|
| 244 |
+
|
| 245 |
+
# statistics
|
| 246 |
+
metadata.to_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
| 247 |
+
num_downloaded = metadata['local_path'].count() if 'local_path' in metadata.columns else 0
|
| 248 |
+
with open(os.path.join(opt.output_dir, 'statistics.txt'), 'w') as f:
|
| 249 |
+
f.write('Statistics:\n')
|
| 250 |
+
f.write(f' - Number of assets: {len(metadata)}\n')
|
| 251 |
+
f.write(f' - Number of assets downloaded: {num_downloaded}\n')
|
| 252 |
+
f.write(f' - Number of assets rendered: {metadata["rendered"].sum()}\n')
|
| 253 |
+
f.write(f' - Number of assets voxelized: {metadata["voxelized"].sum()}\n')
|
| 254 |
+
if len(image_models) != 0:
|
| 255 |
+
f.write(f' - Number of assets with image features extracted:\n')
|
| 256 |
+
for model in image_models:
|
| 257 |
+
f.write(f' - {model}: {metadata[f"feature_{model}"].sum()}\n')
|
| 258 |
+
if len(latent_models) != 0:
|
| 259 |
+
f.write(f' - Number of assets with latents extracted:\n')
|
| 260 |
+
for model in latent_models:
|
| 261 |
+
f.write(f' - {model}: {metadata[f"latent_{model}"].sum()}\n')
|
| 262 |
+
if len(ss_latent_models) != 0:
|
| 263 |
+
f.write(f' - Number of assets with sparse structure latents extracted:\n')
|
| 264 |
+
for model in ss_latent_models:
|
| 265 |
+
f.write(f' - {model}: {metadata[f"ss_latent_{model}"].sum()}\n')
|
| 266 |
+
f.write(f' - Number of assets with captions: {metadata["captions"].count()}\n')
|
| 267 |
+
f.write(f' - Number of assets with image conditions: {metadata["cond_rendered"].sum()}\n')
|
| 268 |
+
|
| 269 |
+
with open(os.path.join(opt.output_dir, 'statistics.txt'), 'r') as f:
|
| 270 |
+
print(f.read())
|
dataset_toolkits/datasets/3D-FUTURE.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import argparse
|
| 4 |
+
import zipfile
|
| 5 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from utils import get_file_hash
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def add_args(parser: argparse.ArgumentParser):
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_metadata(**kwargs):
|
| 16 |
+
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/3D-FUTURE.csv")
|
| 17 |
+
return metadata
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def download(metadata, output_dir, **kwargs):
|
| 21 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 22 |
+
|
| 23 |
+
if not os.path.exists(os.path.join(output_dir, 'raw', '3D-FUTURE-model.zip')):
|
| 24 |
+
print("\033[93m")
|
| 25 |
+
print("3D-FUTURE have to be downloaded manually")
|
| 26 |
+
print(f"Please download the 3D-FUTURE-model.zip file and place it in the {output_dir}/raw directory")
|
| 27 |
+
print("Visit https://tianchi.aliyun.com/specials/promotion/alibaba-3d-future for more information")
|
| 28 |
+
print("\033[0m")
|
| 29 |
+
raise FileNotFoundError("3D-FUTURE-model.zip not found")
|
| 30 |
+
|
| 31 |
+
downloaded = {}
|
| 32 |
+
metadata = metadata.set_index("file_identifier")
|
| 33 |
+
with zipfile.ZipFile(os.path.join(output_dir, 'raw', '3D-FUTURE-model.zip')) as zip_ref:
|
| 34 |
+
all_names = zip_ref.namelist()
|
| 35 |
+
instances = [instance[:-1] for instance in all_names if re.match(r"^3D-FUTURE-model/[^/]+/$", instance)]
|
| 36 |
+
instances = list(filter(lambda x: x in metadata.index, instances))
|
| 37 |
+
|
| 38 |
+
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
|
| 39 |
+
tqdm(total=len(instances), desc="Extracting") as pbar:
|
| 40 |
+
def worker(instance: str) -> str:
|
| 41 |
+
try:
|
| 42 |
+
instance_files = list(filter(lambda x: x.startswith(f"{instance}/") and not x.endswith("/"), all_names))
|
| 43 |
+
zip_ref.extractall(os.path.join(output_dir, 'raw'), members=instance_files)
|
| 44 |
+
sha256 = get_file_hash(os.path.join(output_dir, 'raw', f"{instance}/image.jpg"))
|
| 45 |
+
pbar.update()
|
| 46 |
+
return sha256
|
| 47 |
+
except Exception as e:
|
| 48 |
+
pbar.update()
|
| 49 |
+
print(f"Error extracting for {instance}: {e}")
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
sha256s = executor.map(worker, instances)
|
| 53 |
+
executor.shutdown(wait=True)
|
| 54 |
+
|
| 55 |
+
for k, sha256 in zip(instances, sha256s):
|
| 56 |
+
if sha256 is not None:
|
| 57 |
+
if sha256 == metadata.loc[k, "sha256"]:
|
| 58 |
+
downloaded[sha256] = os.path.join("raw", f"{k}/raw_model.obj")
|
| 59 |
+
else:
|
| 60 |
+
print(f"Error downloading {k}: sha256s do not match")
|
| 61 |
+
|
| 62 |
+
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
|
| 66 |
+
import os
|
| 67 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 68 |
+
from tqdm import tqdm
|
| 69 |
+
|
| 70 |
+
# load metadata
|
| 71 |
+
metadata = metadata.to_dict('records')
|
| 72 |
+
|
| 73 |
+
# processing objects
|
| 74 |
+
records = []
|
| 75 |
+
max_workers = max_workers or os.cpu_count()
|
| 76 |
+
try:
|
| 77 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
|
| 78 |
+
tqdm(total=len(metadata), desc=desc) as pbar:
|
| 79 |
+
def worker(metadatum):
|
| 80 |
+
try:
|
| 81 |
+
local_path = metadatum['local_path']
|
| 82 |
+
sha256 = metadatum['sha256']
|
| 83 |
+
file = os.path.join(output_dir, local_path)
|
| 84 |
+
record = func(file, sha256)
|
| 85 |
+
if record is not None:
|
| 86 |
+
records.append(record)
|
| 87 |
+
pbar.update()
|
| 88 |
+
except Exception as e:
|
| 89 |
+
print(f"Error processing object {sha256}: {e}")
|
| 90 |
+
pbar.update()
|
| 91 |
+
|
| 92 |
+
executor.map(worker, metadata)
|
| 93 |
+
executor.shutdown(wait=True)
|
| 94 |
+
except:
|
| 95 |
+
print("Error happened during processing.")
|
| 96 |
+
|
| 97 |
+
return pd.DataFrame.from_records(records)
|
dataset_toolkits/datasets/ABO.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import argparse
|
| 4 |
+
import tarfile
|
| 5 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from utils import get_file_hash
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def add_args(parser: argparse.ArgumentParser):
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_metadata(**kwargs):
|
| 16 |
+
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ABO.csv")
|
| 17 |
+
return metadata
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def download(metadata, output_dir, **kwargs):
|
| 21 |
+
os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
|
| 22 |
+
|
| 23 |
+
if not os.path.exists(os.path.join(output_dir, 'raw', 'abo-3dmodels.tar')):
|
| 24 |
+
try:
|
| 25 |
+
os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
|
| 26 |
+
os.system(f"wget -O {output_dir}/raw/abo-3dmodels.tar https://amazon-berkeley-objects.s3.amazonaws.com/archives/abo-3dmodels.tar")
|
| 27 |
+
except:
|
| 28 |
+
print("\033[93m")
|
| 29 |
+
print("Error downloading ABO dataset. Please check your internet connection and try again.")
|
| 30 |
+
print("Or, you can manually download the abo-3dmodels.tar file and place it in the {output_dir}/raw directory")
|
| 31 |
+
print("Visit https://amazon-berkeley-objects.s3.amazonaws.com/index.html for more information")
|
| 32 |
+
print("\033[0m")
|
| 33 |
+
raise FileNotFoundError("Error downloading ABO dataset")
|
| 34 |
+
|
| 35 |
+
downloaded = {}
|
| 36 |
+
metadata = metadata.set_index("file_identifier")
|
| 37 |
+
with tarfile.open(os.path.join(output_dir, 'raw', 'abo-3dmodels.tar')) as tar:
|
| 38 |
+
with ThreadPoolExecutor(max_workers=1) as executor, \
|
| 39 |
+
tqdm(total=len(metadata), desc="Extracting") as pbar:
|
| 40 |
+
def worker(instance: str) -> str:
|
| 41 |
+
try:
|
| 42 |
+
tar.extract(f"3dmodels/original/{instance}", path=os.path.join(output_dir, 'raw'))
|
| 43 |
+
sha256 = get_file_hash(os.path.join(output_dir, 'raw/3dmodels/original', instance))
|
| 44 |
+
pbar.update()
|
| 45 |
+
return sha256
|
| 46 |
+
except Exception as e:
|
| 47 |
+
pbar.update()
|
| 48 |
+
print(f"Error extracting for {instance}: {e}")
|
| 49 |
+
return None
|
| 50 |
+
|
| 51 |
+
sha256s = executor.map(worker, metadata.index)
|
| 52 |
+
executor.shutdown(wait=True)
|
| 53 |
+
|
| 54 |
+
for k, sha256 in zip(metadata.index, sha256s):
|
| 55 |
+
if sha256 is not None:
|
| 56 |
+
if sha256 == metadata.loc[k, "sha256"]:
|
| 57 |
+
downloaded[sha256] = os.path.join('raw/3dmodels/original', k)
|
| 58 |
+
else:
|
| 59 |
+
print(f"Error downloading {k}: sha256s do not match")
|
| 60 |
+
|
| 61 |
+
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
|
| 65 |
+
import os
|
| 66 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 67 |
+
from tqdm import tqdm
|
| 68 |
+
|
| 69 |
+
# load metadata
|
| 70 |
+
metadata = metadata.to_dict('records')
|
| 71 |
+
|
| 72 |
+
# processing objects
|
| 73 |
+
records = []
|
| 74 |
+
max_workers = max_workers or os.cpu_count()
|
| 75 |
+
try:
|
| 76 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
|
| 77 |
+
tqdm(total=len(metadata), desc=desc) as pbar:
|
| 78 |
+
def worker(metadatum):
|
| 79 |
+
try:
|
| 80 |
+
local_path = metadatum['local_path']
|
| 81 |
+
sha256 = metadatum['sha256']
|
| 82 |
+
file = os.path.join(output_dir, local_path)
|
| 83 |
+
record = func(file, sha256)
|
| 84 |
+
if record is not None:
|
| 85 |
+
records.append(record)
|
| 86 |
+
pbar.update()
|
| 87 |
+
except Exception as e:
|
| 88 |
+
print(f"Error processing object {sha256}: {e}")
|
| 89 |
+
pbar.update()
|
| 90 |
+
|
| 91 |
+
executor.map(worker, metadata)
|
| 92 |
+
executor.shutdown(wait=True)
|
| 93 |
+
except:
|
| 94 |
+
print("Error happened during processing.")
|
| 95 |
+
|
| 96 |
+
return pd.DataFrame.from_records(records)
|
dataset_toolkits/datasets/HSSD.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import argparse
|
| 4 |
+
import tarfile
|
| 5 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import huggingface_hub
|
| 9 |
+
from utils import get_file_hash
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def add_args(parser: argparse.ArgumentParser):
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_metadata(**kwargs):
|
| 17 |
+
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/HSSD.csv")
|
| 18 |
+
return metadata
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def download(metadata, output_dir, **kwargs):
|
| 22 |
+
os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
|
| 23 |
+
|
| 24 |
+
# check login
|
| 25 |
+
try:
|
| 26 |
+
huggingface_hub.whoami()
|
| 27 |
+
except:
|
| 28 |
+
print("\033[93m")
|
| 29 |
+
print("Haven't logged in to the Hugging Face Hub.")
|
| 30 |
+
print("Visit https://huggingface.co/settings/tokens to get a token.")
|
| 31 |
+
print("\033[0m")
|
| 32 |
+
huggingface_hub.login()
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
huggingface_hub.hf_hub_download(repo_id="hssd/hssd-models", filename="README.md", repo_type="dataset")
|
| 36 |
+
except:
|
| 37 |
+
print("\033[93m")
|
| 38 |
+
print("Error downloading HSSD dataset.")
|
| 39 |
+
print("Check if you have access to the HSSD dataset.")
|
| 40 |
+
print("Visit https://huggingface.co/datasets/hssd/hssd-models for more information")
|
| 41 |
+
print("\033[0m")
|
| 42 |
+
|
| 43 |
+
downloaded = {}
|
| 44 |
+
metadata = metadata.set_index("file_identifier")
|
| 45 |
+
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
|
| 46 |
+
tqdm(total=len(metadata), desc="Downloading") as pbar:
|
| 47 |
+
def worker(instance: str) -> str:
|
| 48 |
+
try:
|
| 49 |
+
huggingface_hub.hf_hub_download(repo_id="hssd/hssd-models", filename=instance, repo_type="dataset", local_dir=os.path.join(output_dir, 'raw'))
|
| 50 |
+
sha256 = get_file_hash(os.path.join(output_dir, 'raw', instance))
|
| 51 |
+
pbar.update()
|
| 52 |
+
return sha256
|
| 53 |
+
except Exception as e:
|
| 54 |
+
pbar.update()
|
| 55 |
+
print(f"Error extracting for {instance}: {e}")
|
| 56 |
+
return None
|
| 57 |
+
|
| 58 |
+
sha256s = executor.map(worker, metadata.index)
|
| 59 |
+
executor.shutdown(wait=True)
|
| 60 |
+
|
| 61 |
+
for k, sha256 in zip(metadata.index, sha256s):
|
| 62 |
+
if sha256 is not None:
|
| 63 |
+
if sha256 == metadata.loc[k, "sha256"]:
|
| 64 |
+
downloaded[sha256] = os.path.join('raw', k)
|
| 65 |
+
else:
|
| 66 |
+
print(f"Error downloading {k}: sha256s do not match")
|
| 67 |
+
|
| 68 |
+
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
|
| 72 |
+
import os
|
| 73 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 74 |
+
from tqdm import tqdm
|
| 75 |
+
|
| 76 |
+
# load metadata
|
| 77 |
+
metadata = metadata.to_dict('records')
|
| 78 |
+
|
| 79 |
+
# processing objects
|
| 80 |
+
records = []
|
| 81 |
+
max_workers = max_workers or os.cpu_count()
|
| 82 |
+
try:
|
| 83 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
|
| 84 |
+
tqdm(total=len(metadata), desc=desc) as pbar:
|
| 85 |
+
def worker(metadatum):
|
| 86 |
+
try:
|
| 87 |
+
local_path = metadatum['local_path']
|
| 88 |
+
sha256 = metadatum['sha256']
|
| 89 |
+
file = os.path.join(output_dir, local_path)
|
| 90 |
+
record = func(file, sha256)
|
| 91 |
+
if record is not None:
|
| 92 |
+
records.append(record)
|
| 93 |
+
pbar.update()
|
| 94 |
+
except Exception as e:
|
| 95 |
+
print(f"Error processing object {sha256}: {e}")
|
| 96 |
+
pbar.update()
|
| 97 |
+
|
| 98 |
+
executor.map(worker, metadata)
|
| 99 |
+
executor.shutdown(wait=True)
|
| 100 |
+
except:
|
| 101 |
+
print("Error happened during processing.")
|
| 102 |
+
|
| 103 |
+
return pd.DataFrame.from_records(records)
|
dataset_toolkits/datasets/ObjaverseXL.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import objaverse.xl as oxl
|
| 7 |
+
from utils import get_file_hash
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def add_args(parser: argparse.ArgumentParser):
|
| 11 |
+
parser.add_argument('--source', type=str, default='sketchfab',
|
| 12 |
+
help='Data source to download annotations from (github, sketchfab)')
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_metadata(source, **kwargs):
|
| 16 |
+
if source == 'sketchfab':
|
| 17 |
+
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ObjaverseXL_sketchfab.csv")
|
| 18 |
+
elif source == 'github':
|
| 19 |
+
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ObjaverseXL_github.csv")
|
| 20 |
+
else:
|
| 21 |
+
raise ValueError(f"Invalid source: {source}")
|
| 22 |
+
return metadata
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def download(metadata, output_dir, **kwargs):
|
| 26 |
+
os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
|
| 27 |
+
|
| 28 |
+
# download annotations
|
| 29 |
+
annotations = oxl.get_annotations()
|
| 30 |
+
annotations = annotations[annotations['sha256'].isin(metadata['sha256'].values)]
|
| 31 |
+
|
| 32 |
+
# download and render objects
|
| 33 |
+
file_paths = oxl.download_objects(
|
| 34 |
+
annotations,
|
| 35 |
+
download_dir=os.path.join(output_dir, "raw"),
|
| 36 |
+
save_repo_format="zip",
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
downloaded = {}
|
| 40 |
+
metadata = metadata.set_index("file_identifier")
|
| 41 |
+
for k, v in file_paths.items():
|
| 42 |
+
sha256 = metadata.loc[k, "sha256"]
|
| 43 |
+
downloaded[sha256] = os.path.relpath(v, output_dir)
|
| 44 |
+
|
| 45 |
+
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
|
| 49 |
+
import os
|
| 50 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 51 |
+
from tqdm import tqdm
|
| 52 |
+
import tempfile
|
| 53 |
+
import zipfile
|
| 54 |
+
|
| 55 |
+
# load metadata
|
| 56 |
+
metadata = metadata.to_dict('records')
|
| 57 |
+
|
| 58 |
+
# processing objects
|
| 59 |
+
records = []
|
| 60 |
+
max_workers = max_workers or os.cpu_count()
|
| 61 |
+
try:
|
| 62 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
|
| 63 |
+
tqdm(total=len(metadata), desc=desc) as pbar:
|
| 64 |
+
def worker(metadatum):
|
| 65 |
+
try:
|
| 66 |
+
local_path = metadatum['local_path']
|
| 67 |
+
sha256 = metadatum['sha256']
|
| 68 |
+
if local_path.startswith('raw/github/repos/'):
|
| 69 |
+
path_parts = local_path.split('/')
|
| 70 |
+
file_name = os.path.join(*path_parts[5:])
|
| 71 |
+
zip_file = os.path.join(output_dir, *path_parts[:5])
|
| 72 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 73 |
+
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
|
| 74 |
+
zip_ref.extractall(tmp_dir)
|
| 75 |
+
file = os.path.join(tmp_dir, file_name)
|
| 76 |
+
record = func(file, sha256)
|
| 77 |
+
else:
|
| 78 |
+
file = os.path.join(output_dir, local_path)
|
| 79 |
+
record = func(file, sha256)
|
| 80 |
+
if record is not None:
|
| 81 |
+
records.append(record)
|
| 82 |
+
pbar.update()
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print(f"Error processing object {sha256}: {e}")
|
| 85 |
+
pbar.update()
|
| 86 |
+
|
| 87 |
+
executor.map(worker, metadata)
|
| 88 |
+
executor.shutdown(wait=True)
|
| 89 |
+
except:
|
| 90 |
+
print("Error happened during processing.")
|
| 91 |
+
|
| 92 |
+
return pd.DataFrame.from_records(records)
|
dataset_toolkits/datasets/Toys4k.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import argparse
|
| 4 |
+
import zipfile
|
| 5 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from utils import get_file_hash
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def add_args(parser: argparse.ArgumentParser):
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_metadata(**kwargs):
|
| 16 |
+
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/Toys4k.csv")
|
| 17 |
+
return metadata
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def download(metadata, output_dir, **kwargs):
|
| 21 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 22 |
+
|
| 23 |
+
if not os.path.exists(os.path.join(output_dir, 'raw', 'toys4k_blend_files.zip')):
|
| 24 |
+
print("\033[93m")
|
| 25 |
+
print("Toys4k have to be downloaded manually")
|
| 26 |
+
print(f"Please download the toys4k_blend_files.zip file and place it in the {output_dir}/raw directory")
|
| 27 |
+
print("Visit https://github.com/rehg-lab/lowshot-shapebias/tree/main/toys4k for more information")
|
| 28 |
+
print("\033[0m")
|
| 29 |
+
raise FileNotFoundError("toys4k_blend_files.zip not found")
|
| 30 |
+
|
| 31 |
+
downloaded = {}
|
| 32 |
+
metadata = metadata.set_index("file_identifier")
|
| 33 |
+
with zipfile.ZipFile(os.path.join(output_dir, 'raw', 'toys4k_blend_files.zip')) as zip_ref:
|
| 34 |
+
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
|
| 35 |
+
tqdm(total=len(metadata), desc="Extracting") as pbar:
|
| 36 |
+
def worker(instance: str) -> str:
|
| 37 |
+
try:
|
| 38 |
+
zip_ref.extract(os.path.join('toys4k_blend_files', instance), os.path.join(output_dir, 'raw'))
|
| 39 |
+
sha256 = get_file_hash(os.path.join(output_dir, 'raw/toys4k_blend_files', instance))
|
| 40 |
+
pbar.update()
|
| 41 |
+
return sha256
|
| 42 |
+
except Exception as e:
|
| 43 |
+
pbar.update()
|
| 44 |
+
print(f"Error extracting for {instance}: {e}")
|
| 45 |
+
return None
|
| 46 |
+
|
| 47 |
+
sha256s = executor.map(worker, metadata.index)
|
| 48 |
+
executor.shutdown(wait=True)
|
| 49 |
+
|
| 50 |
+
for k, sha256 in zip(metadata.index, sha256s):
|
| 51 |
+
if sha256 is not None:
|
| 52 |
+
if sha256 == metadata.loc[k, "sha256"]:
|
| 53 |
+
downloaded[sha256] = os.path.join("raw/toys4k_blend_files", k)
|
| 54 |
+
else:
|
| 55 |
+
print(f"Error downloading {k}: sha256s do not match")
|
| 56 |
+
|
| 57 |
+
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
|
| 61 |
+
import os
|
| 62 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 63 |
+
from tqdm import tqdm
|
| 64 |
+
|
| 65 |
+
# load metadata
|
| 66 |
+
metadata = metadata.to_dict('records')
|
| 67 |
+
|
| 68 |
+
# processing objects
|
| 69 |
+
records = []
|
| 70 |
+
max_workers = max_workers or os.cpu_count()
|
| 71 |
+
try:
|
| 72 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
|
| 73 |
+
tqdm(total=len(metadata), desc=desc) as pbar:
|
| 74 |
+
def worker(metadatum):
|
| 75 |
+
try:
|
| 76 |
+
local_path = metadatum['local_path']
|
| 77 |
+
sha256 = metadatum['sha256']
|
| 78 |
+
file = os.path.join(output_dir, local_path)
|
| 79 |
+
record = func(file, sha256)
|
| 80 |
+
if record is not None:
|
| 81 |
+
records.append(record)
|
| 82 |
+
pbar.update()
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print(f"Error processing object {sha256}: {e}")
|
| 85 |
+
pbar.update()
|
| 86 |
+
|
| 87 |
+
executor.map(worker, metadata)
|
| 88 |
+
executor.shutdown(wait=True)
|
| 89 |
+
except:
|
| 90 |
+
print("Error happened during processing.")
|
| 91 |
+
|
| 92 |
+
return pd.DataFrame.from_records(records)
|
dataset_toolkits/download.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import copy
|
| 3 |
+
import sys
|
| 4 |
+
import importlib
|
| 5 |
+
import argparse
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from easydict import EasyDict as edict
|
| 8 |
+
|
| 9 |
+
if __name__ == '__main__':
|
| 10 |
+
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
| 11 |
+
|
| 12 |
+
parser = argparse.ArgumentParser()
|
| 13 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
| 14 |
+
help='Directory to save the metadata')
|
| 15 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
| 16 |
+
help='Filter objects with aesthetic score lower than this value')
|
| 17 |
+
parser.add_argument('--instances', type=str, default=None,
|
| 18 |
+
help='Instances to process')
|
| 19 |
+
dataset_utils.add_args(parser)
|
| 20 |
+
parser.add_argument('--rank', type=int, default=0)
|
| 21 |
+
parser.add_argument('--world_size', type=int, default=1)
|
| 22 |
+
opt = parser.parse_args(sys.argv[2:])
|
| 23 |
+
opt = edict(vars(opt))
|
| 24 |
+
|
| 25 |
+
os.makedirs(opt.output_dir, exist_ok=True)
|
| 26 |
+
|
| 27 |
+
# get file list
|
| 28 |
+
if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
| 29 |
+
raise ValueError('metadata.csv not found')
|
| 30 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
| 31 |
+
if opt.instances is None:
|
| 32 |
+
if opt.filter_low_aesthetic_score is not None:
|
| 33 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
| 34 |
+
if 'local_path' in metadata.columns:
|
| 35 |
+
metadata = metadata[metadata['local_path'].isna()]
|
| 36 |
+
else:
|
| 37 |
+
if os.path.exists(opt.instances):
|
| 38 |
+
with open(opt.instances, 'r') as f:
|
| 39 |
+
instances = f.read().splitlines()
|
| 40 |
+
else:
|
| 41 |
+
instances = opt.instances.split(',')
|
| 42 |
+
metadata = metadata[metadata['sha256'].isin(instances)]
|
| 43 |
+
|
| 44 |
+
start = len(metadata) * opt.rank // opt.world_size
|
| 45 |
+
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
| 46 |
+
metadata = metadata[start:end]
|
| 47 |
+
|
| 48 |
+
print(f'Processing {len(metadata)} objects...')
|
| 49 |
+
|
| 50 |
+
# process objects
|
| 51 |
+
downloaded = dataset_utils.download(metadata, **opt)
|
| 52 |
+
downloaded.to_csv(os.path.join(opt.output_dir, f'downloaded_{opt.rank}.csv'), index=False)
|
dataset_toolkits/encode_latent.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
| 4 |
+
import copy
|
| 5 |
+
import json
|
| 6 |
+
import argparse
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from easydict import EasyDict as edict
|
| 12 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 13 |
+
from queue import Queue
|
| 14 |
+
|
| 15 |
+
import trellis.models as models
|
| 16 |
+
import trellis.modules.sparse as sp
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
torch.set_grad_enabled(False)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
if __name__ == '__main__':
|
| 23 |
+
parser = argparse.ArgumentParser()
|
| 24 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
| 25 |
+
help='Directory to save the metadata')
|
| 26 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
| 27 |
+
help='Filter objects with aesthetic score lower than this value')
|
| 28 |
+
parser.add_argument('--feat_model', type=str, default='dinov2_vitl14_reg',
|
| 29 |
+
help='Feature model')
|
| 30 |
+
parser.add_argument('--enc_pretrained', type=str, default='JeffreyXiang/TRELLIS-image-large/ckpts/slat_enc_swin8_B_64l8_fp16',
|
| 31 |
+
help='Pretrained encoder model')
|
| 32 |
+
parser.add_argument('--model_root', type=str, default='results',
|
| 33 |
+
help='Root directory of models')
|
| 34 |
+
parser.add_argument('--enc_model', type=str, default=None,
|
| 35 |
+
help='Encoder model. if specified, use this model instead of pretrained model')
|
| 36 |
+
parser.add_argument('--ckpt', type=str, default=None,
|
| 37 |
+
help='Checkpoint to load')
|
| 38 |
+
parser.add_argument('--instances', type=str, default=None,
|
| 39 |
+
help='Instances to process')
|
| 40 |
+
parser.add_argument('--rank', type=int, default=0)
|
| 41 |
+
parser.add_argument('--world_size', type=int, default=1)
|
| 42 |
+
opt = parser.parse_args()
|
| 43 |
+
opt = edict(vars(opt))
|
| 44 |
+
|
| 45 |
+
if opt.enc_model is None:
|
| 46 |
+
latent_name = f'{opt.feat_model}_{opt.enc_pretrained.split("/")[-1]}'
|
| 47 |
+
encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
|
| 48 |
+
else:
|
| 49 |
+
latent_name = f'{opt.feat_model}_{opt.enc_model}_{opt.ckpt}'
|
| 50 |
+
cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
|
| 51 |
+
encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
|
| 52 |
+
ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
|
| 53 |
+
encoder.load_state_dict(torch.load(ckpt_path), strict=False)
|
| 54 |
+
encoder.eval()
|
| 55 |
+
print(f'Loaded model from {ckpt_path}')
|
| 56 |
+
|
| 57 |
+
os.makedirs(os.path.join(opt.output_dir, 'latents', latent_name), exist_ok=True)
|
| 58 |
+
|
| 59 |
+
# get file list
|
| 60 |
+
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
| 61 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError('metadata.csv not found')
|
| 64 |
+
if opt.instances is not None:
|
| 65 |
+
with open(opt.instances, 'r') as f:
|
| 66 |
+
sha256s = [line.strip() for line in f]
|
| 67 |
+
metadata = metadata[metadata['sha256'].isin(sha256s)]
|
| 68 |
+
else:
|
| 69 |
+
if opt.filter_low_aesthetic_score is not None:
|
| 70 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
| 71 |
+
metadata = metadata[metadata[f'feature_{opt.feat_model}'] == True]
|
| 72 |
+
if f'latent_{latent_name}' in metadata.columns:
|
| 73 |
+
metadata = metadata[metadata[f'latent_{latent_name}'] == False]
|
| 74 |
+
|
| 75 |
+
start = len(metadata) * opt.rank // opt.world_size
|
| 76 |
+
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
| 77 |
+
metadata = metadata[start:end]
|
| 78 |
+
records = []
|
| 79 |
+
|
| 80 |
+
# filter out objects that are already processed
|
| 81 |
+
sha256s = list(metadata['sha256'].values)
|
| 82 |
+
for sha256 in copy.copy(sha256s):
|
| 83 |
+
if os.path.exists(os.path.join(opt.output_dir, 'latents', latent_name, f'{sha256}.npz')):
|
| 84 |
+
records.append({'sha256': sha256, f'latent_{latent_name}': True})
|
| 85 |
+
sha256s.remove(sha256)
|
| 86 |
+
|
| 87 |
+
# encode latents
|
| 88 |
+
load_queue = Queue(maxsize=4)
|
| 89 |
+
try:
|
| 90 |
+
with ThreadPoolExecutor(max_workers=32) as loader_executor, \
|
| 91 |
+
ThreadPoolExecutor(max_workers=32) as saver_executor:
|
| 92 |
+
def loader(sha256):
|
| 93 |
+
try:
|
| 94 |
+
feats = np.load(os.path.join(opt.output_dir, 'features', opt.feat_model, f'{sha256}.npz'))
|
| 95 |
+
load_queue.put((sha256, feats))
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print(f"Error loading features for {sha256}: {e}")
|
| 98 |
+
loader_executor.map(loader, sha256s)
|
| 99 |
+
|
| 100 |
+
def saver(sha256, pack):
|
| 101 |
+
save_path = os.path.join(opt.output_dir, 'latents', latent_name, f'{sha256}.npz')
|
| 102 |
+
np.savez_compressed(save_path, **pack)
|
| 103 |
+
records.append({'sha256': sha256, f'latent_{latent_name}': True})
|
| 104 |
+
|
| 105 |
+
for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
|
| 106 |
+
sha256, feats = load_queue.get()
|
| 107 |
+
feats = sp.SparseTensor(
|
| 108 |
+
feats = torch.from_numpy(feats['patchtokens']).float(),
|
| 109 |
+
coords = torch.cat([
|
| 110 |
+
torch.zeros(feats['patchtokens'].shape[0], 1).int(),
|
| 111 |
+
torch.from_numpy(feats['indices']).int(),
|
| 112 |
+
], dim=1),
|
| 113 |
+
).cuda()
|
| 114 |
+
latent = encoder(feats, sample_posterior=False)
|
| 115 |
+
assert torch.isfinite(latent.feats).all(), "Non-finite latent"
|
| 116 |
+
pack = {
|
| 117 |
+
'feats': latent.feats.cpu().numpy().astype(np.float32),
|
| 118 |
+
'coords': latent.coords[:, 1:].cpu().numpy().astype(np.uint8),
|
| 119 |
+
}
|
| 120 |
+
saver_executor.submit(saver, sha256, pack)
|
| 121 |
+
|
| 122 |
+
saver_executor.shutdown(wait=True)
|
| 123 |
+
except:
|
| 124 |
+
print("Error happened during processing.")
|
| 125 |
+
|
| 126 |
+
records = pd.DataFrame.from_records(records)
|
| 127 |
+
records.to_csv(os.path.join(opt.output_dir, f'latent_{latent_name}_{opt.rank}.csv'), index=False)
|
dataset_toolkits/encode_ss_latent.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
| 4 |
+
import copy
|
| 5 |
+
import json
|
| 6 |
+
import argparse
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import utils3d
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from easydict import EasyDict as edict
|
| 13 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 14 |
+
from queue import Queue
|
| 15 |
+
|
| 16 |
+
import trellis.models as models
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
torch.set_grad_enabled(False)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_voxels(instance):
|
| 23 |
+
position = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{instance}.ply'))[0]
|
| 24 |
+
coords = ((torch.tensor(position) + 0.5) * opt.resolution).int().contiguous()
|
| 25 |
+
ss = torch.zeros(1, opt.resolution, opt.resolution, opt.resolution, dtype=torch.long)
|
| 26 |
+
ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
|
| 27 |
+
return ss
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
if __name__ == '__main__':
|
| 31 |
+
parser = argparse.ArgumentParser()
|
| 32 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
| 33 |
+
help='Directory to save the metadata')
|
| 34 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
| 35 |
+
help='Filter objects with aesthetic score lower than this value')
|
| 36 |
+
parser.add_argument('--enc_pretrained', type=str, default='JeffreyXiang/TRELLIS-image-large/ckpts/ss_enc_conv3d_16l8_fp16',
|
| 37 |
+
help='Pretrained encoder model')
|
| 38 |
+
parser.add_argument('--model_root', type=str, default='results',
|
| 39 |
+
help='Root directory of models')
|
| 40 |
+
parser.add_argument('--enc_model', type=str, default=None,
|
| 41 |
+
help='Encoder model. if specified, use this model instead of pretrained model')
|
| 42 |
+
parser.add_argument('--ckpt', type=str, default=None,
|
| 43 |
+
help='Checkpoint to load')
|
| 44 |
+
parser.add_argument('--resolution', type=int, default=64,
|
| 45 |
+
help='Resolution')
|
| 46 |
+
parser.add_argument('--instances', type=str, default=None,
|
| 47 |
+
help='Instances to process')
|
| 48 |
+
parser.add_argument('--rank', type=int, default=0)
|
| 49 |
+
parser.add_argument('--world_size', type=int, default=1)
|
| 50 |
+
opt = parser.parse_args()
|
| 51 |
+
opt = edict(vars(opt))
|
| 52 |
+
|
| 53 |
+
if opt.enc_model is None:
|
| 54 |
+
latent_name = f'{opt.enc_pretrained.split("/")[-1]}'
|
| 55 |
+
encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
|
| 56 |
+
else:
|
| 57 |
+
latent_name = f'{opt.enc_model}_{opt.ckpt}'
|
| 58 |
+
cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
|
| 59 |
+
encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
|
| 60 |
+
ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
|
| 61 |
+
encoder.load_state_dict(torch.load(ckpt_path), strict=False)
|
| 62 |
+
encoder.eval()
|
| 63 |
+
print(f'Loaded model from {ckpt_path}')
|
| 64 |
+
|
| 65 |
+
os.makedirs(os.path.join(opt.output_dir, 'ss_latents', latent_name), exist_ok=True)
|
| 66 |
+
|
| 67 |
+
# get file list
|
| 68 |
+
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
| 69 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
| 70 |
+
else:
|
| 71 |
+
raise ValueError('metadata.csv not found')
|
| 72 |
+
if opt.instances is not None:
|
| 73 |
+
with open(opt.instances, 'r') as f:
|
| 74 |
+
instances = f.read().splitlines()
|
| 75 |
+
metadata = metadata[metadata['sha256'].isin(instances)]
|
| 76 |
+
else:
|
| 77 |
+
if opt.filter_low_aesthetic_score is not None:
|
| 78 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
| 79 |
+
metadata = metadata[metadata['voxelized'] == True]
|
| 80 |
+
if f'ss_latent_{latent_name}' in metadata.columns:
|
| 81 |
+
metadata = metadata[metadata[f'ss_latent_{latent_name}'] == False]
|
| 82 |
+
|
| 83 |
+
start = len(metadata) * opt.rank // opt.world_size
|
| 84 |
+
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
| 85 |
+
metadata = metadata[start:end]
|
| 86 |
+
records = []
|
| 87 |
+
|
| 88 |
+
# filter out objects that are already processed
|
| 89 |
+
sha256s = list(metadata['sha256'].values)
|
| 90 |
+
for sha256 in copy.copy(sha256s):
|
| 91 |
+
if os.path.exists(os.path.join(opt.output_dir, 'ss_latents', latent_name, f'{sha256}.npz')):
|
| 92 |
+
records.append({'sha256': sha256, f'ss_latent_{latent_name}': True})
|
| 93 |
+
sha256s.remove(sha256)
|
| 94 |
+
|
| 95 |
+
# encode latents
|
| 96 |
+
load_queue = Queue(maxsize=4)
|
| 97 |
+
try:
|
| 98 |
+
with ThreadPoolExecutor(max_workers=32) as loader_executor, \
|
| 99 |
+
ThreadPoolExecutor(max_workers=32) as saver_executor:
|
| 100 |
+
def loader(sha256):
|
| 101 |
+
try:
|
| 102 |
+
ss = get_voxels(sha256)[None].float()
|
| 103 |
+
load_queue.put((sha256, ss))
|
| 104 |
+
except Exception as e:
|
| 105 |
+
print(f"Error loading features for {sha256}: {e}")
|
| 106 |
+
loader_executor.map(loader, sha256s)
|
| 107 |
+
|
| 108 |
+
def saver(sha256, pack):
|
| 109 |
+
save_path = os.path.join(opt.output_dir, 'ss_latents', latent_name, f'{sha256}.npz')
|
| 110 |
+
np.savez_compressed(save_path, **pack)
|
| 111 |
+
records.append({'sha256': sha256, f'ss_latent_{latent_name}': True})
|
| 112 |
+
|
| 113 |
+
for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
|
| 114 |
+
sha256, ss = load_queue.get()
|
| 115 |
+
ss = ss.cuda().float()
|
| 116 |
+
latent = encoder(ss, sample_posterior=False)
|
| 117 |
+
assert torch.isfinite(latent).all(), "Non-finite latent"
|
| 118 |
+
pack = {
|
| 119 |
+
'mean': latent[0].cpu().numpy(),
|
| 120 |
+
}
|
| 121 |
+
saver_executor.submit(saver, sha256, pack)
|
| 122 |
+
|
| 123 |
+
saver_executor.shutdown(wait=True)
|
| 124 |
+
except:
|
| 125 |
+
print("Error happened during processing.")
|
| 126 |
+
|
| 127 |
+
records = pd.DataFrame.from_records(records)
|
| 128 |
+
records.to_csv(os.path.join(opt.output_dir, f'ss_latent_{latent_name}_{opt.rank}.csv'), index=False)
|
dataset_toolkits/extract_feature.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import copy
|
| 3 |
+
import sys
|
| 4 |
+
import json
|
| 5 |
+
import importlib
|
| 6 |
+
import argparse
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pandas as pd
|
| 11 |
+
import utils3d
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from easydict import EasyDict as edict
|
| 14 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 15 |
+
from queue import Queue
|
| 16 |
+
from torchvision import transforms
|
| 17 |
+
from PIL import Image
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
torch.set_grad_enabled(False)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_data(frames, sha256):
|
| 24 |
+
with ThreadPoolExecutor(max_workers=16) as executor:
|
| 25 |
+
def worker(view):
|
| 26 |
+
image_path = os.path.join(opt.output_dir, 'renders', sha256, view['file_path'])
|
| 27 |
+
try:
|
| 28 |
+
image = Image.open(image_path)
|
| 29 |
+
except:
|
| 30 |
+
print(f"Error loading image {image_path}")
|
| 31 |
+
return None
|
| 32 |
+
image = image.resize((518, 518), Image.Resampling.LANCZOS)
|
| 33 |
+
image = np.array(image).astype(np.float32) / 255
|
| 34 |
+
image = image[:, :, :3] * image[:, :, 3:]
|
| 35 |
+
image = torch.from_numpy(image).permute(2, 0, 1).float()
|
| 36 |
+
|
| 37 |
+
c2w = torch.tensor(view['transform_matrix'])
|
| 38 |
+
c2w[:3, 1:3] *= -1
|
| 39 |
+
extrinsics = torch.inverse(c2w)
|
| 40 |
+
fov = view['camera_angle_x']
|
| 41 |
+
intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov))
|
| 42 |
+
|
| 43 |
+
return {
|
| 44 |
+
'image': image,
|
| 45 |
+
'extrinsics': extrinsics,
|
| 46 |
+
'intrinsics': intrinsics
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
datas = executor.map(worker, frames)
|
| 50 |
+
for data in datas:
|
| 51 |
+
if data is not None:
|
| 52 |
+
yield data
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
if __name__ == '__main__':
|
| 56 |
+
parser = argparse.ArgumentParser()
|
| 57 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
| 58 |
+
help='Directory to save the metadata')
|
| 59 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
| 60 |
+
help='Filter objects with aesthetic score lower than this value')
|
| 61 |
+
parser.add_argument('--model', type=str, default='dinov2_vitl14_reg',
|
| 62 |
+
help='Feature extraction model')
|
| 63 |
+
parser.add_argument('--instances', type=str, default=None,
|
| 64 |
+
help='Instances to process')
|
| 65 |
+
parser.add_argument('--batch_size', type=int, default=16)
|
| 66 |
+
parser.add_argument('--rank', type=int, default=0)
|
| 67 |
+
parser.add_argument('--world_size', type=int, default=1)
|
| 68 |
+
opt = parser.parse_args()
|
| 69 |
+
opt = edict(vars(opt))
|
| 70 |
+
|
| 71 |
+
feature_name = opt.model
|
| 72 |
+
os.makedirs(os.path.join(opt.output_dir, 'features', feature_name), exist_ok=True)
|
| 73 |
+
|
| 74 |
+
# load model
|
| 75 |
+
dinov2_model = torch.hub.load('facebookresearch/dinov2', opt.model)
|
| 76 |
+
dinov2_model.eval().cuda()
|
| 77 |
+
transform = transforms.Compose([
|
| 78 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 79 |
+
])
|
| 80 |
+
n_patch = 518 // 14
|
| 81 |
+
|
| 82 |
+
# get file list
|
| 83 |
+
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
| 84 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
| 85 |
+
else:
|
| 86 |
+
raise ValueError('metadata.csv not found')
|
| 87 |
+
if opt.instances is not None:
|
| 88 |
+
with open(opt.instances, 'r') as f:
|
| 89 |
+
instances = f.read().splitlines()
|
| 90 |
+
metadata = metadata[metadata['sha256'].isin(instances)]
|
| 91 |
+
else:
|
| 92 |
+
if opt.filter_low_aesthetic_score is not None:
|
| 93 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
| 94 |
+
if f'feature_{feature_name}' in metadata.columns:
|
| 95 |
+
metadata = metadata[metadata[f'feature_{feature_name}'] == False]
|
| 96 |
+
metadata = metadata[metadata['voxelized'] == True]
|
| 97 |
+
metadata = metadata[metadata['rendered'] == True]
|
| 98 |
+
|
| 99 |
+
start = len(metadata) * opt.rank // opt.world_size
|
| 100 |
+
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
| 101 |
+
metadata = metadata[start:end]
|
| 102 |
+
records = []
|
| 103 |
+
|
| 104 |
+
# filter out objects that are already processed
|
| 105 |
+
sha256s = list(metadata['sha256'].values)
|
| 106 |
+
for sha256 in copy.copy(sha256s):
|
| 107 |
+
if os.path.exists(os.path.join(opt.output_dir, 'features', feature_name, f'{sha256}.npz')):
|
| 108 |
+
records.append({'sha256': sha256, f'feature_{feature_name}' : True})
|
| 109 |
+
sha256s.remove(sha256)
|
| 110 |
+
|
| 111 |
+
# extract features
|
| 112 |
+
load_queue = Queue(maxsize=4)
|
| 113 |
+
try:
|
| 114 |
+
with ThreadPoolExecutor(max_workers=8) as loader_executor, \
|
| 115 |
+
ThreadPoolExecutor(max_workers=8) as saver_executor:
|
| 116 |
+
def loader(sha256):
|
| 117 |
+
try:
|
| 118 |
+
with open(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json'), 'r') as f:
|
| 119 |
+
metadata = json.load(f)
|
| 120 |
+
frames = metadata['frames']
|
| 121 |
+
data = []
|
| 122 |
+
for datum in get_data(frames, sha256):
|
| 123 |
+
datum['image'] = transform(datum['image'])
|
| 124 |
+
data.append(datum)
|
| 125 |
+
positions = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0]
|
| 126 |
+
load_queue.put((sha256, data, positions))
|
| 127 |
+
except Exception as e:
|
| 128 |
+
print(f"Error loading data for {sha256}: {e}")
|
| 129 |
+
|
| 130 |
+
loader_executor.map(loader, sha256s)
|
| 131 |
+
|
| 132 |
+
def saver(sha256, pack, patchtokens, uv):
|
| 133 |
+
pack['patchtokens'] = F.grid_sample(
|
| 134 |
+
patchtokens,
|
| 135 |
+
uv.unsqueeze(1),
|
| 136 |
+
mode='bilinear',
|
| 137 |
+
align_corners=False,
|
| 138 |
+
).squeeze(2).permute(0, 2, 1).cpu().numpy()
|
| 139 |
+
pack['patchtokens'] = np.mean(pack['patchtokens'], axis=0).astype(np.float16)
|
| 140 |
+
save_path = os.path.join(opt.output_dir, 'features', feature_name, f'{sha256}.npz')
|
| 141 |
+
np.savez_compressed(save_path, **pack)
|
| 142 |
+
records.append({'sha256': sha256, f'feature_{feature_name}' : True})
|
| 143 |
+
|
| 144 |
+
for _ in tqdm(range(len(sha256s)), desc="Extracting features"):
|
| 145 |
+
sha256, data, positions = load_queue.get()
|
| 146 |
+
positions = torch.from_numpy(positions).float().cuda()
|
| 147 |
+
indices = ((positions + 0.5) * 64).long()
|
| 148 |
+
assert torch.all(indices >= 0) and torch.all(indices < 64), "Some vertices are out of bounds"
|
| 149 |
+
n_views = len(data)
|
| 150 |
+
N = positions.shape[0]
|
| 151 |
+
pack = {
|
| 152 |
+
'indices': indices.cpu().numpy().astype(np.uint8),
|
| 153 |
+
}
|
| 154 |
+
patchtokens_lst = []
|
| 155 |
+
uv_lst = []
|
| 156 |
+
for i in range(0, n_views, opt.batch_size):
|
| 157 |
+
batch_data = data[i:i+opt.batch_size]
|
| 158 |
+
bs = len(batch_data)
|
| 159 |
+
batch_images = torch.stack([d['image'] for d in batch_data]).cuda()
|
| 160 |
+
batch_extrinsics = torch.stack([d['extrinsics'] for d in batch_data]).cuda()
|
| 161 |
+
batch_intrinsics = torch.stack([d['intrinsics'] for d in batch_data]).cuda()
|
| 162 |
+
features = dinov2_model(batch_images, is_training=True)
|
| 163 |
+
uv = utils3d.torch.project_cv(positions, batch_extrinsics, batch_intrinsics)[0] * 2 - 1
|
| 164 |
+
patchtokens = features['x_prenorm'][:, dinov2_model.num_register_tokens + 1:].permute(0, 2, 1).reshape(bs, 1024, n_patch, n_patch)
|
| 165 |
+
patchtokens_lst.append(patchtokens)
|
| 166 |
+
uv_lst.append(uv)
|
| 167 |
+
patchtokens = torch.cat(patchtokens_lst, dim=0)
|
| 168 |
+
uv = torch.cat(uv_lst, dim=0)
|
| 169 |
+
|
| 170 |
+
# save features
|
| 171 |
+
saver_executor.submit(saver, sha256, pack, patchtokens, uv)
|
| 172 |
+
|
| 173 |
+
saver_executor.shutdown(wait=True)
|
| 174 |
+
except:
|
| 175 |
+
print("Error happened during processing.")
|
| 176 |
+
|
| 177 |
+
records = pd.DataFrame.from_records(records)
|
| 178 |
+
records.to_csv(os.path.join(opt.output_dir, f'feature_{feature_name}_{opt.rank}.csv'), index=False)
|
| 179 |
+
|
dataset_toolkits/render.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import copy
|
| 4 |
+
import sys
|
| 5 |
+
import importlib
|
| 6 |
+
import argparse
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from easydict import EasyDict as edict
|
| 9 |
+
from functools import partial
|
| 10 |
+
from subprocess import DEVNULL, call
|
| 11 |
+
import numpy as np
|
| 12 |
+
from utils import sphere_hammersley_sequence
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz'
|
| 16 |
+
BLENDER_INSTALLATION_PATH = '/tmp'
|
| 17 |
+
BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender'
|
| 18 |
+
|
| 19 |
+
def _install_blender():
|
| 20 |
+
if not os.path.exists(BLENDER_PATH):
|
| 21 |
+
os.system('sudo apt-get update')
|
| 22 |
+
os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6')
|
| 23 |
+
os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}')
|
| 24 |
+
os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}')
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _render(file_path, sha256, output_dir, num_views):
|
| 28 |
+
output_folder = os.path.join(output_dir, 'renders', sha256)
|
| 29 |
+
|
| 30 |
+
# Build camera {yaw, pitch, radius, fov}
|
| 31 |
+
yaws = []
|
| 32 |
+
pitchs = []
|
| 33 |
+
offset = (np.random.rand(), np.random.rand())
|
| 34 |
+
for i in range(num_views):
|
| 35 |
+
y, p = sphere_hammersley_sequence(i, num_views, offset)
|
| 36 |
+
yaws.append(y)
|
| 37 |
+
pitchs.append(p)
|
| 38 |
+
radius = [2] * num_views
|
| 39 |
+
fov = [40 / 180 * np.pi] * num_views
|
| 40 |
+
views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)]
|
| 41 |
+
|
| 42 |
+
args = [
|
| 43 |
+
BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render.py'),
|
| 44 |
+
'--',
|
| 45 |
+
'--views', json.dumps(views),
|
| 46 |
+
'--object', os.path.expanduser(file_path),
|
| 47 |
+
'--resolution', '512',
|
| 48 |
+
'--output_folder', output_folder,
|
| 49 |
+
'--engine', 'CYCLES',
|
| 50 |
+
'--save_mesh',
|
| 51 |
+
]
|
| 52 |
+
if file_path.endswith('.blend'):
|
| 53 |
+
args.insert(1, file_path)
|
| 54 |
+
|
| 55 |
+
call(args, stdout=DEVNULL, stderr=DEVNULL)
|
| 56 |
+
|
| 57 |
+
if os.path.exists(os.path.join(output_folder, 'transforms.json')):
|
| 58 |
+
return {'sha256': sha256, 'rendered': True}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
if __name__ == '__main__':
|
| 62 |
+
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
| 63 |
+
|
| 64 |
+
parser = argparse.ArgumentParser()
|
| 65 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
| 66 |
+
help='Directory to save the metadata')
|
| 67 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
| 68 |
+
help='Filter objects with aesthetic score lower than this value')
|
| 69 |
+
parser.add_argument('--instances', type=str, default=None,
|
| 70 |
+
help='Instances to process')
|
| 71 |
+
parser.add_argument('--num_views', type=int, default=150,
|
| 72 |
+
help='Number of views to render')
|
| 73 |
+
dataset_utils.add_args(parser)
|
| 74 |
+
parser.add_argument('--rank', type=int, default=0)
|
| 75 |
+
parser.add_argument('--world_size', type=int, default=1)
|
| 76 |
+
parser.add_argument('--max_workers', type=int, default=8)
|
| 77 |
+
opt = parser.parse_args(sys.argv[2:])
|
| 78 |
+
opt = edict(vars(opt))
|
| 79 |
+
|
| 80 |
+
os.makedirs(os.path.join(opt.output_dir, 'renders'), exist_ok=True)
|
| 81 |
+
|
| 82 |
+
# install blender
|
| 83 |
+
print('Checking blender...', flush=True)
|
| 84 |
+
_install_blender()
|
| 85 |
+
|
| 86 |
+
# get file list
|
| 87 |
+
if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
| 88 |
+
raise ValueError('metadata.csv not found')
|
| 89 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
| 90 |
+
if opt.instances is None:
|
| 91 |
+
metadata = metadata[metadata['local_path'].notna()]
|
| 92 |
+
if opt.filter_low_aesthetic_score is not None:
|
| 93 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
| 94 |
+
if 'rendered' in metadata.columns:
|
| 95 |
+
metadata = metadata[metadata['rendered'] == False]
|
| 96 |
+
else:
|
| 97 |
+
if os.path.exists(opt.instances):
|
| 98 |
+
with open(opt.instances, 'r') as f:
|
| 99 |
+
instances = f.read().splitlines()
|
| 100 |
+
else:
|
| 101 |
+
instances = opt.instances.split(',')
|
| 102 |
+
metadata = metadata[metadata['sha256'].isin(instances)]
|
| 103 |
+
|
| 104 |
+
start = len(metadata) * opt.rank // opt.world_size
|
| 105 |
+
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
| 106 |
+
metadata = metadata[start:end]
|
| 107 |
+
records = []
|
| 108 |
+
|
| 109 |
+
# filter out objects that are already processed
|
| 110 |
+
for sha256 in copy.copy(metadata['sha256'].values):
|
| 111 |
+
if os.path.exists(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json')):
|
| 112 |
+
records.append({'sha256': sha256, 'rendered': True})
|
| 113 |
+
metadata = metadata[metadata['sha256'] != sha256]
|
| 114 |
+
|
| 115 |
+
print(f'Processing {len(metadata)} objects...')
|
| 116 |
+
|
| 117 |
+
# process objects
|
| 118 |
+
func = partial(_render, output_dir=opt.output_dir, num_views=opt.num_views)
|
| 119 |
+
rendered = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Rendering objects')
|
| 120 |
+
rendered = pd.concat([rendered, pd.DataFrame.from_records(records)])
|
| 121 |
+
rendered.to_csv(os.path.join(opt.output_dir, f'rendered_{opt.rank}.csv'), index=False)
|
dataset_toolkits/render_cond.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import copy
|
| 4 |
+
import sys
|
| 5 |
+
import importlib
|
| 6 |
+
import argparse
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from easydict import EasyDict as edict
|
| 9 |
+
from functools import partial
|
| 10 |
+
from subprocess import DEVNULL, call
|
| 11 |
+
import numpy as np
|
| 12 |
+
from utils import sphere_hammersley_sequence
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz'
|
| 16 |
+
BLENDER_INSTALLATION_PATH = '/tmp'
|
| 17 |
+
BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender'
|
| 18 |
+
|
| 19 |
+
def _install_blender():
|
| 20 |
+
if not os.path.exists(BLENDER_PATH):
|
| 21 |
+
os.system('sudo apt-get update')
|
| 22 |
+
os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6')
|
| 23 |
+
os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}')
|
| 24 |
+
os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}')
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _render_cond(file_path, sha256, output_dir, num_views):
|
| 28 |
+
output_folder = os.path.join(output_dir, 'renders_cond', sha256)
|
| 29 |
+
|
| 30 |
+
# Build camera {yaw, pitch, radius, fov}
|
| 31 |
+
yaws = []
|
| 32 |
+
pitchs = []
|
| 33 |
+
offset = (np.random.rand(), np.random.rand())
|
| 34 |
+
for i in range(num_views):
|
| 35 |
+
y, p = sphere_hammersley_sequence(i, num_views, offset)
|
| 36 |
+
yaws.append(y)
|
| 37 |
+
pitchs.append(p)
|
| 38 |
+
fov_min, fov_max = 10, 70
|
| 39 |
+
radius_min = np.sqrt(3) / 2 / np.sin(fov_max / 360 * np.pi)
|
| 40 |
+
radius_max = np.sqrt(3) / 2 / np.sin(fov_min / 360 * np.pi)
|
| 41 |
+
k_min = 1 / radius_max**2
|
| 42 |
+
k_max = 1 / radius_min**2
|
| 43 |
+
ks = np.random.uniform(k_min, k_max, (1000000,))
|
| 44 |
+
radius = [1 / np.sqrt(k) for k in ks]
|
| 45 |
+
fov = [2 * np.arcsin(np.sqrt(3) / 2 / r) for r in radius]
|
| 46 |
+
views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)]
|
| 47 |
+
|
| 48 |
+
args = [
|
| 49 |
+
BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render.py'),
|
| 50 |
+
'--',
|
| 51 |
+
'--views', json.dumps(views),
|
| 52 |
+
'--object', os.path.expanduser(file_path),
|
| 53 |
+
'--output_folder', os.path.expanduser(output_folder),
|
| 54 |
+
'--resolution', '1024',
|
| 55 |
+
]
|
| 56 |
+
if file_path.endswith('.blend'):
|
| 57 |
+
args.insert(1, file_path)
|
| 58 |
+
|
| 59 |
+
call(args, stdout=DEVNULL)
|
| 60 |
+
|
| 61 |
+
if os.path.exists(os.path.join(output_folder, 'transforms.json')):
|
| 62 |
+
return {'sha256': sha256, 'cond_rendered': True}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == '__main__':
|
| 66 |
+
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
| 67 |
+
|
| 68 |
+
parser = argparse.ArgumentParser()
|
| 69 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
| 70 |
+
help='Directory to save the metadata')
|
| 71 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
| 72 |
+
help='Filter objects with aesthetic score lower than this value')
|
| 73 |
+
parser.add_argument('--instances', type=str, default=None,
|
| 74 |
+
help='Instances to process')
|
| 75 |
+
parser.add_argument('--num_views', type=int, default=24,
|
| 76 |
+
help='Number of views to render')
|
| 77 |
+
dataset_utils.add_args(parser)
|
| 78 |
+
parser.add_argument('--rank', type=int, default=0)
|
| 79 |
+
parser.add_argument('--world_size', type=int, default=1)
|
| 80 |
+
parser.add_argument('--max_workers', type=int, default=8)
|
| 81 |
+
opt = parser.parse_args(sys.argv[2:])
|
| 82 |
+
opt = edict(vars(opt))
|
| 83 |
+
|
| 84 |
+
os.makedirs(os.path.join(opt.output_dir, 'renders_cond'), exist_ok=True)
|
| 85 |
+
|
| 86 |
+
# install blender
|
| 87 |
+
print('Checking blender...', flush=True)
|
| 88 |
+
_install_blender()
|
| 89 |
+
|
| 90 |
+
# get file list
|
| 91 |
+
if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
| 92 |
+
raise ValueError('metadata.csv not found')
|
| 93 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
| 94 |
+
if opt.instances is None:
|
| 95 |
+
metadata = metadata[metadata['local_path'].notna()]
|
| 96 |
+
if opt.filter_low_aesthetic_score is not None:
|
| 97 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
| 98 |
+
if 'cond_rendered' in metadata.columns:
|
| 99 |
+
metadata = metadata[metadata['cond_rendered'] == False]
|
| 100 |
+
else:
|
| 101 |
+
if os.path.exists(opt.instances):
|
| 102 |
+
with open(opt.instances, 'r') as f:
|
| 103 |
+
instances = f.read().splitlines()
|
| 104 |
+
else:
|
| 105 |
+
instances = opt.instances.split(',')
|
| 106 |
+
metadata = metadata[metadata['sha256'].isin(instances)]
|
| 107 |
+
|
| 108 |
+
start = len(metadata) * opt.rank // opt.world_size
|
| 109 |
+
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
| 110 |
+
metadata = metadata[start:end]
|
| 111 |
+
records = []
|
| 112 |
+
|
| 113 |
+
# filter out objects that are already processed
|
| 114 |
+
for sha256 in copy.copy(metadata['sha256'].values):
|
| 115 |
+
if os.path.exists(os.path.join(opt.output_dir, 'renders_cond', sha256, 'transforms.json')):
|
| 116 |
+
records.append({'sha256': sha256, 'cond_rendered': True})
|
| 117 |
+
metadata = metadata[metadata['sha256'] != sha256]
|
| 118 |
+
|
| 119 |
+
print(f'Processing {len(metadata)} objects...')
|
| 120 |
+
|
| 121 |
+
# process objects
|
| 122 |
+
func = partial(_render_cond, output_dir=opt.output_dir, num_views=opt.num_views)
|
| 123 |
+
cond_rendered = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Rendering objects')
|
| 124 |
+
cond_rendered = pd.concat([cond_rendered, pd.DataFrame.from_records(records)])
|
| 125 |
+
cond_rendered.to_csv(os.path.join(opt.output_dir, f'cond_rendered_{opt.rank}.csv'), index=False)
|
dataset_toolkits/setup.sh
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
pip install pillow imageio imageio-ffmpeg tqdm easydict opencv-python-headless pandas open3d objaverse huggingface_hub
|
dataset_toolkits/stat_latent.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import argparse
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from easydict import EasyDict as edict
|
| 8 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
if __name__ == '__main__':
|
| 12 |
+
parser = argparse.ArgumentParser()
|
| 13 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
| 14 |
+
help='Directory to save the metadata')
|
| 15 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
| 16 |
+
help='Filter objects with aesthetic score lower than this value')
|
| 17 |
+
parser.add_argument('--model', type=str, default='dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16',
|
| 18 |
+
help='Latent model to use')
|
| 19 |
+
parser.add_argument('--num_samples', type=int, default=50000,
|
| 20 |
+
help='Number of samples to use for calculating stats')
|
| 21 |
+
opt = parser.parse_args()
|
| 22 |
+
opt = edict(vars(opt))
|
| 23 |
+
|
| 24 |
+
# get file list
|
| 25 |
+
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
| 26 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
| 27 |
+
else:
|
| 28 |
+
raise ValueError('metadata.csv not found')
|
| 29 |
+
if opt.filter_low_aesthetic_score is not None:
|
| 30 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
| 31 |
+
metadata = metadata[metadata[f'latent_{opt.model}'] == True]
|
| 32 |
+
sha256s = metadata['sha256'].values
|
| 33 |
+
sha256s = np.random.choice(sha256s, min(opt.num_samples, len(sha256s)), replace=False)
|
| 34 |
+
|
| 35 |
+
# stats
|
| 36 |
+
means = []
|
| 37 |
+
mean2s = []
|
| 38 |
+
with ThreadPoolExecutor(max_workers=16) as executor, \
|
| 39 |
+
tqdm(total=len(sha256s), desc="Extracting features") as pbar:
|
| 40 |
+
def worker(sha256):
|
| 41 |
+
try:
|
| 42 |
+
feats = np.load(os.path.join(opt.output_dir, 'latents', opt.model, f'{sha256}.npz'))
|
| 43 |
+
feats = feats['feats']
|
| 44 |
+
means.append(feats.mean(axis=0))
|
| 45 |
+
mean2s.append((feats ** 2).mean(axis=0))
|
| 46 |
+
pbar.update()
|
| 47 |
+
except Exception as e:
|
| 48 |
+
print(f"Error extracting features for {sha256}: {e}")
|
| 49 |
+
pbar.update()
|
| 50 |
+
|
| 51 |
+
executor.map(worker, sha256s)
|
| 52 |
+
executor.shutdown(wait=True)
|
| 53 |
+
|
| 54 |
+
mean = np.array(means).mean(axis=0)
|
| 55 |
+
mean2 = np.array(mean2s).mean(axis=0)
|
| 56 |
+
std = np.sqrt(mean2 - mean ** 2)
|
| 57 |
+
|
| 58 |
+
print('mean:', mean)
|
| 59 |
+
print('std:', std)
|
| 60 |
+
|
| 61 |
+
with open(os.path.join(opt.output_dir, 'latents', opt.model, 'stats.json'), 'w') as f:
|
| 62 |
+
json.dump({
|
| 63 |
+
'mean': mean.tolist(),
|
| 64 |
+
'std': std.tolist(),
|
| 65 |
+
}, f, indent=4)
|
| 66 |
+
|
dataset_toolkits/utils.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import hashlib
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_file_hash(file: str) -> str:
|
| 7 |
+
sha256 = hashlib.sha256()
|
| 8 |
+
# Read the file from the path
|
| 9 |
+
with open(file, "rb") as f:
|
| 10 |
+
# Update the hash with the file content
|
| 11 |
+
for byte_block in iter(lambda: f.read(4096), b""):
|
| 12 |
+
sha256.update(byte_block)
|
| 13 |
+
return sha256.hexdigest()
|
| 14 |
+
|
| 15 |
+
# ===============LOW DISCREPANCY SEQUENCES================
|
| 16 |
+
|
| 17 |
+
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
|
| 18 |
+
|
| 19 |
+
def radical_inverse(base, n):
|
| 20 |
+
val = 0
|
| 21 |
+
inv_base = 1.0 / base
|
| 22 |
+
inv_base_n = inv_base
|
| 23 |
+
while n > 0:
|
| 24 |
+
digit = n % base
|
| 25 |
+
val += digit * inv_base_n
|
| 26 |
+
n //= base
|
| 27 |
+
inv_base_n *= inv_base
|
| 28 |
+
return val
|
| 29 |
+
|
| 30 |
+
def halton_sequence(dim, n):
|
| 31 |
+
return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
|
| 32 |
+
|
| 33 |
+
def hammersley_sequence(dim, n, num_samples):
|
| 34 |
+
return [n / num_samples] + halton_sequence(dim - 1, n)
|
| 35 |
+
|
| 36 |
+
def sphere_hammersley_sequence(n, num_samples, offset=(0, 0)):
|
| 37 |
+
u, v = hammersley_sequence(2, n, num_samples)
|
| 38 |
+
u += offset[0] / num_samples
|
| 39 |
+
v += offset[1]
|
| 40 |
+
u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
|
| 41 |
+
theta = np.arccos(1 - 2 * u) - np.pi / 2
|
| 42 |
+
phi = v * 2 * np.pi
|
| 43 |
+
return [phi, theta]
|
dataset_toolkits/voxelize.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import copy
|
| 3 |
+
import sys
|
| 4 |
+
import importlib
|
| 5 |
+
import argparse
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from easydict import EasyDict as edict
|
| 8 |
+
from functools import partial
|
| 9 |
+
import numpy as np
|
| 10 |
+
import open3d as o3d
|
| 11 |
+
import utils3d
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _voxelize(file, sha256, output_dir):
|
| 15 |
+
mesh = o3d.io.read_triangle_mesh(os.path.join(output_dir, 'renders', sha256, 'mesh.ply'))
|
| 16 |
+
# clamp vertices to the range [-0.5, 0.5]
|
| 17 |
+
vertices = np.clip(np.asarray(mesh.vertices), -0.5 + 1e-6, 0.5 - 1e-6)
|
| 18 |
+
mesh.vertices = o3d.utility.Vector3dVector(vertices)
|
| 19 |
+
voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(mesh, voxel_size=1/64, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5))
|
| 20 |
+
vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()])
|
| 21 |
+
assert np.all(vertices >= 0) and np.all(vertices < 64), "Some vertices are out of bounds"
|
| 22 |
+
vertices = (vertices + 0.5) / 64 - 0.5
|
| 23 |
+
utils3d.io.write_ply(os.path.join(output_dir, 'voxels', f'{sha256}.ply'), vertices)
|
| 24 |
+
return {'sha256': sha256, 'voxelized': True, 'num_voxels': len(vertices)}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if __name__ == '__main__':
|
| 28 |
+
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
| 29 |
+
|
| 30 |
+
parser = argparse.ArgumentParser()
|
| 31 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
| 32 |
+
help='Directory to save the metadata')
|
| 33 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
| 34 |
+
help='Filter objects with aesthetic score lower than this value')
|
| 35 |
+
parser.add_argument('--instances', type=str, default=None,
|
| 36 |
+
help='Instances to process')
|
| 37 |
+
parser.add_argument('--num_views', type=int, default=150,
|
| 38 |
+
help='Number of views to render')
|
| 39 |
+
dataset_utils.add_args(parser)
|
| 40 |
+
parser.add_argument('--rank', type=int, default=0)
|
| 41 |
+
parser.add_argument('--world_size', type=int, default=1)
|
| 42 |
+
parser.add_argument('--max_workers', type=int, default=None)
|
| 43 |
+
opt = parser.parse_args(sys.argv[2:])
|
| 44 |
+
opt = edict(vars(opt))
|
| 45 |
+
|
| 46 |
+
os.makedirs(os.path.join(opt.output_dir, 'voxels'), exist_ok=True)
|
| 47 |
+
|
| 48 |
+
# get file list
|
| 49 |
+
if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
| 50 |
+
raise ValueError('metadata.csv not found')
|
| 51 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
| 52 |
+
if opt.instances is None:
|
| 53 |
+
if opt.filter_low_aesthetic_score is not None:
|
| 54 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
| 55 |
+
if 'rendered' not in metadata.columns:
|
| 56 |
+
raise ValueError('metadata.csv does not have "rendered" column, please run "build_metadata.py" first')
|
| 57 |
+
metadata = metadata[metadata['rendered'] == True]
|
| 58 |
+
if 'voxelized' in metadata.columns:
|
| 59 |
+
metadata = metadata[metadata['voxelized'] == False]
|
| 60 |
+
else:
|
| 61 |
+
if os.path.exists(opt.instances):
|
| 62 |
+
with open(opt.instances, 'r') as f:
|
| 63 |
+
instances = f.read().splitlines()
|
| 64 |
+
else:
|
| 65 |
+
instances = opt.instances.split(',')
|
| 66 |
+
metadata = metadata[metadata['sha256'].isin(instances)]
|
| 67 |
+
|
| 68 |
+
start = len(metadata) * opt.rank // opt.world_size
|
| 69 |
+
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
| 70 |
+
metadata = metadata[start:end]
|
| 71 |
+
records = []
|
| 72 |
+
|
| 73 |
+
# filter out objects that are already processed
|
| 74 |
+
for sha256 in copy.copy(metadata['sha256'].values):
|
| 75 |
+
if os.path.exists(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply')):
|
| 76 |
+
pts = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0]
|
| 77 |
+
records.append({'sha256': sha256, 'voxelized': True, 'num_voxels': len(pts)})
|
| 78 |
+
metadata = metadata[metadata['sha256'] != sha256]
|
| 79 |
+
|
| 80 |
+
print(f'Processing {len(metadata)} objects...')
|
| 81 |
+
|
| 82 |
+
# process objects
|
| 83 |
+
func = partial(_voxelize, output_dir=opt.output_dir)
|
| 84 |
+
voxelized = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Voxelizing')
|
| 85 |
+
voxelized = pd.concat([voxelized, pd.DataFrame.from_records(records)])
|
| 86 |
+
voxelized.to_csv(os.path.join(opt.output_dir, f'voxelized_{opt.rank}.csv'), index=False)
|
env.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
-
import spaces
|
| 2 |
-
import os
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
@spaces.GPU(duration=5)
|
| 6 |
-
def check():
|
| 7 |
-
print(os.system("nvidia-smi"))
|
| 8 |
-
print(torch.version.cuda)
|
| 9 |
-
|
| 10 |
-
check()
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
@spaces.GPU(duration=5)
|
| 6 |
+
def check():
|
| 7 |
+
print(os.system("nvidia-smi"))
|
| 8 |
+
print(torch.version.cuda)
|
| 9 |
+
|
| 10 |
+
check()
|
extensions/vox2seq/benchmark.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import torch
|
| 3 |
+
import vox2seq
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
if __name__ == "__main__":
|
| 7 |
+
stats = {
|
| 8 |
+
'z_order_cuda': [],
|
| 9 |
+
'z_order_pytorch': [],
|
| 10 |
+
'hilbert_cuda': [],
|
| 11 |
+
'hilbert_pytorch': [],
|
| 12 |
+
}
|
| 13 |
+
RES = [16, 32, 64, 128, 256]
|
| 14 |
+
for res in RES:
|
| 15 |
+
coords = torch.meshgrid(torch.arange(res), torch.arange(res), torch.arange(res))
|
| 16 |
+
coords = torch.stack(coords, dim=-1).reshape(-1, 3).int().cuda()
|
| 17 |
+
|
| 18 |
+
start = time.time()
|
| 19 |
+
for _ in range(100):
|
| 20 |
+
code_z_cuda = vox2seq.encode(coords, mode='z_order').cuda()
|
| 21 |
+
torch.cuda.synchronize()
|
| 22 |
+
stats['z_order_cuda'].append((time.time() - start) / 100)
|
| 23 |
+
|
| 24 |
+
start = time.time()
|
| 25 |
+
for _ in range(100):
|
| 26 |
+
code_z_pytorch = vox2seq.pytorch.encode(coords, mode='z_order').cuda()
|
| 27 |
+
torch.cuda.synchronize()
|
| 28 |
+
stats['z_order_pytorch'].append((time.time() - start) / 100)
|
| 29 |
+
|
| 30 |
+
start = time.time()
|
| 31 |
+
for _ in range(100):
|
| 32 |
+
code_h_cuda = vox2seq.encode(coords, mode='hilbert').cuda()
|
| 33 |
+
torch.cuda.synchronize()
|
| 34 |
+
stats['hilbert_cuda'].append((time.time() - start) / 100)
|
| 35 |
+
|
| 36 |
+
start = time.time()
|
| 37 |
+
for _ in range(100):
|
| 38 |
+
code_h_pytorch = vox2seq.pytorch.encode(coords, mode='hilbert').cuda()
|
| 39 |
+
torch.cuda.synchronize()
|
| 40 |
+
stats['hilbert_pytorch'].append((time.time() - start) / 100)
|
| 41 |
+
|
| 42 |
+
print(f"{'Resolution':<12}{'Z-Order (CUDA)':<24}{'Z-Order (PyTorch)':<24}{'Hilbert (CUDA)':<24}{'Hilbert (PyTorch)':<24}")
|
| 43 |
+
for res, z_order_cuda, z_order_pytorch, hilbert_cuda, hilbert_pytorch in zip(RES, stats['z_order_cuda'], stats['z_order_pytorch'], stats['hilbert_cuda'], stats['hilbert_pytorch']):
|
| 44 |
+
print(f"{res:<12}{z_order_cuda:<24.6f}{z_order_pytorch:<24.6f}{hilbert_cuda:<24.6f}{hilbert_pytorch:<24.6f}")
|
| 45 |
+
|
extensions/vox2seq/setup.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (C) 2023, Inria
|
| 3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
| 4 |
+
# All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# This software is free for non-commercial, research and evaluation use
|
| 7 |
+
# under the terms of the LICENSE.md file.
|
| 8 |
+
#
|
| 9 |
+
# For inquiries contact george.drettakis@inria.fr
|
| 10 |
+
#
|
| 11 |
+
|
| 12 |
+
from setuptools import setup
|
| 13 |
+
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
|
| 14 |
+
import os
|
| 15 |
+
os.path.dirname(os.path.abspath(__file__))
|
| 16 |
+
|
| 17 |
+
setup(
|
| 18 |
+
name="vox2seq",
|
| 19 |
+
packages=['vox2seq', 'vox2seq.pytorch'],
|
| 20 |
+
ext_modules=[
|
| 21 |
+
CUDAExtension(
|
| 22 |
+
name="vox2seq._C",
|
| 23 |
+
sources=[
|
| 24 |
+
"src/api.cu",
|
| 25 |
+
"src/z_order.cu",
|
| 26 |
+
"src/hilbert.cu",
|
| 27 |
+
"src/ext.cpp",
|
| 28 |
+
],
|
| 29 |
+
)
|
| 30 |
+
],
|
| 31 |
+
cmdclass={
|
| 32 |
+
'build_ext': BuildExtension
|
| 33 |
+
}
|
| 34 |
+
)
|
extensions/vox2seq/src/api.cu
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include "api.h"
|
| 3 |
+
#include "z_order.h"
|
| 4 |
+
#include "hilbert.h"
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
torch::Tensor
|
| 8 |
+
z_order_encode(
|
| 9 |
+
const torch::Tensor& x,
|
| 10 |
+
const torch::Tensor& y,
|
| 11 |
+
const torch::Tensor& z
|
| 12 |
+
) {
|
| 13 |
+
// Allocate output tensor
|
| 14 |
+
torch::Tensor codes = torch::empty_like(x);
|
| 15 |
+
|
| 16 |
+
// Call CUDA kernel
|
| 17 |
+
z_order_encode_cuda<<<(x.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
|
| 18 |
+
x.size(0),
|
| 19 |
+
reinterpret_cast<uint32_t*>(x.contiguous().data_ptr<int>()),
|
| 20 |
+
reinterpret_cast<uint32_t*>(y.contiguous().data_ptr<int>()),
|
| 21 |
+
reinterpret_cast<uint32_t*>(z.contiguous().data_ptr<int>()),
|
| 22 |
+
reinterpret_cast<uint32_t*>(codes.data_ptr<int>())
|
| 23 |
+
);
|
| 24 |
+
|
| 25 |
+
return codes;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
|
| 30 |
+
z_order_decode(
|
| 31 |
+
const torch::Tensor& codes
|
| 32 |
+
) {
|
| 33 |
+
// Allocate output tensors
|
| 34 |
+
torch::Tensor x = torch::empty_like(codes);
|
| 35 |
+
torch::Tensor y = torch::empty_like(codes);
|
| 36 |
+
torch::Tensor z = torch::empty_like(codes);
|
| 37 |
+
|
| 38 |
+
// Call CUDA kernel
|
| 39 |
+
z_order_decode_cuda<<<(codes.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
|
| 40 |
+
codes.size(0),
|
| 41 |
+
reinterpret_cast<uint32_t*>(codes.contiguous().data_ptr<int>()),
|
| 42 |
+
reinterpret_cast<uint32_t*>(x.data_ptr<int>()),
|
| 43 |
+
reinterpret_cast<uint32_t*>(y.data_ptr<int>()),
|
| 44 |
+
reinterpret_cast<uint32_t*>(z.data_ptr<int>())
|
| 45 |
+
);
|
| 46 |
+
|
| 47 |
+
return std::make_tuple(x, y, z);
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
torch::Tensor
|
| 52 |
+
hilbert_encode(
|
| 53 |
+
const torch::Tensor& x,
|
| 54 |
+
const torch::Tensor& y,
|
| 55 |
+
const torch::Tensor& z
|
| 56 |
+
) {
|
| 57 |
+
// Allocate output tensor
|
| 58 |
+
torch::Tensor codes = torch::empty_like(x);
|
| 59 |
+
|
| 60 |
+
// Call CUDA kernel
|
| 61 |
+
hilbert_encode_cuda<<<(x.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
|
| 62 |
+
x.size(0),
|
| 63 |
+
reinterpret_cast<uint32_t*>(x.contiguous().data_ptr<int>()),
|
| 64 |
+
reinterpret_cast<uint32_t*>(y.contiguous().data_ptr<int>()),
|
| 65 |
+
reinterpret_cast<uint32_t*>(z.contiguous().data_ptr<int>()),
|
| 66 |
+
reinterpret_cast<uint32_t*>(codes.data_ptr<int>())
|
| 67 |
+
);
|
| 68 |
+
|
| 69 |
+
return codes;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
|
| 74 |
+
hilbert_decode(
|
| 75 |
+
const torch::Tensor& codes
|
| 76 |
+
) {
|
| 77 |
+
// Allocate output tensors
|
| 78 |
+
torch::Tensor x = torch::empty_like(codes);
|
| 79 |
+
torch::Tensor y = torch::empty_like(codes);
|
| 80 |
+
torch::Tensor z = torch::empty_like(codes);
|
| 81 |
+
|
| 82 |
+
// Call CUDA kernel
|
| 83 |
+
hilbert_decode_cuda<<<(codes.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
|
| 84 |
+
codes.size(0),
|
| 85 |
+
reinterpret_cast<uint32_t*>(codes.contiguous().data_ptr<int>()),
|
| 86 |
+
reinterpret_cast<uint32_t*>(x.data_ptr<int>()),
|
| 87 |
+
reinterpret_cast<uint32_t*>(y.data_ptr<int>()),
|
| 88 |
+
reinterpret_cast<uint32_t*>(z.data_ptr<int>())
|
| 89 |
+
);
|
| 90 |
+
|
| 91 |
+
return std::make_tuple(x, y, z);
|
| 92 |
+
}
|
extensions/vox2seq/src/api.h
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Serialize a voxel grid
|
| 3 |
+
*
|
| 4 |
+
* Copyright (C) 2024, Jianfeng XIANG <belljig@outlook.com>
|
| 5 |
+
* All rights reserved.
|
| 6 |
+
*
|
| 7 |
+
* Licensed under The MIT License [see LICENSE for details]
|
| 8 |
+
*
|
| 9 |
+
* Written by Jianfeng XIANG
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
#pragma once
|
| 13 |
+
#include <torch/extension.h>
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
#define BLOCK_SIZE 256
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
/**
|
| 20 |
+
* Z-order encode 3D points
|
| 21 |
+
*
|
| 22 |
+
* @param x [N] tensor containing the x coordinates
|
| 23 |
+
* @param y [N] tensor containing the y coordinates
|
| 24 |
+
* @param z [N] tensor containing the z coordinates
|
| 25 |
+
*
|
| 26 |
+
* @return [N] tensor containing the z-order encoded values
|
| 27 |
+
*/
|
| 28 |
+
torch::Tensor
|
| 29 |
+
z_order_encode(
|
| 30 |
+
const torch::Tensor& x,
|
| 31 |
+
const torch::Tensor& y,
|
| 32 |
+
const torch::Tensor& z
|
| 33 |
+
);
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
/**
|
| 37 |
+
* Z-order decode 3D points
|
| 38 |
+
*
|
| 39 |
+
* @param codes [N] tensor containing the z-order encoded values
|
| 40 |
+
*
|
| 41 |
+
* @return 3 tensors [N] containing the x, y, z coordinates
|
| 42 |
+
*/
|
| 43 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
|
| 44 |
+
z_order_decode(
|
| 45 |
+
const torch::Tensor& codes
|
| 46 |
+
);
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
/**
|
| 50 |
+
* Hilbert encode 3D points
|
| 51 |
+
*
|
| 52 |
+
* @param x [N] tensor containing the x coordinates
|
| 53 |
+
* @param y [N] tensor containing the y coordinates
|
| 54 |
+
* @param z [N] tensor containing the z coordinates
|
| 55 |
+
*
|
| 56 |
+
* @return [N] tensor containing the Hilbert encoded values
|
| 57 |
+
*/
|
| 58 |
+
torch::Tensor
|
| 59 |
+
hilbert_encode(
|
| 60 |
+
const torch::Tensor& x,
|
| 61 |
+
const torch::Tensor& y,
|
| 62 |
+
const torch::Tensor& z
|
| 63 |
+
);
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
/**
|
| 67 |
+
* Hilbert decode 3D points
|
| 68 |
+
*
|
| 69 |
+
* @param codes [N] tensor containing the Hilbert encoded values
|
| 70 |
+
*
|
| 71 |
+
* @return 3 tensors [N] containing the x, y, z coordinates
|
| 72 |
+
*/
|
| 73 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
|
| 74 |
+
hilbert_decode(
|
| 75 |
+
const torch::Tensor& codes
|
| 76 |
+
);
|
extensions/vox2seq/src/ext.cpp
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include "api.h"
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 6 |
+
m.def("z_order_encode", &z_order_encode);
|
| 7 |
+
m.def("z_order_decode", &z_order_decode);
|
| 8 |
+
m.def("hilbert_encode", &hilbert_encode);
|
| 9 |
+
m.def("hilbert_decode", &hilbert_decode);
|
| 10 |
+
}
|
extensions/vox2seq/src/hilbert.cu
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <cuda.h>
|
| 2 |
+
#include <cuda_runtime.h>
|
| 3 |
+
#include <device_launch_parameters.h>
|
| 4 |
+
|
| 5 |
+
#include <cooperative_groups.h>
|
| 6 |
+
#include <cooperative_groups/memcpy_async.h>
|
| 7 |
+
namespace cg = cooperative_groups;
|
| 8 |
+
|
| 9 |
+
#include "hilbert.h"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
// Expands a 10-bit integer into 30 bits by inserting 2 zeros after each bit.
|
| 13 |
+
static __device__ uint32_t expandBits(uint32_t v)
|
| 14 |
+
{
|
| 15 |
+
v = (v * 0x00010001u) & 0xFF0000FFu;
|
| 16 |
+
v = (v * 0x00000101u) & 0x0F00F00Fu;
|
| 17 |
+
v = (v * 0x00000011u) & 0xC30C30C3u;
|
| 18 |
+
v = (v * 0x00000005u) & 0x49249249u;
|
| 19 |
+
return v;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
// Removes 2 zeros after each bit in a 30-bit integer.
|
| 24 |
+
static __device__ uint32_t extractBits(uint32_t v)
|
| 25 |
+
{
|
| 26 |
+
v = v & 0x49249249;
|
| 27 |
+
v = (v ^ (v >> 2)) & 0x030C30C3u;
|
| 28 |
+
v = (v ^ (v >> 4)) & 0x0300F00Fu;
|
| 29 |
+
v = (v ^ (v >> 8)) & 0x030000FFu;
|
| 30 |
+
v = (v ^ (v >> 16)) & 0x000003FFu;
|
| 31 |
+
return v;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
__global__ void hilbert_encode_cuda(
|
| 36 |
+
size_t N,
|
| 37 |
+
const uint32_t* x,
|
| 38 |
+
const uint32_t* y,
|
| 39 |
+
const uint32_t* z,
|
| 40 |
+
uint32_t* codes
|
| 41 |
+
) {
|
| 42 |
+
size_t thread_id = cg::this_grid().thread_rank();
|
| 43 |
+
if (thread_id >= N) return;
|
| 44 |
+
|
| 45 |
+
uint32_t point[3] = {x[thread_id], y[thread_id], z[thread_id]};
|
| 46 |
+
|
| 47 |
+
uint32_t m = 1 << 9, q, p, t;
|
| 48 |
+
|
| 49 |
+
// Inverse undo excess work
|
| 50 |
+
q = m;
|
| 51 |
+
while (q > 1) {
|
| 52 |
+
p = q - 1;
|
| 53 |
+
for (int i = 0; i < 3; i++) {
|
| 54 |
+
if (point[i] & q) {
|
| 55 |
+
point[0] ^= p; // invert
|
| 56 |
+
} else {
|
| 57 |
+
t = (point[0] ^ point[i]) & p;
|
| 58 |
+
point[0] ^= t;
|
| 59 |
+
point[i] ^= t;
|
| 60 |
+
}
|
| 61 |
+
}
|
| 62 |
+
q >>= 1;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
// Gray encode
|
| 66 |
+
for (int i = 1; i < 3; i++) {
|
| 67 |
+
point[i] ^= point[i - 1];
|
| 68 |
+
}
|
| 69 |
+
t = 0;
|
| 70 |
+
q = m;
|
| 71 |
+
while (q > 1) {
|
| 72 |
+
if (point[2] & q) {
|
| 73 |
+
t ^= q - 1;
|
| 74 |
+
}
|
| 75 |
+
q >>= 1;
|
| 76 |
+
}
|
| 77 |
+
for (int i = 0; i < 3; i++) {
|
| 78 |
+
point[i] ^= t;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
// Convert to 3D Hilbert code
|
| 82 |
+
uint32_t xx = expandBits(point[0]);
|
| 83 |
+
uint32_t yy = expandBits(point[1]);
|
| 84 |
+
uint32_t zz = expandBits(point[2]);
|
| 85 |
+
|
| 86 |
+
codes[thread_id] = xx * 4 + yy * 2 + zz;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
__global__ void hilbert_decode_cuda(
|
| 91 |
+
size_t N,
|
| 92 |
+
const uint32_t* codes,
|
| 93 |
+
uint32_t* x,
|
| 94 |
+
uint32_t* y,
|
| 95 |
+
uint32_t* z
|
| 96 |
+
) {
|
| 97 |
+
size_t thread_id = cg::this_grid().thread_rank();
|
| 98 |
+
if (thread_id >= N) return;
|
| 99 |
+
|
| 100 |
+
uint32_t point[3];
|
| 101 |
+
point[0] = extractBits(codes[thread_id] >> 2);
|
| 102 |
+
point[1] = extractBits(codes[thread_id] >> 1);
|
| 103 |
+
point[2] = extractBits(codes[thread_id]);
|
| 104 |
+
|
| 105 |
+
uint32_t m = 2 << 9, q, p, t;
|
| 106 |
+
|
| 107 |
+
// Gray decode by H ^ (H/2)
|
| 108 |
+
t = point[2] >> 1;
|
| 109 |
+
for (int i = 2; i > 0; i--) {
|
| 110 |
+
point[i] ^= point[i - 1];
|
| 111 |
+
}
|
| 112 |
+
point[0] ^= t;
|
| 113 |
+
|
| 114 |
+
// Undo excess work
|
| 115 |
+
q = 2;
|
| 116 |
+
while (q != m) {
|
| 117 |
+
p = q - 1;
|
| 118 |
+
for (int i = 2; i >= 0; i--) {
|
| 119 |
+
if (point[i] & q) {
|
| 120 |
+
point[0] ^= p;
|
| 121 |
+
} else {
|
| 122 |
+
t = (point[0] ^ point[i]) & p;
|
| 123 |
+
point[0] ^= t;
|
| 124 |
+
point[i] ^= t;
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
q <<= 1;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
x[thread_id] = point[0];
|
| 131 |
+
y[thread_id] = point[1];
|
| 132 |
+
z[thread_id] = point[2];
|
| 133 |
+
}
|
extensions/vox2seq/src/hilbert.h
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
/**
|
| 4 |
+
* Hilbert encode 3D points
|
| 5 |
+
*
|
| 6 |
+
* @param x [N] tensor containing the x coordinates
|
| 7 |
+
* @param y [N] tensor containing the y coordinates
|
| 8 |
+
* @param z [N] tensor containing the z coordinates
|
| 9 |
+
*
|
| 10 |
+
* @return [N] tensor containing the z-order encoded values
|
| 11 |
+
*/
|
| 12 |
+
__global__ void hilbert_encode_cuda(
|
| 13 |
+
size_t N,
|
| 14 |
+
const uint32_t* x,
|
| 15 |
+
const uint32_t* y,
|
| 16 |
+
const uint32_t* z,
|
| 17 |
+
uint32_t* codes
|
| 18 |
+
);
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
/**
|
| 22 |
+
* Hilbert decode 3D points
|
| 23 |
+
*
|
| 24 |
+
* @param codes [N] tensor containing the z-order encoded values
|
| 25 |
+
* @param x [N] tensor containing the x coordinates
|
| 26 |
+
* @param y [N] tensor containing the y coordinates
|
| 27 |
+
* @param z [N] tensor containing the z coordinates
|
| 28 |
+
*/
|
| 29 |
+
__global__ void hilbert_decode_cuda(
|
| 30 |
+
size_t N,
|
| 31 |
+
const uint32_t* codes,
|
| 32 |
+
uint32_t* x,
|
| 33 |
+
uint32_t* y,
|
| 34 |
+
uint32_t* z
|
| 35 |
+
);
|
extensions/vox2seq/src/z_order.cu
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <cuda.h>
|
| 2 |
+
#include <cuda_runtime.h>
|
| 3 |
+
#include <device_launch_parameters.h>
|
| 4 |
+
|
| 5 |
+
#include <cooperative_groups.h>
|
| 6 |
+
#include <cooperative_groups/memcpy_async.h>
|
| 7 |
+
namespace cg = cooperative_groups;
|
| 8 |
+
|
| 9 |
+
#include "z_order.h"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
// Expands a 10-bit integer into 30 bits by inserting 2 zeros after each bit.
|
| 13 |
+
static __device__ uint32_t expandBits(uint32_t v)
|
| 14 |
+
{
|
| 15 |
+
v = (v * 0x00010001u) & 0xFF0000FFu;
|
| 16 |
+
v = (v * 0x00000101u) & 0x0F00F00Fu;
|
| 17 |
+
v = (v * 0x00000011u) & 0xC30C30C3u;
|
| 18 |
+
v = (v * 0x00000005u) & 0x49249249u;
|
| 19 |
+
return v;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
// Removes 2 zeros after each bit in a 30-bit integer.
|
| 24 |
+
static __device__ uint32_t extractBits(uint32_t v)
|
| 25 |
+
{
|
| 26 |
+
v = v & 0x49249249;
|
| 27 |
+
v = (v ^ (v >> 2)) & 0x030C30C3u;
|
| 28 |
+
v = (v ^ (v >> 4)) & 0x0300F00Fu;
|
| 29 |
+
v = (v ^ (v >> 8)) & 0x030000FFu;
|
| 30 |
+
v = (v ^ (v >> 16)) & 0x000003FFu;
|
| 31 |
+
return v;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
__global__ void z_order_encode_cuda(
|
| 36 |
+
size_t N,
|
| 37 |
+
const uint32_t* x,
|
| 38 |
+
const uint32_t* y,
|
| 39 |
+
const uint32_t* z,
|
| 40 |
+
uint32_t* codes
|
| 41 |
+
) {
|
| 42 |
+
size_t thread_id = cg::this_grid().thread_rank();
|
| 43 |
+
if (thread_id >= N) return;
|
| 44 |
+
|
| 45 |
+
uint32_t xx = expandBits(x[thread_id]);
|
| 46 |
+
uint32_t yy = expandBits(y[thread_id]);
|
| 47 |
+
uint32_t zz = expandBits(z[thread_id]);
|
| 48 |
+
|
| 49 |
+
codes[thread_id] = xx * 4 + yy * 2 + zz;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
__global__ void z_order_decode_cuda(
|
| 54 |
+
size_t N,
|
| 55 |
+
const uint32_t* codes,
|
| 56 |
+
uint32_t* x,
|
| 57 |
+
uint32_t* y,
|
| 58 |
+
uint32_t* z
|
| 59 |
+
) {
|
| 60 |
+
size_t thread_id = cg::this_grid().thread_rank();
|
| 61 |
+
if (thread_id >= N) return;
|
| 62 |
+
|
| 63 |
+
x[thread_id] = extractBits(codes[thread_id] >> 2);
|
| 64 |
+
y[thread_id] = extractBits(codes[thread_id] >> 1);
|
| 65 |
+
z[thread_id] = extractBits(codes[thread_id]);
|
| 66 |
+
}
|
extensions/vox2seq/src/z_order.h
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
/**
|
| 4 |
+
* Z-order encode 3D points
|
| 5 |
+
*
|
| 6 |
+
* @param x [N] tensor containing the x coordinates
|
| 7 |
+
* @param y [N] tensor containing the y coordinates
|
| 8 |
+
* @param z [N] tensor containing the z coordinates
|
| 9 |
+
*
|
| 10 |
+
* @return [N] tensor containing the z-order encoded values
|
| 11 |
+
*/
|
| 12 |
+
__global__ void z_order_encode_cuda(
|
| 13 |
+
size_t N,
|
| 14 |
+
const uint32_t* x,
|
| 15 |
+
const uint32_t* y,
|
| 16 |
+
const uint32_t* z,
|
| 17 |
+
uint32_t* codes
|
| 18 |
+
);
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
/**
|
| 22 |
+
* Z-order decode 3D points
|
| 23 |
+
*
|
| 24 |
+
* @param codes [N] tensor containing the z-order encoded values
|
| 25 |
+
* @param x [N] tensor containing the x coordinates
|
| 26 |
+
* @param y [N] tensor containing the y coordinates
|
| 27 |
+
* @param z [N] tensor containing the z coordinates
|
| 28 |
+
*/
|
| 29 |
+
__global__ void z_order_decode_cuda(
|
| 30 |
+
size_t N,
|
| 31 |
+
const uint32_t* codes,
|
| 32 |
+
uint32_t* x,
|
| 33 |
+
uint32_t* y,
|
| 34 |
+
uint32_t* z
|
| 35 |
+
);
|
extensions/vox2seq/test.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import vox2seq
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
RES = 256
|
| 7 |
+
coords = torch.meshgrid(torch.arange(RES), torch.arange(RES), torch.arange(RES))
|
| 8 |
+
coords = torch.stack(coords, dim=-1).reshape(-1, 3).int().cuda()
|
| 9 |
+
code_z_cuda = vox2seq.encode(coords, mode='z_order')
|
| 10 |
+
code_z_pytorch = vox2seq.pytorch.encode(coords, mode='z_order')
|
| 11 |
+
code_h_cuda = vox2seq.encode(coords, mode='hilbert')
|
| 12 |
+
code_h_pytorch = vox2seq.pytorch.encode(coords, mode='hilbert')
|
| 13 |
+
assert torch.equal(code_z_cuda, code_z_pytorch)
|
| 14 |
+
assert torch.equal(code_h_cuda, code_h_pytorch)
|
| 15 |
+
|
| 16 |
+
code = torch.arange(RES**3).int().cuda()
|
| 17 |
+
coords_z_cuda = vox2seq.decode(code, mode='z_order')
|
| 18 |
+
coords_z_pytorch = vox2seq.pytorch.decode(code, mode='z_order')
|
| 19 |
+
coords_h_cuda = vox2seq.decode(code, mode='hilbert')
|
| 20 |
+
coords_h_pytorch = vox2seq.pytorch.decode(code, mode='hilbert')
|
| 21 |
+
assert torch.equal(coords_z_cuda, coords_z_pytorch)
|
| 22 |
+
assert torch.equal(coords_h_cuda, coords_h_pytorch)
|
| 23 |
+
|
| 24 |
+
print("All tests passed.")
|
| 25 |
+
|
extensions/vox2seq/vox2seq/__init__.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from typing import *
|
| 3 |
+
import torch
|
| 4 |
+
from . import _C
|
| 5 |
+
from . import pytorch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@torch.no_grad()
|
| 9 |
+
def encode(coords: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor:
|
| 10 |
+
"""
|
| 11 |
+
Encodes 3D coordinates into a 30-bit code.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
coords: a tensor of shape [N, 3] containing the 3D coordinates.
|
| 15 |
+
permute: the permutation of the coordinates.
|
| 16 |
+
mode: the encoding mode to use.
|
| 17 |
+
"""
|
| 18 |
+
assert coords.shape[-1] == 3 and coords.ndim == 2, "Input coordinates must be of shape [N, 3]"
|
| 19 |
+
x = coords[:, permute[0]].int()
|
| 20 |
+
y = coords[:, permute[1]].int()
|
| 21 |
+
z = coords[:, permute[2]].int()
|
| 22 |
+
if mode == 'z_order':
|
| 23 |
+
return _C.z_order_encode(x, y, z)
|
| 24 |
+
elif mode == 'hilbert':
|
| 25 |
+
return _C.hilbert_encode(x, y, z)
|
| 26 |
+
else:
|
| 27 |
+
raise ValueError(f"Unknown encoding mode: {mode}")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@torch.no_grad()
|
| 31 |
+
def decode(code: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor:
|
| 32 |
+
"""
|
| 33 |
+
Decodes a 30-bit code into 3D coordinates.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
code: a tensor of shape [N] containing the 30-bit code.
|
| 37 |
+
permute: the permutation of the coordinates.
|
| 38 |
+
mode: the decoding mode to use.
|
| 39 |
+
"""
|
| 40 |
+
assert code.ndim == 1, "Input code must be of shape [N]"
|
| 41 |
+
if mode == 'z_order':
|
| 42 |
+
coords = _C.z_order_decode(code)
|
| 43 |
+
elif mode == 'hilbert':
|
| 44 |
+
coords = _C.hilbert_decode(code)
|
| 45 |
+
else:
|
| 46 |
+
raise ValueError(f"Unknown decoding mode: {mode}")
|
| 47 |
+
x = coords[permute.index(0)]
|
| 48 |
+
y = coords[permute.index(1)]
|
| 49 |
+
z = coords[permute.index(2)]
|
| 50 |
+
return torch.stack([x, y, z], dim=-1)
|
extensions/vox2seq/vox2seq/pytorch/__init__.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import *
|
| 3 |
+
|
| 4 |
+
from .default import (
|
| 5 |
+
encode,
|
| 6 |
+
decode,
|
| 7 |
+
z_order_encode,
|
| 8 |
+
z_order_decode,
|
| 9 |
+
hilbert_encode,
|
| 10 |
+
hilbert_decode,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@torch.no_grad()
|
| 15 |
+
def encode(coords: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor:
|
| 16 |
+
"""
|
| 17 |
+
Encodes 3D coordinates into a 30-bit code.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
coords: a tensor of shape [N, 3] containing the 3D coordinates.
|
| 21 |
+
permute: the permutation of the coordinates.
|
| 22 |
+
mode: the encoding mode to use.
|
| 23 |
+
"""
|
| 24 |
+
if mode == 'z_order':
|
| 25 |
+
return z_order_encode(coords[:, permute], depth=10).int()
|
| 26 |
+
elif mode == 'hilbert':
|
| 27 |
+
return hilbert_encode(coords[:, permute], depth=10).int()
|
| 28 |
+
else:
|
| 29 |
+
raise ValueError(f"Unknown encoding mode: {mode}")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@torch.no_grad()
|
| 33 |
+
def decode(code: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor:
|
| 34 |
+
"""
|
| 35 |
+
Decodes a 30-bit code into 3D coordinates.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
code: a tensor of shape [N] containing the 30-bit code.
|
| 39 |
+
permute: the permutation of the coordinates.
|
| 40 |
+
mode: the decoding mode to use.
|
| 41 |
+
"""
|
| 42 |
+
if mode == 'z_order':
|
| 43 |
+
return z_order_decode(code, depth=10)[:, permute].float()
|
| 44 |
+
elif mode == 'hilbert':
|
| 45 |
+
return hilbert_decode(code, depth=10)[:, permute].float()
|
| 46 |
+
else:
|
| 47 |
+
raise ValueError(f"Unknown decoding mode: {mode}")
|
| 48 |
+
|
extensions/vox2seq/vox2seq/pytorch/default.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from .z_order import xyz2key as z_order_encode_
|
| 3 |
+
from .z_order import key2xyz as z_order_decode_
|
| 4 |
+
from .hilbert import encode as hilbert_encode_
|
| 5 |
+
from .hilbert import decode as hilbert_decode_
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@torch.inference_mode()
|
| 9 |
+
def encode(grid_coord, batch=None, depth=16, order="z"):
|
| 10 |
+
assert order in {"z", "z-trans", "hilbert", "hilbert-trans"}
|
| 11 |
+
if order == "z":
|
| 12 |
+
code = z_order_encode(grid_coord, depth=depth)
|
| 13 |
+
elif order == "z-trans":
|
| 14 |
+
code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth)
|
| 15 |
+
elif order == "hilbert":
|
| 16 |
+
code = hilbert_encode(grid_coord, depth=depth)
|
| 17 |
+
elif order == "hilbert-trans":
|
| 18 |
+
code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth)
|
| 19 |
+
else:
|
| 20 |
+
raise NotImplementedError
|
| 21 |
+
if batch is not None:
|
| 22 |
+
batch = batch.long()
|
| 23 |
+
code = batch << depth * 3 | code
|
| 24 |
+
return code
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@torch.inference_mode()
|
| 28 |
+
def decode(code, depth=16, order="z"):
|
| 29 |
+
assert order in {"z", "hilbert"}
|
| 30 |
+
batch = code >> depth * 3
|
| 31 |
+
code = code & ((1 << depth * 3) - 1)
|
| 32 |
+
if order == "z":
|
| 33 |
+
grid_coord = z_order_decode(code, depth=depth)
|
| 34 |
+
elif order == "hilbert":
|
| 35 |
+
grid_coord = hilbert_decode(code, depth=depth)
|
| 36 |
+
else:
|
| 37 |
+
raise NotImplementedError
|
| 38 |
+
return grid_coord, batch
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def z_order_encode(grid_coord: torch.Tensor, depth: int = 16):
|
| 42 |
+
x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long()
|
| 43 |
+
# we block the support to batch, maintain batched code in Point class
|
| 44 |
+
code = z_order_encode_(x, y, z, b=None, depth=depth)
|
| 45 |
+
return code
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def z_order_decode(code: torch.Tensor, depth):
|
| 49 |
+
x, y, z, _ = z_order_decode_(code, depth=depth)
|
| 50 |
+
grid_coord = torch.stack([x, y, z], dim=-1) # (N, 3)
|
| 51 |
+
return grid_coord
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def hilbert_encode(grid_coord: torch.Tensor, depth: int = 16):
|
| 55 |
+
return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def hilbert_decode(code: torch.Tensor, depth: int = 16):
|
| 59 |
+
return hilbert_decode_(code, num_dims=3, num_bits=depth)
|
extensions/vox2seq/vox2seq/pytorch/hilbert.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hilbert Order
|
| 3 |
+
Modified from https://github.com/PrincetonLIPS/numpy-hilbert-curve
|
| 4 |
+
|
| 5 |
+
Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com), Kaixin Xu
|
| 6 |
+
Please cite our work if the code is helpful to you.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def right_shift(binary, k=1, axis=-1):
|
| 13 |
+
"""Right shift an array of binary values.
|
| 14 |
+
|
| 15 |
+
Parameters:
|
| 16 |
+
-----------
|
| 17 |
+
binary: An ndarray of binary values.
|
| 18 |
+
|
| 19 |
+
k: The number of bits to shift. Default 1.
|
| 20 |
+
|
| 21 |
+
axis: The axis along which to shift. Default -1.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
--------
|
| 25 |
+
Returns an ndarray with zero prepended and the ends truncated, along
|
| 26 |
+
whatever axis was specified."""
|
| 27 |
+
|
| 28 |
+
# If we're shifting the whole thing, just return zeros.
|
| 29 |
+
if binary.shape[axis] <= k:
|
| 30 |
+
return torch.zeros_like(binary)
|
| 31 |
+
|
| 32 |
+
# Determine the padding pattern.
|
| 33 |
+
# padding = [(0,0)] * len(binary.shape)
|
| 34 |
+
# padding[axis] = (k,0)
|
| 35 |
+
|
| 36 |
+
# Determine the slicing pattern to eliminate just the last one.
|
| 37 |
+
slicing = [slice(None)] * len(binary.shape)
|
| 38 |
+
slicing[axis] = slice(None, -k)
|
| 39 |
+
shifted = torch.nn.functional.pad(
|
| 40 |
+
binary[tuple(slicing)], (k, 0), mode="constant", value=0
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
return shifted
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def binary2gray(binary, axis=-1):
|
| 47 |
+
"""Convert an array of binary values into Gray codes.
|
| 48 |
+
|
| 49 |
+
This uses the classic X ^ (X >> 1) trick to compute the Gray code.
|
| 50 |
+
|
| 51 |
+
Parameters:
|
| 52 |
+
-----------
|
| 53 |
+
binary: An ndarray of binary values.
|
| 54 |
+
|
| 55 |
+
axis: The axis along which to compute the gray code. Default=-1.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
--------
|
| 59 |
+
Returns an ndarray of Gray codes.
|
| 60 |
+
"""
|
| 61 |
+
shifted = right_shift(binary, axis=axis)
|
| 62 |
+
|
| 63 |
+
# Do the X ^ (X >> 1) trick.
|
| 64 |
+
gray = torch.logical_xor(binary, shifted)
|
| 65 |
+
|
| 66 |
+
return gray
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def gray2binary(gray, axis=-1):
|
| 70 |
+
"""Convert an array of Gray codes back into binary values.
|
| 71 |
+
|
| 72 |
+
Parameters:
|
| 73 |
+
-----------
|
| 74 |
+
gray: An ndarray of gray codes.
|
| 75 |
+
|
| 76 |
+
axis: The axis along which to perform Gray decoding. Default=-1.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
--------
|
| 80 |
+
Returns an ndarray of binary values.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
# Loop the log2(bits) number of times necessary, with shift and xor.
|
| 84 |
+
shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1)
|
| 85 |
+
while shift > 0:
|
| 86 |
+
gray = torch.logical_xor(gray, right_shift(gray, shift))
|
| 87 |
+
shift = torch.div(shift, 2, rounding_mode="floor")
|
| 88 |
+
return gray
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def encode(locs, num_dims, num_bits):
|
| 92 |
+
"""Decode an array of locations in a hypercube into a Hilbert integer.
|
| 93 |
+
|
| 94 |
+
This is a vectorized-ish version of the Hilbert curve implementation by John
|
| 95 |
+
Skilling as described in:
|
| 96 |
+
|
| 97 |
+
Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
|
| 98 |
+
Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
|
| 99 |
+
|
| 100 |
+
Params:
|
| 101 |
+
-------
|
| 102 |
+
locs - An ndarray of locations in a hypercube of num_dims dimensions, in
|
| 103 |
+
which each dimension runs from 0 to 2**num_bits-1. The shape can
|
| 104 |
+
be arbitrary, as long as the last dimension of the same has size
|
| 105 |
+
num_dims.
|
| 106 |
+
|
| 107 |
+
num_dims - The dimensionality of the hypercube. Integer.
|
| 108 |
+
|
| 109 |
+
num_bits - The number of bits for each dimension. Integer.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
--------
|
| 113 |
+
The output is an ndarray of uint64 integers with the same shape as the
|
| 114 |
+
input, excluding the last dimension, which needs to be num_dims.
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
# Keep around the original shape for later.
|
| 118 |
+
orig_shape = locs.shape
|
| 119 |
+
bitpack_mask = 1 << torch.arange(0, 8).to(locs.device)
|
| 120 |
+
bitpack_mask_rev = bitpack_mask.flip(-1)
|
| 121 |
+
|
| 122 |
+
if orig_shape[-1] != num_dims:
|
| 123 |
+
raise ValueError(
|
| 124 |
+
"""
|
| 125 |
+
The shape of locs was surprising in that the last dimension was of size
|
| 126 |
+
%d, but num_dims=%d. These need to be equal.
|
| 127 |
+
"""
|
| 128 |
+
% (orig_shape[-1], num_dims)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
if num_dims * num_bits > 63:
|
| 132 |
+
raise ValueError(
|
| 133 |
+
"""
|
| 134 |
+
num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
|
| 135 |
+
into a int64. Are you sure you need that many points on your Hilbert
|
| 136 |
+
curve?
|
| 137 |
+
"""
|
| 138 |
+
% (num_dims, num_bits, num_dims * num_bits)
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Treat the location integers as 64-bit unsigned and then split them up into
|
| 142 |
+
# a sequence of uint8s. Preserve the association by dimension.
|
| 143 |
+
locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1)
|
| 144 |
+
|
| 145 |
+
# Now turn these into bits and truncate to num_bits.
|
| 146 |
+
gray = (
|
| 147 |
+
locs_uint8.unsqueeze(-1)
|
| 148 |
+
.bitwise_and(bitpack_mask_rev)
|
| 149 |
+
.ne(0)
|
| 150 |
+
.byte()
|
| 151 |
+
.flatten(-2, -1)[..., -num_bits:]
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Run the decoding process the other way.
|
| 155 |
+
# Iterate forwards through the bits.
|
| 156 |
+
for bit in range(0, num_bits):
|
| 157 |
+
# Iterate forwards through the dimensions.
|
| 158 |
+
for dim in range(0, num_dims):
|
| 159 |
+
# Identify which ones have this bit active.
|
| 160 |
+
mask = gray[:, dim, bit]
|
| 161 |
+
|
| 162 |
+
# Where this bit is on, invert the 0 dimension for lower bits.
|
| 163 |
+
gray[:, 0, bit + 1 :] = torch.logical_xor(
|
| 164 |
+
gray[:, 0, bit + 1 :], mask[:, None]
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Where the bit is off, exchange the lower bits with the 0 dimension.
|
| 168 |
+
to_flip = torch.logical_and(
|
| 169 |
+
torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1),
|
| 170 |
+
torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
|
| 171 |
+
)
|
| 172 |
+
gray[:, dim, bit + 1 :] = torch.logical_xor(
|
| 173 |
+
gray[:, dim, bit + 1 :], to_flip
|
| 174 |
+
)
|
| 175 |
+
gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
|
| 176 |
+
|
| 177 |
+
# Now flatten out.
|
| 178 |
+
gray = gray.swapaxes(1, 2).reshape((-1, num_bits * num_dims))
|
| 179 |
+
|
| 180 |
+
# Convert Gray back to binary.
|
| 181 |
+
hh_bin = gray2binary(gray)
|
| 182 |
+
|
| 183 |
+
# Pad back out to 64 bits.
|
| 184 |
+
extra_dims = 64 - num_bits * num_dims
|
| 185 |
+
padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0)
|
| 186 |
+
|
| 187 |
+
# Convert binary values into uint8s.
|
| 188 |
+
hh_uint8 = (
|
| 189 |
+
(padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask)
|
| 190 |
+
.sum(2)
|
| 191 |
+
.squeeze()
|
| 192 |
+
.type(torch.uint8)
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Convert uint8s into uint64s.
|
| 196 |
+
hh_uint64 = hh_uint8.view(torch.int64).squeeze()
|
| 197 |
+
|
| 198 |
+
return hh_uint64
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def decode(hilberts, num_dims, num_bits):
|
| 202 |
+
"""Decode an array of Hilbert integers into locations in a hypercube.
|
| 203 |
+
|
| 204 |
+
This is a vectorized-ish version of the Hilbert curve implementation by John
|
| 205 |
+
Skilling as described in:
|
| 206 |
+
|
| 207 |
+
Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
|
| 208 |
+
Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
|
| 209 |
+
|
| 210 |
+
Params:
|
| 211 |
+
-------
|
| 212 |
+
hilberts - An ndarray of Hilbert integers. Must be an integer dtype and
|
| 213 |
+
cannot have fewer bits than num_dims * num_bits.
|
| 214 |
+
|
| 215 |
+
num_dims - The dimensionality of the hypercube. Integer.
|
| 216 |
+
|
| 217 |
+
num_bits - The number of bits for each dimension. Integer.
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
--------
|
| 221 |
+
The output is an ndarray of unsigned integers with the same shape as hilberts
|
| 222 |
+
but with an additional dimension of size num_dims.
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
if num_dims * num_bits > 64:
|
| 226 |
+
raise ValueError(
|
| 227 |
+
"""
|
| 228 |
+
num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
|
| 229 |
+
into a uint64. Are you sure you need that many points on your Hilbert
|
| 230 |
+
curve?
|
| 231 |
+
"""
|
| 232 |
+
% (num_dims, num_bits)
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
# Handle the case where we got handed a naked integer.
|
| 236 |
+
hilberts = torch.atleast_1d(hilberts)
|
| 237 |
+
|
| 238 |
+
# Keep around the shape for later.
|
| 239 |
+
orig_shape = hilberts.shape
|
| 240 |
+
bitpack_mask = 2 ** torch.arange(0, 8).to(hilberts.device)
|
| 241 |
+
bitpack_mask_rev = bitpack_mask.flip(-1)
|
| 242 |
+
|
| 243 |
+
# Treat each of the hilberts as a s equence of eight uint8.
|
| 244 |
+
# This treats all of the inputs as uint64 and makes things uniform.
|
| 245 |
+
hh_uint8 = (
|
| 246 |
+
hilberts.ravel().type(torch.int64).view(torch.uint8).reshape((-1, 8)).flip(-1)
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# Turn these lists of uints into lists of bits and then truncate to the size
|
| 250 |
+
# we actually need for using Skilling's procedure.
|
| 251 |
+
hh_bits = (
|
| 252 |
+
hh_uint8.unsqueeze(-1)
|
| 253 |
+
.bitwise_and(bitpack_mask_rev)
|
| 254 |
+
.ne(0)
|
| 255 |
+
.byte()
|
| 256 |
+
.flatten(-2, -1)[:, -num_dims * num_bits :]
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
# Take the sequence of bits and Gray-code it.
|
| 260 |
+
gray = binary2gray(hh_bits)
|
| 261 |
+
|
| 262 |
+
# There has got to be a better way to do this.
|
| 263 |
+
# I could index them differently, but the eventual packbits likes it this way.
|
| 264 |
+
gray = gray.reshape((-1, num_bits, num_dims)).swapaxes(1, 2)
|
| 265 |
+
|
| 266 |
+
# Iterate backwards through the bits.
|
| 267 |
+
for bit in range(num_bits - 1, -1, -1):
|
| 268 |
+
# Iterate backwards through the dimensions.
|
| 269 |
+
for dim in range(num_dims - 1, -1, -1):
|
| 270 |
+
# Identify which ones have this bit active.
|
| 271 |
+
mask = gray[:, dim, bit]
|
| 272 |
+
|
| 273 |
+
# Where this bit is on, invert the 0 dimension for lower bits.
|
| 274 |
+
gray[:, 0, bit + 1 :] = torch.logical_xor(
|
| 275 |
+
gray[:, 0, bit + 1 :], mask[:, None]
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
# Where the bit is off, exchange the lower bits with the 0 dimension.
|
| 279 |
+
to_flip = torch.logical_and(
|
| 280 |
+
torch.logical_not(mask[:, None]),
|
| 281 |
+
torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
|
| 282 |
+
)
|
| 283 |
+
gray[:, dim, bit + 1 :] = torch.logical_xor(
|
| 284 |
+
gray[:, dim, bit + 1 :], to_flip
|
| 285 |
+
)
|
| 286 |
+
gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
|
| 287 |
+
|
| 288 |
+
# Pad back out to 64 bits.
|
| 289 |
+
extra_dims = 64 - num_bits
|
| 290 |
+
padded = torch.nn.functional.pad(gray, (extra_dims, 0), "constant", 0)
|
| 291 |
+
|
| 292 |
+
# Now chop these up into blocks of 8.
|
| 293 |
+
locs_chopped = padded.flip(-1).reshape((-1, num_dims, 8, 8))
|
| 294 |
+
|
| 295 |
+
# Take those blocks and turn them unto uint8s.
|
| 296 |
+
# from IPython import embed; embed()
|
| 297 |
+
locs_uint8 = (locs_chopped * bitpack_mask).sum(3).squeeze().type(torch.uint8)
|
| 298 |
+
|
| 299 |
+
# Finally, treat these as uint64s.
|
| 300 |
+
flat_locs = locs_uint8.view(torch.int64)
|
| 301 |
+
|
| 302 |
+
# Return them in the expected shape.
|
| 303 |
+
return flat_locs.reshape((*orig_shape, num_dims))
|