Text Classification
Safetensors
GLiClass
text classification
nli
sentiment analysis

Load and save all gliclass in new transformers

#2
by alexneakameni - opened

In recent version of transformers 5.x, loading these models ignore all weights in silent error

transformers 5.x from_pretrained silently fails to apply checkpoint weights for unregistered model types, so we load them manually via load_state_dict.

On my way to evaluate gliclass, this is the script I used to load these models

from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from gliclass.model import GLiClassModelConfig
from transformers import AutoTokenizer
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download

MODELS = [
    "knowledgator/gliclass-base-v3.0",
    "knowledgator/gliclass-large-v3.0",
    "knowledgator/gliclass-modern-base-v3.0",
]

def load_gliclass_pipeline(model_name, device=DEVICE):
    """Load a GLiClass pipeline with manual weight loading.
    transformers 5.x from_pretrained silently fails to apply checkpoint weights
    for unregistered model types, so we load them manually via load_state_dict."""
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    config = GLiClassModelConfig.from_pretrained(model_name)
    config.pad_token_id = tokenizer.pad_token_id

    _orig_tie_weights = GLiClassModel.tie_weights
    GLiClassModel.tie_weights = lambda self, **kwargs: _orig_tie_weights(self)

    # Create model with random weights, then load checkpoint manually
    model = GLiClassModel(config)
    ckpt_path = hf_hub_download(model_name, "model.safetensors")
    model.load_state_dict(load_file(ckpt_path), strict=True)

    pipe = ZeroShotClassificationPipeline(
        model, tokenizer,
        classification_type="multi-label",
        device=device, progress_bar=False
    )
    print(f"  βœ“ {model_name}")
    return pipe

pipelines = {}
for name in MODELS:
    pipelines[name] = load_gliclass_pipeline(name)

# Sanity check
out = pipelines[MODELS[0]]("I absolutely loved this film!", ["positive", "negative", "neutral"], threshold=0.0)[0]
print("\nSanity check (base):")
for r in out:
    print(f"  {r['label']:<12} {r['score']:.3f}")

Sign up or log in to comment