Mr7Explorer's picture
Update app.py
5020db1 verified
import os
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import requests
from BiRefNet import BiRefNet
# 1. Download BiRefNet weights if not present
MODEL_URL = "https://huggingface.co/ZhengPeng7/BiRefNet/resolve/main/BiRefNet.pth"
MODEL_PATH = "BiRefNet.pth"
def download_weights():
if not os.path.exists(MODEL_PATH):
print("Downloading BiRefNet weights...")
r = requests.get(MODEL_URL)
with open(MODEL_PATH, "wb") as f:
f.write(r.content)
print("Done downloading BiRefNet weights.")
# 2. Load BiRefNet model
def load_model():
download_weights()
model = BiRefNet()
state_dict = torch.load(MODEL_PATH, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
return model
bi_ref_net = load_model()
# 3. Define transforms (assuming model expects 224x224 or similar, adjust if needed)
preprocess = transforms.Compose([
transforms.Resize((224, 224)), # Adjust to BiRefNet input size if different
transforms.ToTensor()
])
def remove_bg(input_image):
# Preprocess image
image = input_image.convert("RGB")
img_tensor = preprocess(image).unsqueeze(0) # Add batch dimension
# Inference (no gradients needed)
with torch.no_grad():
mask = bi_ref_net(img_tensor)[0, 0] # Output mask from model, shape: [H, W]
# Resize mask to original image size, normalize (if needed)
mask_img = transforms.ToPILImage()(mask.cpu().clamp(0, 1))
mask_img = mask_img.resize(image.size, Image.BILINEAR)
# Create RGBA output by setting alpha to mask
result = image.convert("RGBA")
result.putalpha(mask_img)
return result
demo = gr.Interface(
fn=remove_bg,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=gr.Image(type="pil", label="Background Removed (PNG)"),
title="Backdrop Studio - BiRefNet Background Removal",
description="Upload an image to remove the background using BiRefNet AI."
)
if __name__ == "__main__":
demo.launch()