| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import time |
| | import torch |
| | from kernels import get_local_kernel, get_kernel |
| | from pathlib import Path |
| | from torch.nn import functional as F |
| | import sys |
| | import matplotlib.pyplot as plt |
| | import matplotlib.gridspec as gridspec |
| | import numpy as np |
| |
|
| | |
| | |
| | |
| |
|
| | yamoe = get_kernel("drbh/yamoe", revision="v0.2.0") |
| | reference = yamoe.reference |
| |
|
| | |
| | torch.manual_seed(0) |
| |
|
| | |
| | configs = [ |
| | {"seq_len": 512, "hidden_dim": 2880, "num_experts": 32, "top_k": 4}, |
| | {"seq_len": 1024, "hidden_dim": 2880, "num_experts": 32, "top_k": 4}, |
| | {"seq_len": 512, "hidden_dim": 1024, "num_experts": 32, "top_k": 4}, |
| | {"seq_len": 512, "hidden_dim": 2880, "num_experts": 16, "top_k": 2}, |
| | {"seq_len": 2048, "hidden_dim": 1024, "num_experts": 16, "top_k": 2}, |
| | {"seq_len": 768, "hidden_dim": 2048, "num_experts": 64, "top_k": 8}, |
| | ] |
| |
|
| | |
| | batch_sizes = [1, 2, 4, 8, 16, 32, 64] |
| | all_results = [] |
| |
|
| | |
| | for config_idx, config in enumerate(configs): |
| | seq_len = config["seq_len"] |
| | hidden_dim = config["hidden_dim"] |
| | num_experts = config["num_experts"] |
| | top_k = config["top_k"] |
| |
|
| | print(f"\n{'=' * 70}") |
| | print( |
| | f"Config {config_idx + 1}: seq={seq_len}, hidden={hidden_dim}, experts={num_experts}, top_k={top_k}" |
| | ) |
| | print(f"{'=' * 70}") |
| |
|
| | yamoe_times = [] |
| | reference_times = [] |
| | yamoe_memory = [] |
| | reference_memory = [] |
| | speedups = [] |
| |
|
| | |
| | for batch_size in batch_sizes: |
| | print(f"\nBatch size = {batch_size}") |
| |
|
| | try: |
| | |
| | logits = torch.randn(batch_size, seq_len, num_experts) |
| |
|
| | |
| | weights, indices = torch.topk(logits, top_k, dim=-1) |
| | weights = F.softmax(weights, dim=-1) |
| | batch_seq = batch_size * seq_len |
| | routing_weights = torch.zeros( |
| | batch_seq, num_experts, device=logits.device, dtype=weights.dtype |
| | ) |
| | flat_indices, flat_weights = ( |
| | indices.reshape(-1, top_k), |
| | weights.reshape(-1, top_k), |
| | ) |
| | batch_indices = ( |
| | torch.arange(batch_seq, device=logits.device) |
| | .unsqueeze(1) |
| | .expand(-1, top_k) |
| | ) |
| | routing_weights[batch_indices, flat_indices] = flat_weights |
| | router_indices = flat_indices |
| |
|
| | |
| | hidden_states = torch.randn(batch_size, seq_len, hidden_dim).cuda().half() |
| | gate_up_proj = ( |
| | torch.randn(num_experts, hidden_dim, 2 * hidden_dim).cuda().half() |
| | ) |
| | gate_up_proj_bias = torch.ones(num_experts, 2 * hidden_dim).cuda().half() |
| | down_proj = torch.randn(num_experts, hidden_dim, hidden_dim).cuda().half() |
| | down_proj_bias = torch.ones(num_experts, hidden_dim).cuda().half() |
| | logits, routing_weights = ( |
| | logits.cuda().half(), |
| | routing_weights.cuda().half(), |
| | ) |
| | router_indices = router_indices.cuda() |
| |
|
| | |
| | yamoe_success = True |
| | yamoe_time = None |
| | yamoe_mem = None |
| |
|
| | try: |
| | |
| | for _ in range(5): |
| | _ = yamoe.experts( |
| | hidden_states.view(-1, hidden_dim), |
| | router_indices, |
| | routing_weights.view(-1, num_experts), |
| | gate_up_proj, |
| | gate_up_proj_bias, |
| | down_proj, |
| | down_proj_bias, |
| | seq_len, |
| | num_experts, |
| | top_k, |
| | ) |
| |
|
| | |
| | torch.cuda.synchronize() |
| | torch.cuda.reset_peak_memory_stats() |
| |
|
| | yamoe_runs = [] |
| | for _ in range(10): |
| | start = time.perf_counter() |
| | output = yamoe.experts( |
| | hidden_states.view(-1, hidden_dim), |
| | router_indices, |
| | routing_weights.view(-1, num_experts), |
| | gate_up_proj, |
| | gate_up_proj_bias, |
| | down_proj, |
| | down_proj_bias, |
| | seq_len, |
| | num_experts, |
| | top_k, |
| | ) |
| | torch.cuda.synchronize() |
| | yamoe_runs.append((time.perf_counter() - start) * 1e3) |
| |
|
| | yamoe_time = sum(yamoe_runs) / len(yamoe_runs) |
| | yamoe_mem = torch.cuda.max_memory_allocated() / (1024 * 1024) |
| |
|
| | except RuntimeError as e: |
| | if "out of memory" in str(e).lower(): |
| | print(f" Yamoe: OOM - skipping this batch size") |
| | yamoe_success = False |
| | else: |
| | raise e |
| |
|
| | |
| | ref_success = True |
| | ref_time = None |
| | ref_mem = None |
| |
|
| | try: |
| | |
| | config_obj = type("Config", (), {})() |
| | config_obj.hidden_size = hidden_dim |
| | config_obj.intermediate_size = 4 * hidden_dim |
| | config_obj.num_local_experts = num_experts |
| |
|
| | model = reference.GptOssExperts(config_obj) |
| | model.gate_up_proj.data = gate_up_proj |
| | model.gate_up_proj_bias.data = gate_up_proj_bias |
| | model.down_proj.data = down_proj |
| | model.down_proj_bias.data = down_proj_bias |
| | model = model.cuda().half() |
| | model.eval() |
| |
|
| | |
| | with torch.no_grad(): |
| | for _ in range(5): |
| | _ = model(hidden_states, router_indices, routing_weights) |
| |
|
| | |
| | torch.cuda.synchronize() |
| | torch.cuda.reset_peak_memory_stats() |
| |
|
| | ref_runs = [] |
| | with torch.no_grad(): |
| | for _ in range(10): |
| | start = time.perf_counter() |
| | ref_output = model( |
| | hidden_states, router_indices, routing_weights |
| | ) |
| | torch.cuda.synchronize() |
| | ref_runs.append((time.perf_counter() - start) * 1e3) |
| |
|
| | ref_time = sum(ref_runs) / len(ref_runs) |
| | ref_mem = torch.cuda.max_memory_allocated() / (1024 * 1024) |
| |
|
| | except RuntimeError as e: |
| | if "out of memory" in str(e).lower(): |
| | print(f" Reference: OOM - skipping this batch size") |
| | ref_success = False |
| | else: |
| | raise e |
| |
|
| | |
| | if yamoe_success and ref_success: |
| | yamoe_times.append(yamoe_time) |
| | yamoe_memory.append(yamoe_mem) |
| | reference_times.append(ref_time) |
| | reference_memory.append(ref_mem) |
| | speedup = ref_time / yamoe_time |
| | speedups.append(speedup) |
| |
|
| | throughput_yamoe = ( |
| | (batch_size * seq_len * hidden_dim) / (yamoe_time / 1000) / 1e9 |
| | ) |
| | throughput_ref = ( |
| | (batch_size * seq_len * hidden_dim) / (ref_time / 1000) / 1e9 |
| | ) |
| |
|
| | print( |
| | f" Yamoe: {yamoe_time:.3f} ms / {yamoe_mem:.1f} MB / {throughput_yamoe:.2f} GFLOPS" |
| | ) |
| | print( |
| | f" Reference: {ref_time:.3f} ms / {ref_mem:.1f} MB / {throughput_ref:.2f} GFLOPS" |
| | ) |
| | print( |
| | f" Speedup: {speedup:.2f}x, Memory reduction: {ref_mem / yamoe_mem:.2f}x, " |
| | f"Efficiency gain: {throughput_yamoe / throughput_ref:.2f}x" |
| | ) |
| | elif yamoe_success and not ref_success: |
| | |
| | yamoe_times.append(yamoe_time) |
| | yamoe_memory.append(yamoe_mem) |
| | |
| | reference_times.append(None) |
| | reference_memory.append(None) |
| | speedups.append(None) |
| |
|
| | throughput_yamoe = ( |
| | (batch_size * seq_len * hidden_dim) / (yamoe_time / 1000) / 1e9 |
| | ) |
| | print( |
| | f" Yamoe: {yamoe_time:.3f} ms / {yamoe_mem:.1f} MB / {throughput_yamoe:.2f} GFLOPS" |
| | ) |
| | print(f" Reference: OOM - unable to measure") |
| | print(f" Yamoe runs successfully while Reference OOMs") |
| | elif not yamoe_success and ref_success: |
| | |
| | yamoe_times.append(None) |
| | yamoe_memory.append(None) |
| | reference_times.append(ref_time) |
| | reference_memory.append(ref_mem) |
| | speedups.append(None) |
| |
|
| | throughput_ref = ( |
| | (batch_size * seq_len * hidden_dim) / (ref_time / 1000) / 1e9 |
| | ) |
| | print(f" Yamoe: OOM - unable to measure") |
| | print( |
| | f" Reference: {ref_time:.3f} ms / {ref_mem:.1f} MB / {throughput_ref:.2f} GFLOPS" |
| | ) |
| | print(f" Reference runs successfully while Yamoe OOMs") |
| | else: |
| | |
| | yamoe_times.append(None) |
| | yamoe_memory.append(None) |
| | reference_times.append(None) |
| | reference_memory.append(None) |
| | speedups.append(None) |
| | print(f" Both implementations OOM at batch_size={batch_size}") |
| |
|
| | except Exception as e: |
| | print(f" Unexpected error at batch_size={batch_size}: {str(e)}") |
| | |
| | yamoe_times.append(None) |
| | yamoe_memory.append(None) |
| | reference_times.append(None) |
| | reference_memory.append(None) |
| | speedups.append(None) |
| |
|
| | |
| | torch.cuda.empty_cache() |
| |
|
| | all_results.append( |
| | { |
| | "config": config, |
| | "yamoe_times": yamoe_times, |
| | "reference_times": reference_times, |
| | "yamoe_memory": yamoe_memory, |
| | "reference_memory": reference_memory, |
| | "speedups": speedups, |
| | } |
| | ) |
| |
|
| | |
| | fig = plt.figure(figsize=(24, 16)) |
| |
|
| | |
| | for config_idx, result in enumerate(all_results[:6]): |
| | |
| | ax1 = plt.subplot(3, 6, config_idx + 1) |
| | x = np.arange(len(batch_sizes)) |
| | width = 0.35 |
| |
|
| | |
| | yamoe_times_filtered = [t if t is not None else 0 for t in result["yamoe_times"]] |
| | ref_times_filtered = [t if t is not None else 0 for t in result["reference_times"]] |
| |
|
| | bars1 = ax1.bar( |
| | x - width / 2, |
| | yamoe_times_filtered, |
| | width, |
| | label="Yamoe", |
| | color="#1f77b4", |
| | alpha=0.8, |
| | ) |
| | bars2 = ax1.bar( |
| | x + width / 2, |
| | ref_times_filtered, |
| | width, |
| | label="Reference", |
| | color="#ff7f0e", |
| | alpha=0.8, |
| | ) |
| |
|
| | |
| | for i, (y_time, r_time) in enumerate( |
| | zip(result["yamoe_times"], result["reference_times"]) |
| | ): |
| | if y_time is not None and r_time is not None: |
| | speedup = r_time / y_time |
| | ax1.text( |
| | i, |
| | max(y_time, r_time) * 1.05, |
| | f"{speedup:.1f}x", |
| | ha="center", |
| | va="bottom", |
| | fontsize=7, |
| | fontweight="bold", |
| | color="green", |
| | ) |
| | elif y_time is not None and r_time is None: |
| | ax1.text( |
| | i, |
| | y_time * 1.05, |
| | "Y-OK", |
| | ha="center", |
| | va="bottom", |
| | fontsize=7, |
| | fontweight="bold", |
| | color="blue", |
| | ) |
| | elif y_time is None and r_time is not None: |
| | ax1.text( |
| | i, |
| | r_time * 1.05, |
| | "R-OK", |
| | ha="center", |
| | va="bottom", |
| | fontsize=7, |
| | fontweight="bold", |
| | color="orange", |
| | ) |
| | else: |
| | ax1.text( |
| | i, |
| | 0.1, |
| | "OOM", |
| | ha="center", |
| | va="bottom", |
| | fontsize=7, |
| | fontweight="bold", |
| | color="red", |
| | ) |
| |
|
| | ax1.set_ylabel("Time (ms)", fontsize=9) |
| | ax1.set_yscale("log") |
| | ax1.set_xticks(x) |
| | ax1.set_xticklabels(batch_sizes, fontsize=8) |
| | ax1.grid(True, alpha=0.3, axis="y") |
| |
|
| | config = result["config"] |
| | ax1.set_title( |
| | f"Time: seq={config['seq_len']}, h={config['hidden_dim']}, e={config['num_experts']}", |
| | fontsize=8, |
| | fontweight="bold", |
| | ) |
| |
|
| | if config_idx == 0: |
| | ax1.legend(loc="upper left", fontsize=8) |
| |
|
| | |
| | ax2 = plt.subplot(3, 6, config_idx + 7) |
| |
|
| | |
| | yamoe_mem_filtered = [m if m is not None else 0 for m in result["yamoe_memory"]] |
| | ref_mem_filtered = [m if m is not None else 0 for m in result["reference_memory"]] |
| |
|
| | bars3 = ax2.bar( |
| | x - width / 2, |
| | yamoe_mem_filtered, |
| | width, |
| | label="Yamoe", |
| | color="#2ca02c", |
| | alpha=0.8, |
| | ) |
| | bars4 = ax2.bar( |
| | x + width / 2, |
| | ref_mem_filtered, |
| | width, |
| | label="Reference", |
| | color="#d62728", |
| | alpha=0.8, |
| | ) |
| |
|
| | |
| | for i, (y_mem, r_mem) in enumerate( |
| | zip(result["yamoe_memory"], result["reference_memory"]) |
| | ): |
| | if y_mem is not None and r_mem is not None: |
| | reduction = r_mem / y_mem |
| | ax2.text( |
| | i, |
| | max(y_mem, r_mem) * 1.05, |
| | f"{reduction:.1f}x", |
| | ha="center", |
| | va="bottom", |
| | fontsize=7, |
| | fontweight="bold", |
| | color="purple", |
| | ) |
| |
|
| | ax2.set_ylabel("Memory (MB)", fontsize=9) |
| | ax2.set_yscale("log") |
| | ax2.set_xticks(x) |
| | ax2.set_xticklabels(batch_sizes, fontsize=8) |
| | ax2.grid(True, alpha=0.3, axis="y") |
| | ax2.set_title( |
| | f"Memory: seq={config['seq_len']}, h={config['hidden_dim']}, e={config['num_experts']}", |
| | fontsize=8, |
| | fontweight="bold", |
| | ) |
| |
|
| | if config_idx == 0: |
| | ax2.legend(loc="upper left", fontsize=8) |
| |
|
| | |
| | ax3 = plt.subplot(3, 6, config_idx + 13) |
| |
|
| | |
| | valid_speedups = [] |
| | valid_mem_reductions = [] |
| | valid_batch_sizes_speedup = [] |
| | valid_batch_sizes_mem = [] |
| |
|
| | for i, (r, y) in enumerate(zip(result["reference_times"], result["yamoe_times"])): |
| | if r is not None and y is not None: |
| | valid_speedups.append(r / y) |
| | valid_batch_sizes_speedup.append(batch_sizes[i]) |
| |
|
| | for i, (r, y) in enumerate(zip(result["reference_memory"], result["yamoe_memory"])): |
| | if r is not None and y is not None: |
| | valid_mem_reductions.append(r / y) |
| | valid_batch_sizes_mem.append(batch_sizes[i]) |
| |
|
| | if valid_speedups: |
| | ax3.plot( |
| | valid_batch_sizes_speedup, |
| | valid_speedups, |
| | "o-", |
| | label="Time Speedup", |
| | color="green", |
| | linewidth=2, |
| | markersize=6, |
| | ) |
| | if valid_mem_reductions: |
| | ax3.plot( |
| | valid_batch_sizes_mem, |
| | valid_mem_reductions, |
| | "s-", |
| | label="Memory Reduction", |
| | color="purple", |
| | linewidth=2, |
| | markersize=6, |
| | ) |
| |
|
| | ax3.set_xlabel("Batch Size", fontsize=9) |
| | ax3.set_ylabel("Improvement Factor", fontsize=9) |
| | ax3.set_xticks(batch_sizes) |
| | ax3.grid(True, alpha=0.3) |
| | ax3.axhline(y=1, color="gray", linestyle="--", alpha=0.5) |
| | ax3.set_title( |
| | f"Improvements: seq={config['seq_len']}, h={config['hidden_dim']}", |
| | fontsize=8, |
| | fontweight="bold", |
| | ) |
| |
|
| | if config_idx == 0: |
| | ax3.legend(loc="upper left", fontsize=8) |
| |
|
| | plt.suptitle( |
| | "MoE Performance & Memory Comparison - Yamoe vs Reference", |
| | fontsize=16, |
| | fontweight="bold", |
| | y=0.98, |
| | ) |
| | plt.tight_layout() |
| | plt.savefig("moe_performance_comparison.png", dpi=150, bbox_inches="tight") |
| | plt.show() |
| |
|
| | |
| |
|
| | |
| | print("\n" + "=" * 80) |
| | print("DETAILED SUMMARY") |
| | print("=" * 80) |
| |
|
| | for idx, result in enumerate(all_results[:6]): |
| | config = result["config"] |
| | print(f"\nConfiguration {idx + 1}:") |
| | print( |
| | f" Parameters: seq_len={config['seq_len']}, hidden_dim={config['hidden_dim']}, " |
| | f"experts={config['num_experts']}, top_k={config['top_k']}" |
| | ) |
| | |
| | valid_speedups = [s for s in result["speedups"] if s is not None] |
| | if valid_speedups: |
| | print(f" Average Speedup: {sum(valid_speedups) / len(valid_speedups):.2f}x") |
| | max_speedup = max(valid_speedups) |
| | min_speedup = min(valid_speedups) |
| | max_idx = result["speedups"].index(max_speedup) |
| | min_idx = result["speedups"].index(min_speedup) |
| | print(f" Max Speedup: {max_speedup:.2f}x at batch_size={batch_sizes[max_idx]}") |
| | print(f" Min Speedup: {min_speedup:.2f}x at batch_size={batch_sizes[min_idx]}") |
| | else: |
| | print(" No valid speedup measurements (all OOM)") |
| |
|