Spaces:
Sleeping
Sleeping
| # prompt_rewrite.py | |
| import os | |
| from huggingface_hub import InferenceClient | |
| # ------------------------- | |
| # Utils | |
| # ------------------------- | |
| def get_caption_language(prompt: str) -> str: | |
| ranges = [ | |
| ('\u4e00', '\u9fff'), # CJK Unified Ideographs | |
| ] | |
| for char in prompt: | |
| if any(start <= char <= end for start, end in ranges): | |
| return 'zh' | |
| return 'en' | |
| def _get_client(): | |
| api_key = os.environ.get("HF_TOKEN") or os.environ.get("hf") | |
| if not api_key: | |
| raise EnvironmentError("HF_TOKEN is not set.") | |
| return InferenceClient( | |
| provider="cerebras", | |
| api_key=api_key, | |
| ) | |
| # ------------------------- | |
| # Core engine | |
| # ------------------------- | |
| def polish_prompt(original_prompt: str, system_prompt: str) -> str: | |
| client = _get_client() | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": original_prompt}, | |
| ] | |
| try: | |
| completion = client.chat.completions.create( | |
| model="Qwen/Qwen3-235B-A22B-Instruct-2507", | |
| messages=messages, | |
| ) | |
| return completion.choices[0].message.content.strip().replace("\n", " ") | |
| except Exception as e: | |
| print(f"[prompt_rewrite] Error: {e}") | |
| return original_prompt | |
| # ------------------------- | |
| # System prompts | |
| # ------------------------- | |
| SYSTEM_PROMPT_EN = """\ | |
| # Image Prompt Rewriting Expert | |
| You are a world-class expert in crafting image prompts... | |
| (PEGÁS ACÁ TU PROMPT EN INGLÉS TAL CUAL) | |
| """ | |
| SYSTEM_PROMPT_ZH = """\ | |
| # 图像 Prompt 改写专家 | |
| 你是一位世界顶级的图像 Prompt 构建专家... | |
| (PEGÁS ACÁ TU PROMPT EN CHINO TAL CUAL) | |
| """ | |
| # ------------------------- | |
| # Public API | |
| # ------------------------- | |
| def polish_prompt_en(original_prompt: str) -> str: | |
| return polish_prompt(original_prompt.strip(), SYSTEM_PROMPT_EN) | |
| def polish_prompt_zh(original_prompt: str) -> str: | |
| return polish_prompt(original_prompt.strip(), SYSTEM_PROMPT_ZH) | |
| def rewrite(prompt: str) -> str: | |
| lang = get_caption_language(prompt) | |
| if lang == "zh": | |
| return polish_prompt_zh(prompt) | |
| return polish_prompt_en(prompt) | |