Spaces:
Running
Running
| import torch | |
| import torchaudio | |
| import gradio as gr | |
| from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq | |
| # device setup | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # load model + processor | |
| model_name = "ibm-granite/granite-speech-3.3-8b" | |
| processor = AutoProcessor.from_pretrained(model_name) | |
| tokenizer = processor.tokenizer | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| model_name, device_map=device, torch_dtype=torch.bfloat16 | |
| ) | |
| today_str = date.today().strftime("%B %d, %Y") | |
| system_prompt = ( | |
| "Knowledge Cutoff Date: April 2024.\n" | |
| f"Today's Date: {today_str}.\n" | |
| "You are Granite, developed by IBM. You are a helpful AI assistant." | |
| ) | |
| def transcribe(audio_file): | |
| # load wav file | |
| wav, sr = torchaudio.load(audio_file, normalize=True) | |
| if wav.shape[0] != 1 or sr != 16000: | |
| # resample + convert to mono if needed | |
| wav = torch.mean(wav, dim=0, keepdim=True) # mono | |
| wav = torchaudio.functional.resample(wav, sr, 16000) | |
| sr = 16000 | |
| # user prompt | |
| user_prompt = "<|audio|>can you transcribe the speech into a written format?" | |
| chat = [ | |
| dict(role="system", content=system_prompt), | |
| dict(role="user", content=user_prompt), | |
| ] | |
| prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) | |
| # run model | |
| model_inputs = processor(prompt, wav, sampling_rate=sr, device=device, return_tensors="pt").to(device) | |
| model_outputs = model.generate( | |
| **model_inputs, | |
| max_new_tokens=200, | |
| do_sample=False, | |
| num_beams=1 | |
| ) | |
| # strip prompt tokens | |
| num_input_tokens = model_inputs["input_ids"].shape[-1] | |
| new_tokens = torch.unsqueeze(model_outputs[0, num_input_tokens:], dim=0) | |
| output_text = tokenizer.batch_decode( | |
| new_tokens, add_special_tokens=False, skip_special_tokens=True | |
| ) | |
| return output_text[0].strip() | |
| # Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Granite 3.3 Speech-to-Text Demo") | |
| with gr.Row(): | |
| audio_input = gr.Audio(type="filepath", label="Upload Audio (16kHz mono preferred)") | |
| output_text = gr.Textbox(label="Transcription", lines=5) | |
| transcribe_btn = gr.Button("Transcribe") | |
| transcribe_btn.click(fn=transcribe, inputs=audio_input, outputs=output_text) | |
| demo.launch() |