Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import math | |
| from dataclasses import dataclass | |
| from functools import partial | |
| from typing import List, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmcv.cnn.bricks import DropPath | |
| from mmengine.model import BaseModule | |
| from mmengine.utils import digit_version | |
| from transformers.modeling_outputs import ( | |
| BaseModelOutputWithPastAndCrossAttentions, ModelOutput, Seq2SeqLMOutput) | |
| from transformers.modeling_utils import (GenerationConfig, GenerationMixin, | |
| PretrainedConfig) | |
| from mmpretrain.registry import MODELS | |
| from ...backbones.resnet import Bottleneck, ResNet | |
| if digit_version(torch.__version__) >= digit_version('1.10.0'): | |
| torch_meshgrid = partial(torch.meshgrid, indexing='ij') | |
| else: | |
| torch_meshgrid = torch.meshgrid | |
| def make_token_bucket_position(bucket_size, max_position=1024): | |
| context_pos = torch.arange(max_position, dtype=torch.long)[:, None] | |
| memory_pos = torch.arange(max_position, dtype=torch.long)[None, :] | |
| relative_pos = context_pos - memory_pos | |
| sign = torch.sign(relative_pos) | |
| mid = bucket_size // 2 | |
| abs_pos = torch.where((relative_pos < mid) & (relative_pos > -mid), | |
| mid - 1, torch.abs(relative_pos)) | |
| log_pos = torch.ceil( | |
| torch.log(abs_pos / mid) / math.log( | |
| (max_position - 1) / mid) * (mid - 1)) + mid | |
| log_pos = log_pos.int() | |
| bucket_pos = torch.where(abs_pos.le(mid), relative_pos, | |
| log_pos * sign).long() | |
| return bucket_pos + bucket_size - 1 | |
| def make_image_bucket_position(bucket_size, num_relative_distance): | |
| coords_h = torch.arange(bucket_size) | |
| coords_w = torch.arange(bucket_size) | |
| # (2, h, w) | |
| coords = torch.stack(torch_meshgrid([coords_h, coords_w])) | |
| # (2, h*w) | |
| coords_flatten = torch.flatten(coords, 1) | |
| # (2, h*w, h*w) | |
| relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] | |
| # (h*w, h*w, 2) | |
| relative_coords = relative_coords.permute(1, 2, 0).contiguous() | |
| relative_coords[:, :, 0] += bucket_size - 1 # shift to start from 0 | |
| relative_coords[:, :, 1] += bucket_size - 1 | |
| relative_coords[:, :, 0] *= 2 * bucket_size - 1 | |
| relative_position_index = torch.zeros( | |
| size=(bucket_size * bucket_size + 1, ) * 2, | |
| dtype=relative_coords.dtype) | |
| # (h*w, h*w) | |
| relative_position_index[1:, 1:] = relative_coords.sum(-1) | |
| relative_position_index[0, 0:] = num_relative_distance - 3 | |
| relative_position_index[0:, 0] = num_relative_distance - 2 | |
| relative_position_index[0, 0] = num_relative_distance - 1 | |
| return relative_position_index | |
| def _make_causal_mask(input_ids_shape: torch.Size, | |
| dtype: torch.dtype, | |
| past_key_values_length: int = 0): | |
| """Make causal mask used for uni-directional self-attention.""" | |
| bsz, tgt_len = input_ids_shape | |
| mask = torch.full((tgt_len, tgt_len), float('-inf')) | |
| mask_cond = torch.arange(mask.size(-1)) | |
| mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) | |
| mask = mask.to(dtype) | |
| if past_key_values_length > 0: | |
| mask = torch.cat( | |
| [torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], | |
| dim=-1) | |
| return mask[None, None, :, :].expand(bsz, 1, tgt_len, | |
| tgt_len + past_key_values_length) | |
| def _expand_mask(mask: torch.Tensor, | |
| dtype: torch.dtype, | |
| tgt_len: Optional[int] = None): | |
| """Expands attention_mask from ``[B, L_s]`` to ``[B, 1, L_t, L_s]``. | |
| Where ``B`` is batch_size, `L_s`` is the source sequence length, and | |
| ``L_t`` is the target sequence length. | |
| """ | |
| bsz, src_len = mask.size() | |
| tgt_len = tgt_len if tgt_len is not None else src_len | |
| expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, | |
| src_len).to(dtype) | |
| return expanded_mask.masked_fill(expanded_mask.bool(), | |
| torch.finfo(dtype).min) | |
| class MultiheadAttention(BaseModule): | |
| """Multi-head Attention Module for OFA. | |
| Args: | |
| embedding_dim (int): The embedding dimension of query. | |
| num_heads (int): Parallel attention heads. | |
| kdim (int, optional): The embedding dimension of key. | |
| Defaults to None, which means the same as the `embedding_dim`. | |
| vdim (int, optional): The embedding dimension of value. | |
| Defaults to None, which means the same as the `embedding_dim`. | |
| attn_drop (float): Dropout rate of the dropout layer after the | |
| attention calculation of query and key. Defaults to 0. | |
| qkv_bias (bool): If True, add a learnable bias to q, k, v. | |
| Defaults to True. | |
| scale_factor (float): The scale of qk will be | |
| ``(head_dim * scale_factor) ** -0.5``. Defaults to 1. | |
| proj_bias (bool) If True, add a learnable bias to output projection. | |
| Defaults to True. | |
| init_cfg (dict, optional): The Config for initialization. | |
| Defaults to None. | |
| """ | |
| def __init__(self, | |
| embedding_dim, | |
| num_heads, | |
| kdim=None, | |
| vdim=None, | |
| attn_drop=0., | |
| scale_factor=1., | |
| qkv_bias=True, | |
| proj_bias=True, | |
| scale_heads=False, | |
| init_cfg=None): | |
| super(MultiheadAttention, self).__init__(init_cfg=init_cfg) | |
| self.embedding_dim = embedding_dim | |
| self.num_heads = num_heads | |
| self.kdim = kdim or embedding_dim | |
| self.vdim = vdim or embedding_dim | |
| self.head_dim = embedding_dim // num_heads | |
| self.scale = (self.head_dim * scale_factor)**-0.5 | |
| self.q_proj = nn.Linear(embedding_dim, embedding_dim, bias=qkv_bias) | |
| self.k_proj = nn.Linear(self.kdim, embedding_dim, bias=qkv_bias) | |
| self.v_proj = nn.Linear(self.vdim, embedding_dim, bias=qkv_bias) | |
| self.out_proj = nn.Linear(embedding_dim, embedding_dim, bias=proj_bias) | |
| self.attn_drop = nn.Dropout(p=attn_drop) | |
| if scale_heads: | |
| self.c_attn = nn.Parameter(torch.ones(num_heads)) | |
| else: | |
| self.c_attn = None | |
| def forward( | |
| self, | |
| query, | |
| key_value=None, | |
| attn_mask=None, | |
| attn_bias=None, | |
| past_key_value=None, | |
| output_attentions=False, | |
| ): | |
| B, _, C = query.shape | |
| assert C == self.head_dim * self.num_heads | |
| is_cross_attention = key_value is not None | |
| if key_value is None: | |
| key_value = query | |
| # (B, L, C) -> (B, num_heads, L, head_dims) | |
| q = self.q_proj(query).reshape(B, -1, self.num_heads, | |
| self.head_dim).transpose(1, 2) | |
| if is_cross_attention and past_key_value is not None: | |
| # Reuse key and value in cross_attentions | |
| k, v = past_key_value | |
| else: | |
| k = self.k_proj(key_value).reshape(B, -1, self.num_heads, | |
| self.head_dim).transpose(1, 2) | |
| v = self.v_proj(key_value).reshape(B, -1, self.num_heads, | |
| self.head_dim).transpose(1, 2) | |
| if past_key_value is not None: | |
| past_key, past_value = past_key_value | |
| k = torch.cat([past_key, k], dim=2) | |
| v = torch.cat([past_value, v], dim=2) | |
| past_key_value = (k, v) | |
| attn_weights = q @ k.transpose(-2, -1) * self.scale | |
| if attn_bias is not None: | |
| src_len = k.size(2) | |
| attn_weights[:, :, -src_len:] += attn_bias[:, :, -src_len:] | |
| if attn_mask is not None: | |
| attn_weights += attn_mask | |
| attn_weights = torch.softmax(attn_weights, dim=-1) | |
| attn = self.attn_drop(attn_weights) @ v | |
| if self.c_attn is not None: | |
| attn = torch.einsum('bhlc,h->bhlc', attn, self.c_attn) | |
| # (B, num_heads, L, head_dims) -> (B, L, C) | |
| attn = attn.transpose(1, 2).reshape(B, -1, self.embedding_dim) | |
| attn = self.out_proj(attn) | |
| if output_attentions: | |
| return attn, attn_weights, past_key_value | |
| else: | |
| return attn, None, past_key_value | |
| class OFAResNet(ResNet): | |
| """ResNet module for OFA. | |
| The ResNet in OFA has only three stages. | |
| """ | |
| arch_settings = { | |
| 50: (Bottleneck, (3, 4, 6)), | |
| 101: (Bottleneck, (3, 4, 23)), | |
| 152: (Bottleneck, (3, 8, 36)), | |
| } | |
| def __init__(self, depth, *args, **kwargs): | |
| super().__init__( | |
| depth=depth, | |
| *args, | |
| num_stages=3, | |
| out_indices=(2, ), | |
| dilations=(1, 1, 1), | |
| strides=(1, 2, 2), | |
| **kwargs) | |
| class OFAEncoderOutput(ModelOutput): | |
| """OFA encoder outputs. | |
| Args: | |
| last_hidden_state (torch.tensor): The hidden-states of the output at | |
| the last layer of the model. The shape is (B, L, C). | |
| hidden_states (Tuple[torch.tensor]): The initial embedding and the | |
| output of each layer. The shape of every item is (B, L, C). | |
| attentions (Tuple[torch.tensor]): The attention weights after the | |
| attention softmax, used to compute the weighted average in the | |
| self-attention heads. The shape of every item is | |
| (B, num_heads, L, L). | |
| position_embedding (torch.tensor): The positional embeddings of the | |
| inputs. The shape is (B, L, C). | |
| """ | |
| last_hidden_state: torch.FloatTensor = None | |
| padding_mask: torch.Tensor = None | |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
| attentions: Optional[Tuple[torch.FloatTensor]] = None | |
| position_embedding: Optional[torch.FloatTensor] = None | |
| class OFAEncoderLayer(nn.Module): | |
| """OFAEncoder layer block.""" | |
| def __init__(self, | |
| embedding_dim, | |
| num_heads, | |
| dropout_rate=0., | |
| drop_path_rate=0., | |
| attn_drop=0., | |
| act_drop=0., | |
| scale_factor=2., | |
| mlp_ratio=4., | |
| scale_heads=True, | |
| normformer=True, | |
| pre_norm=True, | |
| act_cfg=dict(type='GELU')): | |
| super().__init__() | |
| self.embedding_dim = embedding_dim | |
| self.pre_norm = pre_norm | |
| self.attn = MultiheadAttention( | |
| embedding_dim=embedding_dim, | |
| num_heads=num_heads, | |
| attn_drop=attn_drop, | |
| scale_factor=scale_factor, | |
| scale_heads=scale_heads, | |
| ) | |
| mid_channels = int(embedding_dim * mlp_ratio) | |
| self.fc1 = nn.Linear(embedding_dim, mid_channels) | |
| self.fc2 = nn.Linear(mid_channels, embedding_dim) | |
| self.act = MODELS.build(act_cfg) | |
| self.act_drop = nn.Dropout( | |
| act_drop) if act_drop > 0. else nn.Identity() | |
| # LayerNorm between attention block and ffn block. | |
| self.attn_ln = nn.LayerNorm(embedding_dim) | |
| self.ffn_ln = nn.LayerNorm(embedding_dim) | |
| # Extra LayerNorm | |
| self.normformer = normformer | |
| if self.normformer: | |
| self.attn_mid_ln = nn.LayerNorm(embedding_dim) | |
| self.ffn_mid_ln = nn.LayerNorm(mid_channels) | |
| self.dropout = nn.Dropout(dropout_rate) | |
| self.drop_path = DropPath( | |
| drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() | |
| def forward(self, | |
| x, | |
| attention_mask=None, | |
| attn_bias=None, | |
| output_attentions=False): | |
| """Forward the encoder layer. | |
| Args: | |
| x (torch.tensor): The input to the layer of shape ``(B, L, C)``. | |
| attention_mask (torch.Tensor, optional): The attention mask of size | |
| ``(B, 1, L, L)``, where padding elements are indicated by very | |
| large negative values. Defaults to None. | |
| attn_bias (torch.tensor, optional): The bias for positional | |
| information. Defaults to None. | |
| output_attentions (bool): Whether to return the attentions tensors | |
| of the attention layer. | |
| Returns: | |
| List[torch.tensor]: The first element is the encoded output of | |
| shape ``(B, L, C)``. And the second element is the output | |
| attentions if ``output_attentions=True``. | |
| """ | |
| residual = x | |
| # Attention block | |
| if self.pre_norm: | |
| x = self.attn_ln(x) | |
| x, attn_weights, _ = self.attn( | |
| query=x, | |
| attn_mask=attention_mask, | |
| attn_bias=attn_bias, | |
| output_attentions=output_attentions) | |
| if self.normformer: | |
| x = self.attn_mid_ln(x) | |
| x = self.dropout(x) | |
| x = residual + self.drop_path(x) | |
| if not self.pre_norm: | |
| x = self.attn_ln(x) | |
| residual = x | |
| # FFN block | |
| if self.pre_norm: | |
| x = self.ffn_ln(x) | |
| x = self.act(self.fc1(x)) | |
| x = self.act_drop(x) | |
| if self.normformer: | |
| x = self.ffn_mid_ln(x) | |
| x = self.fc2(x) | |
| x = self.dropout(x) | |
| x = residual + self.drop_path(x) | |
| if not self.pre_norm: | |
| x = self.ffn_ln(x) | |
| if output_attentions: | |
| return [x, attn_weights] | |
| else: | |
| return [x] | |
| class OFADecoderLayer(nn.Module): | |
| """OFADecoder layer block.""" | |
| def __init__(self, | |
| embedding_dim, | |
| num_heads, | |
| dropout_rate=0., | |
| drop_path_rate=0., | |
| attn_drop=0., | |
| act_drop=0., | |
| scale_factor=2., | |
| mlp_ratio=4., | |
| encoder_embed_dim=None, | |
| scale_heads=True, | |
| normformer=True, | |
| pre_norm=True, | |
| act_cfg=dict(type='GELU')): | |
| super().__init__() | |
| self.embedding_dim = embedding_dim | |
| self.pre_norm = pre_norm | |
| self.self_attn = MultiheadAttention( | |
| embedding_dim=embedding_dim, | |
| num_heads=num_heads, | |
| attn_drop=attn_drop, | |
| scale_factor=scale_factor, | |
| scale_heads=scale_heads, | |
| ) | |
| self.cross_attn = MultiheadAttention( | |
| embedding_dim=embedding_dim, | |
| kdim=encoder_embed_dim, | |
| vdim=encoder_embed_dim, | |
| num_heads=num_heads, | |
| attn_drop=attn_drop, | |
| scale_factor=scale_factor, | |
| scale_heads=scale_heads, | |
| ) | |
| mid_channels = int(embedding_dim * mlp_ratio) | |
| self.fc1 = nn.Linear(embedding_dim, mid_channels) | |
| self.fc2 = nn.Linear(mid_channels, embedding_dim) | |
| self.act = MODELS.build(act_cfg) | |
| self.act_drop = nn.Dropout( | |
| act_drop) if act_drop > 0. else nn.Identity() | |
| # LayerNorm between attention block and ffn block. | |
| self.self_attn_ln = nn.LayerNorm(embedding_dim) | |
| self.cross_attn_ln = nn.LayerNorm(embedding_dim) | |
| self.ffn_ln = nn.LayerNorm(embedding_dim) | |
| # Extra LayerNorm | |
| self.normformer = normformer | |
| if self.normformer: | |
| self.self_attn_mid_ln = nn.LayerNorm(embedding_dim) | |
| self.cross_attn_mid_ln = nn.LayerNorm(embedding_dim) | |
| self.ffn_mid_ln = nn.LayerNorm(mid_channels) | |
| self.dropout = nn.Dropout(dropout_rate) | |
| self.drop_path = DropPath( | |
| drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() | |
| def forward( | |
| self, | |
| x, | |
| attention_mask=None, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| past_key_value: Optional[List[torch.Tensor]] = None, | |
| output_attentions: bool = False, | |
| use_cache: bool = False, | |
| self_attn_bias: Optional[torch.Tensor] = None, | |
| cross_attn_bias: Optional[torch.Tensor] = None, | |
| ): | |
| """Forward the decoder layer. | |
| Args: | |
| x (torch.tensor): The input to the layer of shape ``(B, L, C)``. | |
| attention_mask (torch.Tensor, optional): The attention mask of size | |
| ``(B, 1, L, L)``, where padding elements are indicated by very | |
| large negative values. Defaults to None. | |
| encoder_hidden_states (torch.Tensor, optional): The cross attention | |
| input to the layer of size ``(B, L, C)``. Defaults to None. | |
| encoder_attention_mask (torch.Tensor, optional): The cross | |
| attention mask where padding elements are indicated by very | |
| large negative values. Defaults to None. | |
| past_key_value (Tuple[torch.tensor], optional): The cached past key | |
| and value projection states. Defaults to none. | |
| output_attentions (bool): whether to return the attentions tensors | |
| of all attention layers. Defaults to False. | |
| use_cache (bool, optional): Whether to use cache. | |
| Defaults to False. | |
| self_attn_bias (torch.Tensor, optional): The self attention bias | |
| for positional information. Defaults to None. | |
| cross_attn_bias (torch.Tensor, optional): The cross attention bias | |
| for positional information. Defaults to None. | |
| Returns: | |
| List[torch.tensor]: The first element is the encoded output of | |
| shape ``(B, L, C)``. The following two elements can be the output | |
| self-attentions and cross-attentions if ``output_attentions=True``. | |
| The following one element can be the cached past key and value | |
| projection states. | |
| """ | |
| residual = x | |
| if past_key_value is not None: | |
| self_past_key_value = past_key_value[:2] | |
| cross_past_key_value = past_key_value[2:] | |
| else: | |
| self_past_key_value, cross_past_key_value = None, None | |
| # Self-Attention block | |
| if self.pre_norm: | |
| x = self.self_attn_ln(x) | |
| x, self_attn_weights, present_key_value = self.self_attn( | |
| query=x, | |
| past_key_value=self_past_key_value, | |
| attn_mask=attention_mask, | |
| output_attentions=output_attentions, | |
| attn_bias=self_attn_bias, | |
| ) | |
| if self.normformer: | |
| x = self.self_attn_mid_ln(x) | |
| x = self.dropout(x) | |
| x = residual + self.drop_path(x) | |
| if not self.pre_norm: | |
| x = self.self_attn_ln(x) | |
| # Cross-Attention block | |
| if encoder_hidden_states is not None: | |
| residual = x | |
| if self.pre_norm: | |
| x = self.cross_attn_ln(x) | |
| x, cross_attn_weights, cross_key_value = self.cross_attn.forward( | |
| query=x, | |
| key_value=encoder_hidden_states, | |
| attn_mask=encoder_attention_mask, | |
| past_key_value=cross_past_key_value, | |
| output_attentions=output_attentions, | |
| attn_bias=cross_attn_bias) | |
| if self.normformer: | |
| x = self.cross_attn_mid_ln(x) | |
| x = self.dropout(x) | |
| x = residual + self.drop_path(x) | |
| if not self.pre_norm: | |
| x = self.cross_attn_ln(x) | |
| present_key_value = present_key_value + cross_key_value | |
| residual = x | |
| # FFN block | |
| if self.pre_norm: | |
| x = self.ffn_ln(x) | |
| x = self.act(self.fc1(x)) | |
| x = self.act_drop(x) | |
| if self.normformer: | |
| x = self.ffn_mid_ln(x) | |
| x = self.fc2(x) | |
| x = self.dropout(x) | |
| x = residual + self.drop_path(x) | |
| if not self.pre_norm: | |
| x = self.ffn_ln(x) | |
| outputs = [x] | |
| if output_attentions: | |
| outputs.extend([self_attn_weights, cross_attn_weights]) | |
| if use_cache: | |
| outputs.append(present_key_value) | |
| return outputs | |
| class OFAEncoder(BaseModule): | |
| """The encoder module of OFA. | |
| Args: | |
| embed_tokens (nn.Embedding): The embedding module to embed the | |
| input tokens. | |
| embed_images (dict | nn.Module): The module to embed the input | |
| images into features. The output number of channels should | |
| be 1024. | |
| num_layers (int): The number of encoder layers. Defaults to 6. | |
| num_heads (int): The number of heads of attention. Defaults to 12. | |
| dropout_rate (float): The prob of dropout for embedding and | |
| transformer layers. Defaults to 0. | |
| drop_path_rate (float): The prob of droppath for transformer layers. | |
| Defaults to 0. | |
| max_source_positions (int): The maximum length of the input tokens. | |
| Defaults to 1024. | |
| token_bucket_size (int): The token bucket size, it's used as the | |
| maximum relative position index in relative position embedding | |
| of input tokens. Defaults to 256. | |
| image_bucket_size (int): The image bucket size, it's used to generate | |
| the image relative position embedding table. It should be larger | |
| than the shape of image feature map. Defaults to 42. | |
| attn_scale_factor (float): The scale factor to calculate qk scale in | |
| attentions. Defaults to 2. | |
| scale_embedding (bool): Whether to scale the embeddings by the square | |
| root of the dimension. Defaults to False. | |
| add_embedding_ln (bool): Whether to add an extra layer norm for token | |
| embeddings. Defaults to True. | |
| add_image_embedding_ln (bool): Whether to add an extra layer norm for | |
| image embeddings. Defaults to True. | |
| pre_norm (bool): Whether to do layer norm before attention and ffn | |
| blocks in transformer layers. Defaults to True. | |
| entangle_position_embedding (bool): Whether to add the position | |
| embedding on the embeddings directly. Defaults to False. | |
| init_cfg (dict, optional): The initialization config. Defaults to None. | |
| """ | |
| def __init__( | |
| self, | |
| embed_tokens, | |
| embed_images: dict, | |
| num_layers=6, | |
| num_heads=12, | |
| dropout_rate=0., | |
| drop_path_rate=0., | |
| max_source_positions=1024, | |
| token_bucket_size=256, | |
| image_bucket_size=42, | |
| attn_scale_factor=2., | |
| scale_embedding=False, | |
| add_embedding_ln=True, | |
| add_type_embed=True, | |
| add_image_embedding_ln=True, | |
| pre_norm=True, | |
| entangle_position_embedding=False, | |
| init_cfg=None, | |
| ): | |
| super().__init__(init_cfg=init_cfg) | |
| self.num_layers = num_layers | |
| embedding_dim = embed_tokens.embedding_dim | |
| self.embedding_dim = embedding_dim | |
| self.padding_idx = embed_tokens.padding_idx | |
| self.max_source_positions = max_source_positions | |
| self.num_heads = num_heads | |
| # Build embedding process components | |
| self.embed_tokens = embed_tokens | |
| self.embedding_scale = math.sqrt( | |
| embedding_dim) if scale_embedding else 1.0 | |
| if not isinstance(embed_images, nn.Module): | |
| self.embed_images = MODELS.build(embed_images) | |
| else: | |
| self.embed_images = embed_images | |
| self.image_proj = nn.Linear(1024, embedding_dim) | |
| if add_embedding_ln: | |
| self.embedding_ln = nn.LayerNorm(embedding_dim) | |
| else: | |
| self.embedding_ln = None | |
| if add_type_embed: | |
| self.embed_type = nn.Embedding(2, embedding_dim) | |
| else: | |
| self.embed_type = None | |
| if add_image_embedding_ln: | |
| self.image_embedding_ln = nn.LayerNorm(embedding_dim) | |
| else: | |
| self.image_embedding_ln = None | |
| self.entangle_position_embedding = entangle_position_embedding | |
| # Build position embedding | |
| self.embed_positions = nn.Embedding(self.max_source_positions + 2, | |
| embedding_dim) | |
| self.pos_ln = nn.LayerNorm(embedding_dim) | |
| self.embed_image_positions = nn.Embedding(image_bucket_size**2 + 1, | |
| embedding_dim) | |
| self.image_pos_ln = nn.LayerNorm(embedding_dim) | |
| self.pos_scaling = float(embedding_dim / num_heads * | |
| attn_scale_factor)**-0.5 | |
| self.pos_q_linear = nn.Linear(embedding_dim, embedding_dim) | |
| self.pos_k_linear = nn.Linear(embedding_dim, embedding_dim) | |
| self.dropout = nn.Dropout( | |
| dropout_rate) if dropout_rate > 0. else nn.Identity() | |
| # Register token relative position embedding table | |
| self.token_bucket_size = token_bucket_size | |
| token_num_rel_dis = 2 * token_bucket_size - 1 | |
| token_rp_bucket = make_token_bucket_position(token_bucket_size, | |
| self.max_source_positions) | |
| self.register_buffer('token_rp_bucket', token_rp_bucket) | |
| self.token_rel_pos_table_list = nn.ModuleList() | |
| # Register image relative position embedding table | |
| self.image_bucket_size = image_bucket_size | |
| image_num_rel_dis = (2 * image_bucket_size - | |
| 1) * (2 * image_bucket_size - 1) + 3 | |
| image_rp_bucket = make_image_bucket_position(image_bucket_size, | |
| image_num_rel_dis) | |
| self.register_buffer('image_rp_bucket', image_rp_bucket) | |
| self.image_rel_pos_table_list = nn.ModuleList() | |
| # Build encoder layers | |
| self.layers = nn.ModuleList() | |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] | |
| for index in range(self.num_layers): | |
| layer = OFAEncoderLayer( | |
| embedding_dim=embedding_dim, | |
| num_heads=num_heads, | |
| dropout_rate=dropout_rate, | |
| drop_path_rate=dpr[index], | |
| scale_factor=attn_scale_factor, | |
| pre_norm=pre_norm, | |
| ) | |
| self.layers.append(layer) | |
| token_pos_table = nn.Embedding(token_num_rel_dis, self.num_heads) | |
| image_pos_table = nn.Embedding(image_num_rel_dis, self.num_heads) | |
| nn.init.constant_(token_pos_table.weight, 0.) | |
| nn.init.constant_(image_pos_table.weight, 0.) | |
| self.token_rel_pos_table_list.append(token_pos_table) | |
| self.image_rel_pos_table_list.append(image_pos_table) | |
| if pre_norm: | |
| self.final_ln = nn.LayerNorm(embedding_dim) | |
| else: | |
| self.final_ln = None | |
| main_input_name = 'input_ids' | |
| def forward(self, | |
| input_ids, | |
| images, | |
| images_mask, | |
| output_attentions=False, | |
| output_hidden_states=False, | |
| sample_patch_num=None): | |
| padding_mask = input_ids.eq(self.padding_idx) | |
| has_pads = padding_mask.any() | |
| token_embedding = self.embed_tokens(input_ids) | |
| token_embedding = self.embedding_scale * token_embedding | |
| # Embed the token position | |
| src_pos_idx = torch.arange(input_ids.size(-1), device=input_ids.device) | |
| src_pos_idx = src_pos_idx.expand(*input_ids.shape).contiguous() | |
| pos_embedding = self.embed_positions(src_pos_idx) | |
| # Embed the input tokens | |
| x = self.process_embedding( | |
| embedding=token_embedding, | |
| type_tokens=input_ids.new_zeros(token_embedding.shape[:2]), | |
| pos_embedding=pos_embedding, | |
| embedding_ln=self.embedding_ln, | |
| ) | |
| pos_embedding = self.pos_ln(pos_embedding) | |
| # Embed the input images | |
| if images is not None: | |
| (image_tokens, image_padding_mask, image_position_ids, | |
| image_pos_embedding) = self.get_image_tokens( | |
| images, | |
| sample_patch_num, | |
| images_mask, | |
| ) | |
| image_embedding = self.image_proj(image_tokens) | |
| image_x = self.process_embedding( | |
| embedding=image_embedding, | |
| type_tokens=input_ids.new_ones(image_embedding.shape[:2]), | |
| pos_embedding=image_pos_embedding, | |
| embedding_ln=self.image_embedding_ln, | |
| ) | |
| image_pos_embedding = self.image_pos_ln(image_pos_embedding) | |
| x = torch.cat([image_x, x], dim=1) | |
| padding_mask = torch.cat([image_padding_mask, padding_mask], dim=1) | |
| pos_embedding = torch.cat([image_pos_embedding, pos_embedding], | |
| dim=1) | |
| # account for padding while computing the representation | |
| if has_pads: | |
| x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) | |
| # Decoupled position embedding | |
| B, L = pos_embedding.shape[:2] | |
| pos_q = self.pos_q_linear(pos_embedding).view( | |
| B, L, self.num_heads, -1).transpose(1, 2) * self.pos_scaling | |
| pos_k = self.pos_k_linear(pos_embedding).view(B, L, self.num_heads, | |
| -1).transpose(1, 2) | |
| abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) | |
| all_hidden_states = [] if output_hidden_states else None | |
| all_attentions = [] if output_attentions else None | |
| for idx, layer in enumerate(self.layers): | |
| if output_hidden_states: | |
| all_hidden_states.append(x) | |
| self_attn_bias = abs_pos_bias.clone() | |
| # Add decoupled position embedding for input tokens. | |
| token_len = input_ids.size(1) | |
| rel_pos_bias = self.get_rel_pos_bias(input_ids, idx) | |
| self_attn_bias[:, :, -token_len:, -token_len:] += rel_pos_bias | |
| # Add decoupled position embedding for images | |
| if images is not None: | |
| token_len = image_tokens.size(1) | |
| rel_pos_bias = self.get_image_rel_pos_bias( | |
| image_position_ids, idx) | |
| self_attn_bias[:, :, :token_len, :token_len] += rel_pos_bias | |
| if has_pads: | |
| attention_mask = _expand_mask(padding_mask, dtype=x.dtype) | |
| else: | |
| attention_mask = None | |
| out = layer( | |
| x, | |
| attention_mask=attention_mask, | |
| attn_bias=self_attn_bias, | |
| output_attentions=output_attentions) | |
| x = out[0] | |
| if output_attentions: | |
| all_attentions.append(out[1]) | |
| if output_hidden_states: | |
| all_hidden_states.append(x) | |
| if self.final_ln is not None: | |
| x = self.final_ln(x) | |
| return OFAEncoderOutput( | |
| last_hidden_state=x, # (B, L, C) | |
| padding_mask=padding_mask, # (B, L) | |
| position_embedding=pos_embedding, # (B, L, C) | |
| hidden_states=all_hidden_states, # list of (B, L, C) | |
| attentions=all_attentions, # list of (B, num_heads, L, head_dims) | |
| ) | |
| def get_image_tokens(self, images, sample_patch_num, images_mask): | |
| image_embedding = self.embed_images(images)[-1] | |
| B, C, H, W = image_embedding.shape | |
| num_patches = H * W | |
| padding_mask = images.new_zeros((B, num_patches)).bool() | |
| position_col = torch.arange(W).unsqueeze(0) | |
| position_row = torch.arange(H).unsqueeze(1) * self.image_bucket_size | |
| position_idx = (position_col + position_row + 1).view(-1) | |
| position_idx = position_idx.to(images.device).expand(B, num_patches) | |
| # (B, C, H, W) -> (B, C, H*W) -> (B, H*W, C) | |
| image_embedding = image_embedding.flatten(2).transpose(1, 2) | |
| if sample_patch_num is not None: | |
| patch_orders = torch.stack([ | |
| torch.randperm(num_patches)[:sample_patch_num] | |
| for _ in range(B) | |
| ]) | |
| num_patches = sample_patch_num | |
| image_embedding = image_embedding.gather( | |
| dim=1, index=patch_orders.unsqueeze(2).expand(-1, -1, C)) | |
| padding_mask = padding_mask.gather(1, patch_orders) | |
| position_idx = position_idx.gather(1, patch_orders) | |
| pos_embedding = self.embed_image_positions(position_idx) | |
| padding_mask[~images_mask] = True | |
| return image_embedding, padding_mask, position_idx, pos_embedding | |
| def process_embedding(self, | |
| embedding, | |
| pos_embedding=None, | |
| type_tokens=None, | |
| embedding_ln=None): | |
| if self.entangle_position_embedding and pos_embedding is not None: | |
| embedding += pos_embedding | |
| if self.embed_type is not None: | |
| embedding += self.embed_type(type_tokens) | |
| if embedding_ln is not None: | |
| embedding = embedding_ln(embedding) | |
| embedding = self.dropout(embedding) | |
| return embedding | |
| def get_rel_pos_bias(self, x, idx): | |
| seq_len = x.size(1) | |
| rp_bucket = self.token_rp_bucket[:seq_len, :seq_len] | |
| values = F.embedding(rp_bucket, | |
| self.token_rel_pos_table_list[idx].weight) | |
| values = values.unsqueeze(0).expand(x.size(0), -1, -1, -1) | |
| values = values.permute([0, 3, 1, 2]) | |
| return values.contiguous() | |
| def get_image_rel_pos_bias(self, image_position_ids, idx): | |
| bsz, seq_len = image_position_ids.shape | |
| rp_bucket_size = self.image_rp_bucket.size(1) | |
| rp_bucket = self.image_rp_bucket.unsqueeze(0).expand( | |
| bsz, rp_bucket_size, rp_bucket_size).gather( | |
| 1, image_position_ids[:, :, None].expand( | |
| bsz, seq_len, rp_bucket_size)).gather( | |
| 2, image_position_ids[:, None, :].expand( | |
| bsz, seq_len, seq_len)) | |
| values = F.embedding(rp_bucket, | |
| self.image_rel_pos_table_list[idx].weight) | |
| values = values.permute(0, 3, 1, 2) | |
| return values | |
| class OFADecoder(BaseModule): | |
| """The decoder module of OFA. | |
| Args: | |
| embed_tokens (nn.Embedding): The embedding module to embed the | |
| input tokens. | |
| num_layers (int): The number of decoder layers. Defaults to 6. | |
| num_heads (int): The number of heads of attention. Defaults to 12. | |
| dropout_rate (float): The prob of dropout for embedding and | |
| transformer layers. Defaults to 0. | |
| drop_path_rate (float): The prob of droppath for transformer layers. | |
| Defaults to 0. | |
| max_target_positions (int): The maximum length of the input tokens. | |
| Defaults to 1024. | |
| code_image_size (int): The resolution of the generated image in the | |
| image infilling task. Defaults to 128. | |
| token_bucket_size (int): The token bucket size, it's used as the | |
| maximum relative position index in relative position embedding | |
| of input tokens. Defaults to 256. | |
| image_bucket_size (int): The image bucket size, it's used to generate | |
| the image relative position embedding table. It should be larger | |
| than the shape of image feature map. Defaults to 42. | |
| attn_scale_factor (float): The scale factor to calculate qk scale in | |
| attentions. Defaults to 2. | |
| scale_embedding (bool): Whether to scale the embeddings by the square | |
| root of the dimension. Defaults to False. | |
| add_embedding_ln (bool): Whether to add an extra layer norm for token | |
| embeddings. Defaults to True. | |
| add_code_embedding_ln (bool): Whether to add an extra layer norm for | |
| code embeddings. Defaults to True. | |
| pre_norm (bool): Whether to do layer norm before attention and ffn | |
| blocks in transformer layers. Defaults to True. | |
| entangle_position_embedding (bool): Whether to add the position | |
| embedding on the embeddings directly. Defaults to False. | |
| share_input_output_embed (bool): Share the weights of the input token | |
| embedding module and the output projection module. | |
| Defaults to True. | |
| init_cfg (dict, optional): The initialization config. Defaults to None. | |
| """ | |
| def __init__( | |
| self, | |
| embed_tokens, | |
| num_layers=6, | |
| num_heads=12, | |
| dropout_rate=0., | |
| drop_layer_rate=0., | |
| drop_path_rate=0., | |
| max_target_positions=1024, | |
| code_image_size=128, | |
| token_bucket_size=256, | |
| image_bucket_size=42, | |
| attn_scale_factor=2., | |
| scale_embedding=False, | |
| add_embedding_ln=True, | |
| add_code_embedding_ln=True, | |
| pre_norm=True, | |
| entangle_position_embedding=False, | |
| share_input_output_embed=True, | |
| init_cfg=None, | |
| ): | |
| super().__init__(init_cfg=init_cfg) | |
| self._future_mask = torch.empty(0) | |
| self.num_layers = num_layers | |
| embedding_dim = embed_tokens.embedding_dim | |
| self.embedding_dim = embedding_dim | |
| self.padding_idx = embed_tokens.padding_idx | |
| self.max_target_positions = max_target_positions | |
| self.num_heads = num_heads | |
| # Build embedding process components | |
| self.embed_tokens = embed_tokens | |
| self.embedding_scale = math.sqrt( | |
| embedding_dim) if scale_embedding else 1.0 | |
| if add_embedding_ln: | |
| self.embedding_ln = nn.LayerNorm(embedding_dim) | |
| else: | |
| self.embedding_ln = None | |
| if add_code_embedding_ln: | |
| self.code_embedding_ln = nn.LayerNorm(embedding_dim) | |
| else: | |
| self.code_embedding_ln = None | |
| # Build position embedding | |
| self.embed_positions = nn.Embedding(self.max_target_positions + 2, | |
| embedding_dim) | |
| self.pos_ln = nn.LayerNorm(embedding_dim) | |
| self.embed_image_positions = nn.Embedding(image_bucket_size**2 + 1, | |
| embedding_dim) | |
| self.image_pos_ln = nn.LayerNorm(embedding_dim) | |
| self.pos_scaling = float(embedding_dim / num_heads * | |
| attn_scale_factor)**-0.5 | |
| self.self_pos_q_linear = nn.Linear(embedding_dim, embedding_dim) | |
| self.self_pos_k_linear = nn.Linear(embedding_dim, embedding_dim) | |
| self.cross_pos_q_linear = nn.Linear(embedding_dim, embedding_dim) | |
| self.cross_pos_k_linear = nn.Linear(embedding_dim, embedding_dim) | |
| self.entangle_position_embedding = entangle_position_embedding | |
| self.dropout = nn.Dropout( | |
| dropout_rate) if dropout_rate > 0. else nn.Identity() | |
| if drop_layer_rate > 0.: | |
| raise NotImplementedError | |
| # Register token relative position embedding table | |
| self.token_bucket_size = token_bucket_size | |
| token_num_rel_dis = 2 * token_bucket_size - 1 | |
| token_rp_bucket = make_token_bucket_position(token_bucket_size) | |
| self.register_buffer('token_rp_bucket', token_rp_bucket) | |
| self.token_rel_pos_table_list = nn.ModuleList() | |
| # Register image relative position embedding table | |
| self.image_bucket_size = image_bucket_size | |
| image_num_rel_dis = (2 * image_bucket_size - | |
| 1) * (2 * image_bucket_size - 1) + 3 | |
| image_rp_bucket = make_image_bucket_position(image_bucket_size, | |
| image_num_rel_dis) | |
| self.register_buffer('image_rp_bucket', image_rp_bucket) | |
| self.image_rel_pos_table_list = nn.ModuleList() | |
| self.window_size = code_image_size // 8 | |
| position_col = torch.arange(self.window_size).unsqueeze(0) | |
| position_row = torch.arange( | |
| self.window_size).unsqueeze(1) * self.image_bucket_size | |
| image_position_idx = (position_col + position_row + 1) | |
| image_position_idx = torch.cat( | |
| [torch.tensor([0]), image_position_idx.view(-1)]) | |
| image_position_idx = torch.cat( | |
| [image_position_idx, | |
| torch.tensor([1024] * 768)]) | |
| self.register_buffer('image_position_idx', image_position_idx) | |
| # Build decoder layers | |
| self.layers = nn.ModuleList() | |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] | |
| for index in range(self.num_layers): | |
| layer = OFADecoderLayer( | |
| embedding_dim=embedding_dim, | |
| num_heads=num_heads, | |
| dropout_rate=dropout_rate, | |
| drop_path_rate=dpr[index], | |
| scale_factor=attn_scale_factor, | |
| pre_norm=pre_norm, | |
| ) | |
| self.layers.append(layer) | |
| token_pos_table = nn.Embedding(token_num_rel_dis, self.num_heads) | |
| image_pos_table = nn.Embedding(image_num_rel_dis, self.num_heads) | |
| nn.init.constant_(token_pos_table.weight, 0.) | |
| nn.init.constant_(image_pos_table.weight, 0.) | |
| self.token_rel_pos_table_list.append(token_pos_table) | |
| self.image_rel_pos_table_list.append(image_pos_table) | |
| if pre_norm: | |
| self.final_ln = nn.LayerNorm(embedding_dim) | |
| else: | |
| self.final_ln = None | |
| # Build output projection | |
| if share_input_output_embed: | |
| self.output_projection = nn.Linear( | |
| self.embed_tokens.weight.shape[1], | |
| self.embed_tokens.weight.shape[0], | |
| bias=False, | |
| ) | |
| self.output_projection.weight = self.embed_tokens.weight | |
| else: | |
| vocab_size = self.embed_tokens.num_embeddings | |
| self.output_projection = nn.Linear( | |
| embedding_dim, vocab_size, bias=False) | |
| nn.init.normal_( | |
| self.output_projection.weight, | |
| mean=0, | |
| std=embedding_dim**-0.5, | |
| ) | |
| main_input_name = 'input_ids' | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor = None, | |
| attention_mask: torch.Tensor = None, | |
| encoder_hidden_states: torch.Tensor = None, | |
| encoder_attention_mask: torch.Tensor = None, | |
| code_masks: Optional[torch.Tensor] = None, | |
| encoder_pos_embedding: Optional[torch.Tensor] = None, | |
| past_key_values: Optional[torch.Tensor] = None, | |
| use_cache: bool = False, | |
| output_attentions: bool = False, | |
| output_hidden_states: bool = False, | |
| ): | |
| if past_key_values is not None and len(past_key_values) > 0: | |
| B, _, L_past, _ = past_key_values[0][0].shape | |
| L = L_past + 1 | |
| else: | |
| B, L = input_ids.shape | |
| L_past = 0 | |
| # Embed the token position | |
| target_pos_idx = torch.arange( | |
| L, device=input_ids.device).expand([B, L]).contiguous() | |
| pos_embedding = self.embed_positions(target_pos_idx) | |
| # Embed the code positions | |
| if code_masks is not None and torch.any(code_masks): | |
| image_position_idx = self.image_position_idx[:input_ids.size(1)] | |
| image_position_idx = image_position_idx.unsqueeze(0).expand(B, L) | |
| pos_embedding[code_masks] = self.embed_image_positions( | |
| image_position_idx)[code_masks] | |
| # Self-attention position bias (B, num_heads, L_t, L_t) | |
| self_abs_pos_bias = self.get_pos_info(self.pos_ln(pos_embedding)) | |
| if code_masks is not None and torch.any(code_masks): | |
| self_image_abs_pos_bias = self.get_pos_info( | |
| self.image_pos_ln(pos_embedding)) | |
| self_abs_pos_bias[code_masks] = self_image_abs_pos_bias[code_masks] | |
| # Cross-attention position bias (B, num_heads, L_t, L_s) | |
| cross_abs_pos_bias = self.get_pos_info( | |
| self.pos_ln(pos_embedding), encoder_pos_embedding) | |
| if code_masks is not None and torch.any(code_masks): | |
| cross_image_abs_pos_bias = self.get_pos_info( | |
| self.image_pos_ln(pos_embedding), encoder_pos_embedding) | |
| cross_abs_pos_bias[code_masks] = cross_image_abs_pos_bias[ | |
| code_masks] | |
| all_prev_output_tokens = input_ids.clone() | |
| if past_key_values is not None and len(past_key_values) > 0: | |
| input_ids = input_ids[:, -1:] | |
| cross_abs_pos_bias = cross_abs_pos_bias[:, :, -1:, :] | |
| pos_embedding = pos_embedding[:, -1:, :] | |
| # Embed the input tokens | |
| x = self.embed_tokens(input_ids) * self.embedding_scale | |
| if self.entangle_position_embedding: | |
| x += pos_embedding | |
| if self.embedding_ln is not None: | |
| if (code_masks is None or not code_masks.any() | |
| or self.code_embedding_ln is None): | |
| x = self.embedding_ln(x) | |
| elif code_masks is not None and code_masks.all(): | |
| x = self.code_embedding_ln(x) | |
| else: | |
| x[~code_masks] = self.embedding_ln(x[~code_masks]) | |
| x[code_masks] = self.code_embedding_ln(x[code_masks]) | |
| x = self.dropout(x) | |
| attention_mask = self._prepare_decoder_attention_mask( | |
| attention_mask, input_ids.shape, x.dtype, L_past) | |
| attention_mask = attention_mask.to(x.device) | |
| # decoder layers | |
| all_hidden_states = [] if output_hidden_states else None | |
| all_self_attns = [] if output_attentions else None | |
| all_cross_attentions = [] if ( | |
| output_attentions and encoder_hidden_states is not None) else None | |
| next_decoder_cache = [] if use_cache else None | |
| for idx, layer in enumerate(self.layers): | |
| # add hidden states from the last decoder layer | |
| if output_hidden_states: | |
| all_hidden_states.append(x) | |
| if past_key_values is not None and len(past_key_values) > 0: | |
| past_key_value = past_key_values[idx] | |
| else: | |
| past_key_value = None | |
| self_attn_bias = self_abs_pos_bias.clone() | |
| if code_masks is None or not code_masks.any(): | |
| self_attn_bias += self.get_rel_pos_bias( | |
| all_prev_output_tokens, idx) | |
| elif code_masks is not None and code_masks.all(): | |
| self_attn_bias += self.get_image_rel_pos_bias( | |
| all_prev_output_tokens, idx) | |
| else: | |
| self_attn_bias[~code_masks] += self.get_rel_pos_bias( | |
| all_prev_output_tokens, idx) | |
| self_attn_bias[code_masks] += self.get_image_rel_pos_bias( | |
| all_prev_output_tokens, idx) | |
| if past_key_value is not None: | |
| self_attn_bias = self_attn_bias[:, :, -1:, :] | |
| out = layer( | |
| x, | |
| attention_mask=attention_mask, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| past_key_value=past_key_value, | |
| output_attentions=output_attentions, | |
| use_cache=use_cache, | |
| self_attn_bias=self_attn_bias, | |
| cross_attn_bias=cross_abs_pos_bias, | |
| ) | |
| x = out.pop(0) | |
| if output_attentions: | |
| all_self_attns.append(out.pop(0)) | |
| if encoder_hidden_states is not None: | |
| all_cross_attentions.append(out.pop(0)) | |
| if use_cache: | |
| next_decoder_cache.append(out.pop(0)) | |
| # add hidden states from the last decoder layer | |
| if output_hidden_states: | |
| all_hidden_states += (x, ) | |
| if self.final_ln is not None: | |
| x = self.final_ln(x) | |
| x = self.output_projection(x) | |
| return BaseModelOutputWithPastAndCrossAttentions( | |
| last_hidden_state=x, | |
| past_key_values=next_decoder_cache, | |
| hidden_states=all_hidden_states, | |
| attentions=all_self_attns, | |
| cross_attentions=all_cross_attentions, | |
| ) | |
| def _prepare_decoder_attention_mask( | |
| self, | |
| attention_mask, | |
| input_shape, | |
| dtype, | |
| past_key_values_length, | |
| ): | |
| r""" | |
| Create causal mask for unidirectional decoding. | |
| [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
| """ | |
| combined_attention_mask = None | |
| if input_shape[-1] > 1: | |
| combined_attention_mask = _make_causal_mask( | |
| input_shape, | |
| dtype, | |
| past_key_values_length=past_key_values_length).to( | |
| attention_mask.device) | |
| if attention_mask is not None: | |
| # (B, L_s) -> (B, 1, L_t, L_s) | |
| expanded_attention_mask = _expand_mask( | |
| attention_mask, dtype, tgt_len=input_shape[-1]) | |
| combined_attention_mask = ( | |
| expanded_attention_mask if combined_attention_mask is None else | |
| expanded_attention_mask + combined_attention_mask) | |
| return combined_attention_mask | |
| def get_pos_info(self, pos_embedding, src_pos_embedding=None): | |
| B, tgt_len = pos_embedding.shape[:2] | |
| if src_pos_embedding is not None: | |
| src_len = src_pos_embedding.size(1) | |
| pos_q = self.cross_pos_q_linear(pos_embedding).view( | |
| B, tgt_len, self.num_heads, -1).transpose(1, 2) | |
| pos_q = pos_q * self.pos_scaling | |
| pos_k = self.cross_pos_k_linear(src_pos_embedding).view( | |
| B, src_len, self.num_heads, -1).transpose(1, 2) | |
| else: | |
| pos_q = self.self_pos_q_linear(pos_embedding).view( | |
| B, tgt_len, self.num_heads, -1).transpose(1, 2) | |
| pos_q = pos_q * self.pos_scaling | |
| pos_k = self.self_pos_k_linear(pos_embedding).view( | |
| B, tgt_len, self.num_heads, -1).transpose(1, 2) | |
| abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) | |
| return abs_pos_bias | |
| def get_rel_pos_bias(self, x, idx): | |
| seq_len = x.size(1) | |
| rp_bucket = self.token_rp_bucket[:seq_len, :seq_len] | |
| values = F.embedding(rp_bucket, | |
| self.token_rel_pos_table_list[idx].weight) | |
| values = values.unsqueeze(0).expand(x.size(0), -1, -1, -1) | |
| values = values.permute([0, 3, 1, 2]) | |
| return values.contiguous() | |
| def get_image_rel_pos_bias(self, image_position_ids, idx): | |
| bsz, seq_len = image_position_ids.shape | |
| rp_bucket_size = self.image_rp_bucket.size(1) | |
| rp_bucket = self.image_rp_bucket.unsqueeze(0).expand( | |
| bsz, rp_bucket_size, rp_bucket_size).gather( | |
| 1, image_position_ids[:, :, None].expand( | |
| bsz, seq_len, rp_bucket_size)).gather( | |
| 2, image_position_ids[:, None, :].expand( | |
| bsz, seq_len, seq_len)) | |
| values = F.embedding(rp_bucket, | |
| self.image_rel_pos_table_list[idx].weight) | |
| values = values.permute(0, 3, 1, 2) | |
| return values | |
| class OFAEncoderDecoder(BaseModule, GenerationMixin): | |
| """The OFA main architecture with an encoder and a decoder. | |
| Args: | |
| encoder_cfg (dict): The config of the encoder, accept the keyword | |
| arguments of :class:`OFAEncoder`. | |
| decoder_cfg (dict): The config of the decoder, accept the keyword | |
| arguments of :class:`OFADecoder`. | |
| padding_idx (int): The index of the padding token. | |
| vocab_size (int): The size of the vocabulary. | |
| embedding_dim (int): The embedding dimensions of both the encoder | |
| and the decoder. | |
| generation_cfg (dict): The extra generation config, accept the keyword | |
| arguments of :class:`~transformers.GenerationConfig`. | |
| Defaults to an empty dict. | |
| init_cfg (dict, optional): The initialization config. Defaults to None. | |
| """ | |
| base_model_prefix = '' | |
| def __init__( | |
| self, | |
| encoder_cfg, | |
| decoder_cfg, | |
| padding_idx, | |
| vocab_size, | |
| embedding_dim, | |
| generation_cfg=dict(), | |
| init_cfg=None, | |
| ): | |
| super().__init__(init_cfg=init_cfg) | |
| self.padding_idx = padding_idx | |
| self.vocab_size = vocab_size | |
| self.embedding_dim = embedding_dim | |
| embed_tokens = nn.Embedding(vocab_size, embedding_dim, padding_idx) | |
| self.encoder = OFAEncoder(embed_tokens, **encoder_cfg) | |
| self.decoder = OFADecoder(embed_tokens, **decoder_cfg) | |
| self.config = PretrainedConfig( | |
| vocab_size=vocab_size, | |
| embedding_dim=embedding_dim, | |
| padding_idx=padding_idx, | |
| bos_token_id=0, | |
| decoder_start_token_id=0, | |
| pad_token_id=1, | |
| eos_token_id=2, | |
| forced_eos_token_id=2, | |
| use_cache=False, | |
| is_encoder_decoder=True, | |
| ) | |
| self.config.update(generation_cfg) | |
| self.generation_config = GenerationConfig.from_model_config( | |
| self.config) | |
| def device(self): | |
| return next(self.parameters()).device | |
| def can_generate(self): | |
| return True | |
| def get_encoder(self): | |
| return self.encoder | |
| def get_decoder(self): | |
| return self.decoder | |
| def max_decoder_positions(self): | |
| """Maximum length supported by the decoder.""" | |
| return self.decoder.max_positions() | |
| def get_normalized_probs(self, net_output, log_probs: bool, sample=None): | |
| """Get normalized probabilities (or log probs) from a net's output.""" | |
| return self.get_normalized_probs_scriptable(net_output, log_probs, | |
| sample) | |
| def get_normalized_probs_scriptable( | |
| self, | |
| net_output, | |
| log_probs: bool, | |
| sample=None, | |
| ): | |
| """Scriptable helper function for get_normalized_probs in. | |
| ~BaseFairseqModel. | |
| """ | |
| if hasattr(self, 'decoder'): | |
| return self.decoder.get_normalized_probs(net_output, log_probs, | |
| sample) | |
| elif torch.is_tensor(net_output): | |
| # syntactic sugar for simple models which don't have a decoder | |
| # (e.g., the classification tutorial) | |
| logits = net_output.float() | |
| if log_probs: | |
| return F.log_softmax(logits, dim=-1) | |
| else: | |
| return F.softmax(logits, dim=-1) | |
| raise NotImplementedError | |
| main_input_name = 'input_ids' | |
| def forward(self, | |
| input_ids=None, | |
| images=None, | |
| images_mask=None, | |
| sample_patch_num=None, | |
| decoder_input_ids=None, | |
| code_masks=None, | |
| attention_mask=None, | |
| encoder_outputs=None, | |
| past_key_values=None, | |
| use_cache=False, | |
| output_attentions=False, | |
| output_hidden_states=False, | |
| constrain_fn=None, | |
| return_dict=False): | |
| """Forword the module. | |
| Args: | |
| input_ids (torch.Tensor): The indices of the input tokens in the | |
| vocabulary, and padding will be ignored by default. The indices | |
| can be obtained using :class:`OFATokenizer`. | |
| The shape is (B, L). | |
| images (torch.Tensor): The input images. The shape is (B, 3, H, W). | |
| images_mask (torch.Tensor): The mask of all available images. The | |
| shape is (B, ). | |
| sample_patch_num (int): The number of patches to sample for the | |
| images. Defaults to None, which means to use all patches. | |
| decoder_input_ids (torch.Tensor): The indices of the input tokens | |
| for the decoder. | |
| code_masks (torch.Tensor): The mask of all samples for image | |
| generation. The shape is (B, ). | |
| attention_mask (torch.Tensor): The attention mask for decoding. | |
| The shape is (B, L). | |
| encoder_outputs (OFAEncoderOutput): The encoder outputs with hidden | |
| states, positional embeddings, and padding masks. | |
| past_key_values (Tuple[Tuple[torch.Tensor]]): If use cache, the | |
| parameter is a tuple of length ``num_layers``. Every item is | |
| also a tuple with four tensors, two for the key and value of | |
| self-attention, two for the key and value of cross-attention. | |
| use_cache (bool): Whether to use cache for faster inference. | |
| Defaults to False. | |
| output_attentions (bool): Whether to output attention weights. | |
| Defaults to False. | |
| output_hidden_states (bool): Whether to output hidden states. | |
| Defaults to False. | |
| constrain_fn (Callable, optional): The function to constrain the | |
| output logits. Defaults to None. | |
| return_dict (bool): Not used, it's only for compat with the | |
| interface of the ``generate`` of ``transformers``. | |
| Returns: | |
| Seq2SeqLMOutput: | |
| - logits (``torch.Tensor``): The last decoder hidden states. | |
| The shape is (B, L, C). | |
| - past_key_values (``Tuple[Tuple[torch.Tensor]]``): The past keys | |
| and values for faster inference. | |
| - decoder_hidden_states (``Tuple[torch.Tensor]``): the decoder | |
| hidden states of all layers. | |
| - decoder_attentions (``Tuple[torch.Tensor]``): The self-attention | |
| weights of all layers in the decoder. | |
| - cross_attentions (``Tuple[torch.Tensor]``): The cross-attention | |
| weights of all layers in the decoder. | |
| - encoder_last_hidden_state (``torch.Tensor``): The last encoder | |
| hidden states. | |
| - encoder_hidden_states (``Tuple[torch.Tensor]``): The encoder | |
| hidden states of all layers, including the embeddings. | |
| - encoder_attentions (``Tuple[torch.Tensor]``): The self-attention | |
| weights of all layers in the encoder. | |
| """ | |
| if encoder_outputs is None: | |
| encoder_outputs = self.encoder( | |
| input_ids=input_ids, | |
| images=images, | |
| images_mask=images_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| sample_patch_num=sample_patch_num, | |
| ) | |
| if decoder_input_ids.eq(self.padding_idx).any(): | |
| attention_mask = decoder_input_ids.eq(self.padding_idx) | |
| encoder_hidden_states = encoder_outputs.last_hidden_state | |
| encoder_attention_mask = _expand_mask(encoder_outputs.padding_mask, | |
| encoder_hidden_states.dtype, | |
| decoder_input_ids.shape[-1]) | |
| src_pos_embed = encoder_outputs.position_embedding | |
| decoder_outputs = self.decoder( | |
| input_ids=decoder_input_ids, | |
| attention_mask=attention_mask, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| code_masks=code_masks, | |
| encoder_pos_embedding=src_pos_embed, | |
| past_key_values=past_key_values, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| ) | |
| # The constrain operation for fine-tuned model in OFA is applied | |
| # before log_softmax, therefore we cannot use | |
| # `prefix_allowed_tokens_fn` to implement it. | |
| if constrain_fn is not None: | |
| logits = constrain_fn(decoder_input_ids, | |
| decoder_outputs.last_hidden_state) | |
| else: | |
| logits = decoder_outputs.last_hidden_state | |
| return Seq2SeqLMOutput( | |
| logits=logits, | |
| past_key_values=decoder_outputs.past_key_values, | |
| decoder_hidden_states=decoder_outputs.hidden_states, | |
| decoder_attentions=decoder_outputs.attentions, | |
| cross_attentions=decoder_outputs.cross_attentions, | |
| encoder_last_hidden_state=encoder_outputs.last_hidden_state, | |
| encoder_hidden_states=encoder_outputs.hidden_states, | |
| encoder_attentions=encoder_outputs.attentions, | |
| ) | |
| def prepare_inputs_for_generation(self, | |
| decoder_input_ids=None, | |
| past=None, | |
| attention_mask=None, | |
| code_masks=None, | |
| use_cache=False, | |
| encoder_outputs=None, | |
| constrain_fn=None, | |
| **kwargs): | |
| # if attention_mask is None: | |
| attention_mask = decoder_input_ids.new_zeros(decoder_input_ids.shape) | |
| # cut decoder_input_ids if past is used | |
| if past is not None: | |
| decoder_input_ids = decoder_input_ids[:, -1:] | |
| return { | |
| 'input_ids': None, | |
| 'images': None, | |
| 'images_mask': None, | |
| 'sample_patch_num': None, | |
| 'attention_mask': attention_mask, | |
| 'encoder_outputs': encoder_outputs, | |
| 'past_key_values': past, | |
| 'decoder_input_ids': decoder_input_ids, | |
| 'code_masks': code_masks, | |
| 'use_cache': use_cache, | |
| 'constrain_fn': constrain_fn, | |
| } | |
| def _prepare_encoder_decoder_kwargs_for_generation( | |
| self, | |
| inputs_tensor: torch.Tensor, | |
| model_kwargs, | |
| model_input_name: Optional[str] = None): | |
| # 1. get encoder | |
| encoder = self.get_encoder() | |
| # 2. prepare encoder args and encoder kwargs from model kwargs | |
| irrelevant_prefix = [ | |
| 'decoder_', 'cross_attn', 'use_cache', 'attention_mask', | |
| 'constrain_fn' | |
| ] | |
| encoder_kwargs = { | |
| argument: value | |
| for argument, value in model_kwargs.items() | |
| if not any(argument.startswith(p) for p in irrelevant_prefix) | |
| } | |
| if encoder_kwargs.get('images_mask') is None: | |
| encoder_kwargs['images_mask'] = torch.tensor([True] * | |
| inputs_tensor.size(0)) | |
| # 3. make sure that encoder returns `ModelOutput` | |
| model_input_name = model_input_name or self.main_input_name | |
| encoder_kwargs[model_input_name] = inputs_tensor | |
| model_kwargs['encoder_outputs']: ModelOutput = encoder( | |
| **encoder_kwargs) | |
| model_kwargs['attention_mask'] = None | |
| return model_kwargs | |
| def _reorder_cache(past, beam_idx): | |
| reordered_past = () | |
| for layer_past in past: | |
| reordered_past += (tuple( | |
| past_state.index_select(0, beam_idx) | |
| for past_state in layer_past), ) | |
| return reordered_past | |
| def _expand_inputs_for_generation( | |
| input_ids: torch.LongTensor, | |
| expand_size: int = 1, | |
| is_encoder_decoder: bool = False, | |
| attention_mask: Optional[torch.LongTensor] = None, | |
| encoder_outputs: Optional[ModelOutput] = None, | |
| **model_kwargs, | |
| ): | |
| expanded_return_idx = ( | |
| torch.arange(input_ids.shape[0]).view(-1, 1).repeat( | |
| 1, expand_size).view(-1).to(input_ids.device)) | |
| input_ids = input_ids.index_select(0, expanded_return_idx) | |
| if attention_mask is not None: | |
| model_kwargs['attention_mask'] = attention_mask.index_select( | |
| 0, expanded_return_idx) | |
| if is_encoder_decoder: | |
| if encoder_outputs is None: | |
| raise ValueError('If `is_encoder_decoder` is True, make ' | |
| 'sure that `encoder_outputs` is defined.') | |
| encoder_outputs['last_hidden_state'] = encoder_outputs.\ | |
| last_hidden_state.index_select(0, expanded_return_idx) | |
| encoder_outputs['position_embedding'] = encoder_outputs.\ | |
| position_embedding.index_select(0, expanded_return_idx) | |
| encoder_outputs['padding_mask'] = encoder_outputs.\ | |
| padding_mask.index_select(0, expanded_return_idx) | |
| model_kwargs['encoder_outputs'] = encoder_outputs | |
| return input_ids, model_kwargs | |