PatoFlamejanteTV's picture
Update app.py
5c9ef1d verified
import os
import torch
from functools import lru_cache
import gradio as gr
from transformers import AutoProcessor
from transformers.models.audio_flamingo.modeling_audio_flamingo import AudioFlamingoForConditionalGeneration
from typing import Optional
from urllib.parse import urlparse
import requests
MODEL_ID = "nvidia/audio-flamingo-3-hf"
@lru_cache(maxsize=1)
def load_model_and_processor(device: Optional[str] = None):
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
processor = AutoProcessor.from_pretrained(MODEL_ID)
if device == "cpu":
model = AudioFlamingoForConditionalGeneration.from_pretrained(MODEL_ID)
model.to("cpu")
else:
model = AudioFlamingoForConditionalGeneration.from_pretrained(MODEL_ID, device_map="auto")
model.eval()
return processor, model
def download_if_url(path_or_url: str) -> str:
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
parsed = urlparse(path_or_url)
local_name = os.path.join("/tmp", os.path.basename(parsed.path) or "input_audio.wav")
try:
r = requests.get(path_or_url, timeout=30)
r.raise_for_status()
with open(local_name, "wb") as f:
f.write(r.content)
return local_name
except Exception as e:
raise RuntimeError(f"Failed to download audio: {e}")
return path_or_url
def run_inference(audio_path_or_url: str, instruction: str, max_new_tokens: int = 256, do_sample: bool = False, temperature: float = 0.0):
try:
audio_local = download_if_url(audio_path_or_url)
except Exception as e:
return f"Error fetching audio: {e}"
processor, model = load_model_and_processor()
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": instruction},
{"type": "audio", "path": audio_local},
],
}
]
batch = processor.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_dict=True).to(model.device)
generate_kwargs = {"max_new_tokens": max_new_tokens, "do_sample": do_sample, "temperature": float(temperature)}
with torch.no_grad():
gen_ids = model.generate(**batch, **generate_kwargs)
inp_len = batch["input_ids"].shape[1]
new_tokens = gen_ids[:, inp_len:]
texts = processor.batch_decode(new_tokens, skip_special_tokens=True)
return texts[0] if len(texts) else "(no output)"
def build_ui():
with gr.Blocks(title="Audio Flamingo 3 — Demo") as demo:
gr.Markdown("# 🎧 Audio Flamingo 3 — Demo\nUpload or link audio, then give an instruction (e.g. 'Transcribe this speech').")
with gr.Row():
audio_input = gr.Audio(label="Upload audio (or provide URL below)", source="upload", type="filepath")
url_input = gr.Textbox(label="Or audio URL", placeholder="https://.../file.wav")
instruction = gr.Textbox(label="Instruction", value="Transcribe the input speech.")
with gr.Row():
max_new_tokens = gr.Slider(label="Max new tokens", minimum=16, maximum=1024, step=16, value=256)
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
do_sample = gr.Checkbox(label="Use sampling (stochastic decoding)")
submit = gr.Button("Submit")
output_text = gr.Textbox(label="Model output", lines=8)
def on_submit(audio_path, url, instr, mnt, temp, ds):
source = url.strip() if url and url.strip() else (audio_path or "")
if not source:
return "Please upload an audio file or provide a valid URL."
return run_inference(source, instr, max_new_tokens=int(mnt), do_sample=bool(ds), temperature=float(temp))
submit.click(on_submit, inputs=[audio_input, url_input, instruction, max_new_tokens, temperature, do_sample], outputs=[output_text])
return demo
def main():
demo = build_ui()
demo.launch()
if __name__ == "__main__":
main()