from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel, ConfigDict from typing import List, Optional, Union from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer import torch import threading import json import requests import gc from PIL import Image from io import BytesIO app = FastAPI(title="Qwen3.5 Multimodal API", version="3.0.0") # --- CORS Middleware --- app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- Global Model State --- ALLOWED_MODELS = [ "Qwen/Qwen3.5-2B", "Qwen/Qwen3.5-0.8B" ] DEFAULT_MODEL = "Qwen/Qwen3.5-2B" current_model_id = None model = None processor = None # --- Pydantic Schemas --- class ImageURL(BaseModel): url: str class ContentPart(BaseModel): type: str text: Optional[str] = None image_url: Optional[ImageURL] = None class ChatMessage(BaseModel): role: str content: Union[str, List[ContentPart]] class GenerationParameters(BaseModel): temperature: float = 0.7 max_new_tokens: int = 256 class ChatCompletionRequest(BaseModel): model: Optional[str] = DEFAULT_MODEL messages: List[ChatMessage] stream: bool = False generation_params: GenerationParameters = GenerationParameters() # --- Helper Functions --- def download_image(url: str) -> Image.Image: try: response = requests.get(url, timeout=10) response.raise_for_status() return Image.open(BytesIO(response.content)).convert("RGB") except Exception as e: raise HTTPException(status_code=400, detail=f"Could not download image: {str(e)}") def format_sse(token: str, finish_reason: Optional[str] = None) -> str: payload = { "choices": [{ "delta": {"content": token}, "finish_reason": finish_reason }] } return f"data: {json.dumps(payload)}\n\n" # --- Model Management --- def load_model_global(model_id: str): global model, processor, current_model_id if current_model_id == model_id: return # Already loaded print(f"Switching to model: {model_id}...") # 1. Clear existing model to save RAM if model: del model del processor gc.collect() print("Previous model cleared.") # 2. Load new model try: processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_id, dtype=torch.float16, low_cpu_mem_usage=True, device_map="cpu", trust_remote_code=True ) model.eval() current_model_id = model_id print(f"Model {model_id} loaded successfully!") except Exception as e: print(f"!!! ERROR loading model {model_id}: {e} !!!") raise e # --- Endpoints --- @app.on_event("startup") def startup_event(): # Load default model on startup load_model_global(DEFAULT_MODEL) @app.get("/") def read_root(): return { "message": "Qwen3.5 Multimodal API is running", "current_model": current_model_id, "available_models": ALLOWED_MODELS } @app.post("/v1/chat/completions") async def chat_completions(request: ChatCompletionRequest): # 1. Validate and Load Model target_model = request.model or DEFAULT_MODEL if target_model not in ALLOWED_MODELS: raise HTTPException(status_code=400, detail=f"Model {target_model} not supported. Use one of {ALLOWED_MODELS}") try: load_model_global(target_model) except Exception as e: raise HTTPException(status_code=503, detail=f"Failed to load model: {str(e)}") if not model or not processor: raise HTTPException(status_code=503, detail="Model not loaded") try: # 2. Process Inputs conversation_history = [] images = [] for msg in request.messages: role = msg.role if isinstance(msg.content, str): conversation_history.append({"role": role, "content": msg.content}) elif isinstance(msg.content, list): processed_content = [] for part in msg.content: if part.type == "text" and part.text: processed_content.append({"type": "text", "text": part.text}) elif part.type == "image_url" and part.image_url: img = download_image(part.image_url.url) images.append(img) processed_content.append({"type": "image"}) conversation_history.append({"role": role, "content": processed_content}) # Prepare text prompt text_inputs = processor.apply_chat_template( conversation_history, tokenize=False, add_generation_prompt=True ) # Tokenize model_inputs = processor( text=[text_inputs], images=images if images else None, return_tensors="pt", padding=True ).to(model.device) # CRITICAL FIX: Remove unexpected arg causing crash if "mm_token_type_ids" in model_inputs: model_inputs.pop("mm_token_type_ids") # Prepare Generation Args gen_kwargs = request.generation_params.model_dump(exclude_none=True) if "max_new_tokens" not in gen_kwargs: gen_kwargs["max_new_tokens"] = 256 # --- Non-Streaming --- if not request.stream: with torch.no_grad(): generated_ids = model.generate(**model_inputs, **gen_kwargs) generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] response_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return { "id": "cmpl-" + str(torch.randint(0, 1000000, (1,)).item()), "model": current_model_id, "choices": [{ "index": 0, "message": {"role": "assistant", "content": response_text}, "finish_reason": "stop" }] } # --- Streaming --- else: streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True) full_kwargs = {**model_inputs, "streamer": streamer, **gen_kwargs} thread = threading.Thread(target=model.generate, kwargs=full_kwargs) thread.start() # FIX: Buffer logic to prevent echoing the prompt prompt_text = text_inputs buffer = "" is_prompt_skipped = False async def event_generator(): nonlocal buffer, is_prompt_skipped try: for new_text in streamer: if not is_prompt_skipped: buffer += new_text if buffer.startswith(prompt_text): if len(buffer) >= len(prompt_text): new_text = buffer[len(prompt_text):] is_prompt_skipped = True buffer = "" if not new_text: continue else: is_prompt_skipped = True if new_text: yield format_sse(new_text) yield format_sse("", finish_reason="stop") except Exception as e: yield format_sse(f"Error: {str(e)}", finish_reason="error") return StreamingResponse(event_generator(), media_type="text/event-stream") except Exception as e: import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")