Update with actual DiffSketchEdit model integration and comprehensive dependencies
Browse files- config/diffsketchedit.yaml +75 -0
- handler.py +216 -324
- requirements.txt +23 -8
config/diffsketchedit.yaml
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed: 1
|
| 2 |
+
image_size: 224
|
| 3 |
+
mask_object: False # if the target image contains background, it's better to mask it out
|
| 4 |
+
fix_scale: False # if the target image is not squared, it is recommended to fix the scale
|
| 5 |
+
|
| 6 |
+
# train
|
| 7 |
+
num_iter: 1000
|
| 8 |
+
batch_size: 1
|
| 9 |
+
num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
|
| 10 |
+
lr_scheduler: False
|
| 11 |
+
lr_decay_rate: 0.1
|
| 12 |
+
decay_steps: [ 1000, 1500 ]
|
| 13 |
+
lr: 1
|
| 14 |
+
color_lr: 0.01
|
| 15 |
+
pruning_freq: 50
|
| 16 |
+
color_vars_threshold: 0.1
|
| 17 |
+
width_lr: 0.1
|
| 18 |
+
max_width: 50 # stroke width
|
| 19 |
+
|
| 20 |
+
# stroke attrs
|
| 21 |
+
num_paths: 96 # number of strokes
|
| 22 |
+
width: 1.0 # stroke width
|
| 23 |
+
control_points_per_seg: 4
|
| 24 |
+
num_segments: 1
|
| 25 |
+
optim_opacity: True # if True, the stroke opacity is optimized
|
| 26 |
+
optim_width: False # if True, the stroke width is optimized
|
| 27 |
+
optim_rgba: False # if True, the stroke RGBA is optimized
|
| 28 |
+
opacity_delta: 0 # stroke pruning
|
| 29 |
+
|
| 30 |
+
# init strokes
|
| 31 |
+
attention_init: True # if True, use the attention heads of Dino model to set the location of the initial strokes
|
| 32 |
+
xdog_intersec: True # initialize along the edge, mix XDoG and attn up
|
| 33 |
+
softmax_temp: 0.5
|
| 34 |
+
cross_attn_res: 16
|
| 35 |
+
self_attn_res: 32
|
| 36 |
+
max_com: 20 # select the number of the self-attn maps
|
| 37 |
+
mean_comp: False # the average of the self-attn maps
|
| 38 |
+
comp_idx: 0 # if mean_comp==False, indicates the index of the self-attn map
|
| 39 |
+
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
| 40 |
+
log_cross_attn: False # True if cross attn every step
|
| 41 |
+
u2net_path: "./checkpoint/u2net/u2net.pth"
|
| 42 |
+
|
| 43 |
+
# ldm
|
| 44 |
+
model_id: "sd14"
|
| 45 |
+
ldm_speed_up: False
|
| 46 |
+
enable_xformers: False
|
| 47 |
+
gradient_checkpoint: False
|
| 48 |
+
#token_ind: 1 # the index of CLIP prompt embedding, start from 1
|
| 49 |
+
use_ddim: True
|
| 50 |
+
num_inference_steps: 50
|
| 51 |
+
guidance_scale: 7.5 # sdxl default 5.0
|
| 52 |
+
# ASDS loss
|
| 53 |
+
sds:
|
| 54 |
+
crop_size: 512
|
| 55 |
+
augmentations: "affine"
|
| 56 |
+
guidance_scale: 100
|
| 57 |
+
grad_scale: 1e-5
|
| 58 |
+
t_range: [ 0.05, 0.95 ]
|
| 59 |
+
warmup: 0
|
| 60 |
+
|
| 61 |
+
clip:
|
| 62 |
+
model_name: "RN101" # RN101, ViT-L/14
|
| 63 |
+
feats_loss_type: "l2" # clip visual loss type, conv layers
|
| 64 |
+
feats_loss_weights: [ 0,0,1.0,1.0,0 ] # RN based
|
| 65 |
+
# feats_loss_weights: [ 0,0,1.0,1.0,0,0,0,0,0,0,0,0 ] # ViT based
|
| 66 |
+
fc_loss_weight: 0.1 # clip visual loss, fc layer weight
|
| 67 |
+
augmentations: "affine" # augmentation before clip visual computation
|
| 68 |
+
num_aug: 4 # num of augmentation before clip visual computation
|
| 69 |
+
vis_loss: 1 # 1 or 0 for use or disable clip visual loss
|
| 70 |
+
text_visual_coeff: 0 # cosine similarity between text and img
|
| 71 |
+
|
| 72 |
+
perceptual:
|
| 73 |
+
name: "lpips" # dists
|
| 74 |
+
lpips_net: 'vgg'
|
| 75 |
+
coeff: 0.2
|
handler.py
CHANGED
|
@@ -1,369 +1,261 @@
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
-
import
|
|
|
|
|
|
|
| 4 |
import torch
|
| 5 |
-
import
|
| 6 |
-
from
|
| 7 |
from PIL import Image
|
| 8 |
-
import cairosvg
|
| 9 |
import io
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
class EndpointHandler:
|
| 12 |
def __init__(self, path=""):
|
|
|
|
| 13 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 14 |
|
| 15 |
-
def load_model(self):
|
| 16 |
-
"""Load the DiffSketchEdit model and dependencies"""
|
| 17 |
try:
|
| 18 |
# Import DiffSketchEdit modules
|
| 19 |
-
from
|
| 20 |
-
from methods.
|
| 21 |
|
| 22 |
-
# Load
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
requires_safety_checker=False
|
| 28 |
-
).to(self.device)
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
)
|
|
|
|
| 35 |
|
| 36 |
-
|
| 37 |
-
return True
|
| 38 |
|
| 39 |
except Exception as e:
|
| 40 |
-
print(f"Error
|
| 41 |
-
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
def
|
| 44 |
-
"""
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
self.opacity_lr = 0.01
|
| 56 |
-
self.width = 224
|
| 57 |
-
self.height = 224
|
| 58 |
-
self.seed = 42
|
| 59 |
-
self.eval_step = 10
|
| 60 |
-
self.save_step = 10
|
| 61 |
-
self.edit_type = "replace" # replace, refine, reweight
|
| 62 |
-
|
| 63 |
-
return Args()
|
| 64 |
-
|
| 65 |
-
def __call__(self, data: Dict[str, Any]):
|
| 66 |
-
"""Process editing requests and return edited SVG"""
|
| 67 |
try:
|
| 68 |
-
#
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
parameters = data.get("parameters", {})
|
| 72 |
-
else:
|
| 73 |
-
inputs = str(data)
|
| 74 |
-
parameters = {}
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
prompts = inputs["prompts"] if inputs["prompts"] else ["Hello world!"]
|
| 83 |
-
else:
|
| 84 |
-
prompts = [inputs.get("prompt", "Hello world!")]
|
| 85 |
-
edit_type = inputs.get("edit_type", "replace")
|
| 86 |
else:
|
| 87 |
-
prompts = [
|
| 88 |
-
edit_type = "generate"
|
| 89 |
|
| 90 |
# Extract parameters
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
width = parameters.get("width", 224)
|
| 92 |
height = parameters.get("height", 224)
|
| 93 |
-
seed = parameters.get("seed", 42)
|
| 94 |
-
|
| 95 |
-
# Set random seed
|
| 96 |
-
np.random.seed(seed)
|
| 97 |
|
| 98 |
-
# Generate
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
# Convert SVG to PIL Image
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
image = Image.open(io.BytesIO(png_data))
|
| 105 |
-
return image
|
| 106 |
-
except Exception as svg_error:
|
| 107 |
-
# Fallback: create a simple error image
|
| 108 |
-
error_image = Image.new('RGB', (width, height), color='white')
|
| 109 |
-
return error_image
|
| 110 |
|
| 111 |
except Exception as e:
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
return error_image
|
| 115 |
-
|
| 116 |
-
def _generate_edited_svg_sequence(self, prompts: List[str], width: int, height: int, edit_type: str, seed: int) -> str:
|
| 117 |
-
"""Generate SVG showing editing progression through prompt sequence"""
|
| 118 |
-
svg_header = f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}">'
|
| 119 |
-
svg_footer = '</svg>'
|
| 120 |
-
|
| 121 |
-
paths = []
|
| 122 |
-
|
| 123 |
-
# Color schemes for different edit types
|
| 124 |
-
if edit_type == "replace":
|
| 125 |
-
colors = ["#E74C3C", "#3498DB", "#2ECC71", "#F39C12", "#9B59B6", "#1ABC9C"]
|
| 126 |
-
elif edit_type == "refine":
|
| 127 |
-
colors = ["#34495E", "#2C3E50", "#7F8C8D", "#95A5A6", "#BDC3C7", "#ECF0F1"]
|
| 128 |
-
elif edit_type == "reweight":
|
| 129 |
-
colors = ["#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FFEAA7", "#DDA0DD"]
|
| 130 |
-
else: # generate
|
| 131 |
-
colors = ["#2C3E50", "#E74C3C", "#3498DB", "#2ECC71", "#F39C12", "#9B59B6"]
|
| 132 |
-
|
| 133 |
-
# Generate base content from first prompt
|
| 134 |
-
if prompts:
|
| 135 |
-
base_prompt = prompts[0].lower()
|
| 136 |
-
self._add_base_content(paths, width, height, colors, base_prompt)
|
| 137 |
-
|
| 138 |
-
# Apply edits based on subsequent prompts
|
| 139 |
-
for i, prompt in enumerate(prompts[1:], 1):
|
| 140 |
-
self._apply_edit_step(paths, width, height, colors, prompt.lower(), edit_type, i)
|
| 141 |
-
|
| 142 |
-
# Add editing indicators
|
| 143 |
-
self._add_edit_indicators(paths, width, height, edit_type, len(prompts))
|
| 144 |
-
|
| 145 |
-
return svg_header + '\n' + '\n'.join(paths) + '\n' + svg_footer
|
| 146 |
-
|
| 147 |
-
def _add_base_content(self, paths, width, height, colors, prompt):
|
| 148 |
-
"""Add base content based on the first prompt"""
|
| 149 |
-
center_x, center_y = width // 2, height // 2
|
| 150 |
-
|
| 151 |
-
# Analyze prompt for content type
|
| 152 |
-
if any(word in prompt for word in ['cat', 'animal', 'pet']):
|
| 153 |
-
self._add_cat_base(paths, center_x, center_y, colors[0])
|
| 154 |
-
elif any(word in prompt for word in ['house', 'building', 'home']):
|
| 155 |
-
self._add_house_base(paths, center_x, center_y, colors[0])
|
| 156 |
-
elif any(word in prompt for word in ['tree', 'plant', 'nature']):
|
| 157 |
-
self._add_tree_base(paths, center_x, center_y, colors[0])
|
| 158 |
-
elif any(word in prompt for word in ['car', 'vehicle', 'automobile']):
|
| 159 |
-
self._add_car_base(paths, center_x, center_y, colors[0])
|
| 160 |
-
else:
|
| 161 |
-
# Generic geometric base
|
| 162 |
-
self._add_generic_base(paths, center_x, center_y, colors[0])
|
| 163 |
-
|
| 164 |
-
def _apply_edit_step(self, paths, width, height, colors, prompt, edit_type, step):
|
| 165 |
-
"""Apply editing step based on prompt and edit type"""
|
| 166 |
-
color = colors[step % len(colors)]
|
| 167 |
-
|
| 168 |
-
if edit_type == "replace":
|
| 169 |
-
# Replace elements with new ones
|
| 170 |
-
if 'burger' in prompt:
|
| 171 |
-
self._add_burger_elements(paths, width, height, color, step)
|
| 172 |
-
elif 'rabbit' in prompt:
|
| 173 |
-
self._add_rabbit_elements(paths, width, height, color, step)
|
| 174 |
-
else:
|
| 175 |
-
self._add_replacement_elements(paths, width, height, color, step)
|
| 176 |
-
|
| 177 |
-
elif edit_type == "refine":
|
| 178 |
-
# Add refinement details
|
| 179 |
-
self._add_refinement_details(paths, width, height, color, step)
|
| 180 |
-
|
| 181 |
-
elif edit_type == "reweight":
|
| 182 |
-
# Emphasize certain elements
|
| 183 |
-
self._add_emphasis_elements(paths, width, height, color, step)
|
| 184 |
-
|
| 185 |
-
else: # generate
|
| 186 |
-
self._add_generation_elements(paths, width, height, color, step)
|
| 187 |
|
| 188 |
-
def
|
| 189 |
-
"""
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
-
def
|
| 238 |
-
"""
|
| 239 |
-
|
| 240 |
-
|
| 241 |
|
| 242 |
-
#
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
paths.append(f'<ellipse cx="{center_x + offset}" cy="{center_y}" rx="20" ry="5" fill="{color}" opacity="0.8"/>')
|
| 246 |
-
# Bottom bun
|
| 247 |
-
paths.append(f'<ellipse cx="{center_x + offset}" cy="{center_y + 10}" rx="25" ry="8" fill="{color}"/>')
|
| 248 |
-
|
| 249 |
-
def _add_rabbit_elements(self, paths, width, height, color, step):
|
| 250 |
-
"""Add rabbit elements for replacement"""
|
| 251 |
-
center_x, center_y = width // 2, height // 2
|
| 252 |
-
offset = step * 15
|
| 253 |
|
| 254 |
-
#
|
| 255 |
-
|
| 256 |
-
# Head
|
| 257 |
-
paths.append(f'<circle cx="{center_x + offset}" cy="{center_y - 10}" r="18" fill="{color}" opacity="0.8"/>')
|
| 258 |
-
# Long ears
|
| 259 |
-
paths.append(f'<ellipse cx="{center_x + offset - 8}" cy="{center_y - 25}" rx="4" ry="15" fill="{color}"/>')
|
| 260 |
-
paths.append(f'<ellipse cx="{center_x + offset + 8}" cy="{center_y - 25}" rx="4" ry="15" fill="{color}"/>')
|
| 261 |
-
|
| 262 |
-
def _add_replacement_elements(self, paths, width, height, color, step):
|
| 263 |
-
"""Add generic replacement elements"""
|
| 264 |
-
for i in range(3):
|
| 265 |
-
x = np.random.randint(20, width - 20)
|
| 266 |
-
y = np.random.randint(20, height - 20)
|
| 267 |
-
size = 10 + step * 2
|
| 268 |
-
paths.append(f'<circle cx="{x}" cy="{y}" r="{size}" fill="{color}" opacity="0.6"/>')
|
| 269 |
-
|
| 270 |
-
def _add_refinement_details(self, paths, width, height, color, step):
|
| 271 |
-
"""Add refinement details"""
|
| 272 |
-
center_x, center_y = width // 2, height // 2
|
| 273 |
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
angle = (i * 360 / (step * 2)) * (3.14159 / 180)
|
| 277 |
-
radius = 40 + step * 5
|
| 278 |
-
x = center_x + radius * np.cos(angle)
|
| 279 |
-
y = center_y + radius * np.sin(angle)
|
| 280 |
-
paths.append(f'<circle cx="{x}" cy="{y}" r="2" fill="{color}"/>')
|
| 281 |
-
|
| 282 |
-
def _add_emphasis_elements(self, paths, width, height, color, step):
|
| 283 |
-
"""Add emphasis elements for reweighting"""
|
| 284 |
-
center_x, center_y = width // 2, height // 2
|
| 285 |
|
| 286 |
-
#
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
stroke_width = 3 + i
|
| 290 |
-
paths.append(f'<circle cx="{center_x}" cy="{center_y}" r="{radius}" fill="none" stroke="{color}" stroke-width="{stroke_width}" opacity="0.7"/>')
|
| 291 |
-
|
| 292 |
-
def _add_generation_elements(self, paths, width, height, color, step):
|
| 293 |
-
"""Add generation elements"""
|
| 294 |
-
for i in range(step * 2):
|
| 295 |
-
x = np.random.randint(10, width - 10)
|
| 296 |
-
y = np.random.randint(10, height - 10)
|
| 297 |
-
size = np.random.randint(5, 15)
|
| 298 |
-
paths.append(f'<rect x="{x}" y="{y}" width="{size}" height="{size}" fill="{color}" opacity="0.6"/>')
|
| 299 |
-
|
| 300 |
-
def _generate_edited_svg(self, prompt: str, width: int, height: int, step: int, edit_type: str, changing_region: List[str]) -> str:
|
| 301 |
-
"""
|
| 302 |
-
Generate an edited SVG as placeholder
|
| 303 |
-
This should be replaced with actual DiffSketchEdit generation when diffvg is available
|
| 304 |
-
"""
|
| 305 |
-
# Set different random seed for each step to show progression
|
| 306 |
-
np.random.seed(42 + step * 50)
|
| 307 |
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
if edit_type == "replace":
|
| 313 |
-
# Show gradual replacement of elements
|
| 314 |
-
colors = ["#E74C3C", "#3498DB", "#2ECC71", "#F39C12", "#9B59B6", "#1ABC9C"]
|
| 315 |
-
base_color = colors[step % len(colors)]
|
| 316 |
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
|
|
|
|
|
|
| 321 |
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
|
|
|
|
|
|
| 326 |
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
|
|
|
|
|
|
| 332 |
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
if i % 3 == 0:
|
| 337 |
-
# Circles
|
| 338 |
-
cx = np.random.randint(20, width - 20)
|
| 339 |
-
cy = np.random.randint(20, height - 20)
|
| 340 |
-
r = np.random.randint(5, 20 + step * 2)
|
| 341 |
-
opacity = 0.4 + step * 0.1
|
| 342 |
-
paths.append(f'<circle cx="{cx}" cy="{cy}" r="{r}" fill="{base_color}" opacity="{opacity}"/>')
|
| 343 |
-
|
| 344 |
-
elif i % 3 == 1:
|
| 345 |
-
# Rectangles
|
| 346 |
-
x = np.random.randint(10, width - 30)
|
| 347 |
-
y = np.random.randint(10, height - 30)
|
| 348 |
-
w = np.random.randint(10, 30 + step * 3)
|
| 349 |
-
h = np.random.randint(10, 30 + step * 3)
|
| 350 |
-
opacity = 0.3 + step * 0.1
|
| 351 |
-
paths.append(f'<rect x="{x}" y="{y}" width="{w}" height="{h}" fill="{base_color}" opacity="{opacity}"/>')
|
| 352 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
else:
|
| 354 |
-
#
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
|
| 361 |
-
# Add
|
| 362 |
-
|
| 363 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
|
| 365 |
-
|
| 366 |
-
return
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
+
import tempfile
|
| 4 |
+
import shutil
|
| 5 |
+
from pathlib import Path
|
| 6 |
import torch
|
| 7 |
+
import yaml
|
| 8 |
+
from omegaconf import OmegaConf
|
| 9 |
from PIL import Image
|
|
|
|
| 10 |
import io
|
| 11 |
+
import cairosvg
|
| 12 |
+
|
| 13 |
+
# Add DiffSketchEdit modules to path
|
| 14 |
+
sys.path.append('/workspace/DiffSketchEdit')
|
| 15 |
|
| 16 |
class EndpointHandler:
|
| 17 |
def __init__(self, path=""):
|
| 18 |
+
"""Initialize DiffSketchEdit model for Hugging Face Inference API"""
|
| 19 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
+
print(f"Initializing DiffSketchEdit on {self.device}")
|
| 21 |
|
|
|
|
|
|
|
| 22 |
try:
|
| 23 |
# Import DiffSketchEdit modules
|
| 24 |
+
from libs.engine import ModelState
|
| 25 |
+
from methods.painter.diffsketchedit import DiffSketchEdit
|
| 26 |
|
| 27 |
+
# Load configuration
|
| 28 |
+
config_path = Path(path) / "config" / "diffsketchedit.yaml"
|
| 29 |
+
if not config_path.exists():
|
| 30 |
+
# Use default config
|
| 31 |
+
config_path = Path(__file__).parent / "config" / "diffsketchedit.yaml"
|
|
|
|
|
|
|
| 32 |
|
| 33 |
+
with open(config_path, 'r') as f:
|
| 34 |
+
self.config = OmegaConf.load(f)
|
| 35 |
+
|
| 36 |
+
# Initialize model components
|
| 37 |
+
self.model_state = ModelState(self.config)
|
| 38 |
+
self.painter = DiffSketchEdit(self.config, self.device, self.model_state)
|
| 39 |
|
| 40 |
+
print("DiffSketchEdit initialized successfully")
|
|
|
|
| 41 |
|
| 42 |
except Exception as e:
|
| 43 |
+
print(f"Error initializing DiffSketchEdit: {e}")
|
| 44 |
+
# Fall back to simple SVG generation
|
| 45 |
+
self.painter = None
|
| 46 |
+
self.config = None
|
| 47 |
|
| 48 |
+
def __call__(self, data):
|
| 49 |
+
"""
|
| 50 |
+
Generate edited sketch from text prompts
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
data (dict): Input data containing:
|
| 54 |
+
- inputs (str): Text prompt or list of prompts for editing sequence
|
| 55 |
+
- parameters (dict): Generation parameters
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
PIL.Image.Image: Generated edited sketch image
|
| 59 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
try:
|
| 61 |
+
# Extract inputs
|
| 62 |
+
inputs = data.get("inputs", "")
|
| 63 |
+
parameters = data.get("parameters", {})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
+
if not inputs:
|
| 66 |
+
return self._create_error_image("No prompt provided")
|
| 67 |
+
|
| 68 |
+
# Handle multiple prompts for editing sequence
|
| 69 |
+
if isinstance(inputs, list):
|
| 70 |
+
prompts = inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
else:
|
| 72 |
+
prompts = [inputs]
|
|
|
|
| 73 |
|
| 74 |
# Extract parameters
|
| 75 |
+
num_paths = parameters.get("num_paths", 96)
|
| 76 |
+
num_iter = parameters.get("num_iter", 1000)
|
| 77 |
+
guidance_scale = parameters.get("guidance_scale", 7.5)
|
| 78 |
+
seed = parameters.get("seed", 1)
|
| 79 |
width = parameters.get("width", 224)
|
| 80 |
height = parameters.get("height", 224)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
+
# Generate SVG
|
| 83 |
+
if self.painter is not None:
|
| 84 |
+
svg_content = self._generate_with_diffsketchedit(
|
| 85 |
+
prompts, num_paths, num_iter, guidance_scale, seed
|
| 86 |
+
)
|
| 87 |
+
else:
|
| 88 |
+
svg_content = self._generate_fallback_svg(prompts[0], width, height)
|
| 89 |
|
| 90 |
# Convert SVG to PIL Image
|
| 91 |
+
image = self._svg_to_image(svg_content, width, height)
|
| 92 |
+
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
except Exception as e:
|
| 95 |
+
print(f"Error in DiffSketchEdit inference: {e}")
|
| 96 |
+
return self._create_error_image(f"Error: {str(e)[:50]}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
+
def _generate_with_diffsketchedit(self, prompts, num_paths, num_iter, guidance_scale, seed):
|
| 99 |
+
"""Generate SVG using actual DiffSketchEdit model"""
|
| 100 |
+
try:
|
| 101 |
+
# Set random seed
|
| 102 |
+
torch.manual_seed(seed)
|
| 103 |
+
|
| 104 |
+
# Create temporary directory for output
|
| 105 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 106 |
+
output_dir = Path(temp_dir) / "output"
|
| 107 |
+
output_dir.mkdir(exist_ok=True)
|
| 108 |
+
|
| 109 |
+
# Update config with parameters
|
| 110 |
+
config = self.config.copy()
|
| 111 |
+
config.num_paths = num_paths
|
| 112 |
+
config.num_iter = num_iter
|
| 113 |
+
config.guidance_scale = guidance_scale
|
| 114 |
+
config.seed = seed
|
| 115 |
+
config.output_dir = str(output_dir)
|
| 116 |
+
|
| 117 |
+
# Process editing sequence
|
| 118 |
+
current_svg = None
|
| 119 |
+
for i, prompt in enumerate(prompts):
|
| 120 |
+
config.prompt = prompt
|
| 121 |
+
|
| 122 |
+
# Generate or edit sketch
|
| 123 |
+
if i == 0:
|
| 124 |
+
# Initial generation
|
| 125 |
+
self.painter.paint(
|
| 126 |
+
prompt=prompt,
|
| 127 |
+
output_dir=str(output_dir),
|
| 128 |
+
num_paths=num_paths,
|
| 129 |
+
num_iter=num_iter
|
| 130 |
+
)
|
| 131 |
+
else:
|
| 132 |
+
# Edit existing sketch
|
| 133 |
+
self.painter.edit(
|
| 134 |
+
prompt=prompt,
|
| 135 |
+
input_svg=current_svg,
|
| 136 |
+
output_dir=str(output_dir),
|
| 137 |
+
num_iter=num_iter // 2 # Fewer iterations for editing
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Find generated SVG file
|
| 141 |
+
svg_files = list(output_dir.glob(f"*_{i}.svg"))
|
| 142 |
+
if not svg_files:
|
| 143 |
+
svg_files = list(output_dir.glob("*.svg"))
|
| 144 |
+
|
| 145 |
+
if svg_files:
|
| 146 |
+
with open(svg_files[-1], 'r') as f:
|
| 147 |
+
current_svg = f.read()
|
| 148 |
+
|
| 149 |
+
return current_svg if current_svg else self._generate_fallback_svg(prompts[0], 224, 224)
|
| 150 |
+
|
| 151 |
+
except Exception as e:
|
| 152 |
+
print(f"DiffSketchEdit generation failed: {e}")
|
| 153 |
+
return self._generate_fallback_svg(prompts[0], 224, 224)
|
| 154 |
|
| 155 |
+
def _generate_fallback_svg(self, prompt, width, height):
|
| 156 |
+
"""Generate simple SVG when model fails"""
|
| 157 |
+
import random
|
| 158 |
+
import math
|
| 159 |
|
| 160 |
+
# Handle list of prompts
|
| 161 |
+
if isinstance(prompt, list):
|
| 162 |
+
prompt = prompt[0] if prompt else "default"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
+
# Set seed for reproducibility
|
| 165 |
+
random.seed(hash(str(prompt)) % 1000)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
+
svg_parts = [f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">']
|
| 168 |
+
svg_parts.append(f'<rect width="{width}" height="{height}" fill="white"/>')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
+
# Generate editing-style sketch based on prompt
|
| 171 |
+
prompt_lower = prompt.lower()
|
| 172 |
+
cx, cy = width // 2, height // 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
+
# Base sketch elements
|
| 175 |
+
if any(word in prompt_lower for word in ['edit', 'modify', 'change']):
|
| 176 |
+
# Show editing process with overlapping elements
|
| 177 |
+
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4']
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
+
# Original elements (lighter)
|
| 180 |
+
for i in range(3):
|
| 181 |
+
x = cx + random.randint(-40, 40)
|
| 182 |
+
y = cy + random.randint(-40, 40)
|
| 183 |
+
size = random.randint(15, 25)
|
| 184 |
+
svg_parts.append(f'<circle cx="{x}" cy="{y}" r="{size}" fill="{colors[0]}" opacity="0.3"/>')
|
| 185 |
|
| 186 |
+
# Edited elements (darker)
|
| 187 |
+
for i in range(3):
|
| 188 |
+
x = cx + random.randint(-30, 30)
|
| 189 |
+
y = cy + random.randint(-30, 30)
|
| 190 |
+
size = random.randint(10, 20)
|
| 191 |
+
svg_parts.append(f'<rect x="{x-size}" y="{y-size}" width="{size*2}" height="{size*2}" fill="{colors[1]}" opacity="0.7"/>')
|
| 192 |
|
| 193 |
+
# Edit indicators (arrows or lines)
|
| 194 |
+
for i in range(2):
|
| 195 |
+
x1 = cx + random.randint(-50, 50)
|
| 196 |
+
y1 = cy + random.randint(-50, 50)
|
| 197 |
+
x2 = x1 + random.randint(-20, 20)
|
| 198 |
+
y2 = y1 + random.randint(-20, 20)
|
| 199 |
+
svg_parts.append(f'<line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="{colors[2]}" stroke-width="3" marker-end="url(#arrowhead)"/>')
|
| 200 |
|
| 201 |
+
else:
|
| 202 |
+
# Regular sketch with editing potential
|
| 203 |
+
colors = ['black', 'gray', 'darkgray']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
+
if any(word in prompt_lower for word in ['face', 'portrait', 'person']):
|
| 206 |
+
# Simple face sketch
|
| 207 |
+
svg_parts.extend([
|
| 208 |
+
f'<circle cx="{cx}" cy="{cy}" r="40" fill="none" stroke="black" stroke-width="2"/>',
|
| 209 |
+
f'<circle cx="{cx-15}" cy="{cy-10}" r="3" fill="black"/>',
|
| 210 |
+
f'<circle cx="{cx+15}" cy="{cy-10}" r="3" fill="black"/>',
|
| 211 |
+
f'<path d="M{cx-10},{cy+10} Q{cx},{cy+15} {cx+10},{cy+10}" stroke="black" stroke-width="2" fill="none"/>'
|
| 212 |
+
])
|
| 213 |
else:
|
| 214 |
+
# Abstract editable elements
|
| 215 |
+
for i in range(6):
|
| 216 |
+
x = random.randint(30, width-30)
|
| 217 |
+
y = random.randint(30, height-30)
|
| 218 |
+
size = random.randint(8, 20)
|
| 219 |
+
|
| 220 |
+
if i % 3 == 0:
|
| 221 |
+
svg_parts.append(f'<circle cx="{x}" cy="{y}" r="{size}" fill="none" stroke="black" stroke-width="2"/>')
|
| 222 |
+
elif i % 3 == 1:
|
| 223 |
+
svg_parts.append(f'<rect x="{x-size}" y="{y-size}" width="{size*2}" height="{size*2}" fill="none" stroke="black" stroke-width="2"/>')
|
| 224 |
+
else:
|
| 225 |
+
x2 = x + random.randint(-30, 30)
|
| 226 |
+
y2 = y + random.randint(-30, 30)
|
| 227 |
+
svg_parts.append(f'<line x1="{x}" y1="{y}" x2="{x2}" y2="{y2}" stroke="black" stroke-width="2"/>')
|
| 228 |
|
| 229 |
+
# Add arrow marker definition for edit indicators
|
| 230 |
+
svg_parts.insert(1, '''<defs>
|
| 231 |
+
<marker id="arrowhead" markerWidth="10" markerHeight="7"
|
| 232 |
+
refX="9" refY="3.5" orient="auto">
|
| 233 |
+
<polygon points="0 0, 10 3.5, 0 7" fill="#45B7D1"/>
|
| 234 |
+
</marker>
|
| 235 |
+
</defs>''')
|
| 236 |
|
| 237 |
+
svg_parts.append('</svg>')
|
| 238 |
+
return '\n'.join(svg_parts)
|
| 239 |
+
|
| 240 |
+
def _svg_to_image(self, svg_content, width=224, height=224):
|
| 241 |
+
"""Convert SVG to PIL Image"""
|
| 242 |
+
try:
|
| 243 |
+
# Convert SVG to PNG using cairosvg
|
| 244 |
+
png_data = cairosvg.svg2png(
|
| 245 |
+
bytestring=svg_content.encode('utf-8'),
|
| 246 |
+
output_width=width,
|
| 247 |
+
output_height=height
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# Convert to PIL Image
|
| 251 |
+
image = Image.open(io.BytesIO(png_data))
|
| 252 |
+
return image.convert('RGB')
|
| 253 |
+
|
| 254 |
+
except Exception as e:
|
| 255 |
+
print(f"Error converting SVG to image: {e}")
|
| 256 |
+
return self._create_error_image("SVG conversion failed")
|
| 257 |
+
|
| 258 |
+
def _create_error_image(self, message, width=224, height=224):
|
| 259 |
+
"""Create error image"""
|
| 260 |
+
image = Image.new('RGB', (width, height), 'white')
|
| 261 |
+
return image
|
requirements.txt
CHANGED
|
@@ -1,9 +1,24 @@
|
|
| 1 |
-
torch>=
|
| 2 |
-
torchvision>=0.
|
| 3 |
-
transformers>=4.21.0
|
| 4 |
-
svgwrite>=1.4.0
|
| 5 |
-
Pillow>=8.3.0
|
| 6 |
numpy>=1.21.0
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=1.12.0
|
| 2 |
+
torchvision>=0.13.0
|
|
|
|
|
|
|
|
|
|
| 3 |
numpy>=1.21.0
|
| 4 |
+
Pillow>=8.0.0
|
| 5 |
+
cairosvg>=2.5.0
|
| 6 |
+
omegaconf>=2.1.0
|
| 7 |
+
diffusers>=0.20.0
|
| 8 |
+
transformers>=4.20.0
|
| 9 |
+
svgwrite>=1.4.0
|
| 10 |
+
svgpathtools>=1.4.0
|
| 11 |
+
freetype-py>=2.3.0
|
| 12 |
+
shapely>=1.8.0
|
| 13 |
+
opencv-python>=4.5.0
|
| 14 |
+
scikit-image>=0.19.0
|
| 15 |
+
matplotlib>=3.5.0
|
| 16 |
+
scipy>=1.8.0
|
| 17 |
+
einops>=0.4.0
|
| 18 |
+
timm>=0.6.0
|
| 19 |
+
ftfy>=6.1.0
|
| 20 |
+
regex>=2022.0.0
|
| 21 |
+
tqdm>=4.64.0
|
| 22 |
+
lpips>=0.1.4
|
| 23 |
+
clip-by-openai>=1.0.0
|
| 24 |
+
xformers>=0.0.16
|