Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Iterable, List, Any, Dict, Optional | |
| import shutil | |
| import torch | |
| from pyannote.audio import Pipeline | |
| from pyannote.audio.pipelines.utils.hook import ProgressHook | |
| from .utils import ensure_audio_path, read_hf_token, convert_to_wav_16k | |
| class Segment: | |
| start: float | |
| end: float | |
| speaker: str | |
| class DiarizationEngine: | |
| """Bao gói pipeline diarization của pyannote.""" | |
| def __init__( | |
| self, | |
| model_id: str = "pyannote/speaker-diarization-3.1", | |
| token: str | None = None, | |
| key_path: str | Path = "hugging_face_key.txt", | |
| device: str = "auto", | |
| segmentation_params: Optional[Dict[str, float]] = None, | |
| clustering_params: Optional[Dict[str, float]] = None, | |
| ) -> None: | |
| import sys | |
| self.device = self._resolve_device(device) | |
| auth_token = read_hf_token(token, key_path) | |
| # Load pipeline with authentication | |
| print(f"DEBUG: Loading model {model_id} with token={'***' if auth_token else 'None'}", file=sys.stderr) | |
| pipeline = Pipeline.from_pretrained(model_id, use_auth_token=auth_token) | |
| if pipeline is None: | |
| raise RuntimeError( | |
| f"Failed to load pipeline '{model_id}'. " | |
| f"IMPORTANT: You need to accept terms for ALL these models:\n" | |
| f" 1. https://hf.co/pyannote/speaker-diarization-3.1\n" | |
| f" 2. https://hf.co/pyannote/segmentation-3.0\n" | |
| f" 3. https://hf.co/pyannote/embedding\n" | |
| f"After accepting, add HF_TOKEN to Space secrets with your token." | |
| ) | |
| # Get default parameters and customize if needed | |
| try: | |
| params = pipeline.default_parameters() | |
| except NotImplementedError: | |
| # If no default parameters, try to instantiate without params | |
| params = {} | |
| print(f"DEBUG: Pipeline params: {params}", file=sys.stderr) | |
| # Update segmentation params if available | |
| if "segmentation" in params and segmentation_params: | |
| params["segmentation"].update(segmentation_params) | |
| if "clustering" in params and clustering_params: | |
| params["clustering"].update(clustering_params) | |
| # Instantiate pipeline with parameters (modifies in-place and returns self) | |
| print(f"DEBUG: Instantiating pipeline...", file=sys.stderr) | |
| pipeline.instantiate(params) | |
| print(f"DEBUG: Pipeline instantiated successfully", file=sys.stderr) | |
| # Store and move to device | |
| self.pipeline = pipeline | |
| self.pipeline.to(self.device) | |
| print(f"DEBUG: Pipeline moved to device: {self.device}", file=sys.stderr) | |
| def _resolve_device(device: str) -> torch.device: | |
| if device == "cpu": | |
| return torch.device("cpu") | |
| if device == "cuda": | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("Yêu cầu CUDA nhưng không phát hiện GPU khả dụng.") | |
| return torch.device("cuda") | |
| if device == "auto": | |
| return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| raise ValueError("Giá trị device hợp lệ: auto, cpu, cuda.") | |
| def diarize( | |
| self, audio_path: str | Path, show_progress: bool = True, keep_audio: bool = False | |
| ): | |
| audio_path = ensure_audio_path(audio_path) | |
| prepared_path, tmpdir = convert_to_wav_16k(audio_path) | |
| try: | |
| if show_progress: | |
| with ProgressHook() as hook: | |
| result = self.pipeline(str(prepared_path), hook=hook) | |
| else: | |
| result = self.pipeline(str(prepared_path)) | |
| if keep_audio: | |
| return result, prepared_path, tmpdir | |
| return result | |
| finally: | |
| if tmpdir and not keep_audio: | |
| shutil.rmtree(tmpdir, ignore_errors=True) | |
| def _get_annotation(diarization: Any): | |
| """Hỗ trợ cả dạng trả về cũ (Annotation) và mới (có speaker_diarization).""" | |
| if hasattr(diarization, "itertracks"): | |
| return diarization | |
| if hasattr(diarization, "speaker_diarization"): | |
| return diarization.speaker_diarization | |
| raise TypeError("Output pipeline không có Annotation hoặc speaker_diarization.") | |
| def to_segments(self, diarization: Any) -> List[Segment]: | |
| annotation = self._get_annotation(diarization) | |
| segments: List[Segment] = [] | |
| for segment, _, speaker in annotation.itertracks(yield_label=True): | |
| segments.append( | |
| Segment( | |
| start=float(segment.start), | |
| end=float(segment.end), | |
| speaker=str(speaker), | |
| ) | |
| ) | |
| return segments | |
| def save_rttm(self, diarization: Any, output_path: str | Path) -> Path: | |
| annotation = self._get_annotation(diarization) | |
| path = Path(output_path) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| with path.open("w", encoding="utf-8") as f: | |
| annotation.write_rttm(f) | |
| return path | |
| def run(self, audio_path: str | Path, show_progress: bool = True) -> List[Segment]: | |
| """Chạy pipeline và trả về danh sách segment.""" | |
| diarization = self.diarize(audio_path, show_progress=show_progress) | |
| return self.to_segments(diarization) | |