akagtag commited on
Commit
19d9b40
Β·
1 Parent(s): 5756499

Fix ZeroGPU startup and local GPU inference path

Browse files
modules/m1_lipsync.py CHANGED
@@ -28,7 +28,7 @@ class LipSyncModule:
28
 
29
  def _load_model(self) -> None:
30
  ckpt_path = hf_hub_download(
31
- repo_id="AkshatAgarwal/LipFD-checkpoint",
32
  filename="ckpt.pth",
33
  cache_dir=self.cache_dir,
34
  )
 
28
 
29
  def _load_model(self) -> None:
30
  ckpt_path = hf_hub_download(
31
+ repo_id="akagtag/LipFD-checkpoint",
32
  filename="ckpt.pth",
33
  cache_dir=self.cache_dir,
34
  )
modules/m3_sstgnn.py CHANGED
@@ -11,7 +11,7 @@ class SSTGNNModule:
11
  self.load_error = ""
12
  try:
13
  ckpt_path = hf_hub_download(
14
- repo_id="AkshatAgarwal/SSTGNN-deepfake",
15
  filename="sstgnn_best.pt",
16
  cache_dir=cache_dir,
17
  )
 
11
  self.load_error = ""
12
  try:
13
  ckpt_path = hf_hub_download(
14
+ repo_id="akagtag/SSTGNN-deepfake",
15
  filename="sstgnn_best.pt",
16
  cache_dir=cache_dir,
17
  )
packages.txt CHANGED
@@ -1,3 +1,6 @@
1
  ffmpeg
2
  libsndfile1-dev
3
-
 
 
 
 
1
  ffmpeg
2
  libsndfile1-dev
3
+ libgles2
4
+ libegl1
5
+ libgl1
6
+ libglib2.0-0
src/api/main.py CHANGED
@@ -14,9 +14,26 @@ import numpy as np
14
  from dotenv import load_dotenv
15
  from fastapi import FastAPI, File, HTTPException, UploadFile
16
  from fastapi.middleware.cors import CORSMiddleware
17
- from fastapi.responses import HTMLResponse, RedirectResponse
18
  from PIL import ExifTags, Image
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  from src.continual.novelty_detector import NoveltyDetector
21
  from src.continual.registry import GeneratorRegistry
22
  from src.engines.coherence.engine import CoherenceEngine
@@ -252,8 +269,23 @@ def _model_inventory() -> dict[str, object]:
252
 
253
 
254
  @app.get("/", response_class=HTMLResponse)
255
- async def root() -> RedirectResponse:
256
- return RedirectResponse(url="/gradio", status_code=307)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
 
259
  @app.on_event("startup")
@@ -262,11 +294,6 @@ async def preload() -> None:
262
  logger.info("Skipping startup preload in test mode")
263
  return
264
 
265
- backend = get_inference_backend()
266
- if backend in {"hf", "runpod"}:
267
- logger.info("Skipping local model preload for backend=%s", backend)
268
- return
269
-
270
  logger.info("Preloading models...")
271
  # Keep model imports/loads sequential to avoid lazy-import race issues.
272
  await asyncio.to_thread(_fp._ensure)
@@ -275,6 +302,115 @@ async def preload() -> None:
275
  logger.info("Model preload complete")
276
 
277
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  @app.get("/health")
279
  async def health() -> dict:
280
  return {
@@ -485,40 +621,12 @@ async def detect_image(file: UploadFile = File(...)) -> DetectionResponse:
485
  except Exception as exc:
486
  logger.warning("RunPod image route failed, falling back to local image inference: %s", exc)
487
 
488
- try:
489
- image = Image.open(io.BytesIO(data)).convert("RGB")
490
- except Exception as exc:
491
- raise HTTPException(status_code=422, detail=f"Could not decode image: {exc}") from exc
492
-
493
- await _ensure_models_loaded()
494
-
495
- fp, co, st = await asyncio.gather(
496
- asyncio.to_thread(_fp.run, image),
497
- asyncio.to_thread(_co.run, image),
498
- asyncio.to_thread(_st.run, image),
499
- )
500
-
501
- elapsed_ms = (time.monotonic() - t0) * 1000
502
- engine_results = _assign_processing_time([fp, co, st], elapsed_ms)
503
-
504
- verdict, conf, generator = fuse(engine_results, is_video=False)
505
- if _is_test_mode():
506
- explanation = _fallback_explanation(verdict, conf, generator)
507
- else:
508
- explanation = await asyncio.to_thread(explain, verdict, conf, engine_results, generator)
509
-
510
- response = DetectionResponse(
511
- verdict=verdict,
512
- confidence=conf,
513
- attributed_generator=generator,
514
- explanation=explanation,
515
- processing_time_ms=elapsed_ms,
516
- engine_breakdown=engine_results,
517
- )
518
- return _apply_metadata_keyword_signal(
519
- response,
520
- filename=file.filename,
521
- metadata_text=metadata_text,
522
  )
523
 
524
 
@@ -581,57 +689,11 @@ async def detect_video(file: UploadFile = File(...)) -> DetectionResponse:
581
  except Exception as exc:
582
  logger.warning("RunPod route failed, falling back to local video inference: %s", exc)
583
 
584
- with tempfile.NamedTemporaryFile(
585
- suffix=_video_temp_suffix(file.content_type, file.filename),
586
- delete=False,
587
- ) as tmp:
588
- tmp.write(data)
589
- tmp_path = tmp.name
590
-
591
- try:
592
- try:
593
- frames = await asyncio.to_thread(extract_video_frames, tmp_path, MAX_FRAMES)
594
- except Exception as exc:
595
- raise HTTPException(status_code=422, detail=f"Video decode failed: {exc}") from exc
596
-
597
- if not frames:
598
- raise HTTPException(status_code=422, detail="Could not extract frames")
599
-
600
- await _ensure_models_loaded()
601
- try:
602
- fp, co, st = await asyncio.gather(
603
- asyncio.to_thread(_fp.run_video, frames),
604
- asyncio.to_thread(_co.run_video, frames, tmp_path),
605
- asyncio.to_thread(_st.run_video, frames),
606
- )
607
- except Exception as exc:
608
- logger.exception("Video engine inference failed")
609
- raise HTTPException(
610
- status_code=503,
611
- detail=f"Video analysis failed: {type(exc).__name__}: {exc}",
612
- ) from exc
613
- finally:
614
- Path(tmp_path).unlink(missing_ok=True)
615
-
616
- elapsed_ms = (time.monotonic() - t0) * 1000
617
- engine_results = _assign_processing_time([fp, co, st], elapsed_ms)
618
-
619
- verdict, conf, generator = fuse(engine_results, is_video=True)
620
- if _is_test_mode():
621
- explanation = _fallback_explanation(verdict, conf, generator)
622
- else:
623
- explanation = await asyncio.to_thread(explain, verdict, conf, engine_results, generator)
624
-
625
- response = DetectionResponse(
626
- verdict=verdict,
627
- confidence=conf,
628
- attributed_generator=generator,
629
- explanation=explanation,
630
- processing_time_ms=elapsed_ms,
631
- engine_breakdown=engine_results,
632
- )
633
- return _apply_metadata_keyword_signal(
634
- response,
635
- filename=file.filename,
636
- metadata_text=metadata_text,
637
  )
 
14
  from dotenv import load_dotenv
15
  from fastapi import FastAPI, File, HTTPException, UploadFile
16
  from fastapi.middleware.cors import CORSMiddleware
17
+ from fastapi.responses import HTMLResponse
18
  from PIL import ExifTags, Image
19
 
20
+ try:
21
+ import spaces # type: ignore
22
+ except ImportError:
23
+ spaces = None
24
+
25
+
26
+ if spaces is None or not hasattr(spaces, "GPU"):
27
+ class _SpacesShim:
28
+ @staticmethod
29
+ def GPU(*args, **kwargs):
30
+ def decorator(fn):
31
+ return fn
32
+
33
+ return decorator
34
+
35
+ spaces = _SpacesShim()
36
+
37
  from src.continual.novelty_detector import NoveltyDetector
38
  from src.continual.registry import GeneratorRegistry
39
  from src.engines.coherence.engine import CoherenceEngine
 
269
 
270
 
271
  @app.get("/", response_class=HTMLResponse)
272
+ async def root() -> HTMLResponse:
273
+ return HTMLResponse(
274
+ """
275
+ <html>
276
+ <head><title>GenAI-DeepDetect</title></head>
277
+ <body style="font-family: sans-serif; max-width: 720px; margin: 48px auto; line-height: 1.5;">
278
+ <h1>GenAI-DeepDetect</h1>
279
+ <p>The FastAPI backend is running.</p>
280
+ <ul>
281
+ <li><a href="/gradio">Open Gradio UI</a></li>
282
+ <li><a href="/docs">Open API Docs</a></li>
283
+ <li><a href="/health">Health Check</a></li>
284
+ </ul>
285
+ </body>
286
+ </html>
287
+ """
288
+ )
289
 
290
 
291
  @app.on_event("startup")
 
294
  logger.info("Skipping startup preload in test mode")
295
  return
296
 
 
 
 
 
 
297
  logger.info("Preloading models...")
298
  # Keep model imports/loads sequential to avoid lazy-import race issues.
299
  await asyncio.to_thread(_fp._ensure)
 
302
  logger.info("Model preload complete")
303
 
304
 
305
+ @spaces.GPU(duration=120)
306
+ def _local_detect_image_sync(
307
+ data: bytes,
308
+ filename: str | None,
309
+ metadata_text: str,
310
+ elapsed_start: float,
311
+ ) -> DetectionResponse:
312
+ try:
313
+ image = Image.open(io.BytesIO(data)).convert("RGB")
314
+ except Exception as exc:
315
+ raise HTTPException(status_code=422, detail=f"Could not decode image: {exc}") from exc
316
+
317
+ _fp._ensure()
318
+ _co._ensure()
319
+ _st._ensure()
320
+
321
+ fp = _fp.run(image)
322
+ co = _co.run(image)
323
+ st = _st.run(image)
324
+
325
+ elapsed_ms = (time.monotonic() - elapsed_start) * 1000
326
+ engine_results = _assign_processing_time([fp, co, st], elapsed_ms)
327
+
328
+ verdict, conf, generator = fuse(engine_results, is_video=False)
329
+ if _is_test_mode():
330
+ explanation = _fallback_explanation(verdict, conf, generator)
331
+ else:
332
+ explanation = explain(verdict, conf, engine_results, generator)
333
+
334
+ response = DetectionResponse(
335
+ verdict=verdict,
336
+ confidence=conf,
337
+ attributed_generator=generator,
338
+ explanation=explanation,
339
+ processing_time_ms=elapsed_ms,
340
+ engine_breakdown=engine_results,
341
+ )
342
+ return _apply_metadata_keyword_signal(
343
+ response,
344
+ filename=filename,
345
+ metadata_text=metadata_text,
346
+ )
347
+
348
+
349
+ @spaces.GPU(duration=180)
350
+ def _local_detect_video_sync(
351
+ data: bytes,
352
+ content_type: str | None,
353
+ filename: str | None,
354
+ metadata_text: str,
355
+ elapsed_start: float,
356
+ ) -> DetectionResponse:
357
+ with tempfile.NamedTemporaryFile(
358
+ suffix=_video_temp_suffix(content_type, filename),
359
+ delete=False,
360
+ ) as tmp:
361
+ tmp.write(data)
362
+ tmp_path = tmp.name
363
+
364
+ try:
365
+ try:
366
+ frames = extract_video_frames(tmp_path, MAX_FRAMES)
367
+ except Exception as exc:
368
+ raise HTTPException(status_code=422, detail=f"Video decode failed: {exc}") from exc
369
+
370
+ if not frames:
371
+ raise HTTPException(status_code=422, detail="Could not extract frames")
372
+
373
+ _fp._ensure()
374
+ _co._ensure()
375
+ _st._ensure()
376
+
377
+ try:
378
+ fp = _fp.run_video(frames)
379
+ co = _co.run_video(frames, tmp_path)
380
+ st = _st.run_video(frames)
381
+ except Exception as exc:
382
+ logger.exception("Video engine inference failed")
383
+ raise HTTPException(
384
+ status_code=503,
385
+ detail=f"Video analysis failed: {type(exc).__name__}: {exc}",
386
+ ) from exc
387
+ finally:
388
+ Path(tmp_path).unlink(missing_ok=True)
389
+
390
+ elapsed_ms = (time.monotonic() - elapsed_start) * 1000
391
+ engine_results = _assign_processing_time([fp, co, st], elapsed_ms)
392
+
393
+ verdict, conf, generator = fuse(engine_results, is_video=True)
394
+ if _is_test_mode():
395
+ explanation = _fallback_explanation(verdict, conf, generator)
396
+ else:
397
+ explanation = explain(verdict, conf, engine_results, generator)
398
+
399
+ response = DetectionResponse(
400
+ verdict=verdict,
401
+ confidence=conf,
402
+ attributed_generator=generator,
403
+ explanation=explanation,
404
+ processing_time_ms=elapsed_ms,
405
+ engine_breakdown=engine_results,
406
+ )
407
+ return _apply_metadata_keyword_signal(
408
+ response,
409
+ filename=filename,
410
+ metadata_text=metadata_text,
411
+ )
412
+
413
+
414
  @app.get("/health")
415
  async def health() -> dict:
416
  return {
 
621
  except Exception as exc:
622
  logger.warning("RunPod image route failed, falling back to local image inference: %s", exc)
623
 
624
+ return await asyncio.to_thread(
625
+ _local_detect_image_sync,
626
+ data,
627
+ file.filename,
628
+ metadata_text,
629
+ t0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
630
  )
631
 
632
 
 
689
  except Exception as exc:
690
  logger.warning("RunPod route failed, falling back to local video inference: %s", exc)
691
 
692
+ return await asyncio.to_thread(
693
+ _local_detect_video_sync,
694
+ data,
695
+ file.content_type,
696
+ file.filename,
697
+ metadata_text,
698
+ t0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
699
  )
src/engines/coherence/engine.py CHANGED
@@ -13,6 +13,11 @@ from typing import Optional
13
  import numpy as np
14
  from PIL import Image
15
 
 
 
 
 
 
16
  from src.types import EngineResult
17
 
18
  logger = logging.getLogger(__name__)
@@ -28,6 +33,10 @@ _resnet_fallback = None # torchvision ResNet-18 used when facenet-pytorch unav
28
  _transform_fallback = None
29
 
30
 
 
 
 
 
31
  def _skip_model_loads() -> bool:
32
  return os.environ.get("GENAI_SKIP_MODEL_LOAD", "").strip().lower() in {
33
  "1",
@@ -130,7 +139,7 @@ def _load() -> None:
130
  import torch # type: ignore
131
 
132
  _torch = torch
133
- _device = "cuda" if torch.cuda.is_available() else "cpu"
134
  logger.info(" Coherence device: %s", _device)
135
 
136
  from facenet_pytorch import InceptionResnetV1, MTCNN # type: ignore
@@ -150,7 +159,7 @@ def _load() -> None:
150
  import torchvision.transforms as tv_transforms # type: ignore
151
 
152
  _torch = torch
153
- _device = "cuda" if torch.cuda.is_available() else "cpu"
154
 
155
  model = tv_models.resnet18(weights=tv_models.ResNet18_Weights.DEFAULT)
156
  model.fc = torch.nn.Identity() # strip classifier β†’ 512-d embedding
 
13
  import numpy as np
14
  from PIL import Image
15
 
16
+ try:
17
+ import spaces # type: ignore # noqa: F401
18
+ except ImportError:
19
+ spaces = None
20
+
21
  from src.types import EngineResult
22
 
23
  logger = logging.getLogger(__name__)
 
33
  _transform_fallback = None
34
 
35
 
36
+ def _prefer_cuda(torch_module) -> bool:
37
+ return torch_module.cuda.is_available() or os.environ.get("SPACE_ID", "").startswith("akagtag/")
38
+
39
+
40
  def _skip_model_loads() -> bool:
41
  return os.environ.get("GENAI_SKIP_MODEL_LOAD", "").strip().lower() in {
42
  "1",
 
139
  import torch # type: ignore
140
 
141
  _torch = torch
142
+ _device = "cuda" if _prefer_cuda(torch) else "cpu"
143
  logger.info(" Coherence device: %s", _device)
144
 
145
  from facenet_pytorch import InceptionResnetV1, MTCNN # type: ignore
 
159
  import torchvision.transforms as tv_transforms # type: ignore
160
 
161
  _torch = torch
162
+ _device = "cuda" if _prefer_cuda(torch) else "cpu"
163
 
164
  model = tv_models.resnet18(weights=tv_models.ResNet18_Weights.DEFAULT)
165
  model.fc = torch.nn.Identity() # strip classifier β†’ 512-d embedding
src/engines/fingerprint/engine.py CHANGED
@@ -17,13 +17,22 @@ import torch
17
  from PIL import Image
18
  from transformers import CLIPModel, CLIPProcessor
19
 
 
 
 
 
 
20
  from src.types import EngineResult
21
 
22
  logger = logging.getLogger(__name__)
23
  CACHE = os.environ.get("MODEL_CACHE_DIR", "/tmp/models")
24
 
25
- # GPU device selection β€” A100 / any CUDA GPU if available, else CPU
26
- _DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
27
  _PIPELINE_DEVICE = 0 if _DEVICE == "cuda" else -1 # HF pipeline convention
28
 
29
  DETECTOR_CANDIDATES = [
 
17
  from PIL import Image
18
  from transformers import CLIPModel, CLIPProcessor
19
 
20
+ try:
21
+ import spaces # type: ignore # noqa: F401
22
+ except ImportError:
23
+ spaces = None
24
+
25
  from src.types import EngineResult
26
 
27
  logger = logging.getLogger(__name__)
28
  CACHE = os.environ.get("MODEL_CACHE_DIR", "/tmp/models")
29
 
30
+ def _prefer_cuda() -> bool:
31
+ return torch.cuda.is_available() or os.environ.get("SPACE_ID", "").startswith("akagtag/")
32
+
33
+
34
+ # GPU device selection β€” ZeroGPU emulates CUDA outside the decorated section.
35
+ _DEVICE = "cuda" if _prefer_cuda() else "cpu"
36
  _PIPELINE_DEVICE = 0 if _DEVICE == "cuda" else -1 # HF pipeline convention
37
 
38
  DETECTOR_CANDIDATES = [
src/engines/sstgnn/engine.py CHANGED
@@ -12,13 +12,22 @@ import numpy as np
12
  import torch
13
  from PIL import Image
14
 
 
 
 
 
 
15
  from src.types import EngineResult
16
 
17
  logger = logging.getLogger(__name__)
18
  CACHE = os.environ.get("MODEL_CACHE_DIR", "/tmp/models")
19
 
 
 
 
 
20
  # GPU device selection
21
- _DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
  _PIPELINE_DEVICE = 0 if _DEVICE == "cuda" else -1 # HF pipeline convention
23
 
24
  _lock = threading.Lock()
 
12
  import torch
13
  from PIL import Image
14
 
15
+ try:
16
+ import spaces # type: ignore # noqa: F401
17
+ except ImportError:
18
+ spaces = None
19
+
20
  from src.types import EngineResult
21
 
22
  logger = logging.getLogger(__name__)
23
  CACHE = os.environ.get("MODEL_CACHE_DIR", "/tmp/models")
24
 
25
+ def _prefer_cuda() -> bool:
26
+ return torch.cuda.is_available() or os.environ.get("SPACE_ID", "").startswith("akagtag/")
27
+
28
+
29
  # GPU device selection
30
+ _DEVICE = "cuda" if _prefer_cuda() else "cpu"
31
  _PIPELINE_DEVICE = 0 if _DEVICE == "cuda" else -1 # HF pipeline convention
32
 
33
  _lock = threading.Lock()
tests/test_api.py CHANGED
@@ -51,9 +51,10 @@ def test_health_models_returns_inventory(client):
51
  # ── GET / ─────────────────────────────────────────────────────────────────────
52
 
53
  def test_root_returns_html(client):
54
- r = client.get("/", follow_redirects=False)
55
- assert r.status_code == 307
56
- assert r.headers["location"] == "/gradio"
 
57
 
58
 
59
  # ── POST /detect/image ────────────────────────────────────────────────────────
 
51
  # ── GET / ─────────────────────────────────────────────────────────────────────
52
 
53
  def test_root_returns_html(client):
54
+ r = client.get("/")
55
+ assert r.status_code == 200
56
+ assert "text/html" in r.headers["content-type"]
57
+ assert "Open Gradio UI" in r.text
58
 
59
 
60
  # ── POST /detect/image ────────────────────────────────────────────────────────
tests/test_zero_gpu_contract.py CHANGED
@@ -40,12 +40,12 @@ def test_app_mounts_gradio_onto_fastapi():
40
  assert 'uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)' in source
41
 
42
 
43
- def test_api_root_redirects_to_gradio():
44
  source = (ROOT / "src" / "api" / "main.py").read_text(encoding="utf-8")
45
  tree = ast.parse(source)
46
 
47
- assert "RedirectResponse" in source
48
- assert 'return RedirectResponse(url="/gradio", status_code=307)' in source
49
  assert any(
50
  isinstance(node, ast.AsyncFunctionDef) and node.name == "root"
51
  for node in tree.body
 
40
  assert 'uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)' in source
41
 
42
 
43
+ def test_api_root_serves_html_landing_page():
44
  source = (ROOT / "src" / "api" / "main.py").read_text(encoding="utf-8")
45
  tree = ast.parse(source)
46
 
47
+ assert "HTMLResponse" in source
48
+ assert 'Open Gradio UI' in source
49
  assert any(
50
  isinstance(node, ast.AsyncFunctionDef) and node.name == "root"
51
  for node in tree.body