ovi054 commited on
Commit
cf3fdd8
·
verified ·
1 Parent(s): 16c8236

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -34
app.py CHANGED
@@ -58,40 +58,44 @@ edit_pipe = LongCatImageEditPipeline.from_pretrained(
58
  edit_pipe.to(device, torch.bfloat16)
59
 
60
  print(f"✅ Image Edit model loaded successfully on {device}")
61
-
62
- # # --- Core Functions ---
63
- # @spaces.GPU(duration=120)
64
- # def generate_image(
65
- # prompt: str,
66
- # width: int,
67
- # height: int,
68
- # seed: int,
69
- # progress=gr.Progress()
70
- # ):
71
- # """Generate image from text prompt"""
72
- # if not prompt or prompt.strip() == "":
73
- # raise gr.Error("Please enter a prompt")
74
- # try:
75
- # progress(0.1, desc="Preparing generation...")
76
- # progress(0.2, desc="Generating image...")
77
- # generator = torch.Generator("cuda" if torch.cuda.is_available() else "cpu").manual_seed(seed)
78
- # with torch.inference_mode():
79
- # output = t2i_pipe(
80
- # prompt,
81
- # negative_prompt="",
82
- # height=height,
83
- # width=width,
84
- # guidance_scale=4.5,
85
- # num_inference_steps=50,
86
- # num_images_per_prompt=1,
87
- # generator=generator,
88
- # enable_cfg_renorm=True,
89
- # enable_prompt_rewrite=True
90
- # )
91
- # progress(1.0, desc="Done!")
92
- # return output.images[0]
93
- # except Exception as e:
94
- # raise gr.Error(f"Error during image generation: {str(e)}")
 
 
 
 
95
 
96
  @spaces.GPU(duration=120)
97
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=4, num_inference_steps=28, lora_id=None, lora_scale=0.95, progress=gr.Progress(track_tqdm=True)):
 
58
  edit_pipe.to(device, torch.bfloat16)
59
 
60
  print(f"✅ Image Edit model loaded successfully on {device}")
61
+ def load_lora_auto(pipe, lora_input):
62
+ lora_input = lora_input.strip()
63
+ if not lora_input:
64
+ return
65
+
66
+ # If it's just an ID like "author/model"
67
+ if "/" in lora_input and not lora_input.startswith("http"):
68
+ pipe.load_lora_weights(lora_input)
69
+ return
70
+
71
+ if lora_input.startswith("http"):
72
+ url = lora_input
73
+
74
+ # Repo page (no blob/resolve)
75
+ if "huggingface.co" in url and "/blob/" not in url and "/resolve/" not in url:
76
+ repo_id = urlparse(url).path.strip("/")
77
+ pipe.load_lora_weights(repo_id)
78
+ return
79
+
80
+ # Blob link → convert to resolve link
81
+ if "/blob/" in url:
82
+ url = url.replace("/blob/", "/resolve/")
83
+
84
+ # Download direct file
85
+ tmp_dir = tempfile.mkdtemp()
86
+ local_path = os.path.join(tmp_dir, os.path.basename(urlparse(url).path))
87
+
88
+ try:
89
+ print(f"Downloading LoRA from {url}...")
90
+ resp = requests.get(url, stream=True)
91
+ resp.raise_for_status()
92
+ with open(local_path, "wb") as f:
93
+ for chunk in resp.iter_content(chunk_size=8192):
94
+ f.write(chunk)
95
+ print(f"Saved LoRA to {local_path}")
96
+ pipe.load_lora_weights(local_path)
97
+ finally:
98
+ shutil.rmtree(tmp_dir, ignore_errors=True)
99
 
100
  @spaces.GPU(duration=120)
101
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=4, num_inference_steps=28, lora_id=None, lora_scale=0.95, progress=gr.Progress(track_tqdm=True)):