Spaces:
Sleeping
Sleeping
Update tagger.py
Browse files
tagger.py
CHANGED
|
@@ -1,83 +1,92 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
|
| 4 |
-
Caption with BLIP and derive simple tags (no POS/NLTK).
|
| 5 |
-
|
| 6 |
-
- Tags are first unique non-stopword tokens from the caption.
|
| 7 |
-
- Sidecar saved to ./data/<stem>.json
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
import os
|
| 11 |
-
import
|
| 12 |
-
|
| 13 |
-
import
|
| 14 |
-
import re as _re
|
| 15 |
-
from typing import List, Tuple
|
| 16 |
|
| 17 |
-
import torch
|
| 18 |
from PIL import Image
|
| 19 |
from transformers import BlipForConditionalGeneration, BlipProcessor
|
| 20 |
|
| 21 |
-
#
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
# Device + singletons
|
| 26 |
-
_device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 27 |
-
_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
| 28 |
-
_model = BlipForConditionalGeneration.from_pretrained(
|
| 29 |
-
"Salesforce/blip-image-captioning-base"
|
| 30 |
-
).to(_device)
|
| 31 |
|
| 32 |
-
#
|
| 33 |
_STOP = {
|
| 34 |
-
"a","an","the","and","or","
|
| 35 |
-
"
|
| 36 |
-
"is","are","
|
| 37 |
-
"
|
| 38 |
-
"up","down","left","right"
|
| 39 |
}
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
continue
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
|
| 54 |
def tag_pil_image(
|
| 55 |
img: Image.Image,
|
| 56 |
stem: str,
|
| 57 |
*,
|
| 58 |
top_k: int = 5,
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
with torch.inference_mode():
|
| 68 |
-
ids = _model.generate(**inputs, max_length=30)
|
| 69 |
-
caption = _processor.decode(ids[0], skip_special_tokens=True)
|
| 70 |
-
|
| 71 |
-
# tags
|
| 72 |
-
tags = _caption_to_tags(caption, top_k)
|
| 73 |
-
|
| 74 |
-
# sidecar
|
| 75 |
-
payload = {
|
| 76 |
-
"caption": caption,
|
| 77 |
-
"tags": tags,
|
| 78 |
-
"timestamp": _dt.datetime.now(_dt.timezone.utc).isoformat(),
|
| 79 |
-
}
|
| 80 |
-
(CAP_TAG_DIR / f"{safe_stem}.json").write_text(_json.dumps(payload, indent=2))
|
| 81 |
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import os
|
| 5 |
+
import re
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import List, Optional
|
|
|
|
|
|
|
| 8 |
|
|
|
|
| 9 |
from PIL import Image
|
| 10 |
from transformers import BlipForConditionalGeneration, BlipProcessor
|
| 11 |
|
| 12 |
+
# -------------------- config --------------------
|
| 13 |
+
MODEL_ID = "Salesforce/blip-image-captioning-base"
|
| 14 |
+
DATA_DIR = Path(os.getenv("DATA_DIR", "/app/data"))
|
| 15 |
+
DATA_DIR.mkdir(parents=True, exist_ok=True) # safe if already exists
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
+
# light, built-in stopword list (keeps us NLTK-free)
|
| 18 |
_STOP = {
|
| 19 |
+
"a", "an", "the", "and", "or", "of", "to", "in", "on", "with", "near",
|
| 20 |
+
"at", "over", "under", "by", "from", "for", "into", "along", "through",
|
| 21 |
+
"is", "are", "be", "being", "been", "it", "its", "this", "that",
|
| 22 |
+
"as", "while", "than", "then", "there", "here",
|
|
|
|
| 23 |
}
|
| 24 |
|
| 25 |
+
# -------------------- model cache --------------------
|
| 26 |
+
_processor: Optional[BlipProcessor] = None
|
| 27 |
+
_model: Optional[BlipForConditionalGeneration] = None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def init_models() -> None:
|
| 31 |
+
"""Load BLIP once (idempotent)."""
|
| 32 |
+
global _processor, _model
|
| 33 |
+
if _processor is None or _model is None:
|
| 34 |
+
_processor = BlipProcessor.from_pretrained(MODEL_ID)
|
| 35 |
+
_model = BlipForConditionalGeneration.from_pretrained(MODEL_ID)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# -------------------- core functionality --------------------
|
| 39 |
+
def caption_image(img: Image.Image, max_len: int = 30) -> str:
|
| 40 |
+
"""Generate a short caption for the image."""
|
| 41 |
+
assert _processor and _model, "Call init_models() first"
|
| 42 |
+
inputs = _processor(images=img, return_tensors="pt")
|
| 43 |
+
ids = _model.generate(**inputs, max_length=max_len)
|
| 44 |
+
return _processor.decode(ids[0], skip_special_tokens=True)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
_TAG_RE = re.compile(r"[a-z0-9-]+")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def caption_to_tags(caption: str, top_k: int = 5) -> List[str]:
|
| 51 |
+
"""
|
| 52 |
+
Convert a caption into up to K simple tags:
|
| 53 |
+
- normalize to lowercase alnum/hyphen tokens
|
| 54 |
+
- remove tiny stopword list
|
| 55 |
+
- keep order of appearance, dedup
|
| 56 |
+
"""
|
| 57 |
+
tags: List[str] = []
|
| 58 |
+
seen = set()
|
| 59 |
+
for tok in _TAG_RE.findall(caption.lower()):
|
| 60 |
+
if tok in _STOP or tok in seen:
|
| 61 |
continue
|
| 62 |
+
seen.add(tok)
|
| 63 |
+
tags.append(tok)
|
| 64 |
+
if len(tags) >= top_k:
|
| 65 |
+
break
|
| 66 |
+
return tags
|
| 67 |
+
|
| 68 |
|
| 69 |
def tag_pil_image(
|
| 70 |
img: Image.Image,
|
| 71 |
stem: str,
|
| 72 |
*,
|
| 73 |
top_k: int = 5,
|
| 74 |
+
write_sidecar: bool = True,
|
| 75 |
+
) -> List[str]:
|
| 76 |
+
"""
|
| 77 |
+
Return ONLY the tags list.
|
| 78 |
+
(We optionally persist a sidecar JSON with caption + tags.)
|
| 79 |
+
"""
|
| 80 |
+
cap = caption_image(img)
|
| 81 |
+
tags = caption_to_tags(cap, top_k=top_k)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
+
if write_sidecar:
|
| 84 |
+
payload = {"caption": cap, "tags": tags}
|
| 85 |
+
sidecar = DATA_DIR / f"{stem}.json"
|
| 86 |
+
try:
|
| 87 |
+
sidecar.write_text(json.dumps(payload, indent=2))
|
| 88 |
+
except Exception:
|
| 89 |
+
# best-effort; tagging should still succeed
|
| 90 |
+
pass
|
| 91 |
|
| 92 |
+
return tags
|