Spaces:
Build error
Build error
| # IMPORTS | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from lang_sam import LangSAM | |
| import gradio as gr | |
| def run_lang_sam(input_image, text_prompt, model): | |
| height = width = 256 | |
| image = input_image.convert("RGB").resize((height, width)) | |
| # Get the mask using the model | |
| masks, _, _, _ = model.predict(image, text_prompt) | |
| # Convert masks to integer format and find the maximum mask | |
| masks_int = masks.to(torch.uint8) | |
| masks_max, _ = masks_int.max(dim=0, keepdim=True) | |
| unified_mask = masks_max.squeeze(0).to(torch.bool) | |
| # Create a colored layer for the mask (choose your color in RGB format) | |
| color = (255, 0, 0) # Red color, for example | |
| colored_mask = np.zeros((256, 256, 3), dtype=np.uint8) | |
| colored_mask[unified_mask] = color # Apply the color to the mask area | |
| # Convert the colored mask to PIL for blending | |
| colored_mask_pil = Image.fromarray(colored_mask) | |
| # Blend the colored mask with the original image | |
| # You can adjust the alpha to change the transparency of the colored mask | |
| alpha = 0.5 # Transparency factor (between 0 and 1) | |
| blended_image = Image.blend(image, colored_mask_pil, alpha=alpha) | |
| return blended_image | |
| def setup_gradio_interface(model): | |
| block = gr.Blocks() | |
| with block: | |
| gr.Markdown("<h1><center>Lang SAM<h1><center>") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="Input Image") | |
| text_prompt = gr.Textbox(label="Enter what you want to segment") | |
| run_button = gr.Button(value="Run") | |
| with gr.Column(): | |
| output_mask = gr.Image(type="numpy", label="Segmentation Mask") | |
| run_button.click( | |
| fn=lambda image, prompt: run_lang_sam( | |
| image, prompt, model, | |
| ), | |
| inputs=[input_image, text_prompt], | |
| outputs=[output_mask], | |
| ) | |
| gr.Examples( | |
| examples=[["bw-image.jpeg", "road"]], | |
| inputs=[input_image, text_prompt], | |
| outputs=[output_mask], | |
| fn=lambda image, prompt: run_lang_sam( | |
| image, prompt, model, | |
| ), | |
| cache_examples=True, | |
| label="Try this example input!", | |
| ) | |
| return block | |
| if __name__ == "__main__": | |
| model = LangSAM() | |
| gradio_interface = setup_gradio_interface(model) | |
| gradio_interface.launch(share=False, show_api=False, show_error=True) |