Oysiyl commited on
Commit
00847d8
·
1 Parent(s): f7291d8

Remove progress tracking, add essential fixes only

Browse files

- Fix PyTorch indexing deprecation warning (tuple instead of list)
- Suppress torchsde floating-point precision warnings
- Add 1-hour automatic cache cleanup (delete_cache)
- Add app.queue() for proper Gradio functionality
- Remove all ProgressTracker and progress display code for simplicity

app.py CHANGED
@@ -1464,7 +1464,7 @@ def _pipeline_artistic(
1464
 
1465
  if __name__ == "__main__" and not os.environ.get("QR_TESTING_MODE"):
1466
  # Start your Gradio app
1467
- with gr.Blocks() as app:
1468
  # Add a title and description
1469
  gr.Markdown("# QR Code Art Generator")
1470
  gr.Markdown("""
@@ -2707,6 +2707,7 @@ if __name__ == "__main__" and not os.environ.get("QR_TESTING_MODE"):
2707
  )
2708
 
2709
  # ARTISTIC QR TAB
 
2710
  app.launch(share=False, mcp_server=True)
2711
  # Note: Automatic file cleanup via delete_cache not available in Gradio 5.49.1
2712
  # Files will be cleaned up when the server is restarted
 
1464
 
1465
  if __name__ == "__main__" and not os.environ.get("QR_TESTING_MODE"):
1466
  # Start your Gradio app
1467
+ with gr.Blocks(delete_cache=(3600, 3600)) as app:
1468
  # Add a title and description
1469
  gr.Markdown("# QR Code Art Generator")
1470
  gr.Markdown("""
 
2707
  )
2708
 
2709
  # ARTISTIC QR TAB
2710
+ app.queue()
2711
  app.launch(share=False, mcp_server=True)
2712
  # Note: Automatic file cleanup via delete_cache not available in Gradio 5.49.1
2713
  # Files will be cleaned up when the server is restarted
comfy/ldm/modules/sub_quadratic_attention.py CHANGED
@@ -9,36 +9,41 @@
9
  # Self-attention Does Not Need O(n2) Memory":
10
  # https://arxiv.org/abs/2112.05682v2
11
 
 
 
12
  from functools import partial
 
13
  import torch
14
  from torch import Tensor
15
  from torch.utils.checkpoint import checkpoint
16
- import math
17
- import logging
18
 
19
  try:
20
- from typing import Optional, NamedTuple, List, Protocol
21
  except ImportError:
22
- from typing import Optional, NamedTuple, List
 
23
  from typing_extensions import Protocol
24
 
25
  from typing import List
26
 
27
  from comfy import model_management
28
 
 
29
  def dynamic_slice(
30
  x: Tensor,
31
  starts: List[int],
32
  sizes: List[int],
33
  ) -> Tensor:
34
- slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
35
  return x[slicing]
36
 
 
37
  class AttnChunk(NamedTuple):
38
  exp_values: Tensor
39
  exp_weights_sum: Tensor
40
  max_score: Tensor
41
 
 
42
  class SummarizeChunk(Protocol):
43
  @staticmethod
44
  def __call__(
@@ -47,6 +52,7 @@ class SummarizeChunk(Protocol):
47
  value: Tensor,
48
  ) -> AttnChunk: ...
49
 
 
50
  class ComputeQueryChunkAttn(Protocol):
51
  @staticmethod
52
  def __call__(
@@ -55,6 +61,7 @@ class ComputeQueryChunkAttn(Protocol):
55
  value: Tensor,
56
  ) -> Tensor: ...
57
 
 
58
  def _summarize_chunk(
59
  query: Tensor,
60
  key_t: Tensor,
@@ -64,7 +71,7 @@ def _summarize_chunk(
64
  mask,
65
  ) -> AttnChunk:
66
  if upcast_attention:
67
- with torch.autocast(enabled=False, device_type = 'cuda'):
68
  query = query.float()
69
  key_t = key_t.float()
70
  attn_weights = torch.baddbmm(
@@ -93,6 +100,7 @@ def _summarize_chunk(
93
  max_score = max_score.squeeze(-1)
94
  return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
95
 
 
96
  def _query_chunk_attention(
97
  query: Tensor,
98
  key_t: Tensor,
@@ -108,15 +116,15 @@ def _query_chunk_attention(
108
  key_chunk = dynamic_slice(
109
  key_t,
110
  (0, 0, chunk_idx),
111
- (batch_x_heads, k_channels_per_head, kv_chunk_size)
112
  )
113
  value_chunk = dynamic_slice(
114
  value,
115
  (0, chunk_idx, 0),
116
- (batch_x_heads, kv_chunk_size, v_channels_per_head)
117
  )
118
  if mask is not None:
119
- mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size]
120
 
121
  return summarize_chunk(query, key_chunk, value_chunk, mask=mask)
122
 
@@ -135,6 +143,7 @@ def _query_chunk_attention(
135
  all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
136
  return all_values / all_weights
137
 
 
138
  # TODO: refactor CrossAttention#get_attention_scores to share code with this
139
  def _get_attention_scores_no_kv_chunking(
140
  query: Tensor,
@@ -145,7 +154,7 @@ def _get_attention_scores_no_kv_chunking(
145
  mask,
146
  ) -> Tensor:
147
  if upcast_attention:
148
- with torch.autocast(enabled=False, device_type = 'cuda'):
149
  query = query.float()
150
  key_t = key_t.float()
151
  attn_scores = torch.baddbmm(
@@ -170,8 +179,10 @@ def _get_attention_scores_no_kv_chunking(
170
  attn_probs = attn_scores.softmax(dim=-1)
171
  del attn_scores
172
  except model_management.OOM_EXCEPTION:
173
- logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
174
- attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
 
 
175
  torch.exp(attn_scores, out=attn_scores)
176
  summed = torch.sum(attn_scores, dim=-1, keepdim=True)
177
  attn_scores /= summed
@@ -180,10 +191,12 @@ def _get_attention_scores_no_kv_chunking(
180
  hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value)
181
  return hidden_states_slice
182
 
 
183
  class ScannedChunk(NamedTuple):
184
  chunk_idx: int
185
  attn_chunk: AttnChunk
186
 
 
187
  def efficient_dot_product_attention(
188
  query: Tensor,
189
  key_t: Tensor,
@@ -193,28 +206,28 @@ def efficient_dot_product_attention(
193
  kv_chunk_size_min: Optional[int] = None,
194
  use_checkpoint=True,
195
  upcast_attention=False,
196
- mask = None,
197
  ):
198
  """Computes efficient dot-product attention given query, transposed key, and value.
199
- This is efficient version of attention presented in
200
- https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
201
- Args:
202
- query: queries for calculating attention with shape of
203
- `[batch * num_heads, tokens, channels_per_head]`.
204
- key_t: keys for calculating attention with shape of
205
- `[batch * num_heads, channels_per_head, tokens]`.
206
- value: values to be used in attention with shape of
207
- `[batch * num_heads, tokens, channels_per_head]`.
208
- query_chunk_size: int: query chunks size
209
- kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
210
- kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
211
- use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
212
- Returns:
213
- Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
214
- """
215
  batch_x_heads, q_tokens, q_channels_per_head = query.shape
216
  _, _, k_tokens = key_t.shape
217
- scale = q_channels_per_head ** -0.5
218
 
219
  kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
220
  if kv_chunk_size_min is not None:
@@ -227,7 +240,7 @@ def efficient_dot_product_attention(
227
  return dynamic_slice(
228
  query,
229
  (0, chunk_idx, 0),
230
- (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
231
  )
232
 
233
  def get_mask_chunk(chunk_idx: int) -> Tensor:
@@ -236,20 +249,28 @@ def efficient_dot_product_attention(
236
  if mask.shape[1] == 1:
237
  return mask
238
  chunk = min(query_chunk_size, q_tokens)
239
- return mask[:,chunk_idx:chunk_idx + chunk]
240
-
241
- summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
242
- summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
243
- compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
244
- _get_attention_scores_no_kv_chunking,
245
- scale=scale,
246
- upcast_attention=upcast_attention
247
- ) if k_tokens <= kv_chunk_size else (
248
- # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
249
  partial(
250
- _query_chunk_attention,
251
- kv_chunk_size=kv_chunk_size,
252
- summarize_chunk=summarize_chunk,
 
 
 
 
 
 
 
 
 
253
  )
254
  )
255
 
@@ -264,12 +285,16 @@ def efficient_dot_product_attention(
264
 
265
  # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
266
  # and pass slices to be mutated, instead of torch.cat()ing the returned slices
267
- res = torch.cat([
268
- compute_query_chunk_attn(
269
- query=get_query_chunk(i * query_chunk_size),
270
- key_t=key_t,
271
- value=value,
272
- mask=get_mask_chunk(i * query_chunk_size)
273
- ) for i in range(math.ceil(q_tokens / query_chunk_size))
274
- ], dim=1)
 
 
 
 
275
  return res
 
9
  # Self-attention Does Not Need O(n2) Memory":
10
  # https://arxiv.org/abs/2112.05682v2
11
 
12
+ import logging
13
+ import math
14
  from functools import partial
15
+
16
  import torch
17
  from torch import Tensor
18
  from torch.utils.checkpoint import checkpoint
 
 
19
 
20
  try:
21
+ from typing import List, NamedTuple, Optional, Protocol
22
  except ImportError:
23
+ from typing import List, NamedTuple, Optional
24
+
25
  from typing_extensions import Protocol
26
 
27
  from typing import List
28
 
29
  from comfy import model_management
30
 
31
+
32
  def dynamic_slice(
33
  x: Tensor,
34
  starts: List[int],
35
  sizes: List[int],
36
  ) -> Tensor:
37
+ slicing = tuple(slice(start, start + size) for start, size in zip(starts, sizes))
38
  return x[slicing]
39
 
40
+
41
  class AttnChunk(NamedTuple):
42
  exp_values: Tensor
43
  exp_weights_sum: Tensor
44
  max_score: Tensor
45
 
46
+
47
  class SummarizeChunk(Protocol):
48
  @staticmethod
49
  def __call__(
 
52
  value: Tensor,
53
  ) -> AttnChunk: ...
54
 
55
+
56
  class ComputeQueryChunkAttn(Protocol):
57
  @staticmethod
58
  def __call__(
 
61
  value: Tensor,
62
  ) -> Tensor: ...
63
 
64
+
65
  def _summarize_chunk(
66
  query: Tensor,
67
  key_t: Tensor,
 
71
  mask,
72
  ) -> AttnChunk:
73
  if upcast_attention:
74
+ with torch.autocast(enabled=False, device_type="cuda"):
75
  query = query.float()
76
  key_t = key_t.float()
77
  attn_weights = torch.baddbmm(
 
100
  max_score = max_score.squeeze(-1)
101
  return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
102
 
103
+
104
  def _query_chunk_attention(
105
  query: Tensor,
106
  key_t: Tensor,
 
116
  key_chunk = dynamic_slice(
117
  key_t,
118
  (0, 0, chunk_idx),
119
+ (batch_x_heads, k_channels_per_head, kv_chunk_size),
120
  )
121
  value_chunk = dynamic_slice(
122
  value,
123
  (0, chunk_idx, 0),
124
+ (batch_x_heads, kv_chunk_size, v_channels_per_head),
125
  )
126
  if mask is not None:
127
+ mask = mask[:, :, chunk_idx : chunk_idx + kv_chunk_size]
128
 
129
  return summarize_chunk(query, key_chunk, value_chunk, mask=mask)
130
 
 
143
  all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
144
  return all_values / all_weights
145
 
146
+
147
  # TODO: refactor CrossAttention#get_attention_scores to share code with this
148
  def _get_attention_scores_no_kv_chunking(
149
  query: Tensor,
 
154
  mask,
155
  ) -> Tensor:
156
  if upcast_attention:
157
+ with torch.autocast(enabled=False, device_type="cuda"):
158
  query = query.float()
159
  key_t = key_t.float()
160
  attn_scores = torch.baddbmm(
 
179
  attn_probs = attn_scores.softmax(dim=-1)
180
  del attn_scores
181
  except model_management.OOM_EXCEPTION:
182
+ logging.warning(
183
+ "ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead"
184
+ )
185
+ attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
186
  torch.exp(attn_scores, out=attn_scores)
187
  summed = torch.sum(attn_scores, dim=-1, keepdim=True)
188
  attn_scores /= summed
 
191
  hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value)
192
  return hidden_states_slice
193
 
194
+
195
  class ScannedChunk(NamedTuple):
196
  chunk_idx: int
197
  attn_chunk: AttnChunk
198
 
199
+
200
  def efficient_dot_product_attention(
201
  query: Tensor,
202
  key_t: Tensor,
 
206
  kv_chunk_size_min: Optional[int] = None,
207
  use_checkpoint=True,
208
  upcast_attention=False,
209
+ mask=None,
210
  ):
211
  """Computes efficient dot-product attention given query, transposed key, and value.
212
+ This is efficient version of attention presented in
213
+ https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
214
+ Args:
215
+ query: queries for calculating attention with shape of
216
+ `[batch * num_heads, tokens, channels_per_head]`.
217
+ key_t: keys for calculating attention with shape of
218
+ `[batch * num_heads, channels_per_head, tokens]`.
219
+ value: values to be used in attention with shape of
220
+ `[batch * num_heads, tokens, channels_per_head]`.
221
+ query_chunk_size: int: query chunks size
222
+ kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
223
+ kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
224
+ use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
225
+ Returns:
226
+ Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
227
+ """
228
  batch_x_heads, q_tokens, q_channels_per_head = query.shape
229
  _, _, k_tokens = key_t.shape
230
+ scale = q_channels_per_head**-0.5
231
 
232
  kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
233
  if kv_chunk_size_min is not None:
 
240
  return dynamic_slice(
241
  query,
242
  (0, chunk_idx, 0),
243
+ (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head),
244
  )
245
 
246
  def get_mask_chunk(chunk_idx: int) -> Tensor:
 
249
  if mask.shape[1] == 1:
250
  return mask
251
  chunk = min(query_chunk_size, q_tokens)
252
+ return mask[:, chunk_idx : chunk_idx + chunk]
253
+
254
+ summarize_chunk: SummarizeChunk = partial(
255
+ _summarize_chunk, scale=scale, upcast_attention=upcast_attention
256
+ )
257
+ summarize_chunk: SummarizeChunk = (
258
+ partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
259
+ )
260
+ compute_query_chunk_attn: ComputeQueryChunkAttn = (
 
261
  partial(
262
+ _get_attention_scores_no_kv_chunking,
263
+ scale=scale,
264
+ upcast_attention=upcast_attention,
265
+ )
266
+ if k_tokens <= kv_chunk_size
267
+ # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
268
+ else (
269
+ partial(
270
+ _query_chunk_attention,
271
+ kv_chunk_size=kv_chunk_size,
272
+ summarize_chunk=summarize_chunk,
273
+ )
274
  )
275
  )
276
 
 
285
 
286
  # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
287
  # and pass slices to be mutated, instead of torch.cat()ing the returned slices
288
+ res = torch.cat(
289
+ [
290
+ compute_query_chunk_attn(
291
+ query=get_query_chunk(i * query_chunk_size),
292
+ key_t=key_t,
293
+ value=value,
294
+ mask=get_mask_chunk(i * query_chunk_size),
295
+ )
296
+ for i in range(math.ceil(q_tokens / query_chunk_size))
297
+ ],
298
+ dim=1,
299
+ )
300
  return res