Spaces:
Runtime error
Runtime error
| import os | |
| import spaces | |
| import torch | |
| import gradio as gr | |
| import tempfile | |
| import subprocess | |
| import sys | |
| from pathlib import Path | |
| import datetime | |
| import math | |
| import random | |
| import gc | |
| import json | |
| import numpy as np | |
| from PIL import Image | |
| from moviepy import * | |
| import librosa | |
| from omegaconf import OmegaConf | |
| from transformers import AutoTokenizer, Wav2Vec2Model, Wav2Vec2Processor | |
| from diffusers import FlowMatchEulerDiscreteScheduler | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| def setup_repository(): | |
| if not os.path.exists("echomimic_v3"): | |
| print("π Cloning EchoMimicV3 repository...") | |
| subprocess.run([ | |
| "git", "clone", | |
| "https://github.com/antgroup/echomimic_v3.git" | |
| ], check=True) | |
| print("β Repository cloned successfully") | |
| sys.path.insert(0, "echomimic_v3") | |
| print("β Repository added to Python path") | |
| def download_models(): | |
| print("π₯ Downloading models...") | |
| os.makedirs("models", exist_ok=True) | |
| try: | |
| print("π Downloading base model...") | |
| base_model_path = snapshot_download( | |
| repo_id="alibaba-pai/Wan2.1-Fun-V1.1-1.3B-InP", | |
| local_dir="models/Wan2.1-Fun-V1.1-1.3B-InP", | |
| local_dir_use_symlinks=False | |
| ) | |
| print(f"β Base model downloaded to: {base_model_path}") | |
| print("π Downloading EchoMimicV3 transformer...") | |
| os.makedirs("models/transformer", exist_ok=True) | |
| transformer_file = hf_hub_download( | |
| repo_id="BadToBest/EchoMimicV3", | |
| filename="transformer/diffusion_pytorch_model.safetensors", | |
| local_dir="models", | |
| local_dir_use_symlinks=False | |
| ) | |
| print(f"β Transformer downloaded to: {transformer_file}") | |
| config_file = hf_hub_download( | |
| repo_id="BadToBest/EchoMimicV3", | |
| filename="transformer/config.json", | |
| local_dir="models", | |
| local_dir_use_symlinks=False | |
| ) | |
| print(f"β Config downloaded to: {config_file}") | |
| print("π Downloading Wav2Vec model...") | |
| wav2vec_path = snapshot_download( | |
| repo_id="facebook/wav2vec2-base-960h", | |
| local_dir="models/wav2vec2-base-960h", | |
| local_dir_use_symlinks=False | |
| ) | |
| print(f"β Wav2Vec model downloaded to: {wav2vec_path}") | |
| print("β All models downloaded successfully!") | |
| return True | |
| except Exception as e: | |
| print(f"β Error downloading models: {e}") | |
| return False | |
| def download_examples(): | |
| print("π Downloading example files...") | |
| os.makedirs("examples", exist_ok=True) | |
| try: | |
| example_files = [ | |
| "datasets/echomimicv3_demos/imgs/demo_ch_woman_04.png", | |
| "datasets/echomimicv3_demos/audios/demo_ch_woman_04.WAV", | |
| "datasets/echomimicv3_demos/prompts/demo_ch_woman_04.txt", | |
| "datasets/echomimicv3_demos/imgs/guitar_woman_01.png", | |
| "datasets/echomimicv3_demos/audios/guitar_woman_01.WAV", | |
| "datasets/echomimicv3_demos/prompts/guitar_woman_01.txt" | |
| ] | |
| repo_url = "https://github.com/antgroup/echomimic_v3/raw/main/" | |
| for file_path in example_files: | |
| try: | |
| import urllib.request | |
| filename = os.path.basename(file_path) | |
| local_path = f"examples/{filename}" | |
| if not os.path.exists(local_path): | |
| print(f"π Downloading {filename}...") | |
| urllib.request.urlretrieve(f"{repo_url}{file_path}", local_path) | |
| print(f"β Downloaded {filename}") | |
| else: | |
| print(f"β {filename} already exists") | |
| except Exception as e: | |
| print(f"β οΈ Could not download {filename}: {e}") | |
| print("β Example files downloaded!") | |
| return True | |
| except Exception as e: | |
| print(f"β Error downloading examples: {e}") | |
| return False | |
| setup_repository() | |
| from src.dist import set_multi_gpus_devices | |
| from src.wan_vae import AutoencoderKLWan | |
| from src.wan_image_encoder import CLIPModel | |
| from src.wan_text_encoder import WanT5EncoderModel | |
| from src.wan_transformer3d_audio import WanTransformerAudioMask3DModel | |
| from src.pipeline_wan_fun_inpaint_audio import WanFunInpaintAudioPipeline | |
| from src.utils import filter_kwargs, get_image_to_video_latent3, save_videos_grid | |
| from src.fm_solvers import FlowDPMSolverMultistepScheduler | |
| from src.fm_solvers_unipc import FlowUniPCMultistepScheduler | |
| from src.cache_utils import get_teacache_coefficients | |
| from src.face_detect import get_mask_coord | |
| class ComprehensiveConfig: | |
| def __init__(self): | |
| self.ulysses_degree = 1 | |
| self.ring_degree = 1 | |
| self.fsdp_dit = False | |
| self.config_path = "echomimic_v3/config/config.yaml" | |
| self.model_name = "models/Wan2.1-Fun-V1.1-1.3B-InP" | |
| self.transformer_path = "models/transformer/diffusion_pytorch_model.safetensors" | |
| self.wav2vec_model_dir = "models/wav2vec2-base-960h" | |
| self.weight_dtype = torch.bfloat16 | |
| self.sample_size = [768, 768] | |
| self.sampler_name = "Flow_DPM++" | |
| self.lora_weight = 1.0 | |
| config = ComprehensiveConfig() | |
| pipeline = None | |
| wav2vec_processor = None | |
| wav2vec_model = None | |
| def load_wav2vec_models(wav2vec_model_dir): | |
| print(f"π Loading Wav2Vec models from {wav2vec_model_dir}...") | |
| try: | |
| processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_dir) | |
| model = Wav2Vec2Model.from_pretrained(wav2vec_model_dir).eval() | |
| model.requires_grad_(False) | |
| print("β Wav2Vec models loaded successfully") | |
| return processor, model | |
| except Exception as e: | |
| print(f"β Error loading Wav2Vec models: {e}") | |
| raise | |
| def extract_audio_features(audio_path, processor, model): | |
| try: | |
| sr = 16000 | |
| audio_segment, sample_rate = librosa.load(audio_path, sr=sr) | |
| input_values = processor(audio_segment, sampling_rate=sample_rate, return_tensors="pt").input_values | |
| input_values = input_values.to(model.device) | |
| with torch.no_grad(): | |
| features = model(input_values).last_hidden_state | |
| return features.squeeze(0) | |
| except Exception as e: | |
| print(f"β Error extracting audio features: {e}") | |
| raise | |
| def get_sample_size(image, default_size): | |
| width, height = image.size | |
| original_area = width * height | |
| default_area = default_size[0] * default_size[1] | |
| if default_area < original_area: | |
| ratio = math.sqrt(original_area / default_area) | |
| width = width / ratio // 16 * 16 | |
| height = height / ratio // 16 * 16 | |
| else: | |
| width = width // 16 * 16 | |
| height = height // 16 * 16 | |
| return int(height), int(width) | |
| def get_ip_mask(coords): | |
| y1, y2, x1, x2, h, w = coords | |
| Y, X = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij') | |
| mask = (Y.unsqueeze(-1) >= y1) & (Y.unsqueeze(-1) < y2) & (X.unsqueeze(-1) >= x1) & (X.unsqueeze(-1) < x2) | |
| mask = mask.reshape(-1) | |
| return mask.float() | |
| def initialize_models(): | |
| global pipeline, wav2vec_processor, wav2vec_model, config | |
| print("π Initializing EchoMimicV3 models...") | |
| try: | |
| if not download_models(): | |
| raise Exception("Failed to download required models") | |
| download_examples() | |
| device = set_multi_gpus_devices(config.ulysses_degree, config.ring_degree) | |
| print(f"β Device set to: {device}") | |
| cfg = OmegaConf.load(config.config_path) | |
| print(f"β Config loaded from {config.config_path}") | |
| print("π Loading transformer...") | |
| transformer = WanTransformerAudioMask3DModel.from_pretrained( | |
| os.path.join(config.model_name, cfg['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')), | |
| transformer_additional_kwargs=OmegaConf.to_container(cfg['transformer_additional_kwargs']), | |
| torch_dtype=config.weight_dtype, | |
| ) | |
| if config.transformer_path is not None and os.path.exists(config.transformer_path): | |
| print(f"π Loading custom transformer weights from {config.transformer_path}...") | |
| from safetensors.torch import load_file | |
| state_dict = load_file(config.transformer_path) | |
| state_dict = state_dict.get("state_dict", state_dict) | |
| missing, unexpected = transformer.load_state_dict(state_dict, strict=False) | |
| print(f"β Custom transformer weights loaded - Missing: {len(missing)}, Unexpected: {len(unexpected)}") | |
| print("π Loading VAE...") | |
| vae = AutoencoderKLWan.from_pretrained( | |
| os.path.join(config.model_name, cfg['vae_kwargs'].get('vae_subpath', 'vae')), | |
| additional_kwargs=OmegaConf.to_container(cfg['vae_kwargs']), | |
| ).to(config.weight_dtype) | |
| print("β VAE loaded") | |
| print("π Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| os.path.join(config.model_name, cfg['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), | |
| ) | |
| print("β Tokenizer loaded") | |
| print("π Loading text encoder...") | |
| text_encoder = WanT5EncoderModel.from_pretrained( | |
| os.path.join(config.model_name, cfg['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), | |
| additional_kwargs=OmegaConf.to_container(cfg['text_encoder_kwargs']), | |
| torch_dtype=config.weight_dtype, | |
| ).eval() | |
| print("β Text encoder loaded") | |
| print("π Loading CLIP image encoder...") | |
| clip_image_encoder = CLIPModel.from_pretrained( | |
| os.path.join(config.model_name, cfg['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')), | |
| ).to(config.weight_dtype).eval() | |
| print("β CLIP image encoder loaded") | |
| print("π Loading scheduler...") | |
| scheduler_cls_map = { | |
| "Flow": FlowMatchEulerDiscreteScheduler, | |
| "Flow_Unipc": FlowUniPCMultistepScheduler, | |
| "Flow_DPM++": FlowDPMSolverMultistepScheduler, | |
| } | |
| scheduler_cls = scheduler_cls_map.get(config.sampler_name, FlowDPMSolverMultistepScheduler) | |
| scheduler = scheduler_cls(**filter_kwargs(scheduler_cls, OmegaConf.to_container(cfg['scheduler_kwargs']))) | |
| print("β Scheduler loaded") | |
| print("π Creating pipeline...") | |
| pipeline = WanFunInpaintAudioPipeline( | |
| transformer=transformer, | |
| vae=vae, | |
| tokenizer=tokenizer, | |
| text_encoder=text_encoder, | |
| scheduler=scheduler, | |
| clip_image_encoder=clip_image_encoder, | |
| ) | |
| pipeline.to(device=device) | |
| if torch.__version__ >= "2.0": | |
| print("π Compiling the pipeline with torch.compile()...") | |
| pipeline.transformer = torch.compile(pipeline.transformer, mode="reduce-overhead", fullgraph=True) | |
| print("β Pipeline transformer compiled!") | |
| print("β Pipeline created and moved to device") | |
| print("π Loading Wav2Vec models...") | |
| wav2vec_processor, wav2vec_model = load_wav2vec_models(config.wav2vec_model_dir) | |
| wav2vec_model.to(device) | |
| print("β Wav2Vec models loaded") | |
| print("π All models initialized successfully!") | |
| return True | |
| except Exception as e: | |
| print(f"β Model initialization failed: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def generate_video( | |
| image_path, | |
| audio_path, | |
| prompt, | |
| negative_prompt, | |
| seed_param, | |
| num_inference_steps, | |
| guidance_scale, | |
| audio_guidance_scale, | |
| fps, | |
| partial_video_length, | |
| overlap_video_length, | |
| neg_scale, | |
| neg_steps, | |
| use_dynamic_cfg, | |
| use_dynamic_acfg, | |
| sampler_name, | |
| shift, | |
| audio_scale, | |
| use_un_ip_mask, | |
| enable_teacache, | |
| teacache_threshold, | |
| teacache_offload, | |
| num_skip_start_steps, | |
| enable_riflex, | |
| riflex_k, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| global pipeline, wav2vec_processor, wav2vec_model, config | |
| progress(0, desc="Starting video generation...") | |
| if image_path is None: | |
| raise gr.Error("Please upload an image") | |
| if audio_path is None: | |
| raise gr.Error("Please upload an audio file") | |
| if not models_ready or pipeline is None: | |
| raise gr.Error("Models not initialized. Please restart the space.") | |
| device = pipeline.device | |
| if seed_param < 0: | |
| seed = random.randint(0, np.iinfo(np.int32).max) | |
| else: | |
| seed = int(seed_param) | |
| print(f"π² Using seed: {seed}") | |
| try: | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| ref_img_pil = Image.open(image_path).convert("RGB") | |
| print(f"πΈ Image loaded: {ref_img_pil.size}") | |
| progress(0.1, desc="Detecting face...") | |
| try: | |
| y1, y2, x1, x2, h_, w_ = get_mask_coord(image_path) | |
| print("β Face detection successful") | |
| except Exception as e: | |
| print(f"β οΈ Face detection failed: {e}, using center crop") | |
| h_, w_ = ref_img_pil.size[1], ref_img_pil.size[0] | |
| y1, y2 = h_ // 4, 3 * h_ // 4 | |
| x1, x2 = w_ // 4, 3 * w_ // 4 | |
| progress(0.2, desc="Processing audio...") | |
| audio_clip = AudioFileClip(audio_path) | |
| audio_features = extract_audio_features(audio_path, wav2vec_processor, wav2vec_model) | |
| audio_embeds = audio_features.unsqueeze(0).to(device=device, dtype=config.weight_dtype) | |
| progress(0.25, desc="Encoding prompts...") | |
| prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt( | |
| prompt, | |
| device=device, | |
| num_images_per_prompt=1, | |
| do_classifier_free_guidance=(guidance_scale > 1.0), | |
| negative_prompt=negative_prompt | |
| ) | |
| video_length = int(audio_clip.duration * fps) | |
| video_length = ( | |
| int((video_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1 | |
| if video_length != 1 else 1 | |
| ) | |
| print(f"π₯ Total video length: {video_length} frames") | |
| sample_height, sample_width = get_sample_size(ref_img_pil, config.sample_size) | |
| print(f"π Sample size: {sample_width}x{sample_height}") | |
| downratio = math.sqrt(sample_height * sample_width / h_ / w_) | |
| coords = ( | |
| y1 * downratio // 16, y2 * downratio // 16, | |
| x1 * downratio // 16, x2 * downratio // 16, | |
| sample_height // 16, sample_width // 16, | |
| ) | |
| ip_mask = get_ip_mask(coords).unsqueeze(0) | |
| ip_mask = torch.cat([ip_mask]*3).to(device=device, dtype=config.weight_dtype) | |
| if enable_riflex: | |
| latent_frames = (video_length - 1) // pipeline.vae.config.temporal_compression_ratio + 1 | |
| pipeline.transformer.enable_riflex(k=riflex_k, L_test=latent_frames) | |
| if enable_teacache: | |
| try: | |
| coefficients = get_teacache_coefficients(config.model_name) | |
| if coefficients: | |
| pipeline.transformer.enable_teacache( | |
| coefficients, num_inference_steps, teacache_threshold, | |
| num_skip_start_steps=num_skip_start_steps, | |
| offload=teacache_offload | |
| ) | |
| print("β TeaCache enabled for this run") | |
| except Exception as e: | |
| print(f"β οΈ Could not enable TeaCache: {e}") | |
| init_frames = 0 | |
| new_sample = None | |
| ref_img_for_loop = ref_img_pil | |
| total_chunks = math.ceil(video_length / (partial_video_length - overlap_video_length)) if video_length > partial_video_length else 1 | |
| chunk_num = 0 | |
| while init_frames < video_length: | |
| chunk_num += 1 | |
| progress(0.3 + (0.6 * (chunk_num / total_chunks)), desc=f"Generating chunk {chunk_num}/{total_chunks}...") | |
| current_partial_length = min(partial_video_length, video_length - init_frames) | |
| current_partial_length = ( | |
| int((current_partial_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1 | |
| if current_partial_length > 1 else 1 | |
| ) | |
| if current_partial_length <= 0: | |
| break | |
| input_video, input_video_mask, clip_image = get_image_to_video_latent3( | |
| ref_img_for_loop, None, video_length=current_partial_length, | |
| sample_size=[sample_height, sample_width] | |
| ) | |
| audio_start_frame = init_frames * 2 | |
| audio_end_frame = (init_frames + current_partial_length) * 2 | |
| if audio_embeds.shape[1] < audio_end_frame: | |
| repeat_times = (audio_end_frame // audio_embeds.shape[1]) + 1 | |
| audio_embeds = audio_embeds.repeat(1, repeat_times, 1) | |
| partial_audio_embeds = audio_embeds[:, audio_start_frame:audio_end_frame] | |
| with torch.no_grad(): | |
| sample = pipeline( | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| num_frames=current_partial_length, | |
| audio_embeds=partial_audio_embeds, | |
| audio_scale=audio_scale, | |
| ip_mask=ip_mask, | |
| use_un_ip_mask=use_un_ip_mask, | |
| height=sample_height, | |
| width=sample_width, | |
| generator=generator, | |
| neg_scale=neg_scale, | |
| neg_steps=neg_steps, | |
| use_dynamic_cfg=use_dynamic_cfg, | |
| use_dynamic_acfg=use_dynamic_acfg, | |
| guidance_scale=guidance_scale, | |
| audio_guidance_scale=audio_guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| video=input_video, | |
| mask_video=input_video_mask, | |
| clip_image=clip_image, | |
| shift=shift, | |
| ).videos | |
| if new_sample is None: | |
| new_sample = sample | |
| else: | |
| mix_ratio = torch.linspace(0, 1, steps=overlap_video_length, device=device).view(1, 1, -1, 1, 1).to(new_sample.dtype) | |
| new_sample[:, :, -overlap_video_length:] = ( | |
| new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + | |
| sample[:, :, :overlap_video_length] * mix_ratio | |
| ) | |
| new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim=2) | |
| if new_sample.shape[2] >= video_length: | |
| break | |
| ref_img_for_loop = [ | |
| Image.fromarray( | |
| (new_sample[0, :, i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) | |
| ) for i in range(-overlap_video_length, 0) | |
| ] | |
| init_frames += current_partial_length - overlap_video_length | |
| progress(0.9, desc="Stitching video and audio...") | |
| final_sample = new_sample[:, :, :video_length] | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file: | |
| video_path = tmp_file.name | |
| with tempfile.NamedTemporaryFile(suffix="_audio.mp4", delete=False) as tmp_file: | |
| video_audio_path = tmp_file.name | |
| save_videos_grid(final_sample, video_path, fps=fps) | |
| video_clip_final = VideoFileClip(video_path) | |
| audio_clip_trimmed = audio_clip.subclip(0, final_sample.shape[2] / fps) | |
| final_video = video_clip_final.with_audio(audio_clip_trimmed) | |
| final_video.write_videofile(video_audio_path, codec="libx264", audio_codec="aac", threads=4, logger=None) | |
| video_clip_final.close() | |
| audio_clip.close() | |
| audio_clip_trimmed.close() | |
| final_video.close() | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| progress(1.0, desc="Generation complete!") | |
| return video_audio_path, seed | |
| except Exception as e: | |
| print(f"β Generation error: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| raise gr.Error(f"Generation failed: {str(e)}") | |
| def create_demo(): | |
| with gr.Blocks(theme=gr.themes.Soft(), title="EchoMimicV3 Demo") as demo: | |
| gr.Markdown(""" | |
| # π EchoMimicV3: Audio-Driven Human Animation | |
| Transform a portrait photo into a talking video! Upload an image and an audio file to create lifelike, expressive animations. This demo showcases the power of the EchoMimicV3 model. | |
| **Key Features:** | |
| - π― **High-Quality Lip Sync:** Accurate mouth movements that match the input audio. | |
| - π¨ **Natural Facial Expressions:** Generates subtle and natural facial emotions. | |
| - π΅ **Speech & Singing:** Works with both spoken word and singing. | |
| - β‘ **Efficient:** Powered by a compact 1.3B parameter model. | |
| """) | |
| if not models_ready: | |
| gr.Warning("Models are still loading. The UI is disabled. Please wait and refresh the page if necessary.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| label="πΈ Upload Portrait Image", | |
| type="filepath", | |
| sources=["upload"], | |
| height=400, | |
| ) | |
| audio_input = gr.Audio( | |
| label="π΅ Upload Audio", | |
| type="filepath", | |
| sources=["upload"], | |
| ) | |
| with gr.Accordion("π Text Prompts", open=True): | |
| prompt = gr.Textbox( | |
| label="βοΈ Prompt", | |
| value="A person talking naturally with clear expressions.", | |
| ) | |
| negative_prompt = gr.Textbox( | |
| label="π« Negative Prompt", | |
| value="Gesture is bad, unclear. Strange, twisted, bad, blurry hands and fingers.", | |
| lines=2, | |
| ) | |
| with gr.Column(scale=1): | |
| video_output = gr.Video( | |
| label="π₯ Generated Video", | |
| interactive=False, | |
| height=400 | |
| ) | |
| seed_output = gr.Number( | |
| label="π² Used Seed", | |
| interactive=False, | |
| precision=0 | |
| ) | |
| with gr.Accordion("βοΈ Advanced Settings", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Core Generation Parameters") | |
| seed_param = gr.Number(label="π² Seed", value=-1, precision=0, info="-1 for random seed.") | |
| num_inference_steps = gr.Slider(label="Inference Steps", minimum=5, maximum=50, value=20, step=1, info="More steps can improve quality but take longer. 15-25 is a good range.") | |
| fps = gr.Slider(label="Frames Per Second (FPS)", minimum=10, maximum=30, value=25, step=1, info="Controls the smoothness of the output video.") | |
| with gr.Column(): | |
| gr.Markdown("### Classifier-Free Guidance (CFG)") | |
| guidance_scale = gr.Slider(label="Text Guidance Scale (CFG)", minimum=1.0, maximum=10.0, value=4.5, step=0.1, info="How strongly to follow the text prompt. Recommended: 3.0-6.0.") | |
| audio_guidance_scale = gr.Slider(label="Audio Guidance Scale (aCFG)", minimum=1.0, maximum=10.0, value=2.5, step=0.1, info="How strongly to follow the audio for lip sync. Recommended: 2.0-3.0.") | |
| use_dynamic_cfg = gr.Checkbox(label="Use Dynamic Text CFG", value=True, info="Gradually adjusts CFG during generation, can improve quality.") | |
| use_dynamic_acfg = gr.Checkbox(label="Use Dynamic Audio aCFG", value=True, info="Gradually adjusts aCFG during generation, can improve quality.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Performance & VRAM (Chunking)") | |
| partial_video_length = gr.Slider(label="Partial Video Length (Chunk Size)", minimum=49, maximum=161, value=113, step=16, info="Key for VRAM usage. 24G VRAM: ~113, 16G: ~81, 12G: ~49. Lower values use less memory but may affect consistency.") | |
| overlap_video_length = gr.Slider(label="Overlap Length", minimum=4, maximum=16, value=8, step=1, info="How many frames to overlap between chunks for smooth transitions.") | |
| with gr.Column(): | |
| gr.Markdown("### Sampler & Scheduler") | |
| sampler_name = gr.Dropdown(label="Sampler", choices=["Flow", "Flow_Unipc", "Flow_DPM++"], value="Flow_DPM++", info="Algorithm for the diffusion process.") | |
| shift = gr.Slider(label="Scheduler Shift", minimum=1.0, maximum=10.0, value=5.0, step=0.1, info="Adjusts the noise schedule. Optimal range depends on the sampler.") | |
| audio_scale = gr.Slider(label="Audio Scale", minimum=0.5, maximum=2.0, value=1.0, step=0.1, info="Global scale for audio feature influence.") | |
| use_un_ip_mask = gr.Checkbox(label="Use Un-IP Mask", value=False, info="Inverts the inpainting mask.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Negative Guidance (Advanced CFG)") | |
| neg_scale = gr.Slider(label="Negative Scale", minimum=1.0, maximum=5.0, value=1.5, step=0.1, info="Strength of negative prompt in early steps.") | |
| neg_steps = gr.Slider(label="Negative Steps", minimum=0, maximum=10, value=2, step=1, info="How many initial steps to apply the negative scale.") | |
| with gr.Accordion("π¬ Experimental Settings", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### TeaCache (Performance Boost)") | |
| enable_teacache = gr.Checkbox(label="Enable TeaCache", value=True) | |
| teacache_threshold = gr.Slider(label="TeaCache Threshold", minimum=0.0, maximum=0.2, value=0.1, step=0.01) | |
| teacache_offload = gr.Checkbox(label="TeaCache Offload", value=True) | |
| with gr.Column(): | |
| gr.Markdown("### Riflex (Consistency)") | |
| enable_riflex = gr.Checkbox(label="Enable Riflex", value=False) | |
| riflex_k = gr.Slider(label="Riflex K", minimum=1, maximum=10, value=6, step=1) | |
| with gr.Column(): | |
| gr.Markdown("### Other") | |
| num_skip_start_steps = gr.Slider(label="Num Skip Start Steps", minimum=0, maximum=10, value=5, step=1) | |
| generate_button = gr.Button( | |
| "π¬ Generate Video", | |
| variant='primary', | |
| size="lg", | |
| interactive=models_ready | |
| ) | |
| all_inputs = [ | |
| image_input, audio_input, prompt, negative_prompt, seed_param, | |
| num_inference_steps, guidance_scale, audio_guidance_scale, fps, | |
| partial_video_length, overlap_video_length, neg_scale, neg_steps, | |
| use_dynamic_cfg, use_dynamic_acfg, sampler_name, shift, audio_scale, | |
| use_un_ip_mask, enable_teacache, teacache_threshold, teacache_offload, | |
| num_skip_start_steps, enable_riflex, riflex_k | |
| ] | |
| if models_ready: | |
| generate_button.click( | |
| fn=generate_video, | |
| inputs=all_inputs, | |
| outputs=[video_output, seed_output] | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### β¨ Click to Try Examples") | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "examples/demo_ch_woman_04.png", | |
| "examples/demo_ch_woman_04.WAV", | |
| "A Chinese woman is talking naturally.", | |
| "bad gestures, blurry, distorted face", | |
| 42, 20, 4.5, 2.5, 25, 113, 8, 1.5, 2, True, True, "Flow_DPM++", 5.0, 1.0, False, True, 0.1, True, 5, False, 6 | |
| ], | |
| [ | |
| "examples/guitar_woman_01.png", | |
| "examples/guitar_woman_01.WAV", | |
| "A woman with glasses is singing and playing the guitar.", | |
| "blurry, distorted face, bad hands", | |
| 123, 25, 5.0, 2.8, 25, 113, 8, 1.5, 2, True, True, "Flow_DPM++", 5.0, 1.0, False, True, 0.1, True, 5, False, 6 | |
| ], | |
| ], | |
| inputs=all_inputs, | |
| outputs=[video_output, seed_output], | |
| fn=generate_video, | |
| cache_examples=True, | |
| label=None, | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown(""" | |
| ### π How to Use | |
| 1. **Upload Image:** Choose a clear portrait photo (front-facing works best). | |
| 2. **Upload Audio:** Add an audio file with clear speech or singing. | |
| 3. **Adjust Settings (Optional):** Fine-tune parameters in the advanced sections for different results. For memory issues, try lowering the "Partial Video Length". | |
| 4. **Generate:** Click the button and wait for your talking video! | |
| **Note:** Generation time depends on settings and audio length. It can take a few minutes. | |
| This demo is based on the [EchoMimicV3 repository](https://github.com/antgroup/echomimic_v3). | |
| """) | |
| return demo | |
| if __name__ == "__main__": | |
| print("π Starting model initialization...") | |
| models_ready = initialize_models() | |
| demo = create_demo() | |
| demo.launch(share=True) |