multimodalart HF Staff commited on
Commit
15f0443
·
verified ·
1 Parent(s): c67e143

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -18,7 +18,8 @@ import gc
18
  from gradio_client import Client, handle_file # Import for API call
19
 
20
  # Import the optimization function from the separate file
21
- from optimization import optimize_pipeline_
 
22
 
23
  # --- Constants and Model Loading ---
24
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
@@ -67,13 +68,13 @@ for i in range(3):
67
  torch.cuda.synchronize()
68
  torch.cuda.empty_cache()
69
 
70
- optimize_pipeline_(pipe,
71
- image=Image.new('RGB', (MAX_DIMENSION, MIN_DIMENSION)),
72
- prompt='prompt',
73
- height=MIN_DIMENSION,
74
- width=MAX_DIMENSION,
75
- num_frames=MAX_FRAMES_MODEL,
76
- )
77
  print("All models loaded and optimized. Gradio app is ready.")
78
 
79
 
 
18
  from gradio_client import Client, handle_file # Import for API call
19
 
20
  # Import the optimization function from the separate file
21
+ from torchao.quantization import quantize_
22
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig
23
 
24
  # --- Constants and Model Loading ---
25
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
 
68
  torch.cuda.synchronize()
69
  torch.cuda.empty_cache()
70
 
71
+ quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
72
+ quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
73
+ quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
74
+
75
+ spaces.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
76
+ spaces.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
77
+
78
  print("All models loaded and optimized. Gradio app is ready.")
79
 
80