import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch import os import spaces # 读取增强指令(从 wan_instruction.md 文件中提取) def extract_instruction(md_path): """ 从 wan_instruction.md 中提取用于AI视频提示词增强的指令部分。 这里假设指令部分以 '## Prompt Recipe' 开始,直到 '## Prompt Bank' 结束。 """ if not os.path.exists(md_path): return "" with open(md_path, "r", encoding="utf-8") as f: lines = f.readlines() start, end = None, None for idx, line in enumerate(lines): if "## Prompt Recipe" in line: start = idx if "## Prompt Bank" in line: end = idx break if start is not None and end is not None: # 只取中间部分 instruction = "".join(lines[start:end]) else: # 如果找不到,取前2000字符兜底 instruction = "".join(lines[:2000]) # 额外补充输出要求 instruction += ( "\n\n请你作为AI视频提示词增强助手," "根据上述最佳实践,扩展和丰富用户输入的简短视频提示词,使其更详细、更具表现力、更符合AI视频生成的需求。" "输出仅为提示词,内容丰富、细致、流畅,适合直接用于AI视频生成。不要输出中文解释。" "如用户输入已包含部分要素,可自动补全缺失部分。" "输出语言和用户输入的语言保持一致。" "\n\n用户输入:" ) return instruction # 路径可根据实际部署情况调整 WAN_INSTRUCTION_MD = os.path.join(os.path.dirname(__file__), "wan_instruction.md") WAN_INSTRUCTION = extract_instruction(WAN_INSTRUCTION_MD) # 加载 deepseek-r1 14b 模型和分词器 MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) @spaces.GPU() def enhance_prompt(user_prompt): # 构建完整的输入 full_prompt = WAN_INSTRUCTION + user_prompt.strip() + "\n" input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids.to(model.device) with torch.no_grad(): output = model.generate( input_ids, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.95, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id ) result = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True) return result.strip() with gr.Blocks(title="AI视频提示词增强器") as demo: gr.Markdown( """ # AI视频提示词增强器 输入简短的视频提示词,自动智能扩写为详细、专业、适合AI视频生成的提示词。 - 基于 deepseek-r1 14b 大模型 - 增强规则自动读取自 wan_instruction.md(从[官方文档](https://alidocs.dingtalk.com/i/nodes/EpGBa2Lm8aZxe5myC99MelA2WgN7R35y?utm_scene=team_space)提取) """ ) with gr.Row(): with gr.Column(): user_input = gr.Textbox(label="输入你的视频提示词", lines=4, placeholder="如:A girl dancing on the stage") submit_btn = gr.Button("一键增强") with gr.Column(): enhanced_output = gr.Textbox(label="增强后的提示词", lines=10) submit_btn.click(enhance_prompt, inputs=user_input, outputs=enhanced_output) if __name__ == "__main__": demo.launch()