Spaces:
Running
on
Zero
Running
on
Zero
| # CLAUDE.md | |
| This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. | |
| ## Project Overview | |
| RAPO++ is a three-stage framework for text-to-video (T2V) generation prompt optimization. It combines: | |
| - **Stage 1 (RAPO)**: Retrieval-Augmented Prompt Optimization using relation graphs | |
| - **Stage 2 (SSPO)**: Self-Supervised Prompt Optimization with test-time iterative refinement | |
| - **Stage 3**: LLM fine-tuning on collected feedback data | |
| The system is model-agnostic and works with various T2V models (Wan2.1, Open-Sora-Plan, HunyuanVideo, etc.). | |
| ## Environment Setup | |
| ```bash | |
| # Create and activate environment | |
| conda create -n rapo_plus python=3.10 | |
| conda activate rapo_plus | |
| # Install dependencies | |
| pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu118 | |
| pip install -r requirement.txt | |
| ``` | |
| ## Required Checkpoints | |
| Download and place in `ckpt/` directory: | |
| **Stage 1:** | |
| - `all-MiniLM-L6-v2/` - Sentence transformer for embeddings | |
| - `llama3_1_instruct_lora_rewrite/` - LLM for prompt rewriting | |
| - `Mistral-7B-Instruct-v0.3/` - Alternative instruction-tuned LLM | |
| **Stage 2 (example with Wan2.1):** | |
| - `Wan2.1-T2V-1.3B-Diffusers/` - Base T2V model | |
| - `Qwen2.5-7B-Instruct/` - Instruction-following LLM for prompt refinement | |
| - `Qwen2.5-vl-7B-instruct/` - Vision-language model for video alignment assessment | |
| Also place relation graph data in `relation_graph/graph_data/`. | |
| ## Core Workflows | |
| ### Stage 1: RAPO (Retrieval-Augmented Prompt Optimization) | |
| **Location:** `examples/Stage1_RAPO/` | |
| **Pipeline:** | |
| 1. **Graph Construction** (`construct_graph.py`): | |
| - Reads CSV with columns: `Input`, `verb_obj_word`, `scenario_word`, `place` | |
| - Creates NetworkX graphs linking places to verbs and scenes | |
| - Generates embeddings with SentenceTransformer | |
| - Outputs: JSON dictionaries, GraphML files to `relation_graph/` | |
| 2. **Modifier Retrieval** (`retrieve_modifiers.py`): | |
| - Input: Test prompts from `data/test_prompts.txt` | |
| - Encodes prompts and retrieves top-K related places via cosine similarity | |
| - Samples connected verbs/scenes from graph neighbors | |
| - Outputs: `output/retrieve_words/{filename}.txt` and `.csv` | |
| - Run: `sh retrieve_modifiers.sh` | |
| 3. **Word Augmentation** (`word_augment.py`): | |
| - Filters retrieved modifiers by similarity threshold | |
| - Merges modifiers interactively | |
| - Run: `sh word_augment.sh` | |
| 4. **Sentence Refactoring** (`refactoring.py`): | |
| - Restructures prompts with augmented modifiers | |
| - Run: `sh refactoring.sh` | |
| 5. **Instruction-Based Rewriting** (`rewrite_via_instruction.py`): | |
| - Uses LLM to refine prompts with natural language instructions | |
| - Run: `sh rewrite_via_instruction.sh` | |
| **Key Parameters:** | |
| - `place_num`: Top-K places to retrieve (default: 3) | |
| - `verb_num`, `topk_num`: Controls verb/scene sampling | |
| - `SIMILARITY_THRESHOLD`: Filters modifiers in word_augment.py | |
| ### Stage 2: SSPO (Self-Supervised Prompt Optimization) | |
| **Location:** `examples/Stage2_SSPO/` | |
| **Main Script:** `phyaware_wan2.1.py` | |
| **Architecture:** | |
| This script implements a closed-loop iterative optimization pipeline: | |
| 1. **Video Generation** (`load_model()`, `generate_single_video()`): | |
| - Uses WanPipeline to generate videos from prompts | |
| - Configurable: height=480, width=832, num_frames=81, fps=15 | |
| 2. **Optical Flow Analysis** (`extract_optical_flow()`): | |
| - Extracts motion statistics using cv2.calcOpticalFlowFarneback | |
| - Samples frames at configurable intervals | |
| - Returns sequence of (x, y) flow vectors | |
| 3. **VLM Alignment Assessment** (`misalignment_assessment()`): | |
| - Uses Qwen2.5-VL to evaluate video-prompt alignment | |
| - Assesses objects, actions, scenes | |
| - Returns textual alignment score (1-5 scale) | |
| 4. **Physics Consistency Check + Prompt Refinement** (`evaluate_physical_consistency()`): | |
| - **Phase 1**: LLM analyzes optical flow for physical plausibility (inertia, momentum, etc.) | |
| - **Phase 2**: Fuses physics analysis + VLM alignment feedback | |
| - Rewrites prompt to enforce physical rules and semantic alignment | |
| - Uses Qwen2.5-7B-Instruct | |
| 5. **Iterative Loop**: | |
| - Generates video β Analyzes β Refines prompt β Generates again | |
| - Default: 5 refinement iterations per prompt | |
| - Logs to CSV: `results/examples_refined/refined_prompts.csv` | |
| **Resume Capability:** | |
| The script checks existing logs and videos to resume from last iteration, maintaining prompt chain consistency. | |
| **Input Format:** | |
| CSV with columns: `captions` (prompt), `phys_law` (physical rule to enforce) | |
| **Key Configuration (lines 248-264):** | |
| ```python | |
| WAN_MODEL_ID = "../../ckpt/Wan2.1-T2V-1.3B-Diffusers" | |
| INSTRUCT_LLM_PATH = "../../ckpt/Qwen2.5-7B-Instruct" | |
| QWEN_VL_PATH = "../../ckpt/qwen2.5-vl-7B-instruct" | |
| num_refine_iterations = 5 | |
| ``` | |
| ### Stage 3: LLM Fine-Tuning | |
| Not provided in code; uses feedback data from Stage 2 to fine-tune model-specific prompt refiners. | |
| ## Key Architectural Patterns | |
| ### Graph-Based Retrieval (Stage 1) | |
| - **Data Structure**: NetworkX graphs with place nodes as hubs | |
| - **Retrieval**: Cosine similarity between prompt embeddings and place embeddings | |
| - **Augmentation**: Graph neighbors provide contextually relevant modifiers | |
| - **Caching**: Pre-computed embeddings stored in JSON for efficiency | |
| ### Closed-Loop Optimization (Stage 2) | |
| - **Multi-Modal Feedback**: Combines optical flow (physics) + VLM (semantics) | |
| - **Iterative Refinement**: Each video informs next prompt | |
| - **Logging**: CSV tracks full prompt evolution chain | |
| - **Modularity**: Easy to swap T2V models, reward functions, or VLMs | |
| ### Embedding Model Usage | |
| - SentenceTransformer for text similarity (Stage 1) | |
| - Pre-encode and cache all graph tokens to avoid redundant computation | |
| ## Common Commands | |
| **Stage 1 - Full Pipeline:** | |
| ```bash | |
| cd examples/Stage1_RAPO | |
| # Build graph from scratch | |
| python construct_graph.py | |
| # Run full RAPO pipeline | |
| sh retrieve_modifiers.sh | |
| sh word_augment.sh | |
| sh refactoring.sh | |
| sh rewrite_via_instruction.sh | |
| ``` | |
| **Stage 2 - SSPO:** | |
| ```bash | |
| cd examples/Stage2_SSPO | |
| python phyaware_wan2.1.py | |
| ``` | |
| ## File Dependencies | |
| **Input Files:** | |
| - `data/test_prompts.txt` - One prompt per line for Stage 1 | |
| - `examples/Stage2_SSPO/examples.csv` - Prompts + physical rules for Stage 2 | |
| - `relation_graph/graph_data/*.json` - Pre-built graph data | |
| - `relation_graph/graph_data/*.graphml` - Graph structure | |
| **Output Structure:** | |
| - `examples/Stage1_RAPO/output/retrieve_words/` - Retrieved modifiers | |
| - `examples/Stage1_RAPO/output/refactor/` - Augmented prompts | |
| - `examples/Stage2_SSPO/results/examples_refined/` - Videos + logs | |
| ## Critical Implementation Details | |
| ### Stage 1 Graph Construction | |
| - Place tokens serve as central nodes linking verbs and scenes | |
| - Edge weights implicitly represent co-occurrence frequency | |
| - Embedding dimension from SentenceTransformer: 384 (all-MiniLM-L6-v2) | |
| ### Stage 2 Physics Analysis | |
| The `evaluate_physical_consistency()` function uses a two-phase LLM prompting strategy: | |
| 1. First call: Analyze optical flow for physics violations | |
| 2. Second call: Synthesize physics + VLM feedback into refined prompt | |
| The prompt rewriting instruction explicitly constrains: | |
| - Motion continuity and force consistency | |
| - Object states and timings | |
| - Camera motion if needed | |
| - Output limited to <120 words | |
| ### Optical Flow Extraction | |
| - Uses Farneback algorithm (dense optical flow) | |
| - Samples frames at 0.5-second intervals by default | |
| - Returns mean (x, y) flow per frame pair | |
| - Sudden reversals or inconsistent magnitudes indicate physics violations | |
| ## Model Swapping | |
| **To use a different T2V model in Stage 2:** | |
| 1. Update pipeline loading in `load_model()` function | |
| 2. Adjust generation parameters (height, width, num_frames) | |
| 3. Ensure model outputs diffusers-compatible format | |
| 4. Update checkpoint path constants (lines 249-251) | |
| **To use a different VLM:** | |
| - Replace `Qwen2_5_VLForConditionalGeneration` with alternative | |
| - Adjust processor and prompt template in `misalignment_assessment()` | |
| **To use a different LLM for refinement:** | |
| - Update `INSTRUCT_LLM_PATH` and ensure transformers compatibility | |
| - Modify system/user message format if needed | |
| ## Troubleshooting | |
| **Graph loading errors:** | |
| - Ensure all JSON files exist in `relation_graph/graph_data/` | |
| - Check GraphML files are valid NetworkX format | |
| **CUDA OOM:** | |
| - Stage 2 loads 3 large models simultaneously (T2V, VLM, LLM) | |
| - Reduce batch size or use smaller models | |
| - Consider offloading models between steps | |
| **Syntax error in phyaware_wan2.1.py line 251:** | |
| - Missing opening quote: `QWEN_VL_PATH = ../../ckpt//qwen2.5-vl-7B-instruct"` | |
| - Should be: `QWEN_VL_PATH = "../../ckpt/qwen2.5-vl-7B-instruct"` | |
| ## Paper References | |
| - **RAPO**: "The Devil is in the Prompts: Retrieval-Augmented Prompt Optimization for Text-to-Video Generation" (CVPR 2025) | |
| - **RAPO++**: arXiv:2510.20206 | |
| - Project pages and models available on HuggingFace | |