| | """ |
| | 3D Voxel Shape Classifier — Complete Geometric Primitive Vocabulary |
| | 5×5×5 binary voxel grid → rigid cascade → curvature analysis → classify |
| | |
| | 38 shape classes covering: |
| | - Rigid 0D-3D: points, lines, joints, triangles, quads, polyhedra, prisms |
| | - Curved 1D: arcs, helices |
| | - Curved 2D: circles, ellipses, discs |
| | - Curved 3D solid: sphere, hemisphere, cylinder, cone, capsule, torus |
| | - Curved 3D hollow: shell, tube |
| | - Curved 3D open: bowl (concave), saddle (hyperbolic) |
| | |
| | Curvature types: none, convex, concave, cylindrical, conical, toroidal, hyperbolic, helical |
| | """ |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from typing import Optional |
| | import math |
| | from itertools import combinations |
| |
|
| |
|
| | |
| |
|
| | class SwiGLU(nn.Module): |
| | """ |
| | SwiGLU activation: out = (x @ W1) * SiLU(x @ W2) |
| | |
| | SiLU(x) = x * sigmoid(x), aka Swish — the "Swi" in SwiGLU. |
| | Unlike plain sigmoid gating, SiLU preserves gradient magnitude |
| | through the gate branch while maintaining sharp gating behavior. |
| | |
| | Used at geometric decision points where crisp on/off transitions |
| | matter more than smooth interpolation. |
| | """ |
| |
|
| | def __init__(self, in_dim, out_dim): |
| | super().__init__() |
| | self.w1 = nn.Linear(in_dim, out_dim) |
| | self.w2 = nn.Linear(in_dim, out_dim) |
| |
|
| | def forward(self, x): |
| | return self.w1(x) * F.silu(self.w2(x)) |
| |
|
| |
|
| | |
| |
|
| | SHAPE_CATALOG = { |
| | |
| | "point": {"dim": 0, "curved": False, "curvature": "none"}, |
| |
|
| | |
| | "line_x": {"dim": 1, "curved": False, "curvature": "none"}, |
| | "line_y": {"dim": 1, "curved": False, "curvature": "none"}, |
| | "line_z": {"dim": 1, "curved": False, "curvature": "none"}, |
| | "line_diag": {"dim": 1, "curved": False, "curvature": "none"}, |
| |
|
| | |
| | "cross": {"dim": 1, "curved": False, "curvature": "none"}, |
| | "l_shape": {"dim": 1, "curved": False, "curvature": "none"}, |
| | "collinear": {"dim": 1, "curved": False, "curvature": "none"}, |
| |
|
| | |
| | "triangle_xy": {"dim": 2, "curved": False, "curvature": "none"}, |
| | "triangle_xz": {"dim": 2, "curved": False, "curvature": "none"}, |
| | "triangle_3d": {"dim": 2, "curved": False, "curvature": "none"}, |
| |
|
| | |
| | "square_xy": {"dim": 2, "curved": False, "curvature": "none"}, |
| | "square_xz": {"dim": 2, "curved": False, "curvature": "none"}, |
| | "rectangle": {"dim": 2, "curved": False, "curvature": "none"}, |
| | "coplanar": {"dim": 2, "curved": False, "curvature": "none"}, |
| |
|
| | |
| | "plane": {"dim": 2, "curved": False, "curvature": "none"}, |
| |
|
| | |
| | "tetrahedron": {"dim": 3, "curved": False, "curvature": "none"}, |
| | "pyramid": {"dim": 3, "curved": False, "curvature": "none"}, |
| | "pentachoron": {"dim": 3, "curved": False, "curvature": "none"}, |
| |
|
| | |
| | "cube": {"dim": 3, "curved": False, "curvature": "none"}, |
| | "cuboid": {"dim": 3, "curved": False, "curvature": "none"}, |
| | "triangular_prism": {"dim": 3, "curved": False, "curvature": "none"}, |
| | "octahedron": {"dim": 3, "curved": False, "curvature": "none"}, |
| |
|
| | |
| | "arc": {"dim": 1, "curved": True, "curvature": "convex"}, |
| | "helix": {"dim": 1, "curved": True, "curvature": "helical"}, |
| |
|
| | |
| | "circle": {"dim": 2, "curved": True, "curvature": "convex"}, |
| | "ellipse": {"dim": 2, "curved": True, "curvature": "convex"}, |
| |
|
| | |
| | "disc": {"dim": 2, "curved": True, "curvature": "convex"}, |
| |
|
| | |
| | "sphere": {"dim": 3, "curved": True, "curvature": "convex"}, |
| | "hemisphere": {"dim": 3, "curved": True, "curvature": "convex"}, |
| | "cylinder": {"dim": 3, "curved": True, "curvature": "cylindrical"}, |
| | "cone": {"dim": 3, "curved": True, "curvature": "conical"}, |
| | "capsule": {"dim": 3, "curved": True, "curvature": "convex"}, |
| | "torus": {"dim": 3, "curved": True, "curvature": "toroidal"}, |
| |
|
| | |
| | "shell": {"dim": 3, "curved": True, "curvature": "convex"}, |
| | "tube": {"dim": 3, "curved": True, "curvature": "cylindrical"}, |
| |
|
| | |
| | "bowl": {"dim": 3, "curved": True, "curvature": "concave"}, |
| | "saddle": {"dim": 3, "curved": True, "curvature": "hyperbolic"}, |
| | } |
| |
|
| | NUM_CLASSES = len(SHAPE_CATALOG) |
| | CLASS_NAMES = list(SHAPE_CATALOG.keys()) |
| | CLASS_TO_IDX = {name: i for i, name in enumerate(CLASS_NAMES)} |
| |
|
| | CURVATURE_TYPES = ["none", "convex", "concave", "cylindrical", "conical", |
| | "toroidal", "hyperbolic", "helical"] |
| | CURV_TO_IDX = {c: i for i, c in enumerate(CURVATURE_TYPES)} |
| | NUM_CURVATURES = len(CURVATURE_TYPES) |
| |
|
| | GS = 5 |
| |
|
| |
|
| | |
| |
|
| | def cayley_menger_det(points: np.ndarray) -> float: |
| | n = len(points) |
| | D = np.zeros((n, n)) |
| | for i in range(n): |
| | for j in range(n): |
| | D[i, j] = np.sum((points[i] - points[j]) ** 2) |
| | CM = np.zeros((n + 1, n + 1)) |
| | CM[0, 1:] = 1 |
| | CM[1:, 0] = 1 |
| | CM[1:, 1:] = D |
| | return np.linalg.det(CM) |
| |
|
| |
|
| | def simplex_volume(points: np.ndarray) -> float: |
| | k = len(points) |
| | if k < 2: return 0.0 |
| | cm = cayley_menger_det(points) |
| | sign = (-1) ** k |
| | denom = (2 ** (k - 1)) * (math.factorial(k - 1) ** 2) |
| | v_sq = sign * cm / denom |
| | return np.sqrt(max(0, v_sq)) |
| |
|
| |
|
| | def effective_volume(points: np.ndarray) -> float: |
| | k = len(points) |
| | if k < 2: return 0.0 |
| | if k == 2: return np.linalg.norm(points[0] - points[1]) |
| | if k >= 3: |
| | max_a = 0 |
| | for idx in combinations(range(min(k, 8)), 3): |
| | max_a = max(max_a, simplex_volume(points[list(idx)])) |
| | if k < 4: return max_a |
| | if k >= 4: |
| | max_v = 0 |
| | for idx in combinations(range(min(k, 8)), 4): |
| | max_v = max(max_v, simplex_volume(points[list(idx)])) |
| | return max_v |
| | return 0.0 |
| |
|
| |
|
| | |
| |
|
| | class ShapeGenerator: |
| | def __init__(self, seed=42): |
| | self.rng = np.random.RandomState(seed) |
| |
|
| | def generate(self, n_samples: int) -> list: |
| | samples = [] |
| | per_class = n_samples // NUM_CLASSES |
| | for name in CLASS_NAMES: |
| | count = 0 |
| | attempts = 0 |
| | while count < per_class and attempts < per_class * 5: |
| | s = self._make(name) |
| | attempts += 1 |
| | if s is not None: |
| | samples.append(s) |
| | count += 1 |
| | while len(samples) < n_samples: |
| | name = self.rng.choice(CLASS_NAMES) |
| | s = self._make(name) |
| | if s is not None: |
| | samples.append(s) |
| | self.rng.shuffle(samples) |
| | return samples[:n_samples] |
| |
|
| | def _make(self, name: str) -> Optional[dict]: |
| | info = SHAPE_CATALOG[name] |
| | if info["curved"]: |
| | voxels = self._curved(name) |
| | else: |
| | voxels = self._rigid(name) |
| | if voxels is None: return None |
| | voxels = np.clip(voxels, 0, GS - 1).astype(int) |
| | voxels = np.unique(voxels, axis=0) |
| | if len(voxels) < 1: return None |
| | return self._build(name, info, voxels) |
| |
|
| | |
| |
|
| | def _rigid(self, name): |
| | rng = self.rng |
| |
|
| | if name == "point": |
| | return rng.randint(0, GS, size=(1, 3)) |
| |
|
| | elif name == "line_x": |
| | y, z = rng.randint(0, GS, size=2) |
| | x1, x2 = sorted(rng.choice(GS, 2, replace=False)) |
| | return np.array([[x1, y, z], [x2, y, z]]) |
| |
|
| | elif name == "line_y": |
| | x, z = rng.randint(0, GS, size=2) |
| | y1, y2 = sorted(rng.choice(GS, 2, replace=False)) |
| | return np.array([[x, y1, z], [x, y2, z]]) |
| |
|
| | elif name == "line_z": |
| | x, y = rng.randint(0, GS, size=2) |
| | z1, z2 = sorted(rng.choice(GS, 2, replace=False)) |
| | return np.array([[x, y, z1], [x, y, z2]]) |
| |
|
| | elif name == "line_diag": |
| | p1 = rng.randint(0, 3, size=3) |
| | step = rng.randint(1, 3) |
| | direction = rng.choice([-1, 1], size=3) |
| | if np.sum(direction != 0) < 2: |
| | direction[rng.randint(3)] = rng.choice([-1, 1]) |
| | p2 = np.clip(p1 + step * direction, 0, GS - 1) |
| | if np.array_equal(p1, p2): |
| | p2 = np.clip(p1 + np.array([1, 1, 0]), 0, GS - 1) |
| | return np.array([p1, p2]) |
| |
|
| | elif name == "cross": |
| | |
| | cx, cy, cz = rng.randint(1, GS - 1, size=3) |
| | length = rng.randint(1, 3) |
| | axis1, axis2 = rng.choice(3, 2, replace=False) |
| | pts = [[cx, cy, cz]] |
| | for sign in [-1, 1]: |
| | p = [cx, cy, cz] |
| | p[axis1] = np.clip(p[axis1] + sign * length, 0, GS - 1) |
| | pts.append(list(p)) |
| | for sign in [-1, 1]: |
| | p = [cx, cy, cz] |
| | p[axis2] = np.clip(p[axis2] + sign * length, 0, GS - 1) |
| | pts.append(list(p)) |
| | return np.array(pts) |
| |
|
| | elif name == "l_shape": |
| | |
| | corner = rng.randint(1, GS - 1, size=3) |
| | axis1, axis2 = rng.choice(3, 2, replace=False) |
| | len1 = rng.randint(1, 3) |
| | len2 = rng.randint(1, 3) |
| | dir1 = rng.choice([-1, 1]) |
| | dir2 = rng.choice([-1, 1]) |
| | pts = [list(corner)] |
| | for i in range(1, len1 + 1): |
| | p = list(corner) |
| | p[axis1] = np.clip(p[axis1] + dir1 * i, 0, GS - 1) |
| | pts.append(p) |
| | for i in range(1, len2 + 1): |
| | p = list(corner) |
| | p[axis2] = np.clip(p[axis2] + dir2 * i, 0, GS - 1) |
| | pts.append(p) |
| | return np.array(pts) |
| |
|
| | elif name == "collinear": |
| | axis = rng.randint(3) |
| | fixed = rng.randint(0, GS, size=2) |
| | vals = sorted(rng.choice(GS, 3, replace=False)) |
| | pts = np.zeros((3, 3), dtype=int) |
| | for i, v in enumerate(vals): |
| | pts[i, axis] = v |
| | pts[i, (axis + 1) % 3] = fixed[0] |
| | pts[i, (axis + 2) % 3] = fixed[1] |
| | return pts |
| |
|
| | elif name == "triangle_xy": |
| | z = rng.randint(0, GS) |
| | pts = self._rand_pts_2d(3, min_dist=1) |
| | if pts is None: return None |
| | return np.column_stack([pts, np.full(3, z)]) |
| |
|
| | elif name == "triangle_xz": |
| | y = rng.randint(0, GS) |
| | pts = self._rand_pts_2d(3, min_dist=1) |
| | if pts is None: return None |
| | return np.column_stack([pts[:, 0], np.full(3, y), pts[:, 1]]) |
| |
|
| | elif name == "triangle_3d": |
| | return self._rand_pts_3d(3, min_dist=1) |
| |
|
| | elif name == "square_xy": |
| | z = rng.randint(0, GS) |
| | x1, y1 = rng.randint(0, 3, size=2) |
| | s = rng.randint(1, 3) |
| | pts = np.array([[x1, y1, z], [x1 + s, y1, z], |
| | [x1, y1 + s, z], [x1 + s, y1 + s, z]]) |
| | return np.clip(pts, 0, GS - 1) |
| |
|
| | elif name == "square_xz": |
| | y = rng.randint(0, GS) |
| | x1, z1 = rng.randint(0, 3, size=2) |
| | s = rng.randint(1, 3) |
| | pts = np.array([[x1, y, z1], [x1 + s, y, z1], |
| | [x1, y, z1 + s], [x1 + s, y, z1 + s]]) |
| | return np.clip(pts, 0, GS - 1) |
| |
|
| | elif name == "rectangle": |
| | axis = rng.randint(3) |
| | val = rng.randint(0, GS) |
| | a1, a2 = rng.randint(0, 3), rng.randint(0, 3) |
| | w, h = rng.randint(1, 4), rng.randint(1, 3) |
| | if w == h: w = min(GS - 1, w + 1) |
| | c = np.array([[a1, a2], [a1 + w, a2], [a1, a2 + h], [a1 + w, a2 + h]]) |
| | c = np.clip(c, 0, GS - 1) |
| | if axis == 0: return np.column_stack([np.full(4, val), c]) |
| | elif axis == 1: return np.column_stack([c[:, 0], np.full(4, val), c[:, 1]]) |
| | else: return np.column_stack([c, np.full(4, val)]) |
| |
|
| | elif name == "coplanar": |
| | pts = self._rand_pts_3d(4, min_dist=1) |
| | if pts is None: return None |
| | pts[:, rng.randint(3)] = pts[0, rng.randint(3)] |
| | return pts |
| |
|
| | elif name == "plane": |
| | |
| | axis = rng.randint(3) |
| | val = rng.randint(0, GS) |
| | a_start = rng.randint(0, 2) |
| | b_start = rng.randint(0, 2) |
| | a_size = rng.randint(2, GS - a_start + 1) |
| | b_size = rng.randint(2, GS - b_start + 1) |
| | pts = [] |
| | for a in range(a_start, min(GS, a_start + a_size)): |
| | for b in range(b_start, min(GS, b_start + b_size)): |
| | p = [0, 0, 0] |
| | p[axis] = val |
| | p[(axis + 1) % 3] = a |
| | p[(axis + 2) % 3] = b |
| | pts.append(p) |
| | return np.array(pts) if len(pts) >= 4 else None |
| |
|
| | elif name == "tetrahedron": |
| | pts = self._rand_pts_3d(4, min_dist=1) |
| | if pts is None: return None |
| | centered = pts - pts.mean(axis=0) |
| | _, s, _ = np.linalg.svd(centered.astype(float)) |
| | if s[-1] < 0.5: |
| | pts[rng.randint(4), rng.randint(3)] = (pts[0, 0] + 2) % GS |
| | return pts |
| |
|
| | elif name == "pyramid": |
| | z_base = rng.randint(0, 3) |
| | x1, y1 = rng.randint(0, 3), rng.randint(0, 3) |
| | s = rng.randint(1, 3) |
| | base = np.array([[x1, y1, z_base], [x1 + s, y1, z_base], |
| | [x1, y1 + s, z_base], [x1 + s, y1 + s, z_base]]) |
| | apex = np.array([[x1 + s // 2, y1 + s // 2, z_base + rng.randint(1, 3)]]) |
| | return np.clip(np.vstack([base, apex]), 0, GS - 1) |
| |
|
| | elif name == "pentachoron": |
| | return self._rand_pts_3d(5, min_dist=1) |
| |
|
| | elif name == "cube": |
| | x1, y1, z1 = rng.randint(0, 3, size=3) |
| | s = rng.randint(1, 3) |
| | pts = [] |
| | for dx in [0, s]: |
| | for dy in [0, s]: |
| | for dz in [0, s]: |
| | pts.append([x1 + dx, y1 + dy, z1 + dz]) |
| | return np.clip(np.array(pts), 0, GS - 1) |
| |
|
| | elif name == "cuboid": |
| | x1, y1, z1 = rng.randint(0, 2, size=3) |
| | sx, sy, sz = rng.randint(1, 4, size=3) |
| | |
| | if sx == sy == sz: |
| | sx = min(GS - 1, sx + 1) |
| | pts = [] |
| | for dx in [0, sx]: |
| | for dy in [0, sy]: |
| | for dz in [0, sz]: |
| | pts.append([x1 + dx, y1 + dy, z1 + dz]) |
| | return np.clip(np.array(pts), 0, GS - 1) |
| |
|
| | elif name == "triangular_prism": |
| | |
| | axis = rng.randint(3) |
| | ext_start = rng.randint(0, 3) |
| | ext_len = rng.randint(1, 3) |
| | tri = self._rand_pts_2d(3, min_dist=1) |
| | if tri is None: return None |
| | pts = [] |
| | for e in range(ext_start, min(GS, ext_start + ext_len + 1)): |
| | for t in tri: |
| | p = [0, 0, 0] |
| | p[axis] = e |
| | p[(axis + 1) % 3] = t[0] |
| | p[(axis + 2) % 3] = t[1] |
| | pts.append(p) |
| | return np.clip(np.array(pts), 0, GS - 1) if len(pts) >= 6 else None |
| |
|
| | elif name == "octahedron": |
| | |
| | cx, cy, cz = rng.randint(1, GS - 1, size=3) |
| | s = rng.randint(1, 3) |
| | pts = [[cx, cy, cz + s], [cx, cy, cz - s], |
| | [cx + s, cy, cz], [cx - s, cy, cz], |
| | [cx, cy + s, cz], [cx, cy - s, cz]] |
| | return np.clip(np.array(pts), 0, GS - 1) |
| |
|
| | return None |
| |
|
| | |
| |
|
| | def _curved(self, name): |
| | rng = self.rng |
| | cx, cy, cz = rng.uniform(1.0, 3.0, size=3) |
| |
|
| | if name == "arc": |
| | r = rng.uniform(1.2, 2.2) |
| | plane = rng.choice(["xy", "xz", "yz"]) |
| | start = rng.uniform(0, 2 * np.pi) |
| | span = rng.uniform(np.pi * 0.4, np.pi * 1.2) |
| | n = rng.randint(6, 12) |
| | angles = np.linspace(start, start + span, n) |
| | pts = [] |
| | for a in angles: |
| | if plane == "xy": |
| | pts.append([cx + r * np.cos(a), cy + r * np.sin(a), cz]) |
| | elif plane == "xz": |
| | pts.append([cx + r * np.cos(a), cy, cz + r * np.sin(a)]) |
| | else: |
| | pts.append([cx, cy + r * np.cos(a), cz + r * np.sin(a)]) |
| | pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0) |
| | return pts if len(pts) >= 3 else None |
| |
|
| | elif name == "helix": |
| | |
| | r = rng.uniform(0.8, 1.8) |
| | axis = rng.randint(3) |
| | pitch = rng.uniform(0.3, 0.8) |
| | n = rng.randint(15, 30) |
| | t = np.linspace(0, 2 * np.pi * rng.uniform(1.0, 2.5), n) |
| | pts = [] |
| | center = [cx, cy, cz] |
| | axes = [i for i in range(3) if i != axis] |
| | start_h = rng.uniform(0, 1.0) |
| | for ti in t: |
| | p = [0.0, 0.0, 0.0] |
| | p[axes[0]] = center[axes[0]] + r * np.cos(ti) |
| | p[axes[1]] = center[axes[1]] + r * np.sin(ti) |
| | p[axis] = start_h + pitch * ti |
| | pts.append(p) |
| | pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0) |
| | return pts if len(pts) >= 5 else None |
| |
|
| | elif name == "circle": |
| | r = rng.uniform(1.0, 2.0) |
| | plane = rng.choice(["xy", "xz", "yz"]) |
| | n = rng.randint(12, 20) |
| | angles = np.linspace(0, 2 * np.pi, n, endpoint=False) |
| | pts = [] |
| | for a in angles: |
| | if plane == "xy": |
| | pts.append([cx + r * np.cos(a), cy + r * np.sin(a), cz]) |
| | elif plane == "xz": |
| | pts.append([cx + r * np.cos(a), cy, cz + r * np.sin(a)]) |
| | else: |
| | pts.append([cx, cy + r * np.cos(a), cz + r * np.sin(a)]) |
| | pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0) |
| | return pts if len(pts) >= 5 else None |
| |
|
| | elif name == "ellipse": |
| | rx, ry = rng.uniform(0.8, 2.0), rng.uniform(0.8, 2.0) |
| | if abs(rx - ry) < 0.3: rx *= 1.4 |
| | plane = rng.choice(["xy", "xz", "yz"]) |
| | n = rng.randint(12, 20) |
| | angles = np.linspace(0, 2 * np.pi, n, endpoint=False) |
| | pts = [] |
| | for a in angles: |
| | if plane == "xy": |
| | pts.append([cx + rx * np.cos(a), cy + ry * np.sin(a), cz]) |
| | elif plane == "xz": |
| | pts.append([cx + rx * np.cos(a), cy, cz + ry * np.sin(a)]) |
| | else: |
| | pts.append([cx, cy + rx * np.cos(a), cz + ry * np.sin(a)]) |
| | pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0) |
| | return pts if len(pts) >= 5 else None |
| |
|
| | elif name == "disc": |
| | |
| | r = rng.uniform(1.0, 2.2) |
| | axis = rng.randint(3) |
| | val = round(rng.uniform(0.5, 3.5)) |
| | center = [cx, cy, cz] |
| | axes = [i for i in range(3) if i != axis] |
| | pts = [] |
| | for x in range(GS): |
| | for y in range(GS): |
| | p = [0, 0, 0] |
| | p[axis] = val |
| | p[axes[0]] = x |
| | p[axes[1]] = y |
| | dist = np.sqrt((x - center[axes[0]])**2 + (y - center[axes[1]])**2) |
| | if dist <= r: |
| | pts.append(p) |
| | return np.array(pts) if len(pts) >= 4 else None |
| |
|
| | elif name == "sphere": |
| | r = rng.uniform(1.0, 2.2) |
| | pts = [] |
| | for x in range(GS): |
| | for y in range(GS): |
| | for z in range(GS): |
| | if (x - cx)**2 + (y - cy)**2 + (z - cz)**2 <= r**2: |
| | pts.append([x, y, z]) |
| | return np.array(pts) if len(pts) >= 4 else None |
| |
|
| | elif name == "hemisphere": |
| | r = rng.uniform(1.0, 2.2) |
| | cut_axis = rng.randint(3) |
| | center = [cx, cy, cz] |
| | pts = [] |
| | for x in range(GS): |
| | for y in range(GS): |
| | for z in range(GS): |
| | p = [x, y, z] |
| | if (x - cx)**2 + (y - cy)**2 + (z - cz)**2 <= r**2: |
| | if p[cut_axis] >= center[cut_axis]: |
| | pts.append(p) |
| | return np.array(pts) if len(pts) >= 3 else None |
| |
|
| | elif name == "cylinder": |
| | r = rng.uniform(0.8, 1.8) |
| | axis = rng.randint(3) |
| | length = rng.randint(2, 5) |
| | start = rng.randint(0, GS - length + 1) |
| | center = [cx, cy, cz] |
| | axes = [i for i in range(3) if i != axis] |
| | pts = [] |
| | for x in range(GS): |
| | for y in range(GS): |
| | for z in range(GS): |
| | p = [x, y, z] |
| | if p[axis] < start or p[axis] >= start + length: continue |
| | dist_sq = sum((p[a] - center[a])**2 for a in axes) |
| | if dist_sq <= r**2: |
| | pts.append(p) |
| | return np.array(pts) if len(pts) >= 4 else None |
| |
|
| | elif name == "cone": |
| | r_base = rng.uniform(1.0, 2.0) |
| | axis = rng.randint(3) |
| | height = rng.randint(2, 5) |
| | base_pos = rng.randint(0, GS - height + 1) |
| | center = [cx, cy, cz] |
| | axes = [i for i in range(3) if i != axis] |
| | pts = [] |
| | for x in range(GS): |
| | for y in range(GS): |
| | for z in range(GS): |
| | p = [x, y, z] |
| | along = p[axis] - base_pos |
| | if along < 0 or along >= height: continue |
| | t = along / (height - 1 + 1e-6) |
| | r_at = r_base * (1.0 - t) |
| | dist_sq = sum((p[a] - center[a])**2 for a in axes) |
| | if dist_sq <= r_at**2: |
| | pts.append(p) |
| | return np.array(pts) if len(pts) >= 4 else None |
| |
|
| | elif name == "capsule": |
| | |
| | r = rng.uniform(0.8, 1.5) |
| | axis = rng.randint(3) |
| | body_len = rng.randint(1, 3) |
| | center = [cx, cy, cz] |
| | axes = [i for i in range(3) if i != axis] |
| | body_start = round(center[axis] - body_len / 2) |
| | body_end = body_start + body_len |
| | pts = [] |
| | for x in range(GS): |
| | for y in range(GS): |
| | for z in range(GS): |
| | p = [x, y, z] |
| | radial_sq = sum((p[a] - center[a])**2 for a in axes) |
| | along = p[axis] |
| | |
| | if body_start <= along <= body_end and radial_sq <= r**2: |
| | pts.append(p) |
| | |
| | elif along < body_start: |
| | cap_center = list(center) |
| | cap_center[axis] = body_start |
| | dist_sq = sum((p[i] - cap_center[i])**2 for i in range(3)) |
| | if dist_sq <= r**2: |
| | pts.append(p) |
| | |
| | elif along > body_end: |
| | cap_center = list(center) |
| | cap_center[axis] = body_end |
| | dist_sq = sum((p[i] - cap_center[i])**2 for i in range(3)) |
| | if dist_sq <= r**2: |
| | pts.append(p) |
| | return np.array(pts) if len(pts) >= 5 else None |
| |
|
| | elif name == "torus": |
| | R = rng.uniform(1.2, 2.0) |
| | r = rng.uniform(0.5, 0.9) |
| | axis = rng.randint(3) |
| | center = [cx, cy, cz] |
| | ring_axes = [i for i in range(3) if i != axis] |
| | pts = [] |
| | for x in range(GS): |
| | for y in range(GS): |
| | for z in range(GS): |
| | p = [x, y, z] |
| | dist_in_plane = np.sqrt( |
| | sum((p[a] - center[a])**2 for a in ring_axes)) |
| | dist_from_ring = np.sqrt( |
| | (dist_in_plane - R)**2 + (p[axis] - center[axis])**2) |
| | if dist_from_ring <= r: |
| | pts.append(p) |
| | return np.array(pts) if len(pts) >= 4 else None |
| |
|
| | elif name == "shell": |
| | |
| | r_out = rng.uniform(1.5, 2.3) |
| | r_in = r_out - rng.uniform(0.4, 0.8) |
| | if r_in < 0.3: r_in = 0.3 |
| | pts = [] |
| | for x in range(GS): |
| | for y in range(GS): |
| | for z in range(GS): |
| | d_sq = (x - cx)**2 + (y - cy)**2 + (z - cz)**2 |
| | if r_in**2 <= d_sq <= r_out**2: |
| | pts.append([x, y, z]) |
| | return np.array(pts) if len(pts) >= 4 else None |
| |
|
| | elif name == "tube": |
| | |
| | r_out = rng.uniform(1.0, 2.0) |
| | r_in = r_out - rng.uniform(0.3, 0.7) |
| | if r_in < 0.2: r_in = 0.2 |
| | axis = rng.randint(3) |
| | length = rng.randint(2, 5) |
| | start = rng.randint(0, GS - length + 1) |
| | center = [cx, cy, cz] |
| | axes = [i for i in range(3) if i != axis] |
| | pts = [] |
| | for x in range(GS): |
| | for y in range(GS): |
| | for z in range(GS): |
| | p = [x, y, z] |
| | if p[axis] < start or p[axis] >= start + length: continue |
| | dist_sq = sum((p[a] - center[a])**2 for a in axes) |
| | if r_in**2 <= dist_sq <= r_out**2: |
| | pts.append(p) |
| | return np.array(pts) if len(pts) >= 4 else None |
| |
|
| | elif name == "bowl": |
| | |
| | r = rng.uniform(1.2, 2.2) |
| | axis = rng.randint(3) |
| | center = [cx, cy, cz] |
| | axes = [i for i in range(3) if i != axis] |
| | thickness = 0.6 |
| | pts = [] |
| | for x in range(GS): |
| | for y in range(GS): |
| | for z in range(GS): |
| | p = [x, y, z] |
| | dist_planar = np.sqrt( |
| | sum((p[a] - center[a])**2 for a in axes)) |
| | if dist_planar > r: continue |
| | |
| | k = 1.0 / (r + 1e-6) |
| | expected_h = center[axis] + k * dist_planar**2 |
| | actual_h = p[axis] |
| | if abs(actual_h - expected_h) <= thickness: |
| | pts.append(p) |
| | return np.array(pts) if len(pts) >= 4 else None |
| |
|
| | elif name == "saddle": |
| | |
| | axis = rng.randint(3) |
| | center = [cx, cy, cz] |
| | axes = [i for i in range(3) if i != axis] |
| | k = rng.uniform(0.3, 0.8) |
| | thickness = 0.7 |
| | pts = [] |
| | for x in range(GS): |
| | for y in range(GS): |
| | for z in range(GS): |
| | p = [x, y, z] |
| | da = p[axes[0]] - center[axes[0]] |
| | db = p[axes[1]] - center[axes[1]] |
| | expected_h = center[axis] + k * (da**2 - db**2) |
| | if abs(p[axis] - expected_h) <= thickness: |
| | |
| | dist_sq = da**2 + db**2 |
| | if dist_sq <= 4.0: |
| | pts.append(p) |
| | return np.array(pts) if len(pts) >= 4 else None |
| |
|
| | return None |
| |
|
| | |
| |
|
| | def _rand_pts_2d(self, n, min_dist=0): |
| | for _ in range(50): |
| | pts = set() |
| | while len(pts) < n: |
| | pts.add((self.rng.randint(0, GS), self.rng.randint(0, GS))) |
| | pts = np.array(list(pts)[:n]) |
| | if min_dist <= 0 or self._check_dist(pts, min_dist): |
| | return pts |
| | return None |
| |
|
| | def _rand_pts_3d(self, n, min_dist=0): |
| | for _ in range(100): |
| | pts = set() |
| | while len(pts) < n: |
| | pts.add(tuple(self.rng.randint(0, GS, size=3))) |
| | pts = np.array(list(pts)[:n]) |
| | if min_dist <= 0 or self._check_dist(pts, min_dist): |
| | return pts |
| | return None |
| |
|
| | def _check_dist(self, pts, min_dist): |
| | for i in range(len(pts)): |
| | for j in range(i + 1, len(pts)): |
| | if np.sum(np.abs(pts[i] - pts[j])) < min_dist: |
| | return False |
| | return True |
| |
|
| | def _build(self, name, info, voxels): |
| | n = len(voxels) |
| | sub = voxels[:6].astype(float) if n > 6 else voxels.astype(float) |
| | cm_det = cayley_menger_det(sub) |
| | volume = effective_volume(sub) |
| |
|
| | dim_conf = np.zeros(4, dtype=np.float32) |
| | dim_conf[0] = 1.0 |
| | if n >= 2: dim_conf[1] = 1.0 |
| | if info["dim"] >= 2: dim_conf[2] = 1.0 |
| | if info["dim"] >= 3: dim_conf[3] = 1.0 |
| |
|
| | grid = np.zeros((GS, GS, GS), dtype=np.float32) |
| | for v in voxels: |
| | grid[v[0], v[1], v[2]] = 1.0 |
| |
|
| | return { |
| | "grid": grid, "label": CLASS_TO_IDX[name], "class_name": name, |
| | "n_points": n, "n_occupied": int(grid.sum()), |
| | "cm_det": float(cm_det), "volume": float(volume), |
| | "peak_dim": info["dim"], "dim_confidence": dim_conf, |
| | "is_curved": info["curved"], "curvature": CURV_TO_IDX[info["curvature"]], |
| | } |
| |
|
| |
|
| | |
| |
|
| | def _generate_chunk(args): |
| | """Worker function for parallel shape generation.""" |
| | class_assignments, seed, start_idx = args |
| | gen = ShapeGenerator(seed=seed) |
| | samples = [] |
| | for ci in class_assignments: |
| | name = CLASS_NAMES[ci] |
| | for attempt in range(10): |
| | s = gen._make(name) |
| | if s is not None: |
| | samples.append(s) |
| | break |
| | else: |
| | s = gen._make("cube") |
| | if s is not None: |
| | samples.append(s) |
| | return samples |
| |
|
| |
|
| | def generate_parallel(n_samples, seed=42, n_workers=8): |
| | """Pre-generate all samples using multiprocessing.""" |
| | import multiprocessing as mp |
| | per_class = n_samples // NUM_CLASSES |
| | class_assignments = [] |
| | for ci in range(NUM_CLASSES): |
| | class_assignments.extend([ci] * per_class) |
| | rng = np.random.RandomState(seed) |
| | while len(class_assignments) < n_samples: |
| | class_assignments.append(rng.randint(0, NUM_CLASSES)) |
| | rng.shuffle(class_assignments) |
| | class_assignments = class_assignments[:n_samples] |
| |
|
| | |
| | chunk_size = (n_samples + n_workers - 1) // n_workers |
| | chunks = [] |
| | for i in range(n_workers): |
| | start = i * chunk_size |
| | end = min(start + chunk_size, n_samples) |
| | if start >= n_samples: |
| | break |
| | chunks.append((class_assignments[start:end], seed + i * 1000000, start)) |
| |
|
| | print(f"Generating {n_samples} shapes across {len(chunks)} workers...") |
| | import time; t0 = time.time() |
| | with mp.Pool(n_workers) as pool: |
| | results = pool.map(_generate_chunk, chunks) |
| | samples = [] |
| | for r in results: |
| | samples.extend(r) |
| | rng.shuffle(samples) |
| | dt = time.time() - t0 |
| | print(f"Generated {len(samples)} samples in {dt:.1f}s ({len(samples)/dt:.0f} samples/s)") |
| | return samples |
| |
|
| |
|
| | class ShapeDataset(torch.utils.data.Dataset): |
| | def __init__(self, samples): |
| | self.grids = torch.tensor(np.stack([s["grid"] for s in samples]), dtype=torch.float32) |
| | self.labels = torch.tensor([s["label"] for s in samples], dtype=torch.long) |
| | self.dim_conf = torch.tensor(np.stack([s["dim_confidence"] for s in samples]), dtype=torch.float32) |
| | self.peak_dim = torch.tensor([s["peak_dim"] for s in samples], dtype=torch.long) |
| | self.volume = torch.tensor([s["volume"] for s in samples], dtype=torch.float32) |
| | self.cm_det = torch.tensor([s["cm_det"] for s in samples], dtype=torch.float32) |
| | self.is_curved = torch.tensor([s["is_curved"] for s in samples], dtype=torch.float32) |
| | self.curvature = torch.tensor([s["curvature"] for s in samples], dtype=torch.long) |
| |
|
| | def __len__(self): |
| | return len(self.labels) |
| |
|
| | def __getitem__(self, idx): |
| | return (self.grids[idx], self.labels[idx], self.dim_conf[idx], |
| | self.peak_dim[idx], self.volume[idx], self.cm_det[idx], |
| | self.is_curved[idx], self.curvature[idx]) |
| |
|
| |
|
| |
|
| | print(f'Loaded {NUM_CLASSES} shape classes, GS={GS}') |