Thanh-Lam's picture
Fix: pipeline.instantiate() modifies in-place, use pipeline directly after instantiate
81c336f
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List, Any, Dict, Optional
import shutil
import torch
from pyannote.audio import Pipeline
from pyannote.audio.pipelines.utils.hook import ProgressHook
from .utils import ensure_audio_path, read_hf_token, convert_to_wav_16k
@dataclass
class Segment:
start: float
end: float
speaker: str
class DiarizationEngine:
"""Bao gói pipeline diarization của pyannote."""
def __init__(
self,
model_id: str = "pyannote/speaker-diarization-3.1",
token: str | None = None,
key_path: str | Path = "hugging_face_key.txt",
device: str = "auto",
segmentation_params: Optional[Dict[str, float]] = None,
clustering_params: Optional[Dict[str, float]] = None,
) -> None:
import sys
self.device = self._resolve_device(device)
auth_token = read_hf_token(token, key_path)
# Load pipeline with authentication
print(f"DEBUG: Loading model {model_id} with token={'***' if auth_token else 'None'}", file=sys.stderr)
pipeline = Pipeline.from_pretrained(model_id, use_auth_token=auth_token)
if pipeline is None:
raise RuntimeError(
f"Failed to load pipeline '{model_id}'. "
f"IMPORTANT: You need to accept terms for ALL these models:\n"
f" 1. https://hf.co/pyannote/speaker-diarization-3.1\n"
f" 2. https://hf.co/pyannote/segmentation-3.0\n"
f" 3. https://hf.co/pyannote/embedding\n"
f"After accepting, add HF_TOKEN to Space secrets with your token."
)
# Get default parameters and customize if needed
try:
params = pipeline.default_parameters()
except NotImplementedError:
# If no default parameters, try to instantiate without params
params = {}
print(f"DEBUG: Pipeline params: {params}", file=sys.stderr)
# Update segmentation params if available
if "segmentation" in params and segmentation_params:
params["segmentation"].update(segmentation_params)
if "clustering" in params and clustering_params:
params["clustering"].update(clustering_params)
# Instantiate pipeline with parameters (modifies in-place and returns self)
print(f"DEBUG: Instantiating pipeline...", file=sys.stderr)
pipeline.instantiate(params)
print(f"DEBUG: Pipeline instantiated successfully", file=sys.stderr)
# Store and move to device
self.pipeline = pipeline
self.pipeline.to(self.device)
print(f"DEBUG: Pipeline moved to device: {self.device}", file=sys.stderr)
@staticmethod
def _resolve_device(device: str) -> torch.device:
if device == "cpu":
return torch.device("cpu")
if device == "cuda":
if not torch.cuda.is_available():
raise RuntimeError("Yêu cầu CUDA nhưng không phát hiện GPU khả dụng.")
return torch.device("cuda")
if device == "auto":
return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
raise ValueError("Giá trị device hợp lệ: auto, cpu, cuda.")
def diarize(
self, audio_path: str | Path, show_progress: bool = True, keep_audio: bool = False
):
audio_path = ensure_audio_path(audio_path)
prepared_path, tmpdir = convert_to_wav_16k(audio_path)
try:
if show_progress:
with ProgressHook() as hook:
result = self.pipeline(str(prepared_path), hook=hook)
else:
result = self.pipeline(str(prepared_path))
if keep_audio:
return result, prepared_path, tmpdir
return result
finally:
if tmpdir and not keep_audio:
shutil.rmtree(tmpdir, ignore_errors=True)
@staticmethod
def _get_annotation(diarization: Any):
"""Hỗ trợ cả dạng trả về cũ (Annotation) và mới (có speaker_diarization)."""
if hasattr(diarization, "itertracks"):
return diarization
if hasattr(diarization, "speaker_diarization"):
return diarization.speaker_diarization
raise TypeError("Output pipeline không có Annotation hoặc speaker_diarization.")
def to_segments(self, diarization: Any) -> List[Segment]:
annotation = self._get_annotation(diarization)
segments: List[Segment] = []
for segment, _, speaker in annotation.itertracks(yield_label=True):
segments.append(
Segment(
start=float(segment.start),
end=float(segment.end),
speaker=str(speaker),
)
)
return segments
def save_rttm(self, diarization: Any, output_path: str | Path) -> Path:
annotation = self._get_annotation(diarization)
path = Path(output_path)
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as f:
annotation.write_rttm(f)
return path
def run(self, audio_path: str | Path, show_progress: bool = True) -> List[Segment]:
"""Chạy pipeline và trả về danh sách segment."""
diarization = self.diarize(audio_path, show_progress=show_progress)
return self.to_segments(diarization)