| import numpy as np | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import omegaconf | |
| from hydra import utils | |
| import os | |
| import torch | |
| import matplotlib.pyplot as plt | |
| from attn_helper import VITAttentionGradRollout, overlay_attn | |
| import vc_models | |
| import torchvision | |
| HF_TOKEN = os.environ['HF_ACC_TOKEN'] | |
| eai_filepath = vc_models.__file__.split('src')[0] | |
| MODEL_DIR=os.path.join(os.path.dirname(eai_filepath),'model_ckpts') | |
| if not os.path.isdir(MODEL_DIR): | |
| os.mkdir(MODEL_DIR) | |
| FILENAME = "config.yaml" | |
| BASE_MODEL_TUPLE = None | |
| LARGE_MODEL_TUPLE = None | |
| def get_model(model_name): | |
| global BASE_MODEL_TUPLE,LARGE_MODEL_TUPLE | |
| download_bin(model_name) | |
| model = None | |
| if BASE_MODEL_TUPLE is None and model_name == 'vc1-base': | |
| repo_name = "facebook/" + model_name | |
| model_cfg = omegaconf.OmegaConf.load( | |
| hf_hub_download(repo_id=repo_name, filename=FILENAME,token=HF_TOKEN) | |
| ) | |
| BASE_MODEL_TUPLE = utils.instantiate(model_cfg) | |
| BASE_MODEL_TUPLE[0].eval() | |
| model = BASE_MODEL_TUPLE | |
| elif LARGE_MODEL_TUPLE is None and model_name == 'vc1-large': | |
| repo_name = "facebook/" + model_name | |
| model_cfg = omegaconf.OmegaConf.load( | |
| hf_hub_download(repo_id=repo_name, filename=FILENAME,token=HF_TOKEN) | |
| ) | |
| LARGE_MODEL_TUPLE = utils.instantiate(model_cfg) | |
| LARGE_MODEL_TUPLE[0].eval() | |
| model = LARGE_MODEL_TUPLE | |
| elif model_name == 'vc1-base': | |
| model = BASE_MODEL_TUPLE | |
| elif model_name == 'vc1-large': | |
| model = LARGE_MODEL_TUPLE | |
| return model | |
| def download_bin(model): | |
| bin_file = "" | |
| if model == "vc1-large": | |
| bin_file = 'vc1_vitl.pth' | |
| elif model == "vc1-base": | |
| bin_file = 'vc1_vitb.pth' | |
| else: | |
| raise NameError("model not found: " + model) | |
| repo_name = 'facebook/' + model | |
| bin_path = os.path.join(MODEL_DIR,bin_file) | |
| if not os.path.isfile(bin_path): | |
| model_bin = hf_hub_download(repo_id=repo_name, filename='pytorch_model.bin',local_dir=MODEL_DIR,local_dir_use_symlinks=True,token=HF_TOKEN) | |
| os.rename(model_bin, bin_path) | |
| def run_attn(input_img, model="vc1-large",discard_ratio=0.89): | |
| download_bin(model) | |
| model, embedding_dim, transform, metadata = get_model(model) | |
| if input_img.shape[0] != 3: | |
| input_img = input_img.transpose(2, 0, 1) | |
| if(len(input_img.shape)== 3): | |
| input_img = torch.tensor(input_img).unsqueeze(0) | |
| input_img = input_img.float() | |
| resize_transform = torchvision.transforms.Resize((250,250)) | |
| input_img = resize_transform(input_img) | |
| x = transform(input_img) | |
| attention_rollout = VITAttentionGradRollout(model,head_fusion="max",discard_ratio=discard_ratio) | |
| y = model(x) | |
| mask = attention_rollout.get_attn_mask() | |
| attn_img = overlay_attn(input_img[0].permute(1,2,0),mask) | |
| return attn_img | |
| model_type = gr.Dropdown( | |
| ["vc1-base", "vc1-large"], label="Model Size", value="vc1-large") | |
| input_img = gr.Image(shape=(250,250)) | |
| output_img = gr.Image(shape=(250,250)) | |
| discard_ratio = gr.Slider(0,1,value=0.89) | |
| css = "#component-2, .input-image, .image-preview {height: 240px !important}" | |
| markdown ="This is a demo for the Visual Cortex models. When passed an image input, it displays the attention of the last layer of the transformer." | |
| demo = gr.Interface(fn=run_attn, title="Visual Cortex Large Model", description=markdown, | |
| examples=[[os.path.join('./imgs',x),None,None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x], | |
| inputs=[input_img,model_type,discard_ratio],outputs=output_img,css=css) | |
| demo.launch() | |