--- license: apache-2.0 language: en pipeline_tag: text-generation tags: - gemma3 - multitask - qlora - customer-service - fashion - complaint-analysis --- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/66ad89f2685fc4c1c2397398/JJJCAzMpxOmvDxsKnItYv.png) # Fine-tuned Gemma-3 4B for Multi-Task Customer Service Complaint Analysis This repository contains a `google/gemma-3-4b-it` model that has been fine-tuned using QLoRA for a comprehensive, multi-task customer service application. The model was trained on a synthetic dataset of fashion-related customer complaints to perform both causal language modeling (generating a structured JSON response) and several classification tasks simultaneously via specialized classification heads. This model is designed to act as an "agent" that can ingest a customer complaint and its surrounding context, then output a complete analysis covering multiple business-critical dimensions. ![image/png](https://cdn-uploads.huggingface.co/production/uploads/66ad89f2685fc4c1c2397398/N5zvIkKPO107nKtMT2SS5.png) ## Model Capabilities ![image/png](https://cdn-uploads.huggingface.co/production/uploads/66ad89f2685fc4c1c2397398/ARLiJ9Vw9x6lfqsQ4p-4Q.png) This model is trained to perform 8 classification tasks simultaneously based on the input complaint: 1. **`is_actionable`**: Determines if the complaint requires a direct action (boolean). 2. **`complaint_category`**: Classifies the complaint into one of 11 categories (e.g., "Sizing Issue", "Damaged Item"). 3. **`decision_recommendation`**: Recommends a course of action from 11 options (e.g., "Full_Refund_With_Return"). 4. **`info_complete`**: Assesses if all necessary information is present to resolve the issue (boolean). 5. **`tone`**: Classifies the required tone for a formal response (e.g., "Empathetic_Standard"). 6. **`refund_percentage`**: Suggests a specific refund percentage (0-100). 7. **`sentiment`**: Detects the customer's sentiment (e.g., "negative", "very_negative"). 8. **`aggression`**: Detects the level of aggression in the customer's message. ![image/png](https://cdn-uploads.huggingface.co/production/uploads/66ad89f2685fc4c1c2397398/JbaH_J37D4KC5wXnFkZfA.png) ## How to Use (for Classification) This model uses custom classification heads and requires the `GemmaComplaintResolver` wrapper class from the training notebook to be used correctly. ```python import torch from transformers import AutoTokenizer, AutoConfig from peft import PeftModel from huggingface_hub import hf_hub_download import os # You must have the GemmaComplaintResolver class definition in your environment. # Assuming it's defined as it was in the training notebook... # --- Configuration --- repo_id = "ShovalBenjer/gemma-3-4b-fashion-multitask_A4000_v7" device = "cuda" if torch.cuda.is_available() else "cpu" # --- 1. Load Tokenizer and Model Config --- tokenizer = AutoTokenizer.from_pretrained(repo_id) config = AutoConfig.from_pretrained("google/gemma-3-4b-it", trust_remote_code=True) # Define the label structure the model was trained with num_labels_dict = { "is_actionable": 2, "complaint_category": 11, "decision_recommendation": 11, "info_complete": 2, "tone": 7, "refund_percentage": 13, "sentiment": 6, "aggression": 5 } # --- 2. Instantiate the Custom Model Wrapper --- # IMPORTANT: This assumes the GemmaComplaintResolver class is defined. model = GemmaComplaintResolver( base_model_name_or_path="google/gemma-3-4b-it", num_labels_dict=num_labels_dict, model_config_for_base_loading=config, ) # --- 3. Load the Fine-Tuned Weights --- # a) Load the classification head weights weights_path = hf_hub_download(repo_id=repo_id, filename="classification_heads.pth") model.load_state_dict(torch.load(weights_path, map_location='cpu'), strict=False) # b) Apply the LoRA adapter model = PeftModel.from_pretrained(model, repo_id) # --- 4. Prepare for Inference --- # Cast to appropriate dtype and move to device compute_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 model.to(dtype=compute_dtype).to(device).eval() # --- 5. Run Inference --- customer_complaint = "The t-shirt I ordered arrived with a huge hole in it! I'm very angry and want a full refund immediately." # The model expects the full prompt structure used during training. # In this notebook, the pre-processed column was 'text_for_lm'. # The structure inside 'text_for_lm' was: user\n{complaint_details}\nmodel\n{json_output} # For inference on just the classification heads, we only need the prompt part. input_text = f"user\\n{customer_complaint}\\nmodel\\n" inputs = tokenizer(input_text, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) # --- 6. Decode a Prediction --- # Example: Get the predicted complaint category category_logits = outputs['logits_complaint_category'] predicted_category_id = torch.argmax(category_logits, dim=-1).item() complaint_categories = ["Sizing Issue", "Damaged Item", "Not as Described", "Shipping Problem", "Policy Inquiry", "Late Delivery", "Wrong Item Received", "Quality Issue", "Return Process Issue", "Other", "N/A"] predicted_category = complaint_categories[predicted_category_id] print(f"Customer Complaint: '{customer_complaint}'") print(f"Predicted Complaint Category: {predicted_category}")