lataon's picture
update: interface
dba24db
raw
history blame
5.27 kB
import pandas as pd
from utils.load_model import run_hubert_base, run_whisper, run_model, run_timit
from utils.audio_process import calculate_error_rate, load_audio
from utils.cmu_process import clean_cmu, cmu_to_ipa
def set_output(model, pre_pho, ref_pho, duration, per, score):
return {
"model": model,
"phonemes": pre_pho,
"ref_phonemes": ref_pho,
"duration": duration,
"PER": per,
"score": score
}
# Map model names to their runner functions
MODEL_RUNNERS = {
"HuBERT-Base": run_hubert_base,
"Whisper": run_whisper,
"HuBERT fine-tuned": run_model,
"Timit": run_timit
}
def get_output(model, wav, reference_phoneme):
"""
Run the given model, compute error rate, and return formatted output.
"""
if model not in MODEL_RUNNERS:
raise ValueError(f"Unknown model: {model}")
run_func = MODEL_RUNNERS[model]
phonemes, dur = run_func(wav)
per, score = calculate_error_rate(reference_phoneme, phonemes)
return set_output(model, phonemes, reference_phoneme, dur, per, score)
def benchmark_all(example):
"""
Run all models on a single dataset example.
"""
# Load waveform manually to avoid datasets' torchcodec dependency
wav = load_audio(example["audio"])
reference_phoneme = example["phonetic"]
reference_phoneme = cmu_to_ipa(clean_cmu(reference_phoneme))
# Run all models
results = [
get_output("HuBERT-Base", wav, reference_phoneme),
get_output("Whisper", wav, reference_phoneme),
get_output("HuBERT fine-tuned", wav, reference_phoneme),
get_output("Timit", wav, reference_phoneme),
]
return pd.DataFrame(results)
def benchmark_dataset(dataset):
"""
Run benchmark_all on each sample and compute average PER and duration per model.
"""
all_results = []
for example in dataset:
df = benchmark_all(example)
all_results.append(df)
full_df = pd.concat(all_results, ignore_index=True)
# Compute average PER and duration per model
avg_stats = (
full_df.groupby("model")[["PER", "duration"]]
.mean()
.reset_index()
.rename(columns={"PER": "Average PER", "duration": "Average Duration (s)"})
)
return full_df, avg_stats
from datasets import load_dataset, Audio
DATASET_LIST = [
"mirfan899/phoneme_asr",
"mirfan899/kids_phoneme_md",
]
def main():
field = "phonetic"
# Collect per-model metrics across datasets
per_model_results = {}
for dataset_name in DATASET_LIST:
try:
dataset = load_dataset(dataset_name, split="train")
except Exception as e:
print(f"[warn] skip dataset {dataset_name}: {e}")
continue
try:
dataset = dataset.cast_column("audio", Audio(decode=False))
except Exception:
pass
unique_texts = dataset.unique(field)
print("Unique phonetic strings (", dataset_name, "):", len(unique_texts))
dataset_unique = dataset.filter(lambda x: x[field] in unique_texts)
def is_valid(example):
phoneme_tokens = example[field].split()
return len(phoneme_tokens) >= 10
dataset_filtered = dataset_unique.filter(is_valid)
dataset_final = dataset_filtered.shuffle(seed=42).select(range(min(100, len(dataset_filtered))))
print(dataset_final)
print("Final size:", len(dataset_final))
full_results, avg_stats = benchmark_dataset(dataset_final.select(range(min(10, len(dataset_final)))))
print("Average Statistic per model (", dataset_name, "):")
print(avg_stats)
# Use dataset name as key (extract the actual name part)
dataset_key = dataset_name.split("/")[-1] # Get the last part after the slash
for _, row in avg_stats.iterrows():
model_name = str(row["model"]).replace(" ", "-")
per = float(row["Average PER"]) if row["Average PER"] is not None else None
avg_dur = float(row["Average Duration (s)"]) if row["Average Duration (s)"] is not None else None
if model_name not in per_model_results:
per_model_results[model_name] = {}
per_model_results[model_name][dataset_key] = {"per": per, "avg_duration": avg_dur}
# Save results for leaderboard consumption (one JSON per model)
import json, os, time
# results_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "eval-results")
results_dir = os.path.join("eval-results")
os.makedirs(results_dir, exist_ok=True)
timestamp = int(time.time())
for model_name, task_results in per_model_results.items():
org_model = f"{model_name}"
payload = {
"config": {
"model_name": org_model,
"model_dtype": "float32",
"model_sha": ""
},
"results": task_results
}
out_path = os.path.join(results_dir, f"results_{timestamp}_{model_name}.json")
with open(out_path, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
print(f"Saved leaderboard result: {out_path}")
if __name__ == "__main__":
main()