Jgray21 commited on
Commit
c01107e
·
verified ·
1 Parent(s): 638275a

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +485 -346
src/streamlit_app.py CHANGED
@@ -7,20 +7,30 @@ from typing import Dict, List, Tuple, Optional
7
 
8
  import numpy as np
9
  import pandas as pd
 
10
  import torch
11
  from torch import nn
 
12
  import networkx as nx
13
  import streamlit as st
14
 
 
15
  from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
 
 
16
  import umap
 
 
 
17
  from sklearn.neighbors import NearestNeighbors, KernelDensity
18
  from sklearn.cluster import KMeans, DBSCAN
 
19
  from sklearn.metrics import pairwise_distances
20
- from scipy.spatial import procrustes
21
- from scipy.linalg import orthogonal_procrustes
22
  import plotly.graph_objects as go
23
 
 
24
 
25
  # Optional libs (use if present)
26
  try:
@@ -41,20 +51,28 @@ try:
41
  HAS_PYVISTA = True
42
  except Exception:
43
  HAS_PYVISTA = False
44
- # ====== Configuration =========================================================================
 
 
 
45
  @dataclass
46
  class Config:
47
  # Model
48
  model_name: str = "Qwen/Qwen1.5-1.8B"
49
- max_length: int = 64
 
 
 
 
50
 
51
  # Data
52
- corpus: List[str] = None
 
53
 
54
- # Graph & Clustering
55
- graph_mode: str = "threshold"
56
- knn_k: int = 8
57
- sim_threshold: float = 0.05 # Percentile of edges shown 0.05 = Show top 5% of edges
58
  use_cosine: bool = True
59
 
60
  # Anchors / LoT-style features (global)
@@ -66,149 +84,111 @@ class Config:
66
  n_clusters_kmeans: int = 6 # fallback for kmeans
67
  hdbscan_min_cluster_size: int = 4
68
 
69
- # UMAP & alignment
70
  umap_n_neighbors: int = 30
71
  umap_min_dist: float = 0.05
72
- umap_metric: str = "cosine"
 
 
 
73
  fit_pool_per_layer: int = 512 # number of states sampled per layer to fit UMAP
74
- align_layers: bool = True # aligning procrustes to layers
75
 
76
- # Visualization
77
- color_by: str = "pos" # "cluster" or "pos" (Part of Speech)
 
 
78
 
79
  # Output
80
  out_dir: str = "qwen_mri3d_outputs"
81
  plotly_html: str = "qwen_layers_3d.html"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  # Default corpus (small and diverse; adjust freely)
84
  DEFAULT_CORPUS = [
85
- "Is a Universal Basic Income (UBI) a viable solution to poverty, or does it simply discourage people from working?",
86
- "Explain the arguments for and against the independence of Taiwan from the perspective of both the US and China.",
87
- "What are the ethical arguments surrounding the use of CRISPR technology to edit human embryos for non-medical enhancements?",
88
- "Analyze the effectiveness of strict lockdowns versus herd immunity strategies during the COVID-19 pandemic.",
89
- "Why is nuclear energy controversial despite being a low-carbon power source? Present both the safety concerns and the environmental benefits.",
90
- "Does the existence of evil in the world disprove the existence of a benevolent God? Summarize the philosophical debate.",
91
- "Summarize the main arguments used by gun rights advocates against stricter background checks in the United States.",
92
- "Should autonomous weapons systems (killer robots) be banned internationally, even if they could reduce soldier casualties?",
93
- "Was the dropping of the atomic bombs on Hiroshima and Nagasaki militarily necessary to end World War II?",
94
- "What are the competing arguments regarding transgender women participating in biological women's sports categories?"
95
  ]
96
 
97
- #Select from 4 different models
98
- MODELS = ["Qwen/Qwen1.5-0.5B", "deepseek-ai/deepseek-coder-1.3b-instruct", "openai-community/gpt2", "prem-research/MiniGuard-v0.1"]
99
-
100
-
101
- # ====== Utilities =========================================================================
102
  def seed_everything(seed: int = 42):
 
103
  np.random.seed(seed)
104
  torch.manual_seed(seed)
105
 
 
106
  def cosine_similarity_matrix(X: np.ndarray) -> np.ndarray:
 
 
107
  norms = np.linalg.norm(X, axis=1, keepdims=True) + 1e-8
108
  Xn = X / norms
109
  return Xn @ Xn.T
110
 
111
- def orthogonal_align(A_ref: np.ndarray, B: np.ndarray) -> np.ndarray:
112
- """
113
- Align B to A_ref using Procrustes analysis (rotation/reflection only).
114
- Preserves local geometry of B, but aligns global orientation to A.
115
- """
116
- # Center both
117
- mu_a = A_ref.mean(0)
118
- mu_b = B.mean(0)
119
- A0 = A_ref - mu_a
120
- B0 = B - mu_b
121
-
122
- # Solve for Rotation R that minimizes ||A0 - B0 @ R||
123
- # M = B0.T @ A0
124
- # U, S, Vt = svd(M)
125
- # R = U @ Vt
126
- R, _ = orthogonal_procrustes(B0, A0)
127
 
128
- # B_aligned = (B - mu_b) @ R + mu_a
129
- # We essentially rotate B to match A's orientation, then shift to A's center
130
- return B0 @ R + mu_a
131
-
132
- def get_pos_tags(text: str, tokenizer, tokens: List[str]) -> List[str]:
133
  """
134
- Map LLM tokens to Spacy POS tags.
135
- Heuristic: Reconstruct text, run Spacy, align based on char overlap.
136
  """
137
- try:
138
- nlp = spacy.load("en_core_web_sm")
139
- except:
140
- # Fallback if model not downloaded
141
- return ["UNK"] * len(tokens)
142
-
143
- doc = nlp(text)
144
-
145
- # This is a simplified mapping. Real alignment is complex due to subwords.
146
- # We will approximate: Find which word the subword belongs to.
147
- pos_tags = []
148
-
149
- # Re-build offsets for tokens (simplified)
150
- # Ideally, we use tokenizer(return_offsets_mapping=True)
151
- # Here we will just iterate and approximate for the demo.
152
-
153
- # Fast approximation: tag the token string itself
154
- # (Not perfect for subwords like "ing", but visually useful)
155
- for t_str in tokens:
156
- clean_t = t_str.replace("Ġ", "").replace("▁", "").strip()
157
- if not clean_t:
158
- pos_tags.append("SYM") # likely special char
159
- continue
160
-
161
- # Tag the single token fragment
162
- sub_doc = nlp(clean_t)
163
- if len(sub_doc) > 0:
164
- pos_tags.append(sub_doc[0].pos_)
165
- else:
166
- pos_tags.append("UNK")
167
-
168
- return pos_tags
169
-
170
- def build_knn_graph(coords: np.ndarray, k: int, metric: str = "cosine") -> nx.Graph:
171
- nbrs = NearestNeighbors(n_neighbors=min(k+1, len(coords)), metric=metric)
172
  nbrs.fit(coords)
173
  distances, indices = nbrs.kneighbors(coords)
 
174
  G = nx.Graph()
175
  G.add_nodes_from(range(len(coords)))
 
176
  for i in range(len(coords)):
177
- for j in indices[i, 1:]:
178
  G.add_edge(int(i), int(j))
179
  return G
180
 
181
- def build_threshold_graph(H: np.ndarray, top_pct: float = 0.05, use_cosine: bool = True, include_ties: bool = True,) -> nx.Graph:
 
 
 
 
 
182
  if use_cosine:
183
  S = cosine_similarity_matrix(H)
184
  else:
185
- S = H @ H.T
186
 
187
  N = S.shape[0]
188
- iu = np.triu_indices(N, k=1)
189
- vals = S[iu]
190
-
191
- # threshold at (1 - top_pct) quantile
192
- q = 1.0 - top_pct
193
- thr = float(np.quantile(vals, q))
194
  G = nx.Graph()
195
  G.add_nodes_from(range(N))
196
-
197
- if include_ties:
198
- mask = vals >= thr
199
- else:
200
- # strictly greater than threshold reduces tie-inflation
201
- mask = vals > thr
202
-
203
- rows = iu[0][mask]
204
- cols = iu[1][mask]
205
- wts = vals[mask]
206
-
207
- for r, c, w in zip(rows, cols, wts):
208
- G.add_edge(int(r), int(c), weight=float(w))
209
  return G
210
 
 
211
  def percolation_stats(G: nx.Graph) -> Dict[str, float]:
 
 
 
 
 
212
  n = G.number_of_nodes()
213
  if n == 0:
214
  return dict(phi=0.0, num_clusters=0, chi=0.0, largest_component_size=0, component_sizes=[])
@@ -230,55 +210,109 @@ def percolation_stats(G: nx.Graph) -> Dict[str, float]:
230
  largest_component_size=largest,
231
  component_sizes=sorted(sizes, reverse=True))
232
 
233
- def cluster_layer(features: np.ndarray, G: Optional[nx.Graph], method: str,
234
- n_clusters_kmeans: int=6, hdbscan_min_cluster_size: int=4) -> np.ndarray:
235
- # (Same as original)
236
- method = method.lower()
237
- N = len(features)
238
- if method == "auto":
239
- if HAS_IGRAPH_LEIDEN and G and G.number_of_edges() > 0: return leiden_communities(G)
240
- elif HAS_HDBSCAN: return hdbscan.HDBSCAN(min_cluster_size=hdbscan_min_cluster_size).fit_predict(features)
241
- else: return KMeans(n_clusters=min(n_clusters_kmeans, N), n_init="auto").fit_predict(features)
242
- # ... (rest of method dispatch unchanged)
243
- return KMeans(n_clusters=min(n_clusters_kmeans, N), n_init="auto").fit_predict(features)
244
 
245
- # Helper for Leiden (from original)
246
  def leiden_communities(G: nx.Graph) -> np.ndarray:
247
- if not HAS_IGRAPH_LEIDEN: raise RuntimeError("Missing igraph")
 
 
 
 
 
 
 
248
  mapping = {n: i for i, n in enumerate(G.nodes())}
249
  edges = [(mapping[u], mapping[v]) for u, v in G.edges()]
250
  ig_g = ig.Graph(n=len(mapping), edges=edges, directed=False)
251
- part = la.find_partition(ig_g, la.RBConfigurationVertexPartition)
252
  labels = np.zeros(len(mapping), dtype=int)
253
  for cid, comm in enumerate(part):
254
- for node in comm: labels[node] = cid
 
255
  return labels
256
 
257
- def anchor_features(H: np.ndarray, anchors: np.ndarray, temperature: float = 1.0):
258
- dists = pairwise_distances(H, anchors, metric="euclidean")
259
- logits = -dists / max(temperature, 1e-6)
260
- logits = logits - logits.max(axis=1, keepdims=True)
261
- P = np.exp(logits)
262
- P /= P.sum(axis=1, keepdims=True) + 1e-12
263
- # Entropy calculation
264
- H_unc = -np.sum(P * np.log(P + 1e-12), axis=1)
265
- return dists, P, H_unc
266
 
267
- def fit_global_anchors(pool: np.ndarray, K: int) -> np.ndarray:
268
- km = KMeans(n_clusters=K, n_init="auto", random_state=42)
269
- km.fit(pool)
270
- return km.cluster_centers_
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
 
 
 
 
272
 
273
- # ====== Model I/O (hidden states) =============================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  @dataclass
275
  class HiddenStatesBundle:
 
 
 
 
 
276
  hidden_layers: List[np.ndarray]
277
  tokens: List[str]
278
 
279
 
280
  def load_qwen(model_name: str, device: str, dtype: torch.dtype):
281
-
 
 
282
  print(f"[Load] {model_name} on {device} ({dtype})")
283
  config = AutoConfig.from_pretrained(model_name, output_hidden_states=True)
284
  tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
@@ -291,7 +325,10 @@ def load_qwen(model_name: str, device: str, dtype: torch.dtype):
291
 
292
  @torch.no_grad()
293
  def extract_hidden_states(model, tokenizer, text: str, max_length: int, device: str) -> HiddenStatesBundle:
294
-
 
 
 
295
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
296
  out = model(**inputs)
297
  # Tuple length = num_layers + 1 (embedding)
@@ -299,17 +336,28 @@ def extract_hidden_states(model, tokenizer, text: str, max_length: int, device:
299
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
300
  return HiddenStatesBundle(hidden_layers=hs, tokens=tokens)
301
 
302
-
303
- # ====== LoT-style anchors & features ==========================================================
304
  def fit_global_anchors(all_states_sampled: np.ndarray, K: int, random_state: int = 42) -> np.ndarray:
305
-
 
 
 
 
306
  print(f"[Anchors] Fitting {K} global centroids on {len(all_states_sampled)} states ...")
307
  kmeans = KMeans(n_clusters=K, n_init="auto", random_state=random_state)
308
  kmeans.fit(all_states_sampled)
309
  return kmeans.cluster_centers_ # (K, D)
310
 
311
- def anchor_features(H: np.ndarray, anchors: np.ndarray, temperature: float = 1.0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
312
 
 
 
 
 
 
 
 
 
 
313
  # Distances (N, K)
314
  dists = pairwise_distances(H, anchors, metric="euclidean") # (N,K)
315
  # Soft assignments
@@ -319,315 +367,409 @@ def anchor_features(H: np.ndarray, anchors: np.ndarray, temperature: float = 1.0
319
  P = np.exp(logits)
320
  P /= P.sum(axis=1, keepdims=True) + 1e-12
321
  # Uncertainty (entropy)
322
- H_unc = -np.sum(P * np.log(P + 1e-12), axis=1)
323
-
324
  return dists, P, H_unc
325
 
326
-
327
- # ====== Dimensionality reduction / embeddings ================================================
328
  def fit_umap_2d(pool: np.ndarray,
329
  n_neighbors: int = 30,
330
  min_dist: float = 0.05,
331
  metric: str = "cosine",
332
  random_state: int = 42) -> umap.UMAP:
333
-
 
 
 
334
 
335
  reducer = umap.UMAP(n_components=2, n_neighbors=n_neighbors, min_dist=min_dist,
336
  metric=metric, random_state=random_state)
337
  reducer.fit(pool)
338
  return reducer
339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
  def fit_umap_3d(all_states: np.ndarray,
342
  n_neighbors: int = 30,
343
  min_dist: float = 0.05,
344
  metric: str = "cosine",
345
  random_state: int = 42) -> np.ndarray:
346
-
 
 
 
347
  reducer = umap.UMAP(n_components=3, n_neighbors=n_neighbors, min_dist=min_dist,
348
  metric=metric, random_state=random_state)
349
  return reducer.fit_transform(all_states)
350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
- # ====== Visualization ========================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  def plotly_3d_layers(xy_layers: List[np.ndarray],
354
  layer_tokens: List[List[str]],
355
  layer_cluster_labels: List[np.ndarray],
356
- layer_pos_tags: List[List[str]],
357
  layer_uncertainty: List[np.ndarray],
358
  layer_graphs: List[nx.Graph],
359
- color_by: str = "cluster",
360
- title: str = "3D Cluster Formation",
361
- prompt: str = None,) -> go.Figure:
362
-
 
 
 
 
 
363
  fig_data = []
364
 
365
- # Define categorical colormap for POS
366
- pos_map = {
367
- "NOUN": "#1f77b4", "VERB": "#d62728", "ADJ": "#2ca02c",
368
- "ADV": "#ff7f0e", "PRON": "#9467bd", "DET": "#8c564b",
369
- "ADP": "#e377c2", "NUM": "#7f7f7f", "PUNCT": "#bcbd22",
370
- "SYM": "#17becf", "UNK": "#bababa"
371
- }
372
-
373
- L = len(xy_layers)
374
- for l, (xy, tokens, labels, pos, unc, G) in enumerate(zip(xy_layers, layer_tokens, layer_cluster_labels, layer_pos_tags, layer_uncertainty, layer_graphs)):
375
- if len(xy) == 0: continue
376
  x, y = xy[:, 0], xy[:, 1]
377
  z = np.full_like(x, l, dtype=float)
378
 
379
- # Color Logic
380
- if color_by == "pos":
381
- # Map POS strings to colors
382
- node_colors = [pos_map.get(p, "#333333") for p in pos]
383
- show_scale = False
384
- colorscale = None
385
- else:
386
- # Cluster ID
387
- node_colors = labels
388
- show_scale = (l == 0)
389
- colorscale = 'Viridis'
390
-
391
- # Hover Text
392
- node_text = [
393
- f"L{l} | {tok}<br>POS: {p}<br>Cluster: {c}<br>Unc: {u:.2f}"
394
- for tok, p, c, u in zip(tokens, pos, labels, unc)
395
- ]
396
-
397
  node_trace = go.Scatter3d(
398
  x=x, y=y, z=z,
399
  mode='markers',
400
  name=f"Layer {l}",
401
- showlegend=False,
402
  marker=dict(
403
- size=3,
404
- opacity=1,
405
- color=node_colors,
406
- colorscale=colorscale,
407
- showscale=show_scale,
408
- colorbar=dict(title="Cluster ID") if show_scale else None
409
  ),
410
  text=node_text,
411
  hovertemplate="%{text}<extra></extra>"
412
  )
413
  fig_data.append(node_trace)
414
 
415
- # Edges
416
  if G is not None and G.number_of_edges() > 0:
417
  edge_x, edge_y, edge_z = [], [], []
418
  for u, v in G.edges():
419
  edge_x += [x[u], x[v], None]
420
  edge_y += [y[u], y[v], None]
421
  edge_z += [z[u], z[v], None]
422
-
423
  edge_trace = go.Scatter3d(
424
  x=edge_x, y=edge_y, z=edge_z,
425
  mode='lines',
426
- line=dict(width=2, color='red'),
427
- opacity=0.6,
428
- hoverinfo='skip',
429
- showlegend=False
430
  )
431
  fig_data.append(edge_trace)
432
 
433
- # Trajectories (connect same token across layers)
434
- if L > 1:
435
- T = len(xy_layers[0])
436
- # Sample trajectories to avoid lag if T is huge
437
- step = max(1, T // 100)
438
- for i in range(0, T, step):
439
- xs = [xy_layers[l][i, 0] for l in range(L)]
440
- ys = [xy_layers[l][i, 1] for l in range(L)]
441
- zs = list(range(L))
442
- traj = go.Scatter3d(
443
- x=xs, y=ys, z=zs,
444
- mode='lines',
445
- line=dict(width=3, color='rgba(50,50,50,0.5)'),
446
- hoverinfo='skip',
447
- showlegend=False
448
- )
449
- fig_data.append(traj)
450
- if color_by == "pos":
451
- # Add legend-only traces for POS categories actually present
452
- present_pos = sorted({p for layer in layer_pos_tags for p in layer})
453
-
454
- for p in present_pos:
455
- fig_data.append(
456
- go.Scatter3d(
457
- x=[None], y=[None], z=[None], # legend-only
458
- mode="markers",
459
- name=p,
460
- marker=dict(size=8, color=pos_map.get(p, "#333333")),
461
- showlegend=True,
462
- hoverinfo="skip"
463
  )
464
- )
465
 
466
  fig = go.Figure(data=fig_data)
467
  fig.update_layout(
468
- title=dict(
469
- text=title,
470
- x=0.5,
471
- xanchor="center",
472
- ),
473
- annotations=[
474
- dict(
475
- text=f"<b>Prompt:</b> {prompt}",
476
- x=0.5,
477
- y=1.02,
478
- xref="paper",
479
- yref="paper",
480
- showarrow=False,
481
- font=dict(size=13),
482
- align="center"
483
- )
484
- ] if prompt else [],
485
  scene=dict(
486
  xaxis_title="UMAP X",
487
  yaxis_title="UMAP Y",
488
- zaxis_title="Layer Depth",
489
- aspectratio=dict(x=1, y=1, z=1.5)
490
  ),
491
  height=900,
492
- margin=dict(l=0, r=0, b=0, t=40)
493
  )
494
  return fig
495
 
496
-
497
- def run_pipeline(cfg: Config, model, tok, device, main_text: str, save_artifacts: bool = False):
498
  seed_everything(42)
499
 
500
- # 1. Extract Hidden States
501
- from transformers import logging
502
- logging.set_verbosity_error()
 
503
 
504
- # Extract
505
  main_bundle = extract_hidden_states(model, tok, main_text, cfg.max_length, device)
506
- layers_np = main_bundle.hidden_layers
507
- tokens = main_bundle.tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  L_all = len(layers_np)
 
509
 
510
- # 2. Get POS Tags
511
- pos_tags = get_pos_tags(main_text, tok, tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
 
513
- # 3. Pooling & Anchors (LoT)
514
- # (Simplified: just pool from the main text for speed in demo)
515
- pool_states = np.vstack([layers_np[l] for l in range(0, L_all, 2)])
516
- idx = np.random.choice(len(pool_states), min(len(pool_states), 2000), replace=False)
517
- anchors = fit_global_anchors(pool_states[idx], cfg.anchor_k)
 
518
 
519
- # 4. Process Layers
520
- layer_features = []
521
- layer_uncertainties = []
522
  layer_graphs = []
523
- layer_cluster_labels = []
524
- percolation = []
525
-
526
  for l in range(L_all):
527
- H = layers_np[l]
528
-
529
- # Features & Uncertainty
530
- dists, P, H_unc = anchor_features(H, anchors, cfg.anchor_temp)
531
- layer_features.append(dists)
532
- layer_uncertainties.append(H_unc)
533
-
534
- # Graphs
535
  if cfg.graph_mode == "knn":
536
- G = build_knn_graph(dists, cfg.knn_k, metric="euclidean")
537
  else:
538
- G = build_threshold_graph(H, cfg.sim_threshold, use_cosine=cfg.use_cosine)
 
539
  layer_graphs.append(G)
540
 
541
- # Clusters
542
- labels = cluster_layer(dists, G, cfg.cluster_method,
543
- cfg.n_clusters_kmeans, cfg.hdbscan_min_cluster_size)
 
 
 
 
 
 
 
 
544
  layer_cluster_labels.append(labels)
545
 
546
- # Percolation
547
- percolation.append(percolation_stats(G))
548
-
549
- # 5. UMAP & Alignment
550
- # Fit UMAP on the pool to establish a coordinate system
551
- reducer = umap.UMAP(n_components=2, n_neighbors=cfg.umap_n_neighbors,
552
- min_dist=cfg.umap_min_dist, metric=cfg.umap_metric, random_state=42)
553
- reducer.fit(pool_states[idx])
554
-
555
- xy_by_layer = []
556
  for l in range(L_all):
557
- # Transform into 2D
558
- xy = reducer.transform(layers_np[l])
559
 
560
- # Procrustes Alignment: Align layer L to L-1
561
- if cfg.align_layers and l > 0:
562
- xy = orthogonal_align(xy_by_layer[l-1], xy)
563
 
564
- xy_by_layer.append(xy)
 
 
 
 
 
565
 
566
- # 6. Plot
 
 
 
 
567
  fig = plotly_3d_layers(
568
  xy_layers=xy_by_layer,
569
- layer_tokens=[tokens] * L_all,
570
  layer_cluster_labels=layer_cluster_labels,
571
- layer_pos_tags=[pos_tags] * L_all,
572
  layer_uncertainty=layer_uncertainties,
573
  layer_graphs=layer_graphs,
574
- color_by=cfg.color_by,
575
- title=f"{cfg.model_name.rsplit("/", 1)[-1]} 3D MRI | Color: {cfg.color_by.upper()} | Aligned: {cfg.align_layers}",
576
- prompt=main_text
577
  )
578
 
579
- # 7. Save Artifacts (This is the missing part)
580
  if save_artifacts:
581
- import os
582
- # Create the directory if it doesn't exist
583
- os.makedirs(cfg.out_dir, exist_ok=True)
584
-
585
- # Construct the full path
586
- out_path = os.path.join(cfg.out_dir, cfg.plotly_html)
587
-
588
- # Write the HTML file
589
- fig.write_html(out_path)
590
- print(f"Successfully saved 3D plot to: {out_path}")
591
 
592
  return fig, {"percolation": percolation, "tokens": tokens}
593
 
594
-
595
  @st.cache_resource(show_spinner=False)
596
  def get_model_and_tok(model_name: str):
597
  device = "cuda" if torch.cuda.is_available() else "cpu"
598
  dtype = torch.float16 if device == "cuda" else torch.float32
599
- config = AutoConfig.from_pretrained(model_name, output_hidden_states=True, trust_remote_code=True)
600
- tok = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=True)
601
- if tok.pad_token_id is None:
602
- tok.pad_token = tok.eos_token
603
-
604
- model = AutoModelForCausalLM.from_pretrained(
605
- model_name,
606
- trust_remote_code=True,
607
- config=config,
608
- torch_dtype=dtype if device == "cuda" else None,
609
- device_map="auto" if device == "cuda" else None
610
- )
611
- model.eval()
612
-
613
- if device != "cuda":
614
- model = model.to(device)
615
-
616
  return model, tok, device, dtype
617
 
618
  def main():
619
- st.set_page_config(page_title="LLM Hidden Layer Explorer", layout="wide")
620
- st.title("Token Embedding Explorer (Live Hidden States)")
621
 
622
  with st.sidebar:
623
  st.header("Model / Input")
624
- model_name = st.selectbox("Model", MODELS, index=1)
625
  max_length = st.slider("Max tokens", 16, 256, 64, step=16)
626
 
627
  st.header("Graph")
628
  graph_mode = st.selectbox("Graph mode", ["knn", "threshold"], index=0)
629
  knn_k = st.slider("k (kNN)", 2, 50, 8) if graph_mode == "knn" else 8
630
- sim_threshold = st.slider("Similarity threshold", 0.0, 0.99, 0.05, step=0.01) if graph_mode == "threshold" else 0.70
631
  use_cosine = st.checkbox("Use cosine similarity", value=True)
632
 
633
  st.header("Anchors / LoT")
@@ -645,16 +787,13 @@ def main():
645
  st.header("Outputs")
646
  save_artifacts = st.checkbox("Save artifacts to disk (HTML/CSV/NPZ)", value=False)
647
 
648
- prompt_col, run_col = st.columns([4, 1])
649
-
650
  with prompt_col:
651
- main_text = st.selectbox(
652
- "Prompt to visualize (hidden states computed on this text)",
653
- options=DEFAULT_CORPUS,
654
- index=0,
655
- help="Select a predefined prompt for analysis"
656
  )
657
-
658
  with run_col:
659
  st.write("")
660
  st.write("")
 
7
 
8
  import numpy as np
9
  import pandas as pd
10
+
11
  import torch
12
  from torch import nn
13
+
14
  import networkx as nx
15
  import streamlit as st
16
 
17
+ # Transformers: Qwen tokenizer can be AutoTokenizer if Qwen2Tokenizer not present
18
  from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
19
+
20
+ # Dimensionality reduction
21
  import umap
22
+ from umap import UMAP
23
+
24
+ # Neighbors & clustering
25
  from sklearn.neighbors import NearestNeighbors, KernelDensity
26
  from sklearn.cluster import KMeans, DBSCAN
27
+ from sklearn.decomposition import PCA
28
  from sklearn.metrics import pairwise_distances
29
+
30
+ # Plotly for interactive 3D
31
  import plotly.graph_objects as go
32
 
33
+ import hashlib
34
 
35
  # Optional libs (use if present)
36
  try:
 
51
  HAS_PYVISTA = True
52
  except Exception:
53
  HAS_PYVISTA = False
54
+
55
+ from scipy.linalg import orthogonal_procrustes # For optional per-layer orientation alignment
56
+
57
+ # ====== 1. Configuration =========================================================================
58
  @dataclass
59
  class Config:
60
  # Model
61
  model_name: str = "Qwen/Qwen1.5-1.8B"
62
+ ### device: str = "cuda" if torch.cuda.is_available() else "cpu"
63
+ ### dtype: torch.dtype = torch.float16 if torch.cuda.is_available() else torch.float32
64
+
65
+ # Tokenization / generation
66
+ max_length: int = 64 # truncate inputs for speed/memory
67
 
68
  # Data
69
+ corpus: List[str] = None # set below
70
+ # If None, uses DEFAULT_CORPUS defined below
71
 
72
+ # Graph building
73
+ graph_mode: str = "threshold" # {"knn", "threshold"}
74
+ knn_k: int = 8 # neighbors per token (used if graph_mode="knn")
75
+ sim_threshold: float = 0.60 # used if graph_mode="threshold"
76
  use_cosine: bool = True
77
 
78
  # Anchors / LoT-style features (global)
 
84
  n_clusters_kmeans: int = 6 # fallback for kmeans
85
  hdbscan_min_cluster_size: int = 4
86
 
87
+ # DR / embeddings
88
  umap_n_neighbors: int = 30
89
  umap_min_dist: float = 0.05
90
+ umap_metric: str = "cosine" # hidden states are directional → cosine works well
91
+ use_global_3d_umap: bool = False # if True, compute a single 3D manifold for all states
92
+
93
+ # Pooling for UMAP fit
94
  fit_pool_per_layer: int = 512 # number of states sampled per layer to fit UMAP
 
95
 
96
+ # Volume grid (MRI view)
97
+ grid_res: int = 128 # voxel resolution in x/y; z = num_layers
98
+ kde_bandwidth: float = 0.15 # KDE bandwidth in manifold space (if using KDE)
99
+ use_hist2d: bool = True # if True, use histogram2d instead of KDE for speed
100
 
101
  # Output
102
  out_dir: str = "qwen_mri3d_outputs"
103
  plotly_html: str = "qwen_layers_3d.html"
104
+ volume_npz: str = "qwen_density_volume.npz" # saved if PyVista isn't available
105
+ volume_screenshot: str = "qwen_volume.png" # if PyVista is available
106
+
107
+ def validate(self):
108
+ if self.graph_mode not in {"knn", "threshold"}:
109
+ raise ValueError("graph_mode must be 'knn' or 'threshold'")
110
+ if self.knn_k < 2:
111
+ raise ValueError("knn_k must be >= 2")
112
+ if self.anchor_k < 2:
113
+ raise ValueError("anchor_k must be >= 2")
114
+ if self.anchor_temp <= 0:
115
+ raise ValueError("anchor_temp must be > 0")
116
+
117
+
118
 
119
  # Default corpus (small and diverse; adjust freely)
120
  DEFAULT_CORPUS = [
121
+ "The cat sat on the mat and watched.",
122
+ "Machine learning models process data using neural networks.",
123
+ "Climate change affects ecosystems around the world.",
124
+ "Quantum computers use superposition for parallel computation.",
125
+ "The universe contains billions of galaxies.",
126
+ "Artificial intelligence transforms how we work.",
127
+ "DNA stores genetic information in cells.",
128
+ "Ocean currents regulate Earth's climate system.",
129
+ "Photosynthesis converts sunlight into chemical energy.",
130
+ "Blockchain technology enables decentralized systems."
131
  ]
132
 
133
+ # ====== 2. Utilities =============================================================================
 
 
 
 
134
  def seed_everything(seed: int = 42):
135
+ """Determinism for reproducibility in layouts/UMAP/kmeans."""
136
  np.random.seed(seed)
137
  torch.manual_seed(seed)
138
 
139
+
140
  def cosine_similarity_matrix(X: np.ndarray) -> np.ndarray:
141
+ """Compute pairwise cosine similarity for rows of X."""
142
+ # X: (N, D)
143
  norms = np.linalg.norm(X, axis=1, keepdims=True) + 1e-8
144
  Xn = X / norms
145
  return Xn @ Xn.T
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
+ def build_knn_graph(coords: np.ndarray, k: int, metric: str = "cosine") -> nx.Graph:
 
 
 
 
149
  """
150
+ Build an undirected kNN graph for the points in coords.
151
+ coords: (N, D)
152
  """
153
+ nbrs = NearestNeighbors(n_neighbors=min(k+1, len(coords)), metric=metric) # +1 to include self
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  nbrs.fit(coords)
155
  distances, indices = nbrs.kneighbors(coords)
156
+
157
  G = nx.Graph()
158
  G.add_nodes_from(range(len(coords)))
159
+ # Connect i to its top-k neighbors (skip index 0 which is itself)
160
  for i in range(len(coords)):
161
+ for j in indices[i, 1:]: # skip self
162
  G.add_edge(int(i), int(j))
163
  return G
164
 
165
+
166
+ def build_threshold_graph(H: np.ndarray, threshold: float, use_cosine: bool = True) -> nx.Graph:
167
+ """
168
+ Build graph by thresholding pairwise similarities in the original hidden-state space.
169
+ H: (N, D) hidden states for a single layer
170
+ """
171
  if use_cosine:
172
  S = cosine_similarity_matrix(H)
173
  else:
174
+ S = H @ H.T # dot product
175
 
176
  N = S.shape[0]
 
 
 
 
 
 
177
  G = nx.Graph()
178
  G.add_nodes_from(range(N))
179
+ for i in range(N):
180
+ for j in range(i + 1, N):
181
+ if S[i, j] > threshold:
182
+ G.add_edge(i, j, weight=float(S[i, j]))
 
 
 
 
 
 
 
 
 
183
  return G
184
 
185
+
186
  def percolation_stats(G: nx.Graph) -> Dict[str, float]:
187
+ """
188
+ Compute percolation observables (φ, #clusters, χ) as in your notebook.
189
+ φ : fraction of nodes in the Giant Connected Component (GCC)
190
+ χ : mean size of components excluding GCC
191
+ """
192
  n = G.number_of_nodes()
193
  if n == 0:
194
  return dict(phi=0.0, num_clusters=0, chi=0.0, largest_component_size=0, component_sizes=[])
 
210
  largest_component_size=largest,
211
  component_sizes=sorted(sizes, reverse=True))
212
 
 
 
 
 
 
 
 
 
 
 
 
213
 
 
214
  def leiden_communities(G: nx.Graph) -> np.ndarray:
215
+ """
216
+ Community detection using Leiden (igraph), if available.
217
+ Returns an array of cluster ids for nodes 0..N-1.
218
+ """
219
+ if not HAS_IGRAPH_LEIDEN:
220
+ raise RuntimeError("igraph+leidenalg not available")
221
+
222
+ # Convert nx → igraph
223
  mapping = {n: i for i, n in enumerate(G.nodes())}
224
  edges = [(mapping[u], mapping[v]) for u, v in G.edges()]
225
  ig_g = ig.Graph(n=len(mapping), edges=edges, directed=False)
226
+ part = la.find_partition(ig_g, la.RBConfigurationVertexPartition) # robust default
227
  labels = np.zeros(len(mapping), dtype=int)
228
  for cid, comm in enumerate(part):
229
+ for node in comm:
230
+ labels[node] = cid
231
  return labels
232
 
 
 
 
 
 
 
 
 
 
233
 
234
+ def cluster_layer(features: np.ndarray,
235
+ G: Optional[nx.Graph],
236
+ method: str,
237
+ n_clusters_kmeans: int = 6,
238
+ hdbscan_min_cluster_size: int = 4) -> np.ndarray:
239
+ """
240
+ Cluster layer states to get cluster labels.
241
+ - If Leiden: requires G (graph) and igraph/leidenalg
242
+ - If HDBSCAN: density-based clustering in feature space
243
+ - If DBSCAN: fallback density-based (scikit-learn)
244
+ - If KMeans: fallback centroid clustering
245
+ """
246
+ method = method.lower()
247
+ N = len(features)
248
+
249
+ if method == "auto":
250
+ # Prefer Leiden (graph) → HDBSCAN → KMeans
251
+ if HAS_IGRAPH_LEIDEN and G is not None and G.number_of_edges() > 0:
252
+ return leiden_communities(G)
253
+ elif HAS_HDBSCAN and N >= 5:
254
+ clusterer = hdbscan.HDBSCAN(min_cluster_size=hdbscan_min_cluster_size,
255
+ metric='euclidean')
256
+ labels = clusterer.fit_predict(features)
257
+ # HDBSCAN: -1 = noise. Keep as its own "noise" cluster id or remap
258
+ return labels
259
+ else:
260
+ km = KMeans(n_clusters=min(n_clusters_kmeans, max(2, N // 3)),
261
+ n_init="auto", random_state=42)
262
+ return km.fit_predict(features)
263
+
264
+ if method == "leiden":
265
+ if G is None or not HAS_IGRAPH_LEIDEN:
266
+ raise RuntimeError("Leiden requires a graph and igraph+leidenalg.")
267
+ return leiden_communities(G)
268
+
269
+ if method == "hdbscan":
270
+ if not HAS_HDBSCAN:
271
+ raise RuntimeError("hdbscan not installed")
272
+ clusterer = hdbscan.HDBSCAN(min_cluster_size=hdbscan_min_cluster_size, metric='euclidean')
273
+ return clusterer.fit_predict(features)
274
+
275
+ if method == "dbscan":
276
+ db = DBSCAN(eps=0.5, min_samples=4, metric='euclidean')
277
+ return db.fit_predict(features)
278
 
279
+ if method == "kmeans":
280
+ km = KMeans(n_clusters=min(n_clusters_kmeans, max(2, N // 3)),
281
+ n_init="auto", random_state=42)
282
+ return km.fit_predict(features)
283
 
284
+ raise ValueError(f"Unknown cluster method: {method}")
285
+
286
+
287
+ def orthogonal_align(A_ref: np.ndarray, B: np.ndarray) -> np.ndarray:
288
+ """
289
+ Align B to A_ref by an orthogonal rotation (Procrustes),
290
+ preserving geometry but removing arbitrary orientation flips.
291
+ """
292
+ R, _ = orthogonal_procrustes(B - B.mean(0), A_ref - A_ref.mean(0))
293
+ return (B - B.mean(0)) @ R + A_ref.mean(0)
294
+
295
+
296
+ def entropy_from_probs(p: np.ndarray, eps: float = 1e-12) -> np.ndarray:
297
+ """Shannon entropy for each row; p is (N, K) with rows summing ~1."""
298
+ return -np.sum(p * np.log(p + eps), axis=1)
299
+
300
+ # ====== 3. Model I/O (hidden states) =============================================================
301
  @dataclass
302
  class HiddenStatesBundle:
303
+ """
304
+ Encapsulates a single input's hidden states and metadata.
305
+ hidden_layers: list of np.ndarray of shape (T, D), length = num_layers+1 (incl. embedding)
306
+ tokens : list of token strings of length T
307
+ """
308
  hidden_layers: List[np.ndarray]
309
  tokens: List[str]
310
 
311
 
312
  def load_qwen(model_name: str, device: str, dtype: torch.dtype):
313
+ """
314
+ Load Qwen with output_hidden_states=True. We use AutoTokenizer for broader compatibility.
315
+ """
316
  print(f"[Load] {model_name} on {device} ({dtype})")
317
  config = AutoConfig.from_pretrained(model_name, output_hidden_states=True)
318
  tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
 
325
 
326
  @torch.no_grad()
327
  def extract_hidden_states(model, tokenizer, text: str, max_length: int, device: str) -> HiddenStatesBundle:
328
+ """
329
+ Run a single forward pass to collect all hidden states (incl. embedding layer).
330
+ Returns CPU numpy arrays to keep GPU memory low.
331
+ """
332
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
333
  out = model(**inputs)
334
  # Tuple length = num_layers + 1 (embedding)
 
336
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
337
  return HiddenStatesBundle(hidden_layers=hs, tokens=tokens)
338
 
339
+ # ====== 4. LoT-style anchors & features ==========================================================
 
340
  def fit_global_anchors(all_states_sampled: np.ndarray, K: int, random_state: int = 42) -> np.ndarray:
341
+ """
342
+ Fit KMeans cluster centroids on a pooled set of states (from many layers/texts).
343
+ These centroids are "anchors" (LoT-like choices) to build low-dim features:
344
+ f(state) = [dist(state, anchor_j)]_{j=1..K}
345
+ """
346
  print(f"[Anchors] Fitting {K} global centroids on {len(all_states_sampled)} states ...")
347
  kmeans = KMeans(n_clusters=K, n_init="auto", random_state=random_state)
348
  kmeans.fit(all_states_sampled)
349
  return kmeans.cluster_centers_ # (K, D)
350
 
 
351
 
352
+ def anchor_features(H: np.ndarray, anchors: np.ndarray, temperature: float = 1.0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
353
+ """
354
+ For states H (N,D) and anchors A (K,D):
355
+ - Compute Euclidean distances to each anchor → Dists (N,K)
356
+ - Convert to soft probabilities with exp(-Dist/T), normalize row-wise → P (N,K)
357
+ - Uncertainty = entropy(P) (cf. LoT Eq. (6))
358
+ - Top-anchor argmin distance for "consistency"-style comparisons (cf. Eq. (5))
359
+ Returns (Dists, P, entropy)
360
+ """
361
  # Distances (N, K)
362
  dists = pairwise_distances(H, anchors, metric="euclidean") # (N,K)
363
  # Soft assignments
 
367
  P = np.exp(logits)
368
  P /= P.sum(axis=1, keepdims=True) + 1e-12
369
  # Uncertainty (entropy)
370
+ H_unc = entropy_from_probs(P)
 
371
  return dists, P, H_unc
372
 
373
+ # ====== 5. Dimensionality reduction / embeddings ================================================
 
374
  def fit_umap_2d(pool: np.ndarray,
375
  n_neighbors: int = 30,
376
  min_dist: float = 0.05,
377
  metric: str = "cosine",
378
  random_state: int = 42) -> umap.UMAP:
379
+ """
380
+ Fit UMAP once on a diverse pool across layers to preserve orientation.
381
+ Later layers call .transform() to embed into the SAME 2D space → "MRI stack".
382
+ """
383
 
384
  reducer = umap.UMAP(n_components=2, n_neighbors=n_neighbors, min_dist=min_dist,
385
  metric=metric, random_state=random_state)
386
  reducer.fit(pool)
387
  return reducer
388
 
389
+ def _corpus_fingerprint(texts, max_items=5, max_chars=4000) -> str:
390
+ """Stable key so cache invalidates if DEFAULT_CORPUS changes."""
391
+ joined = "\n".join(texts[:max_items])
392
+ joined = joined[:max_chars]
393
+ return hashlib.sha256(joined.encode("utf-8")).hexdigest()
394
+
395
+ @st.cache_data(show_spinner=False)
396
+ def get_pool_artifacts(
397
+ model_name: str,
398
+ max_length: int,
399
+ anchor_k: int,
400
+ anchor_temp: float, # not strictly needed for fitting anchors, but included if you want cache keys aligned
401
+ umap_n_neighbors: int,
402
+ umap_min_dist: float,
403
+ umap_metric: str,
404
+ fit_pool_per_layer: int,
405
+ corpus_hash: str,
406
+ ):
407
+ """
408
+ Cached: build pooled hidden states on DEFAULT_CORPUS, fit anchors and a UMAP reducer once.
409
+ Returns:
410
+ anchors: (K, D) np.ndarray
411
+ reducer2d: fitted UMAP reducer object (must be pickleable; umap-learn's UMAP is)
412
+ """
413
+ # Use cached model loader (resource cache)
414
+ model, tok, device, dtype = get_model_and_tok(model_name)
415
+
416
+ texts = DEFAULT_CORPUS # pooled set for stability
417
+
418
+ pool_states = []
419
+ for t in texts[: min(5, len(texts))]:
420
+ b = extract_hidden_states(model, tok, t, max_length, device)
421
+ for H in b.hidden_layers:
422
+ T = len(H)
423
+ take = min(fit_pool_per_layer, T)
424
+ if take <= 0:
425
+ continue
426
+ idx = np.random.choice(T, size=take, replace=False)
427
+ pool_states.append(H[idx])
428
+
429
+ if not pool_states:
430
+ # fallback: this should rarely happen
431
+ raise RuntimeError("Pool construction produced no states.")
432
+
433
+ pool_states = np.vstack(pool_states)
434
+
435
+ anchors = fit_global_anchors(pool_states, anchor_k)
436
+
437
+ reducer2d = fit_umap_2d(
438
+ pool_states,
439
+ n_neighbors=umap_n_neighbors,
440
+ min_dist=umap_min_dist,
441
+ metric=umap_metric,
442
+ )
443
+
444
+ return anchors, reducer2d
445
+
446
 
447
  def fit_umap_3d(all_states: np.ndarray,
448
  n_neighbors: int = 30,
449
  min_dist: float = 0.05,
450
  metric: str = "cosine",
451
  random_state: int = 42) -> np.ndarray:
452
+ """
453
+ Fit a global 3D UMAP embedding for all states at once (alternative to slice stack).
454
+ Returns coords_3d (N,3) for the concatenated states passed in.
455
+ """
456
  reducer = umap.UMAP(n_components=3, n_neighbors=n_neighbors, min_dist=min_dist,
457
  metric=metric, random_state=random_state)
458
  return reducer.fit_transform(all_states)
459
 
460
+ # ====== 6. Volume construction (MRI) ============================================================
461
+ def stack_density_volume(xy_by_layer: List[np.ndarray],
462
+ grid_res: int,
463
+ use_hist2d: bool = True,
464
+ kde_bandwidth: float = 0.15) -> np.ndarray:
465
+ """
466
+ Construct a 3D volume by estimating 2D density on the (x,y) manifold per layer (slice).
467
+ - If use_hist2d: fast uniform binning into grid_res x grid_res
468
+ - Else: KDE (slower but smoother)
469
+ Returns volume of shape (grid_res, grid_res, L) where L = #layers.
470
+ """
471
+ L = len(xy_by_layer)
472
+ vol = np.zeros((grid_res, grid_res, L), dtype=np.float32)
473
+
474
+ # Determine global bounds across layers to keep axes consistent
475
+ all_xy = np.vstack([xy for xy in xy_by_layer if len(xy) > 0]) if L > 0 else np.zeros((0, 2))
476
+ if len(all_xy) == 0:
477
+ return vol
478
+ x_min, y_min = all_xy.min(axis=0)
479
+ x_max, y_max = all_xy.max(axis=0)
480
+ # Slight padding
481
+ pad = 1e-6
482
+ x_edges = np.linspace(x_min - pad, x_max + pad, grid_res + 1)
483
+ y_edges = np.linspace(y_min - pad, y_max + pad, grid_res + 1)
484
+
485
+ for l, XY in enumerate(xy_by_layer):
486
+ if len(XY) == 0:
487
+ continue
488
 
489
+ if use_hist2d:
490
+ H, _, _ = np.histogram2d(XY[:, 0], XY[:, 1], bins=[x_edges, y_edges], density=False)
491
+ vol[:, :, l] = H.T # histogram2d returns [x_bins, y_bins] → transpose to align
492
+ else:
493
+ kde = KernelDensity(bandwidth=kde_bandwidth, kernel="gaussian")
494
+ kde.fit(XY)
495
+ # Evaluate KDE on grid centers
496
+ xs = 0.5 * (x_edges[:-1] + x_edges[1:])
497
+ ys = 0.5 * (y_edges[:-1] + y_edges[1:])
498
+ xx, yy = np.meshgrid(xs, ys, indexing='xy')
499
+ grid_points = np.column_stack([xx.ravel(), yy.ravel()])
500
+ log_dens = kde.score_samples(grid_points)
501
+ dens = np.exp(log_dens).reshape(grid_res, grid_res)
502
+ vol[:, :, l] = dens
503
+
504
+ # Normalize volume to [0,1] for rendering convenience
505
+ if vol.max() > 0:
506
+ vol = vol / vol.max()
507
+ return vol
508
+
509
+
510
+ def render_volume_with_pyvista(volume: np.ndarray,
511
+ out_png: str,
512
+ opacity="sigmoid") -> None:
513
+ """
514
+ Visualize the 3D volume using PyVista/VTK (if installed); save a screenshot.
515
+ """
516
+ if not HAS_PYVISTA:
517
+ raise RuntimeError("PyVista is not installed; cannot render volume.")
518
+ pl = pv.Plotter()
519
+ # Wrap NumPy array as a VTK image data; PyVista expects z as the 3rd axis
520
+ vol_vtk = pv.wrap(volume)
521
+ pl.add_volume(vol_vtk, opacity=opacity, shade=True)
522
+ pl.show(screenshot=out_png) # headless environments will still save a screenshot (if offscreen support)
523
+
524
+ # ====== 7. 3D Plotly visualization ==============================================================
525
  def plotly_3d_layers(xy_layers: List[np.ndarray],
526
  layer_tokens: List[List[str]],
527
  layer_cluster_labels: List[np.ndarray],
 
528
  layer_uncertainty: List[np.ndarray],
529
  layer_graphs: List[nx.Graph],
530
+ connect_token_trajectories: bool = True,
531
+ title: str = "Qwen: 3D Cluster Formation (UMAP2D + Layer as Z)") -> go.Figure:
532
+ """
533
+ Build an interactive 3D Plotly figure:
534
+ - Nodes per layer at (x, y, z=layer)
535
+ - Edge segments (kNN or threshold graph) per layer
536
+ - Trajectory lines: connect same token index across consecutive layers (optional)
537
+ - Color nodes by cluster label; hover shows token & uncertainty
538
+ """
539
  fig_data = []
540
 
541
+ # Build a color per layer node trace
542
+ for l, (xy, tokens, labels, unc, G) in enumerate(zip(xy_layers, layer_tokens, layer_cluster_labels, layer_uncertainty, layer_graphs)):
543
+ if len(xy) == 0:
544
+ continue
 
 
 
 
 
 
 
545
  x, y = xy[:, 0], xy[:, 1]
546
  z = np.full_like(x, l, dtype=float)
547
 
548
+ # --- Nodes
549
+ node_text = [f"layer={l} | idx={i}<br>token={tokens[i]}<br>cluster={int(labels[i])}<br>uncertainty={unc[i]:.3f}"
550
+ for i in range(len(tokens))]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  node_trace = go.Scatter3d(
552
  x=x, y=y, z=z,
553
  mode='markers',
554
  name=f"Layer {l}",
 
555
  marker=dict(
556
+ size=4,
557
+ opacity=0.7,
558
+ color=labels, # cluster ID → color scale
559
+ colorscale='Viridis',
560
+ showscale=(l == 0) # show scale once
 
561
  ),
562
  text=node_text,
563
  hovertemplate="%{text}<extra></extra>"
564
  )
565
  fig_data.append(node_trace)
566
 
567
+ # --- Intra-layer edges (kNN or threshold)
568
  if G is not None and G.number_of_edges() > 0:
569
  edge_x, edge_y, edge_z = [], [], []
570
  for u, v in G.edges():
571
  edge_x += [x[u], x[v], None]
572
  edge_y += [y[u], y[v], None]
573
  edge_z += [z[u], z[v], None]
 
574
  edge_trace = go.Scatter3d(
575
  x=edge_x, y=edge_y, z=edge_z,
576
  mode='lines',
577
+ line=dict(width=1),
578
+ opacity=0.30,
579
+ name=f"Edges L{l}"
 
580
  )
581
  fig_data.append(edge_trace)
582
 
583
+ # --- Trajectories: connect same token index across layers
584
+ if connect_token_trajectories:
585
+ # Only meaningful if tokenization length T is constant across layers (it is)
586
+ # We'll draw faint polylines for each position i across l=0..L-1
587
+ L = len(xy_layers)
588
+ if L > 1:
589
+ T = min(len(xy_layers[l]) for l in range(L))
590
+ for i in range(T):
591
+ xs = [xy_layers[l][i, 0] for l in range(L)]
592
+ ys = [xy_layers[l][i, 1] for l in range(L)]
593
+ zs = list(range(L))
594
+ traj = go.Scatter3d(
595
+ x=xs, y=ys, z=zs,
596
+ mode='lines',
597
+ line=dict(width=1),
598
+ opacity=0.15,
599
+ name=f"traj_{i}",
600
+ hoverinfo='skip'
 
 
 
 
 
 
 
 
 
 
 
 
601
  )
602
+ fig_data.append(traj)
603
 
604
  fig = go.Figure(data=fig_data)
605
  fig.update_layout(
606
+ title=title,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607
  scene=dict(
608
  xaxis_title="UMAP X",
609
  yaxis_title="UMAP Y",
610
+ zaxis_title="Layer (depth)"
 
611
  ),
612
  height=900,
613
+ showlegend=False
614
  )
615
  return fig
616
 
617
+ # ====== 8. Orchestration ========================================================================
618
+ def run_pipeline(cfg: Config, model, tok, device, main_text: str, save_artifacts: bool = False):
619
  seed_everything(42)
620
 
621
+ # 8.2 Collect hidden states for one representative text (detailed viz) + for pool
622
+ # You can extend to many texts; we keep a single text for clarity & speed.
623
+ texts = cfg.corpus or DEFAULT_CORPUS
624
+ #print(f"[Input] Example text: {main_text!r}")
625
 
626
+ # Hidden states for main text
627
  main_bundle = extract_hidden_states(model, tok, main_text, cfg.max_length, device)
628
+ layers_np: List[np.ndarray] = main_bundle.hidden_layers # list of (T,D), length L_all = num_layers+1
629
+ tokens = main_bundle.tokens # list of length T
630
+
631
+ # Cached pool artifacts (anchors + fitted UMAP reducer)
632
+ corpus_hash = _corpus_fingerprint(texts) # texts is cfg.corpus or DEFAULT_CORPUS
633
+
634
+ anchors, reducer2d = get_pool_artifacts(
635
+ model_name=cfg.model_name,
636
+ max_length=cfg.max_length,
637
+ anchor_k=cfg.anchor_k,
638
+ anchor_temp=cfg.anchor_temp,
639
+ umap_n_neighbors=cfg.umap_n_neighbors,
640
+ umap_min_dist=cfg.umap_min_dist,
641
+ umap_metric=cfg.umap_metric,
642
+ fit_pool_per_layer=cfg.fit_pool_per_layer,
643
+ corpus_hash=corpus_hash,
644
+ )
645
+
646
  L_all = len(layers_np)
647
+ #print(f"[Hidden] Layers (incl. embedding): {L_all}, Tokens: {len(tokens)}")
648
 
649
+ """
650
+ # 8.3 Build a pool of states (across a few texts & layers) to fit anchors + UMAP
651
+ pool_states = []
652
+ # Sample across first few texts to improve diversity (lightweight)
653
+ for t in texts[: min(5, len(texts))]:
654
+ b = extract_hidden_states(model, tok, t, cfg.max_length, device)
655
+ # Take a subset from each layer to limit pool size
656
+ for H in b.hidden_layers:
657
+ T = len(H)
658
+ take = min(cfg.fit_pool_per_layer, T)
659
+ idx = np.random.choice(T, size=take, replace=False)
660
+ pool_states.append(H[idx])
661
+ pool_states = np.vstack(pool_states) if len(pool_states) else layers_np[-1]
662
+ #print(f"[Pool] Pooled states for anchors/UMAP: {pool_states.shape}")
663
+
664
+ # 8.4 Fit global anchors (LoT-style features)
665
+ anchors = fit_global_anchors(pool_states, cfg.anchor_k)
666
+ # Save anchors for reproducibility
667
+ """
668
+
669
+ # 8.5 Build per-layer features for main text (LoT-style distances & uncertainty)
670
+ layer_features = [] # list of (T,K)
671
+ layer_uncertainties = [] # list of (T,)
672
+ layer_top_anchor = [] # list of (T,) argmin-id
673
+
674
+ for l, H in enumerate(layers_np):
675
+ dists, P, H_unc = anchor_features(H, anchors, cfg.anchor_temp)
676
+ layer_features.append(dists) # N x K distances (lower = closer)
677
+ layer_uncertainties.append(H_unc) # N
678
+ layer_top_anchor.append(np.argmin(dists, axis=1)) # closest anchor id per token
679
 
680
+ # 8.6 Consistency metric (LoT Eq. (5)): does layer's top anchor match final layer's?
681
+ final_top = layer_top_anchor[-1]
682
+ layer_consistency = []
683
+ for l in range(L_all):
684
+ cons = (layer_top_anchor[l] == final_top).astype(np.int32) # 1 if matches, 0 otherwise
685
+ layer_consistency.append(cons)
686
 
687
+ # 8.7 Build per-layer graphs (kNN by default) on FEATURE space for stability
 
 
688
  layer_graphs = []
 
 
 
689
  for l in range(L_all):
690
+ feats = layer_features[l]
 
 
 
 
 
 
 
691
  if cfg.graph_mode == "knn":
692
+ G = build_knn_graph(feats, cfg.knn_k, metric="euclidean") # kNN in feature space
693
  else:
694
+ # Threshold graph in original hidden space (as in your notebook)
695
+ G = build_threshold_graph(layers_np[l], cfg.sim_threshold, use_cosine=cfg.use_cosine)
696
  layer_graphs.append(G)
697
 
698
+ # 8.8 Cluster per layer
699
+ layer_cluster_labels = []
700
+ for l in range(L_all):
701
+ feats = layer_features[l]
702
+ labels = cluster_layer(
703
+ feats,
704
+ layer_graphs[l],
705
+ method=cfg.cluster_method,
706
+ n_clusters_kmeans=cfg.n_clusters_kmeans,
707
+ hdbscan_min_cluster_size=cfg.hdbscan_min_cluster_size
708
+ )
709
  layer_cluster_labels.append(labels)
710
 
711
+ # 8.9 Percolation statistics (φ, #clusters, χ) per layer (as in your notebook)
712
+ percolation = []
 
 
 
 
 
 
 
 
713
  for l in range(L_all):
714
+ stats = percolation_stats(layer_graphs[l])
715
+ percolation.append(stats)
716
 
 
 
 
717
 
718
+ # 8.10 Common 2D manifold via UMAP (fit-once on the pool), then transform each layer
719
+ """reducer2d = fit_umap_2d(pool_states,
720
+ n_neighbors=cfg.umap_n_neighbors,
721
+ min_dist=cfg.umap_min_dist,
722
+ metric=cfg.umap_metric)"""
723
+ xy_by_layer = [reducer2d.transform(layers_np[l]) for l in range(L_all)]
724
 
725
+ # OPTIONAL: orthogonal alignment across layers (helps if UMAP.transform still drifts)
726
+ # for l in range(1, L_all):
727
+ # xy_by_layer[l] = orthogonal_align(xy_by_layer[l-1], xy_by_layer[l])
728
+
729
+ # 8.11 Plotly 3D point+graph view: X,Y from UMAP; Z = layer index
730
  fig = plotly_3d_layers(
731
  xy_layers=xy_by_layer,
732
+ layer_tokens=[tokens for _ in range(L_all)],
733
  layer_cluster_labels=layer_cluster_labels,
 
734
  layer_uncertainty=layer_uncertainties,
735
  layer_graphs=layer_graphs,
736
+ connect_token_trajectories=True,
737
+ title="Qwen: 3D Cluster Formation (UMAP2D + Layer as Z, LoT metrics on hover)"
 
738
  )
739
 
 
740
  if save_artifacts:
741
+ os.makedirs(cfg.out_dir, exist_ok=True)
742
+ html_path = os.path.join(cfg.out_dir, cfg.plotly_html)
743
+ fig.write_html(html_path)
744
+ # Save percolation series
745
+ with open(os.path.join(cfg.out_dir, "percolation_stats.json"), "w") as f:
746
+ json.dump(percolation, f, indent=2)
747
+ np.save(os.path.join(cfg.out_dir, "anchors.npy"), anchors)
748
+ #print(f"[Percolation] Saved per-layer stats → percolation_stats.json")
749
+ #print(f"[Plotly] 3D HTML saved → {html_path}")
 
750
 
751
  return fig, {"percolation": percolation, "tokens": tokens}
752
 
 
753
  @st.cache_resource(show_spinner=False)
754
  def get_model_and_tok(model_name: str):
755
  device = "cuda" if torch.cuda.is_available() else "cpu"
756
  dtype = torch.float16 if device == "cuda" else torch.float32
757
+ model, tok = load_qwen(model_name, device, dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
758
  return model, tok, device, dtype
759
 
760
  def main():
761
+ st.set_page_config(page_title="Layer Explorer", layout="wide")
762
+ st.title("3D Token Embedding Explorer (Live Hidden States)")
763
 
764
  with st.sidebar:
765
  st.header("Model / Input")
766
+ model_name = st.selectbox("Model", ["Qwen/Qwen1.5-0.5B", "Qwen/Qwen1.5-1.8B", "Qwen/Qwen1.5-4B"], index=1)
767
  max_length = st.slider("Max tokens", 16, 256, 64, step=16)
768
 
769
  st.header("Graph")
770
  graph_mode = st.selectbox("Graph mode", ["knn", "threshold"], index=0)
771
  knn_k = st.slider("k (kNN)", 2, 50, 8) if graph_mode == "knn" else 8
772
+ sim_threshold = st.slider("Similarity threshold", 0.0, 0.99, 0.70, step=0.01) if graph_mode == "threshold" else 0.70
773
  use_cosine = st.checkbox("Use cosine similarity", value=True)
774
 
775
  st.header("Anchors / LoT")
 
787
  st.header("Outputs")
788
  save_artifacts = st.checkbox("Save artifacts to disk (HTML/CSV/NPZ)", value=False)
789
 
790
+ prompt_col, run_col = st.columns([4, 1])
 
791
  with prompt_col:
792
+ main_text = st.text_area(
793
+ "Text to visualize (hidden states computed on this text)",
794
+ value="Explain in one sentence what a transformer attention layer does.",
795
+ height=140
 
796
  )
 
797
  with run_col:
798
  st.write("")
799
  st.write("")