Thanh-Lam commited on
Commit
4193ca8
·
1 Parent(s): f7fe2af

Update project for Hugging Face Space deployment

Browse files
.gitignore CHANGED
@@ -1 +1,19 @@
1
  hugging_face_key.txt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  hugging_face_key.txt
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ .Python
7
+ *.so
8
+ *.egg
9
+ *.egg-info/
10
+ dist/
11
+ build/
12
+ *.log
13
+ .DS_Store
14
+ .vscode/
15
+ .idea/
16
+ *.swp
17
+ *.swo
18
+ *~
19
+ path_demo.txt
README.md CHANGED
@@ -1,3 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
1
  # Vietnamese_Diarization
2
 
3
  Kho mã mẫu diarization tiếng Việt dùng pyannote/speaker-diarization-community-1.
@@ -9,9 +20,15 @@ Kho mã mẫu diarization tiếng Việt dùng pyannote/speaker-diarization-comm
9
  - Hugging Face access token (dán vào hugging_face_key.txt hoặc đặt biến môi trường HUGGINGFACE_TOKEN/HUGGINGFACE_ACCESS_TOKEN)
10
 
11
  ## Cài đặt nhanh
12
- - Cài thư viện: `pip install pyannote.audio` hoặc `uv add pyannote.audio`
13
  - Đảm bảo ffmpeg đã có trong PATH
14
 
 
 
 
 
 
 
15
  ## Chạy mẫu
16
  - Diarization và in kết quả: `python infer.py path/to/audio.wav`
17
  - Lưu thêm RTTM: `python infer.py path/to/audio.wav --rttm outputs/audio.rttm`
@@ -25,7 +42,7 @@ segments = diarize_file("audio.wav", device="auto")
25
  ```
26
 
27
  ## Cấu trúc
28
- - app.py: API Python đơn giản
29
  - infer.py: CLI chạy diarization
30
  - src/models.py: Bao gói pipeline pyannote
31
  - src/utils.py: Hỗ trợ đọc token, định dạng kết quả
 
1
+ ---
2
+ title: Diarization Labeling
3
+ emoji: "\U0001F4E3"
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: "4.39.0"
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
  # Vietnamese_Diarization
13
 
14
  Kho mã mẫu diarization tiếng Việt dùng pyannote/speaker-diarization-community-1.
 
20
  - Hugging Face access token (dán vào hugging_face_key.txt hoặc đặt biến môi trường HUGGINGFACE_TOKEN/HUGGINGFACE_ACCESS_TOKEN)
21
 
22
  ## Cài đặt nhanh
23
+ - Cài thư viện: `pip install pyannote.audio gradio yt-dlp` hoặc `uv add pyannote.audio gradio yt-dlp`
24
  - Đảm bảo ffmpeg đã có trong PATH
25
 
26
+ ## Chạy Gradio
27
+ - Lệnh: `python app.py`
28
+ - Trình duyệt mở tại http://localhost:7860 (hoặc địa chỉ máy chủ nếu chạy từ xa)
29
+ - Điền token nếu chưa đặt sẵn, tải file âm thanh hoặc dán URL YouTube/TikTok, chọn thiết bị rồi nhấn Chạy
30
+ - Bảng kết quả hiển thị dạng phút:giây; có thể gán nhãn giới tính (nam/nữ), vùng miền (bắc/trung/nam) và transcription, sau đó bấm "Tách và tải" để nhận zip gồm các đoạn WAV và metadata.csv
31
+
32
  ## Chạy mẫu
33
  - Diarization và in kết quả: `python infer.py path/to/audio.wav`
34
  - Lưu thêm RTTM: `python infer.py path/to/audio.wav --rttm outputs/audio.rttm`
 
42
  ```
43
 
44
  ## Cấu trúc
45
+ - app.py: API Python giao diện Gradio
46
  - infer.py: CLI chạy diarization
47
  - src/models.py: Bao gói pipeline pyannote
48
  - src/utils.py: Hỗ trợ đọc token, định dạng kết quả
__pycache__/app.cpython-312.pyc CHANGED
Binary files a/__pycache__/app.cpython-312.pyc and b/__pycache__/app.cpython-312.pyc differ
 
app.py CHANGED
@@ -1,9 +1,29 @@
1
  from __future__ import annotations
2
 
 
 
3
  from pathlib import Path
4
- from typing import List
 
 
 
 
 
 
5
 
6
  from src.models import DiarizationEngine, Segment
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
 
9
  def diarize_file(
@@ -17,20 +37,437 @@ def diarize_file(
17
  return engine.run(audio_path, show_progress=show_progress)
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  if __name__ == "__main__":
21
- # Ví dụ nhanh: python app.py audio.wav
22
- import argparse
23
-
24
- parser = argparse.ArgumentParser(description=" dụ chạy diarization qua hàm Python.")
25
- parser.add_argument("audio", help="Đường dẫn tới file âm thanh")
26
- parser.add_argument(
27
- "--device",
28
- choices=["auto", "cpu", "cuda"],
29
- default="auto",
30
- help="Thiết bị ưu tiên khi khởi tạo pipeline",
31
- )
32
- args = parser.parse_args()
33
-
34
- segments = diarize_file(args.audio, device=args.device)
35
- for idx, seg in enumerate(segments, start=1):
36
- print(f"{idx:02d} | {seg.start:7.2f}s -> {seg.end:7.2f}s | speaker {seg.speaker}")
 
1
  from __future__ import annotations
2
 
3
+ import functools
4
+ import tempfile
5
  from pathlib import Path
6
+ from typing import List, Any
7
+ import shutil
8
+ import csv
9
+ import subprocess
10
+ import zipfile
11
+
12
+ import gradio as gr
13
 
14
  from src.models import DiarizationEngine, Segment
15
+ from src.utils import (
16
+ export_segments_json,
17
+ format_segments_table,
18
+ seconds_to_mmss,
19
+ download_audio_from_url,
20
+ )
21
+
22
+ DEFAULT_TOKEN_SENTINEL = "__FROM_FILE_OR_ENV__"
23
+ GENDER_MAP = {"nam": "0", "male": "0", "nữ": "1", "nu": "1", "female": "1"}
24
+ REGION_MAP = {"bắc": "0", "bac": "0", "north": "0", "trung": "1", "central": "1", "nam": "2", "south": "2"}
25
+ ALLOWED_GENDER = {"nam", "nữ", "nu", "male", "female"}
26
+ ALLOWED_REGION = {"bắc", "trung", "nam", "bac", "north", "central", "south"}
27
 
28
 
29
  def diarize_file(
 
37
  return engine.run(audio_path, show_progress=show_progress)
38
 
39
 
40
+ def _token_key(raw_token: str | None) -> str:
41
+ cleaned = raw_token.strip() if raw_token else None
42
+ return cleaned if cleaned else DEFAULT_TOKEN_SENTINEL
43
+
44
+
45
+ @functools.lru_cache(maxsize=2)
46
+ def _get_engine(token_key: str, device: str) -> DiarizationEngine:
47
+ token_value = None if token_key == DEFAULT_TOKEN_SENTINEL else token_key
48
+ return DiarizationEngine(token=token_value, device=device)
49
+
50
+
51
+ def _diarize_action(
52
+ audio_path: str | None,
53
+ hf_token: str | None,
54
+ device: str,
55
+ url: str | None = None,
56
+ ):
57
+ if not audio_path and not url:
58
+ empty_state = ["", "", "", ""]
59
+ return "Vui lòng tải file âm thanh hoặc nhập URL.", None, None, [], [], empty_state, ""
60
+ try:
61
+ downloaded_path = None
62
+ download_tmp = None
63
+ audio_input = audio_path
64
+ if url:
65
+ downloaded_path, download_tmp = download_audio_from_url(url)
66
+ audio_input = str(downloaded_path)
67
+
68
+ engine = _get_engine(_token_key(hf_token), device)
69
+ diarization, prepared_path, prep_tmpdir = engine.diarize(
70
+ audio_input, show_progress=False, keep_audio=True
71
+ )
72
+ segments = engine.to_segments(diarization)
73
+ dict_segments = [
74
+ {"start": float(seg.start), "end": float(seg.end), "speaker": seg.speaker}
75
+ for seg in segments
76
+ ]
77
+ table = format_segments_table(dict_segments)
78
+
79
+ output_tmp = Path(tempfile.mkdtemp(prefix="diarization_out_"))
80
+ rttm_path = engine.save_rttm(diarization, output_tmp / "output.rttm")
81
+ json_path = export_segments_json(dict_segments, output_tmp / "segments.json")
82
+
83
+ df_rows = [
84
+ [
85
+ seconds_to_mmss(seg["start"]),
86
+ seconds_to_mmss(seg["end"]),
87
+ seg["speaker"],
88
+ "",
89
+ "",
90
+ "",
91
+ ]
92
+ for seg in dict_segments
93
+ ]
94
+
95
+ source_name = Path(audio_input).stem if audio_input else "unknown"
96
+ audio_state = [
97
+ str(prepared_path),
98
+ str(prep_tmpdir) if prep_tmpdir else "",
99
+ source_name,
100
+ str(download_tmp) if download_tmp else "",
101
+ ]
102
+ return (
103
+ table,
104
+ str(rttm_path),
105
+ str(json_path),
106
+ df_rows,
107
+ dict_segments,
108
+ audio_state,
109
+ str(prepared_path),
110
+ )
111
+ except Exception as exc: # pragma: no cover - hiển thị lỗi cho người dùng giao diện
112
+ empty_state = ["", "", "", ""]
113
+ return f"Lỗi: {exc}", None, None, [], [], empty_state, ""
114
+
115
+
116
+ def _normalize_label(value: Any) -> str:
117
+ return str(value).strip().lower() if value is not None else ""
118
+
119
+
120
+ def _table_to_rows(table_data: Any) -> list[list[Any]]:
121
+ """Chuyển giá trị DataFrame/ndarray/list sang list of list để thao tác."""
122
+ if table_data is None:
123
+ return []
124
+ if hasattr(table_data, "values"): # pandas DataFrame hoặc ndarray
125
+ try:
126
+ return table_data.values.tolist()
127
+ except Exception:
128
+ pass
129
+ if isinstance(table_data, list):
130
+ return table_data
131
+ if isinstance(table_data, tuple):
132
+ return list(table_data)
133
+ return []
134
+
135
+
136
+ def _select_row_action(evt: gr.SelectData):
137
+ row_idx = evt.index[0] if evt and evt.index else -1
138
+ if row_idx is None or row_idx < 0:
139
+ return "Chưa chọn hàng", -1
140
+ return f"Đang chọn hàng {row_idx + 1}", row_idx
141
+
142
+
143
+ def _apply_dropdown_action(
144
+ table_rows: list[list[Any]] | None,
145
+ selected_idx: int,
146
+ gender_choice: str,
147
+ region_choice: str,
148
+ transcription_text: str,
149
+ ):
150
+ rows = _table_to_rows(table_rows)
151
+ if selected_idx is None or selected_idx < 0 or selected_idx >= len(rows):
152
+ return rows, "Chọn một hàng trước."
153
+
154
+ gender_val = _normalize_label(gender_choice)
155
+ region_val = _normalize_label(region_choice)
156
+ if gender_val and gender_val not in ALLOWED_GENDER:
157
+ return rows, "Giới tính chỉ được chọn nam/nữ."
158
+ if region_val and region_val not in ALLOWED_REGION:
159
+ return rows, "Vùng miền chỉ được chọn bắc/trung/nam."
160
+
161
+ new_rows = [list(r) for r in rows]
162
+ # row order: start_mmss, end_mmss, speaker, gender, region, transcription
163
+ if len(new_rows[selected_idx]) < 6:
164
+ new_rows[selected_idx] = (new_rows[selected_idx] + [""] * 6)[:6]
165
+ new_rows[selected_idx][3] = gender_val
166
+ new_rows[selected_idx][4] = region_val
167
+ new_rows[selected_idx][5] = transcription_text
168
+ return new_rows, f"Đã áp dụng cho hàng {selected_idx + 1}."
169
+
170
+
171
+ def _import_archives_action(files: list[Any] | None, output_root: str = "outputs"):
172
+ if not files:
173
+ return "Chọn ít nhất một file ZIP.", None
174
+ merged_root = Path(tempfile.mkdtemp(prefix="merged_zip_"))
175
+ merged_data = merged_root / "merged"
176
+ merged_data.mkdir(parents=True, exist_ok=True)
177
+ meta_all = merged_data / "metadata_all.csv"
178
+ appended = 0
179
+ extracted = 0
180
+
181
+ with meta_all.open("w", newline="", encoding="utf-8") as csvfile:
182
+ writer = csv.writer(csvfile)
183
+ writer.writerow(
184
+ [
185
+ "id",
186
+ "file_name",
187
+ "start_mmss",
188
+ "end_mmss",
189
+ "gender",
190
+ "region",
191
+ "transcription",
192
+ "speaker",
193
+ "duration_sec",
194
+ "source",
195
+ ]
196
+ )
197
+
198
+ for f in files:
199
+ zip_path = Path(getattr(f, "name", f))
200
+ if not zip_path.exists() and isinstance(f, dict) and "name" in f:
201
+ zip_path = Path(f["name"])
202
+ if not zip_path.exists():
203
+ continue
204
+ extracted += 1
205
+ dest_dir = merged_data / zip_path.stem
206
+ dest_dir.mkdir(parents=True, exist_ok=True)
207
+ with zipfile.ZipFile(zip_path, "r") as zf:
208
+ zf.extractall(dest_dir)
209
+
210
+ meta_csv = dest_dir / "metadata.csv"
211
+ if meta_csv.exists():
212
+ with meta_all.open("a", newline="", encoding="utf-8") as out_csv:
213
+ writer = csv.writer(out_csv)
214
+ with meta_csv.open("r", encoding="utf-8") as src:
215
+ next(src, None) # skip header
216
+ for line in src:
217
+ row = line.strip().split(",")
218
+ if row and any(row):
219
+ writer.writerow(row + [zip_path.stem])
220
+ appended += 1
221
+
222
+ merged_zip = shutil.make_archive(str(merged_data), "zip", merged_data)
223
+ status = f"Đã gộp {extracted} ZIP, metadata_all.csv có thêm {appended} dòng. Tải merged.zip."
224
+ return status, merged_zip
225
+
226
+
227
+ def _split_segments_action(
228
+ table_rows: list[list[Any]] | None,
229
+ segments_state: list[dict],
230
+ audio_state: list[str],
231
+ ):
232
+ if not shutil.which("ffmpeg"):
233
+ return "Cần cài ffmpeg để tách đoạn.", None
234
+ if not segments_state:
235
+ return "Chạy diarization trước.", None
236
+ if not audio_state or len(audio_state) < 1 or not audio_state[0]:
237
+ return "Thiếu thông tin file đã chuẩn hóa.", None
238
+
239
+ prepared_path = Path(audio_state[0])
240
+ tmp_root = Path(tempfile.mkdtemp(prefix="segments_"))
241
+ output_dir = tmp_root / "data"
242
+ output_dir.mkdir(parents=True, exist_ok=True)
243
+ metadata_path = output_dir / "metadata.csv"
244
+ rows = _table_to_rows(table_rows)
245
+
246
+ try:
247
+ with metadata_path.open("w", newline="", encoding="utf-8") as csvfile:
248
+ writer = csv.writer(csvfile)
249
+ writer.writerow(
250
+ [
251
+ "id",
252
+ "file_name",
253
+ "start_mmss",
254
+ "end_mmss",
255
+ "gender",
256
+ "region",
257
+ "transcription",
258
+ "speaker",
259
+ "duration_sec",
260
+ ]
261
+ )
262
+
263
+ for idx, seg in enumerate(segments_state):
264
+ row = rows[idx] if idx < len(rows) else []
265
+ # row order: start_mmss, end_mmss, speaker, gender, region, transcription
266
+ gender = _normalize_label(row[3] if len(row) > 3 else "")
267
+ region = _normalize_label(row[4] if len(row) > 4 else "")
268
+ transcription = row[5] if len(row) > 5 else ""
269
+
270
+ if gender and gender not in ALLOWED_GENDER:
271
+ return f"Lỗi: giới tính hàng {idx+1} phải là nam/nữ.", None
272
+ if region and region not in ALLOWED_REGION:
273
+ return f"Lỗi: vùng miền hàng {idx+1} phải là bắc/trung/nam.", None
274
+
275
+ gender_code = GENDER_MAP.get(gender, "")
276
+ region_code = REGION_MAP.get(region, "")
277
+ seg_id = f"id_{gender_code or 'x'}_{region_code or 'x'}_{idx:03d}"
278
+ gender_disp = "Nam" if gender_code == "0" else "Nữ" if gender_code == "1" else gender
279
+ region_disp = (
280
+ "Bắc"
281
+ if region_code == "0"
282
+ else "Trung"
283
+ if region_code == "1"
284
+ else "Nam"
285
+ if region_code == "2"
286
+ else region
287
+ )
288
+
289
+ start = float(seg["start"])
290
+ end = float(seg["end"])
291
+ duration = max(end - start, 0.0)
292
+ out_file = output_dir / f"{seg_id}.wav"
293
+
294
+ cmd = [
295
+ "ffmpeg",
296
+ "-y",
297
+ "-i",
298
+ str(prepared_path),
299
+ "-ss",
300
+ f"{start:.3f}",
301
+ "-to",
302
+ f"{end:.3f}",
303
+ "-ac",
304
+ "1",
305
+ "-ar",
306
+ "16000",
307
+ "-vn",
308
+ "-f",
309
+ "wav",
310
+ str(out_file),
311
+ ]
312
+ result = subprocess.run(cmd, capture_output=True, text=True)
313
+ if result.returncode != 0:
314
+ stderr = result.stderr.strip()
315
+ raise RuntimeError(f"ffmpeg lỗi khi tách đoạn {idx}: {stderr}")
316
+
317
+ writer.writerow(
318
+ [
319
+ seg_id,
320
+ out_file.name,
321
+ seconds_to_mmss(start),
322
+ seconds_to_mmss(end),
323
+ gender_disp,
324
+ region_disp,
325
+ transcription,
326
+ seg.get("speaker", ""),
327
+ duration,
328
+ ]
329
+ )
330
+
331
+ archive = shutil.make_archive(str(output_dir), "zip", output_dir)
332
+ return f"Tách {len(segments_state)} đoạn thành công. Tải zip bên dưới.", archive
333
+ except Exception as exc: # pragma: no cover
334
+ return f"Lỗi khi tách: {exc}", None
335
+
336
+
337
+ def build_interface() -> gr.Blocks:
338
+ with gr.Blocks(title="Vietnamese Diarization", analytics_enabled=False) as demo:
339
+ gr.Markdown(
340
+ """
341
+ ### Diarization tiếng Việt với pyannote
342
+ - Tải file âm thanh, điền Hugging Face access token (hoặc để trống nếu đã đặt trong môi trường/file).
343
+ - Chọn thiết bị chạy, nhấn Chạy. Kết quả hiển thị dạng bảng và file RTTM/JSON tải về.
344
+ """
345
+ )
346
+
347
+ segments_state = gr.State([])
348
+ audio_state = gr.State({})
349
+
350
+ with gr.Row():
351
+ with gr.Column():
352
+ audio_input = gr.Audio(label="Tải file audio (tùy chọn)", type="filepath")
353
+ playback = gr.Audio(
354
+ label="Audio đã chuyển đổi/đang dùng",
355
+ type="filepath",
356
+ interactive=False,
357
+ )
358
+ with gr.Column():
359
+ url_input = gr.Textbox(
360
+ label="URL YouTube/TikTok (tùy chọn)",
361
+ placeholder="Dán link video nếu không tải file",
362
+ )
363
+ token_input = gr.Textbox(
364
+ label="Hugging Face access token (tùy chọn)",
365
+ type="password",
366
+ placeholder="Để trống nếu đã cấu hình môi trường hoặc hugging_face_key.txt",
367
+ )
368
+ device_input = gr.Dropdown(
369
+ choices=["auto", "cpu", "cuda"],
370
+ value="auto",
371
+ label="Thiết bị",
372
+ )
373
+ run_btn = gr.Button("Chạy diarization")
374
+
375
+ gr.Markdown(
376
+ """
377
+ #### Gán nhãn và tách đoạn
378
+ - Chọn các ô gender (nam/nữ) và region (bắc/trung/nam) bằng dropdown trong bảng, transcription nhập tay.
379
+ - Nhấn "Tách và tải" để tải zip gồm các đoạn WAV và metadata.csv (không lưu lại trên server).
380
+ """
381
+ )
382
+ segment_df = gr.DataFrame(
383
+ headers=[
384
+ "start_mmss",
385
+ "end_mmss",
386
+ "speaker",
387
+ "gender",
388
+ "region",
389
+ "transcription",
390
+ ],
391
+ datatype=["str", "str", "str", "str", "str", "str"],
392
+ interactive=True,
393
+ row_count=(0, "dynamic"),
394
+ )
395
+ gender_dropdown = gr.Dropdown(choices=["", "nam", "nữ"], value="", label="Giới tính chọn nhanh")
396
+ region_dropdown = gr.Dropdown(choices=["", "bắc", "trung", "nam"], value="", label="Vùng miền chọn nhanh")
397
+ transcription_input = gr.Textbox(label="Transcription (áp dụng nhanh)", lines=1, placeholder="Nhập lời thoại")
398
+ selection_info = gr.Textbox(label="Hàng đang chọn", interactive=False, value="Chưa chọn hàng")
399
+ split_btn = gr.Button("Tách và tải")
400
+ split_status = gr.Textbox(label="Trạng thái tách", lines=2)
401
+ zip_file = gr.File(label="Tải ZIP các đoạn")
402
+
403
+ gr.Markdown(
404
+ """
405
+ #### Nhập ZIP đã tách (gộp nhiều ZIP thành một)
406
+ - Tải lên nhiều file ZIP đã tải về trước đó, công cụ sẽ gộp lại và tạo một merged.zip kèm metadata_all.csv.
407
+ """
408
+ )
409
+ import_files = gr.File(label="Chọn nhiều ZIP", file_count="multiple", file_types=[".zip"])
410
+ import_btn = gr.Button("Nhập ZIP vào thư mục chung")
411
+ import_status = gr.Textbox(label="Trạng thái nhập ZIP", lines=2)
412
+
413
+ result_box = gr.Textbox(label="Bảng phân đoạn", lines=12)
414
+ rttm_file = gr.File(label="Tải RTTM")
415
+ json_file = gr.File(label="Tải JSON")
416
+
417
+ selected_row = gr.State(-1)
418
+
419
+ run_btn.click(
420
+ fn=_diarize_action,
421
+ inputs=[audio_input, token_input, device_input, url_input],
422
+ outputs=[result_box, rttm_file, json_file, segment_df, segments_state, audio_state, playback],
423
+ )
424
+ segment_df.select(
425
+ fn=_select_row_action,
426
+ inputs=None,
427
+ outputs=[selection_info, selected_row],
428
+ )
429
+ gender_dropdown.change(
430
+ fn=_apply_dropdown_action,
431
+ inputs=[segment_df, selected_row, gender_dropdown, region_dropdown, transcription_input],
432
+ outputs=[segment_df, selection_info],
433
+ )
434
+ region_dropdown.change(
435
+ fn=_apply_dropdown_action,
436
+ inputs=[segment_df, selected_row, gender_dropdown, region_dropdown, transcription_input],
437
+ outputs=[segment_df, selection_info],
438
+ )
439
+ transcription_input.change(
440
+ fn=_apply_dropdown_action,
441
+ inputs=[segment_df, selected_row, gender_dropdown, region_dropdown, transcription_input],
442
+ outputs=[segment_df, selection_info],
443
+ )
444
+ split_btn.click(
445
+ fn=_split_segments_action,
446
+ inputs=[segment_df, segments_state, audio_state],
447
+ outputs=[split_status, zip_file],
448
+ )
449
+ import_btn.click(
450
+ fn=_import_archives_action,
451
+ inputs=[import_files],
452
+ outputs=[import_status, zip_file],
453
+ )
454
+ return demo
455
+
456
+
457
  if __name__ == "__main__":
458
+ import sys
459
+ print("=" * 60, file=sys.stderr)
460
+ print("Khởi tạo Vietnamese Diarization App...", file=sys.stderr)
461
+ print("=" * 60, file=sys.stderr)
462
+ try:
463
+ demo = build_interface()
464
+ print("Interface đã được khởi tạo thành công!", file=sys.stderr)
465
+ demo.launch(
466
+ server_name="0.0.0.0",
467
+ server_port=7860,
468
+ )
469
+ except Exception as e:
470
+ print(f"LỖI khi khởi động app: {e}", file=sys.stderr)
471
+ import traceback
472
+ traceback.print_exc()
473
+ sys.exit(1)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ pyannote.audio
2
+ torch==2.1.2
3
+ torchaudio==2.1.2
4
+ gradio==4.39.0
5
+ huggingface_hub==0.23.4
6
+ yt-dlp
7
+ numpy<2.0
src/__pycache__/models.cpython-312.pyc CHANGED
Binary files a/src/__pycache__/models.cpython-312.pyc and b/src/__pycache__/models.cpython-312.pyc differ
 
src/__pycache__/utils.cpython-312.pyc CHANGED
Binary files a/src/__pycache__/utils.cpython-312.pyc and b/src/__pycache__/utils.cpython-312.pyc differ
 
src/models.py CHANGED
@@ -2,13 +2,14 @@ from __future__ import annotations
2
 
3
  from dataclasses import dataclass
4
  from pathlib import Path
5
- from typing import Iterable, List
 
6
 
7
  import torch
8
  from pyannote.audio import Pipeline
9
  from pyannote.audio.pipelines.utils.hook import ProgressHook
10
 
11
- from .utils import ensure_audio_path, read_hf_token
12
 
13
 
14
  @dataclass
@@ -27,10 +28,27 @@ class DiarizationEngine:
27
  token: str | None = None,
28
  key_path: str | Path = "hugging_face_key.txt",
29
  device: str = "auto",
 
 
30
  ) -> None:
31
  self.device = self._resolve_device(device)
32
  auth_token = read_hf_token(token, key_path)
33
- self.pipeline = Pipeline.from_pretrained(model_id, token=auth_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  self.pipeline.to(self.device)
35
 
36
  @staticmethod
@@ -45,17 +63,37 @@ class DiarizationEngine:
45
  return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
46
  raise ValueError("Giá trị device hợp lệ: auto, cpu, cuda.")
47
 
48
- def diarize(self, audio_path: str | Path, show_progress: bool = True):
 
 
49
  audio_path = ensure_audio_path(audio_path)
50
- if show_progress:
51
- with ProgressHook() as hook:
52
- return self.pipeline(str(audio_path), hook=hook)
53
- return self.pipeline(str(audio_path))
 
 
 
 
 
 
 
 
 
54
 
55
  @staticmethod
56
- def to_segments(diarization) -> List[Segment]:
 
 
 
 
 
 
 
 
 
57
  segments: List[Segment] = []
58
- for segment, _, speaker in diarization.itertracks(yield_label=True):
59
  segments.append(
60
  Segment(
61
  start=float(segment.start),
@@ -65,11 +103,12 @@ class DiarizationEngine:
65
  )
66
  return segments
67
 
68
- @staticmethod
69
- def save_rttm(diarization, output_path: str | Path) -> Path:
70
  path = Path(output_path)
71
  path.parent.mkdir(parents=True, exist_ok=True)
72
- diarization.write_rttm(path)
 
73
  return path
74
 
75
  def run(self, audio_path: str | Path, show_progress: bool = True) -> List[Segment]:
 
2
 
3
  from dataclasses import dataclass
4
  from pathlib import Path
5
+ from typing import Iterable, List, Any, Dict, Optional
6
+ import shutil
7
 
8
  import torch
9
  from pyannote.audio import Pipeline
10
  from pyannote.audio.pipelines.utils.hook import ProgressHook
11
 
12
+ from .utils import ensure_audio_path, read_hf_token, convert_to_wav_16k
13
 
14
 
15
  @dataclass
 
28
  token: str | None = None,
29
  key_path: str | Path = "hugging_face_key.txt",
30
  device: str = "auto",
31
+ segmentation_params: Optional[Dict[str, float]] = None,
32
+ clustering_params: Optional[Dict[str, float]] = None,
33
  ) -> None:
34
  self.device = self._resolve_device(device)
35
  auth_token = read_hf_token(token, key_path)
36
+ pipeline = Pipeline.from_pretrained(model_id, token=auth_token)
37
+ params = pipeline.parameters()
38
+ # Giảm phân mảnh: chỉ cập nhật các khóa thực sự tồn tại để tránh lỗi.
39
+ seg_cfg = params.get("segmentation")
40
+ if seg_cfg:
41
+ default_seg = {"min_duration_on": 1.0, "min_duration_off": 0.8}
42
+ for k, v in default_seg.items():
43
+ if k in seg_cfg:
44
+ seg_cfg[k] = v
45
+ if segmentation_params:
46
+ for k, v in segmentation_params.items():
47
+ if k in seg_cfg:
48
+ seg_cfg[k] = v
49
+ if clustering_params and "clustering" in params:
50
+ params["clustering"].update(clustering_params)
51
+ self.pipeline = pipeline.instantiate(params)
52
  self.pipeline.to(self.device)
53
 
54
  @staticmethod
 
63
  return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
64
  raise ValueError("Giá trị device hợp lệ: auto, cpu, cuda.")
65
 
66
+ def diarize(
67
+ self, audio_path: str | Path, show_progress: bool = True, keep_audio: bool = False
68
+ ):
69
  audio_path = ensure_audio_path(audio_path)
70
+ prepared_path, tmpdir = convert_to_wav_16k(audio_path)
71
+ try:
72
+ if show_progress:
73
+ with ProgressHook() as hook:
74
+ result = self.pipeline(str(prepared_path), hook=hook)
75
+ else:
76
+ result = self.pipeline(str(prepared_path))
77
+ if keep_audio:
78
+ return result, prepared_path, tmpdir
79
+ return result
80
+ finally:
81
+ if tmpdir and not keep_audio:
82
+ shutil.rmtree(tmpdir, ignore_errors=True)
83
 
84
  @staticmethod
85
+ def _get_annotation(diarization: Any):
86
+ """Hỗ trợ cả dạng trả về cũ (Annotation) và mới (có speaker_diarization)."""
87
+ if hasattr(diarization, "itertracks"):
88
+ return diarization
89
+ if hasattr(diarization, "speaker_diarization"):
90
+ return diarization.speaker_diarization
91
+ raise TypeError("Output pipeline không có Annotation hoặc speaker_diarization.")
92
+
93
+ def to_segments(self, diarization: Any) -> List[Segment]:
94
+ annotation = self._get_annotation(diarization)
95
  segments: List[Segment] = []
96
+ for segment, _, speaker in annotation.itertracks(yield_label=True):
97
  segments.append(
98
  Segment(
99
  start=float(segment.start),
 
103
  )
104
  return segments
105
 
106
+ def save_rttm(self, diarization: Any, output_path: str | Path) -> Path:
107
+ annotation = self._get_annotation(diarization)
108
  path = Path(output_path)
109
  path.parent.mkdir(parents=True, exist_ok=True)
110
+ with path.open("w", encoding="utf-8") as f:
111
+ annotation.write_rttm(f)
112
  return path
113
 
114
  def run(self, audio_path: str | Path, show_progress: bool = True) -> List[Segment]:
src/utils.py CHANGED
@@ -2,8 +2,11 @@ from __future__ import annotations
2
 
3
  import json
4
  import os
 
 
 
5
  from pathlib import Path
6
- from typing import Iterable, List
7
 
8
 
9
  def read_hf_token(token: str | None = None, key_path: str | Path = "hugging_face_key.txt") -> str:
@@ -46,6 +49,12 @@ def export_segments_json(segments: Iterable[dict], output_path: str | Path) -> P
46
  return path
47
 
48
 
 
 
 
 
 
 
49
  def format_segments_table(segments: Iterable[dict]) -> str:
50
  """Trả về chuỗi bảng đơn giản để in ra terminal."""
51
  lines = []
@@ -53,5 +62,112 @@ def format_segments_table(segments: Iterable[dict]) -> str:
53
  start = seg.get("start", 0.0)
54
  end = seg.get("end", 0.0)
55
  speaker = seg.get("speaker", "unknown")
56
- lines.append(f"{idx:02d} | {start:7.2f}s -> {end:7.2f}s | speaker {speaker}")
 
 
57
  return "\n".join(lines)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import json
4
  import os
5
+ import shutil
6
+ import subprocess
7
+ import tempfile
8
  from pathlib import Path
9
+ from typing import Iterable, List, Tuple
10
 
11
 
12
  def read_hf_token(token: str | None = None, key_path: str | Path = "hugging_face_key.txt") -> str:
 
49
  return path
50
 
51
 
52
+ def seconds_to_mmss(seconds: float) -> str:
53
+ total_seconds = int(round(seconds))
54
+ minutes, sec = divmod(total_seconds, 60)
55
+ return f"{minutes:02d}:{sec:02d}"
56
+
57
+
58
  def format_segments_table(segments: Iterable[dict]) -> str:
59
  """Trả về chuỗi bảng đơn giản để in ra terminal."""
60
  lines = []
 
62
  start = seg.get("start", 0.0)
63
  end = seg.get("end", 0.0)
64
  speaker = seg.get("speaker", "unknown")
65
+ lines.append(
66
+ f"{idx:02d} | {seconds_to_mmss(start)} -> {seconds_to_mmss(end)} | speaker {speaker}"
67
+ )
68
  return "\n".join(lines)
69
+
70
+
71
+ def merge_adjacent_segments(
72
+ segments: list[dict],
73
+ max_gap: float = 0.5,
74
+ min_duration: float = 1.0,
75
+ ) -> list[dict]:
76
+ """
77
+ Ghép các đoạn liên tiếp cùng speaker nếu khoảng trống <= max_gap (giây).
78
+ Đồng thời lọc bỏ đoạn quá ngắn (< min_duration).
79
+ """
80
+ if not segments:
81
+ return []
82
+
83
+ merged: list[dict] = []
84
+ # đảm bảo sắp xếp theo thời gian
85
+ segs = sorted(segments, key=lambda s: s.get("start", 0.0))
86
+ current = segs[0].copy()
87
+
88
+ for seg in segs[1:]:
89
+ if (
90
+ seg.get("speaker") == current.get("speaker")
91
+ and seg.get("start", 0.0) - current.get("end", 0.0) <= max_gap
92
+ ):
93
+ current["end"] = max(current.get("end", 0.0), seg.get("end", 0.0))
94
+ else:
95
+ if current.get("end", 0.0) - current.get("start", 0.0) >= min_duration:
96
+ merged.append(current)
97
+ current = seg.copy()
98
+
99
+ if current.get("end", 0.0) - current.get("start", 0.0) >= min_duration:
100
+ merged.append(current)
101
+
102
+ return merged
103
+
104
+
105
+ def convert_to_wav_16k(audio_path: Path) -> Tuple[Path, Path | None]:
106
+ """
107
+ Chuyển audio về WAV mono 16 kHz bằng ffmpeg.
108
+ Trả về (đường dẫn dùng để suy luận, thư mục tạm để dọn dẹp hoặc None nếu không cần).
109
+ """
110
+ safe_stem = audio_path.stem.replace(" ", "_")
111
+ tmpdir = Path(tempfile.mkdtemp(prefix="diarization_audio_"))
112
+
113
+ if not shutil.which("ffmpeg"):
114
+ # Sao chép tệp hiện có vào thư mục tạm với tên không dấu cách
115
+ output = tmpdir / audio_path.name.replace(" ", "_")
116
+ shutil.copy2(audio_path, output)
117
+ return output, tmpdir
118
+
119
+ output = tmpdir / f"{safe_stem}_16k.wav"
120
+ cmd = [
121
+ "ffmpeg",
122
+ "-y",
123
+ "-i",
124
+ str(audio_path),
125
+ "-ac",
126
+ "1",
127
+ "-ar",
128
+ "16000",
129
+ "-vn",
130
+ "-f",
131
+ "wav",
132
+ str(output),
133
+ ]
134
+ result = subprocess.run(cmd, capture_output=True, text=True)
135
+ if result.returncode != 0:
136
+ stderr = result.stderr.strip()
137
+ raise RuntimeError(f"ffmpeg convert thất bại: {stderr}")
138
+ if not output.exists():
139
+ raise RuntimeError("ffmpeg không tạo được file WAV.")
140
+ return output, tmpdir
141
+
142
+
143
+ def download_audio_from_url(url: str) -> Tuple[Path, Path]:
144
+ """
145
+ Tải audio từ YouTube/TikTok/... dùng yt-dlp, xuất WAV để xử lý tiếp.
146
+ Trả về (đường dẫn file wav, thư mục tạm chứa file).
147
+ """
148
+ if not shutil.which("yt-dlp"):
149
+ raise RuntimeError("Cần cài yt-dlp để tải liên kết (pip install yt-dlp).")
150
+ if not shutil.which("ffmpeg"):
151
+ raise RuntimeError("Cần cài ffmpeg để chuyển đổi audio.")
152
+
153
+ tmpdir = Path(tempfile.mkdtemp(prefix="download_media_"))
154
+ out_tmpl = tmpdir / "%(title)s.%(ext)s"
155
+ cmd = [
156
+ "yt-dlp",
157
+ "-x",
158
+ "--audio-format",
159
+ "wav",
160
+ "--audio-quality",
161
+ "0",
162
+ "-o",
163
+ str(out_tmpl),
164
+ url,
165
+ ]
166
+ result = subprocess.run(cmd, capture_output=True, text=True)
167
+ if result.returncode != 0:
168
+ raise RuntimeError(f"Tải audio thất bại: {result.stderr.strip()}")
169
+
170
+ wav_files = list(tmpdir.glob("*.wav"))
171
+ if not wav_files:
172
+ raise RuntimeError("Không tìm thấy file WAV sau khi tải.")
173
+ return wav_files[0], tmpdir
test_gradio.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def greet(name):
4
+ return f"Hello {name}!"
5
+
6
+ demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
+
8
+ if __name__ == "__main__":
9
+ print("Starting Gradio server...")
10
+ demo.launch(server_name="0.0.0.0", server_port=7860)
11
+ print("Server started!")