import os import gradio as gr import torch from torchvision import transforms from PIL import Image import requests from BiRefNet import BiRefNet # 1. Download BiRefNet weights if not present MODEL_URL = "https://huggingface.co/ZhengPeng7/BiRefNet/resolve/main/BiRefNet.pth" MODEL_PATH = "BiRefNet.pth" def download_weights(): if not os.path.exists(MODEL_PATH): print("Downloading BiRefNet weights...") r = requests.get(MODEL_URL) with open(MODEL_PATH, "wb") as f: f.write(r.content) print("Done downloading BiRefNet weights.") # 2. Load BiRefNet model def load_model(): download_weights() model = BiRefNet() state_dict = torch.load(MODEL_PATH, map_location="cpu") model.load_state_dict(state_dict) model.eval() return model bi_ref_net = load_model() # 3. Define transforms (assuming model expects 224x224 or similar, adjust if needed) preprocess = transforms.Compose([ transforms.Resize((224, 224)), # Adjust to BiRefNet input size if different transforms.ToTensor() ]) def remove_bg(input_image): # Preprocess image image = input_image.convert("RGB") img_tensor = preprocess(image).unsqueeze(0) # Add batch dimension # Inference (no gradients needed) with torch.no_grad(): mask = bi_ref_net(img_tensor)[0, 0] # Output mask from model, shape: [H, W] # Resize mask to original image size, normalize (if needed) mask_img = transforms.ToPILImage()(mask.cpu().clamp(0, 1)) mask_img = mask_img.resize(image.size, Image.BILINEAR) # Create RGBA output by setting alpha to mask result = image.convert("RGBA") result.putalpha(mask_img) return result demo = gr.Interface( fn=remove_bg, inputs=gr.Image(type="pil", label="Input Image"), outputs=gr.Image(type="pil", label="Background Removed (PNG)"), title="Backdrop Studio - BiRefNet Background Removal", description="Upload an image to remove the background using BiRefNet AI." ) if __name__ == "__main__": demo.launch()