Spaces:
Build error
Build error
| import os | |
| import sys | |
| import torch | |
| import argparse | |
| from PIL import Image | |
| from diffusers.utils import export_to_video | |
| # Add the project root directory to sys.path | |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| PROJECT_ROOT = os.path.dirname(SCRIPT_DIR) | |
| if PROJECT_ROOT not in sys.path: | |
| sys.path.insert(0, PROJECT_ROOT) | |
| from pyramid_dit import PyramidDiTForVideoGeneration | |
| from trainer_misc import init_distributed_mode, init_sequence_parallel_group | |
| def get_args(): | |
| parser = argparse.ArgumentParser('Pytorch Multi-process Script', add_help=False) | |
| parser.add_argument('--model_name', default='pyramid_mmdit', type=str, help="The model name", choices=["pyramid_flux", "pyramid_mmdit"]) | |
| parser.add_argument('--model_dtype', default='bf16', type=str, help="The Model Dtype: bf16") | |
| parser.add_argument('--model_path', required=True, type=str, help='Path to the downloaded checkpoint directory') | |
| parser.add_argument('--variant', default='diffusion_transformer_768p', type=str) | |
| parser.add_argument('--task', default='t2v', type=str, choices=['i2v', 't2v']) | |
| parser.add_argument('--temp', default=16, type=int, help='The generated latent num, num_frames = temp * 8 + 1') | |
| parser.add_argument('--sp_group_size', default=2, type=int, help="The number of GPUs used for inference, should be 2 or 4") | |
| parser.add_argument('--sp_proc_num', default=-1, type=int, help="The number of processes used for video training, default=-1 means using all processes.") | |
| parser.add_argument('--prompt', type=str, required=True, help="Text prompt for video generation") | |
| parser.add_argument('--image_path', type=str, help="Path to the input image for image-to-video") | |
| parser.add_argument('--video_guidance_scale', type=float, default=5.0, help="Video guidance scale") | |
| parser.add_argument('--guidance_scale', type=float, default=9.0, help="Guidance scale for text-to-video") | |
| parser.add_argument('--resolution', type=str, default='768p', choices=['768p', '384p'], help="Model resolution") | |
| parser.add_argument('--output_path', type=str, required=True, help="Path to save the generated video") | |
| return parser.parse_args() | |
| def main(): | |
| args = get_args() | |
| # Setup DDP | |
| init_distributed_mode(args) | |
| assert args.world_size == args.sp_group_size, "The sequence parallel size should match DDP world size" | |
| # Enable sequence parallel | |
| init_sequence_parallel_group(args) | |
| device = torch.device('cuda') | |
| rank = args.rank | |
| model_dtype = args.model_dtype | |
| if args.model_name == "pyramid_flux": | |
| assert args.variant != "diffusion_transformer_768p", "The pyramid_flux does not support high resolution now, \ | |
| we will release it after finishing training. You can modify the model_name to pyramid_mmdit to support 768p version generation" | |
| model = PyramidDiTForVideoGeneration( | |
| args.model_path, | |
| model_dtype, | |
| model_name=args.model_name, | |
| model_variant=args.variant, | |
| ) | |
| model.vae.to(device) | |
| model.dit.to(device) | |
| model.text_encoder.to(device) | |
| model.vae.enable_tiling() | |
| if model_dtype == "bf16": | |
| torch_dtype = torch.bfloat16 | |
| elif model_dtype == "fp16": | |
| torch_dtype = torch.float16 | |
| else: | |
| torch_dtype = torch.float32 | |
| # The video generation config | |
| if args.resolution == '768p': | |
| width = 1280 | |
| height = 768 | |
| else: | |
| width = 640 | |
| height = 384 | |
| try: | |
| if args.task == 't2v': | |
| prompt = args.prompt | |
| with torch.no_grad(), torch.cuda.amp.autocast(enabled=(model_dtype != 'fp32'), dtype=torch_dtype): | |
| frames = model.generate( | |
| prompt=prompt, | |
| num_inference_steps=[20, 20, 20], | |
| video_num_inference_steps=[10, 10, 10], | |
| height=height, | |
| width=width, | |
| temp=args.temp, | |
| guidance_scale=args.guidance_scale, | |
| video_guidance_scale=args.video_guidance_scale, | |
| output_type="pil", | |
| save_memory=True, | |
| cpu_offloading=False, | |
| inference_multigpu=True, | |
| ) | |
| if rank == 0: | |
| export_to_video(frames, args.output_path, fps=24) | |
| elif args.task == 'i2v': | |
| if not args.image_path: | |
| raise ValueError("Image path is required for image-to-video task") | |
| image = Image.open(args.image_path).convert("RGB") | |
| image = image.resize((width, height)) | |
| prompt = args.prompt | |
| with torch.no_grad(), torch.cuda.amp.autocast(enabled=(model_dtype != 'fp32'), dtype=torch_dtype): | |
| frames = model.generate_i2v( | |
| prompt=prompt, | |
| input_image=image, | |
| num_inference_steps=[10, 10, 10], | |
| temp=args.temp, | |
| video_guidance_scale=args.video_guidance_scale, | |
| output_type="pil", | |
| save_memory=True, | |
| cpu_offloading=False, | |
| inference_multigpu=True, | |
| ) | |
| if rank == 0: | |
| export_to_video(frames, args.output_path, fps=24) | |
| except Exception as e: | |
| if rank == 0: | |
| print(f"[ERROR] Error during video generation: {e}") | |
| raise | |
| finally: | |
| torch.distributed.barrier() | |
| if __name__ == "__main__": | |
| main() | |