Sophia Tang
Initial commit with LFS
7efee70
import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
import torch
import torch.nn as nn
import torch.nn.functional as F
class BiasForceTransformer(nn.Module):
def __init__(self,
args,
d_model = 256,
nhead = 8,
num_layers = 4,
dim_feedforward = 512,
dropout = 0.1,
):
super().__init__()
self.device = args.device
self.N = args.num_particles
self.use_delta_to_target = args.use_delta_to_target
self.rbf = args.rbf
self.sigma = args.sigma
G = args.dim
# Per-atom features in aligned frame for the Transformer
# pos_al(3), vel_al(3), delta_to_target(3 optional), distance(1)
feat_dim = (2 * G) + (G if self.use_delta_to_target else 0) + 1
self.input_proj = nn.Linear(feat_dim, d_model)
enc_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout, activation="gelu",
batch_first=True, norm_first=True
)
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
# Heads
self.scale_head = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, 1),
)
self.vec_head = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, args.dim),
)
self.log_z = nn.Parameter(torch.tensor(0.0))
#self.to(self.device)
@staticmethod
def _softplus_unit(x, beta=1.0, threshold=20.0, eps=1e-8):
return F.softplus(x, beta=beta, threshold=threshold) + eps
def forward(self, pos, vel, target):
"""
pos, vel, target: (B,N,D)
Returns: force (B,N,D), scale (B,N), vector (B,N,D)
N: number of cells in batch
D: dimension of gene vector
"""
B, N, G = pos.shape
assert N == self.N, f"Expected N={self.N}, got {N}"
# direction of target position
delta = target - pos # (B,N,G)
dist = torch.norm(delta, dim=-1, keepdim=True) # (B,N,1)
feats = torch.cat([pos, vel, delta, dist], dim=-1) \
if self.use_delta_to_target else torch.cat([pos, vel, dist], dim=-1)
x = self.input_proj(feats) # (B,N,d_model)
x = self.encoder(x) # (B,N,d_model)
# Heads
scale = self._softplus_unit(self.scale_head(x)).squeeze(-1) # (B,N)
vector = self.vec_head(x) # (B,N,3)
# Direction field d
d = (target - pos)
# Parallel component from scale head
scale = scale.unsqueeze(-1).expand(-1, -1, G)
scaled = scale * d # (B,N,3)
# Project vector head output onto plane orthogonal to d
eps = torch.finfo(pos.dtype).eps
denom = d.pow(2).sum(dim=-1, keepdim=True).clamp_min(eps) # (B,N,1)
vec_parallel = ((vector * d).sum(dim=-1, keepdim=True) / denom) * d
vec_perp = vector - vec_parallel
return vec_perp + scaled
class BiasForceTransformerNoVel(nn.Module):
def __init__(self,
args,
d_model = 256,
nhead = 8,
num_layers = 4,
dim_feedforward = 512,
dropout = 0.1,
):
super().__init__()
self.device = args.device
self.N = args.num_particles
self.use_delta_to_target = args.use_delta_to_target
self.rbf = args.rbf
self.sigma = args.sigma
G = args.dim
# Per-atom features in aligned frame for the Transformer
# pos_al(3), vel_al(3), delta_to_target(3 optional), distance(1)
feat_dim = G + (G if self.use_delta_to_target else 0) + 1
self.input_proj = nn.Linear(feat_dim, d_model)
enc_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout, activation="gelu",
batch_first=True, norm_first=True
)
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
# Heads
self.scale_head = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, 1),
)
self.vec_head = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, args.dim),
)
self.log_z = nn.Parameter(torch.tensor(0.0))
#self.to(self.device)
@staticmethod
def _softplus_unit(x, beta=1.0, threshold=20.0, eps=1e-8):
return F.softplus(x, beta=beta, threshold=threshold) + eps
def forward(self, pos, target):
"""
pos, target: (B,N,D)
Returns: force (B,N,D), scale (B,N), vector (B,N,D)
N: number of cells in batch
D: dimension of gene vector
"""
B, N, G = pos.shape
assert N == self.N, f"Expected N={self.N}, got {N}"
# direction of target position
delta = target - pos # (B,N,G)
dist = torch.norm(delta, dim=-1, keepdim=True) # (B,N,1)
feats = torch.cat([pos, delta, dist], dim=-1) \
if self.use_delta_to_target else torch.cat([pos, dist], dim=-1)
x = self.input_proj(feats) # (B,N,d_model)
x = self.encoder(x) # (B,N,d_model)
# Heads
scale = self._softplus_unit(self.scale_head(x)).squeeze(-1) # (B,N)
vector = self.vec_head(x) # (B,N,3)
# Direction field d
d = (target - pos)
# Parallel component from scale head
scale = scale.unsqueeze(-1).expand(-1, -1, G)
scaled = scale * d # (B,N,3)
# Project vector head output onto plane orthogonal to d
eps = torch.finfo(pos.dtype).eps
denom = d.pow(2).sum(dim=-1, keepdim=True).clamp_min(eps) # (B,N,1)
vec_parallel = ((vector * d).sum(dim=-1, keepdim=True) / denom) * d
vec_perp = vector - vec_parallel
return vec_perp + scaled