import os import torch from functools import lru_cache import gradio as gr from transformers import AutoProcessor from transformers.models.audio_flamingo.modeling_audio_flamingo import AudioFlamingoForConditionalGeneration from typing import Optional from urllib.parse import urlparse import requests MODEL_ID = "nvidia/audio-flamingo-3-hf" @lru_cache(maxsize=1) def load_model_and_processor(device: Optional[str] = None): device = device or ("cuda" if torch.cuda.is_available() else "cpu") processor = AutoProcessor.from_pretrained(MODEL_ID) if device == "cpu": model = AudioFlamingoForConditionalGeneration.from_pretrained(MODEL_ID) model.to("cpu") else: model = AudioFlamingoForConditionalGeneration.from_pretrained(MODEL_ID, device_map="auto") model.eval() return processor, model def download_if_url(path_or_url: str) -> str: if path_or_url.startswith("http://") or path_or_url.startswith("https://"): parsed = urlparse(path_or_url) local_name = os.path.join("/tmp", os.path.basename(parsed.path) or "input_audio.wav") try: r = requests.get(path_or_url, timeout=30) r.raise_for_status() with open(local_name, "wb") as f: f.write(r.content) return local_name except Exception as e: raise RuntimeError(f"Failed to download audio: {e}") return path_or_url def run_inference(audio_path_or_url: str, instruction: str, max_new_tokens: int = 256, do_sample: bool = False, temperature: float = 0.0): try: audio_local = download_if_url(audio_path_or_url) except Exception as e: return f"Error fetching audio: {e}" processor, model = load_model_and_processor() conversation = [ { "role": "user", "content": [ {"type": "text", "text": instruction}, {"type": "audio", "path": audio_local}, ], } ] batch = processor.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_dict=True).to(model.device) generate_kwargs = {"max_new_tokens": max_new_tokens, "do_sample": do_sample, "temperature": float(temperature)} with torch.no_grad(): gen_ids = model.generate(**batch, **generate_kwargs) inp_len = batch["input_ids"].shape[1] new_tokens = gen_ids[:, inp_len:] texts = processor.batch_decode(new_tokens, skip_special_tokens=True) return texts[0] if len(texts) else "(no output)" def build_ui(): with gr.Blocks(title="Audio Flamingo 3 — Demo") as demo: gr.Markdown("# 🎧 Audio Flamingo 3 — Demo\nUpload or link audio, then give an instruction (e.g. 'Transcribe this speech').") with gr.Row(): audio_input = gr.Audio(label="Upload audio (or provide URL below)", source="upload", type="filepath") url_input = gr.Textbox(label="Or audio URL", placeholder="https://.../file.wav") instruction = gr.Textbox(label="Instruction", value="Transcribe the input speech.") with gr.Row(): max_new_tokens = gr.Slider(label="Max new tokens", minimum=16, maximum=1024, step=16, value=256) temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, step=0.01, value=0.0) do_sample = gr.Checkbox(label="Use sampling (stochastic decoding)") submit = gr.Button("Submit") output_text = gr.Textbox(label="Model output", lines=8) def on_submit(audio_path, url, instr, mnt, temp, ds): source = url.strip() if url and url.strip() else (audio_path or "") if not source: return "Please upload an audio file or provide a valid URL." return run_inference(source, instr, max_new_tokens=int(mnt), do_sample=bool(ds), temperature=float(temp)) submit.click(on_submit, inputs=[audio_input, url_input, instruction, max_new_tokens, temperature, do_sample], outputs=[output_text]) return demo def main(): demo = build_ui() demo.launch() if __name__ == "__main__": main()