EntangledSBM / entangled-cell /train_utils.py
Sophia Tang
Initial commit with LFS
7efee70
import yaml
import string
import secrets
import os
import torch
import wandb
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint
from torchdyn.core import NeuralODE
import torch
@torch.no_grad()
def gather_local_starts(x0s, X0_pool, N, k=64):
# for each anchor b, take its k-NN from pool, then sample N distinct
B, G = x0s.shape
d2 = torch.cdist(x0s, X0_pool).pow(2) # (B, M0)
knn_idx = d2.topk(k=min(k, X0_pool.size(0)), largest=False).indices # (B,k)
x0_clusters = []
for b in range(B):
choices = knn_idx[b]
pick = choices[torch.randperm(choices.numel(), device=choices.device)[:N]]
x0_clusters.append(X0_pool[pick]) # (N,G)
return torch.stack(x0_clusters, dim=0) # (B,N,G)
@torch.no_grad()
def make_aligned_clusters(ot_sampler, x0s, x1s, N, replace=True, k_local=128):
device, dtype = x0s.device, x0s.dtype
B, G = x0s.shape
M = x1s.shape[0]
# Use gather_local_starts to get N distinct cells for each source
x0_clusters = gather_local_starts(x0s, x0s, N, k=k_local).to(device=device, dtype=dtype)
x1_clusters = torch.empty((B, N, G), device=device, dtype=dtype)
idx1 = torch.empty((B, N), device=device, dtype=torch.long)
# Try to get a full coupling once (preferred: row-stochastic matrix P of shape (B, M))
P = None
if hasattr(ot_sampler, "coupling"):
P = ot_sampler.coupling(x0s, x1s) # expected (B, M) torch tensor
elif hasattr(ot_sampler, "plan"):
P = ot_sampler.plan(x0s, x1s) # same expectation
# If your ot_sampler only supports sampling, we’ll fall back row-by-row below.
for b in range(B):
x0_b = x0s[b:b+1] # (1, G)
if P is not None:
# --- Sample N targets from the row distribution P[b] ---
probs = P[b].clamp_min(0)
probs = probs / probs.sum().clamp_min(1e-12)
if replace:
j = torch.multinomial(probs, num_samples=N, replacement=True) # (N,)
else:
k = min(N, (probs > 0).sum().item())
j = torch.multinomial(probs, num_samples=k, replacement=False)
if k < N: # pad by repeating the last choice to keep shape
j = torch.cat([j, j[-1:].expand(N-k)], dim=0)
x1_match = x1s[j] # (N, G)
else:
# --- Row-wise fallback using sampler’s own sampling API ---
# Try to ask for N pairs at once
got = False
if hasattr(ot_sampler, "sample_plan"):
try:
# many samplers support an argument like n_pairs / k / n
x0_rep, x1_match = ot_sampler.sample_plan(
x0_b, x1s, replace=replace, n_pairs=N
)
# x0_rep: (N, G) or (1, N, G) -> squeeze if needed
x1_match = x1_match.view(N, G)
got = True
except TypeError:
pass
if not got:
# last resort: call sample_plan N times
xs, ys, js = [], [], []
for _ in range(N):
x0_rep, x1_one = ot_sampler.sample_plan(x0_b, x1s, replace=replace)
# infer index by nearest neighbor for bookkeeping (optional)
j_hat = torch.cdist(x1_one.view(1, -1), x1s).argmin()
xs.append(x0_rep.view(1, G))
ys.append(x1_one.view(1, G))
js.append(j_hat.view(1))
x1_match = torch.cat(ys, dim=0)
j = torch.cat(js, dim=0)
# Fill clusters (source row replicated N times)
#x0_clusters[b] = x0_b.expand(N, G)
x1_clusters[b] = x1_match
idx1[b] = j
return x0_clusters, x1_clusters, idx1
def load_config(path):
with open(path, "r") as file:
config = yaml.safe_load(file)
return config
def merge_config(args, config_updates):
for key, value in config_updates.items():
if not hasattr(args, key):
raise ValueError(
f"Unknown configuration parameter '{key}' found in the config file."
)
setattr(args, key, value)
return args
def generate_group_string(length=16):
alphabet = string.ascii_letters + string.digits
return "".join(secrets.choice(alphabet) for _ in range(length))