Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| import requests | |
| from BiRefNet import BiRefNet | |
| # 1. Download BiRefNet weights if not present | |
| MODEL_URL = "https://huggingface.co/ZhengPeng7/BiRefNet/resolve/main/BiRefNet.pth" | |
| MODEL_PATH = "BiRefNet.pth" | |
| def download_weights(): | |
| if not os.path.exists(MODEL_PATH): | |
| print("Downloading BiRefNet weights...") | |
| r = requests.get(MODEL_URL) | |
| with open(MODEL_PATH, "wb") as f: | |
| f.write(r.content) | |
| print("Done downloading BiRefNet weights.") | |
| # 2. Load BiRefNet model | |
| def load_model(): | |
| download_weights() | |
| model = BiRefNet() | |
| state_dict = torch.load(MODEL_PATH, map_location="cpu") | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| return model | |
| bi_ref_net = load_model() | |
| # 3. Define transforms (assuming model expects 224x224 or similar, adjust if needed) | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((224, 224)), # Adjust to BiRefNet input size if different | |
| transforms.ToTensor() | |
| ]) | |
| def remove_bg(input_image): | |
| # Preprocess image | |
| image = input_image.convert("RGB") | |
| img_tensor = preprocess(image).unsqueeze(0) # Add batch dimension | |
| # Inference (no gradients needed) | |
| with torch.no_grad(): | |
| mask = bi_ref_net(img_tensor)[0, 0] # Output mask from model, shape: [H, W] | |
| # Resize mask to original image size, normalize (if needed) | |
| mask_img = transforms.ToPILImage()(mask.cpu().clamp(0, 1)) | |
| mask_img = mask_img.resize(image.size, Image.BILINEAR) | |
| # Create RGBA output by setting alpha to mask | |
| result = image.convert("RGBA") | |
| result.putalpha(mask_img) | |
| return result | |
| demo = gr.Interface( | |
| fn=remove_bg, | |
| inputs=gr.Image(type="pil", label="Input Image"), | |
| outputs=gr.Image(type="pil", label="Background Removed (PNG)"), | |
| title="Backdrop Studio - BiRefNet Background Removal", | |
| description="Upload an image to remove the background using BiRefNet AI." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |