| import gradio as gr |
| import numpy as np |
| import torch |
| from PIL import Image |
|
|
| from infer_model import CLIPpyModel |
| from utils import get_similarity, get_transform, ade_palette, get_cmap_image |
|
|
| pretrained_ckpt = "https://github.com/kahnchana/clippy/releases/download/v1.0/clippy_5k.pt" |
| ckpt = torch.utils.model_zoo.load_url(pretrained_ckpt) |
|
|
| clippy = CLIPpyModel() |
| transform = get_transform((224, 224)) |
|
|
| msg = clippy.load_state_dict(ckpt, strict=False) |
|
|
| palette = ade_palette() |
|
|
|
|
| def process_image(img, captions): |
| sample_text = [x.strip() for x in captions.split(",")] |
| sample_prompts = [f"a photo of a {x}" for x in sample_text] |
|
|
| image = Image.fromarray(img) |
| image_vector = clippy.encode_image(transform(image).unsqueeze(0), get_pos_tokens=True) |
| text_vector = clippy.text.encode(sample_prompts, convert_to_tensor=True) |
|
|
| similarity = get_similarity(image_vector, text_vector, (224, 224), do_argmax=True)[0, 0].numpy() |
| rgb_seg = np.zeros((similarity.shape[0], similarity.shape[1], 3), dtype=np.uint8) |
| for idx, _ in enumerate(sample_text): |
| rgb_seg[similarity == idx] = palette[idx] |
|
|
| joint = Image.blend(image, Image.fromarray(rgb_seg), 0.5) |
| cmap = get_cmap_image({label: tuple(palette[idx]) for idx, label in enumerate(sample_text)}) |
|
|
| return cmap, rgb_seg, joint |
|
|
|
|
| title = 'CLIPpy' |
|
|
| description = """ |
| Gradio Demo for CLIPpy: Perceptual Grouping in Contrastive Vision Language Models. \n \n |
| Upload an image and type in a set of comma separated labels (e.g.: "man, woman, background"). |
| CLIPPy will segment the image, according to the set of class label you provide. |
| """ |
|
|
| article = """ |
| <p style='text-align: center'> |
| <a href='https://arxiv.org/abs/2210.09996' target='_blank'> |
| Perceptual Grouping in Contrastive Vision Language Models |
| </a> |
| | |
| <a href='https://github.com/kahnchana/clippy' target='_blank'>Github Repository</a></p> |
| """ |
|
|
| demo = gr.Interface( |
| fn=process_image, |
| inputs=[gr.Image(shape=(224, 224)), "text"], |
| outputs=[gr.Image(shape=(224, 224)).style(height=150), |
| gr.Image(shape=(224, 224)).style(height=260), |
| gr.Image(shape=(224, 224)).style(height=260)], |
| title=title, |
| description=description, |
| article=article, |
| ) |
|
|
| demo.launch() |
|
|