|
|
import sys |
|
|
import os |
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "6" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
class BiasForceTransformer(nn.Module): |
|
|
def __init__(self, |
|
|
args, |
|
|
d_model = 256, |
|
|
nhead = 8, |
|
|
num_layers = 4, |
|
|
dim_feedforward = 512, |
|
|
dropout = 0.1, |
|
|
): |
|
|
super().__init__() |
|
|
self.device = args.device |
|
|
self.N = args.num_particles |
|
|
|
|
|
self.use_delta_to_target = args.use_delta_to_target |
|
|
self.rbf = args.rbf |
|
|
|
|
|
self.sigma = args.sigma |
|
|
|
|
|
G = args.dim |
|
|
|
|
|
|
|
|
feat_dim = (2 * G) + (G if self.use_delta_to_target else 0) + 1 |
|
|
|
|
|
self.input_proj = nn.Linear(feat_dim, d_model) |
|
|
enc_layer = nn.TransformerEncoderLayer( |
|
|
d_model=d_model, nhead=nhead, |
|
|
dim_feedforward=dim_feedforward, |
|
|
dropout=dropout, activation="gelu", |
|
|
batch_first=True, norm_first=True |
|
|
) |
|
|
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers) |
|
|
|
|
|
|
|
|
self.scale_head = nn.Sequential( |
|
|
nn.Linear(d_model, d_model // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(d_model // 2, 1), |
|
|
) |
|
|
self.vec_head = nn.Sequential( |
|
|
nn.Linear(d_model, d_model // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(d_model // 2, args.dim), |
|
|
) |
|
|
|
|
|
self.log_z = nn.Parameter(torch.tensor(0.0)) |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def _softplus_unit(x, beta=1.0, threshold=20.0, eps=1e-8): |
|
|
return F.softplus(x, beta=beta, threshold=threshold) + eps |
|
|
|
|
|
def forward(self, pos, vel, target): |
|
|
""" |
|
|
pos, vel, target: (B,N,D) |
|
|
Returns: force (B,N,D), scale (B,N), vector (B,N,D) |
|
|
|
|
|
N: number of cells in batch |
|
|
D: dimension of gene vector |
|
|
""" |
|
|
B, N, G = pos.shape |
|
|
assert N == self.N, f"Expected N={self.N}, got {N}" |
|
|
|
|
|
|
|
|
delta = target - pos |
|
|
dist = torch.norm(delta, dim=-1, keepdim=True) |
|
|
feats = torch.cat([pos, vel, delta, dist], dim=-1) \ |
|
|
if self.use_delta_to_target else torch.cat([pos, vel, dist], dim=-1) |
|
|
|
|
|
x = self.input_proj(feats) |
|
|
x = self.encoder(x) |
|
|
|
|
|
|
|
|
scale = self._softplus_unit(self.scale_head(x)).squeeze(-1) |
|
|
vector = self.vec_head(x) |
|
|
|
|
|
|
|
|
d = (target - pos) |
|
|
|
|
|
|
|
|
scale = scale.unsqueeze(-1).expand(-1, -1, G) |
|
|
scaled = scale * d |
|
|
|
|
|
|
|
|
eps = torch.finfo(pos.dtype).eps |
|
|
denom = d.pow(2).sum(dim=-1, keepdim=True).clamp_min(eps) |
|
|
vec_parallel = ((vector * d).sum(dim=-1, keepdim=True) / denom) * d |
|
|
vec_perp = vector - vec_parallel |
|
|
|
|
|
return vec_perp + scaled |
|
|
|
|
|
class BiasForceTransformerNoVel(nn.Module): |
|
|
def __init__(self, |
|
|
args, |
|
|
d_model = 256, |
|
|
nhead = 8, |
|
|
num_layers = 4, |
|
|
dim_feedforward = 512, |
|
|
dropout = 0.1, |
|
|
): |
|
|
super().__init__() |
|
|
self.device = args.device |
|
|
self.N = args.num_particles |
|
|
|
|
|
self.use_delta_to_target = args.use_delta_to_target |
|
|
self.rbf = args.rbf |
|
|
|
|
|
self.sigma = args.sigma |
|
|
|
|
|
G = args.dim |
|
|
|
|
|
|
|
|
feat_dim = G + (G if self.use_delta_to_target else 0) + 1 |
|
|
|
|
|
self.input_proj = nn.Linear(feat_dim, d_model) |
|
|
enc_layer = nn.TransformerEncoderLayer( |
|
|
d_model=d_model, nhead=nhead, |
|
|
dim_feedforward=dim_feedforward, |
|
|
dropout=dropout, activation="gelu", |
|
|
batch_first=True, norm_first=True |
|
|
) |
|
|
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers) |
|
|
|
|
|
|
|
|
self.scale_head = nn.Sequential( |
|
|
nn.Linear(d_model, d_model // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(d_model // 2, 1), |
|
|
) |
|
|
self.vec_head = nn.Sequential( |
|
|
nn.Linear(d_model, d_model // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(d_model // 2, args.dim), |
|
|
) |
|
|
|
|
|
self.log_z = nn.Parameter(torch.tensor(0.0)) |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def _softplus_unit(x, beta=1.0, threshold=20.0, eps=1e-8): |
|
|
return F.softplus(x, beta=beta, threshold=threshold) + eps |
|
|
|
|
|
def forward(self, pos, target): |
|
|
""" |
|
|
pos, target: (B,N,D) |
|
|
Returns: force (B,N,D), scale (B,N), vector (B,N,D) |
|
|
|
|
|
N: number of cells in batch |
|
|
D: dimension of gene vector |
|
|
""" |
|
|
B, N, G = pos.shape |
|
|
assert N == self.N, f"Expected N={self.N}, got {N}" |
|
|
|
|
|
|
|
|
delta = target - pos |
|
|
dist = torch.norm(delta, dim=-1, keepdim=True) |
|
|
feats = torch.cat([pos, delta, dist], dim=-1) \ |
|
|
if self.use_delta_to_target else torch.cat([pos, dist], dim=-1) |
|
|
|
|
|
x = self.input_proj(feats) |
|
|
x = self.encoder(x) |
|
|
|
|
|
|
|
|
scale = self._softplus_unit(self.scale_head(x)).squeeze(-1) |
|
|
vector = self.vec_head(x) |
|
|
|
|
|
|
|
|
d = (target - pos) |
|
|
|
|
|
|
|
|
scale = scale.unsqueeze(-1).expand(-1, -1, G) |
|
|
scaled = scale * d |
|
|
|
|
|
|
|
|
eps = torch.finfo(pos.dtype).eps |
|
|
denom = d.pow(2).sum(dim=-1, keepdim=True).clamp_min(eps) |
|
|
vec_parallel = ((vector * d).sum(dim=-1, keepdim=True) / denom) * d |
|
|
vec_perp = vector - vec_parallel |
|
|
|
|
|
return vec_perp + scaled |