| from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.staticfiles import StaticFiles |
| from pydantic import BaseModel |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import torch |
| import logging |
| import aiofiles |
| import json |
| from typing import List, Optional |
| from datetime import datetime |
|
|
| |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| |
| app = FastAPI() |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
| |
| MODEL_NAME = "mistralai/Mistral-8x7B" |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16) |
|
|
| |
| search_history = [] |
|
|
| |
| class InferenceRequest(BaseModel): |
| prompt: str |
| max_length: Optional[int] = 100 |
|
|
| class InferenceResponse(BaseModel): |
| generated_text: str |
| timestamp: str |
|
|
| class SearchHistoryResponse(BaseModel): |
| history: List[InferenceResponse] |
|
|
| |
| @app.post("/inference") |
| async def run_inference(request: InferenceRequest): |
| """Run inference using the AI model.""" |
| try: |
| inputs = tokenizer(request.prompt, return_tensors="pt") |
| outputs = model.generate(inputs.input_ids, max_length=request.max_length) |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| |
| search_entry = InferenceResponse( |
| generated_text=generated_text, |
| timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| ) |
| search_history.append(search_entry) |
|
|
| logger.info(f"Inference completed for prompt: {request.prompt}") |
| return search_entry |
| except Exception as e: |
| logger.error(f"Error during inference: {e}") |
| raise HTTPException(status_code=500, detail="Failed to run inference.") |
|
|
| @app.get("/search-history") |
| async def get_search_history(): |
| """Get the history of all searches.""" |
| return SearchHistoryResponse(history=search_history) |
|
|
| |
| @app.websocket("/ws") |
| async def websocket_endpoint(websocket: WebSocket): |
| await websocket.accept() |
| try: |
| while True: |
| data = await websocket.receive_text() |
| request = json.loads(data) |
| prompt = request.get("prompt") |
| max_length = request.get("max_length", 100) |
|
|
| if not prompt: |
| await websocket.send_text(json.dumps({"error": "Prompt is required."})) |
| continue |
|
|
| |
| inputs = tokenizer(prompt, return_tensors="pt") |
| outputs = model.generate(inputs.input_ids, max_length=max_length) |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| |
| response = { |
| "generated_text": generated_text, |
| "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| } |
| await websocket.send_text(json.dumps(response)) |
|
|
| except WebSocketDisconnect: |
| logger.info("WebSocket disconnected.") |
| except Exception as e: |
| logger.error(f"WebSocket error: {e}") |
| await websocket.send_text(json.dumps({"error": str(e)})) |
|
|
| |
| @app.get("/") |
| async def serve_frontend(): |
| """Serve the frontend HTML file.""" |
| async with aiofiles.open("static/index.html", mode="r") as file: |
| return await file.read() |