Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| from functools import lru_cache | |
| import gradio as gr | |
| from transformers import AutoProcessor | |
| from transformers.models.audio_flamingo.modeling_audio_flamingo import AudioFlamingoForConditionalGeneration | |
| from typing import Optional | |
| from urllib.parse import urlparse | |
| import requests | |
| MODEL_ID = "nvidia/audio-flamingo-3-hf" | |
| def load_model_and_processor(device: Optional[str] = None): | |
| device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| if device == "cpu": | |
| model = AudioFlamingoForConditionalGeneration.from_pretrained(MODEL_ID) | |
| model.to("cpu") | |
| else: | |
| model = AudioFlamingoForConditionalGeneration.from_pretrained(MODEL_ID, device_map="auto") | |
| model.eval() | |
| return processor, model | |
| def download_if_url(path_or_url: str) -> str: | |
| if path_or_url.startswith("http://") or path_or_url.startswith("https://"): | |
| parsed = urlparse(path_or_url) | |
| local_name = os.path.join("/tmp", os.path.basename(parsed.path) or "input_audio.wav") | |
| try: | |
| r = requests.get(path_or_url, timeout=30) | |
| r.raise_for_status() | |
| with open(local_name, "wb") as f: | |
| f.write(r.content) | |
| return local_name | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to download audio: {e}") | |
| return path_or_url | |
| def run_inference(audio_path_or_url: str, instruction: str, max_new_tokens: int = 256, do_sample: bool = False, temperature: float = 0.0): | |
| try: | |
| audio_local = download_if_url(audio_path_or_url) | |
| except Exception as e: | |
| return f"Error fetching audio: {e}" | |
| processor, model = load_model_and_processor() | |
| conversation = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": instruction}, | |
| {"type": "audio", "path": audio_local}, | |
| ], | |
| } | |
| ] | |
| batch = processor.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_dict=True).to(model.device) | |
| generate_kwargs = {"max_new_tokens": max_new_tokens, "do_sample": do_sample, "temperature": float(temperature)} | |
| with torch.no_grad(): | |
| gen_ids = model.generate(**batch, **generate_kwargs) | |
| inp_len = batch["input_ids"].shape[1] | |
| new_tokens = gen_ids[:, inp_len:] | |
| texts = processor.batch_decode(new_tokens, skip_special_tokens=True) | |
| return texts[0] if len(texts) else "(no output)" | |
| def build_ui(): | |
| with gr.Blocks(title="Audio Flamingo 3 — Demo") as demo: | |
| gr.Markdown("# 🎧 Audio Flamingo 3 — Demo\nUpload or link audio, then give an instruction (e.g. 'Transcribe this speech').") | |
| with gr.Row(): | |
| audio_input = gr.Audio(label="Upload audio (or provide URL below)", source="upload", type="filepath") | |
| url_input = gr.Textbox(label="Or audio URL", placeholder="https://.../file.wav") | |
| instruction = gr.Textbox(label="Instruction", value="Transcribe the input speech.") | |
| with gr.Row(): | |
| max_new_tokens = gr.Slider(label="Max new tokens", minimum=16, maximum=1024, step=16, value=256) | |
| temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, step=0.01, value=0.0) | |
| do_sample = gr.Checkbox(label="Use sampling (stochastic decoding)") | |
| submit = gr.Button("Submit") | |
| output_text = gr.Textbox(label="Model output", lines=8) | |
| def on_submit(audio_path, url, instr, mnt, temp, ds): | |
| source = url.strip() if url and url.strip() else (audio_path or "") | |
| if not source: | |
| return "Please upload an audio file or provide a valid URL." | |
| return run_inference(source, instr, max_new_tokens=int(mnt), do_sample=bool(ds), temperature=float(temp)) | |
| submit.click(on_submit, inputs=[audio_input, url_input, instruction, max_new_tokens, temperature, do_sample], outputs=[output_text]) | |
| return demo | |
| def main(): | |
| demo = build_ui() | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |