stephenebert commited on
Commit
0e857c8
·
verified ·
1 Parent(s): cb8b918

Update tagger.py

Browse files
Files changed (1) hide show
  1. tagger.py +72 -63
tagger.py CHANGED
@@ -1,83 +1,92 @@
1
  from __future__ import annotations
2
 
3
- """
4
- Caption with BLIP and derive simple tags (no POS/NLTK).
5
-
6
- - Tags are first unique non-stopword tokens from the caption.
7
- - Sidecar saved to ./data/<stem>.json
8
- """
9
-
10
  import os
11
- import datetime as _dt
12
- import json as _json
13
- import pathlib as _pl
14
- import re as _re
15
- from typing import List, Tuple
16
 
17
- import torch
18
  from PIL import Image
19
  from transformers import BlipForConditionalGeneration, BlipProcessor
20
 
21
- # Writable sidecar directory (writable on Spaces)
22
- CAP_TAG_DIR = _pl.Path(os.environ.get("CAP_TAG_DIR", "./data")).resolve()
23
- CAP_TAG_DIR.mkdir(parents=True, exist_ok=True)
24
-
25
- # Device + singletons
26
- _device = "cuda" if torch.cuda.is_available() else "cpu"
27
- _processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
28
- _model = BlipForConditionalGeneration.from_pretrained(
29
- "Salesforce/blip-image-captioning-base"
30
- ).to(_device)
31
 
32
- # very small stopword set to clean tags
33
  _STOP = {
34
- "a","an","the","and","or","but","if","then","so","to","from",
35
- "of","in","on","at","by","for","with","without","into","out",
36
- "is","are","was","were","be","being","been","it","its","this",
37
- "that","these","those","as","over","under","near","above","below",
38
- "up","down","left","right"
39
  }
40
 
41
- def _caption_to_tags(caption: str, k: int) -> List[str]:
42
- tokens = _re.findall(r"[a-z0-9-]+", caption.lower())
43
- out, seen = [], set()
44
- for w in tokens:
45
- if len(w) <= 2 or w in _STOP:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  continue
47
- if w not in seen:
48
- out.append(w)
49
- seen.add(w)
50
- if len(out) >= k:
51
- break
52
- return out
53
 
54
  def tag_pil_image(
55
  img: Image.Image,
56
  stem: str,
57
  *,
58
  top_k: int = 5,
59
- ) -> Tuple[str, List[str]]:
60
- # sanitize stem for filesystem
61
- safe_stem = _re.sub(r"[^A-Za-z0-9_.-]+", "_", stem) or "upload"
62
-
63
- # caption
64
- inputs = _processor(images=img, return_tensors="pt")
65
- if _device == "cuda":
66
- inputs = {k: v.to(_device) for k, v in inputs.items()}
67
- with torch.inference_mode():
68
- ids = _model.generate(**inputs, max_length=30)
69
- caption = _processor.decode(ids[0], skip_special_tokens=True)
70
-
71
- # tags
72
- tags = _caption_to_tags(caption, top_k)
73
-
74
- # sidecar
75
- payload = {
76
- "caption": caption,
77
- "tags": tags,
78
- "timestamp": _dt.datetime.now(_dt.timezone.utc).isoformat(),
79
- }
80
- (CAP_TAG_DIR / f"{safe_stem}.json").write_text(_json.dumps(payload, indent=2))
81
 
82
- return caption, tags
 
 
 
 
 
 
 
83
 
 
 
1
  from __future__ import annotations
2
 
3
+ import json
 
 
 
 
 
 
4
  import os
5
+ import re
6
+ from pathlib import Path
7
+ from typing import List, Optional
 
 
8
 
 
9
  from PIL import Image
10
  from transformers import BlipForConditionalGeneration, BlipProcessor
11
 
12
+ # -------------------- config --------------------
13
+ MODEL_ID = "Salesforce/blip-image-captioning-base"
14
+ DATA_DIR = Path(os.getenv("DATA_DIR", "/app/data"))
15
+ DATA_DIR.mkdir(parents=True, exist_ok=True) # safe if already exists
 
 
 
 
 
 
16
 
17
+ # light, built-in stopword list (keeps us NLTK-free)
18
  _STOP = {
19
+ "a", "an", "the", "and", "or", "of", "to", "in", "on", "with", "near",
20
+ "at", "over", "under", "by", "from", "for", "into", "along", "through",
21
+ "is", "are", "be", "being", "been", "it", "its", "this", "that",
22
+ "as", "while", "than", "then", "there", "here",
 
23
  }
24
 
25
+ # -------------------- model cache --------------------
26
+ _processor: Optional[BlipProcessor] = None
27
+ _model: Optional[BlipForConditionalGeneration] = None
28
+
29
+
30
+ def init_models() -> None:
31
+ """Load BLIP once (idempotent)."""
32
+ global _processor, _model
33
+ if _processor is None or _model is None:
34
+ _processor = BlipProcessor.from_pretrained(MODEL_ID)
35
+ _model = BlipForConditionalGeneration.from_pretrained(MODEL_ID)
36
+
37
+
38
+ # -------------------- core functionality --------------------
39
+ def caption_image(img: Image.Image, max_len: int = 30) -> str:
40
+ """Generate a short caption for the image."""
41
+ assert _processor and _model, "Call init_models() first"
42
+ inputs = _processor(images=img, return_tensors="pt")
43
+ ids = _model.generate(**inputs, max_length=max_len)
44
+ return _processor.decode(ids[0], skip_special_tokens=True)
45
+
46
+
47
+ _TAG_RE = re.compile(r"[a-z0-9-]+")
48
+
49
+
50
+ def caption_to_tags(caption: str, top_k: int = 5) -> List[str]:
51
+ """
52
+ Convert a caption into up to K simple tags:
53
+ - normalize to lowercase alnum/hyphen tokens
54
+ - remove tiny stopword list
55
+ - keep order of appearance, dedup
56
+ """
57
+ tags: List[str] = []
58
+ seen = set()
59
+ for tok in _TAG_RE.findall(caption.lower()):
60
+ if tok in _STOP or tok in seen:
61
  continue
62
+ seen.add(tok)
63
+ tags.append(tok)
64
+ if len(tags) >= top_k:
65
+ break
66
+ return tags
67
+
68
 
69
  def tag_pil_image(
70
  img: Image.Image,
71
  stem: str,
72
  *,
73
  top_k: int = 5,
74
+ write_sidecar: bool = True,
75
+ ) -> List[str]:
76
+ """
77
+ Return ONLY the tags list.
78
+ (We optionally persist a sidecar JSON with caption + tags.)
79
+ """
80
+ cap = caption_image(img)
81
+ tags = caption_to_tags(cap, top_k=top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ if write_sidecar:
84
+ payload = {"caption": cap, "tags": tags}
85
+ sidecar = DATA_DIR / f"{stem}.json"
86
+ try:
87
+ sidecar.write_text(json.dumps(payload, indent=2))
88
+ except Exception:
89
+ # best-effort; tagging should still succeed
90
+ pass
91
 
92
+ return tags