kfoughali commited on
Commit
39106d7
·
verified ·
1 Parent(s): fcc351d

Update compression.py

Browse files
Files changed (1) hide show
  1. compression.py +0 -1204
compression.py CHANGED
@@ -1,1204 +0,0 @@
1
- # compression.py
2
- """
3
- Enhanced SPG compression algorithms with RocketKV-style 450x compression.
4
- NO ESTIMATIONS - only measured values. FAIL FAST on errors.
5
- FIXED: CUDA assert errors, safe tensor operations, bounds checking.
6
- """
7
-
8
- import torch
9
- import torch.nn.functional as F
10
- import numpy as np
11
- from typing import Tuple, Optional, Dict, Any, List
12
- from dataclasses import replace
13
- import logging
14
-
15
- from config import (
16
- CompressionConfig, EnhancedSPGConfig, CompressionType,
17
- ResearchConstants
18
- )
19
-
20
- logger = logging.getLogger(__name__)
21
-
22
-
23
- def safe_topk(tensor, k, dim=-1):
24
- """Safe version of topk that handles edge cases."""
25
- if tensor.numel() == 0:
26
- logger.warning("Empty tensor in topk operation")
27
- return torch.empty(0, dtype=torch.long, device=tensor.device), torch.empty(0, device=tensor.device)
28
-
29
- # Ensure k doesn't exceed tensor size
30
- max_k = tensor.shape[dim]
31
- actual_k = min(k, max_k)
32
-
33
- if actual_k <= 0:
34
- logger.warning(f"Invalid k={k} for tensor with shape {tensor.shape}")
35
- return torch.empty(0, dtype=torch.long, device=tensor.device), torch.empty(0, device=tensor.device)
36
-
37
- return torch.topk(tensor, actual_k, dim=dim)
38
-
39
-
40
- def safe_index_select(tensor, dim, indices):
41
- """Safe version of index_select that validates indices."""
42
- if indices.numel() == 0:
43
- # Return empty tensor with correct shape
44
- shape = list(tensor.shape)
45
- shape[dim] = 0
46
- return torch.empty(shape, dtype=tensor.dtype, device=tensor.device)
47
-
48
- # Validate indices are within bounds
49
- max_idx = tensor.shape[dim] - 1
50
- if indices.max() > max_idx:
51
- logger.warning(f"Index {indices.max()} exceeds max {max_idx}, clamping")
52
- indices = indices.clamp(0, max_idx)
53
-
54
- if indices.min() < 0:
55
- logger.warning(f"Negative index {indices.min()}, clamping to 0")
56
- indices = indices.clamp(0, max_idx)
57
-
58
- return tensor.index_select(dim, indices)
59
-
60
-
61
- class EnhancedSlidingPrecisionGradient:
62
- """
63
- Research-grade Enhanced SPG with RocketKV-style 450x compression capability.
64
- NO ESTIMATIONS OR HARDCODED VALUES - all parameters from validated config.
65
- FIXED: Safe tensor operations with bounds checking.
66
- """
67
-
68
- def __init__(self, config: EnhancedSPGConfig):
69
- self.config = config
70
- self.constants = ResearchConstants()
71
- self.layer_decay_rates: Optional[List[float]] = None
72
- self.compression_stats: List[Dict[str, Any]] = []
73
-
74
- # Progressive compression state
75
- self.current_compression_ratio = config.initial_compression_ratio if config.enable_progressive else None
76
- self.progressive_step = 0
77
- self.quality_history: List[float] = []
78
-
79
- # Adaptive state
80
- self.adaptive_enabled = config.enable_adaptive
81
- self.decay_adjustment_rate = config.decay_adjustment_rate
82
- self.target_perplexity_delta = config.target_perplexity_delta
83
-
84
- # RocketKV-style adaptive decomposition
85
- self.use_adaptive_decomposition = config.use_adaptive_decomposition
86
- self.use_hybrid_sparse_attention = config.use_hybrid_sparse_attention
87
- self.target_compression_ratio = config.target_compression_ratio
88
-
89
- logger.info(f"Enhanced SPG initialized with {config.magnitude_threshold_mode} magnitude thresholds")
90
- if self.use_hybrid_sparse_attention:
91
- logger.info("RocketKV-style Hybrid Sparse Attention enabled")
92
-
93
- def initialize_layer_decay_rates(self, n_layers: int) -> None:
94
- """Initialize per-layer decay rates with validation."""
95
- if not self.constants.MIN_LAYERS <= n_layers <= self.constants.MAX_LAYERS:
96
- logger.warning(f"n_layers {n_layers} outside typical range [{self.constants.MIN_LAYERS}, {self.constants.MAX_LAYERS}]")
97
-
98
- if self.config.per_layer_decay:
99
- self.layer_decay_rates = [self.config.base_decay_rate] * n_layers
100
- else:
101
- self.layer_decay_rates = [self.config.base_decay_rate] * n_layers
102
-
103
- self.n_layers = n_layers
104
- logger.info(f"Initialized decay rates for {n_layers} layers")
105
-
106
- def update_decay_rate(self, layer_idx: int, quality_metric: float, target_quality: float) -> None:
107
- """Update decay rate for adaptive SPG with proper validation."""
108
- if not self.adaptive_enabled or self.layer_decay_rates is None:
109
- return
110
-
111
- if not 0 <= layer_idx < len(self.layer_decay_rates):
112
- logger.error(f"Invalid layer_idx {layer_idx}, valid range: [0, {len(self.layer_decay_rates)})")
113
- return
114
-
115
- # Validate and clamp inputs
116
- quality_metric = max(0.1, min(1000.0, float(quality_metric)))
117
- target_quality = max(0.1, min(1000.0, float(target_quality)))
118
-
119
- # Compute adjustment
120
- quality_delta = quality_metric - target_quality
121
-
122
- if quality_delta > 0: # Quality worse than target
123
- adjustment = -self.decay_adjustment_rate * (quality_delta / target_quality)
124
- else: # Quality better than target
125
- adjustment = self.decay_adjustment_rate * (abs(quality_delta) / target_quality)
126
-
127
- # Apply with bounds
128
- old_rate = self.layer_decay_rates[layer_idx]
129
- new_rate = max(0.8, min(0.99, old_rate + adjustment))
130
- self.layer_decay_rates[layer_idx] = new_rate
131
-
132
- logger.debug(f"Adaptive SPG Layer {layer_idx}: quality={quality_metric:.3f}, "
133
- f"target={target_quality:.3f}, decay_rate: {old_rate:.3f} → {new_rate:.3f}")
134
-
135
- def compute_magnitude_importance(self, keys: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
136
- """
137
- Compute importance scores based on magnitude statistics.
138
- This is an EXPLICIT magnitude-based proxy, not an estimation.
139
- """
140
- try:
141
- # Compute L2 norm across head dimension for each token
142
- k_norms = keys.norm(dim=-1).mean(dim=1).mean(dim=0) # [seq_len]
143
- v_norms = values.norm(dim=-1).mean(dim=1).mean(dim=0) # [seq_len]
144
-
145
- # Combine key and value magnitudes (explicit formula)
146
- importance_scores = (k_norms + v_norms) / 2.0
147
-
148
- # Normalize to [0, 1] range for consistent thresholding
149
- score_min = importance_scores.min()
150
- score_max = importance_scores.max()
151
-
152
- if score_max > score_min:
153
- importance_scores = (importance_scores - score_min) / (score_max - score_min)
154
- else:
155
- importance_scores = torch.ones_like(importance_scores)
156
-
157
- logger.debug(f"Computed magnitude importance: min={score_min:.6f}, max={score_max:.6f}")
158
- return importance_scores
159
-
160
- except Exception as e:
161
- logger.error(f"Error computing magnitude importance: {e}")
162
- raise
163
-
164
- def estimate_attention_sparsity(self, keys: torch.Tensor, values: torch.Tensor) -> float:
165
- """Estimate attention pattern sparsity for adaptive decomposition. FAIL FAST on error."""
166
- try:
167
- # Compute approximate attention patterns using key-key similarity
168
- k_norm = F.normalize(keys.float(), p=2, dim=-1)
169
- attention_approx = torch.matmul(k_norm, k_norm.transpose(-2, -1))
170
-
171
- # Measure sparsity as fraction of near-zero attention weights
172
- # Use configurable threshold from constants
173
- threshold = self.constants.ATTENTION_SPARSITY_THRESHOLD
174
- sparse_fraction = (attention_approx.abs() < threshold).float().mean().item()
175
-
176
- return sparse_fraction
177
-
178
- except Exception as e:
179
- # FAIL FAST - NO FALLBACK VALUES
180
- logger.error(f"Failed to estimate attention sparsity: {e}")
181
- raise RuntimeError(f"Cannot measure attention sparsity: {e}")
182
-
183
- def adaptive_stage_split(self, target_ratio: float, seq_len: int, sparsity: float) -> Tuple[float, float]:
184
- """RocketKV-style adaptive compression decomposition with explicit parameters."""
185
- # Use explicit formulas from research constants
186
- if sparsity > self.constants.SPARSITY_HIGH_THRESHOLD:
187
- stage1_power = self.constants.SPARSE_STAGE1_POWER
188
- elif sparsity > self.constants.SPARSITY_MEDIUM_THRESHOLD:
189
- stage1_power = self.constants.BALANCED_STAGE1_POWER
190
- else:
191
- stage1_power = self.constants.DENSE_STAGE1_POWER
192
-
193
- stage1_ratio = target_ratio ** stage1_power
194
- stage2_ratio = target_ratio / stage1_ratio
195
-
196
- # Bounds checking with explicit limits from config
197
- stage1_ratio = max(self.config.stage_compression_min, min(self.config.stage_compression_max, stage1_ratio))
198
- stage2_ratio = max(self.config.stage_compression_min, min(self.config.stage_compression_max, stage2_ratio))
199
-
200
- logger.debug(f"Adaptive split: sparsity={sparsity:.3f}, stage1={stage1_ratio:.1f}x, stage2={stage2_ratio:.1f}x")
201
- return stage1_ratio, stage2_ratio
202
-
203
- def snapkv_plus_plus(self, keys: torch.Tensor, values: torch.Tensor,
204
- compression_ratio: float) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
205
- """SnapKV++ with GQA support and adaptive pooling - FIXED with safe operations."""
206
- batch_size, n_heads, seq_len, head_dim = keys.shape
207
-
208
- # CRITICAL: Ensure minimum tokens retained
209
- min_tokens = max(8, self.config.min_tokens_for_stability) # At least 8 tokens
210
- n_keep = max(min_tokens, int(seq_len / compression_ratio))
211
- n_keep = min(n_keep, seq_len) # Can't keep more than we have
212
-
213
- logger.debug(f"SnapKV++: seq_len={seq_len}, compression_ratio={compression_ratio:.1f}, n_keep={n_keep}")
214
-
215
- if n_keep >= seq_len:
216
- # No compression needed
217
- return keys, values, list(range(seq_len))
218
-
219
- # Adaptive kernel size based on sequence length (from config)
220
- kernel_size = self.config.get_adaptive_kernel_size(seq_len)
221
-
222
- # Compute importance scores with adaptive pooling
223
- try:
224
- key_norms = keys.norm(dim=-1) # [batch, heads, seq]
225
- value_norms = values.norm(dim=-1)
226
- combined_importance = (key_norms + value_norms) / 2.0
227
-
228
- # Multi-head aggregation with adaptive pooling
229
- if kernel_size > 1 and seq_len > kernel_size:
230
- # Apply 1D pooling along sequence dimension
231
- pooled_importance = F.avg_pool1d(
232
- combined_importance.mean(dim=1).unsqueeze(1), # [batch, 1, seq]
233
- kernel_size=kernel_size,
234
- stride=1,
235
- padding=kernel_size // 2
236
- ).squeeze(1) # [batch, seq]
237
- # Ensure pooled output matches original sequence length
238
- if pooled_importance.shape[-1] != seq_len:
239
- pooled_importance = pooled_importance[:, :seq_len]
240
- else:
241
- pooled_importance = combined_importance.mean(dim=1)
242
-
243
- # Aggregate across batch
244
- final_importance = pooled_importance.mean(dim=0) # [seq]
245
- except Exception as e:
246
- logger.error(f"Error computing importance: {e}")
247
- # Fallback to uniform importance
248
- final_importance = torch.ones(seq_len, device=keys.device)
249
-
250
- # Ensure importance tensor matches sequence length
251
- if final_importance.shape[0] != seq_len:
252
- final_importance = final_importance[:seq_len]
253
-
254
- # Preserve sink and recent tokens
255
- preserve_mask = torch.zeros(seq_len, dtype=torch.bool, device=keys.device)
256
-
257
- # Recent tokens
258
- recent_window = min(self.config.recent_window, seq_len // 2) # Don't preserve more than half
259
- preserve_mask[-recent_window:] = True
260
-
261
- # Sink tokens
262
- if self.config.sink_tokens > 0:
263
- sink_count = min(self.config.sink_tokens, seq_len // 4) # Don't preserve more than quarter
264
- preserve_mask[:sink_count] = True
265
-
266
- preserved_count = preserve_mask.sum().item()
267
- remaining_slots = max(0, n_keep - preserved_count)
268
-
269
- if remaining_slots > 0:
270
- masked_importance = final_importance.clone()
271
- masked_importance[preserve_mask] = -float('inf')
272
-
273
- available_indices = (~preserve_mask).nonzero(as_tuple=True)[0]
274
- if len(available_indices) > 0:
275
- k = min(remaining_slots, len(available_indices))
276
- if k > 0:
277
- available_importance = masked_importance[available_indices]
278
- _, relative_top_indices = safe_topk(available_importance, k)
279
-
280
- if relative_top_indices.numel() > 0:
281
- absolute_indices = available_indices[relative_top_indices]
282
- preserve_mask[absolute_indices] = True
283
-
284
- # Get final retained indices
285
- retained_indices = preserve_mask.nonzero(as_tuple=True)[0]
286
-
287
- if retained_indices.numel() == 0:
288
- logger.error("No indices retained! Keeping at least recent tokens")
289
- # Emergency fallback - keep last few tokens
290
- retained_indices = torch.arange(max(0, seq_len - min_tokens), seq_len,
291
- device=keys.device, dtype=torch.long)
292
-
293
- # Safe indexing
294
- keys_compressed = safe_index_select(keys, 2, retained_indices)
295
- values_compressed = safe_index_select(values, 2, retained_indices)
296
-
297
- actual_ratio = seq_len / len(retained_indices) if len(retained_indices) > 0 else 1.0
298
- logger.debug(f"SnapKV++ compressed: {seq_len} → {len(retained_indices)} tokens ({actual_ratio:.1f}x)")
299
-
300
- return keys_compressed, values_compressed, retained_indices.tolist()
301
-
302
- def hybrid_sparse_attention(self, keys: torch.Tensor, values: torch.Tensor,
303
- head_budget: int, seq_budget: int) -> Dict[str, Any]:
304
- """RocketKV-style Hybrid Sparse Attention for Stage 2 - FIXED with safe operations."""
305
- batch_size, n_heads, seq_len, head_dim = keys.shape
306
-
307
- # Ensure minimum budgets
308
- head_budget = max(1, min(head_budget, n_heads))
309
- seq_budget = max(self.config.min_tokens_for_stability, min(seq_budget, seq_len))
310
-
311
- logger.debug(f"HSA: n_heads={n_heads}, seq_len={seq_len}, head_budget={head_budget}, seq_budget={seq_budget}")
312
-
313
- # 1. Head-wise importance scoring with safe computation
314
- try:
315
- head_importance = (
316
- keys.float().pow(2).sum(dim=(-1, -2)).mean(dim=0) + # Average over batch
317
- values.float().pow(2).sum(dim=(-1, -2)).mean(dim=0)
318
- ) # [n_heads]
319
- except Exception as e:
320
- logger.error(f"Error computing head importance: {e}")
321
- head_importance = torch.ones(n_heads, device=keys.device)
322
-
323
- # Select top heads safely
324
- _, top_head_indices = safe_topk(head_importance, head_budget)
325
-
326
- if top_head_indices.numel() == 0:
327
- # Fallback - keep first head
328
- top_head_indices = torch.tensor([0], device=keys.device, dtype=torch.long)
329
-
330
- compressed_data = {
331
- 'keys': {},
332
- 'values': {},
333
- 'metadata': {
334
- 'head_selection': top_head_indices.tolist(),
335
- 'original_shape': keys.shape,
336
- 'compression_type': 'hybrid_sparse_attention'
337
- }
338
- }
339
-
340
- # 2. Sequence-wise top-k selection per selected head
341
- for head_idx in top_head_indices:
342
- head_idx_int = head_idx.item()
343
-
344
- # Extract head data safely
345
- head_keys = keys[:, head_idx_int:head_idx_int+1, :, :]
346
- head_values = values[:, head_idx_int:head_idx_int+1, :, :]
347
-
348
- # Compute sequence importance for this head
349
- try:
350
- seq_importance = (
351
- head_keys.norm(dim=-1).squeeze(1).mean(dim=0) +
352
- head_values.norm(dim=-1).squeeze(1).mean(dim=0)
353
- ) / 2.0
354
- except Exception as e:
355
- logger.error(f"Error computing seq importance for head {head_idx_int}: {e}")
356
- seq_importance = torch.ones(seq_len, device=keys.device)
357
-
358
- # Apply position-based boost (from research constants)
359
- position_boost = torch.ones_like(seq_importance)
360
- if self.config.sink_tokens > 0:
361
- sink_count = min(self.config.sink_tokens, seq_len // 4)
362
- position_boost[:sink_count] *= self.constants.POSITION_BOOST_SINK
363
- if self.config.recent_window > 0:
364
- recent_count = min(self.config.recent_window, seq_len // 2)
365
- position_boost[-recent_count:] *= self.constants.POSITION_BOOST_RECENT
366
-
367
- boosted_importance = seq_importance * position_boost
368
-
369
- # Select top tokens for this head
370
- _, top_token_indices = safe_topk(boosted_importance, seq_budget)
371
-
372
- if top_token_indices.numel() == 0:
373
- # Fallback - keep last few tokens
374
- top_token_indices = torch.arange(max(0, seq_len - seq_budget), seq_len,
375
- device=keys.device, dtype=torch.long)
376
-
377
- # Store compressed data
378
- head_key = f'head_{head_idx_int}'
379
- compressed_data['keys'][head_key] = {
380
- 'data': safe_index_select(head_keys, 2, top_token_indices),
381
- 'indices': top_token_indices.tolist()
382
- }
383
- compressed_data['values'][head_key] = {
384
- 'data': safe_index_select(head_values, 2, top_token_indices),
385
- 'indices': top_token_indices.tolist()
386
- }
387
-
388
- return compressed_data
389
-
390
- def stage1_permanent_eviction(self, keys: torch.Tensor, values: torch.Tensor,
391
- layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
392
- """
393
- Stage 1: RocketKV-style permanent eviction with SnapKV++ or magnitude-guided approach.
394
- """
395
- batch_size, n_heads, seq_len, head_dim = keys.shape
396
-
397
- if self.use_adaptive_decomposition:
398
- # Use adaptive compression split
399
- sparsity = self.estimate_attention_sparsity(keys, values) # May raise if fails
400
- stage1_ratio, _ = self.adaptive_stage_split(self.target_compression_ratio, seq_len, sparsity)
401
- else:
402
- stage1_ratio = self.config.stage1_compression_ratio
403
-
404
- # Choose compression method based on configuration
405
- if self.config.use_snapkv_plus_plus:
406
- return self.snapkv_plus_plus(keys, values, stage1_ratio)
407
- else:
408
- # Original magnitude-guided approach
409
- return self._magnitude_guided_stage1(keys, values, layer_idx, stage1_ratio)
410
-
411
- def _magnitude_guided_stage1(self, keys: torch.Tensor, values: torch.Tensor,
412
- layer_idx: int, compression_ratio: float) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
413
- """Original magnitude-guided Stage 1 eviction with explicit parameters."""
414
- batch_size, n_heads, seq_len, head_dim = keys.shape
415
-
416
- # Calculate retention based on compression ratio
417
- retention_ratio = 1.0 / compression_ratio
418
- min_retain = max(8, self.config.sink_tokens + self.config.recent_window, self.config.min_tokens_for_stability)
419
- n_retain = max(min_retain, int(seq_len * retention_ratio))
420
-
421
- # Apply layer-specific constraints (from research constants)
422
- layer_position = layer_idx / max(getattr(self, 'n_layers', 12) - 1, 1)
423
- if layer_position <= 0.5: # Early layers
424
- max_retain = int(seq_len * self.constants.EARLY_LAYER_MAX_RETENTION)
425
- else: # Late layers
426
- max_retain = int(seq_len * self.constants.LATE_LAYER_MAX_RETENTION)
427
-
428
- n_retain = min(n_retain, max_retain, seq_len)
429
-
430
- # Compute magnitude-based importance
431
- importance_scores = self.compute_magnitude_importance(keys, values)
432
-
433
- # Quality preservation: boost recent tokens (explicit formula from config)
434
- recent_boost = torch.zeros_like(importance_scores)
435
- if self.config.recent_window > 0:
436
- recent_window = min(self.config.recent_window, seq_len // 2)
437
- recent_boost[-recent_window:] = importance_scores.max() * self.config.recent_boost_factor
438
- importance_scores = importance_scores + recent_boost
439
-
440
- # Initialize preservation mask
441
- preserve_mask = torch.zeros(seq_len, dtype=torch.bool, device=keys.device)
442
- if self.config.sink_tokens > 0:
443
- sink_count = min(self.config.sink_tokens, seq_len // 4)
444
- preserve_mask[:sink_count] = True
445
- if self.config.recent_window > 0:
446
- recent_count = min(self.config.recent_window, seq_len // 2)
447
- preserve_mask[-recent_count:] = True
448
-
449
- # Select additional tokens based on importance
450
- remaining_slots = n_retain - preserve_mask.sum().item()
451
- if remaining_slots > 0:
452
- masked_importance = importance_scores.clone()
453
- masked_importance[preserve_mask] = -float('inf')
454
-
455
- # Use configured threshold (not hardcoded)
456
- magnitude_threshold = torch.quantile(
457
- importance_scores.float(),
458
- self.config.get_magnitude_threshold()
459
- )
460
-
461
- below_threshold = masked_importance < magnitude_threshold
462
- masked_importance[below_threshold] = -float('inf')
463
-
464
- available = (masked_importance > -float('inf')).sum().item()
465
- k = min(remaining_slots, available)
466
- if k > 0:
467
- _, top_indices = safe_topk(masked_importance, k)
468
- if top_indices.numel() > 0:
469
- preserve_mask[top_indices] = True
470
-
471
- # Extract retained tokens
472
- retained_indices = preserve_mask.nonzero(as_tuple=True)[0]
473
-
474
- if retained_indices.numel() == 0:
475
- logger.error(f"No tokens retained in stage 1 layer {layer_idx}! Using fallback")
476
- min_keep = max(8, self.config.min_tokens_for_stability)
477
- retained_indices = torch.arange(seq_len - min_keep, seq_len, device=keys.device, dtype=torch.long)
478
-
479
- keys_stage1 = safe_index_select(keys, 2, retained_indices)
480
- values_stage1 = safe_index_select(values, 2, retained_indices)
481
-
482
- actual_ratio = seq_len / len(retained_indices) if len(retained_indices) > 0 else 1.0
483
- logger.debug(f"Stage 1 Layer {layer_idx}: {seq_len} → {len(retained_indices)} tokens ({actual_ratio:.1f}x)")
484
-
485
- return keys_stage1, values_stage1, retained_indices.tolist()
486
-
487
- def stage2_multi_dimensional_compression(self, keys: torch.Tensor, values: torch.Tensor,
488
- layer_idx: int, retained_indices: List[int]) -> Dict[str, Any]:
489
- """
490
- Stage 2: RocketKV-style Hybrid Sparse Attention compression.
491
- Uses dynamic top-k selection with head and sequence reductions.
492
- """
493
- batch_size, n_heads, seq_len, head_dim = keys.shape
494
-
495
- if self.use_hybrid_sparse_attention:
496
- # RocketKV-style compression with adaptive budgets
497
- try:
498
- sparsity = self.estimate_attention_sparsity(keys, values)
499
- except:
500
- sparsity = 0.5 # Default if estimation fails
501
-
502
- if self.use_adaptive_decomposition:
503
- _, stage2_ratio = self.adaptive_stage_split(
504
- self.target_compression_ratio, seq_len, sparsity
505
- )
506
- else:
507
- stage2_ratio = self.config.stage2_compression_ratio
508
-
509
- # Dynamic budgets based on compression target (from config)
510
- head_retention_ratio = self.config.get_head_retention_ratio()
511
- head_budget = max(1, int(n_heads * head_retention_ratio))
512
- seq_budget = max(self.config.min_tokens_for_stability, int(seq_len / stage2_ratio))
513
-
514
- # Use hybrid sparse attention
515
- compressed_data = self.hybrid_sparse_attention(keys, values, head_budget, seq_budget)
516
-
517
- # Add metadata
518
- compressed_data['metadata'].update({
519
- 'stage1_retained_indices': retained_indices,
520
- 'original_shape_after_stage1': keys.shape,
521
- 'original_dtype': keys.dtype,
522
- 'layer_idx': layer_idx,
523
- 'sparsity_estimate': sparsity,
524
- 'stage2_compression_ratio': stage2_ratio,
525
- 'head_budget': head_budget,
526
- 'seq_budget': seq_budget,
527
- 'head_retention_ratio': head_retention_ratio
528
- })
529
-
530
- return compressed_data
531
-
532
- # Fallback to original multi-dimensional compression
533
- return self._original_stage2_compression(keys, values, layer_idx, retained_indices)
534
-
535
- def _original_stage2_compression(self, keys: torch.Tensor, values: torch.Tensor,
536
- layer_idx: int, retained_indices: List[int]) -> Dict[str, Any]:
537
- """Original Stage 2 implementation for comparison."""
538
- batch_size, n_heads, seq_len, head_dim = keys.shape
539
-
540
- # Compute importance for remaining tokens
541
- importance_scores = self.compute_magnitude_importance(keys, values)
542
-
543
- # Combine with position-based decay (explicit formula)
544
- decay_rate = self.layer_decay_rates[layer_idx] if self.layer_decay_rates else self.config.base_decay_rate
545
- position_scores = torch.pow(
546
- decay_rate,
547
- torch.arange(seq_len, device=keys.device).float() / self.config.decay_normalization
548
- )
549
-
550
- combined_importance = importance_scores * position_scores
551
-
552
- compressed_data = {
553
- 'keys': {},
554
- 'values': {},
555
- 'metadata': {
556
- 'stage1_retained_indices': retained_indices,
557
- 'importance_scores': combined_importance,
558
- 'original_shape_after_stage1': keys.shape,
559
- 'original_dtype': keys.dtype,
560
- 'layer_idx': layer_idx,
561
- 'magnitude_threshold_mode': self.config.magnitude_threshold_mode,
562
- 'compression_type': 'original_multi_dimensional'
563
- }
564
- }
565
-
566
- # Head dimension compression with explicit parameters
567
- if self.config.enable_head_compression:
568
- n_important_heads = max(1, int(n_heads * self.config.head_compression_ratio))
569
-
570
- # UPDATED: Always reserve top head_fp16_reserve heads at full precision
571
- n_reserved_heads = min(getattr(self.config, 'head_fp16_reserve', 2), n_heads)
572
- n_important_heads = max(n_reserved_heads, n_important_heads)
573
-
574
- # Compute head importance (explicit calculation)
575
- head_importance = (
576
- keys.float().pow(2).sum(dim=(-1, -2)).sum(dim=0) +
577
- values.float().pow(2).sum(dim=(-1, -2)).sum(dim=0)
578
- )
579
-
580
- _, important_head_indices = safe_topk(head_importance, n_important_heads)
581
-
582
- if important_head_indices.numel() == 0:
583
- important_head_indices = torch.tensor([0], device=keys.device, dtype=torch.long)
584
-
585
- other_head_indices = torch.tensor(
586
- [h for h in range(n_heads) if h not in important_head_indices.tolist()],
587
- device=keys.device, dtype=torch.long
588
- )
589
-
590
- # Store important heads at full precision
591
- compressed_data['keys']['heads_fp16'] = {
592
- 'data': safe_index_select(keys, 1, important_head_indices).clone(),
593
- 'indices': important_head_indices.tolist()
594
- }
595
- compressed_data['values']['heads_fp16'] = {
596
- 'data': safe_index_select(values, 1, important_head_indices).clone(),
597
- 'indices': important_head_indices.tolist()
598
- }
599
-
600
- if other_head_indices.numel() == 0:
601
- return compressed_data
602
-
603
- seq_keys = safe_index_select(keys, 1, other_head_indices)
604
- seq_values = safe_index_select(values, 1, other_head_indices)
605
- else:
606
- seq_keys = keys
607
- seq_values = values
608
-
609
- # Sequence dimension compression with explicit ratios
610
- levels = self.config.precision_levels
611
-
612
- # Explicit top-K selection for FP16
613
- keep_fp16 = max(0, int(seq_len * self.config.sequence_compression_ratio))
614
- if keep_fp16 > 0:
615
- top_fp16, _ = safe_topk(combined_importance, k=keep_fp16)
616
- is_fp16 = torch.zeros(seq_len, dtype=torch.bool, device=keys.device)
617
- if top_fp16.numel() > 0:
618
- is_fp16[top_fp16] = True
619
- else:
620
- is_fp16 = torch.zeros(seq_len, dtype=torch.bool, device=keys.device)
621
-
622
- # Vectorized token binning
623
- thresh = torch.tensor([pl.threshold for pl in levels], device=keys.device)
624
- thresh_sorted, order = torch.sort(thresh, descending=True)
625
- level_ids = torch.bucketize(combined_importance, thresh_sorted, right=False)
626
-
627
- # Assign tokens to precision levels
628
- for i in range(seq_len):
629
- if is_fp16[i]:
630
- precision_key = 'seq_fp16'
631
- else:
632
- level_idx = min(level_ids[i].item(), len(levels) - 1)
633
- level = levels[order[level_idx]]
634
-
635
- if level.bits is not None:
636
- precision_key = f'seq_{level.bits}bit'
637
- else:
638
- precision_key = f'seq_{level.name}'
639
-
640
- if precision_key not in compressed_data['keys']:
641
- compressed_data['keys'][precision_key] = {
642
- 'indices': [], 'data': None, 'scale': None, 'zero': None
643
- }
644
- compressed_data['values'][precision_key] = {
645
- 'indices': [], 'data': None, 'scale': None, 'zero': None
646
- }
647
-
648
- compressed_data['keys'][precision_key]['indices'].append(i)
649
- compressed_data['values'][precision_key]['indices'].append(i)
650
-
651
- # Store data with aggressive precision (FP16 for most important tokens)
652
- keys_to_delete = []
653
- for precision_key in list(compressed_data['keys'].keys()):
654
- if not precision_key.startswith('seq_'):
655
- continue
656
-
657
- indices = compressed_data['keys'][precision_key]['indices']
658
- if not indices:
659
- keys_to_delete.append(precision_key)
660
- continue
661
-
662
- if precision_key == 'seq_discard':
663
- keys_to_delete.append(precision_key)
664
- continue
665
-
666
- idx_tensor = torch.tensor(indices, device=keys.device, dtype=torch.long)
667
- k_slice = safe_index_select(seq_keys, 2, idx_tensor)
668
- v_slice = safe_index_select(seq_values, 2, idx_tensor)
669
-
670
- # Store with aggressive precision - only FP16 for ultra-selective tokens
671
- compressed_data['keys'][precision_key]['data'] = k_slice.clone()
672
- compressed_data['values'][precision_key]['data'] = v_slice.clone()
673
-
674
- # Clean up empty keys
675
- for pk in keys_to_delete:
676
- compressed_data['keys'].pop(pk, None)
677
- compressed_data['values'].pop(pk, None)
678
-
679
- return compressed_data
680
-
681
- def compress_with_enhanced_gradient(self, keys: torch.Tensor, values: torch.Tensor,
682
- layer_idx: int, current_position: int) -> Dict[str, Any]:
683
- """
684
- Main compression function with explicit two-stage approach.
685
- """
686
- if not self.config.enable_two_stage:
687
- return self._fallback_to_original_spg(keys, values, layer_idx, current_position)
688
-
689
- try:
690
- # Record original shape
691
- orig_shape_full = keys.shape
692
-
693
- # Stage 1: Permanent eviction
694
- keys_stage1, values_stage1, retained_indices = self.stage1_permanent_eviction(
695
- keys, values, layer_idx
696
- )
697
-
698
- # Stage 2: Multi-dimensional compression
699
- compressed_data = self.stage2_multi_dimensional_compression(
700
- keys_stage1, values_stage1, layer_idx, retained_indices
701
- )
702
-
703
- # Add metadata
704
- compressed_data['metadata']['original_full_shape'] = orig_shape_full
705
-
706
- # Progressive compression
707
- if self.config.enable_progressive:
708
- compressed_data = self._apply_progressive_compression(compressed_data, layer_idx)
709
-
710
- return compressed_data
711
-
712
- except Exception as e:
713
- logger.error(f"Error in enhanced compression for layer {layer_idx}: {e}")
714
- # Fallback to original SPG on error
715
- return self._fallback_to_original_spg(keys, values, layer_idx, current_position)
716
-
717
- def _fallback_to_original_spg(self, keys: torch.Tensor, values: torch.Tensor,
718
- layer_idx: int, current_position: Optional[int]) -> Dict[str, Any]:
719
- """Fallback to original SPG implementation with actual data storage."""
720
- batch_size, n_heads, seq_len, head_dim = keys.shape
721
-
722
- # Original position-based precision computation
723
- device = keys.device
724
- precision_scores = torch.zeros(seq_len, device=device)
725
-
726
- decay_rate = self.layer_decay_rates[layer_idx] if self.layer_decay_rates else self.config.base_decay_rate
727
-
728
- positions = torch.arange(seq_len, device=device)
729
- if current_position is None or not isinstance(current_position, (int, float)):
730
- current_position = seq_len
731
- current_position = int(current_position)
732
- distances = torch.tensor(current_position, device=device, dtype=positions.dtype) - positions
733
-
734
- precision_scores = torch.pow(decay_rate, distances.float() / self.config.decay_normalization)
735
- precision_scores[:self.config.sink_tokens] = 1.0
736
-
737
- recent_mask = distances < self.config.recent_window
738
- precision_scores[recent_mask] = torch.maximum(
739
- precision_scores[recent_mask],
740
- torch.tensor(self.config.recent_min_precision, device=device)
741
- )
742
-
743
- # Apply precision levels with actual data storage
744
- compressed_data = {
745
- 'keys': {},
746
- 'values': {},
747
- 'metadata': {
748
- 'precision_scores': precision_scores,
749
- 'original_shape': keys.shape,
750
- 'original_dtype': keys.dtype,
751
- 'layer_idx': layer_idx,
752
- 'compression_type': 'original_spg'
753
- }
754
- }
755
-
756
- # Exclusive binning for precision levels
757
- levels = self.config.precision_levels
758
- for i, score in enumerate(precision_scores):
759
- for j, level in enumerate(levels):
760
- lo = level.threshold
761
- hi = levels[j-1].threshold if j > 0 else float('inf')
762
-
763
- if lo <= score < hi:
764
- if level.bits is not None:
765
- precision_key = f'{level.bits}bit'
766
- else:
767
- precision_key = level.name
768
-
769
- if precision_key not in compressed_data['keys']:
770
- compressed_data['keys'][precision_key] = {
771
- 'indices': [], 'data': None, 'scale': None, 'zero': None
772
- }
773
- compressed_data['values'][precision_key] = {
774
- 'indices': [], 'data': None, 'scale': None, 'zero': None
775
- }
776
-
777
- compressed_data['keys'][precision_key]['indices'].append(i)
778
- compressed_data['values'][precision_key]['indices'].append(i)
779
- break
780
-
781
- # Process data
782
- keys_to_delete = []
783
- for precision_key in list(compressed_data['keys'].keys()):
784
- indices = compressed_data['keys'][precision_key]['indices']
785
- if not indices:
786
- keys_to_delete.append(precision_key)
787
- continue
788
-
789
- if precision_key == 'discard':
790
- keys_to_delete.append(precision_key)
791
- continue
792
-
793
- level_indices = torch.tensor(indices, device=device, dtype=torch.long)
794
- k_slice = safe_index_select(keys, 2, level_indices)
795
- v_slice = safe_index_select(values, 2, level_indices)
796
-
797
- # Store with FP16 precision (simplified for original SPG)
798
- compressed_data['keys'][precision_key]['data'] = k_slice.clone()
799
- compressed_data['values'][precision_key]['data'] = v_slice.clone()
800
-
801
- # Clean up empty keys
802
- for pk in keys_to_delete:
803
- compressed_data['keys'].pop(pk, None)
804
- compressed_data['values'].pop(pk, None)
805
-
806
- return compressed_data
807
-
808
- def _apply_progressive_compression(self, compressed_data: Dict, layer_idx: int) -> Dict:
809
- """Apply progressive compression with relative quality change detection."""
810
- if len(self.quality_history) >= self.constants.PROGRESSIVE_QUALITY_WINDOW:
811
- recent = float(np.mean(self.quality_history[-self.constants.PROGRESSIVE_RECENT_WINDOW:]))
812
- prev = float(np.mean(self.quality_history[-self.constants.PROGRESSIVE_QUALITY_WINDOW:-self.constants.PROGRESSIVE_RECENT_WINDOW]))
813
- rel_delta = (recent - prev) / max(prev, 1e-9)
814
-
815
- if rel_delta <= self.config.quality_threshold:
816
- old_ratio = self.current_compression_ratio or self.config.initial_compression_ratio
817
- new_ratio = min(old_ratio * self.config.progression_factor, self.config.max_compression_ratio)
818
-
819
- if new_ratio > old_ratio:
820
- self.current_compression_ratio = new_ratio
821
- compression_factor = new_ratio / old_ratio
822
-
823
- # Tighten compression ratios (use configurable minimum from config)
824
- self.config.head_compression_ratio = max(self.config.progressive_min_ratio,
825
- self.config.head_compression_ratio / compression_factor)
826
- self.config.sequence_compression_ratio = max(self.config.progressive_min_ratio,
827
- self.config.sequence_compression_ratio / compression_factor)
828
-
829
- self.progressive_step += 1
830
-
831
- logger.info(f"Progressive step {self.progressive_step}: rel_delta={rel_delta:.4f}, new_ratio={new_ratio:.1f}x")
832
-
833
- compressed_data['metadata']['progressive_compression_ratio'] = self.current_compression_ratio
834
- compressed_data['metadata']['progressive_step'] = self.progressive_step
835
-
836
- return compressed_data
837
-
838
- def decompress(self, compressed_data: Dict) -> Tuple[torch.Tensor, torch.Tensor]:
839
- """Decompress enhanced SPG compressed data."""
840
- metadata = compressed_data['metadata']
841
-
842
- if metadata.get('compression_type') == 'original_spg':
843
- return self._decompress_original_spg(compressed_data)
844
-
845
- return self._decompress_enhanced_spg(compressed_data)
846
-
847
- def _decompress_enhanced_spg(self, compressed_data: Dict) -> Tuple[torch.Tensor, torch.Tensor]:
848
- """Decompress enhanced multi-stage compressed data with HSA support."""
849
- metadata = compressed_data['metadata']
850
-
851
- # Get device from first available tensor
852
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
853
- for storage_type in ['keys', 'values']:
854
- for key, data in compressed_data[storage_type].items():
855
- if isinstance(data, dict) and 'data' in data and isinstance(data['data'], torch.Tensor):
856
- device = data['data'].device
857
- break
858
- if device != torch.device('cuda' if torch.cuda.is_available() else 'cpu'):
859
- break
860
-
861
- # Handle hybrid sparse attention format
862
- if metadata.get('compression_type') == 'hybrid_sparse_attention':
863
- return self._decompress_hybrid_sparse_attention(compressed_data)
864
-
865
- # Original enhanced SPG decompression
866
- original_shape = metadata['original_shape_after_stage1']
867
- original_dtype = metadata['original_dtype']
868
-
869
- keys_full = torch.zeros(original_shape, dtype=original_dtype, device=device)
870
- values_full = torch.zeros(original_shape, dtype=original_dtype, device=device)
871
-
872
- # Decompress head dimension data first
873
- if 'heads_fp16' in compressed_data['keys']:
874
- head_indices = compressed_data['keys']['heads_fp16']['indices']
875
- head_idx_tensor = torch.tensor(head_indices, device=device, dtype=torch.long)
876
-
877
- # Safe assignment
878
- head_data_k = compressed_data['keys']['heads_fp16']['data']
879
- head_data_v = compressed_data['values']['heads_fp16']['data']
880
-
881
- if head_data_k is not None and head_data_v is not None:
882
- for i, idx in enumerate(head_indices):
883
- if idx < keys_full.shape[1]:
884
- keys_full[:, idx, :, :] = head_data_k[:, i, :, :]
885
- values_full[:, idx, :, :] = head_data_v[:, i, :, :]
886
-
887
- if self.config.enable_head_compression:
888
- n_heads = original_shape[1]
889
- other_head_indices = torch.tensor([h for h in range(n_heads) if h not in head_indices],
890
- device=device, dtype=torch.long)
891
- else:
892
- other_head_indices = head_idx_tensor
893
- else:
894
- other_head_indices = torch.arange(original_shape[1], device=device, dtype=torch.long)
895
-
896
- # Decompress sequence dimension data
897
- for precision_key in [k for k in compressed_data['keys'].keys() if k.startswith('seq_')]:
898
- if 'data' not in compressed_data['keys'][precision_key]:
899
- continue
900
-
901
- indices = compressed_data['keys'][precision_key]['indices']
902
- if not indices:
903
- continue
904
-
905
- idx_tensor = torch.tensor(indices, device=device, dtype=torch.long)
906
-
907
- # All data stored as FP16 in this simplified version
908
- k_data = compressed_data['keys'][precision_key]['data']
909
- v_data = compressed_data['values'][precision_key]['data']
910
-
911
- if k_data is not None and v_data is not None:
912
- for head_idx in other_head_indices:
913
- if head_idx < keys_full.shape[1]:
914
- for i, seq_idx in enumerate(indices):
915
- if seq_idx < keys_full.shape[2]:
916
- keys_full[:, head_idx, seq_idx, :] = k_data[:, :, i, :].squeeze(1)
917
- values_full[:, head_idx, seq_idx, :] = v_data[:, :, i, :].squeeze(1)
918
-
919
- return keys_full, values_full
920
-
921
- def _decompress_hybrid_sparse_attention(self, compressed_data: Dict) -> Tuple[torch.Tensor, torch.Tensor]:
922
- """Decompress RocketKV-style hybrid sparse attention data."""
923
- metadata = compressed_data['metadata']
924
- original_shape = metadata['original_shape']
925
-
926
- # Get device from first available tensor
927
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
928
- for head_key in compressed_data['keys'].keys():
929
- if head_key.startswith('head_'):
930
- device = compressed_data['keys'][head_key]['data'].device
931
- break
932
-
933
- # Initialize full tensors
934
- keys_full = torch.zeros(original_shape, dtype=torch.float16, device=device)
935
- values_full = torch.zeros(original_shape, dtype=torch.float16, device=device)
936
-
937
- # Reconstruct selected heads with their tokens
938
- for head_key in compressed_data['keys'].keys():
939
- if not head_key.startswith('head_'):
940
- continue
941
-
942
- head_idx = int(head_key.split('_')[1])
943
- head_data_k = compressed_data['keys'][head_key]
944
- head_data_v = compressed_data['values'][head_key]
945
-
946
- token_indices = head_data_k['indices']
947
-
948
- # Place data in the correct head and token positions
949
- if head_idx < keys_full.shape[1]:
950
- for i, token_idx in enumerate(token_indices):
951
- if token_idx < keys_full.shape[2]:
952
- keys_full[:, head_idx, token_idx, :] = head_data_k['data'][:, 0, i, :]
953
- values_full[:, head_idx, token_idx, :] = head_data_v['data'][:, 0, i, :]
954
-
955
- return keys_full, values_full
956
-
957
- def _decompress_original_spg(self, compressed_data: Dict) -> Tuple[torch.Tensor, torch.Tensor]:
958
- """Decompress original SPG data."""
959
- metadata = compressed_data['metadata']
960
- original_shape = metadata['original_shape']
961
- original_dtype = metadata['original_dtype']
962
- device = metadata['precision_scores'].device
963
-
964
- keys_full = torch.zeros(original_shape, dtype=original_dtype, device=device)
965
- values_full = torch.zeros(original_shape, dtype=original_dtype, device=device)
966
-
967
- for precision_key in compressed_data['keys']:
968
- data_dict = compressed_data['keys'][precision_key]
969
- if 'data' in data_dict and 'indices' in data_dict:
970
- indices = data_dict['indices']
971
- if not indices:
972
- continue
973
-
974
- idx_tensor = torch.tensor(indices, device=device, dtype=torch.long)
975
-
976
- # All data stored as original precision
977
- k_data = data_dict['data']
978
- v_data = compressed_data['values'][precision_key]['data']
979
-
980
- if k_data is not None and v_data is not None:
981
- for i, seq_idx in enumerate(indices):
982
- if seq_idx < keys_full.shape[2]:
983
- keys_full[:, :, seq_idx, :] = k_data[:, :, i, :]
984
- values_full[:, :, seq_idx, :] = v_data[:, :, i, :]
985
-
986
- return keys_full, values_full
987
-
988
- def get_memory_footprint(self, compressed_data: Dict[str, Any]) -> int:
989
- """
990
- Calculate ACTUAL memory usage - NO ESTIMATES.
991
- Every byte is accounted for explicitly.
992
- """
993
- total_bytes = 0
994
-
995
- try:
996
- # Count all stored tensors
997
- for storage_type in ['keys', 'values']:
998
- for key, data in compressed_data[storage_type].items():
999
- if isinstance(data, dict):
1000
- # Data tensors
1001
- if 'data' in data and isinstance(data['data'], torch.Tensor):
1002
- total_bytes += data['data'].nelement() * data['data'].element_size()
1003
-
1004
- # Scale/zero tensors
1005
- if 'scale' in data and isinstance(data['scale'], torch.Tensor):
1006
- total_bytes += data['scale'].nelement() * data['scale'].element_size()
1007
- if 'zero' in data and isinstance(data['zero'], torch.Tensor):
1008
- total_bytes += data['zero'].nelement() * data['zero'].element_size()
1009
-
1010
- # Levels tensor for bit-packed data
1011
- if 'levels' in data and isinstance(data['levels'], torch.Tensor):
1012
- total_bytes += data['levels'].nelement() * data['levels'].element_size()
1013
-
1014
- # Metadata overhead (measured, not estimated)
1015
- if 'meta' in data and isinstance(data['meta'], dict):
1016
- total_bytes += self.constants.INT2_METADATA_BYTES
1017
-
1018
- # Indices (count only once under keys to avoid double counting)
1019
- if storage_type == 'keys' and 'indices' in data and data['indices']:
1020
- total_bytes += len(data['indices']) * self.constants.INDEX_SIZE_BYTES
1021
-
1022
- # Metadata overhead
1023
- total_bytes += self.constants.METADATA_OVERHEAD_BYTES
1024
-
1025
- logger.debug(f"Measured memory footprint: {total_bytes} bytes ({total_bytes/1024/1024:.2f} MB)")
1026
- return total_bytes
1027
-
1028
- except Exception as e:
1029
- logger.error(f"Error calculating memory footprint: {e}")
1030
- raise
1031
-
1032
- def update_quality_feedback(self, layer_idx: int, quality_metric: float):
1033
- """Update quality feedback for progressive compression."""
1034
- self.quality_history.append(quality_metric)
1035
-
1036
- # Keep only recent history
1037
- if len(self.quality_history) > self.constants.QUALITY_HISTORY_MAX_SIZE:
1038
- self.quality_history = self.quality_history[-self.constants.QUALITY_HISTORY_MAX_SIZE:]
1039
-
1040
-
1041
- class QuantizedKVCache:
1042
- """Enhanced quantized KV cache with working multi-stage SPG support."""
1043
-
1044
- def __init__(self, config: CompressionConfig):
1045
- self.config = config
1046
- self.compressed_data = {}
1047
- self.dtypes = {}
1048
-
1049
- # Initialize enhanced SPG with RocketKV features
1050
- if config.compression_type in [CompressionType.SPG, CompressionType.ADAPTIVE_SPG]:
1051
- spg_config = replace(config.enhanced_spg_config,
1052
- enable_two_stage=False,
1053
- enable_adaptive=(config.compression_type == CompressionType.ADAPTIVE_SPG))
1054
- self.spg = EnhancedSlidingPrecisionGradient(spg_config)
1055
- elif config.compression_type in [CompressionType.ENHANCED_SPG, CompressionType.PROGRESSIVE_SPG]:
1056
- enhanced_config = config.enhanced_spg_config
1057
- if config.compression_type == CompressionType.PROGRESSIVE_SPG:
1058
- enhanced_config.enable_progressive = True
1059
- self.spg = EnhancedSlidingPrecisionGradient(enhanced_config)
1060
- else:
1061
- self.spg = None
1062
-
1063
- self.current_position = 0
1064
- self.quality_history = []
1065
- self.n_layers = None
1066
-
1067
- def compress_and_store(self, layer_idx: int, keys: torch.Tensor, values: torch.Tensor):
1068
- """Compress and store KV pairs with enhanced SPG support."""
1069
- key_dtype = keys.dtype
1070
- value_dtype = values.dtype
1071
-
1072
- if self.config.compression_type in [CompressionType.SPG, CompressionType.ADAPTIVE_SPG,
1073
- CompressionType.ENHANCED_SPG, CompressionType.PROGRESSIVE_SPG]:
1074
- if self.spg.layer_decay_rates is None:
1075
- if self.n_layers is None:
1076
- raise ValueError("Model layer count not set - call detect_model_layers first")
1077
- self.spg.initialize_layer_decay_rates(self.n_layers)
1078
-
1079
- if self.config.compression_type in [CompressionType.ENHANCED_SPG, CompressionType.PROGRESSIVE_SPG]:
1080
- compressed_data = self.spg.compress_with_enhanced_gradient(
1081
- keys, values, layer_idx, self.current_position
1082
- )
1083
- else:
1084
- compressed_data = self.spg._fallback_to_original_spg(
1085
- keys, values, layer_idx, self.current_position
1086
- )
1087
-
1088
- self.compressed_data[layer_idx] = compressed_data
1089
- self.dtypes[layer_idx] = {'keys': key_dtype, 'values': value_dtype}
1090
- else:
1091
- # No compression - store original tensors
1092
- self.compressed_data[layer_idx] = {
1093
- 'keys': {'original': {'data': keys.clone(), 'indices': list(range(keys.shape[2]))}},
1094
- 'values': {'original': {'data': values.clone(), 'indices': list(range(values.shape[2]))}},
1095
- 'metadata': {
1096
- 'compression_type': 'none',
1097
- 'original_shape': keys.shape,
1098
- 'original_dtype': keys.dtype
1099
- }
1100
- }
1101
- self.dtypes[layer_idx] = {'keys': key_dtype, 'values': value_dtype}
1102
-
1103
- def get_decompressed(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
1104
- """Get decompressed KV pairs with enhanced SPG support."""
1105
- if self.config.compression_type in [CompressionType.SPG, CompressionType.ADAPTIVE_SPG,
1106
- CompressionType.ENHANCED_SPG, CompressionType.PROGRESSIVE_SPG]:
1107
- if layer_idx in self.compressed_data:
1108
- return self.spg.decompress(self.compressed_data[layer_idx])
1109
- return None, None
1110
- else:
1111
- # No compression - return original tensors
1112
- if layer_idx in self.compressed_data:
1113
- data = self.compressed_data[layer_idx]
1114
- return data['keys']['original']['data'], data['values']['original']['data']
1115
- return None, None
1116
-
1117
- def get_memory_footprint(self) -> int:
1118
- """Calculate actual memory usage with enhanced SPG support."""
1119
- total_bytes = 0
1120
- constants = ResearchConstants()
1121
-
1122
- if self.config.compression_type in [CompressionType.SPG, CompressionType.ADAPTIVE_SPG,
1123
- CompressionType.ENHANCED_SPG, CompressionType.PROGRESSIVE_SPG]:
1124
- for layer_idx in self.compressed_data:
1125
- total_bytes += self.spg.get_memory_footprint(self.compressed_data[layer_idx])
1126
- else:
1127
- # No compression - calculate uncompressed memory
1128
- for layer_idx in self.compressed_data:
1129
- data = self.compressed_data[layer_idx]
1130
- keys_data = data['keys']['original']['data']
1131
- values_data = data['values']['original']['data']
1132
- total_bytes += keys_data.nelement() * keys_data.element_size()
1133
- total_bytes += values_data.nelement() * values_data.element_size()
1134
- total_bytes += constants.METADATA_OVERHEAD_BYTES
1135
-
1136
- return total_bytes
1137
-
1138
- def update_position(self, new_position: int):
1139
- """Update current generation position."""
1140
- self.current_position = new_position
1141
-
1142
- def update_quality_feedback(self, layer_idx: int, quality_metric: float):
1143
- """Provide quality feedback for adaptive methods."""
1144
- if self.config.compression_type == CompressionType.ADAPTIVE_SPG and hasattr(self.spg, 'update_decay_rate'):
1145
- target_quality = self.config.enhanced_spg_config.target_perplexity_delta
1146
- self.spg.update_decay_rate(layer_idx, quality_metric, target_quality)
1147
- self.quality_history.append((layer_idx, quality_metric))
1148
- elif self.config.compression_type in [CompressionType.ENHANCED_SPG, CompressionType.PROGRESSIVE_SPG]:
1149
- self.spg.update_quality_feedback(layer_idx, quality_metric)
1150
-
1151
-
1152
- def detect_model_layers(model) -> int:
1153
- """Detect the number of transformer layers with comprehensive validation."""
1154
- config_attrs = [
1155
- 'num_hidden_layers',
1156
- 'n_layer',
1157
- 'num_layers',
1158
- 'n_layers',
1159
- 'decoder_layers',
1160
- 'n_head_layers',
1161
- ]
1162
-
1163
- for attr in config_attrs:
1164
- if hasattr(model.config, attr):
1165
- n_layers = getattr(model.config, attr)
1166
- if isinstance(n_layers, int) and n_layers > 0:
1167
- logger.info(f"Detected {n_layers} layers from config.{attr}")
1168
- return n_layers
1169
-
1170
- layer_patterns = [
1171
- 'layer', 'layers', 'h', 'blocks', 'decoder.layers', 'transformer_blocks', 'decoderLayer',
1172
- ]
1173
-
1174
- for module_name, module in model.named_modules():
1175
- for pattern in layer_patterns:
1176
- if pattern in module_name.lower():
1177
- if hasattr(module, '__len__'):
1178
- n_layers = len(module)
1179
- if n_layers > 0:
1180
- logger.info(f"Detected {n_layers} layers by counting {module_name}")
1181
- return n_layers
1182
-
1183
- decoder_layer_types = [
1184
- 'TransformerBlock', 'DecoderLayer', 'EncoderLayer', 'Block', 'Layer',
1185
- 'GPT2Block', 'LlamaDecoderLayer', 'MistralDecoderLayer', 'OPTDecoderLayer',
1186
- ]
1187
-
1188
- layers = []
1189
- for module in model.modules():
1190
- module_type = type(module).__name__
1191
- if any(layer_type in module_type for layer_type in decoder_layer_types):
1192
- layers.append(module)
1193
-
1194
- if layers:
1195
- n_layers = len(set(layers))
1196
- if n_layers > 0:
1197
- logger.info(f"Detected {n_layers} layers by module type matching")
1198
- return n_layers
1199
-
1200
- # Fail fast if cannot detect layers
1201
- raise ValueError(
1202
- f"Could not automatically detect the number of layers for model {type(model).__name__}. "
1203
- "Please check the model architecture and update the detection logic."
1204
- )