Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from argparse import Namespace | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .features import GVPGraphEmbedding | |
| from .gvp_modules import GVPConvLayer, LayerNorm | |
| from .gvp_utils import unflatten_graph | |
| class GVPEncoder(nn.Module): | |
| def __init__(self, args): | |
| super().__init__() | |
| self.args = args | |
| self.embed_graph = GVPGraphEmbedding(args) | |
| node_hidden_dim = (args.node_hidden_dim_scalar, | |
| args.node_hidden_dim_vector) | |
| edge_hidden_dim = (args.edge_hidden_dim_scalar, | |
| args.edge_hidden_dim_vector) | |
| conv_activations = (F.relu, torch.sigmoid) | |
| self.encoder_layers = nn.ModuleList( | |
| GVPConvLayer( | |
| node_hidden_dim, | |
| edge_hidden_dim, | |
| drop_rate=args.dropout, | |
| vector_gate=True, | |
| attention_heads=0, | |
| n_message=3, | |
| conv_activations=conv_activations, | |
| n_edge_gvps=0, | |
| eps=1e-4, | |
| layernorm=True, | |
| ) | |
| for i in range(args.num_encoder_layers) | |
| ) | |
| def forward(self, coords, coord_mask, padding_mask, confidence): | |
| node_embeddings, edge_embeddings, edge_index = self.embed_graph( | |
| coords, coord_mask, padding_mask, confidence) | |
| for i, layer in enumerate(self.encoder_layers): | |
| node_embeddings, edge_embeddings = layer(node_embeddings, | |
| edge_index, edge_embeddings) | |
| node_embeddings = unflatten_graph(node_embeddings, coords.shape[0]) | |
| return node_embeddings | |