lataon commited on
Commit
aa67214
·
1 Parent(s): 99d9342
Files changed (3) hide show
  1. note.txt +2 -0
  2. phoneme_eval.py +153 -112
  3. utils/load_model.py +15 -87
note.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ - tạo range cho các item trong dataset
2
+ - chạy parallel
phoneme_eval.py CHANGED
@@ -2,7 +2,7 @@ import pandas as pd
2
  from utils.load_model import run_hubert_base, run_whisper, run_model, run_timit, run_wavlm_large_phoneme, run_gruut
3
  from utils.audio_process import calculate_error_rate, load_audio
4
  from utils.cmu_process import clean_cmu, cmu_to_ipa, text_to_phoneme
5
- from constants import DATASETS
6
  from datasets import load_dataset, Audio
7
  import argparse
8
 
@@ -12,8 +12,8 @@ MODEL_RUNNERS = {
12
  "Whisper": run_whisper,
13
  "HuBERT fine-tuned": run_model,
14
  "Timit": run_timit,
15
- "speech31/wavlm-large-english-phoneme": run_wavlm_large_phoneme,
16
- "bookbot/wav2vec2-ljspeech-gruut": run_gruut,
17
  }
18
 
19
  def set_output(model, pre_pho, ref_pho, duration, per, score):
@@ -42,23 +42,32 @@ def get_output(model, wav, reference_phoneme):
42
 
43
  def benchmark_all(example):
44
  """
45
- Run all models on a single dataset example.
46
  """
47
  # Load waveform manually to avoid datasets' torchcodec dependency
48
- wav = load_audio(example["audio"])
49
- reference_phoneme = example["phonetic"]
50
  reference_phoneme = cmu_to_ipa(clean_cmu(reference_phoneme))
51
 
52
- # Run all models
53
- results = [
54
- get_output("HuBERT-Base", wav, reference_phoneme),
55
- get_output("Whisper", wav, reference_phoneme),
56
- get_output("HuBERT fine-tuned", wav, reference_phoneme),
57
- get_output("Timit", wav, reference_phoneme),
58
- get_output("WavLM", wav, reference_phoneme),
59
- get_output("LJSpeech Gruut", wav, reference_phoneme),
 
 
60
  ]
61
 
 
 
 
 
 
 
 
62
  return pd.DataFrame(results)
63
 
64
  def benchmark_dataset(dataset):
@@ -127,124 +136,112 @@ def load_dataset_with_limits(dataset_config, max_samples=None, use_streaming=Fal
127
  print(f"[warn] skip dataset {dataset_config['name']}: {e}")
128
  return None
129
 
130
- def main():
131
- # Parse command line arguments
 
 
132
  parser = argparse.ArgumentParser(description='Phoneme Detection Evaluation')
133
  parser.add_argument('--max-samples', type=int, default=None,
134
  help='Override max_samples for all datasets')
135
  parser.add_argument('--dataset', type=str, default=None,
136
  help='Process only specific dataset (by name)')
137
- args = parser.parse_args()
138
-
139
- per_model_results = {}
140
 
141
- for dataset_config in DATASETS:
142
- # Skip dataset if specific dataset is requested and this isn't it
143
- if args.dataset and args.dataset not in dataset_config["name"]:
144
- continue
145
-
146
- # Override max_samples if provided via command line
147
- max_samples = args.max_samples if args.max_samples is not None else dataset_config.get("max_samples")
148
- use_streaming = dataset_config.get("use_streaming", False)
149
-
150
- # Load dataset with limits
151
- dataset = load_dataset_with_limits(
152
- dataset_config,
153
- max_samples=max_samples,
154
- use_streaming=use_streaming
155
- )
156
-
157
- if dataset is None:
158
- continue
159
 
160
- try:
161
- dataset = dataset.cast_column("audio", Audio(decode=False))
162
- except Exception:
163
- pass
 
 
 
 
164
 
165
- field = dataset_config["field"]
166
-
167
- # Handle streaming datasets differently
168
- if use_streaming:
169
- print("Processing streaming dataset...")
170
- valid_samples = []
171
-
172
- # Set a reasonable limit for streaming (max 100 samples)
173
- streaming_limit = max(max_samples or 100, 100)
174
-
175
- for example in dataset:
176
- # Convert text to phonemes if needed
177
- if field == "text":
178
- phonetic_text = text_to_phoneme(example[field])
179
- example = {**example, "phonetic": phonetic_text}
180
- current_field = "phonetic"
181
- else:
182
- current_field = field
183
-
184
- # Check if valid
185
- if current_field in example:
186
- phoneme_tokens = example[current_field].split()
187
- if len(phoneme_tokens) >= 10:
188
- valid_samples.append(example)
189
- # Stop when we reach the streaming limit
190
- if len(valid_samples) >= streaming_limit:
191
- break
192
-
193
- print(f"Found {len(valid_samples)} valid samples")
194
- if len(valid_samples) == 0:
195
- print("No valid samples found, skipping dataset")
196
- continue
197
-
198
- # Convert to regular dataset for processing
199
- from datasets import Dataset
200
- dataset_final = Dataset.from_list(valid_samples)
201
- field = "phonetic" if field == "text" else field
202
- else:
203
- # Regular dataset processing
204
  if field == "text":
205
- dataset = dataset.map(lambda x: {"phonetic": text_to_phoneme(x[field])})
206
- field = "phonetic"
 
 
 
207
 
208
- unique_texts = dataset.unique(field)
209
- print("Unique phonetic strings (", dataset_config["name"], "):", len(unique_texts))
 
 
 
 
210
 
211
- dataset_unique = dataset.filter(lambda x: x[field] in unique_texts)
 
 
 
212
 
213
- def is_valid(example):
214
- phoneme_tokens = example[field].split()
215
- return len(phoneme_tokens) >= 10
 
 
 
 
216
 
217
- dataset_filtered = dataset_unique.filter(is_valid)
218
- # Use smaller final size for evaluation
219
- final_size = min(100, len(dataset_filtered))
220
- dataset_final = dataset_filtered.shuffle(seed=42).select(range(final_size))
221
 
222
- print(dataset_final)
223
- print("Final size:", len(dataset_final))
224
 
225
- # Limit to 10 samples for benchmarking
226
- benchmark_size = min(10, len(dataset_final))
227
- full_results, avg_stats = benchmark_dataset(dataset_final.select(range(benchmark_size)))
228
- print("Average Statistic per model (", dataset_config["name"], "):")
229
- print(avg_stats)
230
 
231
- # Use dataset name as key (extract the actual name part)
232
- dataset_key = dataset_config["name"].split("/")[-1] # Get the last part after the slash
233
- for _, row in avg_stats.iterrows():
234
- model_name = str(row["model"]).replace(" ", "-")
235
- per = float(row["Average PER"]) if row["Average PER"] is not None else None
236
- avg_dur = float(row["Average Duration (s)"]) if row["Average Duration (s)"] is not None else None
237
 
238
- if model_name not in per_model_results:
239
- per_model_results[model_name] = {}
240
- per_model_results[model_name][dataset_key] = {"per": per, "avg_duration": avg_dur}
 
 
 
 
241
 
242
- # Save results for leaderboard consumption (one JSON per model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  import json, os, time
244
- # results_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "eval-results")
245
- results_dir = os.path.join("eval-results")
246
  os.makedirs(results_dir, exist_ok=True)
247
-
248
  timestamp = int(time.time())
249
  for model_name, task_results in per_model_results.items():
250
  org_model = f"{model_name}"
@@ -261,6 +258,50 @@ def main():
261
  json.dump(payload, f, ensure_ascii=False, indent=2)
262
  print(f"Saved leaderboard result: {out_path}")
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
  if __name__ == "__main__":
266
  main()
 
2
  from utils.load_model import run_hubert_base, run_whisper, run_model, run_timit, run_wavlm_large_phoneme, run_gruut
3
  from utils.audio_process import calculate_error_rate, load_audio
4
  from utils.cmu_process import clean_cmu, cmu_to_ipa, text_to_phoneme
5
+ from constants import DATASETS, FINAL_SIZE
6
  from datasets import load_dataset, Audio
7
  import argparse
8
 
 
12
  "Whisper": run_whisper,
13
  "HuBERT fine-tuned": run_model,
14
  "Timit": run_timit,
15
+ "WavLM": run_wavlm_large_phoneme,
16
+ "LJSpeech Gruut": run_gruut,
17
  }
18
 
19
  def set_output(model, pre_pho, ref_pho, duration, per, score):
 
42
 
43
  def benchmark_all(example):
44
  """
45
+ Run all models on a single dataset example in parallel.
46
  """
47
  # Load waveform manually to avoid datasets' torchcodec dependency
48
+ wav = load_audio(example["audio"])
49
+ reference_phoneme = example["phonetic"]
50
  reference_phoneme = cmu_to_ipa(clean_cmu(reference_phoneme))
51
 
52
+ # Run all models in parallel using ThreadPoolExecutor
53
+ from concurrent.futures import ThreadPoolExecutor
54
+
55
+ models = [
56
+ "HuBERT-Base",
57
+ "Whisper",
58
+ "HuBERT fine-tuned",
59
+ "Timit",
60
+ "WavLM",
61
+ "LJSpeech Gruut"
62
  ]
63
 
64
+ with ThreadPoolExecutor(max_workers=len(models)) as executor:
65
+ futures = [
66
+ executor.submit(get_output, model, wav, reference_phoneme)
67
+ for model in models
68
+ ]
69
+ results = [future.result() for future in futures]
70
+
71
  return pd.DataFrame(results)
72
 
73
  def benchmark_dataset(dataset):
 
136
  print(f"[warn] skip dataset {dataset_config['name']}: {e}")
137
  return None
138
 
139
+ def parse_cli_args():
140
+ """
141
+ Parse and return CLI arguments for the evaluation script.
142
+ """
143
  parser = argparse.ArgumentParser(description='Phoneme Detection Evaluation')
144
  parser.add_argument('--max-samples', type=int, default=None,
145
  help='Override max_samples for all datasets')
146
  parser.add_argument('--dataset', type=str, default=None,
147
  help='Process only specific dataset (by name)')
148
+ return parser.parse_args()
 
 
149
 
150
+ def cast_audio_column_safely(dataset):
151
+ """
152
+ Ensure the dataset's 'audio' column is set to non-decoding Audio.
153
+ """
154
+ try:
155
+ dataset = dataset.cast_column("audio", Audio(decode=False))
156
+ except Exception:
157
+ pass
158
+ return dataset
 
 
 
 
 
 
 
 
 
159
 
160
+ def prepare_dataset_for_evaluation(dataset, dataset_config, max_samples):
161
+ """
162
+ Normalize, deduplicate, and filter dataset examples for evaluation.
163
+ Handles both streaming and non-streaming datasets.
164
+ Returns a finalized small dataset suitable for benchmarking.
165
+ """
166
+ field = dataset_config["field"]
167
+ use_streaming = dataset_config.get("use_streaming", False)
168
 
169
+ if use_streaming:
170
+ print("Processing streaming dataset...")
171
+ valid_samples = []
172
+
173
+ streaming_limit = min(max_samples, FINAL_SIZE)
174
+
175
+ for example in dataset:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  if field == "text":
177
+ phonetic_text = text_to_phoneme(example[field])
178
+ example = {**example, "phonetic": phonetic_text}
179
+ current_field = "phonetic"
180
+ else:
181
+ current_field = field
182
 
183
+ if current_field in example:
184
+ phoneme_tokens = example[current_field].split()
185
+ if len(phoneme_tokens) >= 10:
186
+ valid_samples.append(example)
187
+ if len(valid_samples) >= streaming_limit:
188
+ break
189
 
190
+ print(f"Found {len(valid_samples)} valid samples")
191
+ if len(valid_samples) == 0:
192
+ print("No valid samples found, skipping dataset")
193
+ return None
194
 
195
+ from datasets import Dataset
196
+ dataset_final = Dataset.from_list(valid_samples)
197
+ return dataset_final
198
+ else:
199
+ if field == "text":
200
+ dataset = dataset.map(lambda x: {"phonetic": text_to_phoneme(x[field])})
201
+ field = "phonetic"
202
 
203
+ unique_texts = dataset.unique(field)
204
+ print("Unique phonetic strings (", dataset_config["name"], "):", len(unique_texts))
 
 
205
 
206
+ dataset_unique = dataset.filter(lambda x: x[field] in unique_texts)
 
207
 
208
+ def is_valid(example):
209
+ phoneme_tokens = example[field].split()
210
+ return len(phoneme_tokens) >= 10
 
 
211
 
212
+ dataset_filtered = dataset_unique.filter(is_valid)
213
+ final_size = min(FINAL_SIZE, len(dataset_filtered))
214
+ dataset_final = dataset_filtered.shuffle(seed=42).select(range(final_size))
215
+ return dataset_final
 
 
216
 
217
+ def evaluate_dataset(dataset_final):
218
+ """
219
+ Run benchmarking on a capped subset of the dataset and return both
220
+ the full per-example results and the aggregated stats per model.
221
+ """
222
+ benchmark_size = min(FINAL_SIZE, len(dataset_final))
223
+ return benchmark_dataset(dataset_final.select(range(benchmark_size)))
224
 
225
+ def update_aggregates(per_model_results, avg_stats, dataset_name):
226
+ """
227
+ Update the aggregate dictionary per model with results from one dataset.
228
+ """
229
+ dataset_key = dataset_name.split("/")[-1]
230
+ for _, row in avg_stats.iterrows():
231
+ model_name = str(row["model"]).replace(" ", "-")
232
+ per = float(row["Average PER"]) if row["Average PER"] is not None else None
233
+ avg_dur = float(row["Average Duration (s)"]) if row["Average Duration (s)"] is not None else None
234
+
235
+ if model_name not in per_model_results:
236
+ per_model_results[model_name] = {}
237
+ per_model_results[model_name][dataset_key] = {"per": per, "avg_duration": avg_dur}
238
+
239
+ def save_leaderboard_results(per_model_results, results_dir="eval-results"):
240
+ """
241
+ Persist one JSON file per model for the leaderboard app to consume.
242
+ """
243
  import json, os, time
 
 
244
  os.makedirs(results_dir, exist_ok=True)
 
245
  timestamp = int(time.time())
246
  for model_name, task_results in per_model_results.items():
247
  org_model = f"{model_name}"
 
258
  json.dump(payload, f, ensure_ascii=False, indent=2)
259
  print(f"Saved leaderboard result: {out_path}")
260
 
261
+ def process_single_dataset(dataset_config, args, per_model_results):
262
+ """
263
+ Load, normalize, evaluate a single dataset and update aggregates.
264
+ """
265
+ if args.dataset and args.dataset not in dataset_config["name"]:
266
+ return
267
+
268
+ max_samples = args.max_samples if args.max_samples is not None else dataset_config.get("max_samples")
269
+ use_streaming = dataset_config.get("use_streaming", False)
270
+
271
+ dataset = load_dataset_with_limits(
272
+ dataset_config,
273
+ max_samples=max_samples,
274
+ use_streaming=use_streaming
275
+ )
276
+
277
+ if dataset is None:
278
+ return
279
+
280
+ dataset = cast_audio_column_safely(dataset)
281
+
282
+ dataset_final = prepare_dataset_for_evaluation(dataset, dataset_config, max_samples)
283
+ if dataset_final is None:
284
+ return
285
+
286
+ print(dataset_final)
287
+ print("Final size:", len(dataset_final))
288
+
289
+ full_results, avg_stats = evaluate_dataset(dataset_final)
290
+ print("Average Statistic per model (", dataset_config["name"], "):")
291
+ print(avg_stats)
292
+
293
+ update_aggregates(per_model_results, avg_stats, dataset_config["name"])
294
+
295
+ def main():
296
+ args = parse_cli_args()
297
+
298
+ per_model_results = {}
299
+
300
+ for dataset_config in DATASETS:
301
+ process_single_dataset(dataset_config, args, per_model_results)
302
+
303
+ save_leaderboard_results(per_model_results)
304
+
305
 
306
  if __name__ == "__main__":
307
  main()
utils/load_model.py CHANGED
@@ -9,7 +9,6 @@ from transformers import (
9
  from .cmu_process import text_to_phoneme, cmu_to_ipa, clean_cmu
10
 
11
  from dotenv import load_dotenv
12
- import torch.backends.cudnn as cudnn
13
 
14
  # Load environment variables from .env file
15
  load_dotenv()
@@ -18,10 +17,6 @@ load_dotenv()
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  print("Using device:", device)
20
 
21
- # Enable faster cudnn autotuner for variable input lengths
22
- if device.type == "cuda":
23
- cudnn.benchmark = True
24
-
25
  # === Helper: move all tensors to model device ===
26
  def to_device(batch, device):
27
  if isinstance(batch, dict):
@@ -66,16 +61,9 @@ wavlm_model = AutoModelForCTC.from_pretrained("speech31/wavlm-large-english-phon
66
  def run_hubert_base(wav):
67
  start = time.time()
68
  inputs = base_proc(wav, sampling_rate=16000, return_tensors="pt", padding=True).input_values
69
- if device.type == "cuda":
70
- try:
71
- inputs = inputs.pin_memory()
72
- except Exception:
73
- pass
74
- inputs = inputs.to(device, non_blocking=True)
75
- else:
76
- inputs = inputs.to(device)
77
 
78
- with torch.inference_mode():
79
  logits = base_model(inputs).logits
80
  ids = torch.argmax(logits, dim=-1)
81
  text = base_proc.batch_decode(ids)[0]
@@ -87,43 +75,14 @@ def run_whisper(wav):
87
  start = time.time()
88
 
89
  inputs = whisper_proc(wav, sampling_rate=16000, return_tensors="pt")
90
- input_features = inputs.input_features
91
- if device.type == "cuda":
92
- try:
93
- input_features = input_features.pin_memory()
94
- except Exception:
95
- pass
96
- input_features = input_features.to(device, non_blocking=True)
97
- else:
98
- input_features = input_features.to(device)
99
  attention_mask = inputs.get("attention_mask", None)
100
  gen_kwargs = {"language": "en"}
101
  if attention_mask is not None:
102
- if device.type == "cuda":
103
- try:
104
- attention_mask = attention_mask.pin_memory()
105
- except Exception:
106
- pass
107
- gen_kwargs["attention_mask"] = attention_mask.to(device, non_blocking=True)
108
- else:
109
- gen_kwargs["attention_mask"] = attention_mask.to(device)
110
-
111
- # Force English transcription and use greedy decoding with short max tokens for speed
112
- try:
113
- forced_ids = whisper_proc.get_decoder_prompt_ids(language="en", task="transcribe")
114
- except Exception:
115
- forced_ids = None
116
-
117
- with torch.inference_mode():
118
- pred_ids = whisper_model.generate(
119
- input_features,
120
- forced_decoder_ids=forced_ids,
121
- do_sample=False,
122
- num_beams=1,
123
- max_new_tokens=64,
124
- use_cache=True,
125
- **gen_kwargs,
126
- )
127
 
128
  text = whisper_proc.batch_decode(pred_ids, skip_special_tokens=True)[0]
129
  phonemes = text_to_phoneme(text)
@@ -134,18 +93,10 @@ def run_model(wav):
134
  start = time.time()
135
 
136
  # Prepare input (BatchEncoding supports .to(device))
137
- inputs = proc(wav, sampling_rate=16000, return_tensors="pt")
138
- if device.type == "cuda":
139
- try:
140
- inputs = inputs.pin_memory()
141
- except Exception:
142
- pass
143
- inputs = inputs.to(device, non_blocking=True)
144
- else:
145
- inputs = inputs.to(device)
146
 
147
  # Forward pass
148
- with torch.inference_mode():
149
  logits = model(**inputs).logits
150
 
151
  # Greedy decode
@@ -159,17 +110,10 @@ def run_timit(wav):
159
  start = time.time()
160
  # Read and process the input
161
  inputs = timit_proc(wav, sampling_rate=16_000, return_tensors="pt", padding=True)
162
- if device.type == "cuda":
163
- try:
164
- inputs = inputs.pin_memory()
165
- except Exception:
166
- pass
167
- inputs = inputs.to(device, non_blocking=True)
168
- else:
169
- inputs = inputs.to(device)
170
 
171
  # Forward pass
172
- with torch.inference_mode():
173
  logits = timit_model(inputs.input_values, attention_mask=inputs.attention_mask).logits
174
 
175
  # Decode id into string
@@ -189,18 +133,10 @@ def run_gruut(wav):
189
  sampling_rate=16000,
190
  return_tensors="pt",
191
  padding=True
192
- )
193
- if device.type == "cuda":
194
- try:
195
- inputs = inputs.pin_memory()
196
- except Exception:
197
- pass
198
- inputs = inputs.to(device, non_blocking=True)
199
- else:
200
- inputs = inputs.to(device)
201
 
202
  # Forward pass
203
- with torch.inference_mode():
204
  logits = gruut_model(**inputs).logits
205
 
206
  # Greedy decode → IPA phonemes
@@ -219,21 +155,13 @@ def run_wavlm_large_phoneme(wav):
219
  sampling_rate=16000,
220
  return_tensors="pt",
221
  padding=True
222
- )
223
- if device.type == "cuda":
224
- try:
225
- inputs = inputs.pin_memory()
226
- except Exception:
227
- pass
228
- inputs = inputs.to(device, non_blocking=True)
229
- else:
230
- inputs = inputs.to(device)
231
 
232
  input_values = inputs.input_values
233
  attention_mask = inputs.get("attention_mask", None)
234
 
235
  # Forward pass
236
- with torch.inference_mode():
237
  logits = wavlm_model(input_values, attention_mask=attention_mask).logits
238
 
239
  # Greedy decode → phoneme tokens
 
9
  from .cmu_process import text_to_phoneme, cmu_to_ipa, clean_cmu
10
 
11
  from dotenv import load_dotenv
 
12
 
13
  # Load environment variables from .env file
14
  load_dotenv()
 
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
  print("Using device:", device)
19
 
 
 
 
 
20
  # === Helper: move all tensors to model device ===
21
  def to_device(batch, device):
22
  if isinstance(batch, dict):
 
61
  def run_hubert_base(wav):
62
  start = time.time()
63
  inputs = base_proc(wav, sampling_rate=16000, return_tensors="pt", padding=True).input_values
64
+ inputs = inputs.to(device)
 
 
 
 
 
 
 
65
 
66
+ with torch.no_grad():
67
  logits = base_model(inputs).logits
68
  ids = torch.argmax(logits, dim=-1)
69
  text = base_proc.batch_decode(ids)[0]
 
75
  start = time.time()
76
 
77
  inputs = whisper_proc(wav, sampling_rate=16000, return_tensors="pt")
78
+ input_features = inputs.input_features.to(device)
 
 
 
 
 
 
 
 
79
  attention_mask = inputs.get("attention_mask", None)
80
  gen_kwargs = {"language": "en"}
81
  if attention_mask is not None:
82
+ gen_kwargs["attention_mask"] = attention_mask.to(device)
83
+
84
+ with torch.no_grad():
85
+ pred_ids = whisper_model.generate(input_features, **gen_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  text = whisper_proc.batch_decode(pred_ids, skip_special_tokens=True)[0]
88
  phonemes = text_to_phoneme(text)
 
93
  start = time.time()
94
 
95
  # Prepare input (BatchEncoding supports .to(device))
96
+ inputs = proc(wav, sampling_rate=16000, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
97
 
98
  # Forward pass
99
+ with torch.no_grad():
100
  logits = model(**inputs).logits
101
 
102
  # Greedy decode
 
110
  start = time.time()
111
  # Read and process the input
112
  inputs = timit_proc(wav, sampling_rate=16_000, return_tensors="pt", padding=True)
113
+ inputs = inputs.to(device)
 
 
 
 
 
 
 
114
 
115
  # Forward pass
116
+ with torch.no_grad():
117
  logits = timit_model(inputs.input_values, attention_mask=inputs.attention_mask).logits
118
 
119
  # Decode id into string
 
133
  sampling_rate=16000,
134
  return_tensors="pt",
135
  padding=True
136
+ ).to(device)
 
 
 
 
 
 
 
 
137
 
138
  # Forward pass
139
+ with torch.no_grad():
140
  logits = gruut_model(**inputs).logits
141
 
142
  # Greedy decode → IPA phonemes
 
155
  sampling_rate=16000,
156
  return_tensors="pt",
157
  padding=True
158
+ ).to(device)
 
 
 
 
 
 
 
 
159
 
160
  input_values = inputs.input_values
161
  attention_mask = inputs.get("attention_mask", None)
162
 
163
  # Forward pass
164
+ with torch.no_grad():
165
  logits = wavlm_model(input_values, attention_mask=attention_mask).logits
166
 
167
  # Greedy decode → phoneme tokens