Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import random | |
| import re | |
| from typing import List, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| from mmengine.logging import MMLogger | |
| from mmengine.model import BaseModel | |
| from mmpretrain.registry import MODELS, TOKENIZER | |
| from mmpretrain.structures import DataSample | |
| class MiniGPT4(BaseModel): | |
| """The multi-modality model of MiniGPT-4. | |
| The implementation of `MiniGPT-4 <https://arxiv.org/abs/2304.10592>`_. | |
| Modified from https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/models/mini_gpt4.py | |
| Args: | |
| vision_encoder (dict): The config for vision encoder. | |
| q_former_model (dict): The config for Qformer. | |
| lang_encoder (dict): The config for language model. | |
| tokenizer (dict): The config for tokenizer. | |
| task (str): To define the task, which control the processing of text. | |
| Defaults to 'caption'. | |
| freeze_vit (bool): Freeze the training of ViT. Defaults to True. | |
| freeze_q_former (bool): Freeze the training of Qformer. Defaults to | |
| True. | |
| num_query_token (int): Number of query tokens of Qformer. Defaults to | |
| 32. | |
| prompt_template (dict): Multi-language prompt template of the model. Defaults to dict([ ('en', '###Ask: {} ###Answer: '), | |
| ('zh', '###问:{} ###答:')]) | |
| raw_prompts (dict): Prompts for training. Defaults to dict(). | |
| max_txt_len (int): Max token length while doing tokenization. Defaults | |
| to 32. | |
| end_sym (str): Ended symbol of the sequence. Defaults to '###'. | |
| generation_cfg (dict): The config of text generation. Defaults to | |
| dict(). | |
| data_preprocessor (:obj:`BaseDataPreprocessor`): Used for | |
| pre-processing data sampled by dataloader to the format accepted by | |
| :meth:`forward`. Defaults to None. | |
| init_cfg (dict): Initialization config dict. Defaults to None. | |
| """ # noqa | |
| def __init__(self, | |
| vision_encoder: dict, | |
| q_former_model: dict, | |
| lang_encoder: dict, | |
| tokenizer: dict, | |
| task: str = 'caption', | |
| freeze_vit: bool = True, | |
| freeze_q_former: bool = True, | |
| num_query_token: int = 32, | |
| prompt_template: dict = dict([('en', | |
| '###Ask: {} ###Answer: '), | |
| ('zh', '###问:{} ###答:')]), | |
| raw_prompts: dict = dict(), | |
| max_txt_len: int = 32, | |
| end_sym: str = '###', | |
| generation_cfg: dict = dict(), | |
| data_preprocessor: Optional[dict] = None, | |
| init_cfg: Optional[dict] = None): | |
| if data_preprocessor is None: | |
| data_preprocessor = {} | |
| data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') | |
| data_preprocessor = MODELS.build(data_preprocessor) | |
| super().__init__( | |
| data_preprocessor=data_preprocessor, init_cfg=init_cfg) | |
| self.task = task | |
| logger = MMLogger.get_current_instance() | |
| # build vision model | |
| vision_encoder_weight = vision_encoder.pop('pretrained', None) | |
| self.vision_encoder = MODELS.build(vision_encoder) | |
| self.ln_vision = nn.LayerNorm(self.vision_encoder.embed_dims) | |
| if vision_encoder_weight is not None: | |
| from mmengine.runner.checkpoint import load_checkpoint | |
| load_checkpoint(self.vision_encoder, vision_encoder_weight) | |
| self.vision_encoder.is_init = True | |
| if freeze_vit: | |
| for name, param in self.ln_vision.named_parameters(): | |
| param.requires_grad = False | |
| self.ln_vision = self.ln_vision.eval() | |
| else: | |
| logger.warning('Please check `frozen_stages` in the dict of' | |
| '`vision_encoder`. Also set it to be -1 if do not' | |
| 'freeze ViT.') | |
| # build Qformer | |
| q_former_model_weight = q_former_model.pop('pretrained', None) | |
| self.q_former = MODELS.build(q_former_model) | |
| self.q_former.cls = None | |
| self.q_former.bert.embeddings.word_embeddings = None | |
| self.q_former.bert.embeddings.position_embeddings = None | |
| for layer in self.q_former.bert.encoder.layer: | |
| layer.output = None | |
| layer.intermediate = None | |
| self.query_tokens = nn.Parameter( | |
| torch.zeros(1, num_query_token, self.q_former.config.hidden_size)) | |
| self.query_tokens.data.normal_( | |
| mean=0.0, std=self.q_former.config.initializer_range) | |
| if q_former_model_weight is not None: | |
| from mmengine.runner.checkpoint import CheckpointLoader | |
| state_dict = CheckpointLoader.load_checkpoint( | |
| q_former_model_weight)['state_dict'] | |
| self.load_state_dict(state_dict, strict=False) | |
| # The ln_vision weights are also in the q-former checkpoint. | |
| setattr(self.ln_vision, 'is_init', True) | |
| setattr(self.q_former, 'is_init', True) | |
| if freeze_q_former: | |
| for name, param in self.q_former.named_parameters(): | |
| param.requires_grad = False | |
| self.q_former.eval() | |
| self.query_tokens.requires_grad = False | |
| # build language model | |
| self.llama_tokenizer = TOKENIZER.build(tokenizer) | |
| self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token | |
| self.llama_model = MODELS.build(lang_encoder) | |
| for name, param in self.llama_model.named_parameters(): | |
| param.requires_grad = False | |
| # build linear projection layer | |
| self.llama_proj = nn.Linear(self.q_former.config.hidden_size, | |
| self.llama_model.config.hidden_size) | |
| self.max_txt_len = max_txt_len | |
| self.end_sym = end_sym | |
| self.end_token_id = self.llama_tokenizer.encode(end_sym)[-1] | |
| # set prompts | |
| self.en_prompt_list, self.zh_prompt_list = [], [] | |
| if raw_prompts.get('en') is not None: | |
| en_filted_prompts = [ | |
| raw_prompt for raw_prompt in raw_prompts['en'] | |
| if '<ImageHere>' in raw_prompt | |
| ] | |
| self.en_prompt_list = [ | |
| prompt_template['en'].format(p) for p in en_filted_prompts | |
| ] | |
| if raw_prompts.get('zh') is not None: | |
| zh_filted_prompts = [ | |
| raw_prompt for raw_prompt in raw_prompts['zh'] | |
| if '<ImageHere>' in raw_prompt | |
| ] | |
| self.zh_prompt_list = [ | |
| prompt_template['zh'].format(p) for p in zh_filted_prompts | |
| ] | |
| # update generation configs | |
| self.generation_cfg = dict( | |
| max_new_tokens=300, | |
| num_beams=1, | |
| do_sample=True, | |
| min_length=1, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| length_penalty=1.0, | |
| temperature=1.0) | |
| self.generation_cfg.update(**generation_cfg) | |
| if hasattr(self, 'register_load_state_dict_post_hook'): | |
| self.register_load_state_dict_post_hook(self._load_llama_proj_hook) | |
| def half(self): | |
| self.llama_model = self.llama_model.half() | |
| return self | |
| def encode_img(self, | |
| images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """The function to encode the images.""" | |
| device = images.device | |
| x = self.vision_encoder(images)[0] | |
| image_embeds = self.ln_vision(x).to(device) | |
| image_atts = torch.ones( | |
| image_embeds.size()[:-1], dtype=torch.long).to(device) | |
| query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) | |
| query_output = self.q_former.bert( | |
| query_embeds=query_tokens, | |
| encoder_hidden_states=image_embeds, | |
| encoder_attention_mask=image_atts, | |
| return_dict=True, | |
| ) | |
| inputs_llama = self.llama_proj(query_output.last_hidden_state) | |
| atts_llama = torch.ones( | |
| inputs_llama.size()[:-1], dtype=torch.long).to(images.device) | |
| return inputs_llama, atts_llama | |
| def prompt_wrap(self, img_embeds: torch.Tensor, atts_img: torch.Tensor, | |
| prompt: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """The function to wrap the image and prompt. | |
| Make sure that len(prompt) == img_embeds.shape[0]. | |
| Args: | |
| img_embeds (torch.Tensor): The embedding of the input images. | |
| atts_img (torch.Tensor): Attention map of the image embeddings. | |
| prompt (List[str]): The prompt of the batch data. | |
| Returns: | |
| Tuple[torch.Tensor, torch.Tensor]: The embedding and attention map. | |
| """ | |
| if len(prompt) > 0: | |
| p_before_list, p_after_list = [], [] | |
| for pro in prompt: | |
| p_before, p_after = pro.split('<ImageHere>') | |
| p_before_list.append(p_before) | |
| p_after_list.append(p_after) | |
| p_before_tokens = self.llama_tokenizer( | |
| p_before_list, | |
| return_tensors='pt', | |
| padding='longest', | |
| add_special_tokens=False).to(img_embeds.device) | |
| p_after_tokens = self.llama_tokenizer( | |
| p_after_list, | |
| return_tensors='pt', | |
| padding='longest', | |
| add_special_tokens=False).to(img_embeds.device) | |
| p_before_embeds = self.llama_model.model.embed_tokens( | |
| p_before_tokens.input_ids) | |
| p_after_embeds = self.llama_model.model.embed_tokens( | |
| p_after_tokens.input_ids) | |
| wrapped_img_embeds = torch.cat( | |
| [p_before_embeds, img_embeds, p_after_embeds], dim=1) | |
| wrapped_atts_img = atts_img[:, :1].expand( | |
| -1, wrapped_img_embeds.shape[1]) | |
| return wrapped_img_embeds, wrapped_atts_img | |
| else: | |
| return img_embeds, atts_img | |
| def loss(self, | |
| images: torch.Tensor, | |
| data_samples: Optional[List[DataSample]] = None) -> dict: | |
| """The forward function in training. | |
| Args: | |
| inputs (List[torch.Tensor]): The input images. | |
| data_samples (List[DataSample]): All elements required | |
| during the forward function. | |
| Returns: | |
| Dict[str, torch.Tensor]: A dictionary of loss components. | |
| """ | |
| img_embeds, atts_img = self.encode_img(images) | |
| self.llama_tokenizer.padding_side = 'right' | |
| prompts, texts = [], [] | |
| for t in data_samples: | |
| chat_content = t.chat_content | |
| split_mark = '###Answer: ' if t.lang == 'en' else '###答:' | |
| prompt, text = chat_content.split(split_mark) | |
| prompt += split_mark | |
| text += self.end_sym | |
| prompts.append(prompt) | |
| texts.append(text) | |
| img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts) | |
| to_regress_tokens = self.llama_tokenizer( | |
| texts, | |
| return_tensors='pt', | |
| padding='longest', | |
| truncation=True, | |
| max_length=self.max_txt_len, | |
| add_special_tokens=False).to(images.device) | |
| targets = to_regress_tokens.input_ids.masked_fill( | |
| to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, | |
| -100) | |
| empty_targets = ( | |
| torch.ones([atts_img.shape[0], atts_img.shape[1] + 1], | |
| dtype=torch.long).to(images.device).fill_( | |
| -100) # plus one for bos | |
| ) | |
| targets = torch.cat([empty_targets, targets], dim=1) | |
| batch_size = img_embeds.shape[0] | |
| bos = torch.ones([batch_size, 1], | |
| dtype=to_regress_tokens.input_ids.dtype, | |
| device=to_regress_tokens.input_ids.device | |
| ) * self.llama_tokenizer.bos_token_id | |
| bos_embeds = self.llama_model.model.embed_tokens(bos) | |
| atts_bos = atts_img[:, :1] | |
| to_regress_embeds = self.llama_model.model.embed_tokens( | |
| to_regress_tokens.input_ids) | |
| inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], | |
| dim=1) | |
| attention_mask = torch.cat( | |
| [atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1) | |
| outputs = self.llama_model( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| return_dict=True, | |
| labels=targets, | |
| ) | |
| loss = outputs.loss | |
| return dict(loss=loss) | |
| def predict( | |
| self, | |
| images: torch.Tensor, | |
| data_samples: Optional[List[DataSample]] = None | |
| ) -> List[DataSample]: | |
| with torch.no_grad(): | |
| img_embeds, atts_img = self.encode_img(images) | |
| prompts = [ | |
| random.choice(self.zh_prompt_list) if hasattr(t, 'lang') | |
| and t.lang == 'zh' else random.choice(self.en_prompt_list) | |
| for t in data_samples | |
| ] | |
| img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts) | |
| batch_size = img_embeds.shape[0] | |
| bos = torch.ones( | |
| [batch_size, 1], dtype=torch.long, | |
| device=img_embeds.device) * self.llama_tokenizer.bos_token_id | |
| bos_embeds = self.llama_model.model.embed_tokens(bos) | |
| inputs_embeds = torch.cat([bos_embeds, img_embeds], dim=1) | |
| outputs = self.llama_model.generate( | |
| inputs_embeds=inputs_embeds, | |
| eos_token_id=self.end_token_id, | |
| **self.generation_cfg) | |
| return self.post_process(outputs, data_samples) | |
| def post_process( | |
| self, outputs: torch.Tensor, | |
| data_samples: Optional[List[DataSample]]) -> List[DataSample]: | |
| """Perform post process for outputs for different task. | |
| Args: | |
| outputs (torch.Tensor): The generated outputs. | |
| data_samples (List[DataSample], optional): The annotation | |
| data of every samples. | |
| Returns: | |
| List[DataSample]: Return list of data samples. | |
| """ | |
| outputs = self.llama_tokenizer.batch_decode( | |
| outputs, skip_special_tokens=True) | |
| if data_samples is None: | |
| data_samples = [DataSample() for _ in range(len(outputs))] | |
| for output, data_sample in zip(outputs, data_samples): | |
| if self.task == 'caption': | |
| output = output.split('###')[0] | |
| data_sample.pred_caption = output | |
| else: | |
| # raw output | |
| data_sample.pred_output = output | |
| return data_samples | |
| def forward( | |
| self, | |
| images: torch.Tensor, | |
| data_samples: Optional[list] = None, | |
| mode: str = 'predict', | |
| **kwargs, | |
| ): | |
| """The unified entry for a forward process in both training and test. | |
| The method accepts the following modes: | |
| - "predict": Forward and return a list of data samples contain the | |
| predict results. | |
| Args: | |
| images (torch.Tensor): the preprocessed image tensor of shape | |
| ``(N, C, H, W)``. | |
| data_samples (List[DataSample], optional): The annotation data | |
| of every samples. Defaults to None. | |
| mode (str): Return what kind of value. Defaults to 'predict'. | |
| """ | |
| if mode == 'loss': | |
| return self.loss(images, data_samples) | |
| elif mode == 'predict': | |
| return self.predict(images, data_samples, **kwargs) | |
| else: | |
| raise RuntimeError(f'Invalid mode "{mode}".') | |
| def _load_llama_proj_hook(module, incompatible_keys): | |
| """Avoid warning missing keys except LLaMA projection keys.""" | |
| proj_patterns = [ | |
| 'vision_encoder.*', | |
| 'ln_vision.*', | |
| 'q_former.*', | |
| 'query_tokens', | |
| 'llama_model.*', | |
| ] | |
| for key in list(incompatible_keys.missing_keys): | |
| if any(re.match(pattern, key) for pattern in proj_patterns): | |
| incompatible_keys.missing_keys.remove(key) | |