MutantSparrow commited on
Commit
c192178
·
verified ·
1 Parent(s): efc6236

Upload app.py

Browse files

updated the app

Files changed (1) hide show
  1. app.py +30 -18
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import gradio as gr
3
  import spaces
@@ -5,12 +6,15 @@ import spaces
5
  from huggingface_hub import hf_hub_download
6
  from safetensors.torch import load_file
7
 
8
- from diffusers import ZImagePipeline, ZImageTransformer2DModel # Z-Image specific
9
 
10
  BASE_ID = "Tongyi-MAI/Z-Image-Turbo"
11
  CUSTOM_REPO = "MutantSparrow/Ray"
12
  CUSTOM_FILE = "Z-IMAGE-TURBO/Rayzist.v1.0.safetensors"
13
 
 
 
 
14
  pipe = None
15
 
16
  def load_pipe():
@@ -18,7 +22,6 @@ def load_pipe():
18
  if pipe is not None:
19
  return pipe
20
 
21
- # Load base components like the official demo
22
  transformer = ZImageTransformer2DModel.from_pretrained(
23
  BASE_ID,
24
  subfolder="transformer",
@@ -31,7 +34,6 @@ def load_pipe():
31
  torch_dtype=torch.bfloat16,
32
  ).to("cuda")
33
 
34
- # Now load your custom denoiser weights into the transformer
35
  ckpt_path = hf_hub_download(CUSTOM_REPO, CUSTOM_FILE)
36
  state = load_file(ckpt_path)
37
 
@@ -44,31 +46,41 @@ def load_pipe():
44
  return pipe
45
 
46
  @spaces.GPU
47
- def generate(prompt, steps=9, height=1024, width=1024, seed=0):
48
  p = load_pipe()
49
- g = torch.Generator("cuda").manual_seed(int(seed))
 
 
 
50
 
51
  img = p(
52
  prompt=prompt,
53
  height=int(height),
54
  width=int(width),
55
- num_inference_steps=int(steps),
56
- guidance_scale=0.0, # turbo-style in the official demo
57
  generator=g,
58
  ).images[0]
59
- return img
 
60
 
61
  with gr.Blocks() as demo:
62
- gr.Markdown("RAYZIST! A Z-Image Turbo Finetune")
63
- prompt = gr.Textbox(label="Prompt", lines=5)
64
- steps = gr.Slider(1, 12, value=8, step=1, label="Steps")
65
- width = gr.Dropdown([512, 768, 1024, 1280], value=1024, label="Width")
66
- height = gr.Dropdown([512, 768, 1024, 1280], value=1024, label="Height")
67
- seed = gr.Number(value=0, label="Seed")
68
- out = gr.Image(label="Result")
69
-
70
- btn = gr.Button("GO >")
71
- btn.click(generate, [prompt, steps, height, width, seed], out)
 
 
 
 
 
 
72
 
73
  demo.queue()
74
  demo.launch()
 
1
+ import random
2
  import torch
3
  import gradio as gr
4
  import spaces
 
6
  from huggingface_hub import hf_hub_download
7
  from safetensors.torch import load_file
8
 
9
+ from diffusers import ZImagePipeline, ZImageTransformer2DModel
10
 
11
  BASE_ID = "Tongyi-MAI/Z-Image-Turbo"
12
  CUSTOM_REPO = "MutantSparrow/Ray"
13
  CUSTOM_FILE = "Z-IMAGE-TURBO/Rayzist.v1.0.safetensors"
14
 
15
+ FIXED_STEPS = 8
16
+ GUIDANCE = 1.0
17
+
18
  pipe = None
19
 
20
  def load_pipe():
 
22
  if pipe is not None:
23
  return pipe
24
 
 
25
  transformer = ZImageTransformer2DModel.from_pretrained(
26
  BASE_ID,
27
  subfolder="transformer",
 
34
  torch_dtype=torch.bfloat16,
35
  ).to("cuda")
36
 
 
37
  ckpt_path = hf_hub_download(CUSTOM_REPO, CUSTOM_FILE)
38
  state = load_file(ckpt_path)
39
 
 
46
  return pipe
47
 
48
  @spaces.GPU
49
+ def generate(prompt, height, width):
50
  p = load_pipe()
51
+
52
+ # Random seed every run
53
+ seed = random.randint(0, 2**31 - 1)
54
+ g = torch.Generator("cuda").manual_seed(seed)
55
 
56
  img = p(
57
  prompt=prompt,
58
  height=int(height),
59
  width=int(width),
60
+ num_inference_steps=FIXED_STEPS,
61
+ guidance_scale=GUIDANCE,
62
  generator=g,
63
  ).images[0]
64
+
65
+ return img, seed
66
 
67
  with gr.Blocks() as demo:
68
+ gr.Markdown("Ray's Z-Image Turbo finetune: RAYZIST!")
69
+
70
+ prompt = gr.Textbox(label="Prompt", lines=3)
71
+ width = gr.Dropdown([512, 768, 1024, 1280, 1344], value=1024, label="Width")
72
+ height = gr.Dropdown([512, 768, 1024, 1280, 1344], value=1024, label="Height")
73
+
74
+ # Button ABOVE output
75
+ btn = gr.Button("GO>")
76
+ out = gr.Image(label="Your image")
77
+ seed_info = gr.Markdown()
78
+
79
+ def _run(prompt, height, width):
80
+ img, seed = generate(prompt, height, width)
81
+ return img, f"Seed: `{seed}`"
82
+
83
+ btn.click(_run, [prompt, height, width], [out, seed_info])
84
 
85
  demo.queue()
86
  demo.launch()