import gradio as gr from briarmbg import BriaRMBG from PIL import Image import torch import torchvision.transforms as T # ===== Load model ===== model = BriaRMBG.from_pretrained("./") model.load_state_dict(torch.load("model.pth", map_location="cpu")) model.eval() # ===== Define preprocessing ===== transform = T.Compose([ T.Resize((512, 512)), T.ToTensor() ]) # ===== Background removal function (no auth needed) ===== def remove_background(image): # Preprocess image img = transform(image).unsqueeze(0) with torch.no_grad(): result = model(img) # Handle different result formats if isinstance(result, dict) and "pred" in result: result = result["pred"] elif isinstance(result, (tuple, list)): result = result[0] if isinstance(result, list): result = result[0] result = result.squeeze().numpy() # Apply transparency mask image = image.resize((result.shape[1], result.shape[0])) image = image.convert("RGBA") pixels = image.load() for y in range(image.height): for x in range(image.width): if result[y][x] < 0.5: pixels[x, y] = (255, 255, 255, 0) return image # ===== Launch Gradio app (public access) ===== gr.Interface( fn=remove_background, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"), title="Background Remover", allow_flagging="never" ).launch()