| | """ |
| | This file is used for T2I generation, it also compute the clip similarity between the generated images and the input prompt |
| | """ |
| | from absl import flags |
| | from absl import app |
| | from ml_collections import config_flags |
| | import os |
| |
|
| | import ml_collections |
| | import torch |
| | from torch import multiprocessing as mp |
| | import torch.nn as nn |
| | import accelerate |
| | import utils |
| | import tempfile |
| | from absl import logging |
| | import builtins |
| | import einops |
| | import math |
| | import numpy as np |
| | import time |
| | from PIL import Image |
| |
|
| | from diffusion.flow_matching import FlowMatching, ODEFlowMatchingSolver, ODEEulerFlowMatchingSolver |
| | from tools.clip_score import ClipSocre |
| | import libs.autoencoder |
| | from libs.clip import FrozenCLIPEmbedder |
| | from libs.t5 import T5Embedder |
| |
|
| |
|
| | def unpreprocess(x): |
| | x = 0.5 * (x + 1.) |
| | x.clamp_(0., 1.) |
| | return x |
| |
|
| | def get_caption(llm, text_model, _batch_prompt): |
| | _batch_con = _batch_prompt |
| | if llm == "clip": |
| | _latent, _latent_and_others = text_model.encode(_batch_con) |
| | _con = _latent_and_others['token_embedding'].detach() |
| | elif llm == "t5": |
| | _latent, _latent_and_others = text_model.get_text_embeddings(_batch_con) |
| | _con = (_latent_and_others['token_embedding'] * 10.0).detach() |
| | else: |
| | raise NotImplementedError |
| | _con_mask = _latent_and_others['token_mask'].detach() |
| | _batch_token = _latent_and_others['tokens'].detach() |
| | _batch_caption = _batch_con |
| | return (_con, _con_mask, _batch_token, _batch_caption) |
| |
|
| |
|
| | def evaluate(config): |
| |
|
| | if config.get('benchmark', False): |
| | torch.backends.cudnn.benchmark = True |
| | torch.backends.cudnn.deterministic = False |
| |
|
| | mp.set_start_method('spawn') |
| | accelerator = accelerate.Accelerator() |
| | device = accelerator.device |
| | accelerate.utils.set_seed(config.seed, device_specific=True) |
| | logging.info(f'Process {accelerator.process_index} using device: {device}') |
| |
|
| | config.mixed_precision = accelerator.mixed_precision |
| | config = ml_collections.FrozenConfigDict(config) |
| | if accelerator.is_main_process: |
| | utils.set_logger(log_level='info', fname=config.output_path) |
| | else: |
| | utils.set_logger(log_level='error') |
| | builtins.print = lambda *args: None |
| |
|
| | nnet = utils.get_nnet(**config.nnet) |
| | nnet = accelerator.prepare(nnet) |
| | logging.info(f'load nnet from {config.nnet_path}') |
| | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu')) |
| | nnet.eval() |
| |
|
| | |
| |
|
| | if config.nnet.model_args.clip_dim == 4096: |
| | llm = "t5" |
| | t5 = T5Embedder(device=device) |
| | elif config.nnet.model_args.clip_dim == 768: |
| | llm = "clip" |
| | clip = FrozenCLIPEmbedder() |
| | clip.eval() |
| | clip.to(device) |
| | else: |
| | raise NotImplementedError |
| | |
| | if llm == "clip": |
| | context_generator = get_caption(llm, clip, _batch_prompt=[config.prompt]*config.sample.mini_batch_size) |
| | elif llm == "t5": |
| | context_generator = get_caption(llm, t5, _batch_prompt=[config.prompt]*config.sample.mini_batch_size) |
| | else: |
| | raise NotImplementedError |
| |
|
| | |
| |
|
| | autoencoder = libs.autoencoder.get_model(**config.autoencoder) |
| | autoencoder.to(device) |
| |
|
| | @torch.cuda.amp.autocast() |
| | def encode(_batch): |
| | return autoencoder.encode(_batch) |
| |
|
| | @torch.cuda.amp.autocast() |
| | def decode(_batch): |
| | return autoencoder.decode(_batch) |
| |
|
| | bdv_nnet = None |
| | ClipSocre_model = ClipSocre(device=device) |
| |
|
| | |
| | logging.info(config.sample) |
| | logging.info(f'sample: n_samples={config.sample.n_samples}, mode=t2i, mixed_precision={config.mixed_precision}') |
| |
|
| | |
| | def ode_fm_solver_sample(nnet_ema, _n_samples, _sample_steps, bdv_nnet=bdv_nnet, context=None, caption=None, testbatch_img_blurred=None, two_stage_generation=-1, token=None, token_mask=None, return_clipScore=False, ClipSocre_model=None): |
| | with torch.no_grad(): |
| | del testbatch_img_blurred |
| | |
| | _z_gaussian = torch.randn(_n_samples, *config.z_shape, device=device) |
| |
|
| | if 'dimr' in config.nnet.name or 'dit' in config.nnet.name: |
| | _z_x0, _mu, _log_var = nnet_ema(context, text_encoder = True, shape = _z_gaussian.shape, mask=token_mask) |
| | _z_init = _z_x0.reshape(_z_gaussian.shape) |
| | else: |
| | raise NotImplementedError |
| |
|
| | assert config.sample.scale > 1 |
| | if config.cfg != -1: |
| | _cfg = config.cfg |
| | else: |
| | _cfg = config.sample.scale |
| |
|
| | has_null_indicator = hasattr(config.nnet.model_args, "cfg_indicator") |
| | |
| | _sample_steps = config.sample.sample_steps |
| | |
| | ode_solver = ODEEulerFlowMatchingSolver(nnet_ema, bdv_model_fn=bdv_nnet, step_size_type="step_in_dsigma", guidance_scale=_cfg) |
| | _z, _ = ode_solver.sample(x_T=_z_init, batch_size=_n_samples, sample_steps=_sample_steps, unconditional_guidance_scale=_cfg, has_null_indicator=has_null_indicator) |
| |
|
| | image_unprocessed = decode(_z) |
| | clip_score = ClipSocre_model.calculate_clip_score(caption, image_unprocessed) |
| | |
| | return image_unprocessed, clip_score |
| |
|
| |
|
| | def sample_fn(_n_samples, return_caption=False, return_clipScore=False, ClipSocre_model=None, config=None): |
| | _context, _token_mask, _token, _caption = context_generator |
| | assert _context.size(0) == _n_samples |
| | assert return_clipScore |
| | assert not return_caption |
| | return ode_fm_solver_sample(nnet, _n_samples, config.sample.sample_steps, bdv_nnet=bdv_nnet, context=_context, token=_token, token_mask=_token_mask, return_clipScore=return_clipScore, ClipSocre_model=ClipSocre_model, caption=_caption) |
| | |
| |
|
| | with tempfile.TemporaryDirectory() as temp_path: |
| | path = config.img_save_path or config.sample.path or temp_path |
| | if accelerator.is_main_process: |
| | os.makedirs(path, exist_ok=True) |
| | logging.info(f'Samples are saved in {path}') |
| |
|
| | clip_score_list = utils.sample2dir_wCLIP(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, unpreprocess, return_clipScore=True, ClipSocre_model=ClipSocre_model, config=config) |
| | if clip_score_list is not None: |
| | _clip_score_list = torch.cat(clip_score_list) |
| | if accelerator.is_main_process: |
| | logging.info(f'nnet_path={config.nnet_path}, clip_score{len(_clip_score_list)}={_clip_score_list.mean().item()}') |
| |
|
| |
|
| | FLAGS = flags.FLAGS |
| | config_flags.DEFINE_config_file( |
| | "config", None, "Training configuration.", lock_config=False) |
| |
|
| | flags.mark_flags_as_required(["config"]) |
| | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.") |
| | flags.DEFINE_string("prompt", None, "The prompt used for generation.") |
| | flags.DEFINE_string("output_path", None, "The path to output log.") |
| | flags.DEFINE_float("cfg", -1, 'cfg scale, will use the scale defined in the config file is not assigned') |
| | flags.DEFINE_string("img_save_path", None, "The path to image log.") |
| |
|
| |
|
| | def main(argv): |
| | config = FLAGS.config |
| | config.nnet_path = FLAGS.nnet_path |
| | config.prompt = FLAGS.prompt |
| | config.output_path = FLAGS.output_path |
| | config.img_save_path = FLAGS.img_save_path |
| | config.cfg = FLAGS.cfg |
| | evaluate(config) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | app.run(main) |
| |
|