Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import os | |
| import spaces | |
| # 读取增强指令(从 wan_instruction.md 文件中提取) | |
| def extract_instruction(md_path): | |
| """ | |
| 从 wan_instruction.md 中提取用于AI视频提示词增强的指令部分。 | |
| 这里假设指令部分以 '## Prompt Recipe' 开始,直到 '## Prompt Bank' 结束。 | |
| """ | |
| if not os.path.exists(md_path): | |
| return "" | |
| with open(md_path, "r", encoding="utf-8") as f: | |
| lines = f.readlines() | |
| start, end = None, None | |
| for idx, line in enumerate(lines): | |
| if "## Prompt Recipe" in line: | |
| start = idx | |
| if "## Prompt Bank" in line: | |
| end = idx | |
| break | |
| if start is not None and end is not None: | |
| # 只取中间部分 | |
| instruction = "".join(lines[start:end]) | |
| else: | |
| # 如果找不到,取前2000字符兜底 | |
| instruction = "".join(lines[:2000]) | |
| # 额外补充输出要求 | |
| instruction += ( | |
| "\n\n请你作为AI视频提示词增强助手," | |
| "根据上述最佳实践,扩展和丰富用户输入的简短视频提示词,使其更详细、更具表现力、更符合AI视频生成的需求。" | |
| "输出仅为提示词,内容丰富、细致、流畅,适合直接用于AI视频生成。不要输出中文解释。" | |
| "如用户输入已包含部分要素,可自动补全缺失部分。" | |
| "输出语言和用户输入的语言保持一致。" | |
| "\n\n用户输入:" | |
| ) | |
| return instruction | |
| # 路径可根据实际部署情况调整 | |
| WAN_INSTRUCTION_MD = os.path.join(os.path.dirname(__file__), "wan_instruction.md") | |
| WAN_INSTRUCTION = extract_instruction(WAN_INSTRUCTION_MD) | |
| # 加载 deepseek-r1 14b 模型和分词器 | |
| MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| def enhance_prompt(user_prompt): | |
| # 构建完整的输入 | |
| full_prompt = WAN_INSTRUCTION + user_prompt.strip() + "\n" | |
| input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids.to(model.device) | |
| with torch.no_grad(): | |
| output = model.generate( | |
| input_ids, | |
| max_new_tokens=256, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.95, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| result = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True) | |
| return result.strip() | |
| with gr.Blocks(title="AI视频提示词增强器") as demo: | |
| gr.Markdown( | |
| """ | |
| # AI视频提示词增强器 | |
| 输入简短的视频提示词,自动智能扩写为详细、专业、适合AI视频生成的提示词。 | |
| - 基于 deepseek-r1 14b 大模型 | |
| - 增强规则自动读取自 wan_instruction.md(从[官方文档](https://alidocs.dingtalk.com/i/nodes/EpGBa2Lm8aZxe5myC99MelA2WgN7R35y?utm_scene=team_space)提取) | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| user_input = gr.Textbox(label="输入你的视频提示词", lines=4, placeholder="如:A girl dancing on the stage") | |
| submit_btn = gr.Button("一键增强") | |
| with gr.Column(): | |
| enhanced_output = gr.Textbox(label="增强后的提示词", lines=10) | |
| submit_btn.click(enhance_prompt, inputs=user_input, outputs=enhanced_output) | |
| if __name__ == "__main__": | |
| demo.launch() | |