| """ |
| GPU Training Script for MARS on MovieLens-1M. |
| |
| Trains MARS (innovative method) and SASRec (baseline) for comparison. |
| Pushes results to HF Hub. |
| """ |
|
|
| import os |
| import sys |
| import time |
| import json |
| import random |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
|
|
| |
| def set_seed(seed=42): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
| set_seed(42) |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
| if torch.cuda.is_available(): |
| print(f"GPU: {torch.cuda.get_device_name(0)}") |
| print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB") |
|
|
| |
| CONFIGS = { |
| 'mars': { |
| 'embed_dim': 64, |
| 'max_seq_len': 512, |
| 'short_term_len': 50, |
| 'num_memory_tokens': 8, |
| 'num_tadn_layers': 3, |
| 'num_attn_layers': 2, |
| 'num_heads': 2, |
| 'state_dim': 64, |
| 'dropout': 0.1, |
| 'batch_size': 128, |
| 'lr': 1e-3, |
| 'weight_decay': 0.01, |
| 'epochs': 80, |
| 'num_negatives': 4, |
| }, |
| 'sasrec': { |
| 'embed_dim': 64, |
| 'max_seq_len': 200, |
| 'num_heads': 2, |
| 'num_layers': 2, |
| 'dropout': 0.1, |
| 'batch_size': 256, |
| 'lr': 1e-3, |
| 'weight_decay': 0.0, |
| 'epochs': 80, |
| 'num_negatives': 4, |
| }, |
| } |
|
|
| |
| from model import MARS, SASRecBaseline |
| from data import load_movielens_1m, ReindexedData, create_dataloaders, save_data_config |
| from evaluate import evaluate_model, print_comparison |
|
|
| |
| try: |
| import trackio |
| trackio.init(name="MARS-SeqRec-ML1M", project="mars-seqrec") |
| use_trackio = True |
| print("Trackio initialized successfully") |
| except Exception as e: |
| print(f"Trackio not available: {e}") |
| use_trackio = False |
|
|
| |
| print("\n" + "="*60) |
| print("Loading MovieLens-1M Dataset") |
| print("="*60) |
|
|
| sequences = load_movielens_1m(min_interactions=5) |
| print(f"Loaded {len(sequences)} user sequences") |
|
|
| |
| seq_lens = [len(v['item_ids']) for v in sequences.values()] |
| print(f"Sequence length: mean={np.mean(seq_lens):.1f}, median={np.median(seq_lens):.1f}, " |
| f"max={np.max(seq_lens)}, p90={np.percentile(seq_lens, 90):.0f}") |
| print(f"Users with 100+ interactions: {sum(1 for l in seq_lens if l >= 100)}") |
| print(f"Users with 200+ interactions: {sum(1 for l in seq_lens if l >= 200)}") |
| print(f"Users with 500+ interactions: {sum(1 for l in seq_lens if l >= 500)}") |
|
|
|
|
| def train_model(model_name, config, sequences, device): |
| """Train a model with given config and return results.""" |
| print(f"\n{'='*60}") |
| print(f"Training: {model_name.upper()}") |
| print(f"{'='*60}") |
| |
| max_seq_len = config['max_seq_len'] |
| batch_size = config['batch_size'] |
| epochs = config['epochs'] |
| |
| |
| data = ReindexedData(sequences, max_seq_len=max_seq_len) |
| train_loader, val_loader, test_loader = create_dataloaders( |
| data, max_seq_len=max_seq_len, |
| batch_size=batch_size, |
| num_negatives=config['num_negatives'], |
| num_workers=4, |
| ) |
| |
| |
| if model_name == 'mars': |
| model = MARS( |
| num_items=data.num_items, |
| embed_dim=config['embed_dim'], |
| max_seq_len=max_seq_len, |
| short_term_len=config['short_term_len'], |
| num_memory_tokens=config['num_memory_tokens'], |
| num_tadn_layers=config['num_tadn_layers'], |
| num_attn_layers=config['num_attn_layers'], |
| num_heads=config['num_heads'], |
| state_dim=config['state_dim'], |
| dropout=config['dropout'], |
| ) |
| else: |
| model = SASRecBaseline( |
| num_items=data.num_items, |
| embed_dim=config['embed_dim'], |
| max_seq_len=config['max_seq_len'], |
| num_heads=config['num_heads'], |
| num_layers=config['num_layers'], |
| dropout=config['dropout'], |
| ) |
| |
| model = model.to(device) |
| num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"Parameters: {num_params:,}") |
| |
| |
| optimizer = AdamW( |
| model.parameters(), |
| lr=config['lr'], |
| weight_decay=config['weight_decay'], |
| ) |
| scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=config['lr'] * 0.01) |
| |
| |
| best_val_hr10 = 0 |
| best_epoch = 0 |
| best_state = None |
| |
| save_dir = f'./checkpoints/{model_name}' |
| os.makedirs(save_dir, exist_ok=True) |
| |
| for epoch in range(1, epochs + 1): |
| model.train() |
| total_loss = 0 |
| num_batches = 0 |
| t0 = time.time() |
| |
| for batch_idx, batch in enumerate(train_loader): |
| batch = {k: v.to(device) for k, v in batch.items()} |
| |
| optimizer.zero_grad() |
| loss = model(batch) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
| optimizer.step() |
| |
| total_loss += loss.item() |
| num_batches += 1 |
| |
| scheduler.step() |
| avg_loss = total_loss / num_batches |
| epoch_time = time.time() - t0 |
| lr = scheduler.get_last_lr()[0] |
| |
| print(f"Epoch {epoch:3d}/{epochs} | Loss: {avg_loss:.4f} | " |
| f"LR: {lr:.6f} | Time: {epoch_time:.1f}s") |
| |
| if use_trackio: |
| trackio.log({ |
| f"{model_name}/train_loss": avg_loss, |
| f"{model_name}/lr": lr, |
| f"{model_name}/epoch_time": epoch_time, |
| "epoch": epoch, |
| }) |
| |
| |
| if epoch % 10 == 0 or epoch == epochs or epoch <= 5: |
| metrics = evaluate_model( |
| model, val_loader, data.num_items, device, |
| ks=[5, 10, 20, 50] |
| ) |
| |
| print(f" Val | HR@10={metrics['HR@10']:.4f} NDCG@10={metrics['NDCG@10']:.4f} " |
| f"MRR@10={metrics['MRR@10']:.4f}") |
| |
| if use_trackio: |
| trackio.log({f"{model_name}/val_{k}": v for k, v in metrics.items() if k != 'eval_time'}) |
| trackio.log({"epoch": epoch}) |
| |
| hr10 = metrics['HR@10'] |
| if hr10 > best_val_hr10: |
| best_val_hr10 = hr10 |
| best_epoch = epoch |
| best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} |
| print(f" β New best! HR@10={hr10:.4f}") |
| |
| |
| model.load_state_dict(best_state) |
| model = model.to(device) |
| |
| test_metrics = evaluate_model( |
| model, test_loader, data.num_items, device, |
| ks=[5, 10, 20, 50] |
| ) |
| |
| print(f"\nFinal Test Results ({model_name.upper()}, best epoch {best_epoch}):") |
| for k, v in test_metrics.items(): |
| if k != 'eval_time': |
| print(f" {k}: {v:.4f}") |
| |
| if use_trackio: |
| trackio.log({f"{model_name}/test_{k}": v for k, v in test_metrics.items() if k != 'eval_time'}) |
| |
| |
| torch.save({ |
| 'model_state_dict': best_state, |
| 'config': config, |
| 'test_metrics': test_metrics, |
| 'best_epoch': best_epoch, |
| 'num_items': data.num_items, |
| }, os.path.join(save_dir, 'best_model.pt')) |
| |
| return test_metrics, config, num_params |
|
|
|
|
| |
| print("\n" + "="*60) |
| print("PHASE 1: Training SASRec Baseline") |
| print("="*60) |
|
|
| sasrec_metrics, sasrec_config, sasrec_params = train_model( |
| 'sasrec', CONFIGS['sasrec'], sequences, device |
| ) |
|
|
| print("\n" + "="*60) |
| print("PHASE 2: Training MARS (our method)") |
| print("="*60) |
|
|
| mars_metrics, mars_config, mars_params = train_model( |
| 'mars', CONFIGS['mars'], sequences, device |
| ) |
|
|
| |
| print_comparison(mars_metrics, sasrec_metrics, ks=[5, 10, 20, 50]) |
|
|
| |
| final_results = { |
| 'mars': { |
| 'metrics': mars_metrics, |
| 'config': mars_config, |
| 'params': mars_params, |
| }, |
| 'sasrec': { |
| 'metrics': sasrec_metrics, |
| 'config': sasrec_config, |
| 'params': sasrec_params, |
| }, |
| 'dataset': 'MovieLens-1M', |
| 'num_users': len(sequences), |
| } |
|
|
| with open('./checkpoints/final_results.json', 'w') as f: |
| json.dump(final_results, f, indent=2, default=str) |
|
|
| print("\nβ All training complete! Results saved to ./checkpoints/") |
|
|
| |
| try: |
| from huggingface_hub import HfApi, upload_folder |
| hub_model_id = os.environ.get('HF_HUB_MODEL_ID', 'CyberDancer/MARS-SeqRec') |
| api = HfApi() |
| api.create_repo(hub_model_id, exist_ok=True) |
| |
| |
| readme = f"""# MARS: Multi-scale Adaptive Recurrence with State compression |
| |
| An innovative method for **super long sequence modeling** in sequential recommendation. |
| |
| ## Key Innovations |
| |
| 1. **Temporal-Aware Delta Network (TADN)** β O(n) linear complexity with explicit temporal decay gating |
| 2. **Compressive Memory Tokens** β Fixed-size learnable memory as information bottleneck |
| 3. **Dual-Branch Architecture** β Long-term (TADN) + Short-term (Self-Attention) with adaptive fusion |
| 4. **Multi-Scale Temporal Encoding** β Captures daily/weekly/seasonal patterns |
| |
| ## Architecture |
| |
| ``` |
| Input: Full user interaction sequence + timestamps |
| | |
| v |
| [Item Embedding + Multi-Scale Temporal Encoding] |
| | |
| +---- Long-term Branch (TADN layers, O(n) complexity) |
| | | |
| | [Compressive Memory] β fixed-size memory tokens |
| | | |
| +---- Short-term Branch (Causal Self-Attention, recent K items) |
| | |
| v |
| [Adaptive Fusion Gate (per-user learned)] |
| | |
| v |
| [Prediction Head] β next item scores |
| ``` |
| |
| ## Results on MovieLens-1M |
| |
| | Model | Params | Max Seq | HR@10 | NDCG@10 | MRR@10 | |
| |-------|--------|---------|-------|---------|--------| |
| | SASRec (baseline) | {sasrec_params:,} | {sasrec_config['max_seq_len']} | {sasrec_metrics.get('HR@10', 0):.4f} | {sasrec_metrics.get('NDCG@10', 0):.4f} | {sasrec_metrics.get('MRR@10', 0):.4f} | |
| | **MARS (ours)** | {mars_params:,} | {mars_config['max_seq_len']} | {mars_metrics.get('HR@10', 0):.4f} | {mars_metrics.get('NDCG@10', 0):.4f} | {mars_metrics.get('MRR@10', 0):.4f} | |
| |
| ## Method Details |
| |
| ### Temporal-Aware Delta Network (TADN) |
| The TADN layer maintains a state matrix S that is updated with a delta rule incorporating temporal decay: |
| |
| ``` |
| S_t = S_{{t-1}} * (1 - g_t β Ξ²_t β k_t) + Ξ²_t β v_t β k_t |
| g_t = Ξ± * Ο(W_g * [h_t; Ξh_t]) * Ο_t + (1-Ξ±) * g_static |
| Ο_t = exp(-(t_current - t_behavior) / T) |
| ``` |
| |
| This gives O(n) complexity for processing arbitrarily long sequences, while explicitly modeling temporal decay patterns. |
| |
| ### Compressive Memory |
| Cross-attention memory queries compress the full encoded history into a fixed number of tokens, acting as an information bottleneck that denoises the sequence. |
| |
| ### Adaptive Fusion |
| A learned gate balances long-term (TADN) and short-term (attention) signals: |
| ``` |
| output = Ο(gate(long, short, memory)) * long + (1 - Ο) * short |
| ``` |
| |
| ## References |
| |
| Based on ideas from: |
| - HyTRec (2602.18283) β Hybrid Temporal-Aware Dual-Branch Attention |
| - Rec2PM (2602.11605) β Recurrent Preference Memory |
| - SIGMA (2408.11451) β Bidirectional Selective Gated Mamba |
| - HSTU (2402.17152) β Generative Recommenders |
| - SASRec (1808.09781) β Self-Attentive Sequential Recommendation |
| """ |
| |
| with open('./checkpoints/README.md', 'w') as f: |
| f.write(readme) |
| |
| |
| import shutil |
| for fname in ['model.py', 'data.py', 'evaluate.py', 'train.py', 'train_gpu.py']: |
| if os.path.exists(fname): |
| shutil.copy(fname, f'./checkpoints/{fname}') |
| |
| upload_folder( |
| folder_path='./checkpoints', |
| repo_id=hub_model_id, |
| commit_message="MARS: Multi-scale Adaptive Recurrence with State compression" |
| ) |
| print(f"\nβ Pushed to https://huggingface.co/{hub_model_id}") |
|
|
| except Exception as e: |
| print(f"Hub push failed: {e}") |
|
|
| print("\nDone!") |
|
|