Spaces:
Sleeping
Sleeping
remove to(device) calls in load_model
Browse files- inference.py +5 -5
inference.py
CHANGED
|
@@ -161,10 +161,6 @@ def load_model(args):
|
|
| 161 |
# as these weights are only used for inference, keeping weights in full precision is not required.
|
| 162 |
weight_dtype = torch.bfloat16
|
| 163 |
|
| 164 |
-
text_encoder.to(args.device, dtype=weight_dtype)
|
| 165 |
-
transformer.to(args.device, dtype=weight_dtype)
|
| 166 |
-
vae.to(args.device, dtype=weight_dtype)
|
| 167 |
-
|
| 168 |
pipe = ControlnetCogVideoXPipeline.from_pretrained(
|
| 169 |
args.pretrained_model_path,
|
| 170 |
tokenizer=tokenizer,
|
|
@@ -186,7 +182,6 @@ def load_model(args):
|
|
| 186 |
scheduler_args["variance_type"] = variance_type
|
| 187 |
|
| 188 |
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)
|
| 189 |
-
pipe = pipe.to(args.device)
|
| 190 |
|
| 191 |
return pipe, model_config
|
| 192 |
|
|
@@ -247,6 +242,11 @@ def main(args):
|
|
| 247 |
|
| 248 |
pipe, model_config = load_model(args)
|
| 249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
for image_path in image_paths:
|
| 251 |
image = Image.open(image_path)
|
| 252 |
|
|
|
|
| 161 |
# as these weights are only used for inference, keeping weights in full precision is not required.
|
| 162 |
weight_dtype = torch.bfloat16
|
| 163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
pipe = ControlnetCogVideoXPipeline.from_pretrained(
|
| 165 |
args.pretrained_model_path,
|
| 166 |
tokenizer=tokenizer,
|
|
|
|
| 182 |
scheduler_args["variance_type"] = variance_type
|
| 183 |
|
| 184 |
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)
|
|
|
|
| 185 |
|
| 186 |
return pipe, model_config
|
| 187 |
|
|
|
|
| 242 |
|
| 243 |
pipe, model_config = load_model(args)
|
| 244 |
|
| 245 |
+
# text_encoder.to(args.device, dtype=weight_dtype)
|
| 246 |
+
# transformer.to(args.device, dtype=weight_dtype)
|
| 247 |
+
# vae.to(args.device, dtype=weight_dtype)
|
| 248 |
+
pipe = pipe.to(args.device)
|
| 249 |
+
|
| 250 |
for image_path in image_paths:
|
| 251 |
image = Image.open(image_path)
|
| 252 |
|