Jgray21 commited on
Commit
7e1830a
·
verified ·
1 Parent(s): c01107e

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +337 -446
src/streamlit_app.py CHANGED
@@ -7,30 +7,20 @@ from typing import Dict, List, Tuple, Optional
7
 
8
  import numpy as np
9
  import pandas as pd
10
-
11
  import torch
12
  from torch import nn
13
-
14
  import networkx as nx
15
  import streamlit as st
16
 
17
- # Transformers: Qwen tokenizer can be AutoTokenizer if Qwen2Tokenizer not present
18
  from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
19
-
20
- # Dimensionality reduction
21
  import umap
22
- from umap import UMAP
23
-
24
- # Neighbors & clustering
25
  from sklearn.neighbors import NearestNeighbors, KernelDensity
26
  from sklearn.cluster import KMeans, DBSCAN
27
- from sklearn.decomposition import PCA
28
  from sklearn.metrics import pairwise_distances
29
-
30
- # Plotly for interactive 3D
31
  import plotly.graph_objects as go
32
 
33
- import hashlib
34
 
35
  # Optional libs (use if present)
36
  try:
@@ -52,27 +42,20 @@ try:
52
  except Exception:
53
  HAS_PYVISTA = False
54
 
55
- from scipy.linalg import orthogonal_procrustes # For optional per-layer orientation alignment
56
-
57
- # ====== 1. Configuration =========================================================================
58
  @dataclass
59
  class Config:
60
  # Model
61
  model_name: str = "Qwen/Qwen1.5-1.8B"
62
- ### device: str = "cuda" if torch.cuda.is_available() else "cpu"
63
- ### dtype: torch.dtype = torch.float16 if torch.cuda.is_available() else torch.float32
64
-
65
- # Tokenization / generation
66
- max_length: int = 64 # truncate inputs for speed/memory
67
 
68
  # Data
69
- corpus: List[str] = None # set below
70
- # If None, uses DEFAULT_CORPUS defined below
71
 
72
- # Graph building
73
- graph_mode: str = "threshold" # {"knn", "threshold"}
74
- knn_k: int = 8 # neighbors per token (used if graph_mode="knn")
75
- sim_threshold: float = 0.60 # used if graph_mode="threshold"
76
  use_cosine: bool = True
77
 
78
  # Anchors / LoT-style features (global)
@@ -84,104 +67,147 @@ class Config:
84
  n_clusters_kmeans: int = 6 # fallback for kmeans
85
  hdbscan_min_cluster_size: int = 4
86
 
87
- # DR / embeddings
88
  umap_n_neighbors: int = 30
89
  umap_min_dist: float = 0.05
90
- umap_metric: str = "cosine" # hidden states are directional → cosine works well
91
- use_global_3d_umap: bool = False # if True, compute a single 3D manifold for all states
92
-
93
- # Pooling for UMAP fit
94
  fit_pool_per_layer: int = 512 # number of states sampled per layer to fit UMAP
 
95
 
96
- # Volume grid (MRI view)
97
- grid_res: int = 128 # voxel resolution in x/y; z = num_layers
98
- kde_bandwidth: float = 0.15 # KDE bandwidth in manifold space (if using KDE)
99
- use_hist2d: bool = True # if True, use histogram2d instead of KDE for speed
100
 
101
  # Output
102
  out_dir: str = "qwen_mri3d_outputs"
103
  plotly_html: str = "qwen_layers_3d.html"
104
- volume_npz: str = "qwen_density_volume.npz" # saved if PyVista isn't available
105
- volume_screenshot: str = "qwen_volume.png" # if PyVista is available
106
-
107
- def validate(self):
108
- if self.graph_mode not in {"knn", "threshold"}:
109
- raise ValueError("graph_mode must be 'knn' or 'threshold'")
110
- if self.knn_k < 2:
111
- raise ValueError("knn_k must be >= 2")
112
- if self.anchor_k < 2:
113
- raise ValueError("anchor_k must be >= 2")
114
- if self.anchor_temp <= 0:
115
- raise ValueError("anchor_temp must be > 0")
116
-
117
-
118
 
119
  # Default corpus (small and diverse; adjust freely)
120
  DEFAULT_CORPUS = [
121
- "The cat sat on the mat and watched.",
122
- "Machine learning models process data using neural networks.",
123
- "Climate change affects ecosystems around the world.",
124
- "Quantum computers use superposition for parallel computation.",
125
- "The universe contains billions of galaxies.",
126
- "Artificial intelligence transforms how we work.",
127
- "DNA stores genetic information in cells.",
128
- "Ocean currents regulate Earth's climate system.",
129
- "Photosynthesis converts sunlight into chemical energy.",
130
- "Blockchain technology enables decentralized systems."
131
  ]
132
 
133
- # ====== 2. Utilities =============================================================================
 
 
 
 
134
  def seed_everything(seed: int = 42):
135
- """Determinism for reproducibility in layouts/UMAP/kmeans."""
136
  np.random.seed(seed)
137
  torch.manual_seed(seed)
138
 
139
-
140
  def cosine_similarity_matrix(X: np.ndarray) -> np.ndarray:
141
- """Compute pairwise cosine similarity for rows of X."""
142
- # X: (N, D)
143
  norms = np.linalg.norm(X, axis=1, keepdims=True) + 1e-8
144
  Xn = X / norms
145
  return Xn @ Xn.T
146
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- def build_knn_graph(coords: np.ndarray, k: int, metric: str = "cosine") -> nx.Graph:
 
 
 
 
 
 
 
 
 
 
149
  """
150
- Build an undirected kNN graph for the points in coords.
151
- coords: (N, D)
152
  """
153
- nbrs = NearestNeighbors(n_neighbors=min(k+1, len(coords)), metric=metric) # +1 to include self
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  nbrs.fit(coords)
155
  distances, indices = nbrs.kneighbors(coords)
156
-
157
  G = nx.Graph()
158
  G.add_nodes_from(range(len(coords)))
159
- # Connect i to its top-k neighbors (skip index 0 which is itself)
160
  for i in range(len(coords)):
161
- for j in indices[i, 1:]: # skip self
162
  G.add_edge(int(i), int(j))
163
  return G
164
 
165
-
166
- def build_threshold_graph(H: np.ndarray, threshold: float, use_cosine: bool = True) -> nx.Graph:
167
- """
168
- Build graph by thresholding pairwise similarities in the original hidden-state space.
169
- H: (N, D) hidden states for a single layer
170
- """
171
  if use_cosine:
172
  S = cosine_similarity_matrix(H)
173
  else:
174
- S = H @ H.T # dot product
175
 
176
  N = S.shape[0]
 
 
 
 
 
 
177
  G = nx.Graph()
178
  G.add_nodes_from(range(N))
179
- for i in range(N):
180
- for j in range(i + 1, N):
181
- if S[i, j] > threshold:
182
- G.add_edge(i, j, weight=float(S[i, j]))
183
- return G
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  def percolation_stats(G: nx.Graph) -> Dict[str, float]:
187
  """
@@ -210,94 +236,47 @@ def percolation_stats(G: nx.Graph) -> Dict[str, float]:
210
  largest_component_size=largest,
211
  component_sizes=sorted(sizes, reverse=True))
212
 
 
 
 
 
 
 
 
 
 
 
 
213
 
 
214
  def leiden_communities(G: nx.Graph) -> np.ndarray:
215
- """
216
- Community detection using Leiden (igraph), if available.
217
- Returns an array of cluster ids for nodes 0..N-1.
218
- """
219
- if not HAS_IGRAPH_LEIDEN:
220
- raise RuntimeError("igraph+leidenalg not available")
221
-
222
- # Convert nx → igraph
223
  mapping = {n: i for i, n in enumerate(G.nodes())}
224
  edges = [(mapping[u], mapping[v]) for u, v in G.edges()]
225
  ig_g = ig.Graph(n=len(mapping), edges=edges, directed=False)
226
- part = la.find_partition(ig_g, la.RBConfigurationVertexPartition) # robust default
227
  labels = np.zeros(len(mapping), dtype=int)
228
  for cid, comm in enumerate(part):
229
- for node in comm:
230
- labels[node] = cid
231
  return labels
232
 
 
 
 
 
 
 
 
 
 
233
 
234
- def cluster_layer(features: np.ndarray,
235
- G: Optional[nx.Graph],
236
- method: str,
237
- n_clusters_kmeans: int = 6,
238
- hdbscan_min_cluster_size: int = 4) -> np.ndarray:
239
- """
240
- Cluster layer states to get cluster labels.
241
- - If Leiden: requires G (graph) and igraph/leidenalg
242
- - If HDBSCAN: density-based clustering in feature space
243
- - If DBSCAN: fallback density-based (scikit-learn)
244
- - If KMeans: fallback centroid clustering
245
- """
246
- method = method.lower()
247
- N = len(features)
248
-
249
- if method == "auto":
250
- # Prefer Leiden (graph) → HDBSCAN → KMeans
251
- if HAS_IGRAPH_LEIDEN and G is not None and G.number_of_edges() > 0:
252
- return leiden_communities(G)
253
- elif HAS_HDBSCAN and N >= 5:
254
- clusterer = hdbscan.HDBSCAN(min_cluster_size=hdbscan_min_cluster_size,
255
- metric='euclidean')
256
- labels = clusterer.fit_predict(features)
257
- # HDBSCAN: -1 = noise. Keep as its own "noise" cluster id or remap
258
- return labels
259
- else:
260
- km = KMeans(n_clusters=min(n_clusters_kmeans, max(2, N // 3)),
261
- n_init="auto", random_state=42)
262
- return km.fit_predict(features)
263
-
264
- if method == "leiden":
265
- if G is None or not HAS_IGRAPH_LEIDEN:
266
- raise RuntimeError("Leiden requires a graph and igraph+leidenalg.")
267
- return leiden_communities(G)
268
-
269
- if method == "hdbscan":
270
- if not HAS_HDBSCAN:
271
- raise RuntimeError("hdbscan not installed")
272
- clusterer = hdbscan.HDBSCAN(min_cluster_size=hdbscan_min_cluster_size, metric='euclidean')
273
- return clusterer.fit_predict(features)
274
-
275
- if method == "dbscan":
276
- db = DBSCAN(eps=0.5, min_samples=4, metric='euclidean')
277
- return db.fit_predict(features)
278
-
279
- if method == "kmeans":
280
- km = KMeans(n_clusters=min(n_clusters_kmeans, max(2, N // 3)),
281
- n_init="auto", random_state=42)
282
- return km.fit_predict(features)
283
-
284
- raise ValueError(f"Unknown cluster method: {method}")
285
-
286
-
287
- def orthogonal_align(A_ref: np.ndarray, B: np.ndarray) -> np.ndarray:
288
- """
289
- Align B to A_ref by an orthogonal rotation (Procrustes),
290
- preserving geometry but removing arbitrary orientation flips.
291
- """
292
- R, _ = orthogonal_procrustes(B - B.mean(0), A_ref - A_ref.mean(0))
293
- return (B - B.mean(0)) @ R + A_ref.mean(0)
294
-
295
 
296
- def entropy_from_probs(p: np.ndarray, eps: float = 1e-12) -> np.ndarray:
297
- """Shannon entropy for each row; p is (N, K) with rows summing ~1."""
298
- return -np.sum(p * np.log(p + eps), axis=1)
299
 
300
- # ====== 3. Model I/O (hidden states) =============================================================
301
  @dataclass
302
  class HiddenStatesBundle:
303
  """
@@ -336,7 +315,8 @@ def extract_hidden_states(model, tokenizer, text: str, max_length: int, device:
336
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
337
  return HiddenStatesBundle(hidden_layers=hs, tokens=tokens)
338
 
339
- # ====== 4. LoT-style anchors & features ==========================================================
 
340
  def fit_global_anchors(all_states_sampled: np.ndarray, K: int, random_state: int = 42) -> np.ndarray:
341
  """
342
  Fit KMeans cluster centroids on a pooled set of states (from many layers/texts).
@@ -348,7 +328,6 @@ def fit_global_anchors(all_states_sampled: np.ndarray, K: int, random_state: int
348
  kmeans.fit(all_states_sampled)
349
  return kmeans.cluster_centers_ # (K, D)
350
 
351
-
352
  def anchor_features(H: np.ndarray, anchors: np.ndarray, temperature: float = 1.0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
353
  """
354
  For states H (N,D) and anchors A (K,D):
@@ -367,10 +346,12 @@ def anchor_features(H: np.ndarray, anchors: np.ndarray, temperature: float = 1.0
367
  P = np.exp(logits)
368
  P /= P.sum(axis=1, keepdims=True) + 1e-12
369
  # Uncertainty (entropy)
370
- H_unc = entropy_from_probs(P)
 
371
  return dists, P, H_unc
372
 
373
- # ====== 5. Dimensionality reduction / embeddings ================================================
 
374
  def fit_umap_2d(pool: np.ndarray,
375
  n_neighbors: int = 30,
376
  min_dist: float = 0.05,
@@ -386,63 +367,6 @@ def fit_umap_2d(pool: np.ndarray,
386
  reducer.fit(pool)
387
  return reducer
388
 
389
- def _corpus_fingerprint(texts, max_items=5, max_chars=4000) -> str:
390
- """Stable key so cache invalidates if DEFAULT_CORPUS changes."""
391
- joined = "\n".join(texts[:max_items])
392
- joined = joined[:max_chars]
393
- return hashlib.sha256(joined.encode("utf-8")).hexdigest()
394
-
395
- @st.cache_data(show_spinner=False)
396
- def get_pool_artifacts(
397
- model_name: str,
398
- max_length: int,
399
- anchor_k: int,
400
- anchor_temp: float, # not strictly needed for fitting anchors, but included if you want cache keys aligned
401
- umap_n_neighbors: int,
402
- umap_min_dist: float,
403
- umap_metric: str,
404
- fit_pool_per_layer: int,
405
- corpus_hash: str,
406
- ):
407
- """
408
- Cached: build pooled hidden states on DEFAULT_CORPUS, fit anchors and a UMAP reducer once.
409
- Returns:
410
- anchors: (K, D) np.ndarray
411
- reducer2d: fitted UMAP reducer object (must be pickleable; umap-learn's UMAP is)
412
- """
413
- # Use cached model loader (resource cache)
414
- model, tok, device, dtype = get_model_and_tok(model_name)
415
-
416
- texts = DEFAULT_CORPUS # pooled set for stability
417
-
418
- pool_states = []
419
- for t in texts[: min(5, len(texts))]:
420
- b = extract_hidden_states(model, tok, t, max_length, device)
421
- for H in b.hidden_layers:
422
- T = len(H)
423
- take = min(fit_pool_per_layer, T)
424
- if take <= 0:
425
- continue
426
- idx = np.random.choice(T, size=take, replace=False)
427
- pool_states.append(H[idx])
428
-
429
- if not pool_states:
430
- # fallback: this should rarely happen
431
- raise RuntimeError("Pool construction produced no states.")
432
-
433
- pool_states = np.vstack(pool_states)
434
-
435
- anchors = fit_global_anchors(pool_states, anchor_k)
436
-
437
- reducer2d = fit_umap_2d(
438
- pool_states,
439
- n_neighbors=umap_n_neighbors,
440
- min_dist=umap_min_dist,
441
- metric=umap_metric,
442
- )
443
-
444
- return anchors, reducer2d
445
-
446
 
447
  def fit_umap_3d(all_states: np.ndarray,
448
  n_neighbors: int = 30,
@@ -457,296 +381,244 @@ def fit_umap_3d(all_states: np.ndarray,
457
  metric=metric, random_state=random_state)
458
  return reducer.fit_transform(all_states)
459
 
460
- # ====== 6. Volume construction (MRI) ============================================================
461
- def stack_density_volume(xy_by_layer: List[np.ndarray],
462
- grid_res: int,
463
- use_hist2d: bool = True,
464
- kde_bandwidth: float = 0.15) -> np.ndarray:
465
- """
466
- Construct a 3D volume by estimating 2D density on the (x,y) manifold per layer (slice).
467
- - If use_hist2d: fast uniform binning into grid_res x grid_res
468
- - Else: KDE (slower but smoother)
469
- Returns volume of shape (grid_res, grid_res, L) where L = #layers.
470
- """
471
- L = len(xy_by_layer)
472
- vol = np.zeros((grid_res, grid_res, L), dtype=np.float32)
473
-
474
- # Determine global bounds across layers to keep axes consistent
475
- all_xy = np.vstack([xy for xy in xy_by_layer if len(xy) > 0]) if L > 0 else np.zeros((0, 2))
476
- if len(all_xy) == 0:
477
- return vol
478
- x_min, y_min = all_xy.min(axis=0)
479
- x_max, y_max = all_xy.max(axis=0)
480
- # Slight padding
481
- pad = 1e-6
482
- x_edges = np.linspace(x_min - pad, x_max + pad, grid_res + 1)
483
- y_edges = np.linspace(y_min - pad, y_max + pad, grid_res + 1)
484
-
485
- for l, XY in enumerate(xy_by_layer):
486
- if len(XY) == 0:
487
- continue
488
-
489
- if use_hist2d:
490
- H, _, _ = np.histogram2d(XY[:, 0], XY[:, 1], bins=[x_edges, y_edges], density=False)
491
- vol[:, :, l] = H.T # histogram2d returns [x_bins, y_bins] → transpose to align
492
- else:
493
- kde = KernelDensity(bandwidth=kde_bandwidth, kernel="gaussian")
494
- kde.fit(XY)
495
- # Evaluate KDE on grid centers
496
- xs = 0.5 * (x_edges[:-1] + x_edges[1:])
497
- ys = 0.5 * (y_edges[:-1] + y_edges[1:])
498
- xx, yy = np.meshgrid(xs, ys, indexing='xy')
499
- grid_points = np.column_stack([xx.ravel(), yy.ravel()])
500
- log_dens = kde.score_samples(grid_points)
501
- dens = np.exp(log_dens).reshape(grid_res, grid_res)
502
- vol[:, :, l] = dens
503
-
504
- # Normalize volume to [0,1] for rendering convenience
505
- if vol.max() > 0:
506
- vol = vol / vol.max()
507
- return vol
508
-
509
-
510
- def render_volume_with_pyvista(volume: np.ndarray,
511
- out_png: str,
512
- opacity="sigmoid") -> None:
513
- """
514
- Visualize the 3D volume using PyVista/VTK (if installed); save a screenshot.
515
- """
516
- if not HAS_PYVISTA:
517
- raise RuntimeError("PyVista is not installed; cannot render volume.")
518
- pl = pv.Plotter()
519
- # Wrap NumPy array as a VTK image data; PyVista expects z as the 3rd axis
520
- vol_vtk = pv.wrap(volume)
521
- pl.add_volume(vol_vtk, opacity=opacity, shade=True)
522
- pl.show(screenshot=out_png) # headless environments will still save a screenshot (if offscreen support)
523
-
524
- # ====== 7. 3D Plotly visualization ==============================================================
525
  def plotly_3d_layers(xy_layers: List[np.ndarray],
526
  layer_tokens: List[List[str]],
527
  layer_cluster_labels: List[np.ndarray],
 
528
  layer_uncertainty: List[np.ndarray],
529
  layer_graphs: List[nx.Graph],
530
- connect_token_trajectories: bool = True,
531
- title: str = "Qwen: 3D Cluster Formation (UMAP2D + Layer as Z)") -> go.Figure:
532
- """
533
- Build an interactive 3D Plotly figure:
534
- - Nodes per layer at (x, y, z=layer)
535
- - Edge segments (kNN or threshold graph) per layer
536
- - Trajectory lines: connect same token index across consecutive layers (optional)
537
- - Color nodes by cluster label; hover shows token & uncertainty
538
- """
539
  fig_data = []
540
 
541
- # Build a color per layer node trace
542
- for l, (xy, tokens, labels, unc, G) in enumerate(zip(xy_layers, layer_tokens, layer_cluster_labels, layer_uncertainty, layer_graphs)):
543
- if len(xy) == 0:
544
- continue
 
 
 
 
 
 
 
545
  x, y = xy[:, 0], xy[:, 1]
546
  z = np.full_like(x, l, dtype=float)
547
 
548
- # --- Nodes
549
- node_text = [f"layer={l} | idx={i}<br>token={tokens[i]}<br>cluster={int(labels[i])}<br>uncertainty={unc[i]:.3f}"
550
- for i in range(len(tokens))]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  node_trace = go.Scatter3d(
552
  x=x, y=y, z=z,
553
  mode='markers',
554
  name=f"Layer {l}",
 
555
  marker=dict(
556
- size=4,
557
- opacity=0.7,
558
- color=labels, # cluster ID → color scale
559
- colorscale='Viridis',
560
- showscale=(l == 0) # show scale once
 
561
  ),
562
  text=node_text,
563
  hovertemplate="%{text}<extra></extra>"
564
  )
565
  fig_data.append(node_trace)
566
 
567
- # --- Intra-layer edges (kNN or threshold)
568
  if G is not None and G.number_of_edges() > 0:
569
  edge_x, edge_y, edge_z = [], [], []
570
  for u, v in G.edges():
571
  edge_x += [x[u], x[v], None]
572
  edge_y += [y[u], y[v], None]
573
  edge_z += [z[u], z[v], None]
 
574
  edge_trace = go.Scatter3d(
575
  x=edge_x, y=edge_y, z=edge_z,
576
  mode='lines',
577
- line=dict(width=1),
578
- opacity=0.30,
579
- name=f"Edges L{l}"
 
580
  )
581
  fig_data.append(edge_trace)
582
 
583
- # --- Trajectories: connect same token index across layers
584
- if connect_token_trajectories:
585
- # Only meaningful if tokenization length T is constant across layers (it is)
586
- # We'll draw faint polylines for each position i across l=0..L-1
587
- L = len(xy_layers)
588
- if L > 1:
589
- T = min(len(xy_layers[l]) for l in range(L))
590
- for i in range(T):
591
- xs = [xy_layers[l][i, 0] for l in range(L)]
592
- ys = [xy_layers[l][i, 1] for l in range(L)]
593
- zs = list(range(L))
594
- traj = go.Scatter3d(
595
- x=xs, y=ys, z=zs,
596
- mode='lines',
597
- line=dict(width=1),
598
- opacity=0.15,
599
- name=f"traj_{i}",
600
- hoverinfo='skip'
 
 
 
 
 
 
 
 
 
 
 
 
601
  )
602
- fig_data.append(traj)
603
 
604
  fig = go.Figure(data=fig_data)
605
  fig.update_layout(
606
- title=title,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607
  scene=dict(
608
  xaxis_title="UMAP X",
609
  yaxis_title="UMAP Y",
610
- zaxis_title="Layer (depth)"
 
611
  ),
612
  height=900,
613
- showlegend=False
614
  )
615
  return fig
616
 
617
- # ====== 8. Orchestration ========================================================================
618
- def run_pipeline(cfg: Config, model, tok, device, main_text: str, save_artifacts: bool = False):
619
  seed_everything(42)
620
 
621
- # 8.2 Collect hidden states for one representative text (detailed viz) + for pool
622
- # You can extend to many texts; we keep a single text for clarity & speed.
623
- texts = cfg.corpus or DEFAULT_CORPUS
624
- #print(f"[Input] Example text: {main_text!r}")
625
 
626
- # Hidden states for main text
627
  main_bundle = extract_hidden_states(model, tok, main_text, cfg.max_length, device)
628
- layers_np: List[np.ndarray] = main_bundle.hidden_layers # list of (T,D), length L_all = num_layers+1
629
- tokens = main_bundle.tokens # list of length T
630
-
631
- # Cached pool artifacts (anchors + fitted UMAP reducer)
632
- corpus_hash = _corpus_fingerprint(texts) # texts is cfg.corpus or DEFAULT_CORPUS
633
-
634
- anchors, reducer2d = get_pool_artifacts(
635
- model_name=cfg.model_name,
636
- max_length=cfg.max_length,
637
- anchor_k=cfg.anchor_k,
638
- anchor_temp=cfg.anchor_temp,
639
- umap_n_neighbors=cfg.umap_n_neighbors,
640
- umap_min_dist=cfg.umap_min_dist,
641
- umap_metric=cfg.umap_metric,
642
- fit_pool_per_layer=cfg.fit_pool_per_layer,
643
- corpus_hash=corpus_hash,
644
- )
645
-
646
  L_all = len(layers_np)
647
- #print(f"[Hidden] Layers (incl. embedding): {L_all}, Tokens: {len(tokens)}")
648
 
649
- """
650
- # 8.3 Build a pool of states (across a few texts & layers) to fit anchors + UMAP
651
- pool_states = []
652
- # Sample across first few texts to improve diversity (lightweight)
653
- for t in texts[: min(5, len(texts))]:
654
- b = extract_hidden_states(model, tok, t, cfg.max_length, device)
655
- # Take a subset from each layer to limit pool size
656
- for H in b.hidden_layers:
657
- T = len(H)
658
- take = min(cfg.fit_pool_per_layer, T)
659
- idx = np.random.choice(T, size=take, replace=False)
660
- pool_states.append(H[idx])
661
- pool_states = np.vstack(pool_states) if len(pool_states) else layers_np[-1]
662
- #print(f"[Pool] Pooled states for anchors/UMAP: {pool_states.shape}")
663
-
664
- # 8.4 Fit global anchors (LoT-style features)
665
- anchors = fit_global_anchors(pool_states, cfg.anchor_k)
666
- # Save anchors for reproducibility
667
- """
668
-
669
- # 8.5 Build per-layer features for main text (LoT-style distances & uncertainty)
670
- layer_features = [] # list of (T,K)
671
- layer_uncertainties = [] # list of (T,)
672
- layer_top_anchor = [] # list of (T,) argmin-id
673
-
674
- for l, H in enumerate(layers_np):
675
- dists, P, H_unc = anchor_features(H, anchors, cfg.anchor_temp)
676
- layer_features.append(dists) # N x K distances (lower = closer)
677
- layer_uncertainties.append(H_unc) # N
678
- layer_top_anchor.append(np.argmin(dists, axis=1)) # closest anchor id per token
679
 
680
- # 8.6 Consistency metric (LoT Eq. (5)): does layer's top anchor match final layer's?
681
- final_top = layer_top_anchor[-1]
682
- layer_consistency = []
683
- for l in range(L_all):
684
- cons = (layer_top_anchor[l] == final_top).astype(np.int32) # 1 if matches, 0 otherwise
685
- layer_consistency.append(cons)
686
 
687
- # 8.7 Build per-layer graphs (kNN by default) on FEATURE space for stability
 
 
688
  layer_graphs = []
 
 
 
689
  for l in range(L_all):
690
- feats = layer_features[l]
 
 
 
 
 
 
 
691
  if cfg.graph_mode == "knn":
692
- G = build_knn_graph(feats, cfg.knn_k, metric="euclidean") # kNN in feature space
693
  else:
694
- # Threshold graph in original hidden space (as in your notebook)
695
- G = build_threshold_graph(layers_np[l], cfg.sim_threshold, use_cosine=cfg.use_cosine)
696
  layer_graphs.append(G)
697
 
698
- # 8.8 Cluster per layer
699
- layer_cluster_labels = []
700
- for l in range(L_all):
701
- feats = layer_features[l]
702
- labels = cluster_layer(
703
- feats,
704
- layer_graphs[l],
705
- method=cfg.cluster_method,
706
- n_clusters_kmeans=cfg.n_clusters_kmeans,
707
- hdbscan_min_cluster_size=cfg.hdbscan_min_cluster_size
708
- )
709
  layer_cluster_labels.append(labels)
710
 
711
- # 8.9 Percolation statistics (φ, #clusters, χ) per layer (as in your notebook)
712
- percolation = []
713
- for l in range(L_all):
714
- stats = percolation_stats(layer_graphs[l])
715
- percolation.append(stats)
716
 
 
 
 
 
 
 
 
 
 
 
717
 
718
- # 8.10 Common 2D manifold via UMAP (fit-once on the pool), then transform each layer
719
- """reducer2d = fit_umap_2d(pool_states,
720
- n_neighbors=cfg.umap_n_neighbors,
721
- min_dist=cfg.umap_min_dist,
722
- metric=cfg.umap_metric)"""
723
- xy_by_layer = [reducer2d.transform(layers_np[l]) for l in range(L_all)]
724
 
725
- # OPTIONAL: orthogonal alignment across layers (helps if UMAP.transform still drifts)
726
- # for l in range(1, L_all):
727
- # xy_by_layer[l] = orthogonal_align(xy_by_layer[l-1], xy_by_layer[l])
728
 
729
- # 8.11 Plotly 3D point+graph view: X,Y from UMAP; Z = layer index
730
  fig = plotly_3d_layers(
731
  xy_layers=xy_by_layer,
732
- layer_tokens=[tokens for _ in range(L_all)],
733
  layer_cluster_labels=layer_cluster_labels,
 
734
  layer_uncertainty=layer_uncertainties,
735
  layer_graphs=layer_graphs,
736
- connect_token_trajectories=True,
737
- title="Qwen: 3D Cluster Formation (UMAP2D + Layer as Z, LoT metrics on hover)"
 
738
  )
739
 
 
740
  if save_artifacts:
741
- os.makedirs(cfg.out_dir, exist_ok=True)
742
- html_path = os.path.join(cfg.out_dir, cfg.plotly_html)
743
- fig.write_html(html_path)
744
- # Save percolation series
745
- with open(os.path.join(cfg.out_dir, "percolation_stats.json"), "w") as f:
746
- json.dump(percolation, f, indent=2)
747
- np.save(os.path.join(cfg.out_dir, "anchors.npy"), anchors)
748
- #print(f"[Percolation] Saved per-layer stats → percolation_stats.json")
749
- #print(f"[Plotly] 3D HTML saved → {html_path}")
 
750
 
751
  return fig, {"percolation": percolation, "tokens": tokens}
752
 
@@ -754,16 +626,32 @@ def run_pipeline(cfg: Config, model, tok, device, main_text: str, save_artifacts
754
  def get_model_and_tok(model_name: str):
755
  device = "cuda" if torch.cuda.is_available() else "cpu"
756
  dtype = torch.float16 if device == "cuda" else torch.float32
757
- model, tok = load_qwen(model_name, device, dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
758
  return model, tok, device, dtype
759
 
760
  def main():
761
- st.set_page_config(page_title="Layer Explorer", layout="wide")
762
- st.title("3D Token Embedding Explorer (Live Hidden States)")
763
 
764
  with st.sidebar:
765
  st.header("Model / Input")
766
- model_name = st.selectbox("Model", ["Qwen/Qwen1.5-0.5B", "Qwen/Qwen1.5-1.8B", "Qwen/Qwen1.5-4B"], index=1)
767
  max_length = st.slider("Max tokens", 16, 256, 64, step=16)
768
 
769
  st.header("Graph")
@@ -787,13 +675,16 @@ def main():
787
  st.header("Outputs")
788
  save_artifacts = st.checkbox("Save artifacts to disk (HTML/CSV/NPZ)", value=False)
789
 
790
- prompt_col, run_col = st.columns([4, 1])
 
791
  with prompt_col:
792
- main_text = st.text_area(
793
- "Text to visualize (hidden states computed on this text)",
794
- value="Explain in one sentence what a transformer attention layer does.",
795
- height=140
 
796
  )
 
797
  with run_col:
798
  st.write("")
799
  st.write("")
 
7
 
8
  import numpy as np
9
  import pandas as pd
 
10
  import torch
11
  from torch import nn
 
12
  import networkx as nx
13
  import streamlit as st
14
 
 
15
  from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
 
 
16
  import umap
 
 
 
17
  from sklearn.neighbors import NearestNeighbors, KernelDensity
18
  from sklearn.cluster import KMeans, DBSCAN
 
19
  from sklearn.metrics import pairwise_distances
20
+ from scipy.spatial import procrustes
21
+ from scipy.linalg import orthogonal_procrustes
22
  import plotly.graph_objects as go
23
 
 
24
 
25
  # Optional libs (use if present)
26
  try:
 
42
  except Exception:
43
  HAS_PYVISTA = False
44
 
45
+ # ====== Configuration =========================================================================
 
 
46
  @dataclass
47
  class Config:
48
  # Model
49
  model_name: str = "Qwen/Qwen1.5-1.8B"
50
+ max_length: int = 64
 
 
 
 
51
 
52
  # Data
53
+ corpus: List[str] = None
 
54
 
55
+ # Graph & Clustering
56
+ graph_mode: str = "threshold"
57
+ knn_k: int = 8
58
+ sim_threshold: float = 0.05 # Percentile of edges shown 0.05 = Show top 5% of edges
59
  use_cosine: bool = True
60
 
61
  # Anchors / LoT-style features (global)
 
67
  n_clusters_kmeans: int = 6 # fallback for kmeans
68
  hdbscan_min_cluster_size: int = 4
69
 
70
+ # UMAP & alignment
71
  umap_n_neighbors: int = 30
72
  umap_min_dist: float = 0.05
73
+ umap_metric: str = "cosine"
 
 
 
74
  fit_pool_per_layer: int = 512 # number of states sampled per layer to fit UMAP
75
+ align_layers: bool = True # aligning procrustes to layers
76
 
77
+ # Visualization
78
+ color_by: str = "pos" # "cluster" or "pos" (Part of Speech)
 
 
79
 
80
  # Output
81
  out_dir: str = "qwen_mri3d_outputs"
82
  plotly_html: str = "qwen_layers_3d.html"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  # Default corpus (small and diverse; adjust freely)
85
  DEFAULT_CORPUS = [
86
+ "Is a Universal Basic Income (UBI) a viable solution to poverty, or does it simply discourage people from working?",
87
+ "Explain the arguments for and against the independence of Taiwan from the perspective of both the US and China.",
88
+ "What are the ethical arguments surrounding the use of CRISPR technology to edit human embryos for non-medical enhancements?",
89
+ "Analyze the effectiveness of strict lockdowns versus herd immunity strategies during the COVID-19 pandemic.",
90
+ "Why is nuclear energy controversial despite being a low-carbon power source? Present both the safety concerns and the environmental benefits.",
91
+ "Does the existence of evil in the world disprove the existence of a benevolent God? Summarize the philosophical debate.",
92
+ "Summarize the main arguments used by gun rights advocates against stricter background checks in the United States.",
93
+ "Should autonomous weapons systems (killer robots) be banned internationally, even if they could reduce soldier casualties?",
94
+ "Was the dropping of the atomic bombs on Hiroshima and Nagasaki militarily necessary to end World War II?",
95
+ "What are the competing arguments regarding transgender women participating in biological women's sports categories?"
96
  ]
97
 
98
+ #Select from 4 different models
99
+ MODELS = ["Qwen/Qwen1.5-0.5B", "deepseek-ai/deepseek-coder-1.3b-instruct", "openai-community/gpt2", "prem-research/MiniGuard-v0.1"]
100
+
101
+
102
+ # ====== Utilities =========================================================================
103
  def seed_everything(seed: int = 42):
 
104
  np.random.seed(seed)
105
  torch.manual_seed(seed)
106
 
 
107
  def cosine_similarity_matrix(X: np.ndarray) -> np.ndarray:
 
 
108
  norms = np.linalg.norm(X, axis=1, keepdims=True) + 1e-8
109
  Xn = X / norms
110
  return Xn @ Xn.T
111
 
112
+ def orthogonal_align(A_ref: np.ndarray, B: np.ndarray) -> np.ndarray:
113
+ """
114
+ Align B to A_ref using Procrustes analysis (rotation/reflection only).
115
+ Preserves local geometry of B, but aligns global orientation to A.
116
+ """
117
+ # Center both
118
+ mu_a = A_ref.mean(0)
119
+ mu_b = B.mean(0)
120
+ A0 = A_ref - mu_a
121
+ B0 = B - mu_b
122
 
123
+ # Solve for Rotation R that minimizes ||A0 - B0 @ R||
124
+ # M = B0.T @ A0
125
+ # U, S, Vt = svd(M)
126
+ # R = U @ Vt
127
+ R, _ = orthogonal_procrustes(B0, A0)
128
+
129
+ # B_aligned = (B - mu_b) @ R + mu_a
130
+ # We essentially rotate B to match A's orientation, then shift to A's center
131
+ return B0 @ R + mu_a
132
+
133
+ def get_pos_tags(text: str, tokenizer, tokens: List[str]) -> List[str]:
134
  """
135
+ Map LLM tokens to Spacy POS tags.
136
+ Heuristic: Reconstruct text, run Spacy, align based on char overlap.
137
  """
138
+ try:
139
+ nlp = spacy.load("en_core_web_sm")
140
+ except:
141
+ # Fallback if model not downloaded
142
+ return ["UNK"] * len(tokens)
143
+
144
+ doc = nlp(text)
145
+
146
+ # This is a simplified mapping. Real alignment is complex due to subwords.
147
+ # We will approximate: Find which word the subword belongs to.
148
+ pos_tags = []
149
+
150
+ # Re-build offsets for tokens (simplified)
151
+ # Ideally, we use tokenizer(return_offsets_mapping=True)
152
+ # Here we will just iterate and approximate for the demo.
153
+
154
+ # Fast approximation: tag the token string itself
155
+ # (Not perfect for subwords like "ing", but visually useful)
156
+ for t_str in tokens:
157
+ clean_t = t_str.replace("Ġ", "").replace("▁", "").strip()
158
+ if not clean_t:
159
+ pos_tags.append("SYM") # likely special char
160
+ continue
161
+
162
+ # Tag the single token fragment
163
+ sub_doc = nlp(clean_t)
164
+ if len(sub_doc) > 0:
165
+ pos_tags.append(sub_doc[0].pos_)
166
+ else:
167
+ pos_tags.append("UNK")
168
+
169
+ return pos_tags
170
+
171
+ def build_knn_graph(coords: np.ndarray, k: int, metric: str = "cosine") -> nx.Graph:
172
+ nbrs = NearestNeighbors(n_neighbors=min(k+1, len(coords)), metric=metric)
173
  nbrs.fit(coords)
174
  distances, indices = nbrs.kneighbors(coords)
 
175
  G = nx.Graph()
176
  G.add_nodes_from(range(len(coords)))
 
177
  for i in range(len(coords)):
178
+ for j in indices[i, 1:]:
179
  G.add_edge(int(i), int(j))
180
  return G
181
 
182
+ def build_threshold_graph(H: np.ndarray, top_pct: float = 0.05, use_cosine: bool = True, include_ties: bool = True,) -> nx.Graph:
 
 
 
 
 
183
  if use_cosine:
184
  S = cosine_similarity_matrix(H)
185
  else:
186
+ S = H @ H.T
187
 
188
  N = S.shape[0]
189
+ iu = np.triu_indices(N, k=1)
190
+ vals = S[iu]
191
+
192
+ # threshold at (1 - top_pct) quantile
193
+ q = 1.0 - top_pct
194
+ thr = float(np.quantile(vals, q))
195
  G = nx.Graph()
196
  G.add_nodes_from(range(N))
 
 
 
 
 
197
 
198
+ if include_ties:
199
+ mask = vals >= thr
200
+ else:
201
+ # strictly greater than threshold reduces tie-inflation
202
+ mask = vals > thr
203
+
204
+ rows = iu[0][mask]
205
+ cols = iu[1][mask]
206
+ wts = vals[mask]
207
+
208
+ for r, c, w in zip(rows, cols, wts):
209
+ G.add_edge(int(r), int(c), weight=float(w))
210
+ return G
211
 
212
  def percolation_stats(G: nx.Graph) -> Dict[str, float]:
213
  """
 
236
  largest_component_size=largest,
237
  component_sizes=sorted(sizes, reverse=True))
238
 
239
+ def cluster_layer(features: np.ndarray, G: Optional[nx.Graph], method: str,
240
+ n_clusters_kmeans: int=6, hdbscan_min_cluster_size: int=4) -> np.ndarray:
241
+ # (Same as original)
242
+ method = method.lower()
243
+ N = len(features)
244
+ if method == "auto":
245
+ if HAS_IGRAPH_LEIDEN and G and G.number_of_edges() > 0: return leiden_communities(G)
246
+ elif HAS_HDBSCAN: return hdbscan.HDBSCAN(min_cluster_size=hdbscan_min_cluster_size).fit_predict(features)
247
+ else: return KMeans(n_clusters=min(n_clusters_kmeans, N), n_init="auto").fit_predict(features)
248
+ # ... (rest of method dispatch unchanged)
249
+ return KMeans(n_clusters=min(n_clusters_kmeans, N), n_init="auto").fit_predict(features)
250
 
251
+ # Helper for Leiden (from original)
252
  def leiden_communities(G: nx.Graph) -> np.ndarray:
253
+ if not HAS_IGRAPH_LEIDEN: raise RuntimeError("Missing igraph")
 
 
 
 
 
 
 
254
  mapping = {n: i for i, n in enumerate(G.nodes())}
255
  edges = [(mapping[u], mapping[v]) for u, v in G.edges()]
256
  ig_g = ig.Graph(n=len(mapping), edges=edges, directed=False)
257
+ part = la.find_partition(ig_g, la.RBConfigurationVertexPartition)
258
  labels = np.zeros(len(mapping), dtype=int)
259
  for cid, comm in enumerate(part):
260
+ for node in comm: labels[node] = cid
 
261
  return labels
262
 
263
+ def anchor_features(H: np.ndarray, anchors: np.ndarray, temperature: float = 1.0):
264
+ dists = pairwise_distances(H, anchors, metric="euclidean")
265
+ logits = -dists / max(temperature, 1e-6)
266
+ logits = logits - logits.max(axis=1, keepdims=True)
267
+ P = np.exp(logits)
268
+ P /= P.sum(axis=1, keepdims=True) + 1e-12
269
+ # Entropy calculation
270
+ H_unc = -np.sum(P * np.log(P + 1e-12), axis=1)
271
+ return dists, P, H_unc
272
 
273
+ def fit_global_anchors(pool: np.ndarray, K: int) -> np.ndarray:
274
+ km = KMeans(n_clusters=K, n_init="auto", random_state=42)
275
+ km.fit(pool)
276
+ return km.cluster_centers_
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
 
 
 
278
 
279
+ # ====== Model I/O (hidden states) =============================================================
280
  @dataclass
281
  class HiddenStatesBundle:
282
  """
 
315
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
316
  return HiddenStatesBundle(hidden_layers=hs, tokens=tokens)
317
 
318
+
319
+ # ====== LoT-style anchors & features ==========================================================
320
  def fit_global_anchors(all_states_sampled: np.ndarray, K: int, random_state: int = 42) -> np.ndarray:
321
  """
322
  Fit KMeans cluster centroids on a pooled set of states (from many layers/texts).
 
328
  kmeans.fit(all_states_sampled)
329
  return kmeans.cluster_centers_ # (K, D)
330
 
 
331
  def anchor_features(H: np.ndarray, anchors: np.ndarray, temperature: float = 1.0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
332
  """
333
  For states H (N,D) and anchors A (K,D):
 
346
  P = np.exp(logits)
347
  P /= P.sum(axis=1, keepdims=True) + 1e-12
348
  # Uncertainty (entropy)
349
+ H_unc = -np.sum(P * np.log(P + 1e-12), axis=1)
350
+
351
  return dists, P, H_unc
352
 
353
+
354
+ # ====== Dimensionality reduction / embeddings ================================================
355
  def fit_umap_2d(pool: np.ndarray,
356
  n_neighbors: int = 30,
357
  min_dist: float = 0.05,
 
367
  reducer.fit(pool)
368
  return reducer
369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
  def fit_umap_3d(all_states: np.ndarray,
372
  n_neighbors: int = 30,
 
381
  metric=metric, random_state=random_state)
382
  return reducer.fit_transform(all_states)
383
 
384
+ # ====== Visualization ========================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  def plotly_3d_layers(xy_layers: List[np.ndarray],
386
  layer_tokens: List[List[str]],
387
  layer_cluster_labels: List[np.ndarray],
388
+ layer_pos_tags: List[List[str]],
389
  layer_uncertainty: List[np.ndarray],
390
  layer_graphs: List[nx.Graph],
391
+ color_by: str = "cluster",
392
+ title: str = "3D Cluster Formation",
393
+ prompt: str = None,) -> go.Figure:
394
+
 
 
 
 
 
395
  fig_data = []
396
 
397
+ # Define categorical colormap for POS
398
+ pos_map = {
399
+ "NOUN": "#1f77b4", "VERB": "#d62728", "ADJ": "#2ca02c",
400
+ "ADV": "#ff7f0e", "PRON": "#9467bd", "DET": "#8c564b",
401
+ "ADP": "#e377c2", "NUM": "#7f7f7f", "PUNCT": "#bcbd22",
402
+ "SYM": "#17becf", "UNK": "#bababa"
403
+ }
404
+
405
+ L = len(xy_layers)
406
+ for l, (xy, tokens, labels, pos, unc, G) in enumerate(zip(xy_layers, layer_tokens, layer_cluster_labels, layer_pos_tags, layer_uncertainty, layer_graphs)):
407
+ if len(xy) == 0: continue
408
  x, y = xy[:, 0], xy[:, 1]
409
  z = np.full_like(x, l, dtype=float)
410
 
411
+ # Color Logic
412
+ if color_by == "pos":
413
+ # Map POS strings to colors
414
+ node_colors = [pos_map.get(p, "#333333") for p in pos]
415
+ show_scale = False
416
+ colorscale = None
417
+ else:
418
+ # Cluster ID
419
+ node_colors = labels
420
+ show_scale = (l == 0)
421
+ colorscale = 'Viridis'
422
+
423
+ # Hover Text
424
+ node_text = [
425
+ f"L{l} | {tok}<br>POS: {p}<br>Cluster: {c}<br>Unc: {u:.2f}"
426
+ for tok, p, c, u in zip(tokens, pos, labels, unc)
427
+ ]
428
+
429
  node_trace = go.Scatter3d(
430
  x=x, y=y, z=z,
431
  mode='markers',
432
  name=f"Layer {l}",
433
+ showlegend=False,
434
  marker=dict(
435
+ size=3,
436
+ opacity=1,
437
+ color=node_colors,
438
+ colorscale=colorscale,
439
+ showscale=show_scale,
440
+ colorbar=dict(title="Cluster ID") if show_scale else None
441
  ),
442
  text=node_text,
443
  hovertemplate="%{text}<extra></extra>"
444
  )
445
  fig_data.append(node_trace)
446
 
447
+ # Edges
448
  if G is not None and G.number_of_edges() > 0:
449
  edge_x, edge_y, edge_z = [], [], []
450
  for u, v in G.edges():
451
  edge_x += [x[u], x[v], None]
452
  edge_y += [y[u], y[v], None]
453
  edge_z += [z[u], z[v], None]
454
+
455
  edge_trace = go.Scatter3d(
456
  x=edge_x, y=edge_y, z=edge_z,
457
  mode='lines',
458
+ line=dict(width=2, color='red'),
459
+ opacity=0.6,
460
+ hoverinfo='skip',
461
+ showlegend=False
462
  )
463
  fig_data.append(edge_trace)
464
 
465
+ # Trajectories (connect same token across layers)
466
+ if L > 1:
467
+ T = len(xy_layers[0])
468
+ # Sample trajectories to avoid lag if T is huge
469
+ step = max(1, T // 100)
470
+ for i in range(0, T, step):
471
+ xs = [xy_layers[l][i, 0] for l in range(L)]
472
+ ys = [xy_layers[l][i, 1] for l in range(L)]
473
+ zs = list(range(L))
474
+ traj = go.Scatter3d(
475
+ x=xs, y=ys, z=zs,
476
+ mode='lines',
477
+ line=dict(width=3, color='rgba(50,50,50,0.5)'),
478
+ hoverinfo='skip',
479
+ showlegend=False
480
+ )
481
+ fig_data.append(traj)
482
+ if color_by == "pos":
483
+ # Add legend-only traces for POS categories actually present
484
+ present_pos = sorted({p for layer in layer_pos_tags for p in layer})
485
+
486
+ for p in present_pos:
487
+ fig_data.append(
488
+ go.Scatter3d(
489
+ x=[None], y=[None], z=[None], # legend-only
490
+ mode="markers",
491
+ name=p,
492
+ marker=dict(size=8, color=pos_map.get(p, "#333333")),
493
+ showlegend=True,
494
+ hoverinfo="skip"
495
  )
496
+ )
497
 
498
  fig = go.Figure(data=fig_data)
499
  fig.update_layout(
500
+ title=dict(
501
+ text=title,
502
+ x=0.5,
503
+ xanchor="center",
504
+ ),
505
+ annotations=[
506
+ dict(
507
+ text=f"<b>Prompt:</b> {prompt}",
508
+ x=0.5,
509
+ y=1.02,
510
+ xref="paper",
511
+ yref="paper",
512
+ showarrow=False,
513
+ font=dict(size=13),
514
+ align="center"
515
+ )
516
+ ] if prompt else [],
517
  scene=dict(
518
  xaxis_title="UMAP X",
519
  yaxis_title="UMAP Y",
520
+ zaxis_title="Layer Depth",
521
+ aspectratio=dict(x=1, y=1, z=1.5)
522
  ),
523
  height=900,
524
+ margin=dict(l=0, r=0, b=0, t=40)
525
  )
526
  return fig
527
 
528
+ def run_pipeline(cfg: Config, model, tok, device, main_text: str, save_artifacts: bool = False):
 
529
  seed_everything(42)
530
 
531
+ # 1. Extract Hidden States
532
+ from transformers import logging
533
+ logging.set_verbosity_error()
 
534
 
535
+ # Extract
536
  main_bundle = extract_hidden_states(model, tok, main_text, cfg.max_length, device)
537
+ layers_np = main_bundle.hidden_layers
538
+ tokens = main_bundle.tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  L_all = len(layers_np)
 
540
 
541
+ # 2. Get POS Tags
542
+ pos_tags = get_pos_tags(main_text, tok, tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
 
544
+ # 3. Pooling & Anchors (LoT)
545
+ # (Simplified: just pool from the main text for speed in demo)
546
+ pool_states = np.vstack([layers_np[l] for l in range(0, L_all, 2)])
547
+ idx = np.random.choice(len(pool_states), min(len(pool_states), 2000), replace=False)
548
+ anchors = fit_global_anchors(pool_states[idx], cfg.anchor_k)
 
549
 
550
+ # 4. Process Layers
551
+ layer_features = []
552
+ layer_uncertainties = []
553
  layer_graphs = []
554
+ layer_cluster_labels = []
555
+ percolation = []
556
+
557
  for l in range(L_all):
558
+ H = layers_np[l]
559
+
560
+ # Features & Uncertainty
561
+ dists, P, H_unc = anchor_features(H, anchors, cfg.anchor_temp)
562
+ layer_features.append(dists)
563
+ layer_uncertainties.append(H_unc)
564
+
565
+ # Graphs
566
  if cfg.graph_mode == "knn":
567
+ G = build_knn_graph(dists, cfg.knn_k, metric="euclidean")
568
  else:
569
+ G = build_threshold_graph(H, cfg.sim_threshold, use_cosine=cfg.use_cosine)
 
570
  layer_graphs.append(G)
571
 
572
+ # Clusters
573
+ labels = cluster_layer(dists, G, cfg.cluster_method,
574
+ cfg.n_clusters_kmeans, cfg.hdbscan_min_cluster_size)
 
 
 
 
 
 
 
 
575
  layer_cluster_labels.append(labels)
576
 
577
+ # Percolation
578
+ percolation.append(percolation_stats(G))
 
 
 
579
 
580
+ # 5. UMAP & Alignment
581
+ # Fit UMAP on the pool to establish a coordinate system
582
+ reducer = umap.UMAP(n_components=2, n_neighbors=cfg.umap_n_neighbors,
583
+ min_dist=cfg.umap_min_dist, metric=cfg.umap_metric, random_state=42)
584
+ reducer.fit(pool_states[idx])
585
+
586
+ xy_by_layer = []
587
+ for l in range(L_all):
588
+ # Transform into 2D
589
+ xy = reducer.transform(layers_np[l])
590
 
591
+ # Procrustes Alignment: Align layer L to L-1
592
+ if cfg.align_layers and l > 0:
593
+ xy = orthogonal_align(xy_by_layer[l-1], xy)
 
 
 
594
 
595
+ xy_by_layer.append(xy)
 
 
596
 
597
+ # 6. Plot
598
  fig = plotly_3d_layers(
599
  xy_layers=xy_by_layer,
600
+ layer_tokens=[tokens] * L_all,
601
  layer_cluster_labels=layer_cluster_labels,
602
+ layer_pos_tags=[pos_tags] * L_all,
603
  layer_uncertainty=layer_uncertainties,
604
  layer_graphs=layer_graphs,
605
+ color_by=cfg.color_by,
606
+ title=f"{cfg.model_name.rsplit("/", 1)[-1]} 3D MRI | Color: {cfg.color_by.upper()} | Aligned: {cfg.align_layers}",
607
+ prompt=main_text
608
  )
609
 
610
+ # 7. Save Artifacts (This is the missing part)
611
  if save_artifacts:
612
+ import os
613
+ # Create the directory if it doesn't exist
614
+ os.makedirs(cfg.out_dir, exist_ok=True)
615
+
616
+ # Construct the full path
617
+ out_path = os.path.join(cfg.out_dir, cfg.plotly_html)
618
+
619
+ # Write the HTML file
620
+ fig.write_html(out_path)
621
+ print(f"Successfully saved 3D plot to: {out_path}")
622
 
623
  return fig, {"percolation": percolation, "tokens": tokens}
624
 
 
626
  def get_model_and_tok(model_name: str):
627
  device = "cuda" if torch.cuda.is_available() else "cpu"
628
  dtype = torch.float16 if device == "cuda" else torch.float32
629
+ config = AutoConfig.from_pretrained(model_name, output_hidden_states=True, trust_remote_code=True)
630
+ tok = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=True)
631
+ if tok.pad_token_id is None:
632
+ tok.pad_token = tok.eos_token
633
+
634
+ model = AutoModelForCausalLM.from_pretrained(
635
+ model_name,
636
+ trust_remote_code=True,
637
+ config=config,
638
+ torch_dtype=dtype if device == "cuda" else None,
639
+ device_map="auto" if device == "cuda" else None
640
+ )
641
+ model.eval()
642
+
643
+ if device != "cuda":
644
+ model = model.to(device)
645
+
646
  return model, tok, device, dtype
647
 
648
  def main():
649
+ st.set_page_config(page_title="LLM Hidden Layer Explorer", layout="wide")
650
+ st.title("Token Embedding Explorer (Live Hidden States)")
651
 
652
  with st.sidebar:
653
  st.header("Model / Input")
654
+ model_name = st.selectbox("Model", MODELS, index=1)
655
  max_length = st.slider("Max tokens", 16, 256, 64, step=16)
656
 
657
  st.header("Graph")
 
675
  st.header("Outputs")
676
  save_artifacts = st.checkbox("Save artifacts to disk (HTML/CSV/NPZ)", value=False)
677
 
678
+ prompt_col, run_col = st.columns([4, 1])
679
+
680
  with prompt_col:
681
+ main_text = st.selectbox(
682
+ "Prompt to visualize (hidden states computed on this text)",
683
+ options=DEFAULT_CORPUS,
684
+ index=0,
685
+ help="Select a predefined prompt for analysis"
686
  )
687
+
688
  with run_col:
689
  st.write("")
690
  st.write("")