| |
| """ |
| Main application file for SHASHAΒ AI (Gradio). |
| Only change: enlarge logo to 120β―Γβ―120β―px. |
| """ |
|
|
| import gradio as gr |
| from typing import Optional, Dict, List, Tuple, Any |
| import os |
|
|
| |
| from constants import ( |
| HTML_SYSTEM_PROMPT, |
| TRANSFORMERS_JS_SYSTEM_PROMPT, |
| AVAILABLE_MODELS, |
| DEMO_LIST, |
| ) |
| from hf_client import get_inference_client |
| from tavily_search import enhance_query_with_search |
| from utils import ( |
| extract_text_from_file, |
| extract_website_content, |
| apply_search_replace_changes, |
| history_to_messages, |
| history_to_chatbot_messages, |
| remove_code_block, |
| parse_transformers_js_output, |
| format_transformers_js_output, |
| ) |
| from deploy import send_to_sandbox |
|
|
| History = List[Tuple[str, str]] |
| Model = Dict[str, Any] |
|
|
| SUPPORTED_LANGUAGES = [ |
| "python","c","cpp","markdown","latex","json","html","css","javascript","jinja2", |
| "typescript","yaml","dockerfile","shell","r","sql","sql-msSQL","sql-mySQL", |
| "sql-mariaDB","sql-sqlite","sql-cassandra","sql-plSQL","sql-hive","sql-pgSQL", |
| "sql-gql","sql-gpSQL","sql-sparkSQL","sql-esper" |
| ] |
|
|
| def get_model_details(name:str)->Optional[Model]: |
| return next((m for m in AVAILABLE_MODELS if m["name"]==name), None) |
|
|
| def generation_code( |
| query:Optional[str], |
| file:Optional[str], |
| website_url:Optional[str], |
| current_model:Model, |
| enable_search:bool, |
| language:str, |
| history:Optional[History], |
| )->Tuple[str,History,str,List[Dict[str,str]]]: |
| query = query or "" |
| history = history or [] |
| try: |
| system_prompt = ( |
| HTML_SYSTEM_PROMPT if language=="html" else |
| TRANSFORMERS_JS_SYSTEM_PROMPT if language=="transformers.js" |
| else f"You are an expert {language} developer. Write clean, idiomatic {language} code." |
| ) |
| model_id = current_model["id"] |
| provider = ( |
| "openai" if model_id.startswith("openai/") or model_id in {"gpt-4","gpt-3.5-turbo"} |
| else "gemini" if model_id.startswith(("gemini/","google/")) |
| else "fireworks-ai" if model_id.startswith("fireworks-ai/") |
| else "auto" |
| ) |
|
|
| msgs = history_to_messages(history, system_prompt) |
| context = query |
| if file: |
| context += f"\n\n[Attached file]\n{extract_text_from_file(file)[:5000]}" |
| if website_url: |
| wtext = extract_website_content(website_url) |
| if not wtext.startswith("Error"): |
| context += f"\n\n[Website content]\n{wtext[:8000]}" |
| msgs.append({"role":"user","content":enhance_query_with_search(context, enable_search)}) |
|
|
| client = get_inference_client(model_id, provider) |
| resp = client.chat.completions.create(model=model_id, messages=msgs,max_tokens=16000,temperature=0.1) |
| content = resp.choices[0].message.content |
|
|
| except Exception as e: |
| err = f"β **Error:**\n```\n{e}\n```" |
| history.append((query, err)) |
| return "", history, "", history_to_chatbot_messages(history) |
|
|
| if language=="transformers.js": |
| files = parse_transformers_js_output(content) |
| code = format_transformers_js_output(files) |
| preview = send_to_sandbox(files.get("index.html","")) |
| else: |
| cleaned = remove_code_block(content) |
| code = apply_search_replace_changes(history[-1][1], cleaned) if history and not history[-1][1].startswith("β") else cleaned |
| preview = send_to_sandbox(code) if language=="html" else "" |
|
|
| new_hist = history + [(query, code)] |
| return code, new_hist, preview, history_to_chatbot_messages(new_hist) |
|
|
| |
| CUSTOM_CSS = """ |
| body{font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Roboto,sans-serif;} |
| #main_title{text-align:center;font-size:2.5rem;margin-top:.5rem;} |
| #subtitle{text-align:center;color:#4a5568;margin-bottom:2rem;} |
| .gradio-container{background-color:#f7fafc;} |
| #gen_btn{box-shadow:0 4px 6px rgba(0,0,0,0.1);} |
| """ |
|
|
| LOGO_PATH = "assets/logo.png" |
|
|
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), |
| css=CUSTOM_CSS, |
| title="Shasha AI") as demo: |
| history_state = gr.State([]) |
| initial_model = AVAILABLE_MODELS[0] |
| model_state = gr.State(initial_model) |
|
|
| |
| if os.path.exists(LOGO_PATH): |
| gr.Image(value=LOGO_PATH, height=120, width=120, |
| show_label=False, container=False) |
|
|
| gr.Markdown("# π Shasha AI", elem_id="main_title") |
| gr.Markdown("Your AI partner for generating, modifying, and understanding code.", elem_id="subtitle") |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### 1. Select Model") |
| model_dd = gr.Dropdown([m["name"] for m in AVAILABLE_MODELS], |
| value=initial_model["name"], label="AI Model") |
|
|
| gr.Markdown("### 2. Provide Context") |
| with gr.Tabs(): |
| with gr.Tab("π Prompt"): |
| prompt_in = gr.Textbox(lines=7, placeholder="Describe your request...", show_label=False) |
| with gr.Tab("π File"): |
| file_in = gr.File(type="filepath") |
| with gr.Tab("π Website"): |
| url_in = gr.Textbox(placeholder="https://example.com") |
|
|
| gr.Markdown("### 3. Configure Output") |
| lang_dd = gr.Dropdown(SUPPORTED_LANGUAGES, value="html", label="Target Language") |
| search_chk = gr.Checkbox(label="Enable Web Search") |
| with gr.Row(): |
| clr_btn = gr.Button("Clear Session", variant="secondary") |
| gen_btn = gr.Button("Generate Code", variant="primary", elem_id="gen_btn") |
|
|
| with gr.Column(scale=2): |
| with gr.Tabs(): |
| with gr.Tab("π» Code"): |
| code_out = gr.Code(language="html", interactive=True) |
| with gr.Tab("ποΈ Live Preview"): |
| preview_out = gr.HTML() |
| with gr.Tab("π History"): |
| chat_out = gr.Chatbot(type="messages") |
|
|
| model_dd.change(lambda n: get_model_details(n) or initial_model, |
| inputs=[model_dd], outputs=[model_state]) |
|
|
| gen_btn.click( |
| fn=generation_code, |
| inputs=[prompt_in, file_in, url_in, model_state, search_chk, lang_dd, history_state], |
| outputs=[code_out, history_state, preview_out, chat_out], |
| ) |
|
|
| clr_btn.click( |
| lambda: ("", None, "", [], "", "", []), |
| outputs=[prompt_in, file_in, url_in, history_state, code_out, preview_out, chat_out], |
| queue=False, |
| ) |
|
|
| if __name__ == "__main__": |
| demo.queue().launch() |
|
|