Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # | |
| # Portions of this file were adapted from the open source code for the following | |
| # two papers: | |
| # | |
| # Ingraham, J., Garg, V., Barzilay, R., & Jaakkola, T. (2019). Generative | |
| # models for graph-based protein design. Advances in Neural Information | |
| # Processing Systems, 32. | |
| # | |
| # Jing, B., Eismann, S., Suriana, P., Townshend, R. J. L., & Dror, R. (2020). | |
| # Learning from Protein Structure with Geometric Vector Perceptrons. In | |
| # International Conference on Learning Representations. | |
| # | |
| # MIT License | |
| # | |
| # Copyright (c) 2020 Bowen Jing, Stephan Eismann, Patricia Suriana, Raphael Townshend, Ron Dror | |
| # | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| # | |
| # ================================================================ | |
| # The below license applies to the portions of the code (parts of | |
| # src/datasets.py and src/models.py) adapted from Ingraham, et al. | |
| # ================================================================ | |
| # | |
| # MIT License | |
| # | |
| # Copyright (c) 2019 John Ingraham, Vikas Garg, Regina Barzilay, Tommi Jaakkola | |
| # | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| import math | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| print("features1") | |
| from .gvp_utils import flatten_graph | |
| print("features2") | |
| from .gvp_modules import GVP, LayerNorm | |
| print("features3") | |
| from .util import normalize, norm, nan_to_num, rbf | |
| print("features4") | |
| class GVPInputFeaturizer(nn.Module): | |
| def get_node_features(coords, coord_mask, with_coord_mask=True): | |
| # scalar features | |
| node_scalar_features = GVPInputFeaturizer._dihedrals(coords) | |
| if with_coord_mask: | |
| node_scalar_features = torch.cat([ | |
| node_scalar_features, | |
| coord_mask.float().unsqueeze(-1) | |
| ], dim=-1) | |
| # vector features | |
| X_ca = coords[:, :, 1] | |
| orientations = GVPInputFeaturizer._orientations(X_ca) | |
| sidechains = GVPInputFeaturizer._sidechains(coords) | |
| node_vector_features = torch.cat([orientations, sidechains.unsqueeze(-2)], dim=-2) | |
| return node_scalar_features, node_vector_features | |
| def _orientations(X): | |
| forward = normalize(X[:, 1:] - X[:, :-1]) | |
| backward = normalize(X[:, :-1] - X[:, 1:]) | |
| forward = F.pad(forward, [0, 0, 0, 1]) | |
| backward = F.pad(backward, [0, 0, 1, 0]) | |
| return torch.cat([forward.unsqueeze(-2), backward.unsqueeze(-2)], -2) | |
| def _sidechains(X): | |
| n, origin, c = X[:, :, 0], X[:, :, 1], X[:, :, 2] | |
| c, n = normalize(c - origin), normalize(n - origin) | |
| bisector = normalize(c + n) | |
| perp = normalize(torch.cross(c, n, dim=-1)) | |
| vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3) | |
| return vec | |
| def _dihedrals(X, eps=1e-7): | |
| X = torch.flatten(X[:, :, :3], 1, 2) | |
| bsz = X.shape[0] | |
| dX = X[:, 1:] - X[:, :-1] | |
| U = normalize(dX, dim=-1) | |
| u_2 = U[:, :-2] | |
| u_1 = U[:, 1:-1] | |
| u_0 = U[:, 2:] | |
| # Backbone normals | |
| n_2 = normalize(torch.cross(u_2, u_1, dim=-1), dim=-1) | |
| n_1 = normalize(torch.cross(u_1, u_0, dim=-1), dim=-1) | |
| # Angle between normals | |
| cosD = torch.sum(n_2 * n_1, -1) | |
| cosD = torch.clamp(cosD, -1 + eps, 1 - eps) | |
| D = torch.sign(torch.sum(u_2 * n_1, -1)) * torch.acos(cosD) | |
| # This scheme will remove phi[0], psi[-1], omega[-1] | |
| D = F.pad(D, [1, 2]) | |
| D = torch.reshape(D, [bsz, -1, 3]) | |
| # Lift angle representations to the circle | |
| D_features = torch.cat([torch.cos(D), torch.sin(D)], -1) | |
| return D_features | |
| def _positional_embeddings(edge_index, | |
| num_embeddings=None, | |
| num_positional_embeddings=16, | |
| period_range=[2, 1000]): | |
| # From https://github.com/jingraham/neurips19-graph-protein-design | |
| num_embeddings = num_embeddings or num_positional_embeddings | |
| d = edge_index[0] - edge_index[1] | |
| frequency = torch.exp( | |
| torch.arange(0, num_embeddings, 2, dtype=torch.float32, | |
| device=edge_index.device) | |
| * -(np.log(10000.0) / num_embeddings) | |
| ) | |
| angles = d.unsqueeze(-1) * frequency | |
| E = torch.cat((torch.cos(angles), torch.sin(angles)), -1) | |
| return E | |
| def _dist(X, coord_mask, padding_mask, top_k_neighbors, eps=1e-8): | |
| """ Pairwise euclidean distances """ | |
| bsz, maxlen = X.size(0), X.size(1) | |
| coord_mask_2D = torch.unsqueeze(coord_mask,1) * torch.unsqueeze(coord_mask,2) | |
| residue_mask = ~padding_mask | |
| residue_mask_2D = torch.unsqueeze(residue_mask,1) * torch.unsqueeze(residue_mask,2) | |
| dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2) | |
| D = coord_mask_2D * norm(dX, dim=-1) | |
| # sorting preference: first those with coords, then among the residues that | |
| # exist but are masked use distance in sequence as tie breaker, and then the | |
| # residues that came from padding are last | |
| seqpos = torch.arange(maxlen, device=X.device) | |
| Dseq = torch.abs(seqpos.unsqueeze(1) - seqpos.unsqueeze(0)).repeat(bsz, 1, 1) | |
| D_adjust = nan_to_num(D) + (~coord_mask_2D) * (1e8 + Dseq*1e6) + ( | |
| ~residue_mask_2D) * (1e10) | |
| if top_k_neighbors == -1: | |
| D_neighbors = D_adjust | |
| E_idx = seqpos.repeat( | |
| *D_neighbors.shape[:-1], 1) | |
| else: | |
| # Identify k nearest neighbors (including self) | |
| k = min(top_k_neighbors, X.size(1)) | |
| D_neighbors, E_idx = torch.topk(D_adjust, k, dim=-1, largest=False) | |
| coord_mask_neighbors = (D_neighbors < 5e7) | |
| residue_mask_neighbors = (D_neighbors < 5e9) | |
| return D_neighbors, E_idx, coord_mask_neighbors, residue_mask_neighbors | |
| class Normalize(nn.Module): | |
| def __init__(self, features, epsilon=1e-6): | |
| super(Normalize, self).__init__() | |
| self.gain = nn.Parameter(torch.ones(features)) | |
| self.bias = nn.Parameter(torch.zeros(features)) | |
| self.epsilon = epsilon | |
| def forward(self, x, dim=-1): | |
| mu = x.mean(dim, keepdim=True) | |
| sigma = torch.sqrt(x.var(dim, keepdim=True) + self.epsilon) | |
| gain = self.gain | |
| bias = self.bias | |
| # Reshape | |
| if dim != -1: | |
| shape = [1] * len(mu.size()) | |
| shape[dim] = self.gain.size()[0] | |
| gain = gain.view(shape) | |
| bias = bias.view(shape) | |
| return gain * (x - mu) / (sigma + self.epsilon) + bias | |
| class DihedralFeatures(nn.Module): | |
| def __init__(self, node_embed_dim): | |
| """ Embed dihedral angle features. """ | |
| super(DihedralFeatures, self).__init__() | |
| # 3 dihedral angles; sin and cos of each angle | |
| node_in = 6 | |
| # Normalization and embedding | |
| self.node_embedding = nn.Linear(node_in, node_embed_dim, bias=True) | |
| self.norm_nodes = Normalize(node_embed_dim) | |
| def forward(self, X): | |
| """ Featurize coordinates as an attributed graph """ | |
| V = self._dihedrals(X) | |
| V = self.node_embedding(V) | |
| V = self.norm_nodes(V) | |
| return V | |
| def _dihedrals(X, eps=1e-7, return_angles=False): | |
| # First 3 coordinates are N, CA, C | |
| X = X[:,:,:3,:].reshape(X.shape[0], 3*X.shape[1], 3) | |
| # Shifted slices of unit vectors | |
| dX = X[:,1:,:] - X[:,:-1,:] | |
| U = F.normalize(dX, dim=-1) | |
| u_2 = U[:,:-2,:] | |
| u_1 = U[:,1:-1,:] | |
| u_0 = U[:,2:,:] | |
| # Backbone normals | |
| n_2 = F.normalize(torch.cross(u_2, u_1, dim=-1), dim=-1) | |
| n_1 = F.normalize(torch.cross(u_1, u_0, dim=-1), dim=-1) | |
| # Angle between normals | |
| cosD = (n_2 * n_1).sum(-1) | |
| cosD = torch.clamp(cosD, -1+eps, 1-eps) | |
| D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD) | |
| # This scheme will remove phi[0], psi[-1], omega[-1] | |
| D = F.pad(D, (1,2), 'constant', 0) | |
| D = D.view((D.size(0), int(D.size(1)/3), 3)) | |
| phi, psi, omega = torch.unbind(D,-1) | |
| if return_angles: | |
| return phi, psi, omega | |
| # Lift angle representations to the circle | |
| D_features = torch.cat((torch.cos(D), torch.sin(D)), 2) | |
| return D_features | |
| class GVPGraphEmbedding(GVPInputFeaturizer): | |
| def __init__(self, args): | |
| super().__init__() | |
| self.top_k_neighbors = args.top_k_neighbors | |
| self.num_positional_embeddings = 16 | |
| self.remove_edges_without_coords = True | |
| node_input_dim = (7, 3) | |
| edge_input_dim = (34, 1) | |
| node_hidden_dim = (args.node_hidden_dim_scalar, | |
| args.node_hidden_dim_vector) | |
| edge_hidden_dim = (args.edge_hidden_dim_scalar, | |
| args.edge_hidden_dim_vector) | |
| self.embed_node = nn.Sequential( | |
| GVP(node_input_dim, node_hidden_dim, activations=(None, None)), | |
| LayerNorm(node_hidden_dim, eps=1e-4) | |
| ) | |
| self.embed_edge = nn.Sequential( | |
| GVP(edge_input_dim, edge_hidden_dim, activations=(None, None)), | |
| LayerNorm(edge_hidden_dim, eps=1e-4) | |
| ) | |
| self.embed_confidence = nn.Linear(16, args.node_hidden_dim_scalar) | |
| def forward(self, coords, coord_mask, padding_mask, confidence): | |
| with torch.no_grad(): | |
| node_features = self.get_node_features(coords, coord_mask) | |
| edge_features, edge_index = self.get_edge_features( | |
| coords, coord_mask, padding_mask) | |
| node_embeddings_scalar, node_embeddings_vector = self.embed_node(node_features) | |
| edge_embeddings = self.embed_edge(edge_features) | |
| rbf_rep = rbf(confidence, 0., 1.) | |
| node_embeddings = ( | |
| node_embeddings_scalar + self.embed_confidence(rbf_rep), | |
| node_embeddings_vector | |
| ) | |
| node_embeddings, edge_embeddings, edge_index = flatten_graph( | |
| node_embeddings, edge_embeddings, edge_index) | |
| return node_embeddings, edge_embeddings, edge_index | |
| def get_edge_features(self, coords, coord_mask, padding_mask): | |
| X_ca = coords[:, :, 1] | |
| # Get distances to the top k neighbors | |
| E_dist, E_idx, E_coord_mask, E_residue_mask = GVPInputFeaturizer._dist( | |
| X_ca, coord_mask, padding_mask, self.top_k_neighbors) | |
| # Flatten the graph to be batch size 1 for torch_geometric package | |
| dest = E_idx | |
| B, L, k = E_idx.shape[:3] | |
| src = torch.arange(L, device=E_idx.device).view([1, L, 1]).expand(B, L, k) | |
| # After flattening, [2, B, E] | |
| edge_index = torch.stack([src, dest], dim=0).flatten(2, 3) | |
| # After flattening, [B, E] | |
| E_dist = E_dist.flatten(1, 2) | |
| E_coord_mask = E_coord_mask.flatten(1, 2).unsqueeze(-1) | |
| E_residue_mask = E_residue_mask.flatten(1, 2) | |
| # Calculate relative positional embeddings and distance RBF | |
| pos_embeddings = GVPInputFeaturizer._positional_embeddings( | |
| edge_index, | |
| num_positional_embeddings=self.num_positional_embeddings, | |
| ) | |
| D_rbf = rbf(E_dist, 0., 20.) | |
| # Calculate relative orientation | |
| X_src = X_ca.unsqueeze(2).expand(-1, -1, k, -1).flatten(1, 2) | |
| X_dest = torch.gather( | |
| X_ca, | |
| 1, | |
| edge_index[1, :, :].unsqueeze(-1).expand([B, L*k, 3]) | |
| ) | |
| coord_mask_src = coord_mask.unsqueeze(2).expand(-1, -1, k).flatten(1, 2) | |
| coord_mask_dest = torch.gather( | |
| coord_mask, | |
| 1, | |
| edge_index[1, :, :].expand([B, L*k]) | |
| ) | |
| E_vectors = X_src - X_dest | |
| # For the ones without coordinates, substitute in the average vector | |
| E_vector_mean = torch.sum(E_vectors * E_coord_mask, dim=1, | |
| keepdims=True) / torch.sum(E_coord_mask, dim=1, keepdims=True) | |
| E_vectors = E_vectors * E_coord_mask + E_vector_mean * ~(E_coord_mask) | |
| # Normalize and remove nans | |
| edge_s = torch.cat([D_rbf, pos_embeddings], dim=-1) | |
| edge_v = normalize(E_vectors).unsqueeze(-2) | |
| edge_s, edge_v = map(nan_to_num, (edge_s, edge_v)) | |
| # Also add indications of whether the coordinates are present | |
| edge_s = torch.cat([ | |
| edge_s, | |
| (~coord_mask_src).float().unsqueeze(-1), | |
| (~coord_mask_dest).float().unsqueeze(-1), | |
| ], dim=-1) | |
| edge_index[:, ~E_residue_mask] = -1 | |
| if self.remove_edges_without_coords: | |
| edge_index[:, ~E_coord_mask.squeeze(-1)] = -1 | |
| return (edge_s, edge_v), edge_index.transpose(0, 1) | |