import torch import random import numpy as np from typing import List from itertools import repeat def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) def grad_status(model): return (par.requires_grad for par in model.parameters()) def lmap(f, x): """list(map(f, x))""" return list(map(f, x)) def assert_all_frozen(model): model_grads: List[bool] = list(grad_status(model)) n_require_grad = sum(lmap(int, model_grads)) npars = len(model_grads) assert not any( model_grads ), f"{n_require_grad / npars:.1%} of {npars} weights require grad" def split_dense_inputs(model_input: dict, chunk_size: int): assert len(model_input) == 1 arg_key = list(model_input.keys())[0] arg_val = model_input[arg_key] keys = list(arg_val.keys()) chunked_tensors = [arg_val[k].split(chunk_size, dim=0) for k in keys] chunked_arg_val = [ dict(zip(kk, tt)) for kk, tt in zip(repeat(keys), zip(*chunked_tensors)) ] return [{arg_key: c} for c in chunked_arg_val] def get_dense_rep(x): if x.q_reps is None: return x.p_reps else: return x.q_reps