|
|
import pandas as pd |
|
|
from utils.load_model import run_hubert_base, run_whisper, run_model, run_timit, run_wavlm_large_phoneme, run_gruut |
|
|
from utils.audio_process import calculate_error_rate, load_audio |
|
|
from utils.cmu_process import clean_cmu, cmu_to_ipa, text_to_phoneme |
|
|
from constants import DATASETS, FINAL_SIZE |
|
|
from datasets import load_dataset, Audio |
|
|
import argparse |
|
|
|
|
|
|
|
|
MODEL_RUNNERS = { |
|
|
"HuBERT-Base": run_hubert_base, |
|
|
"Whisper": run_whisper, |
|
|
"HuBERT fine-tuned": run_model, |
|
|
"Timit": run_timit, |
|
|
"WavLM": run_wavlm_large_phoneme, |
|
|
"LJSpeech Gruut": run_gruut, |
|
|
} |
|
|
|
|
|
def set_output(model, pre_pho, ref_pho, duration, per, score): |
|
|
return { |
|
|
"model": model, |
|
|
"phonemes": pre_pho, |
|
|
"ref_phonemes": ref_pho, |
|
|
"duration": duration, |
|
|
"PER": per, |
|
|
"score": score |
|
|
} |
|
|
|
|
|
def get_output(model, wav, reference_phoneme): |
|
|
""" |
|
|
Run the given model, compute error rate, and return formatted output. |
|
|
""" |
|
|
if model not in MODEL_RUNNERS: |
|
|
raise ValueError(f"Unknown model: {model}") |
|
|
|
|
|
run_func = MODEL_RUNNERS[model] |
|
|
phonemes, dur = run_func(wav) |
|
|
per, score = calculate_error_rate(reference_phoneme, phonemes) |
|
|
|
|
|
return set_output(model, phonemes, reference_phoneme, dur, per, score) |
|
|
|
|
|
|
|
|
def benchmark_all(example): |
|
|
""" |
|
|
Run all models on a single dataset example in parallel. |
|
|
""" |
|
|
|
|
|
wav = load_audio(example["audio"]) |
|
|
reference_phoneme = example["phonetic"] |
|
|
reference_phoneme = cmu_to_ipa(clean_cmu(reference_phoneme)) |
|
|
|
|
|
|
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
|
|
models = [ |
|
|
"HuBERT-Base", |
|
|
"Whisper", |
|
|
"HuBERT fine-tuned", |
|
|
"Timit", |
|
|
"WavLM", |
|
|
"LJSpeech Gruut" |
|
|
] |
|
|
|
|
|
with ThreadPoolExecutor(max_workers=len(models)) as executor: |
|
|
futures = [ |
|
|
executor.submit(get_output, model, wav, reference_phoneme) |
|
|
for model in models |
|
|
] |
|
|
results = [future.result() for future in futures] |
|
|
|
|
|
return pd.DataFrame(results) |
|
|
|
|
|
def benchmark_dataset(dataset): |
|
|
""" |
|
|
Run benchmark_all on each sample and compute average PER and duration per model. |
|
|
""" |
|
|
all_results = [] |
|
|
for example in dataset: |
|
|
df = benchmark_all(example) |
|
|
all_results.append(df) |
|
|
|
|
|
full_df = pd.concat(all_results, ignore_index=True) |
|
|
|
|
|
|
|
|
avg_stats = ( |
|
|
full_df.groupby("model")[["PER", "duration"]] |
|
|
.mean() |
|
|
.reset_index() |
|
|
.rename(columns={"PER": "Average PER", "duration": "Average Duration (s)"}) |
|
|
) |
|
|
|
|
|
return full_df, avg_stats |
|
|
|
|
|
def load_dataset_with_limits(dataset_config, max_samples=None, use_streaming=False): |
|
|
""" |
|
|
Load a dataset with optional size limits and streaming. |
|
|
|
|
|
Args: |
|
|
dataset_config: Dictionary containing dataset configuration |
|
|
max_samples: Maximum number of samples to load (None for no limit) |
|
|
use_streaming: Whether to use streaming for large datasets |
|
|
|
|
|
Returns: |
|
|
Dataset object |
|
|
""" |
|
|
try: |
|
|
|
|
|
load_args = { |
|
|
"path": dataset_config["name"], |
|
|
"split": dataset_config["split"] |
|
|
} |
|
|
|
|
|
|
|
|
if "config" in dataset_config: |
|
|
load_args["name"] = dataset_config["config"] |
|
|
|
|
|
|
|
|
if use_streaming: |
|
|
load_args["streaming"] = True |
|
|
print(f"Loading {dataset_config['name']} with streaming...") |
|
|
else: |
|
|
print(f"Loading {dataset_config['name']}...") |
|
|
|
|
|
dataset = load_dataset(**load_args) |
|
|
|
|
|
|
|
|
if max_samples is not None: |
|
|
print(f"Limiting dataset to {max_samples} samples...") |
|
|
if use_streaming: |
|
|
dataset = dataset.take(max_samples) |
|
|
else: |
|
|
dataset = dataset.select(range(min(max_samples, len(dataset)))) |
|
|
|
|
|
return dataset |
|
|
except Exception as e: |
|
|
print(f"[warn] skip dataset {dataset_config['name']}: {e}") |
|
|
return None |
|
|
|
|
|
def parse_cli_args(): |
|
|
""" |
|
|
Parse and return CLI arguments for the evaluation script. |
|
|
""" |
|
|
parser = argparse.ArgumentParser(description='Phoneme Detection Evaluation') |
|
|
parser.add_argument('--max-samples', type=int, default=None, |
|
|
help='Override max_samples for all datasets') |
|
|
parser.add_argument('--dataset', type=str, default=None, |
|
|
help='Process only specific dataset (by name)') |
|
|
return parser.parse_args() |
|
|
|
|
|
def cast_audio_column_safely(dataset): |
|
|
""" |
|
|
Ensure the dataset's 'audio' column is set to non-decoding Audio. |
|
|
""" |
|
|
try: |
|
|
dataset = dataset.cast_column("audio", Audio(decode=False)) |
|
|
except Exception: |
|
|
pass |
|
|
return dataset |
|
|
|
|
|
def prepare_dataset_for_evaluation(dataset, dataset_config, max_samples): |
|
|
""" |
|
|
Normalize, deduplicate, and filter dataset examples for evaluation. |
|
|
Handles both streaming and non-streaming datasets. |
|
|
Returns a finalized small dataset suitable for benchmarking. |
|
|
""" |
|
|
field = dataset_config["field"] |
|
|
use_streaming = dataset_config.get("use_streaming", False) |
|
|
|
|
|
if use_streaming: |
|
|
print("Processing streaming dataset...") |
|
|
valid_samples = [] |
|
|
|
|
|
streaming_limit = min(max_samples, FINAL_SIZE) |
|
|
|
|
|
for example in dataset: |
|
|
if field == "text": |
|
|
phonetic_text = text_to_phoneme(example[field]) |
|
|
example = {**example, "phonetic": phonetic_text} |
|
|
current_field = "phonetic" |
|
|
else: |
|
|
current_field = field |
|
|
|
|
|
if current_field in example: |
|
|
phoneme_tokens = example[current_field].split() |
|
|
if len(phoneme_tokens) >= 10: |
|
|
valid_samples.append(example) |
|
|
if len(valid_samples) >= streaming_limit: |
|
|
break |
|
|
|
|
|
print(f"Found {len(valid_samples)} valid samples") |
|
|
if len(valid_samples) == 0: |
|
|
print("No valid samples found, skipping dataset") |
|
|
return None |
|
|
|
|
|
from datasets import Dataset |
|
|
dataset_final = Dataset.from_list(valid_samples) |
|
|
return dataset_final |
|
|
else: |
|
|
if field == "text": |
|
|
dataset = dataset.map(lambda x: {"phonetic": text_to_phoneme(x[field])}) |
|
|
field = "phonetic" |
|
|
|
|
|
unique_texts = dataset.unique(field) |
|
|
print("Unique phonetic strings (", dataset_config["name"], "):", len(unique_texts)) |
|
|
|
|
|
dataset_unique = dataset.filter(lambda x: x[field] in unique_texts) |
|
|
|
|
|
def is_valid(example): |
|
|
phoneme_tokens = example[field].split() |
|
|
return len(phoneme_tokens) >= 10 |
|
|
|
|
|
dataset_filtered = dataset_unique.filter(is_valid) |
|
|
final_size = min(FINAL_SIZE, len(dataset_filtered)) |
|
|
dataset_final = dataset_filtered.shuffle(seed=42).select(range(final_size)) |
|
|
return dataset_final |
|
|
|
|
|
def evaluate_dataset(dataset_final): |
|
|
""" |
|
|
Run benchmarking on a capped subset of the dataset and return both |
|
|
the full per-example results and the aggregated stats per model. |
|
|
""" |
|
|
benchmark_size = min(FINAL_SIZE, len(dataset_final)) |
|
|
return benchmark_dataset(dataset_final.select(range(benchmark_size))) |
|
|
|
|
|
def update_aggregates(per_model_results, avg_stats, dataset_name): |
|
|
""" |
|
|
Update the aggregate dictionary per model with results from one dataset. |
|
|
""" |
|
|
dataset_key = dataset_name.split("/")[-1] |
|
|
for _, row in avg_stats.iterrows(): |
|
|
model_name = str(row["model"]).replace(" ", "-") |
|
|
per = float(row["Average PER"]) if row["Average PER"] is not None else None |
|
|
avg_dur = float(row["Average Duration (s)"]) if row["Average Duration (s)"] is not None else None |
|
|
|
|
|
if model_name not in per_model_results: |
|
|
per_model_results[model_name] = {} |
|
|
per_model_results[model_name][dataset_key] = {"per": per, "avg_duration": avg_dur} |
|
|
|
|
|
def save_leaderboard_results(per_model_results, results_dir="eval-results"): |
|
|
""" |
|
|
Persist one JSON file per model for the leaderboard app to consume. |
|
|
""" |
|
|
import json, os, time |
|
|
os.makedirs(results_dir, exist_ok=True) |
|
|
timestamp = int(time.time()) |
|
|
for model_name, task_results in per_model_results.items(): |
|
|
org_model = f"{model_name}" |
|
|
payload = { |
|
|
"config": { |
|
|
"model_name": org_model, |
|
|
"model_dtype": "float32", |
|
|
"model_sha": "" |
|
|
}, |
|
|
"results": task_results |
|
|
} |
|
|
out_path = os.path.join(results_dir, f"results_{timestamp}_{model_name}.json") |
|
|
with open(out_path, "w", encoding="utf-8") as f: |
|
|
json.dump(payload, f, ensure_ascii=False, indent=2) |
|
|
print(f"Saved leaderboard result: {out_path}") |
|
|
|
|
|
def process_single_dataset(dataset_config, args, per_model_results): |
|
|
""" |
|
|
Load, normalize, evaluate a single dataset and update aggregates. |
|
|
""" |
|
|
if args.dataset and args.dataset not in dataset_config["name"]: |
|
|
return |
|
|
|
|
|
max_samples = args.max_samples if args.max_samples is not None else dataset_config.get("max_samples") |
|
|
use_streaming = dataset_config.get("use_streaming", False) |
|
|
|
|
|
dataset = load_dataset_with_limits( |
|
|
dataset_config, |
|
|
max_samples=max_samples, |
|
|
use_streaming=use_streaming |
|
|
) |
|
|
|
|
|
if dataset is None: |
|
|
return |
|
|
|
|
|
dataset = cast_audio_column_safely(dataset) |
|
|
|
|
|
dataset_final = prepare_dataset_for_evaluation(dataset, dataset_config, max_samples) |
|
|
if dataset_final is None: |
|
|
return |
|
|
|
|
|
print(dataset_final) |
|
|
print("Final size:", len(dataset_final)) |
|
|
|
|
|
full_results, avg_stats = evaluate_dataset(dataset_final) |
|
|
print("Average Statistic per model (", dataset_config["name"], "):") |
|
|
print(avg_stats) |
|
|
|
|
|
update_aggregates(per_model_results, avg_stats, dataset_config["name"]) |
|
|
|
|
|
def main(): |
|
|
args = parse_cli_args() |
|
|
|
|
|
per_model_results = {} |
|
|
|
|
|
for dataset_config in DATASETS: |
|
|
process_single_dataset(dataset_config, args, per_model_results) |
|
|
|
|
|
save_leaderboard_results(per_model_results) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|