| """ |
| Evaluation module for sequential recommendation models. |
| |
| Metrics: |
| - HR@K (Hit Rate): Whether the ground truth item appears in top-K |
| - NDCG@K (Normalized Discounted Cumulative Gain): Position-aware ranking quality |
| - MRR@K (Mean Reciprocal Rank): Reciprocal of the rank of the first correct item |
| """ |
|
|
| import torch |
| import numpy as np |
| from typing import Dict, List |
| import time |
|
|
|
|
| @torch.no_grad() |
| def evaluate_model( |
| model, |
| eval_loader, |
| num_items: int, |
| device: torch.device, |
| ks: List[int] = [5, 10, 20, 50], |
| full_ranking: bool = False, |
| ) -> Dict[str, float]: |
| """ |
| Evaluate a sequential recommendation model. |
| |
| Uses sampled metrics by default (positive + negatives from batch). |
| Set full_ranking=True for ranking against all items (slow but accurate). |
| |
| Args: |
| model: trained model |
| eval_loader: evaluation DataLoader |
| num_items: total number of items |
| device: torch device |
| ks: list of K values for metrics |
| full_ranking: if True, rank against all items |
| |
| Returns: |
| dict of metrics: HR@K, NDCG@K, MRR@K |
| """ |
| model.eval() |
| |
| all_hrs = {k: [] for k in ks} |
| all_ndcgs = {k: [] for k in ks} |
| all_mrrs = {k: [] for k in ks} |
| |
| start_time = time.time() |
| |
| for batch in eval_loader: |
| batch_device = {k: v.to(device) for k, v in batch.items()} |
| |
| |
| user_emb = model(batch_device) |
| |
| |
| pos_ids = batch_device['positive_ids'] |
| neg_ids = batch_device['negative_ids'] |
| |
| if full_ranking: |
| |
| all_item_embs = model.item_embeddings.weight[1:] |
| scores = torch.matmul(user_emb, all_item_embs.t()) |
| |
| |
| gt_indices = pos_ids - 1 |
| |
| for i in range(scores.shape[0]): |
| gt_score = scores[i, gt_indices[i]] |
| rank = (scores[i] > gt_score).sum().item() + 1 |
| |
| for k in ks: |
| all_hrs[k].append(1.0 if rank <= k else 0.0) |
| all_ndcgs[k].append(1.0 / np.log2(rank + 1) if rank <= k else 0.0) |
| all_mrrs[k].append(1.0 / rank if rank <= k else 0.0) |
| else: |
| |
| pos_emb = model.item_embeddings(pos_ids) |
| neg_emb = model.item_embeddings(neg_ids) |
| |
| pos_scores = (user_emb * pos_emb).sum(dim=-1, keepdim=True) |
| neg_scores = torch.einsum('bd,bnd->bn', user_emb, neg_emb) |
| |
| |
| all_scores = torch.cat([pos_scores, neg_scores], dim=1) |
| |
| |
| ranks = (all_scores > pos_scores).sum(dim=1) + 1 |
| |
| for k in ks: |
| hits = (ranks <= k).float() |
| ndcgs = torch.where( |
| ranks <= k, |
| 1.0 / torch.log2(ranks.float() + 1), |
| torch.zeros_like(ranks.float()) |
| ) |
| mrrs = torch.where( |
| ranks <= k, |
| 1.0 / ranks.float(), |
| torch.zeros_like(ranks.float()) |
| ) |
| |
| all_hrs[k].extend(hits.cpu().tolist()) |
| all_ndcgs[k].extend(ndcgs.cpu().tolist()) |
| all_mrrs[k].extend(mrrs.cpu().tolist()) |
| |
| eval_time = time.time() - start_time |
| |
| metrics = {} |
| for k in ks: |
| metrics[f'HR@{k}'] = np.mean(all_hrs[k]) |
| metrics[f'NDCG@{k}'] = np.mean(all_ndcgs[k]) |
| metrics[f'MRR@{k}'] = np.mean(all_mrrs[k]) |
| |
| metrics['eval_time'] = eval_time |
| |
| return metrics |
|
|
|
|
| @torch.no_grad() |
| def compute_metrics_full( |
| model, |
| eval_data, |
| num_items: int, |
| device: torch.device, |
| max_seq_len: int = 512, |
| ks: List[int] = [5, 10, 20, 50], |
| batch_size: int = 256, |
| ) -> Dict[str, float]: |
| """ |
| Full-ranking evaluation (all items). |
| More accurate but slower. |
| """ |
| model.eval() |
| |
| |
| all_item_embs = model.item_embeddings.weight[1:].to(device) |
| |
| all_hrs = {k: [] for k in ks} |
| all_ndcgs = {k: [] for k in ks} |
| |
| for i in range(0, len(eval_data), batch_size): |
| batch_data = eval_data[i:i+batch_size] |
| |
| |
| max_len = min(max(len(d['item_ids']) for d in batch_data), max_seq_len) |
| |
| item_ids_batch = [] |
| mask_batch = [] |
| gt_items = [] |
| |
| for d in batch_data: |
| ids = d['item_ids'][-max_len:] |
| pad_len = max_len - len(ids) |
| item_ids_batch.append([0] * pad_len + ids) |
| mask_batch.append([False] * pad_len + [True] * len(ids)) |
| gt_items.append(d['next_item']) |
| |
| batch = { |
| 'item_ids': torch.tensor(item_ids_batch, dtype=torch.long, device=device), |
| 'mask': torch.tensor(mask_batch, dtype=torch.bool, device=device), |
| } |
| |
| |
| user_emb = model(batch) |
| |
| |
| scores = torch.matmul(user_emb, all_item_embs.t()) |
| |
| |
| for j, gt in enumerate(gt_items): |
| gt_idx = gt - 1 |
| gt_score = scores[j, gt_idx] |
| rank = (scores[j] > gt_score).sum().item() + 1 |
| |
| for k in ks: |
| all_hrs[k].append(1.0 if rank <= k else 0.0) |
| all_ndcgs[k].append(1.0 / np.log2(rank + 1) if rank <= k else 0.0) |
| |
| metrics = {} |
| for k in ks: |
| metrics[f'HR@{k}'] = np.mean(all_hrs[k]) |
| metrics[f'NDCG@{k}'] = np.mean(all_ndcgs[k]) |
| |
| return metrics |
|
|
|
|
| def print_comparison(mars_results: Dict, sasrec_results: Dict, ks: List[int] = [5, 10, 20]): |
| """Pretty-print comparison between MARS and SASRec.""" |
| |
| print(f"\n{'='*70}") |
| print(f"{'Metric':<15} | {'MARS':>10} | {'SASRec':>10} | {'Δ':>10} | {'Δ%':>10}") |
| print(f"{'-'*70}") |
| |
| for k in ks: |
| for metric_name in [f'HR@{k}', f'NDCG@{k}', f'MRR@{k}']: |
| mars_val = mars_results.get(metric_name, 0) |
| sasrec_val = sasrec_results.get(metric_name, 0) |
| delta = mars_val - sasrec_val |
| delta_pct = (delta / sasrec_val * 100) if sasrec_val > 0 else 0 |
| |
| marker = '↑' if delta > 0 else '↓' if delta < 0 else '=' |
| print(f"{metric_name:<15} | {mars_val:>10.4f} | {sasrec_val:>10.4f} | " |
| f"{delta:>+10.4f} | {marker} {abs(delta_pct):>7.2f}%") |
| |
| print(f"{'='*70}") |
|
|