ftaubner commited on
Commit
c1f2d8e
·
1 Parent(s): f6c8dde

remove to(device) calls in load_model

Browse files
Files changed (1) hide show
  1. inference.py +5 -5
inference.py CHANGED
@@ -161,10 +161,6 @@ def load_model(args):
161
  # as these weights are only used for inference, keeping weights in full precision is not required.
162
  weight_dtype = torch.bfloat16
163
 
164
- text_encoder.to(args.device, dtype=weight_dtype)
165
- transformer.to(args.device, dtype=weight_dtype)
166
- vae.to(args.device, dtype=weight_dtype)
167
-
168
  pipe = ControlnetCogVideoXPipeline.from_pretrained(
169
  args.pretrained_model_path,
170
  tokenizer=tokenizer,
@@ -186,7 +182,6 @@ def load_model(args):
186
  scheduler_args["variance_type"] = variance_type
187
 
188
  pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)
189
- pipe = pipe.to(args.device)
190
 
191
  return pipe, model_config
192
 
@@ -247,6 +242,11 @@ def main(args):
247
 
248
  pipe, model_config = load_model(args)
249
 
 
 
 
 
 
250
  for image_path in image_paths:
251
  image = Image.open(image_path)
252
 
 
161
  # as these weights are only used for inference, keeping weights in full precision is not required.
162
  weight_dtype = torch.bfloat16
163
 
 
 
 
 
164
  pipe = ControlnetCogVideoXPipeline.from_pretrained(
165
  args.pretrained_model_path,
166
  tokenizer=tokenizer,
 
182
  scheduler_args["variance_type"] = variance_type
183
 
184
  pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)
 
185
 
186
  return pipe, model_config
187
 
 
242
 
243
  pipe, model_config = load_model(args)
244
 
245
+ # text_encoder.to(args.device, dtype=weight_dtype)
246
+ # transformer.to(args.device, dtype=weight_dtype)
247
+ # vae.to(args.device, dtype=weight_dtype)
248
+ pipe = pipe.to(args.device)
249
+
250
  for image_path in image_paths:
251
  image = Image.open(image_path)
252