import os from typing import List, Optional, Union, Tuple import torch from transformers import T5EncoderModel, T5Tokenizer import numpy as np import cv2 from diffusers.models.embeddings import get_3d_rotary_pos_embed from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid from accelerate.logging import get_logger import tempfile import argparse import yaml import shutil logger = get_logger(__name__) def get_args(): parser = argparse.ArgumentParser(description="Training script for CogVideoX using config file.") parser.add_argument( "--config", type=str, required=True, help="Path to the YAML config file." ) args = parser.parse_args() with open(args.config, "r") as f: config = yaml.safe_load(f) args = argparse.Namespace(**config) # Convert nested config dict to an argparse.Namespace for easier downstream usage return args def atomic_save(save_path, accelerator): parent = os.path.dirname(save_path) tmp_dir = tempfile.mkdtemp(dir=parent) backup_dir = save_path + "_backup" try: # Save state into the temp directory accelerator.save_state(tmp_dir) # Backup existing save_path if it exists if os.path.exists(save_path): os.rename(save_path, backup_dir) # Atomically move temp directory into place os.rename(tmp_dir, save_path) # Clean up the backup directory if os.path.exists(backup_dir): shutil.rmtree(backup_dir) except Exception as e: # Clean up temp directory on failure if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir) # Restore from backup if replacement failed if os.path.exists(backup_dir): if os.path.exists(save_path): shutil.rmtree(save_path) os.rename(backup_dir, save_path) raise e def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): # Use DeepSpeed optimzer if use_deepspeed: from accelerate.utils import DummyOptim return DummyOptim( params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_epsilon, weight_decay=args.adam_weight_decay, ) # Optimizer creation supported_optimizers = ["adam", "adamw", "prodigy"] if args.optimizer not in supported_optimizers: logger.warning( f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW" ) args.optimizer = "adamw" if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]): logger.warning( f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was " f"set to {args.optimizer.lower()}" ) if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError( "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." ) if args.optimizer.lower() == "adamw": optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW optimizer = optimizer_class( params_to_optimize, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_epsilon, weight_decay=args.adam_weight_decay, ) elif args.optimizer.lower() == "adam": optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam optimizer = optimizer_class( params_to_optimize, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_epsilon, weight_decay=args.adam_weight_decay, ) elif args.optimizer.lower() == "prodigy": try: import prodigyopt except ImportError: raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") optimizer_class = prodigyopt.Prodigy if args.learning_rate <= 0.1: logger.warning( "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" ) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), beta3=args.prodigy_beta3, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, decouple=args.prodigy_decouple, use_bias_correction=args.prodigy_use_bias_correction, safeguard_warmup=args.prodigy_safeguard_warmup, ) return optimizer def prepare_rotary_positional_embeddings( height: int, width: int, num_frames: int, vae_scale_factor_spatial: int = 8, patch_size: int = 2, attention_head_dim: int = 64, device: Optional[torch.device] = None, base_height: int = 480, base_width: int = 720, ) -> Tuple[torch.Tensor, torch.Tensor]: grid_height = height // (vae_scale_factor_spatial * patch_size) grid_width = width // (vae_scale_factor_spatial * patch_size) base_size_width = base_width // (vae_scale_factor_spatial * patch_size) base_size_height = base_height // (vae_scale_factor_spatial * patch_size) grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height) freqs_cos, freqs_sin = get_3d_rotary_pos_embed( embed_dim=attention_head_dim, crops_coords=grid_crops_coords, grid_size=(grid_height, grid_width), temporal_size=num_frames, ) freqs_cos = freqs_cos.to(device=device) freqs_sin = freqs_sin.to(device=device) return freqs_cos, freqs_sin def _get_t5_prompt_embeds( tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, prompt: Union[str, List[str]], num_videos_per_prompt: int = 1, max_sequence_length: int = 226, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, text_input_ids=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if tokenizer is not None: text_inputs = tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids else: if text_input_ids is None: raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") prompt_embeds = text_encoder(text_input_ids.to(device))[0] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) return prompt_embeds def encode_prompt( tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, prompt: Union[str, List[str]], num_videos_per_prompt: int = 1, max_sequence_length: int = 226, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, text_input_ids=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt prompt_embeds = _get_t5_prompt_embeds( tokenizer, text_encoder, prompt=prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype, text_input_ids=text_input_ids, ) return prompt_embeds def compute_prompt_embeddings( tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False ): if requires_grad: prompt_embeds = encode_prompt( tokenizer, text_encoder, prompt, num_videos_per_prompt=1, max_sequence_length=max_sequence_length, device=device, dtype=dtype, ) else: with torch.no_grad(): prompt_embeds = encode_prompt( tokenizer, text_encoder, prompt, num_videos_per_prompt=1, max_sequence_length=max_sequence_length, device=device, dtype=dtype, ) return prompt_embeds def save_frames_as_pngs(video_array,output_dir, downsample_spatial=1, # e.g. 2 to halve width & height downsample_temporal=1): # e.g. 2 to keep every 2nd frame """ Save each frame of a (T, H, W, C) numpy array as a PNG with no compression. """ assert video_array.ndim == 4 and video_array.shape[-1] == 3, \ "Expected (T, H, W, C=3) array" assert video_array.dtype == np.uint8, "Expected uint8 array" os.makedirs(output_dir, exist_ok=True) # temporal downsample frames = video_array[::downsample_temporal] # compute spatially downsampled size T, H, W, _ = frames.shape new_size = (W // downsample_spatial, H // downsample_spatial) # PNG compression param: 0 = no compression png_params = [cv2.IMWRITE_PNG_COMPRESSION, 0] for idx, frame in enumerate(frames): # frame is RGB; convert to BGR for OpenCV bgr = frame[..., ::-1] if downsample_spatial > 1: bgr = cv2.resize(bgr, new_size, interpolation=cv2.INTER_NEAREST) filename = os.path.join(output_dir, "frame_{:05d}.png".format(idx)) success = cv2.imwrite(filename, bgr, png_params) if not success: raise RuntimeError("Failed to write frame ")