ftaubner commited on
Commit
7245cc5
·
1 Parent(s): bbe1648

initial commit

Browse files
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2024 CogVideo Model Team @ Zhipu AI
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
cogvideo_embeddings.py ADDED
The diff for this file is too large to render. See raw diff
 
cogvideo_transformer.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import PeftAdapterMixin
23
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
24
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
25
+ from diffusers.models.attention import Attention, FeedForward
26
+ from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
27
+ #from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
28
+ from cogvideo_embeddings import CogVideoXPatchEmbedWBlur, TimestepEmbedding, Timesteps
29
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
30
+ from diffusers.models.modeling_utils import ModelMixin
31
+ from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
32
+
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ @maybe_allow_in_graph
38
+ class CogVideoXBlock(nn.Module):
39
+ r"""
40
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
41
+
42
+ Parameters:
43
+ dim (`int`):
44
+ The number of channels in the input and output.
45
+ num_attention_heads (`int`):
46
+ The number of heads to use for multi-head attention.
47
+ attention_head_dim (`int`):
48
+ The number of channels in each head.
49
+ time_embed_dim (`int`):
50
+ The number of channels in timestep embedding.
51
+ dropout (`float`, defaults to `0.0`):
52
+ The dropout probability to use.
53
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
54
+ Activation function to be used in feed-forward.
55
+ attention_bias (`bool`, defaults to `False`):
56
+ Whether or not to use bias in attention projection layers.
57
+ qk_norm (`bool`, defaults to `True`):
58
+ Whether or not to use normalization after query and key projections in Attention.
59
+ norm_elementwise_affine (`bool`, defaults to `True`):
60
+ Whether to use learnable elementwise affine parameters for normalization.
61
+ norm_eps (`float`, defaults to `1e-5`):
62
+ Epsilon value for normalization layers.
63
+ final_dropout (`bool` defaults to `False`):
64
+ Whether to apply a final dropout after the last feed-forward layer.
65
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
66
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
67
+ ff_bias (`bool`, defaults to `True`):
68
+ Whether or not to use bias in Feed-forward layer.
69
+ attention_out_bias (`bool`, defaults to `True`):
70
+ Whether or not to use bias in Attention output projection layer.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ dim: int,
76
+ num_attention_heads: int,
77
+ attention_head_dim: int,
78
+ time_embed_dim: int,
79
+ dropout: float = 0.0,
80
+ activation_fn: str = "gelu-approximate",
81
+ attention_bias: bool = False,
82
+ qk_norm: bool = True,
83
+ norm_elementwise_affine: bool = True,
84
+ norm_eps: float = 1e-5,
85
+ final_dropout: bool = True,
86
+ ff_inner_dim: Optional[int] = None,
87
+ ff_bias: bool = True,
88
+ attention_out_bias: bool = True,
89
+ ):
90
+ super().__init__()
91
+
92
+ # 1. Self Attention
93
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
94
+
95
+ self.attn1 = Attention(
96
+ query_dim=dim,
97
+ dim_head=attention_head_dim,
98
+ heads=num_attention_heads,
99
+ qk_norm="layer_norm" if qk_norm else None,
100
+ eps=1e-6,
101
+ bias=attention_bias,
102
+ out_bias=attention_out_bias,
103
+ processor=CogVideoXAttnProcessor2_0(),
104
+ )
105
+
106
+ # 2. Feed Forward
107
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
108
+
109
+ self.ff = FeedForward(
110
+ dim,
111
+ dropout=dropout,
112
+ activation_fn=activation_fn,
113
+ final_dropout=final_dropout,
114
+ inner_dim=ff_inner_dim,
115
+ bias=ff_bias,
116
+ )
117
+
118
+ def forward(
119
+ self,
120
+ hidden_states: torch.Tensor,
121
+ encoder_hidden_states: torch.Tensor,
122
+ temb: torch.Tensor,
123
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
124
+ ) -> torch.Tensor:
125
+ text_seq_length = encoder_hidden_states.size(1)
126
+
127
+ # norm & modulate
128
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
129
+ hidden_states, encoder_hidden_states, temb
130
+ )
131
+
132
+ # attention
133
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
134
+ hidden_states=norm_hidden_states,
135
+ encoder_hidden_states=norm_encoder_hidden_states,
136
+ image_rotary_emb=image_rotary_emb,
137
+ )
138
+
139
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
140
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
141
+
142
+ # norm & modulate
143
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
144
+ hidden_states, encoder_hidden_states, temb
145
+ )
146
+
147
+ # feed-forward
148
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
149
+ ff_output = self.ff(norm_hidden_states)
150
+
151
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
152
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
153
+
154
+ return hidden_states, encoder_hidden_states
155
+
156
+
157
+ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
158
+ """
159
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
160
+
161
+ Parameters:
162
+ num_attention_heads (`int`, defaults to `30`):
163
+ The number of heads to use for multi-head attention.
164
+ attention_head_dim (`int`, defaults to `64`):
165
+ The number of channels in each head.
166
+ in_channels (`int`, defaults to `16`):
167
+ The number of channels in the input.
168
+ out_channels (`int`, *optional*, defaults to `16`):
169
+ The number of channels in the output.
170
+ flip_sin_to_cos (`bool`, defaults to `True`):
171
+ Whether to flip the sin to cos in the time embedding.
172
+ time_embed_dim (`int`, defaults to `512`):
173
+ Output dimension of timestep embeddings.
174
+ ofs_embed_dim (`int`, defaults to `512`):
175
+ Output dimension of "ofs" embeddings used in CogVideoX-5b-I2B in version 1.5
176
+ text_embed_dim (`int`, defaults to `4096`):
177
+ Input dimension of text embeddings from the text encoder.
178
+ num_layers (`int`, defaults to `30`):
179
+ The number of layers of Transformer blocks to use.
180
+ dropout (`float`, defaults to `0.0`):
181
+ The dropout probability to use.
182
+ attention_bias (`bool`, defaults to `True`):
183
+ Whether to use bias in the attention projection layers.
184
+ sample_width (`int`, defaults to `90`):
185
+ The width of the input latents.
186
+ sample_height (`int`, defaults to `60`):
187
+ The height of the input latents.
188
+ sample_frames (`int`, defaults to `49`):
189
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
190
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
191
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
192
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
193
+ patch_size (`int`, defaults to `2`):
194
+ The size of the patches to use in the patch embedding layer.
195
+ temporal_compression_ratio (`int`, defaults to `4`):
196
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
197
+ max_text_seq_length (`int`, defaults to `226`):
198
+ The maximum sequence length of the input text embeddings.
199
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
200
+ Activation function to use in feed-forward.
201
+ timestep_activation_fn (`str`, defaults to `"silu"`):
202
+ Activation function to use when generating the timestep embeddings.
203
+ norm_elementwise_affine (`bool`, defaults to `True`):
204
+ Whether to use elementwise affine in normalization layers.
205
+ norm_eps (`float`, defaults to `1e-5`):
206
+ The epsilon value to use in normalization layers.
207
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
208
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
209
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
210
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
211
+ """
212
+
213
+ _supports_gradient_checkpointing = True
214
+
215
+ @register_to_config
216
+ def __init__(
217
+ self,
218
+ num_attention_heads: int = 30,
219
+ attention_head_dim: int = 64,
220
+ in_channels: int = 16,
221
+ out_channels: Optional[int] = 16,
222
+ flip_sin_to_cos: bool = True,
223
+ freq_shift: int = 0,
224
+ time_embed_dim: int = 512,
225
+ ofs_embed_dim: Optional[int] = None,
226
+ text_embed_dim: int = 4096,
227
+ num_layers: int = 30,
228
+ dropout: float = 0.0,
229
+ attention_bias: bool = True,
230
+ sample_width: int = 90,
231
+ sample_height: int = 60,
232
+ sample_frames: int = 49,
233
+ patch_size: int = 2,
234
+ patch_size_t: Optional[int] = None,
235
+ temporal_compression_ratio: int = 4,
236
+ max_text_seq_length: int = 226,
237
+ activation_fn: str = "gelu-approximate",
238
+ timestep_activation_fn: str = "silu",
239
+ norm_elementwise_affine: bool = True,
240
+ norm_eps: float = 1e-5,
241
+ spatial_interpolation_scale: float = 1.875,
242
+ temporal_interpolation_scale: float = 1.0,
243
+ use_rotary_positional_embeddings: bool = False,
244
+ use_learned_positional_embeddings: bool = False,
245
+ patch_bias: bool = True,
246
+ ):
247
+ super().__init__()
248
+ inner_dim = num_attention_heads * attention_head_dim
249
+
250
+ if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
251
+ raise ValueError(
252
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
253
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
254
+ "issue at https://github.com/huggingface/diffusers/issues."
255
+ )
256
+
257
+ # 1. Patch embedding
258
+ self.patch_embed = CogVideoXPatchEmbedWBlur(
259
+ patch_size=patch_size,
260
+ patch_size_t=patch_size_t,
261
+ in_channels=in_channels,
262
+ embed_dim=inner_dim,
263
+ text_embed_dim=text_embed_dim,
264
+ bias=patch_bias,
265
+ sample_width=sample_width,
266
+ sample_height=sample_height,
267
+ sample_frames=sample_frames,
268
+ temporal_compression_ratio=temporal_compression_ratio,
269
+ max_text_seq_length=max_text_seq_length,
270
+ spatial_interpolation_scale=spatial_interpolation_scale,
271
+ temporal_interpolation_scale=temporal_interpolation_scale,
272
+ use_positional_embeddings=not use_rotary_positional_embeddings,
273
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
274
+ )
275
+ self.embedding_dropout = nn.Dropout(dropout)
276
+
277
+ # 2. Time embeddings and ofs embedding(Only CogVideoX1.5-5B I2V have)
278
+
279
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
280
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
281
+
282
+
283
+ self.ofs_proj = None
284
+ self.ofs_embedding = None
285
+ if ofs_embed_dim:
286
+ self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift)
287
+ self.ofs_embedding = TimestepEmbedding(
288
+ ofs_embed_dim, ofs_embed_dim, timestep_activation_fn
289
+ ) # same as time embeddings, for ofs
290
+
291
+ # 3. Define spatio-temporal transformers blocks
292
+ self.transformer_blocks = nn.ModuleList(
293
+ [
294
+ CogVideoXBlock(
295
+ dim=inner_dim,
296
+ num_attention_heads=num_attention_heads,
297
+ attention_head_dim=attention_head_dim,
298
+ time_embed_dim=time_embed_dim,
299
+ dropout=dropout,
300
+ activation_fn=activation_fn,
301
+ attention_bias=attention_bias,
302
+ norm_elementwise_affine=norm_elementwise_affine,
303
+ norm_eps=norm_eps,
304
+ )
305
+ for _ in range(num_layers)
306
+ ]
307
+ )
308
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
309
+
310
+ # 4. Output blocks
311
+ self.norm_out = AdaLayerNorm(
312
+ embedding_dim=time_embed_dim,
313
+ output_dim=2 * inner_dim,
314
+ norm_elementwise_affine=norm_elementwise_affine,
315
+ norm_eps=norm_eps,
316
+ chunk_dim=1,
317
+ )
318
+
319
+ if patch_size_t is None:
320
+ # For CogVideox 1.0
321
+ output_dim = patch_size * patch_size * out_channels
322
+ else:
323
+ # For CogVideoX 1.5
324
+ output_dim = patch_size * patch_size * patch_size_t * out_channels
325
+
326
+ self.proj_out = nn.Linear(inner_dim, output_dim)
327
+
328
+ self.gradient_checkpointing = False
329
+
330
+ def _set_gradient_checkpointing(self, module, value=False):
331
+ self.gradient_checkpointing = value
332
+
333
+ @property
334
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
335
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
336
+ r"""
337
+ Returns:
338
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
339
+ indexed by its weight name.
340
+ """
341
+ # set recursively
342
+ processors = {}
343
+
344
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
345
+ if hasattr(module, "get_processor"):
346
+ processors[f"{name}.processor"] = module.get_processor()
347
+
348
+ for sub_name, child in module.named_children():
349
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
350
+
351
+ return processors
352
+
353
+ for name, module in self.named_children():
354
+ fn_recursive_add_processors(name, module, processors)
355
+
356
+ return processors
357
+
358
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
359
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
360
+ r"""
361
+ Sets the attention processor to use to compute attention.
362
+
363
+ Parameters:
364
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
365
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
366
+ for **all** `Attention` layers.
367
+
368
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
369
+ processor. This is strongly recommended when setting trainable attention processors.
370
+
371
+ """
372
+ count = len(self.attn_processors.keys())
373
+
374
+ if isinstance(processor, dict) and len(processor) != count:
375
+ raise ValueError(
376
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
377
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
378
+ )
379
+
380
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
381
+ if hasattr(module, "set_processor"):
382
+ if not isinstance(processor, dict):
383
+ module.set_processor(processor)
384
+ else:
385
+ module.set_processor(processor.pop(f"{name}.processor"))
386
+
387
+ for sub_name, child in module.named_children():
388
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
389
+
390
+ for name, module in self.named_children():
391
+ fn_recursive_attn_processor(name, module, processor)
392
+
393
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
394
+ def fuse_qkv_projections(self):
395
+ """
396
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
397
+ are fused. For cross-attention modules, key and value projection matrices are fused.
398
+
399
+ <Tip warning={true}>
400
+
401
+ This API is 🧪 experimental.
402
+
403
+ </Tip>
404
+ """
405
+ self.original_attn_processors = None
406
+
407
+ for _, attn_processor in self.attn_processors.items():
408
+ if "Added" in str(attn_processor.__class__.__name__):
409
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
410
+
411
+ self.original_attn_processors = self.attn_processors
412
+
413
+ for module in self.modules():
414
+ if isinstance(module, Attention):
415
+ module.fuse_projections(fuse=True)
416
+
417
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
418
+
419
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
420
+ def unfuse_qkv_projections(self):
421
+ """Disables the fused QKV projection if enabled.
422
+
423
+ <Tip warning={true}>
424
+
425
+ This API is 🧪 experimental.
426
+
427
+ </Tip>
428
+
429
+ """
430
+ if self.original_attn_processors is not None:
431
+ self.set_attn_processor(self.original_attn_processors)
432
+
433
+ def forward(
434
+ self,
435
+ hidden_states: torch.Tensor,
436
+ encoder_hidden_states: torch.Tensor,
437
+ timestep: Union[int, float, torch.LongTensor],
438
+ intervals: Optional[torch.Tensor],
439
+ condition_mask: Optional[torch.Tensor] = None,
440
+ timestep_cond: Optional[torch.Tensor] = None,
441
+ ofs: Optional[Union[int, float, torch.LongTensor]] = None,
442
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
443
+ attention_kwargs: Optional[Dict[str, Any]] = None,
444
+ return_dict: bool = True,
445
+ ):
446
+
447
+ if attention_kwargs is not None:
448
+ attention_kwargs = attention_kwargs.copy()
449
+ lora_scale = attention_kwargs.pop("scale", 1.0)
450
+ else:
451
+ lora_scale = 1.0
452
+
453
+ if USE_PEFT_BACKEND:
454
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
455
+ scale_lora_layers(self, lora_scale)
456
+ else:
457
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
458
+ logger.warning(
459
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
460
+ )
461
+
462
+ batch_size, num_frames, channels, height, width = hidden_states.shape
463
+
464
+ # 1. Time embedding
465
+ timesteps = timestep
466
+ t_emb = self.time_proj(timesteps)
467
+
468
+ # timesteps does not contain any weights and will always return f32 tensors
469
+ # but time_embedding might actually be running in fp16. so we need to cast here.
470
+ # there might be better ways to encapsulate this.
471
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
472
+ emb = self.time_embedding(t_emb, timestep_cond)
473
+
474
+ if self.ofs_embedding is not None:
475
+ ofs_emb = self.ofs_proj(ofs)
476
+ ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
477
+ ofs_emb = self.ofs_embedding(ofs_emb)
478
+ emb = emb + ofs_emb
479
+
480
+ # 2. Patch embedding
481
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states, intervals, condition_mask)
482
+ hidden_states = self.embedding_dropout(hidden_states)
483
+
484
+ text_seq_length = encoder_hidden_states.shape[1]
485
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
486
+ hidden_states = hidden_states[:, text_seq_length:]
487
+
488
+ # 3. Transformer blocks
489
+ for i, block in enumerate(self.transformer_blocks):
490
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
491
+
492
+ def create_custom_forward(module):
493
+ def custom_forward(*inputs):
494
+ return module(*inputs)
495
+
496
+ return custom_forward
497
+
498
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
499
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
500
+ create_custom_forward(block),
501
+ hidden_states,
502
+ encoder_hidden_states,
503
+ emb,
504
+ image_rotary_emb,
505
+ **ckpt_kwargs,
506
+ )
507
+ else:
508
+ hidden_states, encoder_hidden_states = block(
509
+ hidden_states=hidden_states,
510
+ encoder_hidden_states=encoder_hidden_states,
511
+ temb=emb,
512
+ image_rotary_emb=image_rotary_emb,
513
+ )
514
+
515
+ if not self.config.use_rotary_positional_embeddings:
516
+ # CogVideoX-2B
517
+ hidden_states = self.norm_final(hidden_states)
518
+ else:
519
+ # CogVideoX-5B
520
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
521
+ hidden_states = self.norm_final(hidden_states)
522
+ hidden_states = hidden_states[:, text_seq_length:]
523
+
524
+ # 4. Final block
525
+ hidden_states = self.norm_out(hidden_states, temb=emb)
526
+ hidden_states = self.proj_out(hidden_states)
527
+
528
+ # 5. Unpatchify
529
+ p = self.config.patch_size
530
+ p_t = self.config.patch_size_t
531
+
532
+ if p_t is None:
533
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
534
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
535
+ else:
536
+ output = hidden_states.reshape(
537
+ batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
538
+ )
539
+ output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
540
+
541
+ if USE_PEFT_BACKEND:
542
+ # remove `lora_scale` from each PEFT layer
543
+ unscale_lora_layers(self, lora_scale)
544
+
545
+ if not return_dict:
546
+ return (output,)
547
+ return Transformer2DModelOutput(sample=output)
controlnet_pipeline.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from typing import Callable, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import numpy as np
7
+ from PIL import Image
8
+ from torchvision import transforms
9
+ from einops import rearrange, repeat
10
+ from transformers import T5EncoderModel, T5Tokenizer
11
+ from diffusers.video_processor import VideoProcessor
12
+ from diffusers.utils.torch_utils import randn_tensor
13
+ from diffusers.models.embeddings import get_3d_rotary_pos_embed
14
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
15
+ from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
16
+ from diffusers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
17
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
18
+ from diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipelineOutput, CogVideoXLoraLoaderMixin
19
+ from training.helpers import random_insert_latent_frame, transform_intervals
20
+ import torch.nn.functional as F
21
+ from torch.utils.checkpoint import checkpoint
22
+
23
+ def resize_for_crop(image, crop_h, crop_w):
24
+ img_h, img_w = image.shape[-2:]
25
+ if img_h >= crop_h and img_w >= crop_w:
26
+ coef = max(crop_h / img_h, crop_w / img_w)
27
+ elif img_h <= crop_h and img_w <= crop_w:
28
+ coef = max(crop_h / img_h, crop_w / img_w)
29
+ else:
30
+ coef = crop_h / img_h if crop_h > img_h else crop_w / img_w
31
+ out_h, out_w = int(img_h * coef), int(img_w * coef)
32
+ resized_image = transforms.functional.resize(image, (out_h, out_w), antialias=True)
33
+ return resized_image
34
+
35
+
36
+ def prepare_frames(input_images, video_size, do_resize=True, do_crop=True):
37
+ input_images = np.stack([np.array(x) for x in input_images])
38
+ images_tensor = torch.from_numpy(input_images).permute(0, 3, 1, 2) / 127.5 - 1
39
+ if do_resize:
40
+ images_tensor = [resize_for_crop(x, crop_h=video_size[0], crop_w=video_size[1]) for x in images_tensor]
41
+ if do_crop:
42
+ images_tensor = [transforms.functional.center_crop(x, video_size) for x in images_tensor]
43
+ if isinstance(images_tensor, list):
44
+ images_tensor = torch.stack(images_tensor)
45
+ return images_tensor.unsqueeze(0)
46
+
47
+
48
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
49
+ tw = tgt_width
50
+ th = tgt_height
51
+ h, w = src
52
+ r = h / w
53
+ if r > (th / tw):
54
+ resize_height = th
55
+ resize_width = int(round(th / h * w))
56
+ else:
57
+ resize_width = tw
58
+ resize_height = int(round(tw / w * h))
59
+
60
+ crop_top = int(round((th - resize_height) / 2.0))
61
+ crop_left = int(round((tw - resize_width) / 2.0))
62
+
63
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
64
+
65
+
66
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
67
+ def retrieve_timesteps(
68
+ scheduler,
69
+ num_inference_steps: Optional[int] = None,
70
+ device: Optional[Union[str, torch.device]] = None,
71
+ timesteps: Optional[List[int]] = None,
72
+ sigmas: Optional[List[float]] = None,
73
+ **kwargs,
74
+ ):
75
+ """
76
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
77
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
78
+
79
+ Args:
80
+ scheduler (`SchedulerMixin`):
81
+ The scheduler to get timesteps from.
82
+ num_inference_steps (`int`):
83
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
84
+ must be `None`.
85
+ device (`str` or `torch.device`, *optional*):
86
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
87
+ timesteps (`List[int]`, *optional*):
88
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
89
+ `num_inference_steps` and `sigmas` must be `None`.
90
+ sigmas (`List[float]`, *optional*):
91
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
92
+ `num_inference_steps` and `timesteps` must be `None`.
93
+
94
+ Returns:
95
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
96
+ second element is the number of inference steps.
97
+ """
98
+ if timesteps is not None and sigmas is not None:
99
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
100
+ if timesteps is not None:
101
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
102
+ if not accepts_timesteps:
103
+ raise ValueError(
104
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
105
+ f" timestep schedules. Please check whether you are using the correct scheduler."
106
+ )
107
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
108
+ timesteps = scheduler.timesteps
109
+ num_inference_steps = len(timesteps)
110
+ elif sigmas is not None:
111
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
112
+ if not accept_sigmas:
113
+ raise ValueError(
114
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
115
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
116
+ )
117
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
118
+ timesteps = scheduler.timesteps
119
+ num_inference_steps = len(timesteps)
120
+ else:
121
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
122
+ timesteps = scheduler.timesteps
123
+ return timesteps, num_inference_steps
124
+
125
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
126
+ def retrieve_latents(
127
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
128
+ ):
129
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
130
+ return encoder_output.latent_dist.sample(generator)
131
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
132
+ return encoder_output.latent_dist.mode()
133
+ elif hasattr(encoder_output, "latents"):
134
+ return encoder_output.latents
135
+ else:
136
+ raise AttributeError("Could not access latents of provided encoder_output")
137
+
138
+ class ControlnetCogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
139
+ _optional_components = []
140
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
141
+
142
+ _callback_tensor_inputs = [
143
+ "latents",
144
+ "prompt_embeds",
145
+ "negative_prompt_embeds",
146
+ ]
147
+
148
+ def __init__(
149
+ self,
150
+ tokenizer: T5Tokenizer,
151
+ text_encoder: T5EncoderModel,
152
+ vae: AutoencoderKLCogVideoX,
153
+ transformer: CogVideoXTransformer3DModel,
154
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
155
+ ):
156
+ super().__init__()
157
+
158
+ self.register_modules(
159
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
160
+ )
161
+ self.vae_scale_factor_spatial = (
162
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
163
+ )
164
+ self.vae_scale_factor_temporal = (
165
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
166
+ )
167
+
168
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
169
+
170
+
171
+
172
+ def _get_t5_prompt_embeds(
173
+ self,
174
+ prompt: Union[str, List[str]] = None,
175
+ num_videos_per_prompt: int = 1,
176
+ max_sequence_length: int = 226,
177
+ device: Optional[torch.device] = None,
178
+ dtype: Optional[torch.dtype] = None,
179
+ ):
180
+ device = device or self._execution_device
181
+ dtype = dtype or self.text_encoder.dtype
182
+
183
+ prompt = [prompt] if isinstance(prompt, str) else prompt
184
+ batch_size = len(prompt)
185
+
186
+ text_inputs = self.tokenizer(
187
+ prompt,
188
+ padding="max_length",
189
+ max_length=max_sequence_length,
190
+ truncation=True,
191
+ add_special_tokens=True,
192
+ return_tensors="pt",
193
+ )
194
+ text_input_ids = text_inputs.input_ids
195
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
196
+
197
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
198
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
199
+ logger.warning(
200
+ "The following part of your input was truncated because `max_sequence_length` is set to "
201
+ f" {max_sequence_length} tokens: {removed_text}"
202
+ )
203
+
204
+ # Had to disable auto cast here, otherwise the text encoder produces NaNs.
205
+ # Hope it doesn't break training
206
+ with torch.autocast(device_type=device.type, enabled=False):
207
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
208
+ # prompt embeds is nan here!
209
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
210
+
211
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
212
+ _, seq_len, _ = prompt_embeds.shape
213
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
214
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
215
+
216
+ return prompt_embeds
217
+
218
+ def encode_prompt(
219
+ self,
220
+ prompt: Union[str, List[str]],
221
+ negative_prompt: Optional[Union[str, List[str]]] = None,
222
+ do_classifier_free_guidance: bool = True,
223
+ num_videos_per_prompt: int = 1,
224
+ prompt_embeds: Optional[torch.Tensor] = None,
225
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
226
+ max_sequence_length: int = 226,
227
+ device: Optional[torch.device] = None,
228
+ dtype: Optional[torch.dtype] = None,
229
+ ):
230
+ r"""
231
+ Encodes the prompt into text encoder hidden states.
232
+
233
+ Args:
234
+ prompt (`str` or `List[str]`, *optional*):
235
+ prompt to be encoded
236
+ negative_prompt (`str` or `List[str]`, *optional*):
237
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
238
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
239
+ less than `1`).
240
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
241
+ Whether to use classifier free guidance or not.
242
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
243
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
244
+ prompt_embeds (`torch.Tensor`, *optional*):
245
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
246
+ provided, text embeddings will be generated from `prompt` input argument.
247
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
248
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
249
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
250
+ argument.
251
+ device: (`torch.device`, *optional*):
252
+ torch device
253
+ dtype: (`torch.dtype`, *optional*):
254
+ torch dtype
255
+ """
256
+ device = device or self._execution_device
257
+
258
+ prompt = [prompt] if isinstance(prompt, str) else prompt
259
+ if prompt is not None:
260
+ batch_size = len(prompt)
261
+ else:
262
+ batch_size = prompt_embeds.shape[0]
263
+
264
+ if prompt_embeds is None:
265
+ prompt_embeds = self._get_t5_prompt_embeds(
266
+ prompt=prompt,
267
+ num_videos_per_prompt=num_videos_per_prompt,
268
+ max_sequence_length=max_sequence_length,
269
+ device=device,
270
+ dtype=dtype,
271
+ )
272
+
273
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
274
+ negative_prompt = negative_prompt or ""
275
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
276
+
277
+ if prompt is not None and type(prompt) is not type(negative_prompt):
278
+ raise TypeError(
279
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
280
+ f" {type(prompt)}."
281
+ )
282
+ elif batch_size != len(negative_prompt):
283
+ raise ValueError(
284
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
285
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
286
+ " the batch size of `prompt`."
287
+ )
288
+
289
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
290
+ prompt=negative_prompt,
291
+ num_videos_per_prompt=num_videos_per_prompt,
292
+ max_sequence_length=max_sequence_length,
293
+ device=device,
294
+ dtype=dtype,
295
+ )
296
+
297
+ return prompt_embeds, negative_prompt_embeds
298
+
299
+ def prepare_latents(
300
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
301
+ ):
302
+ shape = (
303
+ batch_size,
304
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
305
+ num_channels_latents,
306
+ height // self.vae_scale_factor_spatial,
307
+ width // self.vae_scale_factor_spatial,
308
+ )
309
+ if isinstance(generator, list) and len(generator) != batch_size:
310
+ raise ValueError(
311
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
312
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
313
+ )
314
+
315
+ if latents is None:
316
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
317
+ else:
318
+ latents = latents.to(device)
319
+
320
+ # scale the initial noise by the standard deviation required by the scheduler
321
+ latents = latents * self.scheduler.init_noise_sigma
322
+ return latents
323
+
324
+
325
+
326
+ def prepare_image_latents(self,
327
+ image: torch.Tensor,
328
+ batch_size: int = 1,
329
+ num_channels_latents: int = 16,
330
+ num_frames: int = 13,
331
+ height: int = 60,
332
+ width: int = 90,
333
+ dtype: Optional[torch.dtype] = None,
334
+ device: Optional[torch.device] = None,
335
+ generator: Optional[torch.Generator] = None,
336
+ latents: Optional[torch.Tensor] = None,):
337
+
338
+ image_prepared = prepare_frames(image, (height, width)).to(device).to(dtype=dtype).permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
339
+
340
+ image_latents = [retrieve_latents(self.vae.encode(image_prepared), generator)]
341
+
342
+ image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
343
+
344
+ if not self.vae.config.invert_scale_latents:
345
+ image_latents = self.vae_scaling_factor_image * image_latents
346
+ else:
347
+ # This is awkward but required because the CogVideoX team forgot to multiply the
348
+ # scaling factor during training :)
349
+ image_latents = 1 / self.vae_scaling_factor_image * image_latents
350
+
351
+ # else:
352
+ # # This is awkward but required because the CogVideoX team forgot to multiply the
353
+ # # scaling factor during training :)
354
+ # image_latents = 1 / self.vae_scaling_factor_image * image_latents
355
+
356
+ return image_prepared, image_latents
357
+
358
+ # def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
359
+ # latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
360
+ # latents = 1 / self.vae.config.scaling_factor * latents
361
+
362
+ # frames = self.vae.decode(latents).sample
363
+ # return frames
364
+
365
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
366
+ latents = latents.permute(0, 2, 1, 3, 4) # [B, C, T, H, W]
367
+ latents = 1 / self.vae.config.scaling_factor * latents
368
+
369
+ def decode_fn(x):
370
+ return self.vae.decode(x).sample
371
+
372
+ # Use checkpointing to save memory
373
+ frames = checkpoint(decode_fn, latents, use_reentrant=False)
374
+ return frames
375
+
376
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
377
+ def prepare_extra_step_kwargs(self, generator, eta):
378
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
379
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
380
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
381
+ # and should be between [0, 1]
382
+
383
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
384
+ extra_step_kwargs = {}
385
+ if accepts_eta:
386
+ extra_step_kwargs["eta"] = eta
387
+
388
+ # check if the scheduler accepts generator
389
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
390
+ if accepts_generator:
391
+ extra_step_kwargs["generator"] = generator
392
+ return extra_step_kwargs
393
+
394
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
395
+ def check_inputs(
396
+ self,
397
+ prompt,
398
+ height,
399
+ width,
400
+ negative_prompt,
401
+ callback_on_step_end_tensor_inputs,
402
+ prompt_embeds=None,
403
+ negative_prompt_embeds=None,
404
+ ):
405
+ if height % 8 != 0 or width % 8 != 0:
406
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
407
+
408
+ if callback_on_step_end_tensor_inputs is not None and not all(
409
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
410
+ ):
411
+ raise ValueError(
412
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
413
+ )
414
+ if prompt is not None and prompt_embeds is not None:
415
+ raise ValueError(
416
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
417
+ " only forward one of the two."
418
+ )
419
+ elif prompt is None and prompt_embeds is None:
420
+ raise ValueError(
421
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
422
+ )
423
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
424
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
425
+
426
+ if prompt is not None and negative_prompt_embeds is not None:
427
+ raise ValueError(
428
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
429
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
430
+ )
431
+
432
+ if negative_prompt is not None and negative_prompt_embeds is not None:
433
+ raise ValueError(
434
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
435
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
436
+ )
437
+
438
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
439
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
440
+ raise ValueError(
441
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
442
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
443
+ f" {negative_prompt_embeds.shape}."
444
+ )
445
+ def fuse_qkv_projections(self) -> None:
446
+ r"""Enables fused QKV projections."""
447
+ self.fusing_transformer = True
448
+ self.transformer.fuse_qkv_projections()
449
+
450
+ def unfuse_qkv_projections(self) -> None:
451
+ r"""Disable QKV projection fusion if enabled."""
452
+ if not self.fusing_transformer:
453
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
454
+ else:
455
+ self.transformer.unfuse_qkv_projections()
456
+ self.fusing_transformer = False
457
+
458
+ def _prepare_rotary_positional_embeddings(
459
+ self,
460
+ height: int,
461
+ width: int,
462
+ num_frames: int,
463
+ device: torch.device,
464
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
465
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
466
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
467
+ base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
468
+ base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
469
+
470
+ grid_crops_coords = get_resize_crop_region_for_grid(
471
+ (grid_height, grid_width), base_size_width, base_size_height
472
+ )
473
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
474
+ embed_dim=self.transformer.config.attention_head_dim,
475
+ crops_coords=grid_crops_coords,
476
+ grid_size=(grid_height, grid_width),
477
+ temporal_size=num_frames,
478
+ )
479
+
480
+ freqs_cos = freqs_cos.to(device=device)
481
+ freqs_sin = freqs_sin.to(device=device)
482
+ return freqs_cos, freqs_sin
483
+
484
+ @property
485
+ def guidance_scale(self):
486
+ return self._guidance_scale
487
+
488
+ @property
489
+ def num_timesteps(self):
490
+ return self._num_timesteps
491
+
492
+ @property
493
+ def attention_kwargs(self):
494
+ return self._attention_kwargs
495
+
496
+ @property
497
+ def interrupt(self):
498
+ return self._interrupt
499
+
500
+ @torch.no_grad()
501
+ def __call__(
502
+ self,
503
+ image,
504
+ input_intervals,
505
+ output_intervals,
506
+ prompt: Optional[Union[str, List[str]]] = None,
507
+ negative_prompt: Optional[Union[str, List[str]]] = None,
508
+ height: int = 480,
509
+ width: int = 720,
510
+ num_frames: int = 49,
511
+ num_inference_steps: int = 50,
512
+ timesteps: Optional[List[int]] = None,
513
+ guidance_scale: float = 6,
514
+ use_dynamic_cfg: bool = False,
515
+ num_videos_per_prompt: int = 1,
516
+ eta: float = 0.0,
517
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
518
+ latents: Optional[torch.FloatTensor] = None,
519
+ prompt_embeds: Optional[torch.FloatTensor] = None,
520
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
521
+ output_type: str = "pil",
522
+ return_dict: bool = True,
523
+ callback_on_step_end: Optional[
524
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
525
+ ] = None,
526
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
527
+ max_sequence_length: int = 226,
528
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
529
+ if num_frames > 49:
530
+ raise ValueError(
531
+ "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
532
+ )
533
+
534
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
535
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
536
+
537
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
538
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
539
+ num_videos_per_prompt = 1
540
+
541
+ self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
542
+
543
+
544
+ # 1. Check inputs. Raise error if not correct
545
+ self.check_inputs(
546
+ prompt,
547
+ height,
548
+ width,
549
+ negative_prompt,
550
+ callback_on_step_end_tensor_inputs,
551
+ prompt_embeds,
552
+ negative_prompt_embeds,
553
+ )
554
+ self._guidance_scale = guidance_scale
555
+ self._interrupt = False
556
+
557
+ # 2. Default call parameters
558
+ if prompt is not None and isinstance(prompt, str):
559
+ batch_size = 1
560
+ elif prompt is not None and isinstance(prompt, list):
561
+ batch_size = len(prompt)
562
+ else:
563
+ batch_size = prompt_embeds.shape[0]
564
+
565
+ device = self._execution_device
566
+
567
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
568
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
569
+ # corresponds to doing no classifier free guidance.
570
+ do_classifier_free_guidance = guidance_scale > 1.0
571
+
572
+ # 3. Encode input prompt
573
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
574
+ prompt,
575
+ negative_prompt,
576
+ do_classifier_free_guidance,
577
+ num_videos_per_prompt=num_videos_per_prompt,
578
+ prompt_embeds=prompt_embeds,
579
+ negative_prompt_embeds=negative_prompt_embeds,
580
+ max_sequence_length=max_sequence_length,
581
+ device=device,
582
+ )
583
+ if do_classifier_free_guidance:
584
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
585
+
586
+ # 4. Prepare timesteps
587
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
588
+ self._num_timesteps = len(timesteps)
589
+
590
+ # 5. Prepare latents.
591
+ latent_channels = 16 #self.transformer.config.in_channels
592
+ latents = self.prepare_latents(
593
+ batch_size * num_videos_per_prompt,
594
+ latent_channels,
595
+ num_frames,
596
+ height,
597
+ width,
598
+ prompt_embeds.dtype,
599
+ device,
600
+ generator,
601
+ latents,
602
+ )
603
+
604
+
605
+
606
+ image_prepared, image_latents = self.prepare_image_latents(
607
+ image,
608
+ batch_size=batch_size,
609
+ num_channels_latents=latent_channels,
610
+ num_frames=num_frames,
611
+ height=height,
612
+ width=width,
613
+ dtype=prompt_embeds.dtype,
614
+ device=device,
615
+ generator=generator,
616
+ )
617
+
618
+
619
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
620
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
621
+
622
+ # 8. Create rotary embeds if required - THIS IS NOT USED
623
+ image_rotary_emb = (
624
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
625
+ if self.transformer.config.use_rotary_positional_embeddings
626
+ else None
627
+ )
628
+
629
+ # 9. Denoising loop
630
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
631
+
632
+
633
+ input_intervals = input_intervals.to(device)
634
+ output_intervals = output_intervals.to(device)
635
+
636
+ input_intervals = transform_intervals(input_intervals)
637
+ output_intervals = transform_intervals(output_intervals)
638
+
639
+ latents_initial, target, condition_mask, intervals = random_insert_latent_frame(image_latents, latents, latents, input_intervals, output_intervals, special_info="just_one")
640
+
641
+ latents = latents_initial.clone()
642
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
643
+ # for DPM-solver++
644
+ old_pred_original_sample = None
645
+ for i, t in enumerate(timesteps):
646
+ if self.interrupt:
647
+ continue
648
+
649
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
650
+ #replace first latent with image_latents
651
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
652
+
653
+ if do_classifier_free_guidance:
654
+ latent_model_input[0][condition_mask[0]] = 0 #set unconditioned latents to 0
655
+ #TODO: Replace the conditional latents with the input latents
656
+ latent_model_input[1][condition_mask[0]] = latents_initial[0][condition_mask[0]].to(latent_model_input.dtype)
657
+ else:
658
+ latent_model_input[:, condition_mask[0]] = latents_initial[0][condition_mask[0]].to(latent_model_input.dtype)
659
+
660
+ timestep = t.expand(latent_model_input.shape[0])
661
+
662
+ current_sampling_percent = i / len(timesteps)
663
+
664
+ latent_model_input = latent_model_input.to(dtype=self.transformer.dtype)
665
+ prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
666
+ # predict noise model_output
667
+ noise_pred = self.transformer(
668
+ hidden_states=latent_model_input,
669
+ encoder_hidden_states=prompt_embeds,
670
+ timestep=timestep,
671
+ intervals=intervals,
672
+ condition_mask=condition_mask,
673
+ image_rotary_emb=image_rotary_emb,
674
+ return_dict=False,
675
+ )[0]
676
+ noise_pred = noise_pred.float()
677
+
678
+ # perform guidance
679
+ if use_dynamic_cfg:
680
+ self._guidance_scale = 1 + guidance_scale * (
681
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
682
+ )
683
+ if do_classifier_free_guidance:
684
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
685
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
686
+ #so I think the problem is that the conditional noise doesn't have a realistic noise prediction on its own frame
687
+ #what I really need to do is replace the unconditional noise at that frame
688
+
689
+ # compute the previous noisy sample x_t -> x_t-1
690
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
691
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
692
+ else:
693
+ latents, old_pred_original_sample = self.scheduler.step(
694
+ noise_pred,
695
+ old_pred_original_sample,
696
+ t,
697
+ timesteps[i - 1] if i > 0 else None,
698
+ latents,
699
+ **extra_step_kwargs,
700
+ return_dict=False,
701
+ )
702
+ latents = latents.to(prompt_embeds.dtype)
703
+
704
+ # call the callback, if provided
705
+ if callback_on_step_end is not None:
706
+ callback_kwargs = {}
707
+ for k in callback_on_step_end_tensor_inputs:
708
+ callback_kwargs[k] = locals()[k]
709
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
710
+
711
+ latents = callback_outputs.pop("latents", latents)
712
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
713
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
714
+
715
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
716
+ progress_bar.update()
717
+
718
+ #after exiting replace the conditioning latent with image_latents
719
+ #latents[:, motion_blur_amount:motion_blur_amount+1] = image_latents[:, 0:1]
720
+ if not output_type == "latent":
721
+ latents = latents[~condition_mask].unsqueeze(0)
722
+ video = self.decode_latents(latents)
723
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
724
+ else:
725
+ video = latents
726
+
727
+ # Offload all models
728
+ self.maybe_free_model_hooks()
729
+
730
+ if not return_dict:
731
+ return (video,)
732
+
733
+ return CogVideoXPipelineOutput(frames=video)
extra/checkpoints_to_hf.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import HfApi
2
+ import os
3
+ #run with HF_TOKEN = your_hf_token before python_command
4
+ api = HfApi(token=os.getenv("HF_TOKEN"))
5
+ folders = ["/datasets/sai/blur2vid/training/cogvideox-baist-test",
6
+ "/datasets/sai/blur2vid/training/cogvideox-gopro-test",
7
+ "/datasets/sai/blur2vid/training/cogvideox-gopro-2x-test",
8
+ "/datasets/sai/blur2vid/training/cogvideox-full-test",
9
+ "/datasets/sai/blur2vid/training/cogvideox-outsidephotos"]
10
+ for folder in folders:
11
+ api.upload_folder(
12
+ folder_path=folder,
13
+ repo_id="tedlasai/blur2vid",
14
+ repo_type="model",
15
+ path_in_repo=os.path.basename(folder)
16
+ )
extra/moMets-parallel-baist.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Motion Metrics
2
+ from concurrent.futures import ProcessPoolExecutor, as_completed
3
+ import numpy as np
4
+ np.float = np.float64
5
+ np.int = np.int_
6
+ import os
7
+ from cdfvd import fvd
8
+ from skimage.metrics import structural_similarity
9
+ import torch
10
+ import lpips
11
+ #from DISTS_pytorch import DISTS
12
+ #import colour as c
13
+ #from torchmetrics.image.fid import FrechetInceptionDistance
14
+ import torch.nn.functional as F
15
+ from epe_metric import compute_bidirectional_epe as epe
16
+ import pdb
17
+ import multiprocessing
18
+ import cv2
19
+ import glob
20
+ # init
21
+ dataDir = 'BAISTResultsImages' # 'dataGoPro' #
22
+ gtDir = 'GT' #'GT' #
23
+ methodDirs = ['Ours', 'Animation-from-blur'] #['Favaro','MotionETR','Ours','GOPROGeneralize'] #
24
+ depth = 8
25
+ resFile = './kellytest.npy'#resultsGoPro20250520.npy'#
26
+
27
+ patchDim = 32 #64 #
28
+ pixMax = 1.0
29
+
30
+ nMets = 7 # new results: scoreFVD, scorePWPSNR, scoreEPE, scorePatchSSIM, scorePatchLPIPS, scorePSNR
31
+ compute = True # if False, load previously computed
32
+ eps = 1e-8
33
+
34
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35
+
36
+ def read_pngs_to_array(path):
37
+ """Read all PNGs in `path`, sort them by filename, convert BGR→RGB, and stack into an np.ndarray."""
38
+ return np.stack([
39
+ cv2.imread(f, cv2.IMREAD_UNCHANGED)[..., ::-1]
40
+ for f in sorted(glob.glob(f"{path}/*.png"))
41
+ ])
42
+
43
+
44
+ # Use 'spawn' to avoid CUDA context issues
45
+ multiprocessing.freeze_support() # on Windows
46
+ multiprocessing.set_start_method('spawn', force=True)
47
+
48
+ def compute_method(results_local, methodDir, files, countMethod):
49
+
50
+ fnLPIPS = lpips.LPIPS(net='alex').to(device)
51
+ #fnDISTS = DISTS().to(device)
52
+ fnFVD = fvd.cdfvd(model='videomae', device=device)
53
+
54
+ countFile = -1
55
+ for file in files:
56
+ countFile+=1
57
+
58
+ # pull frames from MP4
59
+ pathMethod = os.path.join(dataDir, methodDir, file)
60
+ framesMethod = np.clip(read_pngs_to_array(pathMethod).astype(np.float32) / (2**depth-1),0,1)
61
+ pathGT = os.path.join(dataDir, gtDir, file)
62
+ framesGT = np.clip(read_pngs_to_array(pathGT).astype(np.float32) / (2**depth-1),0,1)
63
+
64
+ #make sure the GT and method have the same shape
65
+ assert framesGT.shape == framesMethod.shape, f"GT shape {framesGT.shape} does not match method shape {framesMethod.shape} for file {file}"
66
+ # video metrics
67
+
68
+ # vmaf
69
+ #scoreVMAF = callVMAF(pathGT, pathMethod)
70
+
71
+ # epe - we have to change to tensors here
72
+ framesMethodTensor = torch.from_numpy(framesMethod)
73
+ framesGTtensor = torch.from_numpy(framesGT)
74
+ scoreEPE = epe(framesMethodTensor[0,:,:,:], framesMethodTensor[-1,:,:,:], framesGTtensor[0,:,:,:], framesGTtensor[-1,:,:,:], per_pixel_mode=True).cpu().detach().numpy()
75
+
76
+ # motion blur baseline
77
+ blurryGT = np.mean(framesGT ** 2.2,axis=0) ** (1/2.2)
78
+ blurryMethod = np.mean(framesMethod ** 2.2,axis=0) ** (1/2.2)
79
+ # MSE -> PSNR
80
+ mapBlurryMSE = (blurryGT - blurryMethod)**2
81
+ scoreBlurryMSE = np.mean(mapBlurryMSE)
82
+ scoreBlurryPSNR = (10 * np.log10(pixMax**2 / scoreBlurryMSE))
83
+
84
+ # fvd
85
+ #scoreFVD = fnFVD.compute_fvd(real_videos=(np.expand_dims(framesGT, axis=0)*(2**depth-1)).astype(np.uint8), fake_videos=(np.expand_dims(framesMethod, axis=0)*(2**depth-1)).astype(np.uint8))
86
+ framesGTfvd = np.expand_dims((framesGT * (2**depth-1)).astype(np.uint8), axis=0)
87
+ fnFVD.add_real_stats(framesGTfvd)
88
+ framesMethodFVD = np.expand_dims((framesMethod * (2**depth-1)).astype(np.uint8), axis=0)
89
+ fnFVD.add_fake_stats(framesMethodFVD)
90
+
91
+ # loop directions
92
+ framesMSE = np.stack((framesGT,framesGT)) # pre allocate array for directional PSNR maps
93
+ countDirect = -1
94
+ for direction in directions:
95
+ countDirect = countDirect+1
96
+ order = direction
97
+
98
+ # loop frames + image level metrics
99
+ countFrames = -1
100
+ for i in order:
101
+ countFrames+=1
102
+
103
+ frameMethod = framesMethod[i,:,:,:] # method frames can be re-ordered
104
+ frameGT = framesGT[countFrames,:,:,:]
105
+
106
+
107
+ #assert patch size is divisible by image size
108
+ rows, cols, ch = frameGT.shape
109
+ assert rows % patchDim == 0, f"rows {rows} is not divisible by patchDim {patchDim}"
110
+ assert cols % patchDim == 0, f"cols {cols} is not divisible by patchDim {patchDim}"
111
+
112
+ rPatch = np.ceil(rows/patchDim)
113
+ cPatch = np.ceil(cols/patchDim)
114
+
115
+ # LPIPS
116
+ #pdb.set_trace()
117
+ methodTensor = (torch.from_numpy(np.moveaxis(frameMethod, -1, 0)).unsqueeze(0) * 2 - 1).to(device)
118
+ gtTensor = (torch.from_numpy(np.moveaxis(frameGT, -1, 0)).unsqueeze(0) * 2 - 1).to(device)
119
+ #scoreLPIPS = fnLPIPS(gtTensor, methodTensor).squeeze(0,1,2).cpu().detach().numpy()[0]
120
+
121
+ # FID
122
+ #fnFID.update((gtTensor * (2**depth - 1)).to(torch.uint8), real=True)
123
+ #fnFID.update((methodTensor * (2**depth - 1)).to(torch.uint8), real=False)
124
+
125
+ # DISTS
126
+ #scoreDISTS = fnDISTS(gtTensor.to(torch.float), methodTensor.to(torch.float), require_grad=True, batch_average=True).cpu().detach().numpy()
127
+
128
+ # compute ssim
129
+ #scoreSSIM = structural_similarity(frameGT, frameMethod, data_range=pixMax, channel_axis=2)
130
+
131
+ # compute DE 2000
132
+ #frameMethodXYZ = c.RGB_to_XYZ(frameMethod, c.models.RGB_COLOURSPACE_sRGB, apply_cctf_decoding=True)
133
+ #frameMethodLAB = c.XYZ_to_Lab(frameMethodXYZ)
134
+ #frameGTXYZ = c.RGB_to_XYZ(frameGT, c.models.RGB_COLOURSPACE_sRGB, apply_cctf_decoding=True)
135
+ #frameGTLAB = c.XYZ_to_Lab(frameGTXYZ)
136
+ #mapDE2000 = c.delta_E(frameGTLAB, frameMethodLAB, method='CIE 2000')
137
+ #scoreDE2000 = np.mean(mapDE2000)
138
+
139
+ # MSE
140
+ mapMSE = (frameGT - frameMethod)**2
141
+ scoreMSE = np.mean(mapMSE)
142
+
143
+ # PSNR
144
+ framesMSE[countDirect,countFrames,:,:,:] = mapMSE
145
+ #framesPSNR[countDirect,countFrames,:,:,:] = np.clip((10 * np.log10(pixMax**2 / np.clip(mapMSE,a_min=1e-10,a_max=None))),0,100)
146
+ scorePSNR = (10 * np.log10(pixMax**2 / scoreMSE))
147
+
148
+ #for l in range(ch):
149
+
150
+ # channel-wise metrics
151
+ #chanFrameMethod = frameMethod[:,:,l]
152
+ #chanFrameGT = frameGT[:,:,l]
153
+
154
+ # loop patches rows
155
+ for j in range(int(rPatch)):
156
+
157
+ # loop patches cols + patch level metrics
158
+ for k in range(int(cPatch)):
159
+
160
+ startR = j*patchDim
161
+ startC = k*patchDim
162
+ endR = j*patchDim+patchDim
163
+ endC = k*patchDim+patchDim
164
+
165
+ if endR > rows:
166
+ endR = rows
167
+ else:
168
+ pass
169
+
170
+ if endC > cols:
171
+ endC = cols
172
+ else:
173
+ pass
174
+
175
+ # patch metrics
176
+ #patchMSE = np.mean(mapMSE[startR:endR,startC:endC,:])
177
+ #scorePatchPSNR = np.clip((10 * np.log10(pixMax**2 / patchMSE)),0,100)
178
+ if dataDir == 'BAISTResultsImages':
179
+ patchGtTensor = F.interpolate(gtTensor[:,:,startR:endR,startC:endC], scale_factor=2.0, mode='bilinear', align_corners=False)
180
+ patchMethodTensor = F.interpolate(methodTensor[:,:,startR:endR,startC:endC], scale_factor=2.0, mode='bilinear', align_corners=False)
181
+ scorePatchLPIPS = fnLPIPS(patchGtTensor, patchMethodTensor).squeeze(0,1,2).cpu().detach().numpy()[0]
182
+ else:
183
+ scorePatchLPIPS = fnLPIPS(gtTensor[:,:,startR:endR,startC:endC], methodTensor[:,:,startR:endR,startC:endC]).squeeze(0,1,2).cpu().detach().numpy()[0]
184
+ scorePatchSSIM = structural_similarity(frameGT[startR:endR,startC:endC,:], frameMethod[startR:endR,startC:endC,:], data_range=pixMax, channel_axis=2)
185
+ #scorePatchDISTS = fnDISTS(gtTensor[:,:,startR:endR,startC:endC].to(torch.float), methodTensor[:,:,startR:endR,startC:endC].to(torch.float), require_grad=True, batch_average=True).cpu().detach().numpy()
186
+ #scorePatchDE2000 = np.mean(mapDE2000[startR:endR,startC:endC])
187
+
188
+ # i: frame number, j: patch row, k: patch col
189
+ #results[countMethod,countFile,countDirect,i,j,k,3:] = [scoreEPE, scoreBlurryPSNR, scoreLPIPS, scoreDISTS, scoreSSIM, scoreDE2000, scorePSNR, scorePatchPSNR, scorePatchSSIM, scorePatchLPIPS, scorePatchDISTS, scorePatchDE2000]
190
+ results_local[countMethod,countFile,countDirect,i,j,k,2:] = [scoreEPE, scoreBlurryPSNR, scorePatchSSIM, scorePatchLPIPS, scorePSNR]
191
+ print('Method: ', methodDir, ' File: ', file, ' Frame: ', str(i), ' PSNR: ', scorePSNR, end='\r')
192
+ #print('VMAF: ', str(scoreVMAF), ' FVD: ', str(scoreFVD), ' LPIPS: ', str(scoreLPIPS), ' FID: ', str(scoreFID), ' DISTS: ', str(scoreDISTS), ' SSIM: ', str(scoreSSIM), ' DE2000: ', str(scoreDE2000), ' PSNR: ', str(scorePSNR), ' Patch PSNR: ', str(scorePatchPSNR), end='\r')
193
+ #pdb.set_trace()
194
+ scorePWPSNR = (10 * np.log10(pixMax**2 / np.mean(np.min(np.mean(framesMSE, axis=(1)),axis=0)))) # take max pixel wise PSNR per direction, average over image dims
195
+ #print('Method: ', methodDir, ' File: ', file, ' Frame: ', str(i), ' PWPSNR: ', scorePWPSNR, end='\n')
196
+ #scorePWPSNR = np.clip((10 * np.log10(pixMax**2 / np.mean(np.min(framesPSNR, axis=0),axis=(1,2,3)))),0,100) # take max pixel wise PSNR per direction, average over image dims
197
+ results_local[countMethod,countFile,:,:,:,:,1] = np.tile(scorePWPSNR, results_local.shape[2:-1])#np.broadcast_to(scorePWPSNR[:, np.newaxis, np.newaxis], results.shape[3:-1])
198
+ np.save(resFile, results_local) # save part of the way through the loop ..
199
+
200
+ #scoreFID = fnFID.compute().cpu().detach().numpy()
201
+ #fnFID.reset()
202
+ #results[countMethod,:,:,:,:,:,0] = np.tile(scoreFID, results.shape[1:-1])
203
+ scoreFVD = fnFVD.compute_fvd_from_stats()
204
+ fnFVD.empty_real_stats()
205
+ fnFVD.empty_fake_stats()
206
+ results_local[countMethod,:,:,:,:,:,0] = np.tile(scoreFVD, results_local.shape[1:-1])
207
+ print('Results computed .. analyzing ..')
208
+
209
+ return results_local
210
+
211
+
212
+ # init results matrix
213
+ path = os.path.join(dataDir, gtDir)
214
+ clipDirs = [name for name in os.listdir(path) if os.path.isdir(os.path.join(path, name))]
215
+ files = []
216
+ if dataDir == 'BAISTResultsImages':
217
+ extraFknDir = 'blur'
218
+ else:
219
+ extraFknDir = ''
220
+ for clipDir in clipDirs:
221
+ path = os.path.join(dataDir, gtDir, clipDir, extraFknDir)
222
+ files = files + [os.path.join(clipDir,extraFknDir,name) for name in os.listdir(path)]
223
+ files = sorted(files)
224
+ path = os.path.join(dataDir, methodDirs[0], files[0])
225
+ testFileGT = read_pngs_to_array(path)
226
+ frams,rows,cols,ch = testFileGT.shape
227
+ framRange = [i for i in range(frams)]
228
+ directions = [framRange, framRange[::-1]]
229
+
230
+ #loop through all methods and make sure they all have the same directory structure and same number of files
231
+ for methodDir in methodDirs:
232
+ path = os.path.join(dataDir, methodDir)
233
+ clipDirs = [name for name in os.listdir(path) if os.path.isdir(os.path.join(path, name))]
234
+ filesMethod = []
235
+ for clipDir in clipDirs:
236
+ path = os.path.join(dataDir, methodDir, clipDir, extraFknDir)
237
+ filesMethod = filesMethod + [os.path.join(clipDir,extraFknDir,name) for name in os.listdir(path)]
238
+ filesMethod = sorted(filesMethod)
239
+ assert len(files) == len(filesMethod), f"Number of files in {methodDir} does not match GT number of files"
240
+ assert files == filesMethod, f"Files in {methodDir} do not match GT files"
241
+
242
+ def main():
243
+
244
+ results = np.zeros((len(methodDirs),len(files),len(directions),frams,int(np.ceil(rows/patchDim)),int(np.ceil(cols/patchDim)),nMets))
245
+
246
+ if compute:
247
+
248
+ # loop methods + compute dataset level metrics (after nested for loops)
249
+ import multiprocessing as mp
250
+ ctx = mp.get_context('spawn')
251
+ with ProcessPoolExecutor(mp_context=ctx, max_workers=len(methodDirs)) as executor:
252
+ # submit one job per method
253
+ futures = {
254
+ executor.submit(compute_method, np.copy(results), md, files, idx): idx
255
+ for idx, md in enumerate(methodDirs)
256
+ }
257
+ # collect and merge results as they finish
258
+ for fut in as_completed(futures):
259
+ idx = futures[fut]
260
+ res_local = fut.result()
261
+ results[idx] = res_local[idx]
262
+
263
+
264
+ else:
265
+
266
+ results = np.load(resFile)
267
+
268
+ np.save(resFile, results)
269
+ # analyze
270
+
271
+ # new results: scoreFID, scoreFVD, scorePWPSNR, scoreEPE, scoreLPIPS, scoreDISTS, scoreSSIM, scoreDE2000, scorePSNR, scorePatchPSNR, scorePatchSSIM, scorePatchLPIPS, scorePatchDISTS, scorePatchDE2000
272
+ upMetrics = [1,3,4,6]
273
+
274
+
275
+ # 0508 results: scoreFID, scoreFVD, scoreLPIPS, scoreDISTS, scoreSSIM, scoreDE2000, scorePSNR, scorePatchPSNR, scorePatchSSIM, scorePatchLPIPS, scorePatchDISTS, scorePatchDE2000
276
+ #upMetrics = [4,6,7,8] # PSNR, SSIM, Patch PSNR, Patch SSIM
277
+ print("Results shape 1: ", results.shape)
278
+ forwardBackwardResults = np.mean(results,axis=(3))
279
+ #print("Results shape 2: ", forwardResults.shape)
280
+ maxDirResults = np.max(forwardBackwardResults,axis=(2))
281
+ minDirResults = np.min(forwardBackwardResults,axis=(2))
282
+ bestDirResults = minDirResults
283
+ #pdb.set_trace()
284
+ bestDirResults[:,:,:,:,upMetrics] = maxDirResults[:,:,:,:,upMetrics]
285
+ import pdb
286
+ #pdb.set_trace()
287
+
288
+ meanResults = bestDirResults.mean(axis=(1, 2, 3)) # Shape becomes (3, 6)
289
+ meanResultsT = meanResults.T
290
+
291
+ '''
292
+ maxDirResults = np.max(results,axis=2)
293
+ minDirResults = np.min(results,axis=2)
294
+ bestDirResults = minDirResults
295
+ bestDirResults[:,:,:,:,:,upMetrics] = maxDirResults[:,:,:,:,:,upMetrics]
296
+ meanResults = bestDirResults.mean(axis=(1, 2, 3, 4)) # Shape becomes (3, 6)
297
+ meanResultsT = meanResults.T
298
+ '''
299
+
300
+ #
301
+ #meanResults = forwardResults.mean(axis=(1, 2, 3, 4)) # Shape becomes (3, 6)
302
+ #meanResultsT = meanResults.T
303
+
304
+ # print latex table
305
+ method_labels = methodDirs
306
+
307
+ # results 0508: scoreLPIPS, scoreDISTS, scoreSSIM, scoreDE2000, scorePSNR, scorePatchPSNR, scorePatchSSIM, scorePatchLPIPS, scorePatchDISTS, scoreFID, scoreFVD
308
+ # metric_labels = ["FID $\downarrow$","FVD $\downarrow$","LPIPS $\downarrow$", "DISTS $\downarrow$", "SSIM $\downarrow$", "DE2000 $\downarrow$", "PSNR $\downarrow$", "Patch PSNR $\downarrow$", "Patch SSIM $\downarrow$", "Patch LPIPS $\downarrow$", "Patch DISTS $\downarrow$", "Patch DE2000 $\downarrow$"]
309
+ # results 0517:
310
+ # metric_labels = ["FID $\downarrow$","FVD $\downarrow$","PWPSNR $\downarrow$","EPE $\downarrow$","BlurryPSNR $\downarrow$", "LPIPS $\downarrow$", "DISTS $\downarrow$", "SSIM $\downarrow$", "DE2000 $\downarrow$", "PSNR $\downarrow$", "Patch PSNR $\downarrow$", "Patch SSIM $\downarrow$", "Patch LPIPS $\downarrow$", "Patch DISTS $\downarrow$", "Patch DE2000 $\downarrow$"]
311
+
312
+ # results 0518:
313
+ metric_labels = ["FVD $\downarrow$","PWPSNR $\downarrow$","EPE $\downarrow$","BlurryPSNR $\downarrow$","Patch SSIM $\downarrow$","Patch LPIPS $\downarrow$", "PSNR $\downarrow$"]
314
+
315
+ # appropriate for results 0507
316
+ #metric_labels = ["FID $\downarrow$", "FVD $\downarrow$", "LPIPS $\downarrow$", "DISTS $\downarrow$", "SSIM $\downarrow$", "DE2000 $\downarrow$", "PSNR $\downarrow$"]
317
+
318
+ latex_table = "\\begin{tabular}{l" + "c" * len(method_labels) + "}\n"
319
+ latex_table += "Metric & " + " & ".join(method_labels) + " \\\\\n"
320
+ latex_table += "\\hline\n"
321
+
322
+ for metric, row in zip(metric_labels, meanResultsT):
323
+ row_values = " & ".join(f"{v:.4f}" for v in row)
324
+ latex_table += f"{metric} & {row_values} \\\\\n"
325
+
326
+ latex_table += "\\end{tabular}"
327
+ print(latex_table)
328
+
329
+ if __name__ == '__main__':
330
+ main()
extra/moMets-parallel-gopro.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Motion Metrics
2
+ from concurrent.futures import ProcessPoolExecutor, as_completed
3
+ import numpy as np
4
+ np.float = np.float64
5
+ np.int = np.int_
6
+ import os
7
+ from cdfvd import fvd
8
+ from skimage.metrics import structural_similarity
9
+ import torch
10
+ import lpips
11
+ #from DISTS_pytorch import DISTS
12
+ #import colour as c
13
+ #from torchmetrics.image.fid import FrechetInceptionDistance
14
+ import torch.nn.functional as F
15
+ from epe_metric import compute_bidirectional_epe as epe
16
+ import pdb
17
+ import multiprocessing
18
+ import cv2
19
+ import glob
20
+ # init
21
+ # dataDir = 'BaistCroppedOutput' # 'dataGoPro' #
22
+ # gtDir = 'gt_subset' #'GT' #
23
+ # methodDirs = ['deblurred', 'animation-from-blur', ] #['Favaro','MotionETR','Ours','GOPROGeneralize'] #
24
+ # fType = '.mp4'
25
+ # depth = 8
26
+ # resFile = './resultsBaist20250521.npy'#resultsGoPro20250520.npy'#
27
+
28
+ # patchDim = 32 #64 #
29
+ # pixMax = 1.0
30
+
31
+ # nMets = 7 # new results: scoreFVD, scorePWPSNR, scoreEPE, scorePatchSSIM, scorePatchLPIPS, scorePSNR
32
+ # compute = True # if False, load previously computed
33
+ # eps = 1e-8
34
+
35
+ dataDir = 'GOPROResultsImages' # 'dataBaist' #
36
+ gtDir = 'GT' #'gt' #
37
+ methodDirs = ['Jin','MotionETR','Ours'] #'GOPROGeneralize',# ['animation-from-blur'] #
38
+ depth = 8
39
+ resFile = 'resultsGoPro20250521.npy'# './resultsBaist20250521.npy'#
40
+ patchDim = 40 #32 #
41
+ pixMax = 1.0
42
+ nMets = 7 # new results: scoreFVD, scorePWPSNR, scoreEPE, scorePatchSSIM, scorePatchLPIPS, scorePSNR
43
+ compute = False # if False, load previously computed
44
+ eps = 1e-8
45
+
46
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
47
+
48
+ # Use 'spawn' to avoid CUDA context issues
49
+ multiprocessing.freeze_support() # on Windows
50
+ multiprocessing.set_start_method('spawn', force=True)
51
+
52
+ def read_pngs_to_array(path):
53
+ """Read all PNGs in `path`, sort them by filename, convert BGR→RGB, and stack into an np.ndarray."""
54
+ return np.stack([
55
+ cv2.imread(f, cv2.IMREAD_UNCHANGED)[..., ::-1]
56
+ for f in sorted(glob.glob(f"{path}/*.png"))
57
+ ])
58
+ def compute_method(results_local, methodDir, files, countMethod):
59
+
60
+ fnLPIPS = lpips.LPIPS(net='alex').to(device)
61
+ #fnDISTS = DISTS().to(device)
62
+ fnFVD = fvd.cdfvd(model='videomae', device=device)
63
+
64
+ countFile = -1
65
+ for file in files:
66
+ countFile+=1
67
+
68
+ # pull frames from MP4
69
+ pathMethod = os.path.join(dataDir, methodDir, file)
70
+ framesMethod = np.clip(read_pngs_to_array(pathMethod).astype(np.float32) / (2**depth-1),0,1)
71
+ pathGT = os.path.join(dataDir, gtDir, file)
72
+ framesGT = np.clip(read_pngs_to_array(pathGT).astype(np.float32) / (2**depth-1),0,1)
73
+
74
+ #make sure the GT and method have the same shape
75
+ assert framesGT.shape == framesMethod.shape, f"GT shape {framesGT.shape} does not match method shape {framesMethod.shape} for file {file}"
76
+
77
+ # video metrics
78
+
79
+ # vmaf
80
+ #scoreVMAF = callVMAF(pathGT, pathMethod)
81
+
82
+ # epe - we have to change to tensors here
83
+ framesMethodTensor = torch.from_numpy(framesMethod)
84
+ framesGTtensor = torch.from_numpy(framesGT)
85
+ scoreEPE = epe(framesMethodTensor[0,:,:,:], framesMethodTensor[-1,:,:,:], framesGTtensor[0,:,:,:], framesGTtensor[-1,:,:,:], per_pixel_mode=True).cpu().detach().numpy()
86
+
87
+ # motion blur baseline
88
+ blurryGT = np.mean(framesGT ** 2.2,axis=0) ** (1/2.2)
89
+ blurryMethod = np.mean(framesMethod ** 2.2,axis=0) ** (1/2.2)
90
+ # MSE -> PSNR
91
+ mapBlurryMSE = (blurryGT - blurryMethod)**2
92
+ scoreBlurryMSE = np.mean(mapBlurryMSE)
93
+ scoreBlurryPSNR = (10 * np.log10(pixMax**2 / scoreBlurryMSE))
94
+
95
+ # fvd
96
+ #scoreFVD = fnFVD.compute_fvd(real_videos=(np.expand_dims(framesGT, axis=0)*(2**depth-1)).astype(np.uint8), fake_videos=(np.expand_dims(framesMethod, axis=0)*(2**depth-1)).astype(np.uint8))
97
+ framesGTfvd = np.expand_dims((framesGT * (2**depth-1)).astype(np.uint8), axis=0)
98
+ fnFVD.add_real_stats(framesGTfvd)
99
+ framesMethodFVD = np.expand_dims((framesMethod * (2**depth-1)).astype(np.uint8), axis=0)
100
+ fnFVD.add_fake_stats(framesMethodFVD)
101
+
102
+ # loop directions
103
+ framesMSE = np.stack((framesGT,framesGT)) # pre allocate array for directional PSNR maps
104
+ countDirect = -1
105
+ for direction in directions:
106
+ countDirect = countDirect+1
107
+ order = direction
108
+
109
+ # loop frames + image level metrics
110
+ countFrames = -1
111
+ for i in order:
112
+ countFrames+=1
113
+
114
+ frameMethod = framesMethod[i,:,:,:] # method frames can be re-ordered
115
+ frameGT = framesGT[countFrames,:,:,:]
116
+
117
+ #assert patch size is divisible by image size
118
+ rows, cols, ch = frameGT.shape
119
+ assert rows % patchDim == 0, f"rows {rows} is not divisible by patchDim {patchDim}"
120
+ assert cols % patchDim == 0, f"cols {cols} is not divisible by patchDim {patchDim}"
121
+
122
+ rPatch = np.ceil(rows/patchDim)
123
+ cPatch = np.ceil(cols/patchDim)
124
+
125
+ # LPIPS
126
+ #pdb.set_trace()
127
+ methodTensor = (torch.from_numpy(np.moveaxis(frameMethod, -1, 0)).unsqueeze(0) * 2 - 1).to(device)
128
+ gtTensor = (torch.from_numpy(np.moveaxis(frameGT, -1, 0)).unsqueeze(0) * 2 - 1).to(device)
129
+ #scoreLPIPS = fnLPIPS(gtTensor, methodTensor).squeeze(0,1,2).cpu().detach().numpy()[0]
130
+
131
+ # FID
132
+ #fnFID.update((gtTensor * (2**depth - 1)).to(torch.uint8), real=True)
133
+ #fnFID.update((methodTensor * (2**depth - 1)).to(torch.uint8), real=False)
134
+
135
+ # DISTS
136
+ #scoreDISTS = fnDISTS(gtTensor.to(torch.float), methodTensor.to(torch.float), require_grad=True, batch_average=True).cpu().detach().numpy()
137
+
138
+ # compute ssim
139
+ #scoreSSIM = structural_similarity(frameGT, frameMethod, data_range=pixMax, channel_axis=2)
140
+
141
+ # compute DE 2000
142
+ #frameMethodXYZ = c.RGB_to_XYZ(frameMethod, c.models.RGB_COLOURSPACE_sRGB, apply_cctf_decoding=True)
143
+ #frameMethodLAB = c.XYZ_to_Lab(frameMethodXYZ)
144
+ #frameGTXYZ = c.RGB_to_XYZ(frameGT, c.models.RGB_COLOURSPACE_sRGB, apply_cctf_decoding=True)
145
+ #frameGTLAB = c.XYZ_to_Lab(frameGTXYZ)
146
+ #mapDE2000 = c.delta_E(frameGTLAB, frameMethodLAB, method='CIE 2000')
147
+ #scoreDE2000 = np.mean(mapDE2000)
148
+
149
+ # MSE
150
+ mapMSE = (frameGT - frameMethod)**2
151
+ scoreMSE = np.mean(mapMSE)
152
+
153
+ # PSNR
154
+ framesMSE[countDirect,countFrames,:,:,:] = mapMSE
155
+ #framesPSNR[countDirect,countFrames,:,:,:] = np.clip((10 * np.log10(pixMax**2 / np.clip(mapMSE,a_min=1e-10,a_max=None))),0,100)
156
+ scorePSNR = (10 * np.log10(pixMax**2 / scoreMSE))
157
+
158
+ #for l in range(ch):
159
+
160
+ # channel-wise metrics
161
+ #chanFrameMethod = frameMethod[:,:,l]
162
+ #chanFrameGT = frameGT[:,:,l]
163
+
164
+ # loop patches rows
165
+ for j in range(int(rPatch)):
166
+
167
+ # loop patches cols + patch level metrics
168
+ for k in range(int(cPatch)):
169
+
170
+ startR = j*patchDim
171
+ startC = k*patchDim
172
+ endR = j*patchDim+patchDim
173
+ endC = k*patchDim+patchDim
174
+
175
+ if endR > rows:
176
+ endR = rows
177
+ else:
178
+ pass
179
+
180
+ if endC > cols:
181
+ endC = cols
182
+ else:
183
+ pass
184
+
185
+ # patch metrics
186
+ #patchMSE = np.mean(mapMSE[startR:endR,startC:endC,:])
187
+ #scorePatchPSNR = np.clip((10 * np.log10(pixMax**2 / patchMSE)),0,100)
188
+ if dataDir == 'BaistCroppedOutput':
189
+ patchGtTensor = F.interpolate(gtTensor[:,:,startR:endR,startC:endC], scale_factor=2.0, mode='bilinear', align_corners=False)
190
+ patchMethodTensor = F.interpolate(methodTensor[:,:,startR:endR,startC:endC], scale_factor=2.0, mode='bilinear', align_corners=False)
191
+ scorePatchLPIPS = fnLPIPS(patchGtTensor, patchMethodTensor).squeeze(0,1,2).cpu().detach().numpy()[0]
192
+ else:
193
+ scorePatchLPIPS = fnLPIPS(gtTensor[:,:,startR:endR,startC:endC], methodTensor[:,:,startR:endR,startC:endC]).squeeze(0,1,2).cpu().detach().numpy()[0]
194
+ scorePatchSSIM = structural_similarity(frameGT[startR:endR,startC:endC,:], frameMethod[startR:endR,startC:endC,:], data_range=pixMax, channel_axis=2)
195
+ #scorePatchDISTS = fnDISTS(gtTensor[:,:,startR:endR,startC:endC].to(torch.float), methodTensor[:,:,startR:endR,startC:endC].to(torch.float), require_grad=True, batch_average=True).cpu().detach().numpy()
196
+ #scorePatchDE2000 = np.mean(mapDE2000[startR:endR,startC:endC])
197
+
198
+ # i: frame number, j: patch row, k: patch col
199
+ #results[countMethod,countFile,countDirect,i,j,k,3:] = [scoreEPE, scoreBlurryPSNR, scoreLPIPS, scoreDISTS, scoreSSIM, scoreDE2000, scorePSNR, scorePatchPSNR, scorePatchSSIM, scorePatchLPIPS, scorePatchDISTS, scorePatchDE2000]
200
+ results_local[countMethod,countFile,countDirect,i,j,k,2:] = [scoreEPE, scoreBlurryPSNR, scorePatchSSIM, scorePatchLPIPS, scorePSNR]
201
+ print('Method: ', methodDir, ' File: ', file, ' Frame: ', str(i), ' PSNR: ', scorePSNR, end='\r')
202
+
203
+ #print('VMAF: ', str(scoreVMAF), ' FVD: ', str(scoreFVD), ' LPIPS: ', str(scoreLPIPS), ' FID: ', str(scoreFID), ' DISTS: ', str(scoreDISTS), ' SSIM: ', str(scoreSSIM), ' DE2000: ', str(scoreDE2000), ' PSNR: ', str(scorePSNR), ' Patch PSNR: ', str(scorePatchPSNR), end='\r')
204
+ #pdb.set_trace()
205
+ scorePWPSNR = (10 * np.log10(pixMax**2 / np.mean(np.min(np.mean(framesMSE, axis=(1)),axis=0)))) # take max pixel wise PSNR per direction, average over image dims
206
+ #print('Method: ', methodDir, ' File: ', file, ' Frame: ', str(i), ' PWPSNR: ', scorePWPSNR, end='\n')
207
+ #scorePWPSNR = np.clip((10 * np.log10(pixMax**2 / np.mean(np.min(framesPSNR, axis=0),axis=(1,2,3)))),0,100) # take max pixel wise PSNR per direction, average over image dims
208
+ results_local[countMethod,countFile,:,:,:,:,1] = np.tile(scorePWPSNR, results_local.shape[2:-1])#np.broadcast_to(scorePWPSNR[:, np.newaxis, np.newaxis], results.shape[3:-1])
209
+ np.save(resFile, results_local) # save part of the way through the loop ..
210
+
211
+ #scoreFID = fnFID.compute().cpu().detach().numpy()
212
+ #fnFID.reset()
213
+ #results[countMethod,:,:,:,:,:,0] = np.tile(scoreFID, results.shape[1:-1])
214
+ scoreFVD = fnFVD.compute_fvd_from_stats()
215
+ fnFVD.empty_real_stats()
216
+ fnFVD.empty_fake_stats()
217
+ results_local[countMethod,:,:,:,:,:,0] = np.tile(scoreFVD, results_local.shape[1:-1])
218
+ print('Results computed .. analyzing ..')
219
+
220
+ return results_local
221
+
222
+
223
+ # init results matrix
224
+ path = os.path.join(dataDir, gtDir)
225
+ clipDirs = [name for name in os.listdir(path) if os.path.isdir(os.path.join(path, name))]
226
+ files = []
227
+ if dataDir == 'BaistCroppedOutput':
228
+ extraFknDir = 'blur'
229
+ else:
230
+ extraFknDir = ''
231
+ for clipDir in clipDirs:
232
+ path = os.path.join(dataDir, gtDir, clipDir, extraFknDir)
233
+ files = files + [os.path.join(clipDir,extraFknDir,name) for name in os.listdir(path)]
234
+ files = sorted(files)
235
+ path = os.path.join(dataDir, methodDirs[0], files[0])
236
+ testFileGT = read_pngs_to_array(path)
237
+ frams,rows,cols,ch = testFileGT.shape
238
+ framRange = [i for i in range(frams)]
239
+ directions = [framRange, framRange[::-1]]
240
+
241
+ #loop through all methods and make sure they all have the same directory structure and same number of files
242
+ for methodDir in methodDirs:
243
+ path = os.path.join(dataDir, methodDir)
244
+ clipDirs = [name for name in os.listdir(path) if os.path.isdir(os.path.join(path, name))]
245
+ filesMethod = []
246
+ for clipDir in clipDirs:
247
+ path = os.path.join(dataDir, methodDir, clipDir, extraFknDir)
248
+ filesMethod = filesMethod + [os.path.join(clipDir,extraFknDir,name) for name in os.listdir(path)]
249
+ filesMethod = sorted(filesMethod)
250
+ print('Method: ', methodDir, ' Number of files: ', len(filesMethod))
251
+ assert len(files) == len(filesMethod), f"Number of files in {methodDir} does not match GT number of files"
252
+ assert files == filesMethod, f"Files in {methodDir} do not match GT files"
253
+
254
+
255
+ def main():
256
+
257
+ results = np.zeros((len(methodDirs),len(files),len(directions),frams,int(np.ceil(rows/patchDim)),int(np.ceil(cols/patchDim)),nMets))
258
+
259
+ if compute:
260
+
261
+ # loop methods + compute dataset level metrics (after nested for loops)
262
+ import multiprocessing as mp
263
+ ctx = mp.get_context('spawn')
264
+ with ProcessPoolExecutor(mp_context=ctx, max_workers=len(methodDirs)) as executor:
265
+ # submit one job per method
266
+ futures = {
267
+ executor.submit(compute_method, np.copy(results), md, files, idx): idx
268
+ for idx, md in enumerate(methodDirs)
269
+ }
270
+ # collect and merge results as they finish
271
+ for fut in as_completed(futures):
272
+ idx = futures[fut]
273
+ res_local = fut.result()
274
+ results[idx] = res_local[idx]
275
+
276
+
277
+ else:
278
+
279
+ results = np.load(resFile)
280
+
281
+ np.save(resFile, results)
282
+ # analyze
283
+
284
+ # new results: scoreFID, scoreFVD, scorePWPSNR, scoreEPE, scoreLPIPS, scoreDISTS, scoreSSIM, scoreDE2000, scorePSNR, scorePatchPSNR, scorePatchSSIM, scorePatchLPIPS, scorePatchDISTS, scorePatchDE2000
285
+ upMetrics = [1,3,4,6]
286
+
287
+
288
+ # 0508 results: scoreFID, scoreFVD, scoreLPIPS, scoreDISTS, scoreSSIM, scoreDE2000, scorePSNR, scorePatchPSNR, scorePatchSSIM, scorePatchLPIPS, scorePatchDISTS, scorePatchDE2000
289
+ #upMetrics = [4,6,7,8] # PSNR, SSIM, Patch PSNR, Patch SSIM
290
+ print("Results shape 1: ", results.shape)
291
+ forwardBackwardResults = np.mean(results,axis=(3))
292
+ #print("Results shape 2: ", forwardResults.shape)
293
+ maxDirResults = np.max(forwardBackwardResults,axis=(2))
294
+ minDirResults = np.min(forwardBackwardResults,axis=(2))
295
+ bestDirResults = minDirResults
296
+ #pdb.set_trace()
297
+ bestDirResults[:,:,:,:,upMetrics] = maxDirResults[:,:,:,:,upMetrics]
298
+ import pdb
299
+ #pdb.set_trace()
300
+
301
+ meanResults = bestDirResults.mean(axis=(1, 2, 3)) # Shape becomes (3, 6)
302
+ meanResultsT = meanResults.T
303
+
304
+ '''
305
+ maxDirResults = np.max(results,axis=2)
306
+ minDirResults = np.min(results,axis=2)
307
+ bestDirResults = minDirResults
308
+ bestDirResults[:,:,:,:,:,upMetrics] = maxDirResults[:,:,:,:,:,upMetrics]
309
+ meanResults = bestDirResults.mean(axis=(1, 2, 3, 4)) # Shape becomes (3, 6)
310
+ meanResultsT = meanResults.T
311
+ '''
312
+
313
+ #
314
+ #meanResults = forwardResults.mean(axis=(1, 2, 3, 4)) # Shape becomes (3, 6)
315
+ #meanResultsT = meanResults.T
316
+
317
+ # print latex table
318
+ method_labels = methodDirs
319
+
320
+ # results 0508: scoreLPIPS, scoreDISTS, scoreSSIM, scoreDE2000, scorePSNR, scorePatchPSNR, scorePatchSSIM, scorePatchLPIPS, scorePatchDISTS, scoreFID, scoreFVD
321
+ # metric_labels = ["FID $\downarrow$","FVD $\downarrow$","LPIPS $\downarrow$", "DISTS $\downarrow$", "SSIM $\downarrow$", "DE2000 $\downarrow$", "PSNR $\downarrow$", "Patch PSNR $\downarrow$", "Patch SSIM $\downarrow$", "Patch LPIPS $\downarrow$", "Patch DISTS $\downarrow$", "Patch DE2000 $\downarrow$"]
322
+ # results 0517:
323
+ # metric_labels = ["FID $\downarrow$","FVD $\downarrow$","PWPSNR $\downarrow$","EPE $\downarrow$","BlurryPSNR $\downarrow$", "LPIPS $\downarrow$", "DISTS $\downarrow$", "SSIM $\downarrow$", "DE2000 $\downarrow$", "PSNR $\downarrow$", "Patch PSNR $\downarrow$", "Patch SSIM $\downarrow$", "Patch LPIPS $\downarrow$", "Patch DISTS $\downarrow$", "Patch DE2000 $\downarrow$"]
324
+
325
+ # results 0518:
326
+ metric_labels = ["FVD $\downarrow$","PWPSNR $\downarrow$","EPE $\downarrow$","BlurryPSNR $\downarrow$","Patch SSIM $\downarrow$","Patch LPIPS $\downarrow$", "PSNR $\downarrow$"]
327
+
328
+ # appropriate for results 0507
329
+ #metric_labels = ["FID $\downarrow$", "FVD $\downarrow$", "LPIPS $\downarrow$", "DISTS $\downarrow$", "SSIM $\downarrow$", "DE2000 $\downarrow$", "PSNR $\downarrow$"]
330
+
331
+ latex_table = "\\begin{tabular}{l" + "c" * len(method_labels) + "}\n"
332
+ latex_table += "Metric & " + " & ".join(method_labels) + " \\\\\n"
333
+ latex_table += "\\hline\n"
334
+
335
+ for metric, row in zip(metric_labels, meanResultsT):
336
+ row_values = " & ".join(f"{v:.4f}" for v in row)
337
+ latex_table += f"{metric} & {row_values} \\\\\n"
338
+
339
+ latex_table += "\\end{tabular}"
340
+ print(latex_table)
341
+
342
+ if __name__ == '__main__':
343
+ main()
gradio/app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ from pathlib import Path
4
+
5
+ import gradio as gr
6
+ from PIL import Image
7
+
8
+ # -----------------------
9
+ # 1. Load your model here
10
+ # -----------------------
11
+ # Example:
12
+ # from my_model_lib import MyVideoModel
13
+ # model = MyVideoModel.from_pretrained("your/model/hub/id")
14
+
15
+ OUTPUT_DIR = Path("/tmp/generated_videos")
16
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
17
+
18
+
19
+ def generate_video_from_image(image: Image.Image) -> str:
20
+ video_id = uuid.uuid4().hex
21
+ output_path = OUTPUT_DIR / f"{video_id}.mp4"
22
+
23
+ # 1. Preprocess image
24
+ # img_tensor = preprocess(image) # your code
25
+
26
+ # 2. Run model
27
+ # frames = model(img_tensor) # e.g. np.ndarray of shape (T, H, W, 3), dtype=uint8
28
+
29
+ # 3. Save frames to video
30
+ # iio.imwrite(
31
+ # uri=output_path,
32
+ # image=frames,
33
+ # fps=16,
34
+ # codec="h264",
35
+ # )
36
+
37
+ return str(output_path)
38
+
39
+
40
+ def demo_predict(image: Image.Image) -> str:
41
+ """
42
+ Wrapper for Gradio. Takes an image and returns a video path.
43
+ """
44
+ if image is None:
45
+ raise gr.Error("Please upload an image first.")
46
+
47
+ video_path = generate_video_from_image(image)
48
+ if not os.path.exists(video_path):
49
+ raise gr.Error("Video generation failed: output file not found.")
50
+ return video_path
51
+
52
+
53
+ with gr.Blocks(css="footer {visibility: hidden}") as demo:
54
+ gr.Markdown(
55
+ """
56
+ # 🖼️ ➜ 🎬 Recover motion from a blurry image!
57
+
58
+ Upload an image and the model will generate a short video.
59
+ """
60
+ )
61
+
62
+ with gr.Row():
63
+ with gr.Column():
64
+ image_in = gr.Image(
65
+ type="pil",
66
+ label="Input image",
67
+ interactive=True,
68
+ )
69
+ generate_btn = gr.Button("Generate video", variant="primary")
70
+ with gr.Column():
71
+ video_out = gr.Video(
72
+ label="Generated video",
73
+ format="mp4", # ensures browser-friendly output
74
+ autoplay=True,
75
+ )
76
+
77
+ generate_btn.click(
78
+ fn=demo_predict,
79
+ inputs=image_in,
80
+ outputs=video_out,
81
+ api_name="predict",
82
+ )
83
+
84
+ if __name__ == "__main__":
85
+ demo.launch()
inference.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import os
18
+ from pathlib import Path
19
+ import io
20
+ import yaml
21
+
22
+ from PIL import Image, ImageCms
23
+ import torch
24
+ import numpy as np
25
+ from transformers import T5Tokenizer, T5EncoderModel
26
+ from safetensors.torch import load_file
27
+ import diffusers
28
+ from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler
29
+ from diffusers.utils import check_min_version, export_to_video
30
+
31
+ from controlnet_pipeline import ControlnetCogVideoXPipeline
32
+ from cogvideo_transformer import CogVideoXTransformer3DModel
33
+
34
+ from training.utils import save_frames_as_pngs
35
+ from training.helpers import get_conditioning
36
+
37
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
38
+ check_min_version("0.31.0.dev0")
39
+
40
+
41
+ def convert_to_srgb(img: Image):
42
+ if 'icc_profile' in img.info:
43
+ icc = img.info['icc_profile']
44
+ src_profile = ImageCms.ImageCmsProfile(io.BytesIO(icc))
45
+ dst_profile = ImageCms.createProfile("sRGB")
46
+ img = ImageCms.profileToProfile(img, src_profile, dst_profile, outputMode='RGB')
47
+ else:
48
+ img = img.convert("RGB") # Assume sRGB
49
+ return img
50
+
51
+
52
+ INTERVALS = {
53
+ "present": {
54
+ "in_start": 0,
55
+ "in_end": 16,
56
+ "out_start": 0,
57
+ "out_end": 16,
58
+ "center": 8,
59
+ "window_size": 16,
60
+ "mode": "1x",
61
+ "fps": 240
62
+ },
63
+ "past_present_and_future": {
64
+ "in_start": 4,
65
+ "in_end": 12,
66
+ "out_start": 0,
67
+ "out_end": 16,
68
+ "center": 8,
69
+ "window_size": 16,
70
+ "mode": "2x",
71
+ "fps": 240,
72
+ },
73
+ }
74
+
75
+
76
+ def convert_to_batch(
77
+ image,
78
+ interval_key="present",
79
+ image_size=(720, 1280),
80
+ ):
81
+ interval = INTERVALS[interval_key]
82
+
83
+ inp_int, out_int, num_frames = get_conditioning(
84
+ in_start=interval['in_start'],
85
+ in_end=interval['in_end'],
86
+ out_start=interval['out_start'],
87
+ out_end=interval['out_end'],
88
+ mode=interval['mode'],
89
+ fps=interval['fps'],
90
+ )
91
+
92
+ blur_img_original = convert_to_srgb(image)
93
+ H, W = blur_img_original.size
94
+
95
+ blur_img = blur_img_original.resize((image_size[1], image_size[0])) # pil is width, height
96
+ blur_img = torch.from_numpy(np.array(blur_img)[None]).permute(0, 3, 1, 2).contiguous().float()
97
+ blur_img = blur_img / 127.5 - 1.0
98
+
99
+ data = {
100
+ "original_size": (H, W),
101
+ 'blur_img': blur_img,
102
+ 'caption': "",
103
+ 'input_interval': inp_int,
104
+ 'output_interval': out_int,
105
+ 'height': image_size[0],
106
+ 'width': image_size[1],
107
+ 'num_frames': num_frames,
108
+ }
109
+ return data
110
+
111
+
112
+ def load_model(args):
113
+ with open(args.model_config_path) as f:
114
+ model_config = yaml.safe_load(f)
115
+
116
+ load_dtype = torch.float16
117
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
118
+ args.pretrained_model_path,
119
+ subfolder="transformer",
120
+ torch_dtype=load_dtype,
121
+ revision=model_config["revision"],
122
+ variant=model_config["variant"],
123
+ low_cpu_mem_usage=False,
124
+ )
125
+ transformer.load_state_dict(load_file(args.weight_path))
126
+
127
+ text_encoder = T5EncoderModel.from_pretrained(
128
+ args.pretrained_model_path,
129
+ subfolder="text_encoder",
130
+ revision=model_config["revision"],
131
+ )
132
+
133
+ tokenizer = T5Tokenizer.from_pretrained(
134
+ args.pretrained_model_path,
135
+ subfolder="tokenizer",
136
+ revision=model_config["revision"],
137
+ )
138
+
139
+ vae = AutoencoderKLCogVideoX.from_pretrained(
140
+ args.pretrained_model_path,
141
+ subfolder="vae",
142
+ revision=model_config["revision"],
143
+ variant=model_config["variant"],
144
+ )
145
+
146
+ scheduler = CogVideoXDPMScheduler.from_pretrained(
147
+ args.pretrained_model_path,
148
+ subfolder="scheduler"
149
+ )
150
+
151
+ # Enable slicing or tiling if VRAM is low!
152
+ vae.enable_slicing()
153
+ vae.enable_tiling()
154
+
155
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
156
+ # as these weights are only used for inference, keeping weights in full precision is not required.
157
+ weight_dtype = torch.bfloat16
158
+
159
+ text_encoder.to(args.device, dtype=weight_dtype)
160
+ transformer.to(args.device, dtype=weight_dtype)
161
+ vae.to(args.device, dtype=weight_dtype)
162
+
163
+ pipe = ControlnetCogVideoXPipeline.from_pretrained(
164
+ args.pretrained_model_path,
165
+ tokenizer=tokenizer,
166
+ transformer=transformer,
167
+ text_encoder=text_encoder,
168
+ vae=vae,
169
+ scheduler=scheduler,
170
+ torch_dtype=weight_dtype,
171
+ )
172
+
173
+ scheduler_args = {}
174
+
175
+ if "variance_type" in pipe.scheduler.config:
176
+ variance_type = pipe.scheduler.config.variance_type
177
+
178
+ if variance_type in ["learned", "learned_range"]:
179
+ variance_type = "fixed_small"
180
+
181
+ scheduler_args["variance_type"] = variance_type
182
+
183
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)
184
+ pipe = pipe.to(args.device)
185
+
186
+ return pipe, model_config
187
+
188
+
189
+ def inference_on_image(pipe, image, interval_key, model_config, args):
190
+ # If passed along, set the training seed now.
191
+ if args.seed is not None:
192
+ np.random.seed(args.seed)
193
+ torch.manual_seed(args.seed)
194
+
195
+ # run inference
196
+ generator = torch.Generator(device=args.device).manual_seed(args.seed) if args.seed else None
197
+
198
+ with torch.autocast(args.device, enabled=True):
199
+ batch = convert_to_batch(image, interval_key, (args.video_height, args.video_width))
200
+
201
+ frame = batch["blur_img"].permute(0, 2, 3, 1).cpu().numpy()
202
+ frame = (frame + 1.0) * 127.5
203
+ frame = frame.astype(np.uint8)
204
+ pipeline_args = {
205
+ "prompt": "",
206
+ "negative_prompt": "",
207
+ "image": frame,
208
+ "input_intervals": torch.stack([batch["input_interval"]]),
209
+ "output_intervals": torch.stack([batch["output_interval"]]),
210
+ "guidance_scale": model_config["guidance_scale"],
211
+ "use_dynamic_cfg": model_config["use_dynamic_cfg"],
212
+ "height": batch["height"],
213
+ "width": batch["width"],
214
+ "num_frames": torch.tensor([[model_config["max_num_frames"]]]), # torch.tensor([[batch["num_frames"]]]),
215
+ "num_inference_steps": model_config["num_inference_steps"],
216
+ }
217
+
218
+ input_image = frame
219
+
220
+ num_frames = batch["num_frames"] # this is the actual number of frames, the video generation is padded by one frame
221
+
222
+ print(f"Running inference for interval {interval_key}...")
223
+ video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
224
+
225
+ video = video[0:num_frames]
226
+
227
+ return input_image, video
228
+
229
+
230
+ def main(args):
231
+ output_path = Path(args.output_path)
232
+ output_path.mkdir(exist_ok=True)
233
+
234
+ image_path = Path(args.image_path)
235
+
236
+ is_dir = image_path.is_dir()
237
+
238
+ if is_dir:
239
+ image_paths = sorted(list(image_path.glob("*.*")))
240
+ else:
241
+ image_paths = [image_path]
242
+
243
+ pipe, model_config = load_model(args)
244
+
245
+ for image_path in image_paths:
246
+ image = Image.open(image_path)
247
+
248
+ processed_image, video = inference_on_image(pipe, image, "past_present_and_future", model_config, args)
249
+
250
+ vid_output_path = output_path / f"{image_path.stem}.mp4"
251
+ export_to_video(video, vid_output_path, fps=20)
252
+
253
+ # save input image as well
254
+ inpug_image_output_path = output_path / f"{image_path.stem}_input.png"
255
+ Image.fromarray(processed_image[0]).save(inpug_image_output_path)
256
+
257
+
258
+ if __name__ == "__main__":
259
+ parser = argparse.ArgumentParser()
260
+ parser.add_argument(
261
+ "--image_path",
262
+ type=str,
263
+ required=True,
264
+ help="Path to image input or directory containing input images",
265
+ )
266
+ parser.add_argument(
267
+ "--weight_path",
268
+ type=str,
269
+ default="training/cogvideox-outsidephotos/checkpoint/model.safetensors",
270
+ help="directory containing weight files",
271
+ )
272
+ parser.add_argument(
273
+ "--pretrained_model_path",
274
+ type=str,
275
+ default="THUDM/CogVideoX-2b",
276
+ help="repo id or path for pretrained CogVideoX model",
277
+ )
278
+ parser.add_argument(
279
+ "--model_config_path",
280
+ type=str,
281
+ default="training/configs/outsidephotos.yaml",
282
+ help="path to model config yaml",
283
+ )
284
+ parser.add_argument(
285
+ "--output_path",
286
+ type=str,
287
+ required=True,
288
+ help="path to output",
289
+ )
290
+ parser.add_argument(
291
+ "--video_width",
292
+ type=int,
293
+ default=1280,
294
+ help="video resolution width",
295
+ )
296
+ parser.add_argument(
297
+ "--video_height",
298
+ type=int,
299
+ default=720,
300
+ help="video resolution height",
301
+ )
302
+ parser.add_argument(
303
+ "--seed",
304
+ type=int,
305
+ default=None,
306
+ help="random generator seed",
307
+ )
308
+ parser.add_argument(
309
+ "--device",
310
+ type=str,
311
+ default="cuda",
312
+ help="inference device",
313
+ )
314
+ args = parser.parse_args()
315
+ main(args)
316
+
317
+ # python inference.py --image_path assets/dummy_image.png --output_path output/
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spaces>=0.29.3
2
+ safetensors>=0.4.5
3
+ spandrel>=0.4.0
4
+ tqdm>=4.66.5
5
+ scikit-video>=1.1.11
6
+ git+https://github.com/huggingface/diffusers.git@main
7
+ transformers>=4.44.0
8
+ accelerate>=0.34.2
9
+ opencv-python>=4.10.0.84
10
+ sentencepiece>=0.2.0
11
+ numpy==1.26.0
12
+ torch>=2.4.0
13
+ torchvision>=0.19.0
14
+ gradio>=4.44.0
15
+ imageio>=2.34.2
16
+ imageio-ffmpeg>=0.5.1
17
+ openai>=1.45.0
18
+ moviepy>=1.0.3
19
+ pillow==9.5.0
20
+ denku==0.0.51
21
+ controlnet-aux==0.0.9
22
+ gradio>=4.44.0
setup/download_checkpoints.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+ import os
3
+ import sys
4
+ # Make sure HF_TOKEN is set in your env beforehand:
5
+ # export HF_TOKEN=your_hf_token
6
+ #get first command line argument
7
+
8
+
9
+ mode = sys.argv[1] if len(sys.argv) > 1 else "outsidephotos"
10
+
11
+
12
+ REPO_ID = "tedlasai/blur2vid"
13
+ REPO_TYPE = "model"
14
+
15
+ # These are the subfolders you previously used as path_in_repo
16
+ if mode == "outsidephotos":
17
+ checkpoints = [
18
+ "cogvideox-outsidephotos",
19
+ ]
20
+ elif mode == "gopro":
21
+ checkpoints = [
22
+ "cogvideox-gopro-test",
23
+ "cogvideox-gopro-2x-test",
24
+ ]
25
+ elif mode == "baist":
26
+ checkpoints = [
27
+ "cogvideox-baist-test",
28
+ ]
29
+ elif mode == "full":
30
+ checkpoints = [
31
+ "cogvideox-baist-test",
32
+ "cogvideox-gopro-test",
33
+ "cogvideox-gopro-2x-test",
34
+ "cogvideox-full-test",
35
+ "cogvideox-outsidephotos",
36
+ ]
37
+
38
+ # This is the root local directory where you want everything saved
39
+ #get path of this file
40
+ LOCAL_TRAINING_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "training")
41
+ os.makedirs(LOCAL_TRAINING_ROOT, exist_ok=True)
42
+
43
+ # Download only those folders from the repo and place them under LOCAL_TRAINING_ROOT
44
+ snapshot_download(
45
+ repo_id=REPO_ID,
46
+ repo_type=REPO_TYPE,
47
+ local_dir=LOCAL_TRAINING_ROOT,
48
+ local_dir_use_symlinks=False,
49
+ allow_patterns=[f"{name}/*" for name in checkpoints],
50
+ token=os.getenv("HF_TOKEN"),
51
+ )
52
+
53
+ print(f"Done! Checkpoints downloaded under: {LOCAL_TRAINING_ROOT}")
setup/download_cogvideo_weights.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+
3
+ # Download the entire model repository and store it locally
4
+ model_path = snapshot_download(repo_id="THUDM/CogVideoX-2b", cache_dir="./CogVideoX-2b")
5
+
6
+ print(f"Model downloaded to: {model_path}")
setup/environment.yaml ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: blur2vid
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=main
7
+ - _openmp_mutex=5.1=1_gnu
8
+ - asttokens=3.0.0=pyhd8ed1ab_1
9
+ - bzip2=1.0.8=h5eee18b_6
10
+ - ca-certificates=2025.4.26=hbd8a1cb_0
11
+ - comm=0.2.2=pyhd8ed1ab_1
12
+ - debugpy=1.6.0=py310hd8f1fbe_0
13
+ - entrypoints=0.4=pyhd8ed1ab_1
14
+ - exceptiongroup=1.2.2=pyhd8ed1ab_1
15
+ - executing=2.2.0=pyhd8ed1ab_0
16
+ - ffmpeg=4.3.2=hca11adc_0
17
+ - freetype=2.10.4=h0708190_1
18
+ - gmp=6.2.1=h58526e2_0
19
+ - gnutls=3.6.13=h85f3911_1
20
+ - ipykernel=6.20.2=pyh210e3f2_0
21
+ - ipython=8.36.0=pyh907856f_0
22
+ - jedi=0.19.2=pyhd8ed1ab_1
23
+ - jupyter_client=7.3.4=pyhd8ed1ab_0
24
+ - jupyter_core=5.7.2=pyh31011fe_1
25
+ - lame=3.100=h7f98852_1001
26
+ - ld_impl_linux-64=2.40=h12ee557_0
27
+ - libevent=2.1.12=hdbd6064_1
28
+ - libffi=3.4.4=h6a678d5_1
29
+ - libgcc-ng=11.2.0=h1234567_1
30
+ - libgomp=11.2.0=h1234567_1
31
+ - libpng=1.6.37=h21135ba_2
32
+ - libsodium=1.0.18=h36c2ea0_1
33
+ - libstdcxx-ng=11.2.0=h1234567_1
34
+ - libuuid=1.41.5=h5eee18b_0
35
+ - matplotlib-inline=0.1.7=pyhd8ed1ab_1
36
+ - ncurses=6.4=h6a678d5_0
37
+ - nest-asyncio=1.6.0=pyhd8ed1ab_1
38
+ - nettle=3.6=he412f7d_0
39
+ - openh264=2.1.1=h780b84a_0
40
+ - openssl=3.0.16=h5eee18b_0
41
+ - parso=0.8.4=pyhd8ed1ab_1
42
+ - pexpect=4.9.0=pyhd8ed1ab_1
43
+ - pickleshare=0.7.5=pyhd8ed1ab_1004
44
+ - pip=25.0=py310h06a4308_0
45
+ - platformdirs=4.3.7=pyh29332c3_0
46
+ - prompt-toolkit=3.0.51=pyha770c72_0
47
+ - ptyprocess=0.7.0=pyhd8ed1ab_1
48
+ - pure_eval=0.2.3=pyhd8ed1ab_1
49
+ - pygments=2.19.1=pyhd8ed1ab_0
50
+ - python=3.10.16=he870216_1
51
+ - python-dateutil=2.9.0.post0=pyhff2d567_1
52
+ - python_abi=3.10=2_cp310
53
+ - pyzmq=23.0.0=py310h330234f_0
54
+ - readline=8.2=h5eee18b_0
55
+ - setuptools=75.8.0=py310h06a4308_0
56
+ - six=1.17.0=pyhd8ed1ab_0
57
+ - sqlite=3.45.3=h5eee18b_0
58
+ - stack_data=0.6.3=pyhd8ed1ab_1
59
+ - tk=8.6.14=h39e8969_0
60
+ - tmux=3.3a=h5eee18b_1
61
+ - tornado=6.1=py310h5764c6d_3
62
+ - traitlets=5.14.3=pyhd8ed1ab_1
63
+ - typing_extensions=4.13.2=pyh29332c3_0
64
+ - wcwidth=0.2.13=pyhd8ed1ab_1
65
+ - wheel=0.45.1=py310h06a4308_0
66
+ - x264=1!161.3030=h7f98852_1
67
+ - xz=5.6.4=h5eee18b_1
68
+ - zeromq=4.3.4=h9c3ff4c_1
69
+ - zlib=1.2.13=h5eee18b_1
70
+ - pip:
71
+ - absl-py==2.2.0
72
+ - accelerate==1.5.2
73
+ - aiofiles==23.2.1
74
+ - aiohappyeyeballs==2.6.1
75
+ - aiohttp==3.12.14
76
+ - aiosignal==1.4.0
77
+ - annotated-types==0.7.0
78
+ - anyio==4.9.0
79
+ - async-timeout==5.0.1
80
+ - atomicwrites==1.4.1
81
+ - attrs==25.3.0
82
+ - beautifulsoup4==4.13.4
83
+ - certifi==2025.1.31
84
+ - cffi==1.17.1
85
+ - charset-normalizer==3.4.1
86
+ - click==8.1.8
87
+ - colour-science==0.4.6
88
+ - contourpy==1.3.1
89
+ - controlnet-aux==0.0.9
90
+ - cycler==0.12.1
91
+ - decorator==4.4.2
92
+ - decord==0.6.0
93
+ - denku==0.0.51
94
+ - diffusers==0.32.0
95
+ - distro==1.9.0
96
+ - docker-pycreds==0.4.0
97
+ - einops==0.8.1
98
+ - einops-exts==0.0.4
99
+ - fastapi==0.115.11
100
+ - ffmpeg-python==0.2.0
101
+ - ffmpy==0.5.0
102
+ - filelock==3.18.0
103
+ - flatbuffers==25.2.10
104
+ - fonttools==4.56.0
105
+ - frozenlist==1.7.0
106
+ - fsspec==2025.3.0
107
+ - future==1.0.0
108
+ - gdown==5.2.0
109
+ - gitdb==4.0.12
110
+ - gitpython==3.1.44
111
+ - gradio==5.22.0
112
+ - gradio-client==1.8.0
113
+ - groovy==0.1.2
114
+ - h11==0.14.0
115
+ - hf-transfer==0.1.9
116
+ - httpcore==1.0.7
117
+ - httpx==0.28.1
118
+ - huggingface-hub==0.29.3
119
+ - idna==3.10
120
+ - imageio==2.37.0
121
+ - imageio-ffmpeg==0.6.0
122
+ - importlib-metadata==8.6.1
123
+ - jax==0.5.3
124
+ - jaxlib==0.5.3
125
+ - jinja2==3.1.6
126
+ - jiter==0.9.0
127
+ - kiwisolver==1.4.8
128
+ - lazy-loader==0.4
129
+ - lightning==2.5.2
130
+ - lightning-utilities==0.14.3
131
+ - markdown-it-py==3.0.0
132
+ - markupsafe==3.0.2
133
+ - matplotlib==3.10.1
134
+ - mdurl==0.1.2
135
+ - mediapipe==0.10.21
136
+ - ml-dtypes==0.5.1
137
+ - moviepy==1.0.3
138
+ - mpmath==1.3.0
139
+ - multidict==6.6.3
140
+ - networkx==3.4.2
141
+ - numpy==1.26.0
142
+ - nvidia-cublas-cu12==12.4.5.8
143
+ - nvidia-cuda-cupti-cu12==12.4.127
144
+ - nvidia-cuda-nvrtc-cu12==12.4.127
145
+ - nvidia-cuda-runtime-cu12==12.4.127
146
+ - nvidia-cudnn-cu12==9.1.0.70
147
+ - nvidia-cufft-cu12==11.2.1.3
148
+ - nvidia-curand-cu12==10.3.5.147
149
+ - nvidia-cusolver-cu12==11.6.1.9
150
+ - nvidia-cusparse-cu12==12.3.1.170
151
+ - nvidia-cusparselt-cu12==0.6.2
152
+ - nvidia-ml-py==12.570.86
153
+ - nvidia-nccl-cu12==2.21.5
154
+ - nvidia-nvjitlink-cu12==12.4.127
155
+ - nvidia-nvtx-cu12==12.4.127
156
+ - nvitop==1.4.2
157
+ - openai==1.68.2
158
+ - opencv-contrib-python==4.11.0.86
159
+ - opencv-python==4.11.0.86
160
+ - opencv-python-headless==4.11.0.86
161
+ - opt-einsum==3.4.0
162
+ - orjson==3.10.15
163
+ - packaging==24.2
164
+ - pandas==2.2.3
165
+ - peft==0.15.0
166
+ - pillow==9.5.0
167
+ - proglog==0.1.10
168
+ - propcache==0.3.2
169
+ - protobuf==4.25.6
170
+ - psutil==5.9.8
171
+ - ptflops==0.7.4
172
+ - pycparser==2.22
173
+ - pydantic==2.10.6
174
+ - pydantic-core==2.27.2
175
+ - pydub==0.25.1
176
+ - pyparsing==3.2.1
177
+ - pysocks==1.7.1
178
+ - python-dotenv==1.0.1
179
+ - python-multipart==0.0.20
180
+ - pytorch-lightning==2.5.2
181
+ - pytz==2025.1
182
+ - pyyaml==6.0.2
183
+ - regex==2024.11.6
184
+ - requests==2.32.3
185
+ - rich==13.9.4
186
+ - ruff==0.11.2
187
+ - safehttpx==0.1.6
188
+ - safetensors==0.5.3
189
+ - scikit-image==0.24.0
190
+ - scikit-video==1.1.11
191
+ - scipy==1.15.2
192
+ - semantic-version==2.10.0
193
+ - sentencepiece==0.2.0
194
+ - sentry-sdk==2.24.0
195
+ - setproctitle==1.3.5
196
+ - shellingham==1.5.4
197
+ - smmap==5.0.2
198
+ - sniffio==1.3.1
199
+ - sounddevice==0.5.1
200
+ - soupsieve==2.7
201
+ - spaces==0.32.0
202
+ - spandrel==0.4.1
203
+ - starlette==0.46.1
204
+ - sympy==1.13.1
205
+ - tifffile==2025.3.13
206
+ - timm==0.6.7
207
+ - tokenizers==0.21.1
208
+ - tomlkit==0.13.2
209
+ - torch==2.6.0
210
+ - torch-fidelity==0.3.0
211
+ - torchmetrics==1.7.4
212
+ - torchvision==0.21.0
213
+ - tqdm==4.67.1
214
+ - transformers==4.50.0
215
+ - triton==3.2.0
216
+ - typer==0.15.2
217
+ - typing-extensions==4.12.2
218
+ - tzdata==2025.1
219
+ - urllib3==2.3.0
220
+ - uvicorn==0.34.0
221
+ - videoio==0.3.0
222
+ - wandb==0.19.8
223
+ - websockets==15.0.1
224
+ - yarl==1.20.1
225
+ - zipp==3.21.0
training/accelerator_configs/accelerate_test.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # accelerate_test.py
2
+ from accelerate import Accelerator
3
+ import os
4
+ print("MADE IT HERE")
5
+ # Force unbuffered printing
6
+ import sys; sys.stdout.reconfigure(line_buffering=True)
7
+
8
+ acc = Accelerator()
9
+ print(acc.num_processes )
10
+ print(
11
+ f"[host {os.uname().nodename}] "
12
+ f"global rank {acc.process_index}/{acc.num_processes}, "
13
+ f"local rank {acc.local_process_index}"
14
+ )
15
+
16
+ # Print out assigned CUDA device
17
+ print(f"Device: {acc.device}")
training/accelerator_configs/accelerator_multigpu.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Specify distributed_type as `MULTI_GPU` for DDP
2
+ distributed_type: "MULTI_GPU"
3
+ # Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`)
4
+ mixed_precision: "bf16"
5
+ # Specify the number of GPUs to use
6
+ num_processes: 4
training/accelerator_configs/accelerator_multinode.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ distributed_type: "MULTI_GPU"
2
+ mixed_precision: "bf16"
3
+ num_processes: 16
4
+ num_machines: 4
training/accelerator_configs/accelerator_singlegpu.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ main_process_port: 29501
3
+ debug: false
4
+ deepspeed_config:
5
+ gradient_accumulation_steps: 1
6
+ gradient_clipping: 1.0
7
+ offload_optimizer_device: none
8
+ offload_param_device: none
9
+ zero3_init_flag: false
10
+ zero_stage: 2
11
+ distributed_type: DEEPSPEED
12
+ downcast_bf16: 'no'
13
+ enable_cpu_affinity: false
14
+ machine_rank: 0
15
+ main_training_function: main
16
+ dynamo_backend: 'no'
17
+ mixed_precision: 'no'
18
+ num_machines: 1
19
+ num_processes: 1
20
+ rdzv_backend: static
21
+ same_network: true
22
+ tpu_env: []
23
+ tpu_use_cluster: false
24
+ tpu_use_sudo: false
25
+ use_cpu: false
training/accelerator_configs/accelerator_val_config.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ main_process_port: 29501
3
+ debug: false
4
+ deepspeed_config:
5
+ gradient_accumulation_steps: 1
6
+ gradient_clipping: 1.0
7
+ offload_optimizer_device: none
8
+ offload_param_device: none
9
+ zero3_init_flag: false
10
+ zero_stage: 2
11
+ distributed_type: DEEPSPEED
12
+ downcast_bf16: 'no'
13
+ enable_cpu_affinity: false
14
+ machine_rank: 0
15
+ main_training_function: main
16
+ dynamo_backend: 'no'
17
+ mixed_precision: 'no'
18
+ num_machines: 1
19
+ num_processes: 4
20
+ rdzv_backend: static
21
+ same_network: true
22
+ tpu_env: []
23
+ tpu_use_cluster: false
24
+ tpu_use_sudo: false
25
+ use_cpu: false
training/available-qos.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Name Priority GraceTime Preempt PreemptExemptTime PreemptMode Flags UsageThres UsageFactor GrpTRES GrpTRESMins GrpTRESRunMin GrpJobs GrpSubmit GrpWall MaxTRES MaxTRESPerNode MaxTRESMins MaxWall MaxTRESPU MaxJobsPU MaxSubmitPU MaxTRESPA MaxJobsPA MaxSubmitPA MinTRES
2
+ ---------- ---------- ---------- ---------- ------------------- ----------- ---------------------------------------- ---------- ----------- ------------- ------------- ------------- ------- --------- ----------- ------------- -------------- ------------- ----------- ------------- --------- ----------- ------------- --------- ----------- -------------
3
+ normal 0 00:00:00 cluster 1.000000
4
+ gpu1-32h 10000 00:00:00 scavenger cluster DenyOnLimit 1.000000 1-08:00:00 cpu=28,gres/+
5
+ gpu2-16h 10000 00:00:00 scavenger cluster DenyOnLimit 1.000000 16:00:00 cpu=56,gres/+
6
+ gpu4-8h 10000 00:00:00 scavenger cluster DenyOnLimit 1.000000 08:00:00 cpu=112,gres+
7
+ gpu8-4h 10000 00:00:00 scavenger cluster DenyOnLimit 1.000000 04:00:00 cpu=224,gres+
8
+ gpu16-2h 10000 00:00:00 scavenger cluster DenyOnLimit 1.000000 02:00:00 cpu=448,gres+
9
+ gpu32-1h 10000 00:00:00 scavenger cluster DenyOnLimit 1.000000 01:00:00 cpu=896,gres+
10
+ scavenger 0 00:00:00 01:00:00 cluster 0.250000
training/configs/baist_test.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # === Required or overridden ===
2
+ base_dir: "/datasets/sai/gencam/blur2vid"
3
+ pretrained_model_name_or_path: "CogVideoX-2b/models--THUDM--CogVideoX-2b/snapshots/1137dacfc2c9c012bed6a0793f4ecf2ca8e7ba01" # Replace with actual path or env var expansion
4
+ video_root_dir: "datasets/b-aist"
5
+ csv_path: "set-path-to-csv-file" # Replace with actual CSV path
6
+ output_dir: "cogvideox-baist-test"
7
+ tracker_name: "cogvideox-baist-test"
8
+
9
+ # === Data-related ===
10
+ stride_min: 1
11
+ stride_max: 3
12
+ hflip_p: 0.5
13
+ downscale_coef: 8
14
+ init_from_transformer: true
15
+ dataloader_num_workers: 32
16
+ val_split: "test"
17
+ dataset: "baist"
18
+
19
+ # === Validation ===
20
+ num_inference_steps: 50
21
+ validation_prompt: ""
22
+ validation_video: "../resources/car.mp4:::../resources/ship.mp4"
23
+ validation_prompt_separator: ":::"
24
+ num_validation_videos: 1
25
+ validation_steps: 400
26
+ guidance_scale: 1.1
27
+ use_dynamic_cfg: false
28
+ just_validate: true
29
+ special_info: "just_one"
30
+
31
+ # === Training ===
32
+ seed: 42
33
+ mixed_precision: "bf16"
34
+ height: 720
35
+ width: 1280
36
+ fps: 8
37
+ max_num_frames: 17
38
+ train_batch_size: 2
39
+ num_train_epochs: 100
40
+ max_train_steps: null
41
+ checkpointing_steps: 200
42
+ checkpoints_total_limit: null
43
+ gradient_accumulation_steps: 1
44
+ gradient_checkpointing: true
45
+ learning_rate: 0.0001
46
+ scale_lr: false
47
+ lr_scheduler: "constant"
48
+ lr_warmup_steps: 250
49
+ lr_num_cycles: 1
50
+ lr_power: 1.0
51
+ enable_slicing: true
52
+ enable_tiling: true
53
+
54
+ # === Optimizer ===
55
+ optimizer: "adamw"
56
+ use_8bit_adam: false
57
+ adam_beta1: 0.9
58
+ adam_beta2: 0.95
59
+ prodigy_beta3: null
60
+ prodigy_decouple: false
61
+ adam_weight_decay: 0.0001
62
+ adam_epsilon: 0.0000001
63
+ max_grad_norm: 1.0
64
+ prodigy_use_bias_correction: false
65
+ prodigy_safeguard_warmup: false
66
+
67
+ # === Logging & Reporting ===
68
+ push_to_hub: false
69
+ hub_token: null
70
+ hub_model_id: null
71
+ logging_dir: "logs"
72
+ allow_tf32: true
73
+ report_to: null
74
+
75
+ # === Optional HuggingFace model variant ===
76
+ revision: null
77
+ variant: null
training/configs/baist_train.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # === Required or overridden ===
2
+ base_dir: "/datasets/sai/gencam/blur2vid"
3
+ pretrained_model_name_or_path: "CogVideoX-2b/models--THUDM--CogVideoX-2b/snapshots/1137dacfc2c9c012bed6a0793f4ecf2ca8e7ba01" # Replace with actual path or env var expansion
4
+ video_root_dir: "datasets/b-aist"
5
+ csv_path: "set-path-to-csv-file" # Replace with actual CSV path
6
+ output_dir: "cogvideox-baist-train"
7
+ tracker_name: "cogvideox-baist-train"
8
+
9
+
10
+ # === Data-related ===
11
+ stride_min: 1
12
+ stride_max: 3
13
+ hflip_p: 0.5
14
+ downscale_coef: 8
15
+ init_from_transformer: true
16
+ dataloader_num_workers: 32
17
+ val_split: "val"
18
+ dataset: "baist"
19
+
20
+ # === Validation ===
21
+ num_inference_steps: 50
22
+ validation_prompt: ""
23
+ validation_video: "../resources/car.mp4:::../resources/ship.mp4"
24
+ validation_prompt_separator: ":::"
25
+ num_validation_videos: 1
26
+ validation_steps: 400
27
+ guidance_scale: 1.1
28
+ use_dynamic_cfg: false
29
+ just_validate: false
30
+ special_info: "just_one"
31
+
32
+ # === Training ===
33
+ seed: 42
34
+ mixed_precision: "bf16"
35
+ height: 720
36
+ width: 1280
37
+ fps: 8
38
+ max_num_frames: 17
39
+ train_batch_size: 2
40
+ num_train_epochs: 100
41
+ max_train_steps: null
42
+ checkpointing_steps: 200
43
+ checkpoints_total_limit: null
44
+ gradient_accumulation_steps: 1
45
+ gradient_checkpointing: true
46
+ learning_rate: 0.0001
47
+ scale_lr: false
48
+ lr_scheduler: "constant"
49
+ lr_warmup_steps: 250
50
+ lr_num_cycles: 1
51
+ lr_power: 1.0
52
+ enable_slicing: true
53
+ enable_tiling: true
54
+
55
+ # === Optimizer ===
56
+ optimizer: "adamw"
57
+ use_8bit_adam: false
58
+ adam_beta1: 0.9
59
+ adam_beta2: 0.95
60
+ prodigy_beta3: null
61
+ prodigy_decouple: false
62
+ adam_weight_decay: 0.0001
63
+ adam_epsilon: 0.0000001
64
+ max_grad_norm: 1.0
65
+ prodigy_use_bias_correction: false
66
+ prodigy_safeguard_warmup: false
67
+
68
+ # === Logging & Reporting ===
69
+ push_to_hub: false
70
+ hub_token: null
71
+ hub_model_id: null
72
+ logging_dir: "logs"
73
+ allow_tf32: true
74
+ report_to: null
75
+
76
+ # === Optional HuggingFace model variant ===
77
+ revision: null
78
+ variant: null
training/configs/full_test.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # === Required or overridden ===
2
+ base_dir: "/datasets/sai/gencam/blur2vid"
3
+ pretrained_model_name_or_path: "CogVideoX-2b/models--THUDM--CogVideoX-2b/snapshots/1137dacfc2c9c012bed6a0793f4ecf2ca8e7ba01" # Replace with actual path or env var expansion
4
+ video_root_dir: "datasets/FullDataset"
5
+ csv_path: "set-path-to-csv-file" # Replace with actual CSV path
6
+ output_dir: "cogvideox-full-test"
7
+ tracker_name: "cogvideox-full-test"
8
+
9
+
10
+ # === Data-related ===
11
+ stride_min: 1
12
+ stride_max: 3
13
+ hflip_p: 0.5
14
+ downscale_coef: 8
15
+ init_from_transformer: true
16
+ dataloader_num_workers: 32
17
+ val_split: "test"
18
+ dataset: "full"
19
+
20
+ # === Validation ===
21
+ num_inference_steps: 50
22
+ validation_prompt: ""
23
+ validation_video: "../resources/car.mp4:::../resources/ship.mp4"
24
+ validation_prompt_separator: ":::"
25
+ num_validation_videos: 1
26
+ validation_steps: 400
27
+ guidance_scale: 1.1
28
+ use_dynamic_cfg: false
29
+ just_validate: true
30
+ special_info: "just_one"
31
+
32
+ # === Training ===
33
+ seed: 42
34
+ mixed_precision: "bf16"
35
+ height: 720
36
+ width: 1280
37
+ fps: 8
38
+ max_num_frames: 17
39
+ train_batch_size: 2
40
+ num_train_epochs: 200
41
+ max_train_steps: null
42
+ checkpointing_steps: 200
43
+ checkpoints_total_limit: null
44
+ gradient_accumulation_steps: 2
45
+ gradient_checkpointing: true
46
+ learning_rate: 0.0001
47
+ scale_lr: false
48
+ lr_scheduler: "constant"
49
+ lr_warmup_steps: 250
50
+ lr_num_cycles: 1
51
+ lr_power: 1.0
52
+ enable_slicing: true
53
+ enable_tiling: true
54
+
55
+ # === Optimizer ===
56
+ optimizer: "adamw"
57
+ use_8bit_adam: false
58
+ adam_beta1: 0.9
59
+ adam_beta2: 0.95
60
+ prodigy_beta3: null
61
+ prodigy_decouple: false
62
+ adam_weight_decay: 0.0001
63
+ adam_epsilon: 0.0000001
64
+ max_grad_norm: 1.0
65
+ prodigy_use_bias_correction: false
66
+ prodigy_safeguard_warmup: false
67
+
68
+ # === Logging & Reporting ===
69
+ push_to_hub: false
70
+ hub_token: null
71
+ hub_model_id: null
72
+ logging_dir: "logs"
73
+ allow_tf32: true
74
+ report_to: null
75
+
76
+ # === Optional HuggingFace model variant ===
77
+ revision: null
78
+ variant: null
training/configs/full_train.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # === Required or overridden ===
2
+ base_dir: "/datasets/sai/gencam/blur2vid"
3
+ pretrained_model_name_or_path: "CogVideoX-2b/models--THUDM--CogVideoX-2b/snapshots/1137dacfc2c9c012bed6a0793f4ecf2ca8e7ba01" # Replace with actual path or env var expansion
4
+ video_root_dir: "datasets/FullDataset"
5
+ csv_path: "set-path-to-csv-file" # Replace with actual CSV path
6
+ output_dir: "cogvideox-full-train"
7
+ tracker_name: "cogvideox-full-train"
8
+
9
+
10
+ # === Data-related ===
11
+ stride_min: 1
12
+ stride_max: 3
13
+ hflip_p: 0.5
14
+ downscale_coef: 8
15
+ init_from_transformer: true
16
+ dataloader_num_workers: 2
17
+ val_split: "val"
18
+ dataset: "full"
19
+
20
+ # === Validation ===
21
+ num_inference_steps: 50
22
+ validation_prompt: ""
23
+ validation_video: "../resources/car.mp4:::../resources/ship.mp4"
24
+ validation_prompt_separator: ":::"
25
+ num_validation_videos: 1
26
+ validation_steps: 400
27
+ guidance_scale: 1.0
28
+ use_dynamic_cfg: false
29
+ just_validate: false
30
+ special_info: "just_one"
31
+
32
+ # === Training ===
33
+ seed: 42
34
+ mixed_precision: "bf16"
35
+ height: 720
36
+ width: 1280
37
+ fps: 8
38
+ max_num_frames: 17
39
+ train_batch_size: 2
40
+ num_train_epochs: 200
41
+ max_train_steps: null
42
+ checkpointing_steps: 200
43
+ checkpoints_total_limit: null
44
+ gradient_accumulation_steps: 2
45
+ gradient_checkpointing: true
46
+ learning_rate: 0.0001
47
+ scale_lr: false
48
+ lr_scheduler: "constant"
49
+ lr_warmup_steps: 250
50
+ lr_num_cycles: 1
51
+ lr_power: 1.0
52
+ enable_slicing: true
53
+ enable_tiling: true
54
+
55
+ # === Optimizer ===
56
+ optimizer: "adamw"
57
+ use_8bit_adam: false
58
+ adam_beta1: 0.9
59
+ adam_beta2: 0.95
60
+ prodigy_beta3: null
61
+ prodigy_decouple: false
62
+ adam_weight_decay: 0.0001
63
+ adam_epsilon: 0.0000001
64
+ max_grad_norm: 1.0
65
+ prodigy_use_bias_correction: false
66
+ prodigy_safeguard_warmup: false
67
+
68
+ # === Logging & Reporting ===
69
+ push_to_hub: false
70
+ hub_token: null
71
+ hub_model_id: null
72
+ logging_dir: "logs"
73
+ allow_tf32: true
74
+ report_to: null
75
+
76
+ # === Optional HuggingFace model variant ===
77
+ revision: null
78
+ variant: null
training/configs/gopro_2x_test.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # === Required or overridden ===
2
+ base_dir: "/datasets/sai/gencam/blur2vid"
3
+ pretrained_model_name_or_path: "CogVideoX-2b/models--THUDM--CogVideoX-2b/snapshots/1137dacfc2c9c012bed6a0793f4ecf2ca8e7ba01" # Replace with actual path or env var expansion
4
+ video_root_dir: "datasets/GOPRO_7"
5
+ csv_path: "set-path-to-csv-file" # Replace with actual CSV path
6
+ output_dir: "cogvideox-gopro-2x-test"
7
+ tracker_name: "cogvideox-gopro-2x-test"
8
+
9
+
10
+ # === Data-related ===
11
+ stride_min: 1
12
+ stride_max: 3
13
+ hflip_p: 0.5
14
+ downscale_coef: 8
15
+ init_from_transformer: true
16
+ dataloader_num_workers: 32
17
+ val_split: "test"
18
+ dataset: "gopro2x"
19
+
20
+ # === Validation ===
21
+ num_inference_steps: 50
22
+ validation_prompt: ""
23
+ validation_video: "../resources/car.mp4:::../resources/ship.mp4"
24
+ validation_prompt_separator: ":::"
25
+ num_validation_videos: 1
26
+ validation_steps: 400
27
+ guidance_scale: 1.1
28
+ use_dynamic_cfg: false
29
+ just_validate: true
30
+ special_info: "just_one"
31
+
32
+ # === Training ===
33
+ seed: 42
34
+ mixed_precision: "bf16"
35
+ height: 720
36
+ width: 1280
37
+ fps: 8
38
+ max_num_frames: 17
39
+ train_batch_size: 4
40
+ num_train_epochs: 100
41
+ max_train_steps: null
42
+ checkpointing_steps: 400
43
+ checkpoints_total_limit: null
44
+ gradient_accumulation_steps: 1
45
+ gradient_checkpointing: true
46
+ learning_rate: 0.0001
47
+ scale_lr: false
48
+ lr_scheduler: "constant"
49
+ lr_warmup_steps: 250
50
+ lr_num_cycles: 1
51
+ lr_power: 1.0
52
+ enable_slicing: true
53
+ enable_tiling: true
54
+
55
+ # === Optimizer ===
56
+ optimizer: "adamw"
57
+ use_8bit_adam: false
58
+ adam_beta1: 0.9
59
+ adam_beta2: 0.95
60
+ prodigy_beta3: null
61
+ prodigy_decouple: false
62
+ adam_weight_decay: 0.0001
63
+ adam_epsilon: 0.0000001
64
+ max_grad_norm: 1.0
65
+ prodigy_use_bias_correction: false
66
+ prodigy_safeguard_warmup: false
67
+
68
+ # === Logging & Reporting ===
69
+ push_to_hub: false
70
+ hub_token: null
71
+ hub_model_id: null
72
+ logging_dir: "logs"
73
+ allow_tf32: true
74
+ report_to: null
75
+
76
+ # === Optional HuggingFace model variant ===
77
+ revision: null
78
+ variant: null
training/configs/gopro_test.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # === Required or overridden ===
2
+ base_dir: "/datasets/sai/gencam/blur2vid"
3
+ pretrained_model_name_or_path: "CogVideoX-2b/models--THUDM--CogVideoX-2b/snapshots/1137dacfc2c9c012bed6a0793f4ecf2ca8e7ba01" # Replace with actual path or env var expansion
4
+ video_root_dir: "datasets/GOPRO_7"
5
+ csv_path: "set-path-to-csv-file" # Replace with actual CSV path
6
+ output_dir: "cogvideox-gopro-test"
7
+ tracker_name: "cogvideox-gopro-test"
8
+
9
+
10
+ # === Data-related ===
11
+ stride_min: 1
12
+ stride_max: 3
13
+ hflip_p: 0.5
14
+ downscale_coef: 8
15
+ init_from_transformer: true
16
+ dataloader_num_workers: 32
17
+ val_split: "test"
18
+ dataset: "gopro"
19
+
20
+ # === Validation ===
21
+ num_inference_steps: 50
22
+ validation_prompt: ""
23
+ validation_video: "../resources/car.mp4:::../resources/ship.mp4"
24
+ validation_prompt_separator: ":::"
25
+ num_validation_videos: 1
26
+ validation_steps: 400
27
+ guidance_scale: 1.1
28
+ use_dynamic_cfg: false
29
+ just_validate: true
30
+ special_info: "just_one"
31
+
32
+ # === Training ===
33
+ seed: 42
34
+ mixed_precision: "bf16"
35
+ height: 720
36
+ width: 1280
37
+ fps: 8
38
+ max_num_frames: 9
39
+ train_batch_size: 4
40
+ num_train_epochs: 500
41
+ max_train_steps: null
42
+ checkpointing_steps: 100
43
+ checkpoints_total_limit: null
44
+ gradient_accumulation_steps: 1
45
+ gradient_checkpointing: true
46
+ learning_rate: 0.0001
47
+ scale_lr: false
48
+ lr_scheduler: "constant"
49
+ lr_warmup_steps: 250
50
+ lr_num_cycles: 1
51
+ lr_power: 1.0
52
+ enable_slicing: true
53
+ enable_tiling: true
54
+
55
+ # === Optimizer ===
56
+ optimizer: "adamw"
57
+ use_8bit_adam: false
58
+ adam_beta1: 0.9
59
+ adam_beta2: 0.95
60
+ prodigy_beta3: null
61
+ prodigy_decouple: false
62
+ adam_weight_decay: 0.0001
63
+ adam_epsilon: 0.0000001
64
+ max_grad_norm: 1.0
65
+ prodigy_use_bias_correction: false
66
+ prodigy_safeguard_warmup: false
67
+
68
+ # === Logging & Reporting ===
69
+ push_to_hub: false
70
+ hub_token: null
71
+ hub_model_id: null
72
+ logging_dir: "logs"
73
+ allow_tf32: true
74
+ report_to: null
75
+
76
+ # === Optional HuggingFace model variant ===
77
+ revision: null
78
+ variant: null
training/configs/gopro_train.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # === Required or overridden ===
2
+ base_dir: "/datasets/sai/gencam/blur2vid"
3
+ pretrained_model_name_or_path: "CogVideoX-2b/models--THUDM--CogVideoX-2b/snapshots/1137dacfc2c9c012bed6a0793f4ecf2ca8e7ba01" # Replace with actual path or env var expansion
4
+ video_root_dir: "datasets/GOPRO_7"
5
+ csv_path: "set-path-to-csv-file" # Replace with actual CSV path
6
+ output_dir: "cogvideox-gopro-train"
7
+ tracker_name: "cogvideox-gopro-train"
8
+
9
+ # === Data-related ===
10
+ stride_min: 1
11
+ stride_max: 3
12
+ hflip_p: 0.5
13
+ downscale_coef: 8
14
+ init_from_transformer: true
15
+ dataloader_num_workers: 2
16
+ val_split: "val"
17
+ dataset: "gopro"
18
+
19
+ # === Validation ===
20
+ num_inference_steps: 50
21
+ validation_prompt: ""
22
+ validation_video: "../resources/car.mp4:::../resources/ship.mp4"
23
+ validation_prompt_separator: ":::"
24
+ num_validation_videos: 1
25
+ validation_steps: 100
26
+ guidance_scale: 1.0
27
+ use_dynamic_cfg: false
28
+ just_validate: false
29
+ special_info: "just_one"
30
+
31
+ # === Training ===
32
+ seed: 42
33
+ mixed_precision: "bf16"
34
+ height: 720
35
+ width: 1280
36
+ fps: 8
37
+ max_num_frames: 9
38
+ train_batch_size: 4
39
+ num_train_epochs: 500
40
+ max_train_steps: null
41
+ checkpointing_steps: 100
42
+ checkpoints_total_limit: null
43
+ gradient_accumulation_steps: 1
44
+ gradient_checkpointing: true
45
+ learning_rate: 0.0001
46
+ scale_lr: false
47
+ lr_scheduler: "constant"
48
+ lr_warmup_steps: 250
49
+ lr_num_cycles: 1
50
+ lr_power: 1.0
51
+ enable_slicing: true
52
+ enable_tiling: true
53
+
54
+ # === Optimizer ===
55
+ optimizer: "adamw"
56
+ use_8bit_adam: false
57
+ adam_beta1: 0.9
58
+ adam_beta2: 0.95
59
+ prodigy_beta3: null
60
+ prodigy_decouple: false
61
+ adam_weight_decay: 0.0001
62
+ adam_epsilon: 0.0000001
63
+ max_grad_norm: 1.0
64
+ prodigy_use_bias_correction: false
65
+ prodigy_safeguard_warmup: false
66
+
67
+ # === Logging & Reporting ===
68
+ push_to_hub: false
69
+ hub_token: null
70
+ hub_model_id: null
71
+ logging_dir: "logs"
72
+ allow_tf32: true
73
+ report_to: null
74
+
75
+ # === Optional HuggingFace model variant ===
76
+ revision: null
77
+ variant: null
training/configs/outsidephotos.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # === Required or overridden ===
2
+ base_dir: "/datasets/sai/gencam/blur2vid"
3
+ pretrained_model_name_or_path: "cogvideox/CogVideoX-2b/models--THUDM--CogVideoX-2b/snapshots/1137dacfc2c9c012bed6a0793f4ecf2ca8e7ba01" # Replace with actual path or env var expansion
4
+ video_root_dir: "datasets/my_motion_blurred_images"
5
+ csv_path: "set-path-to-csv-file" # Replace with actual CSV path
6
+ output_dir: "cogvideox-outsidephotos"
7
+ tracker_name: "cogvideox-outsidephotos"
8
+
9
+ # === Data-related ===
10
+ stride_min: 1
11
+ stride_max: 3
12
+ hflip_p: 0.5
13
+ downscale_coef: 8
14
+ init_from_transformer: true
15
+ dataloader_num_workers: 0
16
+ val_split: "test"
17
+ dataset: "outsidephotos"
18
+
19
+ # === Validation ===
20
+ num_inference_steps: 50
21
+ just_validate: true
22
+ validation_prompt: ""
23
+ validation_video: "../resources/car.mp4:::../resources/ship.mp4"
24
+ validation_prompt_separator: ":::"
25
+ num_validation_videos: 1
26
+ validation_steps: 100
27
+ guidance_scale: 1.1
28
+ use_dynamic_cfg: false
29
+
30
+ # === Training ===
31
+ seed: 42
32
+ mixed_precision: "bf16"
33
+ height: 720
34
+ width: 1280
35
+ fps: 8
36
+ max_num_frames: 17
37
+ train_batch_size: 1
38
+ num_train_epochs: 100
39
+ max_train_steps: null
40
+ checkpointing_steps: 100
41
+ checkpoints_total_limit: null
42
+ gradient_accumulation_steps: 1
43
+ gradient_checkpointing: true
44
+ learning_rate: 0.0001
45
+ scale_lr: false
46
+ lr_scheduler: "constant"
47
+ lr_warmup_steps: 250
48
+ lr_num_cycles: 1
49
+ lr_power: 1.0
50
+ enable_slicing: true
51
+ enable_tiling: true
52
+
53
+ # === Optimizer ===
54
+ optimizer: "adamw"
55
+ use_8bit_adam: false
56
+ adam_beta1: 0.9
57
+ adam_beta2: 0.95
58
+ prodigy_beta3: null
59
+ prodigy_decouple: false
60
+ adam_weight_decay: 0.0001
61
+ adam_epsilon: 0.0000001
62
+ max_grad_norm: 1.0
63
+ prodigy_use_bias_correction: false
64
+ prodigy_safeguard_warmup: false
65
+
66
+ # === Logging & Reporting ===
67
+ push_to_hub: false
68
+ hub_token: null
69
+ hub_model_id: null
70
+ logging_dir: "logs"
71
+ allow_tf32: true
72
+ report_to: null
73
+
74
+ # === Optional HuggingFace model variant ===
75
+ revision: null
76
+ variant: null
training/controlnet_datasets.py ADDED
@@ -0,0 +1,735 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import glob
4
+ from pathlib import Path
5
+ import pickle
6
+ import random
7
+ import time
8
+
9
+
10
+ import cv2
11
+ import torch
12
+ import numpy as np
13
+ import pandas as pd
14
+ import torchvision.transforms as transforms
15
+ from PIL import Image, ImageOps, ImageCms
16
+ from decord import VideoReader
17
+ from torch.utils.data.dataset import Dataset
18
+ from controlnet_aux import CannyDetector, HEDdetector
19
+ import torch.nn.functional as F
20
+ from helpers import generate_1x_sequence, generate_2x_sequence, generate_large_blur_sequence, generate_test_case
21
+
22
+
23
+ def unpack_mm_params(p):
24
+ if isinstance(p, (tuple, list)):
25
+ return p[0], p[1]
26
+ elif isinstance(p, (int, float)):
27
+ return p, p
28
+ raise Exception(f'Unknown input parameter type.\nParameter: {p}.\nType: {type(p)}')
29
+
30
+
31
+ def resize_for_crop(image, min_h, min_w):
32
+ img_h, img_w = image.shape[-2:]
33
+
34
+ if img_h >= min_h and img_w >= min_w:
35
+ coef = min(min_h / img_h, min_w / img_w)
36
+ elif img_h <= min_h and img_w <=min_w:
37
+ coef = max(min_h / img_h, min_w / img_w)
38
+ else:
39
+ coef = min_h / img_h if min_h > img_h else min_w / img_w
40
+
41
+ out_h, out_w = int(img_h * coef), int(img_w * coef)
42
+ resized_image = transforms.functional.resize(image, (out_h, out_w), antialias=True)
43
+ return resized_image
44
+
45
+
46
+
47
+ class BaseClass(Dataset):
48
+ def __init__(
49
+ self,
50
+ data_dir,
51
+ output_dir,
52
+ image_size=(320, 512),
53
+ hflip_p=0.5,
54
+ controlnet_type='canny',
55
+ split='train',
56
+ *args,
57
+ **kwargs
58
+ ):
59
+ self.split = split
60
+ self.height, self.width = unpack_mm_params(image_size)
61
+ self.data_dir = data_dir
62
+ self.output_dir = output_dir
63
+ self.hflip_p = hflip_p
64
+ self.image_size = image_size
65
+ self.length = 0
66
+
67
+ def __len__(self):
68
+ return self.length
69
+
70
+
71
+ def load_frames(self, frames):
72
+ # frames: numpy array of shape (N, H, W, C), 0–255
73
+ # → tensor of shape (N, C, H, W) as float
74
+ pixel_values = torch.from_numpy(frames).permute(0, 3, 1, 2).contiguous().float()
75
+ # normalize to [-1, 1]
76
+ pixel_values = pixel_values / 127.5 - 1.0
77
+ # resize to (self.height, self.width)
78
+ pixel_values = F.interpolate(
79
+ pixel_values,
80
+ size=(self.height, self.width),
81
+ mode="bilinear",
82
+ align_corners=False
83
+ )
84
+ return pixel_values
85
+
86
+ def get_batch(self, idx):
87
+ raise Exception('Get batch method is not realized.')
88
+
89
+ def __getitem__(self, idx):
90
+ while True:
91
+ try:
92
+ video, caption, motion_blur = self.get_batch(idx)
93
+ break
94
+ except Exception as e:
95
+ print(e)
96
+ idx = random.randint(0, self.length - 1)
97
+
98
+ video, = [
99
+ resize_for_crop(x, self.height, self.width) for x in [video]
100
+ ]
101
+ video, = [
102
+ transforms.functional.center_crop(x, (self.height, self.width)) for x in [video]
103
+ ]
104
+ data = {
105
+ 'video': video,
106
+ 'caption': caption,
107
+ }
108
+ return data
109
+
110
+ def load_as_srgb(path):
111
+ img = Image.open(path)
112
+ img = ImageOps.exif_transpose(img)
113
+
114
+ if 'icc_profile' in img.info:
115
+ icc = img.info['icc_profile']
116
+ src_profile = ImageCms.ImageCmsProfile(io.BytesIO(icc))
117
+ dst_profile = ImageCms.createProfile("sRGB")
118
+ img = ImageCms.profileToProfile(img, src_profile, dst_profile, outputMode='RGB')
119
+ else:
120
+ img = img.convert("RGB") # Assume sRGB
121
+ return img
122
+
123
+ class GoProMotionBlurDataset(BaseClass): #7 frame go pro dataset
124
+ def __init__(self,
125
+ *args, **kwargs):
126
+ super().__init__(*args, **kwargs)
127
+ # Set blur and sharp directories based on split
128
+ if self.split == 'train':
129
+ self.blur_root = os.path.join(self.data_dir, 'train', 'blur')
130
+ self.sharp_root = os.path.join(self.data_dir, 'train', 'sharp')
131
+ elif self.split in ['val', 'test']:
132
+ self.blur_root = os.path.join(self.data_dir, 'test', 'blur')
133
+ self.sharp_root = os.path.join(self.data_dir, 'test', 'sharp')
134
+ else:
135
+ raise ValueError(f"Unsupported split: {self.split}")
136
+
137
+ # Collect all blurred image paths
138
+ pattern = os.path.join(self.blur_root, '*', '*.png')
139
+
140
+ self.blur_paths = sorted(glob.glob(pattern))
141
+
142
+ if self.split == 'val':
143
+ # Optional: limit validation subset
144
+ self.blur_paths = self.blur_paths[:5]
145
+
146
+ filtered_blur_paths = []
147
+ for path in self.blur_paths:
148
+ output_deblurred_dir = os.path.join(self.output_dir, "deblurred")
149
+ full_output_path = Path(output_deblurred_dir, *path.split('/')[-2:]).with_suffix(".mp4")
150
+ if not os.path.exists(full_output_path):
151
+ filtered_blur_paths.append(path)
152
+
153
+ # Window and padding parameters
154
+ self.window_size = 7 # original number of sharp frames
155
+ self.pad = 2 # number of times to repeat last frame
156
+ self.output_length = self.window_size + self.pad
157
+ self.half_window = self.window_size // 2
158
+ self.length = len(self.blur_paths)
159
+
160
+ # Normalized input interval: always [-0.5, 0.5]
161
+ self.input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float)
162
+
163
+ # Precompute normalized output intervals: first for window_size frames, then pad duplicates
164
+ step = 1.0 / (self.window_size - 1)
165
+ # intervals for the original 7 frames
166
+ window_intervals = []
167
+ for i in range(self.window_size):
168
+ start = -0.5 + i * step
169
+ if i < self.window_size - 1:
170
+ end = -0.5 + (i + 1) * step
171
+ else:
172
+ end = 0.5
173
+ window_intervals.append([start, end])
174
+ # append the last interval pad times
175
+ intervals = window_intervals + [window_intervals[-1]] * self.pad
176
+ self.output_interval = torch.tensor(intervals, dtype=torch.float)
177
+
178
+ def __len__(self):
179
+ return self.length
180
+
181
+ def __getitem__(self, idx):
182
+ # Path to the blurred (center) frame
183
+ blur_path = self.blur_paths[idx]
184
+ seq_name = os.path.basename(os.path.dirname(blur_path))
185
+ frame_name = os.path.basename(blur_path)
186
+ center_idx = int(os.path.splitext(frame_name)[0])
187
+
188
+ # Compute sharp frame range [center-half, center+half]
189
+ start_idx = center_idx - self.half_window
190
+ end_idx = center_idx + self.half_window
191
+
192
+ # Load sharp frames
193
+ sharp_dir = os.path.join(self.sharp_root, seq_name)
194
+ frames = []
195
+ for i in range(start_idx, end_idx + 1):
196
+ sharp_filename = f"{i:06d}.png"
197
+ sharp_path = os.path.join(sharp_dir, sharp_filename)
198
+ img = Image.open(sharp_path).convert('RGB')
199
+ frames.append(img)
200
+
201
+ # Repeat last sharp frame so total frames == output_length
202
+ while len(frames) < self.output_length:
203
+ frames.append(frames[-1])
204
+
205
+ # Load blurred image
206
+ blur_img = Image.open(blur_path).convert('RGB')
207
+
208
+ # Convert to pixel values via BaseClass loader
209
+ video = self.load_frames(np.array(frames)) # shape: (output_length, H, W, C)
210
+ blur_input = self.load_frames(np.expand_dims(np.array(blur_img), 0)) # shape: (1, H, W, C)
211
+ end_time = time.time()
212
+ data = {
213
+ 'file_name': os.path.join(seq_name, frame_name),
214
+ 'blur_img': blur_input,
215
+ 'video': video,
216
+ "caption": "",
217
+ 'motion_blur_amount': torch.tensor(self.half_window, dtype=torch.long),
218
+ 'input_interval': self.input_interval,
219
+ 'output_interval': self.output_interval,
220
+ "num_frames": self.window_size,
221
+ "mode": "1x",
222
+ }
223
+ return data
224
+
225
+
226
+ class OutsidePhotosDataset(BaseClass):
227
+ def __init__(self, *args, **kwargs):
228
+ super().__init__(*args, **kwargs)
229
+ self.image_paths = sorted(glob.glob(os.path.join(self.data_dir, '**', '*.*'), recursive=True))
230
+
231
+ INTERVALS = [
232
+ {"in_start": 0, "in_end": 16, "out_start": 0, "out_end": 16, "center": 8, "window_size": 16, "mode": "1x", "fps": 240},
233
+ {"in_start": 4, "in_end": 12, "out_start": 0, "out_end": 16, "center": 8, "window_size": 16, "mode": "2x", "fps": 240},]
234
+ #other modes commented out for faster processing
235
+ #{"in_start": 0, "in_end": 4, "out_start": 0, "out_end": 4, "center": 2, "window_size": 4, "mode": "1x", "fps": 240},
236
+ #{"in_start": 0, "in_end": 8, "out_start": 0, "out_end": 8, "center": 4, "window_size": 8, "mode": "1x", "fps": 240},
237
+ #{"in_start": 0, "in_end": 12, "out_start": 0, "out_end": 12, "center": 6, "window_size": 12, "mode": "1x", "fps": 240},
238
+ #{"in_start": 0, "in_end": 32, "out_start": 0, "out_end": 32, "center": 12, "window_size": 32, "mode": "lb", "fps": 120}
239
+ #{"in_start": 0, "in_end": 48, "out_start": 0, "out_end": 48, "center": 24, "window_size": 48, "mode": "lb", "fps": 80}
240
+
241
+
242
+ self.cleaned_intervals = []
243
+ for image_path in self.image_paths:
244
+ for interval in INTERVALS:
245
+ #create a copy of the interval dictionary
246
+ i = interval.copy()
247
+ #add the image path to the interval dictionary
248
+ i['video_name'] = image_path
249
+ video_name = i['video_name']
250
+ mode = i['mode']
251
+
252
+ vid_name_w_extension = os.path.relpath(video_name, self.data_dir).split('.')[0] # "frame_00000"
253
+ output_name = (
254
+ f"{vid_name_w_extension}_{mode}.mp4"
255
+ )
256
+
257
+ full_output_path = os.path.join("/datasets/sai/gencam/cogvideox/training/cogvideox-outsidephotos/deblurred", output_name) #THIS IS A HACK - YOU NEED TO UPDATE THIS TO YOUR OUTPUT DIRECTORY
258
+
259
+ # Keep only if output doesn't exist
260
+ if not os.path.exists(full_output_path):
261
+ self.cleaned_intervals.append(i)
262
+
263
+
264
+ self.length = len(self.cleaned_intervals)
265
+
266
+ def __len__(self):
267
+ return self.length
268
+
269
+ def __getitem__(self, idx):
270
+
271
+ interval = self.cleaned_intervals[idx]
272
+
273
+ in_start = interval['in_start']
274
+ in_end = interval['in_end']
275
+ out_start = interval['out_start']
276
+ out_end = interval['out_end']
277
+ center = interval['center']
278
+ window = interval['window_size']
279
+ mode = interval['mode']
280
+ fps = interval['fps']
281
+
282
+
283
+ image_path = interval['video_name']
284
+ blur_img_original = load_as_srgb(image_path)
285
+ H,W = blur_img_original.size
286
+
287
+ frame_paths = []
288
+ frame_paths = ["../assets/dummy_image.png" for _ in range(window)] #any random path replicated
289
+
290
+ # generate test case
291
+ _, seq_frames, inp_int, out_int, high_fps_video, num_frames = generate_test_case(
292
+ frame_paths=frame_paths, window_max=window, in_start=in_start, in_end=in_end, out_start=out_start,out_end=out_end, center=center, mode=mode, fps=fps
293
+ )
294
+ file_name = image_path
295
+
296
+ # Get base directory and frame prefix
297
+ relative_file_name = os.path.relpath(file_name, self.data_dir)
298
+ base_dir = os.path.dirname(relative_file_name)
299
+ frame_stem = os.path.splitext(os.path.basename(file_name))[0] # "frame_00000"
300
+ # Build new filename
301
+ new_filename = (
302
+ f"{frame_stem}_{mode}.png"
303
+ )
304
+
305
+ blur_img =blur_img_original.resize((self.image_size[1], self.image_size[0])) #cause pil is width, height
306
+
307
+ # Final path
308
+ relative_file_name = os.path.join(base_dir, new_filename)
309
+
310
+
311
+ blur_input = self.load_frames(np.expand_dims(blur_img, 0).copy())
312
+ # seq_frames is list of frames; stack along time dim
313
+ video = self.load_frames(np.stack(seq_frames, axis=0))
314
+
315
+
316
+ data = {
317
+ 'file_name': relative_file_name,
318
+ "original_size": (H, W),
319
+ 'blur_img': blur_input,
320
+ 'video': video,
321
+ 'caption': "",
322
+ 'input_interval': inp_int,
323
+ 'output_interval': out_int,
324
+ "num_frames": num_frames,
325
+ }
326
+ return data
327
+
328
+ class FullMotionBlurDataset(BaseClass):
329
+ """
330
+ A dataset that randomly selects among 1×, 2×, or large-blur modes per sample.
331
+ Uses category-specific <split>_list.txt files under each subfolder of FullDataset to assemble sequences.
332
+ In 'test' split, it instead loads precomputed intervals from intervals_test.pkl and uses generate_test_case.
333
+ """
334
+ def __init__(self, *args, **kwargs):
335
+ super().__init__(*args, **kwargs)
336
+ self.seq_dirs = []
337
+
338
+ # TEST split: load fixed intervals early
339
+ if self.split == 'test':
340
+ pkl_path = os.path.join(self.data_dir, 'intervals_test.pkl')
341
+ with open(pkl_path, 'rb') as f:
342
+ self.test_intervals = pickle.load(f)
343
+ assert self.test_intervals, f"No test intervals found in {pkl_path}"
344
+
345
+ cleaned_intervals = []
346
+ for interval in self.test_intervals:
347
+ # Extract interval values
348
+ in_start = interval['in_start']
349
+ in_end = interval['in_end']
350
+ out_start = interval['out_start']
351
+ out_end = interval['out_end']
352
+ center = interval['center']
353
+ window = interval['window_size']
354
+ mode = interval['mode']
355
+ fps = interval['fps'] # e.g. "lower_fps_frames/720p_240fps_1/frame_00247.png"
356
+ category, seq = interval['video_name'].split('/')#.split('/')
357
+ seq_dir = os.path.join(self.data_dir, category, 'lower_fps_frames', seq)
358
+ frame_paths = sorted(glob.glob(os.path.join(seq_dir, '*.png')))
359
+ rel_path = os.path.relpath(frame_paths[center], self.data_dir)
360
+ rel_path = os.path.splitext(rel_path)[0] # remove the file extension
361
+
362
+ output_name = (
363
+ f"{rel_path}_"
364
+ f"in{in_start:04d}_ie{in_end:04d}_"
365
+ f"os{out_start:04d}_oe{out_end:04d}_"
366
+ f"ctr{center:04d}_win{window:04d}_"
367
+ f"fps{fps:04d}_{mode}.mp4"
368
+ )
369
+ output_deblurred_dir = os.path.join(self.output_dir, "deblurred")
370
+ full_output_path = os.path.join(output_deblurred_dir, output_name)
371
+
372
+ # Keep only if output doesn't exist
373
+ if not os.path.exists(full_output_path):
374
+ cleaned_intervals.append(interval)
375
+ print("Len of test intervals after cleaning: ", len(cleaned_intervals))
376
+ print("Len of test intervals before cleaning: ", len(self.test_intervals))
377
+ self.test_intervals = cleaned_intervals
378
+
379
+
380
+ # TRAIN/VAL: build seq_dirs from each category's list or fallback
381
+ list_file = 'train_list.txt' if self.split == 'train' else 'test_list.txt'
382
+ for category in sorted(os.listdir(self.data_dir)):
383
+ cat_dir = os.path.join(self.data_dir, category)
384
+ if not os.path.isdir(cat_dir):
385
+ continue
386
+ list_path = os.path.join(cat_dir, list_file)
387
+ if os.path.isfile(list_path):
388
+ with open(list_path, 'r') as f:
389
+ for line in f:
390
+ rel = line.strip()
391
+ if not rel:
392
+ continue
393
+ seq_dir = os.path.join(self.data_dir, rel)
394
+ if os.path.isdir(seq_dir):
395
+ self.seq_dirs.append(seq_dir)
396
+ else:
397
+ fps_root = os.path.join(cat_dir, 'lower_fps_frames')
398
+ if os.path.isdir(fps_root):
399
+ for seq in sorted(os.listdir(fps_root)):
400
+ seq_path = os.path.join(fps_root, seq)
401
+ if os.path.isdir(seq_path):
402
+ self.seq_dirs.append(seq_path)
403
+
404
+ if self.split == 'val':
405
+ self.seq_dirs = self.seq_dirs[:5]
406
+ if self.split == 'train':
407
+ self.seq_dirs *= 10
408
+
409
+ assert self.seq_dirs, \
410
+ f"No sequences found for split '{self.split}' in {self.data_dir}"
411
+
412
+ def __len__(self):
413
+ return len(self.test_intervals) if self.split == 'test' else len(self.seq_dirs)
414
+
415
+ def __getitem__(self, idx):
416
+ # Prepare base items
417
+ if self.split == 'test':
418
+ interval = self.test_intervals[idx]
419
+ category, seq = interval['video_name'].split('/')
420
+ seq_dir = os.path.join(self.data_dir, category, 'lower_fps_frames', seq)
421
+ frame_paths = sorted(glob.glob(os.path.join(seq_dir, '*.png')))
422
+
423
+ in_start = interval['in_start']
424
+ in_end = interval['in_end']
425
+ out_start = interval['out_start']
426
+ out_end = interval['out_end']
427
+ center = interval['center']
428
+ window = interval['window_size']
429
+ mode = interval['mode']
430
+ fps = interval['fps']
431
+
432
+ # generate test case
433
+ blur_img, seq_frames, inp_int, out_int, high_fps_video, num_frames = generate_test_case(
434
+ frame_paths=frame_paths, window_max=window, in_start=in_start, in_end=in_end, out_start=out_start,out_end=out_end, center=center, mode=mode, fps=fps
435
+ )
436
+ file_name = frame_paths[center]
437
+
438
+ else:
439
+ seq_dir = self.seq_dirs[idx]
440
+ frame_paths = sorted(glob.glob(os.path.join(seq_dir, '*.png')))
441
+ mode = random.choice(['1x', '2x', 'large_blur'])
442
+
443
+ if mode == '1x' or len(frame_paths) < 50:
444
+ base_rate = random.choice([1, 2])
445
+ blur_img, seq_frames, inp_int, out_int, _ = generate_1x_sequence(
446
+ frame_paths, window_max=16, output_len=17, base_rate=base_rate
447
+ )
448
+ elif mode == '2x':
449
+ base_rate = random.choice([1, 2])
450
+ blur_img, seq_frames, inp_int, out_int, _ = generate_2x_sequence(
451
+ frame_paths, window_max=16, output_len=17, base_rate=base_rate
452
+ )
453
+ else:
454
+ max_base = min((len(frame_paths) - 1) // 17, 3)
455
+ base_rate = random.randint(1, max_base)
456
+ blur_img, seq_frames, inp_int, out_int, _ = generate_large_blur_sequence(
457
+ frame_paths, window_max=16, output_len=17, base_rate=base_rate
458
+ )
459
+ file_name = frame_paths[0]
460
+ num_frames = 16
461
+
462
+ # blur_img is a single frame; wrap in batch dim
463
+ blur_input = self.load_frames(np.expand_dims(blur_img, 0))
464
+ # seq_frames is list of frames; stack along time dim
465
+ video = self.load_frames(np.stack(seq_frames, axis=0))
466
+
467
+
468
+ relative_file_name = os.path.relpath(file_name, self.data_dir)
469
+
470
+ if self.split == 'test':
471
+ # Get base directory and frame prefix
472
+ base_dir = os.path.dirname(relative_file_name)
473
+ frame_stem = os.path.splitext(os.path.basename(relative_file_name))[0] # "frame_00000"
474
+
475
+ # Build new filename
476
+ new_filename = (
477
+ f"{frame_stem}_"
478
+ f"in{in_start:04d}_ie{in_end:04d}_"
479
+ f"os{out_start:04d}_oe{out_end:04d}_"
480
+ f"ctr{center:04d}_win{window:04d}_"
481
+ f"fps{fps:04d}_{mode}.png"
482
+ )
483
+
484
+ # Final path
485
+ relative_file_name = os.path.join(base_dir, new_filename)
486
+
487
+ data = {
488
+ 'file_name': relative_file_name,
489
+ 'blur_img': blur_input,
490
+ 'num_frames': num_frames,
491
+ 'video': video,
492
+ 'caption': "",
493
+ 'mode': mode,
494
+ 'input_interval': inp_int,
495
+ 'output_interval': out_int,
496
+ }
497
+ if self.split == 'test':
498
+ high_fps_video = self.load_frames(np.stack(high_fps_video, axis=0))
499
+ data['high_fps_video'] = high_fps_video
500
+ return data
501
+
502
+ class GoPro2xMotionBlurDataset(BaseClass):
503
+ def __init__(self,
504
+ *args, **kwargs):
505
+ super().__init__(*args, **kwargs)
506
+ # Set blur and sharp directories based on split
507
+ if self.split == 'train':
508
+ self.blur_root = os.path.join(self.data_dir, 'train', 'blur')
509
+ self.sharp_root = os.path.join(self.data_dir, 'train', 'sharp')
510
+ elif self.split in ['val', 'test']:
511
+ self.blur_root = os.path.join(self.data_dir, 'test', 'blur')
512
+ self.sharp_root = os.path.join(self.data_dir, 'test', 'sharp')
513
+ else:
514
+ raise ValueError(f"Unsupported split: {self.split}")
515
+
516
+ # Collect all blurred image paths
517
+ pattern = os.path.join(self.blur_root, '*', '*.png')
518
+
519
+ def get_sharp_paths(blur_paths):
520
+ sharp_paths = []
521
+ for blur_path in blur_paths:
522
+ base_dir = blur_path.replace('/blur/', '/sharp/')
523
+ frame_num = int(os.path.basename(blur_path).split('.')[0])
524
+ dir_path = os.path.dirname(base_dir)
525
+ sequence = [
526
+ os.path.join(dir_path, f"{frame_num + offset:06d}.png")
527
+ for offset in range(-6, 7)
528
+ ]
529
+ if all(os.path.exists(path) for path in sequence):
530
+ sharp_paths.append(sequence)
531
+ return sharp_paths
532
+
533
+
534
+
535
+
536
+ self.blur_paths = sorted(glob.glob(pattern))
537
+ filtered_blur_paths = []
538
+ for path in self.blur_paths:
539
+ output_deblurred_dir = os.path.join(self.output_dir, "deblurred")
540
+ full_output_path = Path(output_deblurred_dir, *path.split('/')[-2:]).with_suffix(".mp4")
541
+ if not os.path.exists(full_output_path):
542
+ filtered_blur_paths.append(path)
543
+ self.blur_paths = filtered_blur_paths
544
+
545
+ self.sharp_paths = get_sharp_paths(self.blur_paths)
546
+ if self.split == 'val':
547
+ # Optional: limit validation subset
548
+ self.sharp_paths = self.sharp_paths[:5]
549
+ self.length = len(self.sharp_paths)
550
+
551
+ def __len__(self):
552
+ return self.length
553
+
554
+ def __getitem__(self, idx):
555
+ # Path to the blurred (center) frame
556
+ sharp_path = self.sharp_paths[idx]
557
+
558
+
559
+ # Load sharp frames
560
+ blur_img, seq_frames, inp_int, out_int, high_fps_video, num_frames = generate_test_case(
561
+ frame_paths=sharp_path, window_max=13, in_start=3, in_end=10, out_start=0,out_end=13, center=6, mode="2x", fps=240
562
+ )
563
+
564
+ # Convert to pixel values via BaseClass loader
565
+ video = self.load_frames(np.array(seq_frames)) # shape: (output_length, H, W, C)
566
+ blur_input = self.load_frames(np.expand_dims(np.array(blur_img), 0)) # shape: (1, H, W, C)
567
+ last_two_parts_of_path = os.path.join(*sharp_path[6].split(os.sep)[-2:])
568
+ #print(f"Time taken to load and process data: {end_time - start_time:.2f} seconds")
569
+ data = {
570
+ 'file_name': last_two_parts_of_path,
571
+ 'blur_img': blur_input,
572
+ 'video': video,
573
+ "caption": "",
574
+ 'input_interval': inp_int,
575
+ 'output_interval': out_int,
576
+ "num_frames": num_frames,
577
+ "mode": "2x",
578
+ }
579
+ return data
580
+
581
+ class BAISTDataset(BaseClass):
582
+ def __init__(self, *args, **kwargs):
583
+ super().__init__(*args, **kwargs)
584
+
585
+
586
+ test_folders = {
587
+ "gWA_sBM_c01_d26_mWA0_ch06_cropped_32X": None,
588
+ "gBR_sBM_c01_d05_mBR0_ch01_cropped_32X": None,
589
+ "gMH_sBM_c01_d22_mMH0_ch04_cropped_32X": None,
590
+ "gHO_sBM_c01_d20_mHO0_ch05_cropped_32X": None,
591
+ "gMH_sBM_c01_d22_mMH0_ch08_cropped_32X": None,
592
+ "gWA_sBM_c01_d26_mWA0_ch02_cropped_32X": None,
593
+ "gJS_sBM_c01_d02_mJS0_ch08_cropped_32X": None,
594
+ "gHO_sBM_c01_d20_mHO0_ch07_cropped_32X": None,
595
+ "gHO_sBM_c01_d20_mHO0_ch06_cropped_32X": None,
596
+ "gBR_sBM_c01_d05_mBR0_ch03_cropped_32X": None,
597
+ "gBR_sBM_c01_d05_mBR0_ch05_cropped_32X": None,
598
+ "gHO_sBM_c01_d20_mHO0_ch02_cropped_32X": None,
599
+ "gHO_sBM_c01_d20_mHO0_ch03_cropped_32X": None,
600
+ "gHO_sBM_c01_d20_mHO0_ch09_cropped_32X": None,
601
+ "gMH_sBM_c01_d22_mMH0_ch10_cropped_32X": None,
602
+ "gWA_sBM_c01_d26_mWA0_ch10_cropped_32X": None,
603
+ "gBR_sBM_c01_d05_mBR0_ch06_cropped_32X": None,
604
+ "gHO_sBM_c01_d20_mHO0_ch08_cropped_32X": None,
605
+ "gMH_sBM_c01_d22_mMH0_ch06_cropped_32X": None,
606
+ "gHO_sBM_c01_d20_mHO0_ch10_cropped_32X": None,
607
+ "gMH_sBM_c01_d22_mMH0_ch09_cropped_32X": None,
608
+ "gMH_sBM_c01_d22_mMH0_ch02_cropped_32X": None,
609
+ "gBR_sBM_c01_d05_mBR0_ch04_cropped_32X": None,
610
+ "gPO_sBM_c01_d10_mPO0_ch09_cropped_32X": None,
611
+ "gMH_sBM_c01_d22_mMH0_ch01_cropped_32X": None,
612
+ "gMH_sBM_c01_d22_mMH0_ch07_cropped_32X": None,
613
+ "gMH_sBM_c01_d22_mMH0_ch03_cropped_32X": None,
614
+ "gHO_sBM_c01_d20_mHO0_ch04_cropped_32X": None,
615
+ "gBR_sBM_c01_d05_mBR0_ch02_cropped_32X": None,
616
+ "gHO_sBM_c01_d20_mHO0_ch01_cropped_32X": None,
617
+ "gMH_sBM_c01_d22_mMH0_ch05_cropped_32X": None,
618
+ "gPO_sBM_c01_d10_mPO0_ch10_cropped_32X": None,
619
+ }
620
+
621
+ def collect_blur_images(root_dir, allowed_folders, skip_start=40, skip_end=40):
622
+ blur_image_paths = []
623
+
624
+ for dirpath, dirnames, filenames in os.walk(root_dir):
625
+ if os.path.basename(dirpath) == "blur":
626
+ parent_folder = os.path.basename(os.path.dirname(dirpath))
627
+ if (self.split in ["test", "val"] and parent_folder in test_folders) or (self.split in "train" and parent_folder not in test_folders):
628
+ # Filter and sort valid image filenames
629
+ valid_files = [
630
+ f for f in filenames
631
+ if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')) and os.path.splitext(f)[0].isdigit()
632
+ ]
633
+ valid_files.sort(key=lambda x: int(os.path.splitext(x)[0]))
634
+
635
+ # Skip first and last N files
636
+ middle_files = valid_files[skip_start:len(valid_files) - skip_end]
637
+
638
+ for f in middle_files:
639
+ from pathlib import Path
640
+ full_path = Path(os.path.join(dirpath, f))
641
+ output_deblurred_dir = os.path.join(self.output_dir, "deblurred")
642
+ full_output_path = Path(output_deblurred_dir, *full_path.parts[-3:]).with_suffix(".mp4")
643
+ if not os.path.exists(full_output_path) or self.split in ["train", "val"]:
644
+ blur_image_paths.append(os.path.join(dirpath, f))
645
+
646
+ return blur_image_paths
647
+
648
+
649
+
650
+ self.image_paths = collect_blur_images(self.data_dir, test_folders)
651
+ #if bbx path does not exist, remove the image path
652
+ self.image_paths = [path for path in self.image_paths if os.path.exists(path.replace("blur", "blur_anno").replace(".png", ".pkl"))]
653
+
654
+ filtered_image_paths = []
655
+ for blur_path in self.image_paths:
656
+ base_dir = blur_path.replace('/blur/', '/sharp/').replace('.png', '')
657
+ sharp_paths = [f"{base_dir}_{i:03d}.png" for i in range(7)]
658
+ if all(os.path.exists(p) for p in sharp_paths):
659
+ filtered_image_paths.append(blur_path)
660
+
661
+ self.image_paths = filtered_image_paths
662
+
663
+ if self.split == 'val':
664
+ # Optional: limit validation subset
665
+ self.image_paths = self.image_paths[:4]
666
+ self.length = len(self.image_paths)
667
+
668
+ def __len__(self):
669
+ return self.length
670
+
671
+
672
+ def __getitem__(self, idx):
673
+ image_path = self.image_paths[idx]
674
+ blur_img_original = load_as_srgb(image_path)
675
+
676
+ bbx_path = image_path.replace("blur", "blur_anno").replace(".png", ".pkl")
677
+
678
+ #load the bbx path
679
+ bbx = np.load(bbx_path, allow_pickle=True)['bbox'][0:4]
680
+ # Final crop box
681
+ #turn crop_box into tupel
682
+ W,H = blur_img_original.size
683
+ blur_img = blur_img_original.resize((self.image_size[1], self.image_size[0]), resample=Image.BILINEAR)
684
+
685
+ #cause pil is width, height
686
+ blur_np = np.array([blur_img])
687
+
688
+ base_dir = os.path.dirname(os.path.dirname(image_path)) # strip /blur
689
+ filename = os.path.splitext(os.path.basename(image_path))[0] # '00000000'
690
+ sharp_dir = os.path.join(base_dir, "sharp")
691
+
692
+ frame_paths = [
693
+ os.path.join(sharp_dir, f"{filename}_{i:03d}.png")
694
+ for i in range(7)
695
+ ]
696
+
697
+ _, seq_frames, inp_int, out_int, high_fps_video, num_frames = generate_test_case(
698
+ frame_paths=frame_paths, window_max=7, in_start=0, in_end=7, out_start=0,out_end=7, center=3, mode="1x", fps=240
699
+ )
700
+
701
+ pixel_values = self.load_frames(np.stack(seq_frames, axis=0))
702
+ blur_pixel_values = self.load_frames(blur_np)
703
+
704
+ relative_file_name = os.path.relpath(image_path, self.data_dir)
705
+
706
+ out_bbx = bbx.copy()
707
+
708
+ scale_x = blur_pixel_values.shape[3]/W
709
+ scale_y = blur_pixel_values.shape[2]/H
710
+ #scale the bbx
711
+ out_bbx[0] = int(out_bbx[0] * scale_x)
712
+ out_bbx[1] = int(out_bbx[1] * scale_y)
713
+ out_bbx[2] = int(out_bbx[2] * scale_x)
714
+ out_bbx[3] = int(out_bbx[3] * scale_y)
715
+
716
+ out_bbx = torch.tensor(out_bbx, dtype=torch.uint32)
717
+
718
+ #crop image using the bbx
719
+ blur_img_npy = np.array(blur_img)
720
+ out_bbx_npy = out_bbx.numpy().astype(np.uint32)
721
+ blur_img_npy = blur_img_npy[out_bbx_npy[1]:out_bbx_npy[3], out_bbx_npy[0]:out_bbx_npy[2], :]
722
+
723
+ data = {
724
+ 'file_name': relative_file_name,
725
+ 'blur_img': blur_pixel_values,
726
+ 'video': pixel_values,
727
+ 'bbx': out_bbx,
728
+ 'caption': "",
729
+ 'input_interval': inp_int,
730
+ 'output_interval': out_int,
731
+ "num_frames": num_frames,
732
+ 'mode': "1x",
733
+ }
734
+ return data
735
+
training/helpers.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import random
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ def random_insert_latent_frame(
8
+ image_latent: torch.Tensor,
9
+ noisy_model_input: torch.Tensor,
10
+ target_latents: torch.Tensor,
11
+ input_intervals: torch.Tensor,
12
+ output_intervals: torch.Tensor,
13
+ special_info
14
+ ):
15
+ """
16
+ Inserts latent frames into noisy input, pads targets, and builds flattened intervals with flags.
17
+
18
+ Args:
19
+ image_latent: [B, latent_count, C, H, W]
20
+ noisy_model_input:[B, F, C, H, W]
21
+ target_latents: [B, F, C, H, W]
22
+ input_intervals: [B, N, frames_per_latent, L]
23
+ output_intervals: [B, M, frames_per_latent, L]
24
+
25
+ For each sample randomly choose:
26
+ Mode A (50%):
27
+ - Insert two image_latent frames at start of noisy input and targets.
28
+ - Pad target_latents by prepending two zero-frames.
29
+ - Pad input_intervals by repeating its last group once.
30
+ Mode B (50%):
31
+ - Insert one image_latent frame at start and repeat last noisy frame at end.
32
+ - Pad target_latents by prepending one one-frame and appending last target frame.
33
+ - Pad output_intervals by repeating its last group once.
34
+
35
+ After padding intervals, flatten each group from [frames_per_latent, L] to [frames_per_latent * L],
36
+ then append a 4-element flag (1 for input groups, 0 for output groups).
37
+
38
+ Returns:
39
+ outputs: Tensor [B, F+2, C, H, W]
40
+ new_targets: Tensor [B, F+2, C, H, W]
41
+ masks: Tensor [B, F+2] bool mask of latent inserts
42
+ intervals: Tensor [B, N+M+1, fpl * L + 4]
43
+ """
44
+ B, F, C, H, W = noisy_model_input.shape
45
+ _, N, fpl, L = input_intervals.shape
46
+ _, M, _, _ = output_intervals.shape
47
+ device = noisy_model_input.device
48
+
49
+ new_F = F + 1 if special_info == "just_one" else F + 2
50
+ outputs = torch.empty((B, new_F, C, H, W), device=device)
51
+ masks = torch.zeros((B, new_F), dtype=torch.bool, device=device)
52
+ combined_groups = N + M #+ 1
53
+ feature_len = fpl * L
54
+ # intervals = torch.empty((B, combined_groups, feature_len + 4), device=device,
55
+ # dtype=input_intervals.dtype)
56
+ intervals = torch.empty((B, combined_groups, feature_len), device=device,
57
+ dtype=input_intervals.dtype)
58
+ new_targets = torch.empty((B, new_F, C, H, W), device=device,
59
+ dtype=target_latents.dtype)
60
+
61
+ for b in range(B):
62
+ latent = image_latent[b, 0]
63
+ frames = noisy_model_input[b]
64
+ tgt = target_latents[b]
65
+
66
+ limit = 10 if special_info == "use_a" else 0.5
67
+ if special_info == "just_one": #ALWAYS_MODE_A
68
+ # Mode A: two latent inserts, zero-prefixed targets
69
+ outputs[b, 0] = latent
70
+ masks[b, :1] = True
71
+ outputs[b, 1:] = frames
72
+
73
+ # pad targets: two large-numbers - these should be ignored
74
+ large_number = torch.ones_like(tgt[0])*10000
75
+ new_targets[b, 0] = large_number
76
+ new_targets[b, 1:] = tgt
77
+
78
+ # pad intervals: input + replicated last input group
79
+ #pad_group = input_intervals[b, -1:].clone()
80
+ in_groups = input_intervals[b] #torch.cat([input_intervals[b], pad_group], dim=0)
81
+ out_groups = output_intervals[b]
82
+ elif random.random() < limit: #ALWAYS_MODE_A
83
+ # Mode A: two latent inserts, zero-prefixed targets
84
+ outputs[b, 0] = latent
85
+ outputs[b, 1] = latent
86
+ masks[b, :2] = True
87
+ outputs[b, 2:] = frames
88
+
89
+ # pad targets: two large-numbers - these should be ignored
90
+ large_number = torch.ones_like(tgt[0])*10000
91
+ new_targets[b, 0] = large_number
92
+ new_targets[b, 1] = large_number
93
+ new_targets[b, 2:] = tgt
94
+
95
+ # pad intervals: input + replicated last input group
96
+ pad_group = input_intervals[b, -1:].clone()
97
+ in_groups = torch.cat([input_intervals[b], pad_group], dim=0)
98
+ out_groups = output_intervals[b]
99
+ else:
100
+ # Mode B: one latent insert & last-frame repeat, one-prefixed/appended targets
101
+ outputs[b, 0] = latent
102
+ masks[b, 0] = True
103
+ outputs[b, 1:new_F-1] = frames
104
+ outputs[b, new_F-1] = frames[-1]
105
+
106
+ # pad targets: one one-frame then original then last frame
107
+ zero = torch.zeros_like(tgt[0])
108
+ new_targets[b, 0] = zero
109
+ new_targets[b, 1:new_F-1] = tgt
110
+ new_targets[b, new_F-1] = tgt[-1]
111
+
112
+ # pad intervals: output + replicated last output group
113
+ in_groups = input_intervals[b]
114
+ pad_group = output_intervals[b, -1:].clone()
115
+ out_groups = torch.cat([output_intervals[b], pad_group], dim=0)
116
+
117
+ # flatten & flag groups
118
+ flat_in = in_groups.reshape(-1, feature_len)
119
+ proc_in = torch.cat([flat_in], dim=1)
120
+
121
+ flat_out = out_groups.reshape(-1, feature_len)
122
+ proc_out = torch.cat([flat_out], dim=1)
123
+
124
+ intervals[b] = torch.cat([proc_in, proc_out], dim=0)
125
+
126
+ return outputs, new_targets, masks, intervals
127
+
128
+
129
+
130
+
131
+ def transform_intervals(
132
+ intervals: torch.Tensor,
133
+ frames_per_latent: int = 4,
134
+ repeat_first: bool = True
135
+ ) -> torch.Tensor:
136
+ """
137
+ Pad and reshape intervals into [B, num_latent_frames, frames_per_latent, L].
138
+
139
+ Args:
140
+ intervals: Tensor of shape [B, N, L]
141
+ frames_per_latent: number of frames per latent group (e.g., 4)
142
+ repeat_first: if True, pad at the beginning by repeating the first row; otherwise pad at the end by repeating the last row.
143
+
144
+ Returns:
145
+ Tensor of shape [B, num_latent_frames, frames_per_latent, L]
146
+ """
147
+ B, N, L = intervals.shape
148
+ num_latent = math.ceil(N / frames_per_latent)
149
+ target_N = num_latent * frames_per_latent
150
+ pad_count = target_N - N
151
+
152
+ if pad_count > 0:
153
+ # choose row to repeat
154
+ pad_row = intervals[:, :1, :] if repeat_first else intervals[:, -1:, :]
155
+ # replicate pad_row pad_count times
156
+ pad = pad_row.repeat(1, pad_count, 1)
157
+ # pad at beginning or end
158
+ if repeat_first:
159
+ expanded = torch.cat([pad, intervals], dim=1)
160
+ else:
161
+ expanded = torch.cat([intervals, pad], dim=1)
162
+ else:
163
+ expanded = intervals[:, :target_N, :]
164
+
165
+ # reshape into latent-frame groups
166
+ return expanded.view(B, num_latent, frames_per_latent, L)
167
+
168
+ import random
169
+ import numpy as np
170
+ import torch
171
+ from PIL import Image
172
+
173
+
174
+ import random
175
+ import numpy as np
176
+ import torch
177
+ from PIL import Image
178
+
179
+
180
+ def build_blur(frame_paths, gamma=2.2):
181
+ """
182
+ Simulate motion blur using inverse-gamma (linear-light) summation:
183
+ - Load each image, convert to float32 sRGB [0,255]
184
+ - Linearize via inverse gamma: linear = (img/255)^gamma
185
+ - Sum linear values, average, then re-encode via gamma: (linear_avg)^(1/gamma)*255
186
+ Returns a uint8 numpy array.
187
+ """
188
+ acc_lin = None
189
+ for p in frame_paths:
190
+ img = np.array(Image.open(p).convert('RGB'), dtype=np.float32)
191
+ # normalize to [0,1] then linearize
192
+ lin = np.power(img / 255.0, gamma)
193
+ acc_lin = lin if acc_lin is None else acc_lin + lin
194
+ # average in linear domain
195
+ avg_lin = acc_lin / len(frame_paths)
196
+ # gamma-encode back to sRGB domain
197
+ srgb = np.power(avg_lin, 1.0 / gamma) * 255.0
198
+ return np.clip(srgb, 0, 255).astype(np.uint8)
199
+
200
+ def generate_1x_sequence(frame_paths, window_max =16, output_len=17, base_rate=1, start = None):
201
+ """
202
+ 1× mode at arbitrary base_rate (units of 1/240s):
203
+ - Treat each output step as the sum of `base_rate` consecutive raw frames.
204
+ - Pick window size W ∈ [1, output_len]
205
+ - Randomly choose start index so W*base_rate frames fit
206
+ - Group raw frames into W groups of length base_rate
207
+ - Build blur image over all W*base_rate frames for input
208
+ - For each group, build a blurred output frame by summing its base_rate frames
209
+ - Pad sequence of W blurred frames to output_len by repeating last blurred frame
210
+ - Input interval always [-0.5, 0.5]
211
+ - Output intervals reflect each group’s coverage within [-0.5,0.5]
212
+ """
213
+ N = len(frame_paths)
214
+ max_w = min(output_len, N // base_rate)
215
+ max_w = min(max_w, window_max)
216
+ W = random.randint(1, max_w)
217
+ if start is not None:
218
+ # choose start so that W*base_rate frames fit
219
+ assert N >= W * base_rate, f"Not enough frames for base_rate={base_rate}, need {W * base_rate}, got {N}"
220
+ else:
221
+ start = random.randint(0, N - W * base_rate)
222
+
223
+
224
+ # group start indices
225
+ group_starts = [start + i * base_rate for i in range(W)]
226
+ # flatten raw frame paths for blur input
227
+ blur_paths = []
228
+ for gs in group_starts:
229
+ blur_paths.extend(frame_paths[gs:gs + base_rate])
230
+ blur_img = build_blur(blur_paths)
231
+
232
+ # build blurred output frames per group
233
+ seq = []
234
+ for gs in group_starts:
235
+ group = frame_paths[gs:gs + base_rate]
236
+ seq.append(build_blur(group))
237
+ # pad with last blurred frame
238
+ seq += [seq[-1]] * (output_len - len(seq))
239
+
240
+ input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float)
241
+ # each group covers interval of length 1/W
242
+ step = 1.0 / W
243
+ intervals = [[-0.5 + i * step, -0.5 + (i + 1) * step] for i in range(W)]
244
+ num_frames = len(intervals)
245
+ intervals += [intervals[-1]] * (output_len - W)
246
+ output_intervals = torch.tensor(intervals, dtype=torch.float)
247
+
248
+ return blur_img, seq, input_interval, output_intervals, num_frames
249
+
250
+ def generate_2x_sequence(frame_paths, window_max =16, output_len=17, base_rate=1):
251
+ """
252
+ 2× mode:
253
+ - Logical window of W output-steps so that 2*W ≤ output_len
254
+ - Raw window spans W*base_rate frames
255
+ - Build blur only over that raw window (flattened) for input
256
+ - before_count = W//2, after_count = W - before_count
257
+ - Define groups for before, during, and after each of length base_rate
258
+ - Build blurred frames for each group
259
+ - Pad sequence of 2*W blurred frames to output_len by repeating last
260
+ - Input interval always [-0.5,0.5]
261
+ - Output intervals relative to window: each group’s center
262
+ """
263
+ N = len(frame_paths)
264
+ max_w = min(output_len // 2, N // base_rate)
265
+ max_w = min(max_w, window_max)
266
+ W = random.randint(1, max_w)
267
+ before_count = W // 2
268
+ after_count = W - before_count
269
+ # choose start so that before and after stay within bounds
270
+ min_start = before_count * base_rate
271
+ max_start = N - (W + after_count) * base_rate
272
+ # ensure we can pick a valid start, else fail
273
+ assert max_start >= min_start, f"Cannot satisfy before/after window for W={W}, base_rate={base_rate}, N={N}"
274
+ start = random.randint(min_start, max_start)
275
+
276
+
277
+ # window group starts
278
+ window_starts = [start + i * base_rate for i in range(W)]
279
+ # flatten for blur input
280
+ blur_paths = []
281
+ for gs in window_starts:
282
+ blur_paths.extend(frame_paths[gs:gs + base_rate])
283
+
284
+
285
+ blur_img = build_blur(blur_paths)
286
+
287
+ # define before/after group starts
288
+ before_count = W // 2
289
+ after_count = W - before_count
290
+ before_starts = [max(0, start - (i + 1) * base_rate) for i in range(before_count)][::-1]
291
+ after_starts = [min(N - base_rate, start + W * base_rate + i * base_rate) for i in range(after_count)]
292
+
293
+ # all group starts in sequence
294
+ group_starts = before_starts + window_starts + after_starts
295
+ # build blurred frames per group
296
+ seq = []
297
+ for gs in group_starts:
298
+ group = frame_paths[gs:gs + base_rate]
299
+ seq.append(build_blur(group))
300
+ # pad blurred frames to output_len
301
+ seq += [seq[-1]] * (output_len - len(seq))
302
+
303
+ input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float)
304
+ # each group covers 1/(2W) around its center within [-0.5,0.5]
305
+ half = 0.5 / W
306
+ centers = [((gs - start) / (W * base_rate)) - 0.5 + half
307
+ for gs in group_starts]
308
+ intervals = [[c - half, c + half] for c in centers]
309
+ num_frames = len(intervals)
310
+ intervals += [intervals[-1]] * (output_len - len(intervals))
311
+ output_intervals = torch.tensor(intervals, dtype=torch.float)
312
+
313
+ return blur_img, seq, input_interval, output_intervals, num_frames
314
+
315
+
316
+ def generate_large_blur_sequence(frame_paths, window_max=16, output_len=17, base_rate=1):
317
+ """
318
+ Large blur mode (fixed output_len=25) with instantaneous outputs:
319
+ - Raw window spans 25 * base_rate consecutive frames
320
+ - Build blur over that full raw window for input
321
+ - For output sequence:
322
+ • Pick 1 raw frame every `base_rate` (group_starts)
323
+ • Each output frame is the instantaneous frame at that raw index
324
+ - Input interval always [-0.5, 0.5]
325
+ - Output intervals reflect each 1-frame slice’s coverage within the blur window,
326
+ leaving gaps between.
327
+ """
328
+ N = len(frame_paths)
329
+ total_raw = window_max * base_rate
330
+ assert N >= total_raw, f"Not enough frames for base_rate={base_rate}, need {total_raw}, got {N}"
331
+ start = random.randint(0, N - total_raw)
332
+
333
+ # build blur input over the full raw block
334
+ raw_block = frame_paths[start:start + total_raw]
335
+ blur_img = build_blur(raw_block)
336
+
337
+ # output sequence: instantaneous frames at each group_start
338
+ seq = []
339
+ group_starts = [start + i * base_rate for i in range(window_max)]
340
+ for gs in group_starts:
341
+ img = np.array(Image.open(frame_paths[gs]).convert('RGB'), dtype=np.uint8)
342
+ seq.append(img)
343
+ # pad blurred frames to output_len
344
+ seq += [seq[-1]] * (output_len - len(seq))
345
+
346
+ # compute intervals for each instantaneous frame:
347
+ # each covers [gs, gs+1) over total_raw, normalized to [-0.5, 0.5]
348
+ intervals = []
349
+ for gs in group_starts:
350
+ t0 = (gs - start) / total_raw - 0.5
351
+ t1 = (gs + 1 - start) / total_raw - 0.5
352
+ intervals.append([t0, t1])
353
+ num_frames = len(intervals)
354
+ intervals += [intervals[-1]] * (output_len - len(intervals))
355
+ output_intervals = torch.tensor(intervals, dtype=torch.float)
356
+
357
+ # input interval
358
+ input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float)
359
+ return blur_img, seq, input_interval, output_intervals, num_frames
360
+
361
+ def generate_test_case(frame_paths,
362
+ window_max=16,
363
+ output_len=17,
364
+ in_start=None,
365
+ in_end=None,
366
+ out_start=None,
367
+ out_end = None,
368
+ center=None,
369
+ mode="1x",
370
+ fps=240):
371
+ """
372
+ Generate blurred input + a target sequence + normalized intervals.
373
+
374
+ Args:
375
+ frame_paths: list of all frame filepaths
376
+ window_max: number of groups/bins W
377
+ output_len: desired length of the output sequence
378
+ in_start, in_end: integer indices defining the raw window [in_start, in_end)
379
+ mode: one of "1x", "2x", or "lb"
380
+ fps: frames-per-second (only used to override mode=="2x" if fps==120)
381
+
382
+ Returns:
383
+ blur_img: np.ndarray of the global blur over the window
384
+ seq: list of np.ndarray, length = output_len (blured groups or raw frames)
385
+ input_interval: torch.Tensor [[-0.5, 0.5]]
386
+ output_intervals: torch.Tensor shape [output_len, 2], normalized in [-0.5,0.5]
387
+ """
388
+ # 1) slice and blur
389
+ raw_paths = frame_paths[in_start:in_end]
390
+
391
+ blur_img = build_blur(raw_paths)
392
+
393
+ # 2) build the sequence
394
+ # one target per frame
395
+ seq = [
396
+ np.array(Image.open(p).convert("RGB"), dtype=np.uint8)
397
+ for p in frame_paths[out_start:out_end]
398
+ ]
399
+
400
+ # 3) compute normalized intervals
401
+ input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float)
402
+
403
+ # 2) define the normalizer
404
+ def normalize(x, in_start, in_end):
405
+ return (x - in_start) / (in_end - in_start) - 0.5
406
+
407
+ base_rate = 240 // fps
408
+
409
+ # 3) define the raw intervals in absolute frame‐indices
410
+ base_rate = 240 // fps
411
+ if mode == "1x":
412
+ assert in_start == out_start and in_end == out_end
413
+ #assert fps == 240, "haven't implemented 120fps in 1x yet"
414
+ W = (out_end - out_start) // base_rate
415
+ # one frame per window
416
+ group_starts = [out_start + i * base_rate for i in range(W)]
417
+ group_ends = [out_start + (i + 1) * base_rate for i in range(W)]
418
+
419
+ elif mode == "2x":
420
+ W = (out_end - out_start) // base_rate
421
+ # every base_rate frames, starting at out_start
422
+ group_starts = [out_start + i * base_rate for i in range(W)]
423
+ group_ends = [out_start + (i + 1) * base_rate for i in range(W)]
424
+
425
+ elif mode == "lb":
426
+ W = (out_end - out_start) // base_rate
427
+ # sparse “key‐frame” windows from the raw input range
428
+ group_starts = [in_start + i * base_rate for i in range(W)]
429
+ group_ends = [s + 1 for s in group_starts]
430
+
431
+ else:
432
+ raise ValueError(f"Unsupported mode: {mode}")
433
+
434
+ # --- after mode‐switch, once you have raw group_starts & group_ends ---
435
+ # 4) build a summed video sequence by blurring each interval
436
+ summed_seq = []
437
+ for s, e in zip(group_starts, group_ends):
438
+ # make sure indices lie in [in_start, in_end)
439
+ s_clamped = max(in_start, min(s, in_end-1))
440
+ e_clamped = max(s_clamped+1, min(e, in_end))
441
+ # sum/blur the frames in [s_clamped:e_clamped)
442
+ summed = build_blur(frame_paths[s_clamped:e_clamped])
443
+ summed_seq.append(summed)
444
+
445
+ # pad to output_len
446
+ if len(summed_seq) < output_len:
447
+ summed_seq += [summed_seq[-1]] * (output_len - len(summed_seq))
448
+
449
+ # 5) now normalize your intervals as before
450
+ def normalize(x):
451
+ return (x - in_start) / (in_end - in_start) - 0.5
452
+
453
+ intervals = [[normalize(s), normalize(e)] for s, e in zip(group_starts, group_ends)]
454
+ num_frames = len(intervals)
455
+ if len(intervals) < output_len:
456
+ intervals += [intervals[-1]] * (output_len - len(intervals))
457
+
458
+ output_intervals = torch.tensor(intervals, dtype=torch.float)
459
+
460
+ # final return now also includes summed_seq
461
+ return blur_img, summed_seq, input_interval, output_intervals, seq, num_frames
462
+
463
+
464
+ def get_conditioning(
465
+ output_len=17,
466
+ in_start=None,
467
+ in_end=None,
468
+ out_start=None,
469
+ out_end=None,
470
+ mode="1x",
471
+ fps=240,
472
+ ):
473
+ """
474
+ Generate normalized intervals conditioning singals. Just like the above function but without
475
+ loading any images (for inference only).
476
+
477
+ Args:
478
+ output_len: desired length of the output sequence
479
+ in_start, in_end: integer indices defining the raw window [in_start, in_end)
480
+ mode: one of "1x", "2x", or "lb"
481
+ fps: frames-per-second (only used to override mode=="2x" if fps==120)
482
+
483
+ Returns:
484
+ input_interval: torch.Tensor [[-0.5, 0.5]]
485
+ output_intervals: torch.Tensor shape [output_len, 2], normalized in [-0.5,0.5]
486
+ """
487
+
488
+ # 3) compute normalized intervals
489
+ input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float)
490
+
491
+ # 2) define the normalizer
492
+ def normalize(x, in_start, in_end):
493
+ return (x - in_start) / (in_end - in_start) - 0.5
494
+
495
+ base_rate = 240 // fps
496
+
497
+ # 3) define the raw intervals in absolute frame‐indices
498
+ base_rate = 240 // fps
499
+ if mode == "1x":
500
+ assert in_start == out_start and in_end == out_end
501
+ #assert fps == 240, "haven't implemented 120fps in 1x yet"
502
+ W = (out_end - out_start) // base_rate
503
+ # one frame per window
504
+ group_starts = [out_start + i * base_rate for i in range(W)]
505
+ group_ends = [out_start + (i + 1) * base_rate for i in range(W)]
506
+
507
+ elif mode == "2x":
508
+ W = (out_end - out_start) // base_rate
509
+ # every base_rate frames, starting at out_start
510
+ group_starts = [out_start + i * base_rate for i in range(W)]
511
+ group_ends = [out_start + (i + 1) * base_rate for i in range(W)]
512
+
513
+ elif mode == "lb":
514
+ W = (out_end - out_start) // base_rate
515
+ # sparse “key‐frame” windows from the raw input range
516
+ group_starts = [in_start + i * base_rate for i in range(W)]
517
+ group_ends = [s + 1 for s in group_starts]
518
+
519
+ else:
520
+ raise ValueError(f"Unsupported mode: {mode}")
521
+
522
+ # 5) now normalize your intervals as before
523
+ def normalize(x):
524
+ return (x - in_start) / (in_end - in_start) - 0.5
525
+
526
+ intervals = [[normalize(s), normalize(e)] for s, e in zip(group_starts, group_ends)]
527
+ num_frames = len(intervals)
528
+ if len(intervals) < output_len:
529
+ intervals += [intervals[-1]] * (output_len - len(intervals))
530
+
531
+ output_intervals = torch.tensor(intervals, dtype=torch.float)
532
+
533
+ return input_interval, output_intervals, num_frames
training/slurm_scripts/simple_multinode.sbatch ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=XYZ
3
+ #SBATCH --nodes=4
4
+ #SBATCH --mem=256gb
5
+ #SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
6
+ #SBATCH --cpus-per-task=28
7
+ #SBATCH --gpus-per-node=4
8
+ #SBATCH --exclusive
9
+ #SBATCH --output=output/slurm-%j-%N.out
10
+ #SBATCH --error=error/slurm-%j-%N.err
11
+ #SBATCH --qos=scavenger
12
+ #SBATCH --signal=B:USR1@300
13
+ #SBATCH --nodelist=lse-hpcnode[1,3,4,5,10-12]
14
+
15
+ #6 and 9 are messed up
16
+ #7 is sketchy as well
17
+
18
+ set -x -e
19
+
20
+ if [ -z "$1" ]
21
+ then
22
+ #quit if no job number is passed
23
+ echo "No config file passed, quitting"
24
+ exit 1
25
+ else
26
+ config_file=$1
27
+ fi
28
+
29
+ source ~/.bashrc
30
+ conda activate gencam
31
+ cd /datasets/sai/gencam/cogvideox/training
32
+
33
+ echo "START TIME: $(date)"
34
+
35
+ # needed until we fix IB issues
36
+ export NCCL_IB_DISABLE=1
37
+ export NCCL_SOCKET_IFNAME=ens
38
+
39
+ # Training setup
40
+ GPUS_PER_NODE=4
41
+ # so processes know who to talk to
42
+ MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
43
+ MASTER_PORT=6000
44
+ NNODES=$SLURM_NNODES
45
+ NODE_RANK=$SLURM_PROCID
46
+ WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
47
+
48
+
49
+ #CMD="accelerate_test.py"
50
+ CMD="train_controlnet.py --config $config_file"
51
+ LAUNCHER="accelerate launch \
52
+ --multi_gpu \
53
+ --gpu_ids 0,1,2,3 \
54
+ --num_processes $WORLD_SIZE \
55
+ --num_machines $NNODES \
56
+ --main_process_ip $MASTER_ADDR \
57
+ --main_process_port $MASTER_PORT \
58
+ --rdzv_backend=c10d \
59
+ --max_restarts 0 \
60
+ --tee 3 \
61
+ "
62
+
63
+ # # NOT SURE THE FOLLOWING ENV VARS IS STRICTLY NEEDED (PROBABLY NOT)
64
+ # export CUDA_HOME=/usr/local/cuda-11.6
65
+ # export LD_PRELOAD=$CUDA_HOME/lib/libnccl.so
66
+ # export LD_LIBRARY_PATH=$CUDA_HOME/efa/lib:$CUDA_HOME/lib:$CUDA_HOME/lib64:$LD_LIBRARY_PATH
67
+
68
+ SRUN_ARGS=" \
69
+ --wait=60 \
70
+ --kill-on-bad-exit=1 \
71
+ "
72
+
73
+ handler()
74
+ {
75
+ echo "Signal handler triggered at $(date)"
76
+
77
+ sleep 120 # Let training save
78
+ sbatch ${BASH_SOURCE[0]} $config_file
79
+ }
80
+
81
+ # register signal handler
82
+ trap handler SIGUSR1
83
+
84
+ clear; srun --cpu-bind=none --jobid $SLURM_JOB_ID $LAUNCHER $CMD & srun_pid=$!
85
+
86
+ wait
87
+
88
+ echo "END TIME: $(date)"
training/slurm_scripts/slurm-bash.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ srun --nodes=1 --gpus=4 --qos=gpu4-8h --pty bash
training/slurm_scripts/train.sbatch ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=train_deblur
3
+ #SBATCH --nodes=1
4
+ #SBATCH --gpus-per-node=4
5
+ #SBATCH --qos=gpu4-8h
6
+ #SBATCH --signal=B:USR1@600
7
+ #SBATCH --cpus-per-task=24
8
+ #SBATCH --output=output/slurm-%j.out
9
+ #SBATCH --error=error/slurm-%j.err
10
+ #SBATCH --nodelist=lse-hpcnode[8]
11
+
12
+ #the signal time needs to be larger than the sleep in the handler function
13
+
14
+ # prepare your environment here
15
+ source ~/.bashrc
16
+ conda activate gencam
17
+ cd /datasets/sai/gencam/cogvideox/training
18
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
19
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
20
+
21
+ if [ -z "$1" ]
22
+ then
23
+ #quit if no job number is passed
24
+ echo "No config file passed, quitting"
25
+ exit 1
26
+ else
27
+ config_file=$1
28
+ fi
29
+
30
+ handler()
31
+ {
32
+ echo "function handler called at $(date)"
33
+ # Send SIGUSR1 to the captured PID of the accelerate job
34
+ if [ -n "$accelerate_pid" ]; then
35
+ echo "Sending SIGUSR1 to accelerate PID: $accelerate_pid"
36
+ python_id=$(ps --ppid $accelerate_pid -o pid=)
37
+ kill -USR1 $python_id # Send SIGUSR1 to the accelerate job
38
+ sleep 300 # Wait for 5 minutes
39
+ else
40
+ echo "No accelerate PID found"
41
+ fi
42
+ echo "Resubmitting job with config file: $config_file"
43
+ sbatch ${BASH_SOURCE[0]} $config_file
44
+ }
45
+
46
+ # register signal handler
47
+ trap handler SIGUSR1
48
+
49
+ echo "Starting job at $(date)"
50
+ #python train_controlnet.py #--config $config_file #& wait
51
+ accelerate launch --config_file accelerator_configs/accelerator_train_config.yaml --multi_gpu train_controlnet.py --config $config_file &
52
+ accelerate_pid=$!
53
+
54
+ wait
training/slurm_scripts/val.sbatch ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=train_deblur
3
+ #SBATCH --nodes=1
4
+ #SBATCH --gpus-per-node=4
5
+ #SBATCH --qos=scavenger
6
+ #SBATCH --signal=B:USR1@600
7
+ #SBATCH --cpus-per-task=24
8
+ #SBATCH --output=output/slurm-%j.out
9
+ #SBATCH --error=error/slurm-%j.err
10
+ #SBATCH --exclude=lse-hpcnode9
11
+ # prepare your environment here
12
+ source ~/.bashrc
13
+ conda activate gencam
14
+ cd /datasets/sai/gencam/cogvideox/training
15
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
16
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
17
+
18
+ if [ -z "$1" ]
19
+ then
20
+ #quit if no job number is passed
21
+ echo "No config file passed, quitting"
22
+ exit 1
23
+ else
24
+ config_file=$1
25
+ fi
26
+
27
+ handler()
28
+ {
29
+ echo "function handler called at $(date)"
30
+ # Send SIGUSR1 to the captured PID of the accelerate job
31
+ if [ -n "$accelerate_pid" ]; then
32
+ echo "Sending SIGUSR1 to accelerate PID: $accelerate_pid"
33
+ python_id=$(ps --ppid $accelerate_pid -o pid=)
34
+ kill -USR1 $python_id # Send SIGUSR1 to the accelerate job
35
+ sleep 300 # Wait for 5 minutes
36
+ else
37
+ echo "No accelerate PID found"
38
+ fi
39
+ sbatch ${BASH_SOURCE[0]} $config_file
40
+ }
41
+
42
+ # register signal handler
43
+ trap handler SIGUSR1
44
+
45
+ echo "Starting job at $(date)"
46
+ #python train_controlnet.py #--config $config_file #& wait
47
+ accelerate launch --config_file accelerator_configs/accelerator_val_config.yaml --multi_gpu train_controlnet.py --config $config_file &
48
+ accelerate_pid=$!
49
+
50
+ wait
training/test_dataset.py ADDED
File without changes
training/train_controlnet.py ADDED
@@ -0,0 +1,724 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import signal
17
+ import sys
18
+ import threading
19
+ import time
20
+ import cv2
21
+ sys.path.append('..')
22
+ from PIL import Image
23
+ import logging
24
+ import math
25
+ import os
26
+ from pathlib import Path
27
+
28
+ import torch
29
+ import transformers
30
+ from accelerate import Accelerator
31
+ from accelerate.logging import get_logger
32
+ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
33
+ from huggingface_hub import create_repo
34
+ from torch.utils.data import DataLoader
35
+ from tqdm.auto import tqdm
36
+ import numpy as np
37
+ from transformers import AutoTokenizer, T5EncoderModel
38
+
39
+ import diffusers
40
+ from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler
41
+ from diffusers.optimization import get_scheduler
42
+ from diffusers.training_utils import (
43
+ cast_training_params,
44
+ free_memory,
45
+ )
46
+ from diffusers.utils import check_min_version, export_to_video, is_wandb_available
47
+ from diffusers.utils.torch_utils import is_compiled_module
48
+
49
+ from controlnet_datasets import FullMotionBlurDataset, GoPro2xMotionBlurDataset, OutsidePhotosDataset, GoProMotionBlurDataset, BAISTDataset
50
+ from controlnet_pipeline import ControlnetCogVideoXPipeline
51
+ from cogvideo_transformer import CogVideoXTransformer3DModel
52
+ from helpers import random_insert_latent_frame, transform_intervals
53
+ import os
54
+ from utils import save_frames_as_pngs, compute_prompt_embeddings, prepare_rotary_positional_embeddings, encode_prompt, get_optimizer, atomic_save, get_args
55
+ if is_wandb_available():
56
+ import wandb
57
+
58
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
59
+ check_min_version("0.31.0.dev0")
60
+
61
+ logger = get_logger(__name__)
62
+
63
+
64
+ def log_validation(
65
+ pipe,
66
+ args,
67
+ accelerator,
68
+ pipeline_args,
69
+ ):
70
+ logger.info(
71
+ f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
72
+ )
73
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
74
+ scheduler_args = {}
75
+
76
+ if "variance_type" in pipe.scheduler.config:
77
+ variance_type = pipe.scheduler.config.variance_type
78
+
79
+ if variance_type in ["learned", "learned_range"]:
80
+ variance_type = "fixed_small"
81
+
82
+ scheduler_args["variance_type"] = variance_type
83
+
84
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)
85
+ pipe = pipe.to(accelerator.device)
86
+
87
+ # run inference
88
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
89
+
90
+ videos = []
91
+ for _ in range(args.num_validation_videos):
92
+ video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
93
+ videos.append(video)
94
+
95
+ free_memory() #delete the pipeline to free up memory
96
+
97
+ return videos
98
+
99
+
100
+
101
+ def main(args):
102
+ global signal_recieved_time
103
+ if args.report_to == "wandb" and args.hub_token is not None:
104
+ raise ValueError(
105
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
106
+ " Please use `huggingface-cli login` to authenticate with the Hub."
107
+ )
108
+
109
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
110
+ # due to pytorch#99272, MPS does not yet support bfloat16.
111
+ raise ValueError(
112
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
113
+ )
114
+
115
+ logging_dir = Path(args.output_dir, args.logging_dir)
116
+
117
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
118
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
119
+ accelerator = Accelerator(
120
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
121
+ mixed_precision=args.mixed_precision,
122
+ log_with=args.report_to,
123
+ project_config=accelerator_project_config,
124
+ kwargs_handlers=[kwargs],
125
+ )
126
+
127
+ # Disable AMP for MPS.
128
+ if torch.backends.mps.is_available():
129
+ accelerator.native_amp = False
130
+
131
+ if args.report_to == "wandb":
132
+ if not is_wandb_available():
133
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
134
+
135
+ # Make one log on every process with the configuration for debugging.
136
+ logging.basicConfig(
137
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
138
+ datefmt="%m/%d/%Y %H:%M:%S",
139
+ level=logging.INFO,
140
+ )
141
+ logger.info(accelerator.state, main_process_only=False)
142
+ if accelerator.is_local_main_process:
143
+ transformers.utils.logging.set_verbosity_warning()
144
+ diffusers.utils.logging.set_verbosity_info()
145
+ else:
146
+ transformers.utils.logging.set_verbosity_error()
147
+ diffusers.utils.logging.set_verbosity_error()
148
+
149
+ # If passed along, set the training seed now.
150
+ if args.seed is not None:
151
+ set_seed(args.seed)
152
+
153
+ # Handle the repository creation
154
+ if accelerator.is_main_process:
155
+ if args.output_dir is not None:
156
+ os.makedirs(args.output_dir, exist_ok=True)
157
+
158
+ if args.push_to_hub:
159
+ repo_id = create_repo(
160
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
161
+ exist_ok=True,
162
+ ).repo_id
163
+
164
+ # Prepare models and scheduler
165
+ tokenizer = AutoTokenizer.from_pretrained(
166
+ os.path.join(args.base_dir, args.pretrained_model_name_or_path), subfolder="tokenizer", revision=args.revision
167
+ )
168
+
169
+ text_encoder = T5EncoderModel.from_pretrained(
170
+ os.path.join(args.base_dir, args.pretrained_model_name_or_path), subfolder="text_encoder", revision=args.revision
171
+ )
172
+
173
+ # CogVideoX-2b weights are stored in float16
174
+ config = CogVideoXTransformer3DModel.load_config(
175
+ os.path.join(args.base_dir, args.pretrained_model_name_or_path),
176
+ subfolder="transformer",
177
+ revision=args.revision,
178
+ variant=args.variant,
179
+ )
180
+
181
+ load_dtype = torch.bfloat16 if "5b" in os.path.join(args.base_dir, args.pretrained_model_name_or_path).lower() else torch.float16
182
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
183
+ os.path.join(args.base_dir, args.pretrained_model_name_or_path),
184
+ subfolder="transformer",
185
+ torch_dtype=load_dtype,
186
+ revision=args.revision,
187
+ variant=args.variant,
188
+ low_cpu_mem_usage=False,
189
+ )
190
+
191
+ vae = AutoencoderKLCogVideoX.from_pretrained(
192
+ os.path.join(args.base_dir, args.pretrained_model_name_or_path), subfolder="vae", revision=args.revision, variant=args.variant
193
+ )
194
+
195
+ scheduler = CogVideoXDPMScheduler.from_pretrained(os.path.join(args.base_dir, args.pretrained_model_name_or_path), subfolder="scheduler")
196
+
197
+ if args.enable_slicing:
198
+ vae.enable_slicing()
199
+ if args.enable_tiling:
200
+ vae.enable_tiling()
201
+
202
+ # We only train the additional adapter controlnet layers
203
+ text_encoder.requires_grad_(False)
204
+ transformer.requires_grad_(True)
205
+ vae.requires_grad_(False)
206
+
207
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
208
+ # as these weights are only used for inference, keeping weights in full precision is not required.
209
+ weight_dtype = torch.float32
210
+ if accelerator.state.deepspeed_plugin:
211
+ # DeepSpeed is handling precision, use what's in the DeepSpeed config
212
+ if (
213
+ "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
214
+ and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
215
+ ):
216
+ weight_dtype = torch.float16
217
+ if (
218
+ "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
219
+ and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
220
+ ):
221
+ weight_dtype = torch.float16
222
+ else:
223
+ if accelerator.mixed_precision == "fp16":
224
+ weight_dtype = torch.float16
225
+ elif accelerator.mixed_precision == "bf16":
226
+ weight_dtype = torch.bfloat16
227
+
228
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
229
+ # due to pytorch#99272, MPS does not yet support bfloat16.
230
+ raise ValueError(
231
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
232
+ )
233
+
234
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
235
+ transformer.to(accelerator.device, dtype=weight_dtype)
236
+ vae.to(accelerator.device, dtype=weight_dtype)
237
+
238
+ if args.gradient_checkpointing:
239
+ transformer.enable_gradient_checkpointing()
240
+
241
+ def unwrap_model(model):
242
+ model = accelerator.unwrap_model(model)
243
+ model = model._orig_mod if is_compiled_module(model) else model
244
+ return model
245
+
246
+ # Enable TF32 for faster training on Ampere GPUs,
247
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
248
+ if args.allow_tf32 and torch.cuda.is_available():
249
+ torch.backends.cuda.matmul.allow_tf32 = True
250
+
251
+ if args.scale_lr:
252
+ args.learning_rate = (
253
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
254
+ )
255
+
256
+ # Make sure the trainable params are in float32.
257
+ if args.mixed_precision == "fp16":
258
+ # only upcast trainable parameters into fp32
259
+ cast_training_params([transformer], dtype=torch.float32)
260
+
261
+ trainable_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
262
+
263
+ # Optimization parameters
264
+ trainable_parameters_with_lr = {"params": trainable_parameters, "lr": args.learning_rate}
265
+ params_to_optimize = [trainable_parameters_with_lr]
266
+
267
+ use_deepspeed_optimizer = (
268
+ accelerator.state.deepspeed_plugin is not None
269
+ and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
270
+ )
271
+ use_deepspeed_scheduler = (
272
+ accelerator.state.deepspeed_plugin is not None
273
+ and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
274
+ )
275
+
276
+ optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
277
+
278
+ # Dataset and DataLoader
279
+ DATASET_REGISTRY = {
280
+ "gopro": GoProMotionBlurDataset,
281
+ "gopro2x": GoPro2xMotionBlurDataset,
282
+ "full": FullMotionBlurDataset,
283
+ "baist": BAISTDataset,
284
+ "outsidephotos": OutsidePhotosDataset, # val-only special (no split)
285
+ }
286
+
287
+ if args.dataset not in DATASET_REGISTRY:
288
+ raise ValueError(f"Unknown dataset: {args.dataset}")
289
+
290
+ train_dataset_class = DATASET_REGISTRY[args.dataset]
291
+ val_dataset_class = train_dataset_class
292
+
293
+ common_kwargs = dict(
294
+ data_dir=os.path.join(args.base_dir, args.video_root_dir),
295
+ output_dir = args.output_dir,
296
+ image_size=(args.height, args.width),
297
+ stride=(args.stride_min, args.stride_max),
298
+ sample_n_frames=args.max_num_frames,
299
+ hflip_p=args.hflip_p,
300
+ )
301
+
302
+ def build_kwargs(is_train: bool):
303
+ """Return constructor kwargs, adding split"""
304
+ kw = dict(common_kwargs)
305
+ kw["split"] = "train" if is_train else args.val_split
306
+ return kw
307
+
308
+ train_dataset = train_dataset_class(**build_kwargs(is_train=True))
309
+ val_dataset = val_dataset_class(**build_kwargs(is_train=False))
310
+
311
+ def encode_video(video):
312
+ video = video.to(accelerator.device, dtype=vae.dtype)
313
+ video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
314
+ latent_dist = vae.encode(video).latent_dist.sample() * vae.config.scaling_factor
315
+ return latent_dist.permute(0, 2, 1, 3, 4).to(memory_format=torch.contiguous_format)
316
+
317
+ def collate_fn(examples):
318
+ blur_img = [example["blur_img"] for example in examples]
319
+ videos = [example["video"] for example in examples]
320
+ if "high_fps_video" in examples[0]:
321
+ high_fps_videos = [example["high_fps_video"] for example in examples]
322
+ high_fps_videos = torch.stack(high_fps_videos)
323
+ high_fps_videos = high_fps_videos.to(memory_format=torch.contiguous_format).float()
324
+ if "bbx" in examples[0]:
325
+ bbx = [example["bbx"] for example in examples]
326
+ bbx = torch.stack(bbx)
327
+ bbx = bbx.to(memory_format=torch.contiguous_format).float()
328
+ prompts = [example["caption"] for example in examples]
329
+ file_names = [example["file_name"] for example in examples]
330
+ num_frames = [example["num_frames"] for example in examples]
331
+ input_intervals = [example["input_interval"] for example in examples]
332
+ output_intervals = [example["output_interval"] for example in examples]
333
+
334
+ videos = torch.stack(videos)
335
+ videos = videos.to(memory_format=torch.contiguous_format).float()
336
+
337
+ blur_img = torch.stack(blur_img)
338
+ blur_img = blur_img.to(memory_format=torch.contiguous_format).float()
339
+
340
+
341
+ input_intervals = torch.stack(input_intervals)
342
+ input_intervals = input_intervals.to(memory_format=torch.contiguous_format).float()
343
+
344
+ output_intervals = torch.stack(output_intervals)
345
+ output_intervals = output_intervals.to(memory_format=torch.contiguous_format).float()
346
+
347
+
348
+ out_dict = {
349
+ "file_names": file_names,
350
+ "blur_img": blur_img,
351
+ "videos": videos,
352
+ "num_frames": num_frames,
353
+ "prompts": prompts,
354
+ "input_intervals": input_intervals,
355
+ "output_intervals": output_intervals,
356
+ }
357
+
358
+ if "high_fps_video" in examples[0]:
359
+ out_dict["high_fps_video"] = high_fps_videos
360
+ if "bbx" in examples[0]:
361
+ out_dict["bbx"] = bbx
362
+ return out_dict
363
+
364
+ train_dataloader = DataLoader(
365
+ train_dataset,
366
+ batch_size=args.train_batch_size,
367
+ shuffle=True,
368
+ collate_fn=collate_fn,
369
+ num_workers=args.dataloader_num_workers,
370
+ )
371
+
372
+ val_dataloader = DataLoader(
373
+ val_dataset,
374
+ batch_size=1,
375
+ shuffle=False,
376
+ collate_fn=collate_fn,
377
+ num_workers=args.dataloader_num_workers,
378
+ )
379
+
380
+ # Scheduler and math around the number of training steps.
381
+ overrode_max_train_steps = False
382
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
383
+ if args.max_train_steps is None:
384
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
385
+ overrode_max_train_steps = True
386
+
387
+ if use_deepspeed_scheduler:
388
+ from accelerate.utils import DummyScheduler
389
+
390
+ lr_scheduler = DummyScheduler(
391
+ name=args.lr_scheduler,
392
+ optimizer=optimizer,
393
+ total_num_steps=args.max_train_steps * accelerator.num_processes,
394
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
395
+ )
396
+ else:
397
+ lr_scheduler = get_scheduler(
398
+ args.lr_scheduler,
399
+ optimizer=optimizer,
400
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
401
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
402
+ num_cycles=args.lr_num_cycles,
403
+ power=args.lr_power,
404
+ )
405
+
406
+ # Prepare everything with our `accelerator`.
407
+ transformer, optimizer, train_dataloader, lr_scheduler, val_dataloader = accelerator.prepare(
408
+ transformer, optimizer, train_dataloader, lr_scheduler, val_dataloader
409
+ )
410
+
411
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
412
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
413
+ if overrode_max_train_steps:
414
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
415
+ # Afterwards we recalculate our number of training epochs
416
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
417
+
418
+ # We need to initialize the trackers we use, and also store our configuration.
419
+ # The trackers initializes automatically on the main process.
420
+ if accelerator.is_main_process:
421
+ tracker_name = args.tracker_name or "cogvideox-controlnet"
422
+ accelerator.init_trackers(tracker_name, config=vars(args))
423
+
424
+
425
+ accelerator.register_for_checkpointing(transformer, optimizer, lr_scheduler)
426
+ save_path = os.path.join(args.output_dir, f"checkpoint")
427
+
428
+ #check if the checkpoint already exists
429
+ if os.path.exists(save_path):
430
+ accelerator.load_state(save_path)
431
+ logger.info(f"Loaded state from {save_path}")
432
+
433
+ # Train!
434
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
435
+ num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
436
+
437
+ logger.info("***** Running training *****")
438
+ logger.info(f" Num trainable parameters = {num_trainable_parameters}")
439
+ logger.info(f" Num examples = {len(train_dataset)}")
440
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
441
+ logger.info(f" Num epochs = {args.num_train_epochs}")
442
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
443
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
444
+ logger.info(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
445
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
446
+ global_step = 0
447
+ first_epoch = 0
448
+ initial_global_step = 0
449
+
450
+ progress_bar = tqdm(
451
+ range(0, args.max_train_steps),
452
+ initial=initial_global_step,
453
+ desc="Steps",
454
+ # Only show the progress bar once on each machine.
455
+ disable=not accelerator.is_local_main_process,
456
+ )
457
+ vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1)
458
+
459
+ # For DeepSpeed training
460
+ model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
461
+
462
+ for epoch in range(first_epoch, args.num_train_epochs):
463
+ transformer.train()
464
+ for step, batch in enumerate(train_dataloader):
465
+ if not args.just_validate:
466
+ models_to_accumulate = [transformer]
467
+ with accelerator.accumulate(models_to_accumulate):
468
+ model_input = encode_video(batch["videos"]).to(dtype=weight_dtype) # [B, F, C, H, W]
469
+ prompts = batch["prompts"]
470
+ image_latent = encode_video(batch["blur_img"]).to(dtype=weight_dtype) # [B, F, C, H, W]
471
+ input_intervals = batch["input_intervals"]
472
+ output_intervals = batch["output_intervals"]
473
+
474
+ batch_size = len(prompts)
475
+ # True = use real prompt (conditional); False = drop to empty (unconditional)
476
+ guidance_mask = torch.rand(batch_size, device=accelerator.device) >= 0.2
477
+
478
+ # build a new prompts list: keep the original where mask True, else blank
479
+ per_sample_prompts = [
480
+ prompts[i] if guidance_mask[i] else ""
481
+ for i in range(batch_size)
482
+ ]
483
+ prompts = per_sample_prompts
484
+
485
+ # encode prompts
486
+ prompt_embeds = compute_prompt_embeddings(
487
+ tokenizer,
488
+ text_encoder,
489
+ prompts,
490
+ model_config.max_text_seq_length,
491
+ accelerator.device,
492
+ weight_dtype,
493
+ requires_grad=False,
494
+ )
495
+
496
+ # Sample noise that will be added to the latents
497
+ noise = torch.randn_like(model_input)
498
+ batch_size, num_frames, num_channels, height, width = model_input.shape
499
+
500
+ # Sample a random timestep for each image
501
+ timesteps = torch.randint(
502
+ 0, scheduler.config.num_train_timesteps, (batch_size,), device=model_input.device
503
+ )
504
+ timesteps = timesteps.long()
505
+
506
+ # Prepare rotary embeds
507
+ image_rotary_emb = (
508
+ prepare_rotary_positional_embeddings(
509
+ height=args.height,
510
+ width=args.width,
511
+ num_frames=num_frames,
512
+ vae_scale_factor_spatial=vae_scale_factor_spatial,
513
+ patch_size=model_config.patch_size,
514
+ attention_head_dim=model_config.attention_head_dim,
515
+ device=accelerator.device,
516
+ )
517
+ if model_config.use_rotary_positional_embeddings
518
+ else None
519
+ )
520
+
521
+ # Add noise to the model input according to the noise magnitude at each timestep (this is the forward diffusion process)
522
+ noisy_model_input = scheduler.add_noise(model_input, noise, timesteps)
523
+
524
+ input_intervals = transform_intervals(input_intervals, frames_per_latent=4)
525
+ output_intervals = transform_intervals(output_intervals, frames_per_latent=4)
526
+
527
+ #first interval is always rep
528
+ noisy_model_input, target, condition_mask, intervals = random_insert_latent_frame(image_latent, noisy_model_input, model_input, input_intervals, output_intervals, special_info=args.special_info)
529
+
530
+ for i in range(batch_size):
531
+ if not guidance_mask[i]:
532
+ noisy_model_input[i][condition_mask[i]] = 0
533
+
534
+ # Predict the noise residual
535
+ model_output = transformer(
536
+ hidden_states=noisy_model_input,
537
+ encoder_hidden_states=prompt_embeds,
538
+ intervals=intervals,
539
+ condition_mask=condition_mask,
540
+ timestep=timesteps,
541
+ image_rotary_emb=image_rotary_emb,
542
+ return_dict=False,
543
+ )[0]
544
+
545
+ #this line below is also scaling the input which is bad - so the model is also learning to scale this input latent somehow
546
+ #thus, we need to replace the first frame with the original frame later
547
+ model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps)
548
+
549
+ alphas_cumprod = scheduler.alphas_cumprod[timesteps]
550
+ weights = 1 / (1 - alphas_cumprod)
551
+ while len(weights.shape) < len(model_pred.shape):
552
+ weights = weights.unsqueeze(-1)
553
+
554
+
555
+
556
+ loss = torch.mean((weights * (model_pred[~condition_mask] - target[~condition_mask]) ** 2).reshape(batch_size, -1), dim=1)
557
+ loss = loss.mean()
558
+ accelerator.backward(loss)
559
+
560
+ if accelerator.state.deepspeed_plugin is None:
561
+ if not args.just_validate:
562
+ optimizer.step()
563
+ optimizer.zero_grad()
564
+ lr_scheduler.step()
565
+
566
+ #wait for all processes to finish
567
+ accelerator.wait_for_everyone()
568
+
569
+
570
+ # Checks if the accelerator has performed an optimization step behind the scenes
571
+ if accelerator.sync_gradients:
572
+ progress_bar.update(1)
573
+ global_step += 1
574
+
575
+ if signal_recieved_time != 0:
576
+ if time.time() - signal_recieved_time > 60:
577
+ print("Signal received, saving state and exiting")
578
+ atomic_save(save_path, accelerator)
579
+ signal_recieved_time = 0
580
+ exit(0)
581
+ else:
582
+ exit(0)
583
+
584
+ if accelerator.is_main_process:
585
+ if global_step % args.checkpointing_steps == 0:
586
+ atomic_save(save_path, accelerator)
587
+ logger.info(f"Saved state to {save_path}")
588
+
589
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
590
+ progress_bar.set_postfix(**logs)
591
+ accelerator.log(logs, step=global_step)
592
+
593
+ if global_step >= args.max_train_steps:
594
+ break
595
+
596
+ print("Step", step)
597
+ accelerator.wait_for_everyone()
598
+
599
+ if step == 0 or args.validation_prompt is not None and (step + 1) % args.validation_steps == 0:
600
+ # Create pipeline
601
+ pipe = ControlnetCogVideoXPipeline.from_pretrained(
602
+ os.path.join(args.base_dir, args.pretrained_model_name_or_path),
603
+ transformer=unwrap_model(transformer),
604
+ text_encoder=unwrap_model(text_encoder),
605
+ vae=unwrap_model(vae),
606
+ scheduler=scheduler,
607
+ torch_dtype=weight_dtype,
608
+ )
609
+
610
+ print("Length of validation dataset: ", len(val_dataloader))
611
+ #create a pipeline per accelerator device (for faster inference)
612
+ with torch.autocast(str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"):
613
+ for batch in val_dataloader:
614
+ frame = ((batch["blur_img"][0].permute(0,2,3,1).cpu().numpy() + 1)*127.5).astype(np.uint8)
615
+ pipeline_args = {
616
+ "prompt": "",
617
+ "negative_prompt": "",
618
+ "image": frame,
619
+ "input_intervals": batch["input_intervals"][0:1],
620
+ "output_intervals": batch["output_intervals"][0:1],
621
+ "guidance_scale": args.guidance_scale,
622
+ "use_dynamic_cfg": args.use_dynamic_cfg,
623
+ "height": args.height,
624
+ "width": args.width,
625
+ "num_frames": args.max_num_frames,
626
+ "num_inference_steps": args.num_inference_steps,
627
+ }
628
+
629
+ modified_filenames = []
630
+ filenames = batch['file_names']
631
+ for file in filenames:
632
+ modified_filenames.append(os.path.splitext(file)[0] + ".mp4")
633
+
634
+ num_frames = batch["num_frames"][0]
635
+ #save the gt_video output
636
+ if args.dataset not in ["outsidephotos"]:
637
+ gt_video = batch["videos"][0].permute(0,2,3,1).cpu().numpy()
638
+ gt_video = ((gt_video + 1) * 127.5)/255
639
+ gt_video = gt_video[0:num_frames]
640
+
641
+ for file in modified_filenames:
642
+ gt_file_name = os.path.join(args.output_dir, "gt", modified_filenames[0])
643
+ os.makedirs(os.path.dirname(gt_file_name), exist_ok=True)
644
+ if args.dataset == "baist":
645
+ bbox = batch["bbx"][0].cpu().numpy().astype(np.int32)
646
+ gt_video = gt_video[:, bbox[1]:bbox[3], bbox[0]:bbox[2], :]
647
+ gt_video = np.array([cv2.resize(frame, (160, 192)) for frame in gt_video]) #resize to 192x160
648
+
649
+ save_frames_as_pngs((gt_video*255).astype(np.uint8), gt_file_name.replace(".mp4", "").replace("gt", "gt_frames"))
650
+ export_to_video(gt_video, gt_file_name, fps=20)
651
+
652
+
653
+ if "high_fps_video" in batch:
654
+ high_fps_video = batch["high_fps_video"][0].permute(0,2,3,1).cpu().numpy()
655
+ high_fps_video = ((high_fps_video + 1) * 127.5)/255
656
+ gt_file_name = os.path.join(args.output_dir, "gt_highfps", modified_filenames[0])
657
+
658
+
659
+ #save the blurred image
660
+ if args.dataset in ["full", "outsidephotos", "gopro2x", "baist"]:
661
+ for file in modified_filenames:
662
+ blurry_file_name = os.path.join(args.output_dir, "blurry", modified_filenames[0].replace(".mp4", ".png"))
663
+ os.makedirs(os.path.dirname(blurry_file_name), exist_ok=True)
664
+ if args.dataset == "baist":
665
+ bbox = batch["bbx"][0].cpu().numpy().astype(np.int32)
666
+ frame0 = frame[0][bbox[1]:bbox[3], bbox[0]:bbox[2], :]
667
+ frame0 = cv2.resize(frame0, (160, 192)) #resize to 192x160
668
+ Image.fromarray(frame0).save(blurry_file_name)
669
+ else:
670
+ Image.fromarray(frame[0]).save(blurry_file_name)
671
+
672
+ videos = log_validation(
673
+ pipe=pipe,
674
+ args=args,
675
+ accelerator=accelerator,
676
+ pipeline_args=pipeline_args
677
+ )
678
+
679
+ #save the output video frames as pngs (uncompressed results) and mp4 (compressed results easily viewable)
680
+ for i, video in enumerate(videos):
681
+ video = video[0:num_frames]
682
+ filename = os.path.join(args.output_dir, "deblurred", modified_filenames[0])
683
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
684
+ if args.dataset == "baist":
685
+ bbox = batch["bbx"][0].cpu().numpy().astype(np.int32)
686
+ video = video[:, bbox[1]:bbox[3], bbox[0]:bbox[2], :]
687
+ video = np.array([cv2.resize(frame, (160, 192)) for frame in video]) #resize to 192x160
688
+ save_frames_as_pngs((video*255).astype(np.uint8), filename.replace(".mp4", "").replace("deblurred", "deblurred_frames"))
689
+ export_to_video(video, filename, fps=20)
690
+ accelerator.wait_for_everyone()
691
+
692
+ if args.just_validate:
693
+ exit(0)
694
+
695
+ accelerator.wait_for_everyone()
696
+ accelerator.end_training()
697
+
698
+ signal_recieved_time = 0
699
+
700
+ def handle_signal(signum, frame):
701
+ global signal_recieved_time
702
+ signal_recieved_time = time.time()
703
+
704
+ print(f"Signal {signum} received at {time.ctime()}")
705
+
706
+ with open("/datasets/sai/gencam/cogvideox/interrupted.txt", "w") as f:
707
+ f.write(f"Training was interrupted at {time.ctime()}")
708
+
709
+ if __name__ == "__main__":
710
+
711
+ args = get_args()
712
+
713
+ print("Registering signal handler")
714
+ #Register the signal handler (catch SIGUSR1)
715
+ signal.signal(signal.SIGUSR1, handle_signal)
716
+
717
+ main_thread = threading.Thread(target=main, args=(args,))
718
+ main_thread.start()
719
+
720
+ while signal_recieved_time!= 0:
721
+ time.sleep(1)
722
+
723
+ #call main with args as a thread
724
+
training/train_controlnet_backup.py ADDED
@@ -0,0 +1,1235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import random
17
+ import signal
18
+ import sys
19
+ import threading
20
+ import time
21
+
22
+ import cv2
23
+ import yaml
24
+
25
+ sys.path.append('..')
26
+ import argparse
27
+ from PIL import Image
28
+ import logging
29
+ import math
30
+ import os
31
+ import shutil
32
+ from pathlib import Path
33
+ from typing import List, Optional, Tuple, Union
34
+
35
+ import torch
36
+ import transformers
37
+ from accelerate import Accelerator
38
+ from accelerate.logging import get_logger
39
+ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
40
+ from huggingface_hub import create_repo, upload_folder
41
+ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
42
+ from torch.utils.data import DataLoader, Dataset
43
+ from torchvision import transforms
44
+ from tqdm.auto import tqdm
45
+ import numpy as np
46
+ from decord import VideoReader
47
+ from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
48
+
49
+ import diffusers
50
+ from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler
51
+ from diffusers.models.embeddings import get_3d_rotary_pos_embed
52
+ from diffusers.optimization import get_scheduler
53
+ from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
54
+ from diffusers.training_utils import (
55
+ cast_training_params,
56
+ free_memory,
57
+ )
58
+ from diffusers.utils import check_min_version, export_to_video, is_wandb_available
59
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
60
+ from diffusers.utils.torch_utils import is_compiled_module
61
+
62
+ from controlnet_datasets import AblationFullMotionBlurDataset, AdobeMotionBlurDataset, FullMotionBlurDataset, GoPro2xMotionBlurDataset, GoProLargeMotionBlurDataset, OutsidePhotosDataset, GoProMotionBlurDataset, BAISTDataset, SimpleBAISTDataset
63
+ from controlnet_pipeline import ControlnetCogVideoXPipeline
64
+ from cogvideo_transformer import CogVideoXTransformer3DModel
65
+ from helpers import random_insert_latent_frame, transform_intervals
66
+ import os
67
+ import tempfile
68
+ from atomicwrites import atomic_write
69
+
70
+
71
+
72
+ if is_wandb_available():
73
+ import wandb
74
+
75
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
76
+ check_min_version("0.31.0.dev0")
77
+
78
+ logger = get_logger(__name__)
79
+
80
+ def save_frames_as_pngs(video_array,output_dir,
81
+ downsample_spatial=1, # e.g. 2 to halve width & height
82
+ downsample_temporal=1): # e.g. 2 to keep every 2nd frame
83
+ """
84
+ Save each frame of a (T, H, W, C) numpy array as a PNG with no compression.
85
+ """
86
+ assert video_array.ndim == 4 and video_array.shape[-1] == 3, \
87
+ "Expected (T, H, W, C=3) array"
88
+ assert video_array.dtype == np.uint8, "Expected uint8 array"
89
+
90
+ os.makedirs(output_dir, exist_ok=True)
91
+
92
+ # temporal downsample
93
+ frames = video_array[::downsample_temporal]
94
+
95
+ # compute spatially downsampled size
96
+ T, H, W, _ = frames.shape
97
+ new_size = (W // downsample_spatial, H // downsample_spatial)
98
+
99
+ # PNG compression param: 0 = no compression
100
+ png_params = [cv2.IMWRITE_PNG_COMPRESSION, 0]
101
+
102
+ for idx, frame in enumerate(frames):
103
+ # frame is RGB; convert to BGR for OpenCV
104
+ bgr = frame[..., ::-1]
105
+ if downsample_spatial > 1:
106
+ bgr = cv2.resize(bgr, new_size, interpolation=cv2.INTER_NEAREST)
107
+
108
+ filename = os.path.join(output_dir, "frame_{:05d}.png".format(idx))
109
+ success = cv2.imwrite(filename, bgr, png_params)
110
+ if not success:
111
+ raise RuntimeError("Failed to write frame ")
112
+
113
+
114
+ def get_args():
115
+ parser = argparse.ArgumentParser(description="Training script for CogVideoX using config file.")
116
+ parser.add_argument(
117
+ "--config",
118
+ type=str,
119
+ required=True,
120
+ help="Path to the YAML config file."
121
+ )
122
+ args = parser.parse_args()
123
+
124
+ with open(args.config, "r") as f:
125
+ config = yaml.safe_load(f)
126
+
127
+ args = argparse.Namespace(**config)
128
+
129
+ # Convert nested config dict to an argparse.Namespace for easier downstream usage
130
+ return args
131
+
132
+
133
+ # def read_video(video_path, start_index=0, frames_count=49, stride=1):
134
+ # video_reader = VideoReader(video_path)
135
+ # end_index = min(start_index + frames_count * stride, len(video_reader)) - 1
136
+ # batch_index = np.linspace(start_index, end_index, frames_count, dtype=int)
137
+ # numpy_video = video_reader.get_batch(batch_index).asnumpy()
138
+ # return numpy_video
139
+
140
+
141
+ def log_validation(
142
+ pipe,
143
+ args,
144
+ accelerator,
145
+ pipeline_args,
146
+ epoch,
147
+ is_final_validation: bool = False,
148
+ ):
149
+ logger.info(
150
+ f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
151
+ )
152
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
153
+ scheduler_args = {}
154
+
155
+ if "variance_type" in pipe.scheduler.config:
156
+ variance_type = pipe.scheduler.config.variance_type
157
+
158
+ if variance_type in ["learned", "learned_range"]:
159
+ variance_type = "fixed_small"
160
+
161
+ scheduler_args["variance_type"] = variance_type
162
+
163
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)
164
+ pipe = pipe.to(accelerator.device)
165
+ # pipe.set_progress_bar_config(disable=True)
166
+
167
+ # run inference
168
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
169
+
170
+ videos = []
171
+ for _ in range(args.num_validation_videos):
172
+ video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
173
+ videos.append(video)
174
+
175
+ free_memory()
176
+
177
+ return videos
178
+
179
+
180
+ def _get_t5_prompt_embeds(
181
+ tokenizer: T5Tokenizer,
182
+ text_encoder: T5EncoderModel,
183
+ prompt: Union[str, List[str]],
184
+ num_videos_per_prompt: int = 1,
185
+ max_sequence_length: int = 226,
186
+ device: Optional[torch.device] = None,
187
+ dtype: Optional[torch.dtype] = None,
188
+ text_input_ids=None,
189
+ ):
190
+ prompt = [prompt] if isinstance(prompt, str) else prompt
191
+ batch_size = len(prompt)
192
+
193
+ if tokenizer is not None:
194
+ text_inputs = tokenizer(
195
+ prompt,
196
+ padding="max_length",
197
+ max_length=max_sequence_length,
198
+ truncation=True,
199
+ add_special_tokens=True,
200
+ return_tensors="pt",
201
+ )
202
+ text_input_ids = text_inputs.input_ids
203
+ else:
204
+ if text_input_ids is None:
205
+ raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
206
+
207
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
208
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
209
+
210
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
211
+ _, seq_len, _ = prompt_embeds.shape
212
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
213
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
214
+
215
+ return prompt_embeds
216
+
217
+
218
+ def encode_prompt(
219
+ tokenizer: T5Tokenizer,
220
+ text_encoder: T5EncoderModel,
221
+ prompt: Union[str, List[str]],
222
+ num_videos_per_prompt: int = 1,
223
+ max_sequence_length: int = 226,
224
+ device: Optional[torch.device] = None,
225
+ dtype: Optional[torch.dtype] = None,
226
+ text_input_ids=None,
227
+ ):
228
+ prompt = [prompt] if isinstance(prompt, str) else prompt
229
+ prompt_embeds = _get_t5_prompt_embeds(
230
+ tokenizer,
231
+ text_encoder,
232
+ prompt=prompt,
233
+ num_videos_per_prompt=num_videos_per_prompt,
234
+ max_sequence_length=max_sequence_length,
235
+ device=device,
236
+ dtype=dtype,
237
+ text_input_ids=text_input_ids,
238
+ )
239
+ return prompt_embeds
240
+
241
+
242
+ def compute_prompt_embeddings(
243
+ tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False
244
+ ):
245
+ if requires_grad:
246
+ prompt_embeds = encode_prompt(
247
+ tokenizer,
248
+ text_encoder,
249
+ prompt,
250
+ num_videos_per_prompt=1,
251
+ max_sequence_length=max_sequence_length,
252
+ device=device,
253
+ dtype=dtype,
254
+ )
255
+ else:
256
+ with torch.no_grad():
257
+ prompt_embeds = encode_prompt(
258
+ tokenizer,
259
+ text_encoder,
260
+ prompt,
261
+ num_videos_per_prompt=1,
262
+ max_sequence_length=max_sequence_length,
263
+ device=device,
264
+ dtype=dtype,
265
+ )
266
+ return prompt_embeds
267
+
268
+
269
+ def prepare_rotary_positional_embeddings(
270
+ height: int,
271
+ width: int,
272
+ num_frames: int,
273
+ vae_scale_factor_spatial: int = 8,
274
+ patch_size: int = 2,
275
+ attention_head_dim: int = 64,
276
+ device: Optional[torch.device] = None,
277
+ base_height: int = 480,
278
+ base_width: int = 720,
279
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
280
+ grid_height = height // (vae_scale_factor_spatial * patch_size)
281
+ grid_width = width // (vae_scale_factor_spatial * patch_size)
282
+ base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
283
+ base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
284
+
285
+ grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
286
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
287
+ embed_dim=attention_head_dim,
288
+ crops_coords=grid_crops_coords,
289
+ grid_size=(grid_height, grid_width),
290
+ temporal_size=num_frames,
291
+ )
292
+
293
+ freqs_cos = freqs_cos.to(device=device)
294
+ freqs_sin = freqs_sin.to(device=device)
295
+ return freqs_cos, freqs_sin
296
+
297
+
298
+ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
299
+ # Use DeepSpeed optimzer
300
+ if use_deepspeed:
301
+ from accelerate.utils import DummyOptim
302
+
303
+
304
+ return DummyOptim(
305
+ params_to_optimize,
306
+ lr=args.learning_rate,
307
+ betas=(args.adam_beta1, args.adam_beta2),
308
+ eps=args.adam_epsilon,
309
+ weight_decay=args.adam_weight_decay,
310
+ )
311
+
312
+ # Optimizer creation
313
+ supported_optimizers = ["adam", "adamw", "prodigy"]
314
+ if args.optimizer not in supported_optimizers:
315
+ logger.warning(
316
+ f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW"
317
+ )
318
+ args.optimizer = "adamw"
319
+
320
+ if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]):
321
+ logger.warning(
322
+ f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was "
323
+ f"set to {args.optimizer.lower()}"
324
+ )
325
+
326
+ if args.use_8bit_adam:
327
+ try:
328
+ import bitsandbytes as bnb
329
+ except ImportError:
330
+ raise ImportError(
331
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
332
+ )
333
+
334
+ if args.optimizer.lower() == "adamw":
335
+ optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW
336
+
337
+ optimizer = optimizer_class(
338
+ params_to_optimize,
339
+ betas=(args.adam_beta1, args.adam_beta2),
340
+ eps=args.adam_epsilon,
341
+ weight_decay=args.adam_weight_decay,
342
+ )
343
+ elif args.optimizer.lower() == "adam":
344
+ optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam
345
+
346
+
347
+ optimizer = optimizer_class(
348
+ params_to_optimize,
349
+ betas=(args.adam_beta1, args.adam_beta2),
350
+ eps=args.adam_epsilon,
351
+ weight_decay=args.adam_weight_decay,
352
+ )
353
+ elif args.optimizer.lower() == "prodigy":
354
+ try:
355
+ import prodigyopt
356
+ except ImportError:
357
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
358
+
359
+ optimizer_class = prodigyopt.Prodigy
360
+
361
+ if args.learning_rate <= 0.1:
362
+ logger.warning(
363
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
364
+ )
365
+
366
+ optimizer = optimizer_class(
367
+ params_to_optimize,
368
+ lr=args.learning_rate,
369
+ betas=(args.adam_beta1, args.adam_beta2),
370
+ beta3=args.prodigy_beta3,
371
+ weight_decay=args.adam_weight_decay,
372
+ eps=args.adam_epsilon,
373
+ decouple=args.prodigy_decouple,
374
+ use_bias_correction=args.prodigy_use_bias_correction,
375
+ safeguard_warmup=args.prodigy_safeguard_warmup,
376
+ )
377
+
378
+ return optimizer
379
+
380
+
381
+ def main(args):
382
+ global signal_recieved_time
383
+ if args.report_to == "wandb" and args.hub_token is not None:
384
+ raise ValueError(
385
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
386
+ " Please use `huggingface-cli login` to authenticate with the Hub."
387
+ )
388
+
389
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
390
+ # due to pytorch#99272, MPS does not yet support bfloat16.
391
+ raise ValueError(
392
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
393
+ )
394
+
395
+ logging_dir = Path(args.output_dir, args.logging_dir)
396
+
397
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
398
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
399
+ accelerator = Accelerator(
400
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
401
+ mixed_precision=args.mixed_precision,
402
+ log_with=args.report_to,
403
+ project_config=accelerator_project_config,
404
+ kwargs_handlers=[kwargs],
405
+ )
406
+
407
+ # Disable AMP for MPS.
408
+ if torch.backends.mps.is_available():
409
+ accelerator.native_amp = False
410
+
411
+ if args.report_to == "wandb":
412
+ if not is_wandb_available():
413
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
414
+
415
+ # Make one log on every process with the configuration for debugging.
416
+ logging.basicConfig(
417
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
418
+ datefmt="%m/%d/%Y %H:%M:%S",
419
+ level=logging.INFO,
420
+ )
421
+ logger.info(accelerator.state, main_process_only=False)
422
+ if accelerator.is_local_main_process:
423
+ transformers.utils.logging.set_verbosity_warning()
424
+ diffusers.utils.logging.set_verbosity_info()
425
+ else:
426
+ transformers.utils.logging.set_verbosity_error()
427
+ diffusers.utils.logging.set_verbosity_error()
428
+
429
+ # If passed along, set the training seed now.
430
+ if args.seed is not None:
431
+ set_seed(args.seed)
432
+
433
+ # Handle the repository creation
434
+ if accelerator.is_main_process:
435
+ if args.output_dir is not None:
436
+ os.makedirs(args.output_dir, exist_ok=True)
437
+
438
+ if args.push_to_hub:
439
+ repo_id = create_repo(
440
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
441
+ exist_ok=True,
442
+ ).repo_id
443
+
444
+ # Prepare models and scheduler
445
+ tokenizer = AutoTokenizer.from_pretrained(
446
+ os.path.join(args.base_dir, args.pretrained_model_name_or_path), subfolder="tokenizer", revision=args.revision
447
+ )
448
+
449
+ text_encoder = T5EncoderModel.from_pretrained(
450
+ os.path.join(args.base_dir, args.pretrained_model_name_or_path), subfolder="text_encoder", revision=args.revision
451
+ )
452
+
453
+ # CogVideoX-2b weights are stored in float16
454
+ # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
455
+
456
+ ## TRYING NEW CONFIG LOADING
457
+ config = CogVideoXTransformer3DModel.load_config(
458
+ os.path.join(args.base_dir, args.pretrained_model_name_or_path),
459
+ subfolder="transformer",
460
+ revision=args.revision,
461
+ variant=args.variant,
462
+ )
463
+ config["ablation_mode"] = args.ablation_mode if hasattr(args, "ablation_mode") else None
464
+
465
+ ##FINISH TRYING NEW CONFIG LOADING
466
+
467
+
468
+
469
+ load_dtype = torch.bfloat16 if "5b" in os.path.join(args.base_dir, args.pretrained_model_name_or_path).lower() else torch.float16
470
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
471
+ os.path.join(args.base_dir, args.pretrained_model_name_or_path),
472
+ subfolder="transformer",
473
+ torch_dtype=load_dtype,
474
+ ablation_mode=args.ablation_mode if hasattr(args, "ablation_mode") else None,
475
+ revision=args.revision,
476
+ variant=args.variant,
477
+ low_cpu_mem_usage=False,
478
+ )
479
+
480
+ vae = AutoencoderKLCogVideoX.from_pretrained(
481
+ os.path.join(args.base_dir, args.pretrained_model_name_or_path), subfolder="vae", revision=args.revision, variant=args.variant
482
+ )
483
+
484
+
485
+
486
+
487
+ scheduler = CogVideoXDPMScheduler.from_pretrained(os.path.join(args.base_dir, args.pretrained_model_name_or_path), subfolder="scheduler")
488
+
489
+ if args.enable_slicing:
490
+ vae.enable_slicing()
491
+ if args.enable_tiling:
492
+ vae.enable_tiling()
493
+
494
+ # We only train the additional adapter controlnet layers
495
+ text_encoder.requires_grad_(False)
496
+ transformer.requires_grad_(True)
497
+ vae.requires_grad_(False)
498
+
499
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
500
+ # as these weights are only used for inference, keeping weights in full precision is not required.
501
+ weight_dtype = torch.float32
502
+ if accelerator.state.deepspeed_plugin:
503
+ # DeepSpeed is handling precision, use what's in the DeepSpeed config
504
+ if (
505
+ "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
506
+ and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
507
+ ):
508
+ weight_dtype = torch.float16
509
+ if (
510
+ "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
511
+ and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
512
+ ):
513
+ weight_dtype = torch.float16
514
+ else:
515
+ if accelerator.mixed_precision == "fp16":
516
+ weight_dtype = torch.float16
517
+ elif accelerator.mixed_precision == "bf16":
518
+ weight_dtype = torch.bfloat16
519
+
520
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
521
+ # due to pytorch#99272, MPS does not yet support bfloat16.
522
+ raise ValueError(
523
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
524
+ )
525
+
526
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
527
+ transformer.to(accelerator.device, dtype=weight_dtype)
528
+ vae.to(accelerator.device, dtype=weight_dtype)
529
+
530
+ if args.gradient_checkpointing:
531
+ transformer.enable_gradient_checkpointing()
532
+
533
+ def unwrap_model(model):
534
+ model = accelerator.unwrap_model(model)
535
+ model = model._orig_mod if is_compiled_module(model) else model
536
+ return model
537
+
538
+ # Enable TF32 for faster training on Ampere GPUs,
539
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
540
+ if args.allow_tf32 and torch.cuda.is_available():
541
+ torch.backends.cuda.matmul.allow_tf32 = True
542
+
543
+ if args.scale_lr:
544
+ args.learning_rate = (
545
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
546
+ )
547
+
548
+ # Make sure the trainable params are in float32.
549
+ if args.mixed_precision == "fp16":
550
+ # only upcast trainable parameters into fp32
551
+ cast_training_params([transformer], dtype=torch.float32)
552
+
553
+ trainable_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
554
+
555
+ # Optimization parameters
556
+ trainable_parameters_with_lr = {"params": trainable_parameters, "lr": args.learning_rate}
557
+ params_to_optimize = [trainable_parameters_with_lr]
558
+
559
+ use_deepspeed_optimizer = (
560
+ accelerator.state.deepspeed_plugin is not None
561
+ and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
562
+ )
563
+ use_deepspeed_scheduler = (
564
+ accelerator.state.deepspeed_plugin is not None
565
+ and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
566
+ )
567
+
568
+ optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
569
+
570
+ # Dataset and DataLoader
571
+ if args.dataset == "adobe":
572
+ train_dataset = AdobeMotionBlurDataset(
573
+ data_dir=os.path.join(args.base_dir, args.video_root_dir),
574
+ split = "train",
575
+ image_size=(args.height, args.width),
576
+ stride=(args.stride_min, args.stride_max),
577
+ sample_n_frames=args.max_num_frames,
578
+ hflip_p=args.hflip_p,
579
+ )
580
+ elif args.dataset == "gopro":
581
+ train_dataset = GoProMotionBlurDataset(
582
+ data_dir=os.path.join(args.base_dir, args.video_root_dir),
583
+ split = "train",
584
+ image_size=(args.height, args.width),
585
+ stride=(args.stride_min, args.stride_max),
586
+ sample_n_frames=args.max_num_frames,
587
+ hflip_p=args.hflip_p,
588
+ )
589
+ elif args.dataset == "gopro2x":
590
+ train_dataset = GoPro2xMotionBlurDataset(
591
+ data_dir=os.path.join(args.base_dir, args.video_root_dir),
592
+ split = "train",
593
+ image_size=(args.height, args.width),
594
+ stride=(args.stride_min, args.stride_max),
595
+ sample_n_frames=args.max_num_frames,
596
+ hflip_p=args.hflip_p,
597
+ )
598
+ elif args.dataset == "goprolarge":
599
+ train_dataset = GoProLargeMotionBlurDataset(
600
+ data_dir=os.path.join(args.base_dir, args.video_root_dir),
601
+ split = "train",
602
+ image_size=(args.height, args.width),
603
+ stride=(args.stride_min, args.stride_max),
604
+ sample_n_frames=args.max_num_frames,
605
+ hflip_p=args.hflip_p,
606
+ )
607
+ elif args.dataset == "full":
608
+ train_dataset = FullMotionBlurDataset(
609
+ data_dir=os.path.join(args.base_dir, args.video_root_dir),
610
+ split = "train",
611
+ image_size=(args.height, args.width),
612
+ stride=(args.stride_min, args.stride_max),
613
+ sample_n_frames=args.max_num_frames,
614
+ hflip_p=args.hflip_p,
615
+ )
616
+ elif args.dataset == "fullablation":
617
+ train_dataset = AblationFullMotionBlurDataset(
618
+ data_dir=os.path.join(args.base_dir, args.video_root_dir),
619
+ split = "train",
620
+ image_size=(args.height, args.width),
621
+ stride=(args.stride_min, args.stride_max),
622
+ sample_n_frames=args.max_num_frames,
623
+ hflip_p=args.hflip_p,
624
+ ablation_mode = args.ablation_mode, #this is not called for now
625
+ )
626
+ elif args.dataset == "baist":
627
+ train_dataset = BAISTDataset(
628
+ data_dir=os.path.join(args.base_dir, args.video_root_dir),
629
+ split = "train",
630
+ image_size=(args.height, args.width),
631
+ stride=(args.stride_min, args.stride_max),
632
+ sample_n_frames=args.max_num_frames,
633
+ hflip_p=args.hflip_p,
634
+ ) #this is not called for now
635
+ elif args.dataset == "simplebaist":
636
+ train_dataset = SimpleBAISTDataset(
637
+ data_dir=os.path.join(args.base_dir, args.video_root_dir),
638
+ split = "train",
639
+ image_size=(args.height, args.width),
640
+ stride=(args.stride_min, args.stride_max),
641
+ sample_n_frames=args.max_num_frames,
642
+ hflip_p=args.hflip_p,
643
+ )
644
+
645
+
646
+ if args.dataset == "adobe":
647
+ val_dataset = AdobeMotionBlurDataset(
648
+ data_dir=os.path.join(args.base_dir, args.video_root_dir),
649
+ split = args.val_split,
650
+ image_size=(args.height, args.width),
651
+ stride=(args.stride_min, args.stride_max),
652
+ sample_n_frames=args.max_num_frames,
653
+ hflip_p=args.hflip_p,
654
+ )
655
+ elif args.dataset == "outsidephotos":
656
+
657
+ val_dataset = OutsidePhotosDataset(
658
+ data_dir=os.path.join(args.base_dir, args.video_root_dir),
659
+ image_size=(args.height, args.width),
660
+ stride=(args.stride_min, args.stride_max),
661
+ sample_n_frames=args.max_num_frames,
662
+ hflip_p=args.hflip_p,
663
+ )
664
+ train_dataset = val_dataset #dummy dataset
665
+ elif args.dataset == "gopro":
666
+ val_dataset = GoProMotionBlurDataset(
667
+ data_dir=os.path.join(args.base_dir, args.video_root_dir),
668
+ split = args.val_split,
669
+ image_size=(args.height, args.width),
670
+ stride=(args.stride_min, args.stride_max),
671
+ sample_n_frames=args.max_num_frames,
672
+ hflip_p=args.hflip_p,
673
+ )
674
+ elif args.dataset == "gopro2x":
675
+ val_dataset = GoPro2xMotionBlurDataset(
676
+ data_dir=os.path.join(args.base_dir, args.video_root_dir),
677
+ split = args.val_split,
678
+ image_size=(args.height, args.width),
679
+ stride=(args.stride_min, args.stride_max),
680
+ sample_n_frames=args.max_num_frames,
681
+ hflip_p=args.hflip_p,
682
+ )
683
+ elif args.dataset == "goprolarge":
684
+ val_dataset = GoProLargeMotionBlurDataset(
685
+ data_dir=os.path.join(args.base_dir, args.video_root_dir),
686
+ split = args.val_split,
687
+ image_size=(args.height, args.width),
688
+ stride=(args.stride_min, args.stride_max),
689
+ sample_n_frames=args.max_num_frames,
690
+ hflip_p=args.hflip_p,
691
+ )
692
+ elif args.dataset == "full":
693
+ val_dataset = FullMotionBlurDataset(
694
+ data_dir=os.path.join(args.base_dir, args.video_root_dir),
695
+ split = args.val_split,
696
+ image_size=(args.height, args.width),
697
+ stride=(args.stride_min, args.stride_max),
698
+ sample_n_frames=args.max_num_frames,
699
+ hflip_p=args.hflip_p,
700
+ )
701
+ elif args.dataset == "fullablation":
702
+ val_dataset = AblationFullMotionBlurDataset(
703
+ data_dir=os.path.join(args.base_dir, args.video_root_dir),
704
+ split = args.val_split,
705
+ image_size=(args.height, args.width),
706
+ stride=(args.stride_min, args.stride_max),
707
+ sample_n_frames=args.max_num_frames,
708
+ hflip_p=args.hflip_p,
709
+ ablation_mode = args.ablation_mode, #this is not called for now
710
+ )
711
+ elif args.dataset == "baist":
712
+ val_dataset = BAISTDataset(
713
+ data_dir=os.path.join(args.base_dir, args.video_root_dir),
714
+ split = args.val_split,
715
+ image_size=(args.height, args.width),
716
+ stride=(args.stride_min, args.stride_max),
717
+ sample_n_frames=args.max_num_frames,
718
+ hflip_p=args.hflip_p,
719
+ )
720
+ elif args.dataset == "simplebaist":
721
+ val_dataset = SimpleBAISTDataset(
722
+ data_dir=os.path.join(args.base_dir, args.video_root_dir),
723
+ split = args.val_split,
724
+ image_size=(args.height, args.width),
725
+ stride=(args.stride_min, args.stride_max),
726
+ sample_n_frames=args.max_num_frames,
727
+ hflip_p=args.hflip_p,
728
+ )
729
+
730
+ def encode_video(video):
731
+ video = video.to(accelerator.device, dtype=vae.dtype)
732
+ video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
733
+ latent_dist = vae.encode(video).latent_dist.sample() * vae.config.scaling_factor
734
+ return latent_dist.permute(0, 2, 1, 3, 4).to(memory_format=torch.contiguous_format)
735
+
736
+ # def atomic_save(save_path, accelerator):
737
+
738
+ # dir_name = os.path.dirname(save_path)
739
+ # with tempfile.NamedTemporaryFile(delete=False, dir=dir_name) as tmp_file:
740
+ # tmp_path = tmp_file.name
741
+ # # Close the file so that it can be moved later
742
+ # #delete anything at the tmp_path
743
+ # if accelerator.is_main_process:
744
+ # accelerator.save_state(tmp_path) #just a backup incase things go crazy
745
+ # accelerator.save_state(save_path)
746
+ # os.remove(tmp_path)
747
+ # accelerator.wait_for_everyone()
748
+
749
+
750
+
751
+ def atomic_save(save_path, accelerator):
752
+ parent = os.path.dirname(save_path)
753
+ tmp_dir = tempfile.mkdtemp(dir=parent)
754
+ backup_dir = save_path + "_backup"
755
+
756
+ try:
757
+ # Save state into the temp directory
758
+ accelerator.save_state(tmp_dir)
759
+
760
+ # Backup existing save_path if it exists
761
+ if os.path.exists(save_path):
762
+ os.rename(save_path, backup_dir)
763
+
764
+ # Atomically move temp directory into place
765
+ os.rename(tmp_dir, save_path)
766
+
767
+ # Clean up the backup directory
768
+ if os.path.exists(backup_dir):
769
+ shutil.rmtree(backup_dir)
770
+
771
+ except Exception as e:
772
+ # Clean up temp directory on failure
773
+ if os.path.exists(tmp_dir):
774
+ shutil.rmtree(tmp_dir)
775
+
776
+ # Restore from backup if replacement failed
777
+ if os.path.exists(backup_dir):
778
+ if os.path.exists(save_path):
779
+ shutil.rmtree(save_path)
780
+ os.rename(backup_dir, save_path)
781
+
782
+ raise e
783
+
784
+
785
+
786
+ def collate_fn(examples):
787
+ blur_img = [example["blur_img"] for example in examples]
788
+ videos = [example["video"] for example in examples]
789
+ if "high_fps_video" in examples[0]:
790
+ high_fps_videos = [example["high_fps_video"] for example in examples]
791
+ high_fps_videos = torch.stack(high_fps_videos)
792
+ high_fps_videos = high_fps_videos.to(memory_format=torch.contiguous_format).float()
793
+ if "bbx" in examples[0]:
794
+ bbx = [example["bbx"] for example in examples]
795
+ bbx = torch.stack(bbx)
796
+ bbx = bbx.to(memory_format=torch.contiguous_format).float()
797
+ prompts = [example["caption"] for example in examples]
798
+ file_names = [example["file_name"] for example in examples]
799
+ num_frames = [example["num_frames"] for example in examples]
800
+ # if full_file_names in examples[0]:
801
+ # full_file_names = [example["full_file_name"] for example in examples]
802
+ input_intervals = [example["input_interval"] for example in examples]
803
+ output_intervals = [example["output_interval"] for example in examples]
804
+ ablation_condition = [example["ablation_condition"] for example in examples] if "ablation_condition" in examples[0] else None
805
+
806
+
807
+ videos = torch.stack(videos)
808
+ videos = videos.to(memory_format=torch.contiguous_format).float()
809
+
810
+ blur_img = torch.stack(blur_img)
811
+ blur_img = blur_img.to(memory_format=torch.contiguous_format).float()
812
+
813
+
814
+ input_intervals = torch.stack(input_intervals)
815
+ if args.dataset == "gopro":
816
+ input_intervals = input_intervals.to(memory_format=torch.contiguous_format).long() #this is a bug, but I trained it like this on GOPRO (sets intervals all to 0), model doesn't need intervals for this dataset cause its always 7 frames in the same spacing
817
+ else:
818
+ input_intervals = input_intervals.to(memory_format=torch.contiguous_format).float()
819
+
820
+ output_intervals = torch.stack(output_intervals)
821
+ if args.dataset == "gopro":
822
+ output_intervals = output_intervals.to(memory_format=torch.contiguous_format).long() #this is a bug, but I trained it like this on GOPRO (sets intervals all to 0), model doesn't need intervals for this dataset cause its always 7 frames in the same spacing
823
+ else:
824
+ output_intervals = output_intervals.to(memory_format=torch.contiguous_format).float()
825
+
826
+ #just used for ablation studies
827
+ ablation_condition = torch.stack(ablation_condition) if ablation_condition is not None else None
828
+ if ablation_condition is not None:
829
+ ablation_condition = ablation_condition.to(memory_format=torch.contiguous_format).float()
830
+
831
+ out_dict = {
832
+ "file_names": file_names,
833
+ "blur_img": blur_img,
834
+ "videos": videos,
835
+ "num_frames": num_frames,
836
+ "prompts": prompts,
837
+ "input_intervals": input_intervals,
838
+ "output_intervals": output_intervals,
839
+ }
840
+
841
+ if "high_fps_video" in examples[0]:
842
+ out_dict["high_fps_video"] = high_fps_videos
843
+ if "bbx" in examples[0]:
844
+ out_dict["bbx"] = bbx
845
+ if ablation_condition is not None:
846
+ out_dict["ablation_condition"] = ablation_condition
847
+ return out_dict
848
+
849
+ train_dataloader = DataLoader(
850
+ train_dataset,
851
+ batch_size=args.train_batch_size,
852
+ shuffle=True,
853
+ collate_fn=collate_fn,
854
+ num_workers=args.dataloader_num_workers,
855
+ )
856
+
857
+ val_dataloader = DataLoader(
858
+ val_dataset,
859
+ batch_size=1,
860
+ shuffle=False,
861
+ collate_fn=collate_fn,
862
+ num_workers=args.dataloader_num_workers,
863
+ )
864
+
865
+ # Scheduler and math around the number of training steps.
866
+ overrode_max_train_steps = False
867
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
868
+ if args.max_train_steps is None:
869
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
870
+ overrode_max_train_steps = True
871
+
872
+ if use_deepspeed_scheduler:
873
+ from accelerate.utils import DummyScheduler
874
+
875
+ lr_scheduler = DummyScheduler(
876
+ name=args.lr_scheduler,
877
+ optimizer=optimizer,
878
+ total_num_steps=args.max_train_steps * accelerator.num_processes,
879
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
880
+ )
881
+ else:
882
+ lr_scheduler = get_scheduler(
883
+ args.lr_scheduler,
884
+ optimizer=optimizer,
885
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
886
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
887
+ num_cycles=args.lr_num_cycles,
888
+ power=args.lr_power,
889
+ )
890
+
891
+ # Prepare everything with our `accelerator`.
892
+ transformer, optimizer, train_dataloader, lr_scheduler, val_dataloader = accelerator.prepare(
893
+ transformer, optimizer, train_dataloader, lr_scheduler, val_dataloader
894
+ )
895
+
896
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
897
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
898
+ if overrode_max_train_steps:
899
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
900
+ # Afterwards we recalculate our number of training epochs
901
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
902
+
903
+ # We need to initialize the trackers we use, and also store our configuration.
904
+ # The trackers initializes automatically on the main process.
905
+ if accelerator.is_main_process:
906
+ tracker_name = args.tracker_name or "cogvideox-controlnet"
907
+ accelerator.init_trackers(tracker_name, config=vars(args))
908
+
909
+
910
+ accelerator.register_for_checkpointing(transformer, optimizer, lr_scheduler)
911
+ save_path = os.path.join(args.output_dir, f"checkpoint")
912
+
913
+ #check if the checkpoint already exists
914
+ if os.path.exists(save_path):
915
+ accelerator.load_state(save_path)
916
+ logger.info(f"Loaded state from {save_path}")
917
+
918
+
919
+
920
+ # Train!
921
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
922
+ num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
923
+
924
+ logger.info("***** Running training *****")
925
+ logger.info(f" Num trainable parameters = {num_trainable_parameters}")
926
+ logger.info(f" Num examples = {len(train_dataset)}")
927
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
928
+ logger.info(f" Num epochs = {args.num_train_epochs}")
929
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
930
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
931
+ logger.info(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
932
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
933
+ global_step = 0
934
+ first_epoch = 0
935
+ initial_global_step = 0
936
+
937
+ progress_bar = tqdm(
938
+ range(0, args.max_train_steps),
939
+ initial=initial_global_step,
940
+ desc="Steps",
941
+ # Only show the progress bar once on each machine.
942
+ disable=not accelerator.is_local_main_process,
943
+ )
944
+ vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1)
945
+
946
+ # For DeepSpeed training
947
+ model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
948
+
949
+ for epoch in range(first_epoch, args.num_train_epochs):
950
+ transformer.train()
951
+ for step, batch in enumerate(train_dataloader):
952
+ if not args.just_validate:
953
+ models_to_accumulate = [transformer]
954
+ with accelerator.accumulate(models_to_accumulate):
955
+ model_input = encode_video(batch["videos"]).to(dtype=weight_dtype) # [B, F, C, H, W]
956
+ prompts = batch["prompts"]
957
+ image_latent = encode_video(batch["blur_img"]).to(dtype=weight_dtype) # [B, F, C, H, W]
958
+ input_intervals = batch["input_intervals"]
959
+ output_intervals = batch["output_intervals"]
960
+ ablation_condition = batch["ablation_condition"] if "ablation_condition" in batch else None
961
+
962
+ batch_size = len(prompts)
963
+ # True = use real prompt (conditional); False = drop to empty (unconditional)
964
+ guidance_mask = torch.rand(batch_size, device=accelerator.device) >= 0.2
965
+
966
+ # build a new prompts list: keep the original where mask True, else blank
967
+ per_sample_prompts = [
968
+ prompts[i] if guidance_mask[i] else ""
969
+ for i in range(batch_size)
970
+ ]
971
+ prompts = per_sample_prompts
972
+
973
+ # encode prompts
974
+ prompt_embeds = compute_prompt_embeddings(
975
+ tokenizer,
976
+ text_encoder,
977
+ prompts,
978
+ model_config.max_text_seq_length,
979
+ accelerator.device,
980
+ weight_dtype,
981
+ requires_grad=False,
982
+ )
983
+
984
+ # Sample noise that will be added to the latents
985
+ noise = torch.randn_like(model_input)
986
+ batch_size, num_frames, num_channels, height, width = model_input.shape
987
+
988
+ # Sample a random timestep for each image
989
+ timesteps = torch.randint(
990
+ 0, scheduler.config.num_train_timesteps, (batch_size,), device=model_input.device
991
+ )
992
+ timesteps = timesteps.long()
993
+
994
+ # Prepare rotary embeds
995
+ image_rotary_emb = (
996
+ prepare_rotary_positional_embeddings(
997
+ height=args.height,
998
+ width=args.width,
999
+ num_frames=num_frames,
1000
+ vae_scale_factor_spatial=vae_scale_factor_spatial,
1001
+ patch_size=model_config.patch_size,
1002
+ attention_head_dim=model_config.attention_head_dim,
1003
+ device=accelerator.device,
1004
+ )
1005
+ if model_config.use_rotary_positional_embeddings
1006
+ else None
1007
+ )
1008
+
1009
+ # Add noise to the model input according to the noise magnitude at each timestep (this is the forward diffusion process)
1010
+ noisy_model_input = scheduler.add_noise(model_input, noise, timesteps)
1011
+
1012
+ input_intervals = transform_intervals(input_intervals, frames_per_latent=4)
1013
+ output_intervals = transform_intervals(output_intervals, frames_per_latent=4)
1014
+
1015
+ #first interval is always rep
1016
+ noisy_model_input, target, condition_mask, intervals = random_insert_latent_frame(image_latent, noisy_model_input, model_input, input_intervals, output_intervals, special_info=args.special_info)
1017
+
1018
+ for i in range(batch_size):
1019
+ if not guidance_mask[i]:
1020
+ noisy_model_input[i][condition_mask[i]] = 0
1021
+
1022
+ # Predict the noise residual
1023
+ model_output = transformer(
1024
+ hidden_states=noisy_model_input,
1025
+ encoder_hidden_states=prompt_embeds,
1026
+ intervals=intervals,
1027
+ condition_mask=condition_mask,
1028
+ timestep=timesteps,
1029
+ image_rotary_emb=image_rotary_emb,
1030
+ return_dict=False,
1031
+ ablation_condition = ablation_condition
1032
+ )[0]
1033
+
1034
+ #this line below is also scaling the input which is bad - so the model is also learning to scale this input latent somehow
1035
+ #thus, we need to replace the first frame with the original frame later
1036
+ model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps)
1037
+
1038
+
1039
+
1040
+ alphas_cumprod = scheduler.alphas_cumprod[timesteps]
1041
+ weights = 1 / (1 - alphas_cumprod)
1042
+ while len(weights.shape) < len(model_pred.shape):
1043
+ weights = weights.unsqueeze(-1)
1044
+
1045
+
1046
+
1047
+ loss = torch.mean((weights * (model_pred[~condition_mask] - target[~condition_mask]) ** 2).reshape(batch_size, -1), dim=1)
1048
+ loss = loss.mean()
1049
+ accelerator.backward(loss)
1050
+
1051
+ if accelerator.state.deepspeed_plugin is None:
1052
+ if not args.just_validate:
1053
+ optimizer.step()
1054
+ optimizer.zero_grad()
1055
+
1056
+ lr_scheduler.step()
1057
+
1058
+
1059
+ #wait for all processes to finish
1060
+ accelerator.wait_for_everyone()
1061
+
1062
+
1063
+ # Checks if the accelerator has performed an optimization step behind the scenes
1064
+ if accelerator.sync_gradients:
1065
+ progress_bar.update(1)
1066
+ global_step += 1
1067
+
1068
+ if signal_recieved_time != 0:
1069
+ if time.time() - signal_recieved_time > 60:
1070
+ print("Signal received, saving state and exiting")
1071
+ #accelerator.save_state(save_path)
1072
+ atomic_save(save_path, accelerator)
1073
+ signal_recieved_time = 0
1074
+ exit(0)
1075
+ else:
1076
+ exit(0)
1077
+
1078
+ if accelerator.is_main_process:
1079
+ if global_step % args.checkpointing_steps == 0:
1080
+ #accelerator.save_state(save_path)
1081
+ atomic_save(save_path, accelerator)
1082
+ logger.info(f"Saved state to {save_path}")
1083
+
1084
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1085
+ progress_bar.set_postfix(**logs)
1086
+ accelerator.log(logs, step=global_step)
1087
+
1088
+ if global_step >= args.max_train_steps:
1089
+ break
1090
+
1091
+ print("Step", step)
1092
+ accelerator.wait_for_everyone()
1093
+
1094
+ if step == 0 or args.validation_prompt is not None and (step + 1) % args.validation_steps == 0:
1095
+ # Create pipeline
1096
+ pipe = ControlnetCogVideoXPipeline.from_pretrained(
1097
+ os.path.join(args.base_dir, args.pretrained_model_name_or_path),
1098
+ transformer=unwrap_model(transformer),
1099
+ text_encoder=unwrap_model(text_encoder),
1100
+ vae=unwrap_model(vae),
1101
+ scheduler=scheduler,
1102
+ torch_dtype=weight_dtype,
1103
+ )
1104
+
1105
+ print("Length of validation dataset: ", len(val_dataloader))
1106
+ #create a pipeline per accelerator device (for faster inference)
1107
+ with torch.autocast(str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"):
1108
+ for batch in val_dataloader:
1109
+ frame = ((batch["blur_img"][0].permute(0,2,3,1).cpu().numpy() + 1)*127.5).astype(np.uint8)
1110
+ pipeline_args = {
1111
+ "prompt": "",
1112
+ "negative_prompt": "",
1113
+ "image": frame,
1114
+ "input_intervals": batch["input_intervals"][0:1],
1115
+ "output_intervals": batch["output_intervals"][0:1],
1116
+ "ablation_condition": batch["ablation_condition"][0:1] if "ablation_condition" in batch else None,
1117
+ "guidance_scale": args.guidance_scale,
1118
+ "use_dynamic_cfg": args.use_dynamic_cfg,
1119
+ "height": args.height,
1120
+ "width": args.width,
1121
+ "num_frames": args.max_num_frames,
1122
+ "num_inference_steps": args.num_inference_steps,
1123
+ }
1124
+
1125
+ modified_filenames = []
1126
+ filenames = batch['file_names']
1127
+ for file in filenames:
1128
+ modified_filenames.append(os.path.splitext(file)[0] + ".mp4")
1129
+
1130
+ num_frames = batch["num_frames"][0]
1131
+ #save the gt_video output
1132
+ if args.dataset not in ["outsidephotos"]:
1133
+ gt_video = batch["videos"][0].permute(0,2,3,1).cpu().numpy()
1134
+ gt_video = ((gt_video + 1) * 127.5)/255
1135
+ gt_video = gt_video[0:num_frames]
1136
+
1137
+ for file in modified_filenames:
1138
+ #create the directory if it does not exist
1139
+ gt_file_name = os.path.join(args.output_dir, "gt", modified_filenames[0])
1140
+ os.makedirs(os.path.dirname(gt_file_name), exist_ok=True)
1141
+ if args.dataset in ["baist", "simplebaist"]:
1142
+ bbox = batch["bbx"][0].cpu().numpy().astype(np.int32)
1143
+ gt_video = gt_video[:, bbox[1]:bbox[3], bbox[0]:bbox[2], :]
1144
+ gt_video = np.array([cv2.resize(frame, (160, 192)) for frame in gt_video])
1145
+
1146
+ save_frames_as_pngs((gt_video*255).astype(np.uint8), gt_file_name.replace(".mp4", "").replace("gt", "gt_frames"))
1147
+ export_to_video(gt_video, gt_file_name, fps=20)
1148
+
1149
+
1150
+ if "high_fps_video" in batch:
1151
+ high_fps_video = batch["high_fps_video"][0].permute(0,2,3,1).cpu().numpy()
1152
+ high_fps_video = ((high_fps_video + 1) * 127.5)/255
1153
+ gt_file_name = os.path.join(args.output_dir, "gt_highfps", modified_filenames[0])
1154
+
1155
+
1156
+ if args.dataset in ["adobe", "full", "baist", "outsidephotos", "gopro2x", "goprolarge", "simplebaist"]:
1157
+ for file in modified_filenames:
1158
+ #create the directory if it does not exist
1159
+ blurry_file_name = os.path.join(args.output_dir, "blurry", modified_filenames[0].replace(".mp4", ".png"))
1160
+ #save the blurry image
1161
+ os.makedirs(os.path.dirname(blurry_file_name), exist_ok=True)
1162
+ if args.dataset in ["baist", "simplebaist"]:
1163
+ bbox = batch["bbx"][0].cpu().numpy().astype(np.int32)
1164
+ frame0 = frame[0][bbox[1]:bbox[3], bbox[0]:bbox[2], :]
1165
+ #resize to 192x160
1166
+ frame0 = cv2.resize(frame0, (160, 192))
1167
+ Image.fromarray(frame0).save(blurry_file_name)
1168
+ else:
1169
+ Image.fromarray(frame[0]).save(blurry_file_name)
1170
+
1171
+ videos = log_validation(
1172
+ pipe=pipe,
1173
+ args=args,
1174
+ accelerator=accelerator,
1175
+ pipeline_args=pipeline_args,
1176
+ epoch=epoch,
1177
+ )
1178
+
1179
+ for i, video in enumerate(videos):
1180
+ prompt = (
1181
+ pipeline_args["prompt"][:25]
1182
+ .replace(" ", "_")
1183
+ .replace(" ", "_")
1184
+ .replace("'", "_")
1185
+ .replace('"', "_")
1186
+ .replace("/", "_")
1187
+ )
1188
+ video = video[0:num_frames]
1189
+ filename = os.path.join(args.output_dir, "deblurred", modified_filenames[0])
1190
+ print("Deblurred file name", filename)
1191
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
1192
+ if args.dataset in ["baist", "simplebaist"]:
1193
+ bbox = batch["bbx"][0].cpu().numpy().astype(np.int32)
1194
+ video = video[:, bbox[1]:bbox[3], bbox[0]:bbox[2], :]
1195
+ #resize to 192x160
1196
+ video = np.array([cv2.resize(frame, (160, 192)) for frame in video])
1197
+ save_frames_as_pngs((video*255).astype(np.uint8), filename.replace(".mp4", "").replace("deblurred", "deblurred_frames"))
1198
+ export_to_video(video, filename, fps=20)
1199
+
1200
+ accelerator.wait_for_everyone()
1201
+
1202
+ if args.just_validate:
1203
+ exit(0)
1204
+
1205
+ accelerator.wait_for_everyone()
1206
+ accelerator.end_training()
1207
+
1208
+ signal_recieved_time = 0
1209
+
1210
+ def handle_signal(signum, frame):
1211
+ global signal_recieved_time
1212
+ signal_recieved_time = time.time()
1213
+
1214
+ print(f"Signal {signum} received at {time.ctime()}")
1215
+
1216
+ with open("/datasets/sai/gencam/cogvideox/interrupted.txt", "w") as f:
1217
+ f.write(f"Training was interrupted at {time.ctime()}")
1218
+
1219
+ if __name__ == "__main__":
1220
+
1221
+ args = get_args()
1222
+
1223
+ print("Registering signal handler")
1224
+ #Register the signal handler (catch SIGUSR1)
1225
+ signal.signal(signal.SIGUSR1, handle_signal)
1226
+
1227
+ main_thread = threading.Thread(target=main, args=(args,))
1228
+ main_thread.start()
1229
+
1230
+ print("SIGNAL RECIEVED TIME", signal_recieved_time)
1231
+ while signal_recieved_time!= 0:
1232
+ time.sleep(1)
1233
+
1234
+ #call main with args as a thread
1235
+
training/utils.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Optional, Union, Tuple
3
+ import torch
4
+ from transformers import T5EncoderModel, T5Tokenizer
5
+ import numpy as np
6
+ import cv2
7
+ from diffusers.models.embeddings import get_3d_rotary_pos_embed
8
+ from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
9
+ from accelerate.logging import get_logger
10
+ import tempfile
11
+ import argparse
12
+ import yaml
13
+ import shutil
14
+
15
+ logger = get_logger(__name__)
16
+
17
+ def get_args():
18
+ parser = argparse.ArgumentParser(description="Training script for CogVideoX using config file.")
19
+ parser.add_argument(
20
+ "--config",
21
+ type=str,
22
+ required=True,
23
+ help="Path to the YAML config file."
24
+ )
25
+ args = parser.parse_args()
26
+ with open(args.config, "r") as f:
27
+ config = yaml.safe_load(f)
28
+ args = argparse.Namespace(**config)
29
+ # Convert nested config dict to an argparse.Namespace for easier downstream usage
30
+ return args
31
+
32
+
33
+
34
+ def atomic_save(save_path, accelerator):
35
+ parent = os.path.dirname(save_path)
36
+ tmp_dir = tempfile.mkdtemp(dir=parent)
37
+ backup_dir = save_path + "_backup"
38
+
39
+ try:
40
+ # Save state into the temp directory
41
+ accelerator.save_state(tmp_dir)
42
+
43
+ # Backup existing save_path if it exists
44
+ if os.path.exists(save_path):
45
+ os.rename(save_path, backup_dir)
46
+
47
+ # Atomically move temp directory into place
48
+ os.rename(tmp_dir, save_path)
49
+
50
+ # Clean up the backup directory
51
+ if os.path.exists(backup_dir):
52
+ shutil.rmtree(backup_dir)
53
+
54
+ except Exception as e:
55
+ # Clean up temp directory on failure
56
+ if os.path.exists(tmp_dir):
57
+ shutil.rmtree(tmp_dir)
58
+
59
+ # Restore from backup if replacement failed
60
+ if os.path.exists(backup_dir):
61
+ if os.path.exists(save_path):
62
+ shutil.rmtree(save_path)
63
+ os.rename(backup_dir, save_path)
64
+
65
+ raise e
66
+
67
+
68
+ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
69
+ # Use DeepSpeed optimzer
70
+ if use_deepspeed:
71
+ from accelerate.utils import DummyOptim
72
+
73
+
74
+ return DummyOptim(
75
+ params_to_optimize,
76
+ lr=args.learning_rate,
77
+ betas=(args.adam_beta1, args.adam_beta2),
78
+ eps=args.adam_epsilon,
79
+ weight_decay=args.adam_weight_decay,
80
+ )
81
+
82
+ # Optimizer creation
83
+ supported_optimizers = ["adam", "adamw", "prodigy"]
84
+ if args.optimizer not in supported_optimizers:
85
+ logger.warning(
86
+ f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW"
87
+ )
88
+ args.optimizer = "adamw"
89
+
90
+ if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]):
91
+ logger.warning(
92
+ f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was "
93
+ f"set to {args.optimizer.lower()}"
94
+ )
95
+
96
+ if args.use_8bit_adam:
97
+ try:
98
+ import bitsandbytes as bnb
99
+ except ImportError:
100
+ raise ImportError(
101
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
102
+ )
103
+
104
+ if args.optimizer.lower() == "adamw":
105
+ optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW
106
+
107
+ optimizer = optimizer_class(
108
+ params_to_optimize,
109
+ betas=(args.adam_beta1, args.adam_beta2),
110
+ eps=args.adam_epsilon,
111
+ weight_decay=args.adam_weight_decay,
112
+ )
113
+ elif args.optimizer.lower() == "adam":
114
+ optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam
115
+
116
+
117
+ optimizer = optimizer_class(
118
+ params_to_optimize,
119
+ betas=(args.adam_beta1, args.adam_beta2),
120
+ eps=args.adam_epsilon,
121
+ weight_decay=args.adam_weight_decay,
122
+ )
123
+ elif args.optimizer.lower() == "prodigy":
124
+ try:
125
+ import prodigyopt
126
+ except ImportError:
127
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
128
+
129
+ optimizer_class = prodigyopt.Prodigy
130
+
131
+ if args.learning_rate <= 0.1:
132
+ logger.warning(
133
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
134
+ )
135
+
136
+ optimizer = optimizer_class(
137
+ params_to_optimize,
138
+ lr=args.learning_rate,
139
+ betas=(args.adam_beta1, args.adam_beta2),
140
+ beta3=args.prodigy_beta3,
141
+ weight_decay=args.adam_weight_decay,
142
+ eps=args.adam_epsilon,
143
+ decouple=args.prodigy_decouple,
144
+ use_bias_correction=args.prodigy_use_bias_correction,
145
+ safeguard_warmup=args.prodigy_safeguard_warmup,
146
+ )
147
+
148
+ return optimizer
149
+
150
+
151
+ def prepare_rotary_positional_embeddings(
152
+ height: int,
153
+ width: int,
154
+ num_frames: int,
155
+ vae_scale_factor_spatial: int = 8,
156
+ patch_size: int = 2,
157
+ attention_head_dim: int = 64,
158
+ device: Optional[torch.device] = None,
159
+ base_height: int = 480,
160
+ base_width: int = 720,
161
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ grid_height = height // (vae_scale_factor_spatial * patch_size)
163
+ grid_width = width // (vae_scale_factor_spatial * patch_size)
164
+ base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
165
+ base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
166
+
167
+ grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
168
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
169
+ embed_dim=attention_head_dim,
170
+ crops_coords=grid_crops_coords,
171
+ grid_size=(grid_height, grid_width),
172
+ temporal_size=num_frames,
173
+ )
174
+
175
+ freqs_cos = freqs_cos.to(device=device)
176
+ freqs_sin = freqs_sin.to(device=device)
177
+ return freqs_cos, freqs_sin
178
+
179
+
180
+ def _get_t5_prompt_embeds(
181
+ tokenizer: T5Tokenizer,
182
+ text_encoder: T5EncoderModel,
183
+ prompt: Union[str, List[str]],
184
+ num_videos_per_prompt: int = 1,
185
+ max_sequence_length: int = 226,
186
+ device: Optional[torch.device] = None,
187
+ dtype: Optional[torch.dtype] = None,
188
+ text_input_ids=None,
189
+ ):
190
+ prompt = [prompt] if isinstance(prompt, str) else prompt
191
+ batch_size = len(prompt)
192
+
193
+ if tokenizer is not None:
194
+ text_inputs = tokenizer(
195
+ prompt,
196
+ padding="max_length",
197
+ max_length=max_sequence_length,
198
+ truncation=True,
199
+ add_special_tokens=True,
200
+ return_tensors="pt",
201
+ )
202
+ text_input_ids = text_inputs.input_ids
203
+ else:
204
+ if text_input_ids is None:
205
+ raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
206
+
207
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
208
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
209
+
210
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
211
+ _, seq_len, _ = prompt_embeds.shape
212
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
213
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
214
+
215
+ return prompt_embeds
216
+
217
+
218
+ def encode_prompt(
219
+ tokenizer: T5Tokenizer,
220
+ text_encoder: T5EncoderModel,
221
+ prompt: Union[str, List[str]],
222
+ num_videos_per_prompt: int = 1,
223
+ max_sequence_length: int = 226,
224
+ device: Optional[torch.device] = None,
225
+ dtype: Optional[torch.dtype] = None,
226
+ text_input_ids=None,
227
+ ):
228
+ prompt = [prompt] if isinstance(prompt, str) else prompt
229
+ prompt_embeds = _get_t5_prompt_embeds(
230
+ tokenizer,
231
+ text_encoder,
232
+ prompt=prompt,
233
+ num_videos_per_prompt=num_videos_per_prompt,
234
+ max_sequence_length=max_sequence_length,
235
+ device=device,
236
+ dtype=dtype,
237
+ text_input_ids=text_input_ids,
238
+ )
239
+ return prompt_embeds
240
+
241
+
242
+ def compute_prompt_embeddings(
243
+ tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False
244
+ ):
245
+ if requires_grad:
246
+ prompt_embeds = encode_prompt(
247
+ tokenizer,
248
+ text_encoder,
249
+ prompt,
250
+ num_videos_per_prompt=1,
251
+ max_sequence_length=max_sequence_length,
252
+ device=device,
253
+ dtype=dtype,
254
+ )
255
+ else:
256
+ with torch.no_grad():
257
+ prompt_embeds = encode_prompt(
258
+ tokenizer,
259
+ text_encoder,
260
+ prompt,
261
+ num_videos_per_prompt=1,
262
+ max_sequence_length=max_sequence_length,
263
+ device=device,
264
+ dtype=dtype,
265
+ )
266
+ return prompt_embeds
267
+
268
+ def save_frames_as_pngs(video_array,output_dir,
269
+ downsample_spatial=1, # e.g. 2 to halve width & height
270
+ downsample_temporal=1): # e.g. 2 to keep every 2nd frame
271
+ """
272
+ Save each frame of a (T, H, W, C) numpy array as a PNG with no compression.
273
+ """
274
+ assert video_array.ndim == 4 and video_array.shape[-1] == 3, \
275
+ "Expected (T, H, W, C=3) array"
276
+ assert video_array.dtype == np.uint8, "Expected uint8 array"
277
+
278
+ os.makedirs(output_dir, exist_ok=True)
279
+
280
+ # temporal downsample
281
+ frames = video_array[::downsample_temporal]
282
+
283
+ # compute spatially downsampled size
284
+ T, H, W, _ = frames.shape
285
+ new_size = (W // downsample_spatial, H // downsample_spatial)
286
+
287
+ # PNG compression param: 0 = no compression
288
+ png_params = [cv2.IMWRITE_PNG_COMPRESSION, 0]
289
+
290
+ for idx, frame in enumerate(frames):
291
+ # frame is RGB; convert to BGR for OpenCV
292
+ bgr = frame[..., ::-1]
293
+ if downsample_spatial > 1:
294
+ bgr = cv2.resize(bgr, new_size, interpolation=cv2.INTER_NEAREST)
295
+
296
+ filename = os.path.join(output_dir, "frame_{:05d}.png".format(idx))
297
+ success = cv2.imwrite(filename, bgr, png_params)
298
+ if not success:
299
+ raise RuntimeError("Failed to write frame ")