MARS-SeqRec / evaluate.py
CyberDancer's picture
MARS: Multi-scale Adaptive Recurrence with State compression
2319f81 verified
"""
Evaluation module for sequential recommendation models.
Metrics:
- HR@K (Hit Rate): Whether the ground truth item appears in top-K
- NDCG@K (Normalized Discounted Cumulative Gain): Position-aware ranking quality
- MRR@K (Mean Reciprocal Rank): Reciprocal of the rank of the first correct item
"""
import torch
import numpy as np
from typing import Dict, List
import time
@torch.no_grad()
def evaluate_model(
model,
eval_loader,
num_items: int,
device: torch.device,
ks: List[int] = [5, 10, 20, 50],
full_ranking: bool = False,
) -> Dict[str, float]:
"""
Evaluate a sequential recommendation model.
Uses sampled metrics by default (positive + negatives from batch).
Set full_ranking=True for ranking against all items (slow but accurate).
Args:
model: trained model
eval_loader: evaluation DataLoader
num_items: total number of items
device: torch device
ks: list of K values for metrics
full_ranking: if True, rank against all items
Returns:
dict of metrics: HR@K, NDCG@K, MRR@K
"""
model.eval()
all_hrs = {k: [] for k in ks}
all_ndcgs = {k: [] for k in ks}
all_mrrs = {k: [] for k in ks}
start_time = time.time()
for batch in eval_loader:
batch_device = {k: v.to(device) for k, v in batch.items()}
# Get user embeddings
user_emb = model(batch_device) # (B, D)
# Get item embeddings
pos_ids = batch_device['positive_ids'] # (B,)
neg_ids = batch_device['negative_ids'] # (B, num_neg)
if full_ranking:
# Rank against ALL items
all_item_embs = model.item_embeddings.weight[1:] # Skip padding (0)
scores = torch.matmul(user_emb, all_item_embs.t()) # (B, num_items)
# Ground truth indices (0-indexed)
gt_indices = pos_ids - 1 # Convert from 1-indexed to 0-indexed
for i in range(scores.shape[0]):
gt_score = scores[i, gt_indices[i]]
rank = (scores[i] > gt_score).sum().item() + 1
for k in ks:
all_hrs[k].append(1.0 if rank <= k else 0.0)
all_ndcgs[k].append(1.0 / np.log2(rank + 1) if rank <= k else 0.0)
all_mrrs[k].append(1.0 / rank if rank <= k else 0.0)
else:
# Sampled ranking: positive + negatives
pos_emb = model.item_embeddings(pos_ids) # (B, D)
neg_emb = model.item_embeddings(neg_ids) # (B, num_neg, D)
pos_scores = (user_emb * pos_emb).sum(dim=-1, keepdim=True) # (B, 1)
neg_scores = torch.einsum('bd,bnd->bn', user_emb, neg_emb) # (B, num_neg)
# All scores: positive first, then negatives
all_scores = torch.cat([pos_scores, neg_scores], dim=1) # (B, 1+num_neg)
# Rank of positive item (0-indexed)
ranks = (all_scores > pos_scores).sum(dim=1) + 1 # (B,)
for k in ks:
hits = (ranks <= k).float()
ndcgs = torch.where(
ranks <= k,
1.0 / torch.log2(ranks.float() + 1),
torch.zeros_like(ranks.float())
)
mrrs = torch.where(
ranks <= k,
1.0 / ranks.float(),
torch.zeros_like(ranks.float())
)
all_hrs[k].extend(hits.cpu().tolist())
all_ndcgs[k].extend(ndcgs.cpu().tolist())
all_mrrs[k].extend(mrrs.cpu().tolist())
eval_time = time.time() - start_time
metrics = {}
for k in ks:
metrics[f'HR@{k}'] = np.mean(all_hrs[k])
metrics[f'NDCG@{k}'] = np.mean(all_ndcgs[k])
metrics[f'MRR@{k}'] = np.mean(all_mrrs[k])
metrics['eval_time'] = eval_time
return metrics
@torch.no_grad()
def compute_metrics_full(
model,
eval_data,
num_items: int,
device: torch.device,
max_seq_len: int = 512,
ks: List[int] = [5, 10, 20, 50],
batch_size: int = 256,
) -> Dict[str, float]:
"""
Full-ranking evaluation (all items).
More accurate but slower.
"""
model.eval()
# Get all item embeddings
all_item_embs = model.item_embeddings.weight[1:].to(device) # (num_items, D)
all_hrs = {k: [] for k in ks}
all_ndcgs = {k: [] for k in ks}
for i in range(0, len(eval_data), batch_size):
batch_data = eval_data[i:i+batch_size]
# Prepare batch
max_len = min(max(len(d['item_ids']) for d in batch_data), max_seq_len)
item_ids_batch = []
mask_batch = []
gt_items = []
for d in batch_data:
ids = d['item_ids'][-max_len:]
pad_len = max_len - len(ids)
item_ids_batch.append([0] * pad_len + ids)
mask_batch.append([False] * pad_len + [True] * len(ids))
gt_items.append(d['next_item'])
batch = {
'item_ids': torch.tensor(item_ids_batch, dtype=torch.long, device=device),
'mask': torch.tensor(mask_batch, dtype=torch.bool, device=device),
}
# Encode
user_emb = model(batch) # (B, D)
# Score all items
scores = torch.matmul(user_emb, all_item_embs.t()) # (B, num_items)
# Compute metrics
for j, gt in enumerate(gt_items):
gt_idx = gt - 1 # 0-indexed
gt_score = scores[j, gt_idx]
rank = (scores[j] > gt_score).sum().item() + 1
for k in ks:
all_hrs[k].append(1.0 if rank <= k else 0.0)
all_ndcgs[k].append(1.0 / np.log2(rank + 1) if rank <= k else 0.0)
metrics = {}
for k in ks:
metrics[f'HR@{k}'] = np.mean(all_hrs[k])
metrics[f'NDCG@{k}'] = np.mean(all_ndcgs[k])
return metrics
def print_comparison(mars_results: Dict, sasrec_results: Dict, ks: List[int] = [5, 10, 20]):
"""Pretty-print comparison between MARS and SASRec."""
print(f"\n{'='*70}")
print(f"{'Metric':<15} | {'MARS':>10} | {'SASRec':>10} | {'Δ':>10} | {'Δ%':>10}")
print(f"{'-'*70}")
for k in ks:
for metric_name in [f'HR@{k}', f'NDCG@{k}', f'MRR@{k}']:
mars_val = mars_results.get(metric_name, 0)
sasrec_val = sasrec_results.get(metric_name, 0)
delta = mars_val - sasrec_val
delta_pct = (delta / sasrec_val * 100) if sasrec_val > 0 else 0
marker = '↑' if delta > 0 else '↓' if delta < 0 else '='
print(f"{metric_name:<15} | {mars_val:>10.4f} | {sasrec_val:>10.4f} | "
f"{delta:>+10.4f} | {marker} {abs(delta_pct):>7.2f}%")
print(f"{'='*70}")