SNIPED_rapo / CLAUDE.md
jbilcke-hf's picture
Upload repository for paper 2510.20206
ee81688 verified
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project Overview
RAPO++ is a three-stage framework for text-to-video (T2V) generation prompt optimization. It combines:
- **Stage 1 (RAPO)**: Retrieval-Augmented Prompt Optimization using relation graphs
- **Stage 2 (SSPO)**: Self-Supervised Prompt Optimization with test-time iterative refinement
- **Stage 3**: LLM fine-tuning on collected feedback data
The system is model-agnostic and works with various T2V models (Wan2.1, Open-Sora-Plan, HunyuanVideo, etc.).
## Environment Setup
```bash
# Create and activate environment
conda create -n rapo_plus python=3.10
conda activate rapo_plus
# Install dependencies
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu118
pip install -r requirement.txt
```
## Required Checkpoints
Download and place in `ckpt/` directory:
**Stage 1:**
- `all-MiniLM-L6-v2/` - Sentence transformer for embeddings
- `llama3_1_instruct_lora_rewrite/` - LLM for prompt rewriting
- `Mistral-7B-Instruct-v0.3/` - Alternative instruction-tuned LLM
**Stage 2 (example with Wan2.1):**
- `Wan2.1-T2V-1.3B-Diffusers/` - Base T2V model
- `Qwen2.5-7B-Instruct/` - Instruction-following LLM for prompt refinement
- `Qwen2.5-vl-7B-instruct/` - Vision-language model for video alignment assessment
Also place relation graph data in `relation_graph/graph_data/`.
## Core Workflows
### Stage 1: RAPO (Retrieval-Augmented Prompt Optimization)
**Location:** `examples/Stage1_RAPO/`
**Pipeline:**
1. **Graph Construction** (`construct_graph.py`):
- Reads CSV with columns: `Input`, `verb_obj_word`, `scenario_word`, `place`
- Creates NetworkX graphs linking places to verbs and scenes
- Generates embeddings with SentenceTransformer
- Outputs: JSON dictionaries, GraphML files to `relation_graph/`
2. **Modifier Retrieval** (`retrieve_modifiers.py`):
- Input: Test prompts from `data/test_prompts.txt`
- Encodes prompts and retrieves top-K related places via cosine similarity
- Samples connected verbs/scenes from graph neighbors
- Outputs: `output/retrieve_words/{filename}.txt` and `.csv`
- Run: `sh retrieve_modifiers.sh`
3. **Word Augmentation** (`word_augment.py`):
- Filters retrieved modifiers by similarity threshold
- Merges modifiers interactively
- Run: `sh word_augment.sh`
4. **Sentence Refactoring** (`refactoring.py`):
- Restructures prompts with augmented modifiers
- Run: `sh refactoring.sh`
5. **Instruction-Based Rewriting** (`rewrite_via_instruction.py`):
- Uses LLM to refine prompts with natural language instructions
- Run: `sh rewrite_via_instruction.sh`
**Key Parameters:**
- `place_num`: Top-K places to retrieve (default: 3)
- `verb_num`, `topk_num`: Controls verb/scene sampling
- `SIMILARITY_THRESHOLD`: Filters modifiers in word_augment.py
### Stage 2: SSPO (Self-Supervised Prompt Optimization)
**Location:** `examples/Stage2_SSPO/`
**Main Script:** `phyaware_wan2.1.py`
**Architecture:**
This script implements a closed-loop iterative optimization pipeline:
1. **Video Generation** (`load_model()`, `generate_single_video()`):
- Uses WanPipeline to generate videos from prompts
- Configurable: height=480, width=832, num_frames=81, fps=15
2. **Optical Flow Analysis** (`extract_optical_flow()`):
- Extracts motion statistics using cv2.calcOpticalFlowFarneback
- Samples frames at configurable intervals
- Returns sequence of (x, y) flow vectors
3. **VLM Alignment Assessment** (`misalignment_assessment()`):
- Uses Qwen2.5-VL to evaluate video-prompt alignment
- Assesses objects, actions, scenes
- Returns textual alignment score (1-5 scale)
4. **Physics Consistency Check + Prompt Refinement** (`evaluate_physical_consistency()`):
- **Phase 1**: LLM analyzes optical flow for physical plausibility (inertia, momentum, etc.)
- **Phase 2**: Fuses physics analysis + VLM alignment feedback
- Rewrites prompt to enforce physical rules and semantic alignment
- Uses Qwen2.5-7B-Instruct
5. **Iterative Loop**:
- Generates video β†’ Analyzes β†’ Refines prompt β†’ Generates again
- Default: 5 refinement iterations per prompt
- Logs to CSV: `results/examples_refined/refined_prompts.csv`
**Resume Capability:**
The script checks existing logs and videos to resume from last iteration, maintaining prompt chain consistency.
**Input Format:**
CSV with columns: `captions` (prompt), `phys_law` (physical rule to enforce)
**Key Configuration (lines 248-264):**
```python
WAN_MODEL_ID = "../../ckpt/Wan2.1-T2V-1.3B-Diffusers"
INSTRUCT_LLM_PATH = "../../ckpt/Qwen2.5-7B-Instruct"
QWEN_VL_PATH = "../../ckpt/qwen2.5-vl-7B-instruct"
num_refine_iterations = 5
```
### Stage 3: LLM Fine-Tuning
Not provided in code; uses feedback data from Stage 2 to fine-tune model-specific prompt refiners.
## Key Architectural Patterns
### Graph-Based Retrieval (Stage 1)
- **Data Structure**: NetworkX graphs with place nodes as hubs
- **Retrieval**: Cosine similarity between prompt embeddings and place embeddings
- **Augmentation**: Graph neighbors provide contextually relevant modifiers
- **Caching**: Pre-computed embeddings stored in JSON for efficiency
### Closed-Loop Optimization (Stage 2)
- **Multi-Modal Feedback**: Combines optical flow (physics) + VLM (semantics)
- **Iterative Refinement**: Each video informs next prompt
- **Logging**: CSV tracks full prompt evolution chain
- **Modularity**: Easy to swap T2V models, reward functions, or VLMs
### Embedding Model Usage
- SentenceTransformer for text similarity (Stage 1)
- Pre-encode and cache all graph tokens to avoid redundant computation
## Common Commands
**Stage 1 - Full Pipeline:**
```bash
cd examples/Stage1_RAPO
# Build graph from scratch
python construct_graph.py
# Run full RAPO pipeline
sh retrieve_modifiers.sh
sh word_augment.sh
sh refactoring.sh
sh rewrite_via_instruction.sh
```
**Stage 2 - SSPO:**
```bash
cd examples/Stage2_SSPO
python phyaware_wan2.1.py
```
## File Dependencies
**Input Files:**
- `data/test_prompts.txt` - One prompt per line for Stage 1
- `examples/Stage2_SSPO/examples.csv` - Prompts + physical rules for Stage 2
- `relation_graph/graph_data/*.json` - Pre-built graph data
- `relation_graph/graph_data/*.graphml` - Graph structure
**Output Structure:**
- `examples/Stage1_RAPO/output/retrieve_words/` - Retrieved modifiers
- `examples/Stage1_RAPO/output/refactor/` - Augmented prompts
- `examples/Stage2_SSPO/results/examples_refined/` - Videos + logs
## Critical Implementation Details
### Stage 1 Graph Construction
- Place tokens serve as central nodes linking verbs and scenes
- Edge weights implicitly represent co-occurrence frequency
- Embedding dimension from SentenceTransformer: 384 (all-MiniLM-L6-v2)
### Stage 2 Physics Analysis
The `evaluate_physical_consistency()` function uses a two-phase LLM prompting strategy:
1. First call: Analyze optical flow for physics violations
2. Second call: Synthesize physics + VLM feedback into refined prompt
The prompt rewriting instruction explicitly constrains:
- Motion continuity and force consistency
- Object states and timings
- Camera motion if needed
- Output limited to <120 words
### Optical Flow Extraction
- Uses Farneback algorithm (dense optical flow)
- Samples frames at 0.5-second intervals by default
- Returns mean (x, y) flow per frame pair
- Sudden reversals or inconsistent magnitudes indicate physics violations
## Model Swapping
**To use a different T2V model in Stage 2:**
1. Update pipeline loading in `load_model()` function
2. Adjust generation parameters (height, width, num_frames)
3. Ensure model outputs diffusers-compatible format
4. Update checkpoint path constants (lines 249-251)
**To use a different VLM:**
- Replace `Qwen2_5_VLForConditionalGeneration` with alternative
- Adjust processor and prompt template in `misalignment_assessment()`
**To use a different LLM for refinement:**
- Update `INSTRUCT_LLM_PATH` and ensure transformers compatibility
- Modify system/user message format if needed
## Troubleshooting
**Graph loading errors:**
- Ensure all JSON files exist in `relation_graph/graph_data/`
- Check GraphML files are valid NetworkX format
**CUDA OOM:**
- Stage 2 loads 3 large models simultaneously (T2V, VLM, LLM)
- Reduce batch size or use smaller models
- Consider offloading models between steps
**Syntax error in phyaware_wan2.1.py line 251:**
- Missing opening quote: `QWEN_VL_PATH = ../../ckpt//qwen2.5-vl-7B-instruct"`
- Should be: `QWEN_VL_PATH = "../../ckpt/qwen2.5-vl-7B-instruct"`
## Paper References
- **RAPO**: "The Devil is in the Prompts: Retrieval-Augmented Prompt Optimization for Text-to-Video Generation" (CVPR 2025)
- **RAPO++**: arXiv:2510.20206
- Project pages and models available on HuggingFace