Spaces:
Running
on
Zero
Running
on
Zero
initial commit
Browse files- LICENSE +201 -0
- cogvideo_embeddings.py +0 -0
- cogvideo_transformer.py +547 -0
- controlnet_pipeline.py +733 -0
- extra/checkpoints_to_hf.py +16 -0
- extra/moMets-parallel-baist.py +330 -0
- extra/moMets-parallel-gopro.py +343 -0
- gradio/app.py +85 -0
- inference.py +317 -0
- requirements.txt +22 -0
- setup/download_checkpoints.py +53 -0
- setup/download_cogvideo_weights.py +6 -0
- setup/environment.yaml +225 -0
- training/accelerator_configs/accelerate_test.py +17 -0
- training/accelerator_configs/accelerator_multigpu.yaml +6 -0
- training/accelerator_configs/accelerator_multinode.yaml +4 -0
- training/accelerator_configs/accelerator_singlegpu.yaml +25 -0
- training/accelerator_configs/accelerator_val_config.yaml +25 -0
- training/available-qos.txt +10 -0
- training/configs/baist_test.yaml +77 -0
- training/configs/baist_train.yaml +78 -0
- training/configs/full_test.yaml +78 -0
- training/configs/full_train.yaml +78 -0
- training/configs/gopro_2x_test.yaml +78 -0
- training/configs/gopro_test.yaml +78 -0
- training/configs/gopro_train.yaml +77 -0
- training/configs/outsidephotos.yaml +76 -0
- training/controlnet_datasets.py +735 -0
- training/helpers.py +533 -0
- training/slurm_scripts/simple_multinode.sbatch +88 -0
- training/slurm_scripts/slurm-bash.sh +1 -0
- training/slurm_scripts/train.sbatch +54 -0
- training/slurm_scripts/val.sbatch +50 -0
- training/test_dataset.py +0 -0
- training/train_controlnet.py +724 -0
- training/train_controlnet_backup.py +1235 -0
- training/utils.py +299 -0
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2024 CogVideo Model Team @ Zhipu AI
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
cogvideo_embeddings.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cogvideo_transformer.py
ADDED
|
@@ -0,0 +1,547 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from torch import nn
|
| 20 |
+
|
| 21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 22 |
+
from diffusers.loaders import PeftAdapterMixin
|
| 23 |
+
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
| 24 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 25 |
+
from diffusers.models.attention import Attention, FeedForward
|
| 26 |
+
from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
|
| 27 |
+
#from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
|
| 28 |
+
from cogvideo_embeddings import CogVideoXPatchEmbedWBlur, TimestepEmbedding, Timesteps
|
| 29 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 30 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 31 |
+
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@maybe_allow_in_graph
|
| 38 |
+
class CogVideoXBlock(nn.Module):
|
| 39 |
+
r"""
|
| 40 |
+
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
|
| 41 |
+
|
| 42 |
+
Parameters:
|
| 43 |
+
dim (`int`):
|
| 44 |
+
The number of channels in the input and output.
|
| 45 |
+
num_attention_heads (`int`):
|
| 46 |
+
The number of heads to use for multi-head attention.
|
| 47 |
+
attention_head_dim (`int`):
|
| 48 |
+
The number of channels in each head.
|
| 49 |
+
time_embed_dim (`int`):
|
| 50 |
+
The number of channels in timestep embedding.
|
| 51 |
+
dropout (`float`, defaults to `0.0`):
|
| 52 |
+
The dropout probability to use.
|
| 53 |
+
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
| 54 |
+
Activation function to be used in feed-forward.
|
| 55 |
+
attention_bias (`bool`, defaults to `False`):
|
| 56 |
+
Whether or not to use bias in attention projection layers.
|
| 57 |
+
qk_norm (`bool`, defaults to `True`):
|
| 58 |
+
Whether or not to use normalization after query and key projections in Attention.
|
| 59 |
+
norm_elementwise_affine (`bool`, defaults to `True`):
|
| 60 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
| 61 |
+
norm_eps (`float`, defaults to `1e-5`):
|
| 62 |
+
Epsilon value for normalization layers.
|
| 63 |
+
final_dropout (`bool` defaults to `False`):
|
| 64 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
| 65 |
+
ff_inner_dim (`int`, *optional*, defaults to `None`):
|
| 66 |
+
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
|
| 67 |
+
ff_bias (`bool`, defaults to `True`):
|
| 68 |
+
Whether or not to use bias in Feed-forward layer.
|
| 69 |
+
attention_out_bias (`bool`, defaults to `True`):
|
| 70 |
+
Whether or not to use bias in Attention output projection layer.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
dim: int,
|
| 76 |
+
num_attention_heads: int,
|
| 77 |
+
attention_head_dim: int,
|
| 78 |
+
time_embed_dim: int,
|
| 79 |
+
dropout: float = 0.0,
|
| 80 |
+
activation_fn: str = "gelu-approximate",
|
| 81 |
+
attention_bias: bool = False,
|
| 82 |
+
qk_norm: bool = True,
|
| 83 |
+
norm_elementwise_affine: bool = True,
|
| 84 |
+
norm_eps: float = 1e-5,
|
| 85 |
+
final_dropout: bool = True,
|
| 86 |
+
ff_inner_dim: Optional[int] = None,
|
| 87 |
+
ff_bias: bool = True,
|
| 88 |
+
attention_out_bias: bool = True,
|
| 89 |
+
):
|
| 90 |
+
super().__init__()
|
| 91 |
+
|
| 92 |
+
# 1. Self Attention
|
| 93 |
+
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
| 94 |
+
|
| 95 |
+
self.attn1 = Attention(
|
| 96 |
+
query_dim=dim,
|
| 97 |
+
dim_head=attention_head_dim,
|
| 98 |
+
heads=num_attention_heads,
|
| 99 |
+
qk_norm="layer_norm" if qk_norm else None,
|
| 100 |
+
eps=1e-6,
|
| 101 |
+
bias=attention_bias,
|
| 102 |
+
out_bias=attention_out_bias,
|
| 103 |
+
processor=CogVideoXAttnProcessor2_0(),
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# 2. Feed Forward
|
| 107 |
+
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
| 108 |
+
|
| 109 |
+
self.ff = FeedForward(
|
| 110 |
+
dim,
|
| 111 |
+
dropout=dropout,
|
| 112 |
+
activation_fn=activation_fn,
|
| 113 |
+
final_dropout=final_dropout,
|
| 114 |
+
inner_dim=ff_inner_dim,
|
| 115 |
+
bias=ff_bias,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def forward(
|
| 119 |
+
self,
|
| 120 |
+
hidden_states: torch.Tensor,
|
| 121 |
+
encoder_hidden_states: torch.Tensor,
|
| 122 |
+
temb: torch.Tensor,
|
| 123 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 124 |
+
) -> torch.Tensor:
|
| 125 |
+
text_seq_length = encoder_hidden_states.size(1)
|
| 126 |
+
|
| 127 |
+
# norm & modulate
|
| 128 |
+
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
| 129 |
+
hidden_states, encoder_hidden_states, temb
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# attention
|
| 133 |
+
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
| 134 |
+
hidden_states=norm_hidden_states,
|
| 135 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 136 |
+
image_rotary_emb=image_rotary_emb,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
| 140 |
+
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
| 141 |
+
|
| 142 |
+
# norm & modulate
|
| 143 |
+
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
| 144 |
+
hidden_states, encoder_hidden_states, temb
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# feed-forward
|
| 148 |
+
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
| 149 |
+
ff_output = self.ff(norm_hidden_states)
|
| 150 |
+
|
| 151 |
+
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
| 152 |
+
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
| 153 |
+
|
| 154 |
+
return hidden_states, encoder_hidden_states
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
| 158 |
+
"""
|
| 159 |
+
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
| 160 |
+
|
| 161 |
+
Parameters:
|
| 162 |
+
num_attention_heads (`int`, defaults to `30`):
|
| 163 |
+
The number of heads to use for multi-head attention.
|
| 164 |
+
attention_head_dim (`int`, defaults to `64`):
|
| 165 |
+
The number of channels in each head.
|
| 166 |
+
in_channels (`int`, defaults to `16`):
|
| 167 |
+
The number of channels in the input.
|
| 168 |
+
out_channels (`int`, *optional*, defaults to `16`):
|
| 169 |
+
The number of channels in the output.
|
| 170 |
+
flip_sin_to_cos (`bool`, defaults to `True`):
|
| 171 |
+
Whether to flip the sin to cos in the time embedding.
|
| 172 |
+
time_embed_dim (`int`, defaults to `512`):
|
| 173 |
+
Output dimension of timestep embeddings.
|
| 174 |
+
ofs_embed_dim (`int`, defaults to `512`):
|
| 175 |
+
Output dimension of "ofs" embeddings used in CogVideoX-5b-I2B in version 1.5
|
| 176 |
+
text_embed_dim (`int`, defaults to `4096`):
|
| 177 |
+
Input dimension of text embeddings from the text encoder.
|
| 178 |
+
num_layers (`int`, defaults to `30`):
|
| 179 |
+
The number of layers of Transformer blocks to use.
|
| 180 |
+
dropout (`float`, defaults to `0.0`):
|
| 181 |
+
The dropout probability to use.
|
| 182 |
+
attention_bias (`bool`, defaults to `True`):
|
| 183 |
+
Whether to use bias in the attention projection layers.
|
| 184 |
+
sample_width (`int`, defaults to `90`):
|
| 185 |
+
The width of the input latents.
|
| 186 |
+
sample_height (`int`, defaults to `60`):
|
| 187 |
+
The height of the input latents.
|
| 188 |
+
sample_frames (`int`, defaults to `49`):
|
| 189 |
+
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
|
| 190 |
+
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
|
| 191 |
+
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
|
| 192 |
+
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
|
| 193 |
+
patch_size (`int`, defaults to `2`):
|
| 194 |
+
The size of the patches to use in the patch embedding layer.
|
| 195 |
+
temporal_compression_ratio (`int`, defaults to `4`):
|
| 196 |
+
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
|
| 197 |
+
max_text_seq_length (`int`, defaults to `226`):
|
| 198 |
+
The maximum sequence length of the input text embeddings.
|
| 199 |
+
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
| 200 |
+
Activation function to use in feed-forward.
|
| 201 |
+
timestep_activation_fn (`str`, defaults to `"silu"`):
|
| 202 |
+
Activation function to use when generating the timestep embeddings.
|
| 203 |
+
norm_elementwise_affine (`bool`, defaults to `True`):
|
| 204 |
+
Whether to use elementwise affine in normalization layers.
|
| 205 |
+
norm_eps (`float`, defaults to `1e-5`):
|
| 206 |
+
The epsilon value to use in normalization layers.
|
| 207 |
+
spatial_interpolation_scale (`float`, defaults to `1.875`):
|
| 208 |
+
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
|
| 209 |
+
temporal_interpolation_scale (`float`, defaults to `1.0`):
|
| 210 |
+
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
_supports_gradient_checkpointing = True
|
| 214 |
+
|
| 215 |
+
@register_to_config
|
| 216 |
+
def __init__(
|
| 217 |
+
self,
|
| 218 |
+
num_attention_heads: int = 30,
|
| 219 |
+
attention_head_dim: int = 64,
|
| 220 |
+
in_channels: int = 16,
|
| 221 |
+
out_channels: Optional[int] = 16,
|
| 222 |
+
flip_sin_to_cos: bool = True,
|
| 223 |
+
freq_shift: int = 0,
|
| 224 |
+
time_embed_dim: int = 512,
|
| 225 |
+
ofs_embed_dim: Optional[int] = None,
|
| 226 |
+
text_embed_dim: int = 4096,
|
| 227 |
+
num_layers: int = 30,
|
| 228 |
+
dropout: float = 0.0,
|
| 229 |
+
attention_bias: bool = True,
|
| 230 |
+
sample_width: int = 90,
|
| 231 |
+
sample_height: int = 60,
|
| 232 |
+
sample_frames: int = 49,
|
| 233 |
+
patch_size: int = 2,
|
| 234 |
+
patch_size_t: Optional[int] = None,
|
| 235 |
+
temporal_compression_ratio: int = 4,
|
| 236 |
+
max_text_seq_length: int = 226,
|
| 237 |
+
activation_fn: str = "gelu-approximate",
|
| 238 |
+
timestep_activation_fn: str = "silu",
|
| 239 |
+
norm_elementwise_affine: bool = True,
|
| 240 |
+
norm_eps: float = 1e-5,
|
| 241 |
+
spatial_interpolation_scale: float = 1.875,
|
| 242 |
+
temporal_interpolation_scale: float = 1.0,
|
| 243 |
+
use_rotary_positional_embeddings: bool = False,
|
| 244 |
+
use_learned_positional_embeddings: bool = False,
|
| 245 |
+
patch_bias: bool = True,
|
| 246 |
+
):
|
| 247 |
+
super().__init__()
|
| 248 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 249 |
+
|
| 250 |
+
if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
|
| 251 |
+
raise ValueError(
|
| 252 |
+
"There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
|
| 253 |
+
"embeddings. If you're using a custom model and/or believe this should be supported, please open an "
|
| 254 |
+
"issue at https://github.com/huggingface/diffusers/issues."
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# 1. Patch embedding
|
| 258 |
+
self.patch_embed = CogVideoXPatchEmbedWBlur(
|
| 259 |
+
patch_size=patch_size,
|
| 260 |
+
patch_size_t=patch_size_t,
|
| 261 |
+
in_channels=in_channels,
|
| 262 |
+
embed_dim=inner_dim,
|
| 263 |
+
text_embed_dim=text_embed_dim,
|
| 264 |
+
bias=patch_bias,
|
| 265 |
+
sample_width=sample_width,
|
| 266 |
+
sample_height=sample_height,
|
| 267 |
+
sample_frames=sample_frames,
|
| 268 |
+
temporal_compression_ratio=temporal_compression_ratio,
|
| 269 |
+
max_text_seq_length=max_text_seq_length,
|
| 270 |
+
spatial_interpolation_scale=spatial_interpolation_scale,
|
| 271 |
+
temporal_interpolation_scale=temporal_interpolation_scale,
|
| 272 |
+
use_positional_embeddings=not use_rotary_positional_embeddings,
|
| 273 |
+
use_learned_positional_embeddings=use_learned_positional_embeddings,
|
| 274 |
+
)
|
| 275 |
+
self.embedding_dropout = nn.Dropout(dropout)
|
| 276 |
+
|
| 277 |
+
# 2. Time embeddings and ofs embedding(Only CogVideoX1.5-5B I2V have)
|
| 278 |
+
|
| 279 |
+
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
| 280 |
+
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
self.ofs_proj = None
|
| 284 |
+
self.ofs_embedding = None
|
| 285 |
+
if ofs_embed_dim:
|
| 286 |
+
self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift)
|
| 287 |
+
self.ofs_embedding = TimestepEmbedding(
|
| 288 |
+
ofs_embed_dim, ofs_embed_dim, timestep_activation_fn
|
| 289 |
+
) # same as time embeddings, for ofs
|
| 290 |
+
|
| 291 |
+
# 3. Define spatio-temporal transformers blocks
|
| 292 |
+
self.transformer_blocks = nn.ModuleList(
|
| 293 |
+
[
|
| 294 |
+
CogVideoXBlock(
|
| 295 |
+
dim=inner_dim,
|
| 296 |
+
num_attention_heads=num_attention_heads,
|
| 297 |
+
attention_head_dim=attention_head_dim,
|
| 298 |
+
time_embed_dim=time_embed_dim,
|
| 299 |
+
dropout=dropout,
|
| 300 |
+
activation_fn=activation_fn,
|
| 301 |
+
attention_bias=attention_bias,
|
| 302 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
| 303 |
+
norm_eps=norm_eps,
|
| 304 |
+
)
|
| 305 |
+
for _ in range(num_layers)
|
| 306 |
+
]
|
| 307 |
+
)
|
| 308 |
+
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
|
| 309 |
+
|
| 310 |
+
# 4. Output blocks
|
| 311 |
+
self.norm_out = AdaLayerNorm(
|
| 312 |
+
embedding_dim=time_embed_dim,
|
| 313 |
+
output_dim=2 * inner_dim,
|
| 314 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
| 315 |
+
norm_eps=norm_eps,
|
| 316 |
+
chunk_dim=1,
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
if patch_size_t is None:
|
| 320 |
+
# For CogVideox 1.0
|
| 321 |
+
output_dim = patch_size * patch_size * out_channels
|
| 322 |
+
else:
|
| 323 |
+
# For CogVideoX 1.5
|
| 324 |
+
output_dim = patch_size * patch_size * patch_size_t * out_channels
|
| 325 |
+
|
| 326 |
+
self.proj_out = nn.Linear(inner_dim, output_dim)
|
| 327 |
+
|
| 328 |
+
self.gradient_checkpointing = False
|
| 329 |
+
|
| 330 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 331 |
+
self.gradient_checkpointing = value
|
| 332 |
+
|
| 333 |
+
@property
|
| 334 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 335 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 336 |
+
r"""
|
| 337 |
+
Returns:
|
| 338 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 339 |
+
indexed by its weight name.
|
| 340 |
+
"""
|
| 341 |
+
# set recursively
|
| 342 |
+
processors = {}
|
| 343 |
+
|
| 344 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 345 |
+
if hasattr(module, "get_processor"):
|
| 346 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 347 |
+
|
| 348 |
+
for sub_name, child in module.named_children():
|
| 349 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 350 |
+
|
| 351 |
+
return processors
|
| 352 |
+
|
| 353 |
+
for name, module in self.named_children():
|
| 354 |
+
fn_recursive_add_processors(name, module, processors)
|
| 355 |
+
|
| 356 |
+
return processors
|
| 357 |
+
|
| 358 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 359 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 360 |
+
r"""
|
| 361 |
+
Sets the attention processor to use to compute attention.
|
| 362 |
+
|
| 363 |
+
Parameters:
|
| 364 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 365 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 366 |
+
for **all** `Attention` layers.
|
| 367 |
+
|
| 368 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 369 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 370 |
+
|
| 371 |
+
"""
|
| 372 |
+
count = len(self.attn_processors.keys())
|
| 373 |
+
|
| 374 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 375 |
+
raise ValueError(
|
| 376 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 377 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 381 |
+
if hasattr(module, "set_processor"):
|
| 382 |
+
if not isinstance(processor, dict):
|
| 383 |
+
module.set_processor(processor)
|
| 384 |
+
else:
|
| 385 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 386 |
+
|
| 387 |
+
for sub_name, child in module.named_children():
|
| 388 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 389 |
+
|
| 390 |
+
for name, module in self.named_children():
|
| 391 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 392 |
+
|
| 393 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
|
| 394 |
+
def fuse_qkv_projections(self):
|
| 395 |
+
"""
|
| 396 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
| 397 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
| 398 |
+
|
| 399 |
+
<Tip warning={true}>
|
| 400 |
+
|
| 401 |
+
This API is 🧪 experimental.
|
| 402 |
+
|
| 403 |
+
</Tip>
|
| 404 |
+
"""
|
| 405 |
+
self.original_attn_processors = None
|
| 406 |
+
|
| 407 |
+
for _, attn_processor in self.attn_processors.items():
|
| 408 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
| 409 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
| 410 |
+
|
| 411 |
+
self.original_attn_processors = self.attn_processors
|
| 412 |
+
|
| 413 |
+
for module in self.modules():
|
| 414 |
+
if isinstance(module, Attention):
|
| 415 |
+
module.fuse_projections(fuse=True)
|
| 416 |
+
|
| 417 |
+
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
|
| 418 |
+
|
| 419 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
| 420 |
+
def unfuse_qkv_projections(self):
|
| 421 |
+
"""Disables the fused QKV projection if enabled.
|
| 422 |
+
|
| 423 |
+
<Tip warning={true}>
|
| 424 |
+
|
| 425 |
+
This API is 🧪 experimental.
|
| 426 |
+
|
| 427 |
+
</Tip>
|
| 428 |
+
|
| 429 |
+
"""
|
| 430 |
+
if self.original_attn_processors is not None:
|
| 431 |
+
self.set_attn_processor(self.original_attn_processors)
|
| 432 |
+
|
| 433 |
+
def forward(
|
| 434 |
+
self,
|
| 435 |
+
hidden_states: torch.Tensor,
|
| 436 |
+
encoder_hidden_states: torch.Tensor,
|
| 437 |
+
timestep: Union[int, float, torch.LongTensor],
|
| 438 |
+
intervals: Optional[torch.Tensor],
|
| 439 |
+
condition_mask: Optional[torch.Tensor] = None,
|
| 440 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
| 441 |
+
ofs: Optional[Union[int, float, torch.LongTensor]] = None,
|
| 442 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 443 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 444 |
+
return_dict: bool = True,
|
| 445 |
+
):
|
| 446 |
+
|
| 447 |
+
if attention_kwargs is not None:
|
| 448 |
+
attention_kwargs = attention_kwargs.copy()
|
| 449 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 450 |
+
else:
|
| 451 |
+
lora_scale = 1.0
|
| 452 |
+
|
| 453 |
+
if USE_PEFT_BACKEND:
|
| 454 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 455 |
+
scale_lora_layers(self, lora_scale)
|
| 456 |
+
else:
|
| 457 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 458 |
+
logger.warning(
|
| 459 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
batch_size, num_frames, channels, height, width = hidden_states.shape
|
| 463 |
+
|
| 464 |
+
# 1. Time embedding
|
| 465 |
+
timesteps = timestep
|
| 466 |
+
t_emb = self.time_proj(timesteps)
|
| 467 |
+
|
| 468 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
| 469 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 470 |
+
# there might be better ways to encapsulate this.
|
| 471 |
+
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
| 472 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
| 473 |
+
|
| 474 |
+
if self.ofs_embedding is not None:
|
| 475 |
+
ofs_emb = self.ofs_proj(ofs)
|
| 476 |
+
ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
|
| 477 |
+
ofs_emb = self.ofs_embedding(ofs_emb)
|
| 478 |
+
emb = emb + ofs_emb
|
| 479 |
+
|
| 480 |
+
# 2. Patch embedding
|
| 481 |
+
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states, intervals, condition_mask)
|
| 482 |
+
hidden_states = self.embedding_dropout(hidden_states)
|
| 483 |
+
|
| 484 |
+
text_seq_length = encoder_hidden_states.shape[1]
|
| 485 |
+
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
| 486 |
+
hidden_states = hidden_states[:, text_seq_length:]
|
| 487 |
+
|
| 488 |
+
# 3. Transformer blocks
|
| 489 |
+
for i, block in enumerate(self.transformer_blocks):
|
| 490 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 491 |
+
|
| 492 |
+
def create_custom_forward(module):
|
| 493 |
+
def custom_forward(*inputs):
|
| 494 |
+
return module(*inputs)
|
| 495 |
+
|
| 496 |
+
return custom_forward
|
| 497 |
+
|
| 498 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 499 |
+
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
| 500 |
+
create_custom_forward(block),
|
| 501 |
+
hidden_states,
|
| 502 |
+
encoder_hidden_states,
|
| 503 |
+
emb,
|
| 504 |
+
image_rotary_emb,
|
| 505 |
+
**ckpt_kwargs,
|
| 506 |
+
)
|
| 507 |
+
else:
|
| 508 |
+
hidden_states, encoder_hidden_states = block(
|
| 509 |
+
hidden_states=hidden_states,
|
| 510 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 511 |
+
temb=emb,
|
| 512 |
+
image_rotary_emb=image_rotary_emb,
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
if not self.config.use_rotary_positional_embeddings:
|
| 516 |
+
# CogVideoX-2B
|
| 517 |
+
hidden_states = self.norm_final(hidden_states)
|
| 518 |
+
else:
|
| 519 |
+
# CogVideoX-5B
|
| 520 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 521 |
+
hidden_states = self.norm_final(hidden_states)
|
| 522 |
+
hidden_states = hidden_states[:, text_seq_length:]
|
| 523 |
+
|
| 524 |
+
# 4. Final block
|
| 525 |
+
hidden_states = self.norm_out(hidden_states, temb=emb)
|
| 526 |
+
hidden_states = self.proj_out(hidden_states)
|
| 527 |
+
|
| 528 |
+
# 5. Unpatchify
|
| 529 |
+
p = self.config.patch_size
|
| 530 |
+
p_t = self.config.patch_size_t
|
| 531 |
+
|
| 532 |
+
if p_t is None:
|
| 533 |
+
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
| 534 |
+
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
| 535 |
+
else:
|
| 536 |
+
output = hidden_states.reshape(
|
| 537 |
+
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
|
| 538 |
+
)
|
| 539 |
+
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
|
| 540 |
+
|
| 541 |
+
if USE_PEFT_BACKEND:
|
| 542 |
+
# remove `lora_scale` from each PEFT layer
|
| 543 |
+
unscale_lora_layers(self, lora_scale)
|
| 544 |
+
|
| 545 |
+
if not return_dict:
|
| 546 |
+
return (output,)
|
| 547 |
+
return Transformer2DModelOutput(sample=output)
|
controlnet_pipeline.py
ADDED
|
@@ -0,0 +1,733 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from torchvision import transforms
|
| 9 |
+
from einops import rearrange, repeat
|
| 10 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
| 11 |
+
from diffusers.video_processor import VideoProcessor
|
| 12 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 13 |
+
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
| 14 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 15 |
+
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
| 16 |
+
from diffusers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
| 17 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 18 |
+
from diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipelineOutput, CogVideoXLoraLoaderMixin
|
| 19 |
+
from training.helpers import random_insert_latent_frame, transform_intervals
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
from torch.utils.checkpoint import checkpoint
|
| 22 |
+
|
| 23 |
+
def resize_for_crop(image, crop_h, crop_w):
|
| 24 |
+
img_h, img_w = image.shape[-2:]
|
| 25 |
+
if img_h >= crop_h and img_w >= crop_w:
|
| 26 |
+
coef = max(crop_h / img_h, crop_w / img_w)
|
| 27 |
+
elif img_h <= crop_h and img_w <= crop_w:
|
| 28 |
+
coef = max(crop_h / img_h, crop_w / img_w)
|
| 29 |
+
else:
|
| 30 |
+
coef = crop_h / img_h if crop_h > img_h else crop_w / img_w
|
| 31 |
+
out_h, out_w = int(img_h * coef), int(img_w * coef)
|
| 32 |
+
resized_image = transforms.functional.resize(image, (out_h, out_w), antialias=True)
|
| 33 |
+
return resized_image
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def prepare_frames(input_images, video_size, do_resize=True, do_crop=True):
|
| 37 |
+
input_images = np.stack([np.array(x) for x in input_images])
|
| 38 |
+
images_tensor = torch.from_numpy(input_images).permute(0, 3, 1, 2) / 127.5 - 1
|
| 39 |
+
if do_resize:
|
| 40 |
+
images_tensor = [resize_for_crop(x, crop_h=video_size[0], crop_w=video_size[1]) for x in images_tensor]
|
| 41 |
+
if do_crop:
|
| 42 |
+
images_tensor = [transforms.functional.center_crop(x, video_size) for x in images_tensor]
|
| 43 |
+
if isinstance(images_tensor, list):
|
| 44 |
+
images_tensor = torch.stack(images_tensor)
|
| 45 |
+
return images_tensor.unsqueeze(0)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
| 49 |
+
tw = tgt_width
|
| 50 |
+
th = tgt_height
|
| 51 |
+
h, w = src
|
| 52 |
+
r = h / w
|
| 53 |
+
if r > (th / tw):
|
| 54 |
+
resize_height = th
|
| 55 |
+
resize_width = int(round(th / h * w))
|
| 56 |
+
else:
|
| 57 |
+
resize_width = tw
|
| 58 |
+
resize_height = int(round(tw / w * h))
|
| 59 |
+
|
| 60 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
| 61 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
| 62 |
+
|
| 63 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 67 |
+
def retrieve_timesteps(
|
| 68 |
+
scheduler,
|
| 69 |
+
num_inference_steps: Optional[int] = None,
|
| 70 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 71 |
+
timesteps: Optional[List[int]] = None,
|
| 72 |
+
sigmas: Optional[List[float]] = None,
|
| 73 |
+
**kwargs,
|
| 74 |
+
):
|
| 75 |
+
"""
|
| 76 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 77 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
scheduler (`SchedulerMixin`):
|
| 81 |
+
The scheduler to get timesteps from.
|
| 82 |
+
num_inference_steps (`int`):
|
| 83 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 84 |
+
must be `None`.
|
| 85 |
+
device (`str` or `torch.device`, *optional*):
|
| 86 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 87 |
+
timesteps (`List[int]`, *optional*):
|
| 88 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 89 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 90 |
+
sigmas (`List[float]`, *optional*):
|
| 91 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 92 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 96 |
+
second element is the number of inference steps.
|
| 97 |
+
"""
|
| 98 |
+
if timesteps is not None and sigmas is not None:
|
| 99 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 100 |
+
if timesteps is not None:
|
| 101 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 102 |
+
if not accepts_timesteps:
|
| 103 |
+
raise ValueError(
|
| 104 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 105 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 106 |
+
)
|
| 107 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 108 |
+
timesteps = scheduler.timesteps
|
| 109 |
+
num_inference_steps = len(timesteps)
|
| 110 |
+
elif sigmas is not None:
|
| 111 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 112 |
+
if not accept_sigmas:
|
| 113 |
+
raise ValueError(
|
| 114 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 115 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 116 |
+
)
|
| 117 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 118 |
+
timesteps = scheduler.timesteps
|
| 119 |
+
num_inference_steps = len(timesteps)
|
| 120 |
+
else:
|
| 121 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 122 |
+
timesteps = scheduler.timesteps
|
| 123 |
+
return timesteps, num_inference_steps
|
| 124 |
+
|
| 125 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 126 |
+
def retrieve_latents(
|
| 127 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 128 |
+
):
|
| 129 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 130 |
+
return encoder_output.latent_dist.sample(generator)
|
| 131 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 132 |
+
return encoder_output.latent_dist.mode()
|
| 133 |
+
elif hasattr(encoder_output, "latents"):
|
| 134 |
+
return encoder_output.latents
|
| 135 |
+
else:
|
| 136 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 137 |
+
|
| 138 |
+
class ControlnetCogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
| 139 |
+
_optional_components = []
|
| 140 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 141 |
+
|
| 142 |
+
_callback_tensor_inputs = [
|
| 143 |
+
"latents",
|
| 144 |
+
"prompt_embeds",
|
| 145 |
+
"negative_prompt_embeds",
|
| 146 |
+
]
|
| 147 |
+
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
tokenizer: T5Tokenizer,
|
| 151 |
+
text_encoder: T5EncoderModel,
|
| 152 |
+
vae: AutoencoderKLCogVideoX,
|
| 153 |
+
transformer: CogVideoXTransformer3DModel,
|
| 154 |
+
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
| 155 |
+
):
|
| 156 |
+
super().__init__()
|
| 157 |
+
|
| 158 |
+
self.register_modules(
|
| 159 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
| 160 |
+
)
|
| 161 |
+
self.vae_scale_factor_spatial = (
|
| 162 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
| 163 |
+
)
|
| 164 |
+
self.vae_scale_factor_temporal = (
|
| 165 |
+
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def _get_t5_prompt_embeds(
|
| 173 |
+
self,
|
| 174 |
+
prompt: Union[str, List[str]] = None,
|
| 175 |
+
num_videos_per_prompt: int = 1,
|
| 176 |
+
max_sequence_length: int = 226,
|
| 177 |
+
device: Optional[torch.device] = None,
|
| 178 |
+
dtype: Optional[torch.dtype] = None,
|
| 179 |
+
):
|
| 180 |
+
device = device or self._execution_device
|
| 181 |
+
dtype = dtype or self.text_encoder.dtype
|
| 182 |
+
|
| 183 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 184 |
+
batch_size = len(prompt)
|
| 185 |
+
|
| 186 |
+
text_inputs = self.tokenizer(
|
| 187 |
+
prompt,
|
| 188 |
+
padding="max_length",
|
| 189 |
+
max_length=max_sequence_length,
|
| 190 |
+
truncation=True,
|
| 191 |
+
add_special_tokens=True,
|
| 192 |
+
return_tensors="pt",
|
| 193 |
+
)
|
| 194 |
+
text_input_ids = text_inputs.input_ids
|
| 195 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 196 |
+
|
| 197 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 198 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 199 |
+
logger.warning(
|
| 200 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 201 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Had to disable auto cast here, otherwise the text encoder produces NaNs.
|
| 205 |
+
# Hope it doesn't break training
|
| 206 |
+
with torch.autocast(device_type=device.type, enabled=False):
|
| 207 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
| 208 |
+
# prompt embeds is nan here!
|
| 209 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 210 |
+
|
| 211 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 212 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 213 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 214 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 215 |
+
|
| 216 |
+
return prompt_embeds
|
| 217 |
+
|
| 218 |
+
def encode_prompt(
|
| 219 |
+
self,
|
| 220 |
+
prompt: Union[str, List[str]],
|
| 221 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 222 |
+
do_classifier_free_guidance: bool = True,
|
| 223 |
+
num_videos_per_prompt: int = 1,
|
| 224 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 225 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 226 |
+
max_sequence_length: int = 226,
|
| 227 |
+
device: Optional[torch.device] = None,
|
| 228 |
+
dtype: Optional[torch.dtype] = None,
|
| 229 |
+
):
|
| 230 |
+
r"""
|
| 231 |
+
Encodes the prompt into text encoder hidden states.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 235 |
+
prompt to be encoded
|
| 236 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 237 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 238 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 239 |
+
less than `1`).
|
| 240 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 241 |
+
Whether to use classifier free guidance or not.
|
| 242 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 243 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 244 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 245 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 246 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 247 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 248 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 249 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 250 |
+
argument.
|
| 251 |
+
device: (`torch.device`, *optional*):
|
| 252 |
+
torch device
|
| 253 |
+
dtype: (`torch.dtype`, *optional*):
|
| 254 |
+
torch dtype
|
| 255 |
+
"""
|
| 256 |
+
device = device or self._execution_device
|
| 257 |
+
|
| 258 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 259 |
+
if prompt is not None:
|
| 260 |
+
batch_size = len(prompt)
|
| 261 |
+
else:
|
| 262 |
+
batch_size = prompt_embeds.shape[0]
|
| 263 |
+
|
| 264 |
+
if prompt_embeds is None:
|
| 265 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 266 |
+
prompt=prompt,
|
| 267 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 268 |
+
max_sequence_length=max_sequence_length,
|
| 269 |
+
device=device,
|
| 270 |
+
dtype=dtype,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 274 |
+
negative_prompt = negative_prompt or ""
|
| 275 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 276 |
+
|
| 277 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 278 |
+
raise TypeError(
|
| 279 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 280 |
+
f" {type(prompt)}."
|
| 281 |
+
)
|
| 282 |
+
elif batch_size != len(negative_prompt):
|
| 283 |
+
raise ValueError(
|
| 284 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 285 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 286 |
+
" the batch size of `prompt`."
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 290 |
+
prompt=negative_prompt,
|
| 291 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 292 |
+
max_sequence_length=max_sequence_length,
|
| 293 |
+
device=device,
|
| 294 |
+
dtype=dtype,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
return prompt_embeds, negative_prompt_embeds
|
| 298 |
+
|
| 299 |
+
def prepare_latents(
|
| 300 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
| 301 |
+
):
|
| 302 |
+
shape = (
|
| 303 |
+
batch_size,
|
| 304 |
+
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
| 305 |
+
num_channels_latents,
|
| 306 |
+
height // self.vae_scale_factor_spatial,
|
| 307 |
+
width // self.vae_scale_factor_spatial,
|
| 308 |
+
)
|
| 309 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 310 |
+
raise ValueError(
|
| 311 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 312 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
if latents is None:
|
| 316 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 317 |
+
else:
|
| 318 |
+
latents = latents.to(device)
|
| 319 |
+
|
| 320 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 321 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 322 |
+
return latents
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def prepare_image_latents(self,
|
| 327 |
+
image: torch.Tensor,
|
| 328 |
+
batch_size: int = 1,
|
| 329 |
+
num_channels_latents: int = 16,
|
| 330 |
+
num_frames: int = 13,
|
| 331 |
+
height: int = 60,
|
| 332 |
+
width: int = 90,
|
| 333 |
+
dtype: Optional[torch.dtype] = None,
|
| 334 |
+
device: Optional[torch.device] = None,
|
| 335 |
+
generator: Optional[torch.Generator] = None,
|
| 336 |
+
latents: Optional[torch.Tensor] = None,):
|
| 337 |
+
|
| 338 |
+
image_prepared = prepare_frames(image, (height, width)).to(device).to(dtype=dtype).permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
| 339 |
+
|
| 340 |
+
image_latents = [retrieve_latents(self.vae.encode(image_prepared), generator)]
|
| 341 |
+
|
| 342 |
+
image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
|
| 343 |
+
|
| 344 |
+
if not self.vae.config.invert_scale_latents:
|
| 345 |
+
image_latents = self.vae_scaling_factor_image * image_latents
|
| 346 |
+
else:
|
| 347 |
+
# This is awkward but required because the CogVideoX team forgot to multiply the
|
| 348 |
+
# scaling factor during training :)
|
| 349 |
+
image_latents = 1 / self.vae_scaling_factor_image * image_latents
|
| 350 |
+
|
| 351 |
+
# else:
|
| 352 |
+
# # This is awkward but required because the CogVideoX team forgot to multiply the
|
| 353 |
+
# # scaling factor during training :)
|
| 354 |
+
# image_latents = 1 / self.vae_scaling_factor_image * image_latents
|
| 355 |
+
|
| 356 |
+
return image_prepared, image_latents
|
| 357 |
+
|
| 358 |
+
# def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 359 |
+
# latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
| 360 |
+
# latents = 1 / self.vae.config.scaling_factor * latents
|
| 361 |
+
|
| 362 |
+
# frames = self.vae.decode(latents).sample
|
| 363 |
+
# return frames
|
| 364 |
+
|
| 365 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 366 |
+
latents = latents.permute(0, 2, 1, 3, 4) # [B, C, T, H, W]
|
| 367 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 368 |
+
|
| 369 |
+
def decode_fn(x):
|
| 370 |
+
return self.vae.decode(x).sample
|
| 371 |
+
|
| 372 |
+
# Use checkpointing to save memory
|
| 373 |
+
frames = checkpoint(decode_fn, latents, use_reentrant=False)
|
| 374 |
+
return frames
|
| 375 |
+
|
| 376 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 377 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 378 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 379 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 380 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 381 |
+
# and should be between [0, 1]
|
| 382 |
+
|
| 383 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 384 |
+
extra_step_kwargs = {}
|
| 385 |
+
if accepts_eta:
|
| 386 |
+
extra_step_kwargs["eta"] = eta
|
| 387 |
+
|
| 388 |
+
# check if the scheduler accepts generator
|
| 389 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 390 |
+
if accepts_generator:
|
| 391 |
+
extra_step_kwargs["generator"] = generator
|
| 392 |
+
return extra_step_kwargs
|
| 393 |
+
|
| 394 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 395 |
+
def check_inputs(
|
| 396 |
+
self,
|
| 397 |
+
prompt,
|
| 398 |
+
height,
|
| 399 |
+
width,
|
| 400 |
+
negative_prompt,
|
| 401 |
+
callback_on_step_end_tensor_inputs,
|
| 402 |
+
prompt_embeds=None,
|
| 403 |
+
negative_prompt_embeds=None,
|
| 404 |
+
):
|
| 405 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 406 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 407 |
+
|
| 408 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 409 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 410 |
+
):
|
| 411 |
+
raise ValueError(
|
| 412 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 413 |
+
)
|
| 414 |
+
if prompt is not None and prompt_embeds is not None:
|
| 415 |
+
raise ValueError(
|
| 416 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 417 |
+
" only forward one of the two."
|
| 418 |
+
)
|
| 419 |
+
elif prompt is None and prompt_embeds is None:
|
| 420 |
+
raise ValueError(
|
| 421 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 422 |
+
)
|
| 423 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 424 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 425 |
+
|
| 426 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 427 |
+
raise ValueError(
|
| 428 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 429 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 433 |
+
raise ValueError(
|
| 434 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 435 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 439 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 440 |
+
raise ValueError(
|
| 441 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 442 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 443 |
+
f" {negative_prompt_embeds.shape}."
|
| 444 |
+
)
|
| 445 |
+
def fuse_qkv_projections(self) -> None:
|
| 446 |
+
r"""Enables fused QKV projections."""
|
| 447 |
+
self.fusing_transformer = True
|
| 448 |
+
self.transformer.fuse_qkv_projections()
|
| 449 |
+
|
| 450 |
+
def unfuse_qkv_projections(self) -> None:
|
| 451 |
+
r"""Disable QKV projection fusion if enabled."""
|
| 452 |
+
if not self.fusing_transformer:
|
| 453 |
+
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
| 454 |
+
else:
|
| 455 |
+
self.transformer.unfuse_qkv_projections()
|
| 456 |
+
self.fusing_transformer = False
|
| 457 |
+
|
| 458 |
+
def _prepare_rotary_positional_embeddings(
|
| 459 |
+
self,
|
| 460 |
+
height: int,
|
| 461 |
+
width: int,
|
| 462 |
+
num_frames: int,
|
| 463 |
+
device: torch.device,
|
| 464 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 465 |
+
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 466 |
+
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 467 |
+
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 468 |
+
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 469 |
+
|
| 470 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
| 471 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
| 472 |
+
)
|
| 473 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 474 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 475 |
+
crops_coords=grid_crops_coords,
|
| 476 |
+
grid_size=(grid_height, grid_width),
|
| 477 |
+
temporal_size=num_frames,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
freqs_cos = freqs_cos.to(device=device)
|
| 481 |
+
freqs_sin = freqs_sin.to(device=device)
|
| 482 |
+
return freqs_cos, freqs_sin
|
| 483 |
+
|
| 484 |
+
@property
|
| 485 |
+
def guidance_scale(self):
|
| 486 |
+
return self._guidance_scale
|
| 487 |
+
|
| 488 |
+
@property
|
| 489 |
+
def num_timesteps(self):
|
| 490 |
+
return self._num_timesteps
|
| 491 |
+
|
| 492 |
+
@property
|
| 493 |
+
def attention_kwargs(self):
|
| 494 |
+
return self._attention_kwargs
|
| 495 |
+
|
| 496 |
+
@property
|
| 497 |
+
def interrupt(self):
|
| 498 |
+
return self._interrupt
|
| 499 |
+
|
| 500 |
+
@torch.no_grad()
|
| 501 |
+
def __call__(
|
| 502 |
+
self,
|
| 503 |
+
image,
|
| 504 |
+
input_intervals,
|
| 505 |
+
output_intervals,
|
| 506 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 507 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 508 |
+
height: int = 480,
|
| 509 |
+
width: int = 720,
|
| 510 |
+
num_frames: int = 49,
|
| 511 |
+
num_inference_steps: int = 50,
|
| 512 |
+
timesteps: Optional[List[int]] = None,
|
| 513 |
+
guidance_scale: float = 6,
|
| 514 |
+
use_dynamic_cfg: bool = False,
|
| 515 |
+
num_videos_per_prompt: int = 1,
|
| 516 |
+
eta: float = 0.0,
|
| 517 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 518 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 519 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 520 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 521 |
+
output_type: str = "pil",
|
| 522 |
+
return_dict: bool = True,
|
| 523 |
+
callback_on_step_end: Optional[
|
| 524 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 525 |
+
] = None,
|
| 526 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 527 |
+
max_sequence_length: int = 226,
|
| 528 |
+
) -> Union[CogVideoXPipelineOutput, Tuple]:
|
| 529 |
+
if num_frames > 49:
|
| 530 |
+
raise ValueError(
|
| 531 |
+
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 535 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 536 |
+
|
| 537 |
+
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
| 538 |
+
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
| 539 |
+
num_videos_per_prompt = 1
|
| 540 |
+
|
| 541 |
+
self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
# 1. Check inputs. Raise error if not correct
|
| 545 |
+
self.check_inputs(
|
| 546 |
+
prompt,
|
| 547 |
+
height,
|
| 548 |
+
width,
|
| 549 |
+
negative_prompt,
|
| 550 |
+
callback_on_step_end_tensor_inputs,
|
| 551 |
+
prompt_embeds,
|
| 552 |
+
negative_prompt_embeds,
|
| 553 |
+
)
|
| 554 |
+
self._guidance_scale = guidance_scale
|
| 555 |
+
self._interrupt = False
|
| 556 |
+
|
| 557 |
+
# 2. Default call parameters
|
| 558 |
+
if prompt is not None and isinstance(prompt, str):
|
| 559 |
+
batch_size = 1
|
| 560 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 561 |
+
batch_size = len(prompt)
|
| 562 |
+
else:
|
| 563 |
+
batch_size = prompt_embeds.shape[0]
|
| 564 |
+
|
| 565 |
+
device = self._execution_device
|
| 566 |
+
|
| 567 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 568 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 569 |
+
# corresponds to doing no classifier free guidance.
|
| 570 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 571 |
+
|
| 572 |
+
# 3. Encode input prompt
|
| 573 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 574 |
+
prompt,
|
| 575 |
+
negative_prompt,
|
| 576 |
+
do_classifier_free_guidance,
|
| 577 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 578 |
+
prompt_embeds=prompt_embeds,
|
| 579 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 580 |
+
max_sequence_length=max_sequence_length,
|
| 581 |
+
device=device,
|
| 582 |
+
)
|
| 583 |
+
if do_classifier_free_guidance:
|
| 584 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 585 |
+
|
| 586 |
+
# 4. Prepare timesteps
|
| 587 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 588 |
+
self._num_timesteps = len(timesteps)
|
| 589 |
+
|
| 590 |
+
# 5. Prepare latents.
|
| 591 |
+
latent_channels = 16 #self.transformer.config.in_channels
|
| 592 |
+
latents = self.prepare_latents(
|
| 593 |
+
batch_size * num_videos_per_prompt,
|
| 594 |
+
latent_channels,
|
| 595 |
+
num_frames,
|
| 596 |
+
height,
|
| 597 |
+
width,
|
| 598 |
+
prompt_embeds.dtype,
|
| 599 |
+
device,
|
| 600 |
+
generator,
|
| 601 |
+
latents,
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
image_prepared, image_latents = self.prepare_image_latents(
|
| 607 |
+
image,
|
| 608 |
+
batch_size=batch_size,
|
| 609 |
+
num_channels_latents=latent_channels,
|
| 610 |
+
num_frames=num_frames,
|
| 611 |
+
height=height,
|
| 612 |
+
width=width,
|
| 613 |
+
dtype=prompt_embeds.dtype,
|
| 614 |
+
device=device,
|
| 615 |
+
generator=generator,
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 620 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 621 |
+
|
| 622 |
+
# 8. Create rotary embeds if required - THIS IS NOT USED
|
| 623 |
+
image_rotary_emb = (
|
| 624 |
+
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
| 625 |
+
if self.transformer.config.use_rotary_positional_embeddings
|
| 626 |
+
else None
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
# 9. Denoising loop
|
| 630 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
input_intervals = input_intervals.to(device)
|
| 634 |
+
output_intervals = output_intervals.to(device)
|
| 635 |
+
|
| 636 |
+
input_intervals = transform_intervals(input_intervals)
|
| 637 |
+
output_intervals = transform_intervals(output_intervals)
|
| 638 |
+
|
| 639 |
+
latents_initial, target, condition_mask, intervals = random_insert_latent_frame(image_latents, latents, latents, input_intervals, output_intervals, special_info="just_one")
|
| 640 |
+
|
| 641 |
+
latents = latents_initial.clone()
|
| 642 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 643 |
+
# for DPM-solver++
|
| 644 |
+
old_pred_original_sample = None
|
| 645 |
+
for i, t in enumerate(timesteps):
|
| 646 |
+
if self.interrupt:
|
| 647 |
+
continue
|
| 648 |
+
|
| 649 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 650 |
+
#replace first latent with image_latents
|
| 651 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 652 |
+
|
| 653 |
+
if do_classifier_free_guidance:
|
| 654 |
+
latent_model_input[0][condition_mask[0]] = 0 #set unconditioned latents to 0
|
| 655 |
+
#TODO: Replace the conditional latents with the input latents
|
| 656 |
+
latent_model_input[1][condition_mask[0]] = latents_initial[0][condition_mask[0]].to(latent_model_input.dtype)
|
| 657 |
+
else:
|
| 658 |
+
latent_model_input[:, condition_mask[0]] = latents_initial[0][condition_mask[0]].to(latent_model_input.dtype)
|
| 659 |
+
|
| 660 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 661 |
+
|
| 662 |
+
current_sampling_percent = i / len(timesteps)
|
| 663 |
+
|
| 664 |
+
latent_model_input = latent_model_input.to(dtype=self.transformer.dtype)
|
| 665 |
+
prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
|
| 666 |
+
# predict noise model_output
|
| 667 |
+
noise_pred = self.transformer(
|
| 668 |
+
hidden_states=latent_model_input,
|
| 669 |
+
encoder_hidden_states=prompt_embeds,
|
| 670 |
+
timestep=timestep,
|
| 671 |
+
intervals=intervals,
|
| 672 |
+
condition_mask=condition_mask,
|
| 673 |
+
image_rotary_emb=image_rotary_emb,
|
| 674 |
+
return_dict=False,
|
| 675 |
+
)[0]
|
| 676 |
+
noise_pred = noise_pred.float()
|
| 677 |
+
|
| 678 |
+
# perform guidance
|
| 679 |
+
if use_dynamic_cfg:
|
| 680 |
+
self._guidance_scale = 1 + guidance_scale * (
|
| 681 |
+
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
| 682 |
+
)
|
| 683 |
+
if do_classifier_free_guidance:
|
| 684 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
| 685 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 686 |
+
#so I think the problem is that the conditional noise doesn't have a realistic noise prediction on its own frame
|
| 687 |
+
#what I really need to do is replace the unconditional noise at that frame
|
| 688 |
+
|
| 689 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 690 |
+
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
| 691 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 692 |
+
else:
|
| 693 |
+
latents, old_pred_original_sample = self.scheduler.step(
|
| 694 |
+
noise_pred,
|
| 695 |
+
old_pred_original_sample,
|
| 696 |
+
t,
|
| 697 |
+
timesteps[i - 1] if i > 0 else None,
|
| 698 |
+
latents,
|
| 699 |
+
**extra_step_kwargs,
|
| 700 |
+
return_dict=False,
|
| 701 |
+
)
|
| 702 |
+
latents = latents.to(prompt_embeds.dtype)
|
| 703 |
+
|
| 704 |
+
# call the callback, if provided
|
| 705 |
+
if callback_on_step_end is not None:
|
| 706 |
+
callback_kwargs = {}
|
| 707 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 708 |
+
callback_kwargs[k] = locals()[k]
|
| 709 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 710 |
+
|
| 711 |
+
latents = callback_outputs.pop("latents", latents)
|
| 712 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 713 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 714 |
+
|
| 715 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 716 |
+
progress_bar.update()
|
| 717 |
+
|
| 718 |
+
#after exiting replace the conditioning latent with image_latents
|
| 719 |
+
#latents[:, motion_blur_amount:motion_blur_amount+1] = image_latents[:, 0:1]
|
| 720 |
+
if not output_type == "latent":
|
| 721 |
+
latents = latents[~condition_mask].unsqueeze(0)
|
| 722 |
+
video = self.decode_latents(latents)
|
| 723 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 724 |
+
else:
|
| 725 |
+
video = latents
|
| 726 |
+
|
| 727 |
+
# Offload all models
|
| 728 |
+
self.maybe_free_model_hooks()
|
| 729 |
+
|
| 730 |
+
if not return_dict:
|
| 731 |
+
return (video,)
|
| 732 |
+
|
| 733 |
+
return CogVideoXPipelineOutput(frames=video)
|
extra/checkpoints_to_hf.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import HfApi
|
| 2 |
+
import os
|
| 3 |
+
#run with HF_TOKEN = your_hf_token before python_command
|
| 4 |
+
api = HfApi(token=os.getenv("HF_TOKEN"))
|
| 5 |
+
folders = ["/datasets/sai/blur2vid/training/cogvideox-baist-test",
|
| 6 |
+
"/datasets/sai/blur2vid/training/cogvideox-gopro-test",
|
| 7 |
+
"/datasets/sai/blur2vid/training/cogvideox-gopro-2x-test",
|
| 8 |
+
"/datasets/sai/blur2vid/training/cogvideox-full-test",
|
| 9 |
+
"/datasets/sai/blur2vid/training/cogvideox-outsidephotos"]
|
| 10 |
+
for folder in folders:
|
| 11 |
+
api.upload_folder(
|
| 12 |
+
folder_path=folder,
|
| 13 |
+
repo_id="tedlasai/blur2vid",
|
| 14 |
+
repo_type="model",
|
| 15 |
+
path_in_repo=os.path.basename(folder)
|
| 16 |
+
)
|
extra/moMets-parallel-baist.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Motion Metrics
|
| 2 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 3 |
+
import numpy as np
|
| 4 |
+
np.float = np.float64
|
| 5 |
+
np.int = np.int_
|
| 6 |
+
import os
|
| 7 |
+
from cdfvd import fvd
|
| 8 |
+
from skimage.metrics import structural_similarity
|
| 9 |
+
import torch
|
| 10 |
+
import lpips
|
| 11 |
+
#from DISTS_pytorch import DISTS
|
| 12 |
+
#import colour as c
|
| 13 |
+
#from torchmetrics.image.fid import FrechetInceptionDistance
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from epe_metric import compute_bidirectional_epe as epe
|
| 16 |
+
import pdb
|
| 17 |
+
import multiprocessing
|
| 18 |
+
import cv2
|
| 19 |
+
import glob
|
| 20 |
+
# init
|
| 21 |
+
dataDir = 'BAISTResultsImages' # 'dataGoPro' #
|
| 22 |
+
gtDir = 'GT' #'GT' #
|
| 23 |
+
methodDirs = ['Ours', 'Animation-from-blur'] #['Favaro','MotionETR','Ours','GOPROGeneralize'] #
|
| 24 |
+
depth = 8
|
| 25 |
+
resFile = './kellytest.npy'#resultsGoPro20250520.npy'#
|
| 26 |
+
|
| 27 |
+
patchDim = 32 #64 #
|
| 28 |
+
pixMax = 1.0
|
| 29 |
+
|
| 30 |
+
nMets = 7 # new results: scoreFVD, scorePWPSNR, scoreEPE, scorePatchSSIM, scorePatchLPIPS, scorePSNR
|
| 31 |
+
compute = True # if False, load previously computed
|
| 32 |
+
eps = 1e-8
|
| 33 |
+
|
| 34 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 35 |
+
|
| 36 |
+
def read_pngs_to_array(path):
|
| 37 |
+
"""Read all PNGs in `path`, sort them by filename, convert BGR→RGB, and stack into an np.ndarray."""
|
| 38 |
+
return np.stack([
|
| 39 |
+
cv2.imread(f, cv2.IMREAD_UNCHANGED)[..., ::-1]
|
| 40 |
+
for f in sorted(glob.glob(f"{path}/*.png"))
|
| 41 |
+
])
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Use 'spawn' to avoid CUDA context issues
|
| 45 |
+
multiprocessing.freeze_support() # on Windows
|
| 46 |
+
multiprocessing.set_start_method('spawn', force=True)
|
| 47 |
+
|
| 48 |
+
def compute_method(results_local, methodDir, files, countMethod):
|
| 49 |
+
|
| 50 |
+
fnLPIPS = lpips.LPIPS(net='alex').to(device)
|
| 51 |
+
#fnDISTS = DISTS().to(device)
|
| 52 |
+
fnFVD = fvd.cdfvd(model='videomae', device=device)
|
| 53 |
+
|
| 54 |
+
countFile = -1
|
| 55 |
+
for file in files:
|
| 56 |
+
countFile+=1
|
| 57 |
+
|
| 58 |
+
# pull frames from MP4
|
| 59 |
+
pathMethod = os.path.join(dataDir, methodDir, file)
|
| 60 |
+
framesMethod = np.clip(read_pngs_to_array(pathMethod).astype(np.float32) / (2**depth-1),0,1)
|
| 61 |
+
pathGT = os.path.join(dataDir, gtDir, file)
|
| 62 |
+
framesGT = np.clip(read_pngs_to_array(pathGT).astype(np.float32) / (2**depth-1),0,1)
|
| 63 |
+
|
| 64 |
+
#make sure the GT and method have the same shape
|
| 65 |
+
assert framesGT.shape == framesMethod.shape, f"GT shape {framesGT.shape} does not match method shape {framesMethod.shape} for file {file}"
|
| 66 |
+
# video metrics
|
| 67 |
+
|
| 68 |
+
# vmaf
|
| 69 |
+
#scoreVMAF = callVMAF(pathGT, pathMethod)
|
| 70 |
+
|
| 71 |
+
# epe - we have to change to tensors here
|
| 72 |
+
framesMethodTensor = torch.from_numpy(framesMethod)
|
| 73 |
+
framesGTtensor = torch.from_numpy(framesGT)
|
| 74 |
+
scoreEPE = epe(framesMethodTensor[0,:,:,:], framesMethodTensor[-1,:,:,:], framesGTtensor[0,:,:,:], framesGTtensor[-1,:,:,:], per_pixel_mode=True).cpu().detach().numpy()
|
| 75 |
+
|
| 76 |
+
# motion blur baseline
|
| 77 |
+
blurryGT = np.mean(framesGT ** 2.2,axis=0) ** (1/2.2)
|
| 78 |
+
blurryMethod = np.mean(framesMethod ** 2.2,axis=0) ** (1/2.2)
|
| 79 |
+
# MSE -> PSNR
|
| 80 |
+
mapBlurryMSE = (blurryGT - blurryMethod)**2
|
| 81 |
+
scoreBlurryMSE = np.mean(mapBlurryMSE)
|
| 82 |
+
scoreBlurryPSNR = (10 * np.log10(pixMax**2 / scoreBlurryMSE))
|
| 83 |
+
|
| 84 |
+
# fvd
|
| 85 |
+
#scoreFVD = fnFVD.compute_fvd(real_videos=(np.expand_dims(framesGT, axis=0)*(2**depth-1)).astype(np.uint8), fake_videos=(np.expand_dims(framesMethod, axis=0)*(2**depth-1)).astype(np.uint8))
|
| 86 |
+
framesGTfvd = np.expand_dims((framesGT * (2**depth-1)).astype(np.uint8), axis=0)
|
| 87 |
+
fnFVD.add_real_stats(framesGTfvd)
|
| 88 |
+
framesMethodFVD = np.expand_dims((framesMethod * (2**depth-1)).astype(np.uint8), axis=0)
|
| 89 |
+
fnFVD.add_fake_stats(framesMethodFVD)
|
| 90 |
+
|
| 91 |
+
# loop directions
|
| 92 |
+
framesMSE = np.stack((framesGT,framesGT)) # pre allocate array for directional PSNR maps
|
| 93 |
+
countDirect = -1
|
| 94 |
+
for direction in directions:
|
| 95 |
+
countDirect = countDirect+1
|
| 96 |
+
order = direction
|
| 97 |
+
|
| 98 |
+
# loop frames + image level metrics
|
| 99 |
+
countFrames = -1
|
| 100 |
+
for i in order:
|
| 101 |
+
countFrames+=1
|
| 102 |
+
|
| 103 |
+
frameMethod = framesMethod[i,:,:,:] # method frames can be re-ordered
|
| 104 |
+
frameGT = framesGT[countFrames,:,:,:]
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
#assert patch size is divisible by image size
|
| 108 |
+
rows, cols, ch = frameGT.shape
|
| 109 |
+
assert rows % patchDim == 0, f"rows {rows} is not divisible by patchDim {patchDim}"
|
| 110 |
+
assert cols % patchDim == 0, f"cols {cols} is not divisible by patchDim {patchDim}"
|
| 111 |
+
|
| 112 |
+
rPatch = np.ceil(rows/patchDim)
|
| 113 |
+
cPatch = np.ceil(cols/patchDim)
|
| 114 |
+
|
| 115 |
+
# LPIPS
|
| 116 |
+
#pdb.set_trace()
|
| 117 |
+
methodTensor = (torch.from_numpy(np.moveaxis(frameMethod, -1, 0)).unsqueeze(0) * 2 - 1).to(device)
|
| 118 |
+
gtTensor = (torch.from_numpy(np.moveaxis(frameGT, -1, 0)).unsqueeze(0) * 2 - 1).to(device)
|
| 119 |
+
#scoreLPIPS = fnLPIPS(gtTensor, methodTensor).squeeze(0,1,2).cpu().detach().numpy()[0]
|
| 120 |
+
|
| 121 |
+
# FID
|
| 122 |
+
#fnFID.update((gtTensor * (2**depth - 1)).to(torch.uint8), real=True)
|
| 123 |
+
#fnFID.update((methodTensor * (2**depth - 1)).to(torch.uint8), real=False)
|
| 124 |
+
|
| 125 |
+
# DISTS
|
| 126 |
+
#scoreDISTS = fnDISTS(gtTensor.to(torch.float), methodTensor.to(torch.float), require_grad=True, batch_average=True).cpu().detach().numpy()
|
| 127 |
+
|
| 128 |
+
# compute ssim
|
| 129 |
+
#scoreSSIM = structural_similarity(frameGT, frameMethod, data_range=pixMax, channel_axis=2)
|
| 130 |
+
|
| 131 |
+
# compute DE 2000
|
| 132 |
+
#frameMethodXYZ = c.RGB_to_XYZ(frameMethod, c.models.RGB_COLOURSPACE_sRGB, apply_cctf_decoding=True)
|
| 133 |
+
#frameMethodLAB = c.XYZ_to_Lab(frameMethodXYZ)
|
| 134 |
+
#frameGTXYZ = c.RGB_to_XYZ(frameGT, c.models.RGB_COLOURSPACE_sRGB, apply_cctf_decoding=True)
|
| 135 |
+
#frameGTLAB = c.XYZ_to_Lab(frameGTXYZ)
|
| 136 |
+
#mapDE2000 = c.delta_E(frameGTLAB, frameMethodLAB, method='CIE 2000')
|
| 137 |
+
#scoreDE2000 = np.mean(mapDE2000)
|
| 138 |
+
|
| 139 |
+
# MSE
|
| 140 |
+
mapMSE = (frameGT - frameMethod)**2
|
| 141 |
+
scoreMSE = np.mean(mapMSE)
|
| 142 |
+
|
| 143 |
+
# PSNR
|
| 144 |
+
framesMSE[countDirect,countFrames,:,:,:] = mapMSE
|
| 145 |
+
#framesPSNR[countDirect,countFrames,:,:,:] = np.clip((10 * np.log10(pixMax**2 / np.clip(mapMSE,a_min=1e-10,a_max=None))),0,100)
|
| 146 |
+
scorePSNR = (10 * np.log10(pixMax**2 / scoreMSE))
|
| 147 |
+
|
| 148 |
+
#for l in range(ch):
|
| 149 |
+
|
| 150 |
+
# channel-wise metrics
|
| 151 |
+
#chanFrameMethod = frameMethod[:,:,l]
|
| 152 |
+
#chanFrameGT = frameGT[:,:,l]
|
| 153 |
+
|
| 154 |
+
# loop patches rows
|
| 155 |
+
for j in range(int(rPatch)):
|
| 156 |
+
|
| 157 |
+
# loop patches cols + patch level metrics
|
| 158 |
+
for k in range(int(cPatch)):
|
| 159 |
+
|
| 160 |
+
startR = j*patchDim
|
| 161 |
+
startC = k*patchDim
|
| 162 |
+
endR = j*patchDim+patchDim
|
| 163 |
+
endC = k*patchDim+patchDim
|
| 164 |
+
|
| 165 |
+
if endR > rows:
|
| 166 |
+
endR = rows
|
| 167 |
+
else:
|
| 168 |
+
pass
|
| 169 |
+
|
| 170 |
+
if endC > cols:
|
| 171 |
+
endC = cols
|
| 172 |
+
else:
|
| 173 |
+
pass
|
| 174 |
+
|
| 175 |
+
# patch metrics
|
| 176 |
+
#patchMSE = np.mean(mapMSE[startR:endR,startC:endC,:])
|
| 177 |
+
#scorePatchPSNR = np.clip((10 * np.log10(pixMax**2 / patchMSE)),0,100)
|
| 178 |
+
if dataDir == 'BAISTResultsImages':
|
| 179 |
+
patchGtTensor = F.interpolate(gtTensor[:,:,startR:endR,startC:endC], scale_factor=2.0, mode='bilinear', align_corners=False)
|
| 180 |
+
patchMethodTensor = F.interpolate(methodTensor[:,:,startR:endR,startC:endC], scale_factor=2.0, mode='bilinear', align_corners=False)
|
| 181 |
+
scorePatchLPIPS = fnLPIPS(patchGtTensor, patchMethodTensor).squeeze(0,1,2).cpu().detach().numpy()[0]
|
| 182 |
+
else:
|
| 183 |
+
scorePatchLPIPS = fnLPIPS(gtTensor[:,:,startR:endR,startC:endC], methodTensor[:,:,startR:endR,startC:endC]).squeeze(0,1,2).cpu().detach().numpy()[0]
|
| 184 |
+
scorePatchSSIM = structural_similarity(frameGT[startR:endR,startC:endC,:], frameMethod[startR:endR,startC:endC,:], data_range=pixMax, channel_axis=2)
|
| 185 |
+
#scorePatchDISTS = fnDISTS(gtTensor[:,:,startR:endR,startC:endC].to(torch.float), methodTensor[:,:,startR:endR,startC:endC].to(torch.float), require_grad=True, batch_average=True).cpu().detach().numpy()
|
| 186 |
+
#scorePatchDE2000 = np.mean(mapDE2000[startR:endR,startC:endC])
|
| 187 |
+
|
| 188 |
+
# i: frame number, j: patch row, k: patch col
|
| 189 |
+
#results[countMethod,countFile,countDirect,i,j,k,3:] = [scoreEPE, scoreBlurryPSNR, scoreLPIPS, scoreDISTS, scoreSSIM, scoreDE2000, scorePSNR, scorePatchPSNR, scorePatchSSIM, scorePatchLPIPS, scorePatchDISTS, scorePatchDE2000]
|
| 190 |
+
results_local[countMethod,countFile,countDirect,i,j,k,2:] = [scoreEPE, scoreBlurryPSNR, scorePatchSSIM, scorePatchLPIPS, scorePSNR]
|
| 191 |
+
print('Method: ', methodDir, ' File: ', file, ' Frame: ', str(i), ' PSNR: ', scorePSNR, end='\r')
|
| 192 |
+
#print('VMAF: ', str(scoreVMAF), ' FVD: ', str(scoreFVD), ' LPIPS: ', str(scoreLPIPS), ' FID: ', str(scoreFID), ' DISTS: ', str(scoreDISTS), ' SSIM: ', str(scoreSSIM), ' DE2000: ', str(scoreDE2000), ' PSNR: ', str(scorePSNR), ' Patch PSNR: ', str(scorePatchPSNR), end='\r')
|
| 193 |
+
#pdb.set_trace()
|
| 194 |
+
scorePWPSNR = (10 * np.log10(pixMax**2 / np.mean(np.min(np.mean(framesMSE, axis=(1)),axis=0)))) # take max pixel wise PSNR per direction, average over image dims
|
| 195 |
+
#print('Method: ', methodDir, ' File: ', file, ' Frame: ', str(i), ' PWPSNR: ', scorePWPSNR, end='\n')
|
| 196 |
+
#scorePWPSNR = np.clip((10 * np.log10(pixMax**2 / np.mean(np.min(framesPSNR, axis=0),axis=(1,2,3)))),0,100) # take max pixel wise PSNR per direction, average over image dims
|
| 197 |
+
results_local[countMethod,countFile,:,:,:,:,1] = np.tile(scorePWPSNR, results_local.shape[2:-1])#np.broadcast_to(scorePWPSNR[:, np.newaxis, np.newaxis], results.shape[3:-1])
|
| 198 |
+
np.save(resFile, results_local) # save part of the way through the loop ..
|
| 199 |
+
|
| 200 |
+
#scoreFID = fnFID.compute().cpu().detach().numpy()
|
| 201 |
+
#fnFID.reset()
|
| 202 |
+
#results[countMethod,:,:,:,:,:,0] = np.tile(scoreFID, results.shape[1:-1])
|
| 203 |
+
scoreFVD = fnFVD.compute_fvd_from_stats()
|
| 204 |
+
fnFVD.empty_real_stats()
|
| 205 |
+
fnFVD.empty_fake_stats()
|
| 206 |
+
results_local[countMethod,:,:,:,:,:,0] = np.tile(scoreFVD, results_local.shape[1:-1])
|
| 207 |
+
print('Results computed .. analyzing ..')
|
| 208 |
+
|
| 209 |
+
return results_local
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# init results matrix
|
| 213 |
+
path = os.path.join(dataDir, gtDir)
|
| 214 |
+
clipDirs = [name for name in os.listdir(path) if os.path.isdir(os.path.join(path, name))]
|
| 215 |
+
files = []
|
| 216 |
+
if dataDir == 'BAISTResultsImages':
|
| 217 |
+
extraFknDir = 'blur'
|
| 218 |
+
else:
|
| 219 |
+
extraFknDir = ''
|
| 220 |
+
for clipDir in clipDirs:
|
| 221 |
+
path = os.path.join(dataDir, gtDir, clipDir, extraFknDir)
|
| 222 |
+
files = files + [os.path.join(clipDir,extraFknDir,name) for name in os.listdir(path)]
|
| 223 |
+
files = sorted(files)
|
| 224 |
+
path = os.path.join(dataDir, methodDirs[0], files[0])
|
| 225 |
+
testFileGT = read_pngs_to_array(path)
|
| 226 |
+
frams,rows,cols,ch = testFileGT.shape
|
| 227 |
+
framRange = [i for i in range(frams)]
|
| 228 |
+
directions = [framRange, framRange[::-1]]
|
| 229 |
+
|
| 230 |
+
#loop through all methods and make sure they all have the same directory structure and same number of files
|
| 231 |
+
for methodDir in methodDirs:
|
| 232 |
+
path = os.path.join(dataDir, methodDir)
|
| 233 |
+
clipDirs = [name for name in os.listdir(path) if os.path.isdir(os.path.join(path, name))]
|
| 234 |
+
filesMethod = []
|
| 235 |
+
for clipDir in clipDirs:
|
| 236 |
+
path = os.path.join(dataDir, methodDir, clipDir, extraFknDir)
|
| 237 |
+
filesMethod = filesMethod + [os.path.join(clipDir,extraFknDir,name) for name in os.listdir(path)]
|
| 238 |
+
filesMethod = sorted(filesMethod)
|
| 239 |
+
assert len(files) == len(filesMethod), f"Number of files in {methodDir} does not match GT number of files"
|
| 240 |
+
assert files == filesMethod, f"Files in {methodDir} do not match GT files"
|
| 241 |
+
|
| 242 |
+
def main():
|
| 243 |
+
|
| 244 |
+
results = np.zeros((len(methodDirs),len(files),len(directions),frams,int(np.ceil(rows/patchDim)),int(np.ceil(cols/patchDim)),nMets))
|
| 245 |
+
|
| 246 |
+
if compute:
|
| 247 |
+
|
| 248 |
+
# loop methods + compute dataset level metrics (after nested for loops)
|
| 249 |
+
import multiprocessing as mp
|
| 250 |
+
ctx = mp.get_context('spawn')
|
| 251 |
+
with ProcessPoolExecutor(mp_context=ctx, max_workers=len(methodDirs)) as executor:
|
| 252 |
+
# submit one job per method
|
| 253 |
+
futures = {
|
| 254 |
+
executor.submit(compute_method, np.copy(results), md, files, idx): idx
|
| 255 |
+
for idx, md in enumerate(methodDirs)
|
| 256 |
+
}
|
| 257 |
+
# collect and merge results as they finish
|
| 258 |
+
for fut in as_completed(futures):
|
| 259 |
+
idx = futures[fut]
|
| 260 |
+
res_local = fut.result()
|
| 261 |
+
results[idx] = res_local[idx]
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
else:
|
| 265 |
+
|
| 266 |
+
results = np.load(resFile)
|
| 267 |
+
|
| 268 |
+
np.save(resFile, results)
|
| 269 |
+
# analyze
|
| 270 |
+
|
| 271 |
+
# new results: scoreFID, scoreFVD, scorePWPSNR, scoreEPE, scoreLPIPS, scoreDISTS, scoreSSIM, scoreDE2000, scorePSNR, scorePatchPSNR, scorePatchSSIM, scorePatchLPIPS, scorePatchDISTS, scorePatchDE2000
|
| 272 |
+
upMetrics = [1,3,4,6]
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
# 0508 results: scoreFID, scoreFVD, scoreLPIPS, scoreDISTS, scoreSSIM, scoreDE2000, scorePSNR, scorePatchPSNR, scorePatchSSIM, scorePatchLPIPS, scorePatchDISTS, scorePatchDE2000
|
| 276 |
+
#upMetrics = [4,6,7,8] # PSNR, SSIM, Patch PSNR, Patch SSIM
|
| 277 |
+
print("Results shape 1: ", results.shape)
|
| 278 |
+
forwardBackwardResults = np.mean(results,axis=(3))
|
| 279 |
+
#print("Results shape 2: ", forwardResults.shape)
|
| 280 |
+
maxDirResults = np.max(forwardBackwardResults,axis=(2))
|
| 281 |
+
minDirResults = np.min(forwardBackwardResults,axis=(2))
|
| 282 |
+
bestDirResults = minDirResults
|
| 283 |
+
#pdb.set_trace()
|
| 284 |
+
bestDirResults[:,:,:,:,upMetrics] = maxDirResults[:,:,:,:,upMetrics]
|
| 285 |
+
import pdb
|
| 286 |
+
#pdb.set_trace()
|
| 287 |
+
|
| 288 |
+
meanResults = bestDirResults.mean(axis=(1, 2, 3)) # Shape becomes (3, 6)
|
| 289 |
+
meanResultsT = meanResults.T
|
| 290 |
+
|
| 291 |
+
'''
|
| 292 |
+
maxDirResults = np.max(results,axis=2)
|
| 293 |
+
minDirResults = np.min(results,axis=2)
|
| 294 |
+
bestDirResults = minDirResults
|
| 295 |
+
bestDirResults[:,:,:,:,:,upMetrics] = maxDirResults[:,:,:,:,:,upMetrics]
|
| 296 |
+
meanResults = bestDirResults.mean(axis=(1, 2, 3, 4)) # Shape becomes (3, 6)
|
| 297 |
+
meanResultsT = meanResults.T
|
| 298 |
+
'''
|
| 299 |
+
|
| 300 |
+
#
|
| 301 |
+
#meanResults = forwardResults.mean(axis=(1, 2, 3, 4)) # Shape becomes (3, 6)
|
| 302 |
+
#meanResultsT = meanResults.T
|
| 303 |
+
|
| 304 |
+
# print latex table
|
| 305 |
+
method_labels = methodDirs
|
| 306 |
+
|
| 307 |
+
# results 0508: scoreLPIPS, scoreDISTS, scoreSSIM, scoreDE2000, scorePSNR, scorePatchPSNR, scorePatchSSIM, scorePatchLPIPS, scorePatchDISTS, scoreFID, scoreFVD
|
| 308 |
+
# metric_labels = ["FID $\downarrow$","FVD $\downarrow$","LPIPS $\downarrow$", "DISTS $\downarrow$", "SSIM $\downarrow$", "DE2000 $\downarrow$", "PSNR $\downarrow$", "Patch PSNR $\downarrow$", "Patch SSIM $\downarrow$", "Patch LPIPS $\downarrow$", "Patch DISTS $\downarrow$", "Patch DE2000 $\downarrow$"]
|
| 309 |
+
# results 0517:
|
| 310 |
+
# metric_labels = ["FID $\downarrow$","FVD $\downarrow$","PWPSNR $\downarrow$","EPE $\downarrow$","BlurryPSNR $\downarrow$", "LPIPS $\downarrow$", "DISTS $\downarrow$", "SSIM $\downarrow$", "DE2000 $\downarrow$", "PSNR $\downarrow$", "Patch PSNR $\downarrow$", "Patch SSIM $\downarrow$", "Patch LPIPS $\downarrow$", "Patch DISTS $\downarrow$", "Patch DE2000 $\downarrow$"]
|
| 311 |
+
|
| 312 |
+
# results 0518:
|
| 313 |
+
metric_labels = ["FVD $\downarrow$","PWPSNR $\downarrow$","EPE $\downarrow$","BlurryPSNR $\downarrow$","Patch SSIM $\downarrow$","Patch LPIPS $\downarrow$", "PSNR $\downarrow$"]
|
| 314 |
+
|
| 315 |
+
# appropriate for results 0507
|
| 316 |
+
#metric_labels = ["FID $\downarrow$", "FVD $\downarrow$", "LPIPS $\downarrow$", "DISTS $\downarrow$", "SSIM $\downarrow$", "DE2000 $\downarrow$", "PSNR $\downarrow$"]
|
| 317 |
+
|
| 318 |
+
latex_table = "\\begin{tabular}{l" + "c" * len(method_labels) + "}\n"
|
| 319 |
+
latex_table += "Metric & " + " & ".join(method_labels) + " \\\\\n"
|
| 320 |
+
latex_table += "\\hline\n"
|
| 321 |
+
|
| 322 |
+
for metric, row in zip(metric_labels, meanResultsT):
|
| 323 |
+
row_values = " & ".join(f"{v:.4f}" for v in row)
|
| 324 |
+
latex_table += f"{metric} & {row_values} \\\\\n"
|
| 325 |
+
|
| 326 |
+
latex_table += "\\end{tabular}"
|
| 327 |
+
print(latex_table)
|
| 328 |
+
|
| 329 |
+
if __name__ == '__main__':
|
| 330 |
+
main()
|
extra/moMets-parallel-gopro.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Motion Metrics
|
| 2 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 3 |
+
import numpy as np
|
| 4 |
+
np.float = np.float64
|
| 5 |
+
np.int = np.int_
|
| 6 |
+
import os
|
| 7 |
+
from cdfvd import fvd
|
| 8 |
+
from skimage.metrics import structural_similarity
|
| 9 |
+
import torch
|
| 10 |
+
import lpips
|
| 11 |
+
#from DISTS_pytorch import DISTS
|
| 12 |
+
#import colour as c
|
| 13 |
+
#from torchmetrics.image.fid import FrechetInceptionDistance
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from epe_metric import compute_bidirectional_epe as epe
|
| 16 |
+
import pdb
|
| 17 |
+
import multiprocessing
|
| 18 |
+
import cv2
|
| 19 |
+
import glob
|
| 20 |
+
# init
|
| 21 |
+
# dataDir = 'BaistCroppedOutput' # 'dataGoPro' #
|
| 22 |
+
# gtDir = 'gt_subset' #'GT' #
|
| 23 |
+
# methodDirs = ['deblurred', 'animation-from-blur', ] #['Favaro','MotionETR','Ours','GOPROGeneralize'] #
|
| 24 |
+
# fType = '.mp4'
|
| 25 |
+
# depth = 8
|
| 26 |
+
# resFile = './resultsBaist20250521.npy'#resultsGoPro20250520.npy'#
|
| 27 |
+
|
| 28 |
+
# patchDim = 32 #64 #
|
| 29 |
+
# pixMax = 1.0
|
| 30 |
+
|
| 31 |
+
# nMets = 7 # new results: scoreFVD, scorePWPSNR, scoreEPE, scorePatchSSIM, scorePatchLPIPS, scorePSNR
|
| 32 |
+
# compute = True # if False, load previously computed
|
| 33 |
+
# eps = 1e-8
|
| 34 |
+
|
| 35 |
+
dataDir = 'GOPROResultsImages' # 'dataBaist' #
|
| 36 |
+
gtDir = 'GT' #'gt' #
|
| 37 |
+
methodDirs = ['Jin','MotionETR','Ours'] #'GOPROGeneralize',# ['animation-from-blur'] #
|
| 38 |
+
depth = 8
|
| 39 |
+
resFile = 'resultsGoPro20250521.npy'# './resultsBaist20250521.npy'#
|
| 40 |
+
patchDim = 40 #32 #
|
| 41 |
+
pixMax = 1.0
|
| 42 |
+
nMets = 7 # new results: scoreFVD, scorePWPSNR, scoreEPE, scorePatchSSIM, scorePatchLPIPS, scorePSNR
|
| 43 |
+
compute = False # if False, load previously computed
|
| 44 |
+
eps = 1e-8
|
| 45 |
+
|
| 46 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 47 |
+
|
| 48 |
+
# Use 'spawn' to avoid CUDA context issues
|
| 49 |
+
multiprocessing.freeze_support() # on Windows
|
| 50 |
+
multiprocessing.set_start_method('spawn', force=True)
|
| 51 |
+
|
| 52 |
+
def read_pngs_to_array(path):
|
| 53 |
+
"""Read all PNGs in `path`, sort them by filename, convert BGR→RGB, and stack into an np.ndarray."""
|
| 54 |
+
return np.stack([
|
| 55 |
+
cv2.imread(f, cv2.IMREAD_UNCHANGED)[..., ::-1]
|
| 56 |
+
for f in sorted(glob.glob(f"{path}/*.png"))
|
| 57 |
+
])
|
| 58 |
+
def compute_method(results_local, methodDir, files, countMethod):
|
| 59 |
+
|
| 60 |
+
fnLPIPS = lpips.LPIPS(net='alex').to(device)
|
| 61 |
+
#fnDISTS = DISTS().to(device)
|
| 62 |
+
fnFVD = fvd.cdfvd(model='videomae', device=device)
|
| 63 |
+
|
| 64 |
+
countFile = -1
|
| 65 |
+
for file in files:
|
| 66 |
+
countFile+=1
|
| 67 |
+
|
| 68 |
+
# pull frames from MP4
|
| 69 |
+
pathMethod = os.path.join(dataDir, methodDir, file)
|
| 70 |
+
framesMethod = np.clip(read_pngs_to_array(pathMethod).astype(np.float32) / (2**depth-1),0,1)
|
| 71 |
+
pathGT = os.path.join(dataDir, gtDir, file)
|
| 72 |
+
framesGT = np.clip(read_pngs_to_array(pathGT).astype(np.float32) / (2**depth-1),0,1)
|
| 73 |
+
|
| 74 |
+
#make sure the GT and method have the same shape
|
| 75 |
+
assert framesGT.shape == framesMethod.shape, f"GT shape {framesGT.shape} does not match method shape {framesMethod.shape} for file {file}"
|
| 76 |
+
|
| 77 |
+
# video metrics
|
| 78 |
+
|
| 79 |
+
# vmaf
|
| 80 |
+
#scoreVMAF = callVMAF(pathGT, pathMethod)
|
| 81 |
+
|
| 82 |
+
# epe - we have to change to tensors here
|
| 83 |
+
framesMethodTensor = torch.from_numpy(framesMethod)
|
| 84 |
+
framesGTtensor = torch.from_numpy(framesGT)
|
| 85 |
+
scoreEPE = epe(framesMethodTensor[0,:,:,:], framesMethodTensor[-1,:,:,:], framesGTtensor[0,:,:,:], framesGTtensor[-1,:,:,:], per_pixel_mode=True).cpu().detach().numpy()
|
| 86 |
+
|
| 87 |
+
# motion blur baseline
|
| 88 |
+
blurryGT = np.mean(framesGT ** 2.2,axis=0) ** (1/2.2)
|
| 89 |
+
blurryMethod = np.mean(framesMethod ** 2.2,axis=0) ** (1/2.2)
|
| 90 |
+
# MSE -> PSNR
|
| 91 |
+
mapBlurryMSE = (blurryGT - blurryMethod)**2
|
| 92 |
+
scoreBlurryMSE = np.mean(mapBlurryMSE)
|
| 93 |
+
scoreBlurryPSNR = (10 * np.log10(pixMax**2 / scoreBlurryMSE))
|
| 94 |
+
|
| 95 |
+
# fvd
|
| 96 |
+
#scoreFVD = fnFVD.compute_fvd(real_videos=(np.expand_dims(framesGT, axis=0)*(2**depth-1)).astype(np.uint8), fake_videos=(np.expand_dims(framesMethod, axis=0)*(2**depth-1)).astype(np.uint8))
|
| 97 |
+
framesGTfvd = np.expand_dims((framesGT * (2**depth-1)).astype(np.uint8), axis=0)
|
| 98 |
+
fnFVD.add_real_stats(framesGTfvd)
|
| 99 |
+
framesMethodFVD = np.expand_dims((framesMethod * (2**depth-1)).astype(np.uint8), axis=0)
|
| 100 |
+
fnFVD.add_fake_stats(framesMethodFVD)
|
| 101 |
+
|
| 102 |
+
# loop directions
|
| 103 |
+
framesMSE = np.stack((framesGT,framesGT)) # pre allocate array for directional PSNR maps
|
| 104 |
+
countDirect = -1
|
| 105 |
+
for direction in directions:
|
| 106 |
+
countDirect = countDirect+1
|
| 107 |
+
order = direction
|
| 108 |
+
|
| 109 |
+
# loop frames + image level metrics
|
| 110 |
+
countFrames = -1
|
| 111 |
+
for i in order:
|
| 112 |
+
countFrames+=1
|
| 113 |
+
|
| 114 |
+
frameMethod = framesMethod[i,:,:,:] # method frames can be re-ordered
|
| 115 |
+
frameGT = framesGT[countFrames,:,:,:]
|
| 116 |
+
|
| 117 |
+
#assert patch size is divisible by image size
|
| 118 |
+
rows, cols, ch = frameGT.shape
|
| 119 |
+
assert rows % patchDim == 0, f"rows {rows} is not divisible by patchDim {patchDim}"
|
| 120 |
+
assert cols % patchDim == 0, f"cols {cols} is not divisible by patchDim {patchDim}"
|
| 121 |
+
|
| 122 |
+
rPatch = np.ceil(rows/patchDim)
|
| 123 |
+
cPatch = np.ceil(cols/patchDim)
|
| 124 |
+
|
| 125 |
+
# LPIPS
|
| 126 |
+
#pdb.set_trace()
|
| 127 |
+
methodTensor = (torch.from_numpy(np.moveaxis(frameMethod, -1, 0)).unsqueeze(0) * 2 - 1).to(device)
|
| 128 |
+
gtTensor = (torch.from_numpy(np.moveaxis(frameGT, -1, 0)).unsqueeze(0) * 2 - 1).to(device)
|
| 129 |
+
#scoreLPIPS = fnLPIPS(gtTensor, methodTensor).squeeze(0,1,2).cpu().detach().numpy()[0]
|
| 130 |
+
|
| 131 |
+
# FID
|
| 132 |
+
#fnFID.update((gtTensor * (2**depth - 1)).to(torch.uint8), real=True)
|
| 133 |
+
#fnFID.update((methodTensor * (2**depth - 1)).to(torch.uint8), real=False)
|
| 134 |
+
|
| 135 |
+
# DISTS
|
| 136 |
+
#scoreDISTS = fnDISTS(gtTensor.to(torch.float), methodTensor.to(torch.float), require_grad=True, batch_average=True).cpu().detach().numpy()
|
| 137 |
+
|
| 138 |
+
# compute ssim
|
| 139 |
+
#scoreSSIM = structural_similarity(frameGT, frameMethod, data_range=pixMax, channel_axis=2)
|
| 140 |
+
|
| 141 |
+
# compute DE 2000
|
| 142 |
+
#frameMethodXYZ = c.RGB_to_XYZ(frameMethod, c.models.RGB_COLOURSPACE_sRGB, apply_cctf_decoding=True)
|
| 143 |
+
#frameMethodLAB = c.XYZ_to_Lab(frameMethodXYZ)
|
| 144 |
+
#frameGTXYZ = c.RGB_to_XYZ(frameGT, c.models.RGB_COLOURSPACE_sRGB, apply_cctf_decoding=True)
|
| 145 |
+
#frameGTLAB = c.XYZ_to_Lab(frameGTXYZ)
|
| 146 |
+
#mapDE2000 = c.delta_E(frameGTLAB, frameMethodLAB, method='CIE 2000')
|
| 147 |
+
#scoreDE2000 = np.mean(mapDE2000)
|
| 148 |
+
|
| 149 |
+
# MSE
|
| 150 |
+
mapMSE = (frameGT - frameMethod)**2
|
| 151 |
+
scoreMSE = np.mean(mapMSE)
|
| 152 |
+
|
| 153 |
+
# PSNR
|
| 154 |
+
framesMSE[countDirect,countFrames,:,:,:] = mapMSE
|
| 155 |
+
#framesPSNR[countDirect,countFrames,:,:,:] = np.clip((10 * np.log10(pixMax**2 / np.clip(mapMSE,a_min=1e-10,a_max=None))),0,100)
|
| 156 |
+
scorePSNR = (10 * np.log10(pixMax**2 / scoreMSE))
|
| 157 |
+
|
| 158 |
+
#for l in range(ch):
|
| 159 |
+
|
| 160 |
+
# channel-wise metrics
|
| 161 |
+
#chanFrameMethod = frameMethod[:,:,l]
|
| 162 |
+
#chanFrameGT = frameGT[:,:,l]
|
| 163 |
+
|
| 164 |
+
# loop patches rows
|
| 165 |
+
for j in range(int(rPatch)):
|
| 166 |
+
|
| 167 |
+
# loop patches cols + patch level metrics
|
| 168 |
+
for k in range(int(cPatch)):
|
| 169 |
+
|
| 170 |
+
startR = j*patchDim
|
| 171 |
+
startC = k*patchDim
|
| 172 |
+
endR = j*patchDim+patchDim
|
| 173 |
+
endC = k*patchDim+patchDim
|
| 174 |
+
|
| 175 |
+
if endR > rows:
|
| 176 |
+
endR = rows
|
| 177 |
+
else:
|
| 178 |
+
pass
|
| 179 |
+
|
| 180 |
+
if endC > cols:
|
| 181 |
+
endC = cols
|
| 182 |
+
else:
|
| 183 |
+
pass
|
| 184 |
+
|
| 185 |
+
# patch metrics
|
| 186 |
+
#patchMSE = np.mean(mapMSE[startR:endR,startC:endC,:])
|
| 187 |
+
#scorePatchPSNR = np.clip((10 * np.log10(pixMax**2 / patchMSE)),0,100)
|
| 188 |
+
if dataDir == 'BaistCroppedOutput':
|
| 189 |
+
patchGtTensor = F.interpolate(gtTensor[:,:,startR:endR,startC:endC], scale_factor=2.0, mode='bilinear', align_corners=False)
|
| 190 |
+
patchMethodTensor = F.interpolate(methodTensor[:,:,startR:endR,startC:endC], scale_factor=2.0, mode='bilinear', align_corners=False)
|
| 191 |
+
scorePatchLPIPS = fnLPIPS(patchGtTensor, patchMethodTensor).squeeze(0,1,2).cpu().detach().numpy()[0]
|
| 192 |
+
else:
|
| 193 |
+
scorePatchLPIPS = fnLPIPS(gtTensor[:,:,startR:endR,startC:endC], methodTensor[:,:,startR:endR,startC:endC]).squeeze(0,1,2).cpu().detach().numpy()[0]
|
| 194 |
+
scorePatchSSIM = structural_similarity(frameGT[startR:endR,startC:endC,:], frameMethod[startR:endR,startC:endC,:], data_range=pixMax, channel_axis=2)
|
| 195 |
+
#scorePatchDISTS = fnDISTS(gtTensor[:,:,startR:endR,startC:endC].to(torch.float), methodTensor[:,:,startR:endR,startC:endC].to(torch.float), require_grad=True, batch_average=True).cpu().detach().numpy()
|
| 196 |
+
#scorePatchDE2000 = np.mean(mapDE2000[startR:endR,startC:endC])
|
| 197 |
+
|
| 198 |
+
# i: frame number, j: patch row, k: patch col
|
| 199 |
+
#results[countMethod,countFile,countDirect,i,j,k,3:] = [scoreEPE, scoreBlurryPSNR, scoreLPIPS, scoreDISTS, scoreSSIM, scoreDE2000, scorePSNR, scorePatchPSNR, scorePatchSSIM, scorePatchLPIPS, scorePatchDISTS, scorePatchDE2000]
|
| 200 |
+
results_local[countMethod,countFile,countDirect,i,j,k,2:] = [scoreEPE, scoreBlurryPSNR, scorePatchSSIM, scorePatchLPIPS, scorePSNR]
|
| 201 |
+
print('Method: ', methodDir, ' File: ', file, ' Frame: ', str(i), ' PSNR: ', scorePSNR, end='\r')
|
| 202 |
+
|
| 203 |
+
#print('VMAF: ', str(scoreVMAF), ' FVD: ', str(scoreFVD), ' LPIPS: ', str(scoreLPIPS), ' FID: ', str(scoreFID), ' DISTS: ', str(scoreDISTS), ' SSIM: ', str(scoreSSIM), ' DE2000: ', str(scoreDE2000), ' PSNR: ', str(scorePSNR), ' Patch PSNR: ', str(scorePatchPSNR), end='\r')
|
| 204 |
+
#pdb.set_trace()
|
| 205 |
+
scorePWPSNR = (10 * np.log10(pixMax**2 / np.mean(np.min(np.mean(framesMSE, axis=(1)),axis=0)))) # take max pixel wise PSNR per direction, average over image dims
|
| 206 |
+
#print('Method: ', methodDir, ' File: ', file, ' Frame: ', str(i), ' PWPSNR: ', scorePWPSNR, end='\n')
|
| 207 |
+
#scorePWPSNR = np.clip((10 * np.log10(pixMax**2 / np.mean(np.min(framesPSNR, axis=0),axis=(1,2,3)))),0,100) # take max pixel wise PSNR per direction, average over image dims
|
| 208 |
+
results_local[countMethod,countFile,:,:,:,:,1] = np.tile(scorePWPSNR, results_local.shape[2:-1])#np.broadcast_to(scorePWPSNR[:, np.newaxis, np.newaxis], results.shape[3:-1])
|
| 209 |
+
np.save(resFile, results_local) # save part of the way through the loop ..
|
| 210 |
+
|
| 211 |
+
#scoreFID = fnFID.compute().cpu().detach().numpy()
|
| 212 |
+
#fnFID.reset()
|
| 213 |
+
#results[countMethod,:,:,:,:,:,0] = np.tile(scoreFID, results.shape[1:-1])
|
| 214 |
+
scoreFVD = fnFVD.compute_fvd_from_stats()
|
| 215 |
+
fnFVD.empty_real_stats()
|
| 216 |
+
fnFVD.empty_fake_stats()
|
| 217 |
+
results_local[countMethod,:,:,:,:,:,0] = np.tile(scoreFVD, results_local.shape[1:-1])
|
| 218 |
+
print('Results computed .. analyzing ..')
|
| 219 |
+
|
| 220 |
+
return results_local
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# init results matrix
|
| 224 |
+
path = os.path.join(dataDir, gtDir)
|
| 225 |
+
clipDirs = [name for name in os.listdir(path) if os.path.isdir(os.path.join(path, name))]
|
| 226 |
+
files = []
|
| 227 |
+
if dataDir == 'BaistCroppedOutput':
|
| 228 |
+
extraFknDir = 'blur'
|
| 229 |
+
else:
|
| 230 |
+
extraFknDir = ''
|
| 231 |
+
for clipDir in clipDirs:
|
| 232 |
+
path = os.path.join(dataDir, gtDir, clipDir, extraFknDir)
|
| 233 |
+
files = files + [os.path.join(clipDir,extraFknDir,name) for name in os.listdir(path)]
|
| 234 |
+
files = sorted(files)
|
| 235 |
+
path = os.path.join(dataDir, methodDirs[0], files[0])
|
| 236 |
+
testFileGT = read_pngs_to_array(path)
|
| 237 |
+
frams,rows,cols,ch = testFileGT.shape
|
| 238 |
+
framRange = [i for i in range(frams)]
|
| 239 |
+
directions = [framRange, framRange[::-1]]
|
| 240 |
+
|
| 241 |
+
#loop through all methods and make sure they all have the same directory structure and same number of files
|
| 242 |
+
for methodDir in methodDirs:
|
| 243 |
+
path = os.path.join(dataDir, methodDir)
|
| 244 |
+
clipDirs = [name for name in os.listdir(path) if os.path.isdir(os.path.join(path, name))]
|
| 245 |
+
filesMethod = []
|
| 246 |
+
for clipDir in clipDirs:
|
| 247 |
+
path = os.path.join(dataDir, methodDir, clipDir, extraFknDir)
|
| 248 |
+
filesMethod = filesMethod + [os.path.join(clipDir,extraFknDir,name) for name in os.listdir(path)]
|
| 249 |
+
filesMethod = sorted(filesMethod)
|
| 250 |
+
print('Method: ', methodDir, ' Number of files: ', len(filesMethod))
|
| 251 |
+
assert len(files) == len(filesMethod), f"Number of files in {methodDir} does not match GT number of files"
|
| 252 |
+
assert files == filesMethod, f"Files in {methodDir} do not match GT files"
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def main():
|
| 256 |
+
|
| 257 |
+
results = np.zeros((len(methodDirs),len(files),len(directions),frams,int(np.ceil(rows/patchDim)),int(np.ceil(cols/patchDim)),nMets))
|
| 258 |
+
|
| 259 |
+
if compute:
|
| 260 |
+
|
| 261 |
+
# loop methods + compute dataset level metrics (after nested for loops)
|
| 262 |
+
import multiprocessing as mp
|
| 263 |
+
ctx = mp.get_context('spawn')
|
| 264 |
+
with ProcessPoolExecutor(mp_context=ctx, max_workers=len(methodDirs)) as executor:
|
| 265 |
+
# submit one job per method
|
| 266 |
+
futures = {
|
| 267 |
+
executor.submit(compute_method, np.copy(results), md, files, idx): idx
|
| 268 |
+
for idx, md in enumerate(methodDirs)
|
| 269 |
+
}
|
| 270 |
+
# collect and merge results as they finish
|
| 271 |
+
for fut in as_completed(futures):
|
| 272 |
+
idx = futures[fut]
|
| 273 |
+
res_local = fut.result()
|
| 274 |
+
results[idx] = res_local[idx]
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
else:
|
| 278 |
+
|
| 279 |
+
results = np.load(resFile)
|
| 280 |
+
|
| 281 |
+
np.save(resFile, results)
|
| 282 |
+
# analyze
|
| 283 |
+
|
| 284 |
+
# new results: scoreFID, scoreFVD, scorePWPSNR, scoreEPE, scoreLPIPS, scoreDISTS, scoreSSIM, scoreDE2000, scorePSNR, scorePatchPSNR, scorePatchSSIM, scorePatchLPIPS, scorePatchDISTS, scorePatchDE2000
|
| 285 |
+
upMetrics = [1,3,4,6]
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
# 0508 results: scoreFID, scoreFVD, scoreLPIPS, scoreDISTS, scoreSSIM, scoreDE2000, scorePSNR, scorePatchPSNR, scorePatchSSIM, scorePatchLPIPS, scorePatchDISTS, scorePatchDE2000
|
| 289 |
+
#upMetrics = [4,6,7,8] # PSNR, SSIM, Patch PSNR, Patch SSIM
|
| 290 |
+
print("Results shape 1: ", results.shape)
|
| 291 |
+
forwardBackwardResults = np.mean(results,axis=(3))
|
| 292 |
+
#print("Results shape 2: ", forwardResults.shape)
|
| 293 |
+
maxDirResults = np.max(forwardBackwardResults,axis=(2))
|
| 294 |
+
minDirResults = np.min(forwardBackwardResults,axis=(2))
|
| 295 |
+
bestDirResults = minDirResults
|
| 296 |
+
#pdb.set_trace()
|
| 297 |
+
bestDirResults[:,:,:,:,upMetrics] = maxDirResults[:,:,:,:,upMetrics]
|
| 298 |
+
import pdb
|
| 299 |
+
#pdb.set_trace()
|
| 300 |
+
|
| 301 |
+
meanResults = bestDirResults.mean(axis=(1, 2, 3)) # Shape becomes (3, 6)
|
| 302 |
+
meanResultsT = meanResults.T
|
| 303 |
+
|
| 304 |
+
'''
|
| 305 |
+
maxDirResults = np.max(results,axis=2)
|
| 306 |
+
minDirResults = np.min(results,axis=2)
|
| 307 |
+
bestDirResults = minDirResults
|
| 308 |
+
bestDirResults[:,:,:,:,:,upMetrics] = maxDirResults[:,:,:,:,:,upMetrics]
|
| 309 |
+
meanResults = bestDirResults.mean(axis=(1, 2, 3, 4)) # Shape becomes (3, 6)
|
| 310 |
+
meanResultsT = meanResults.T
|
| 311 |
+
'''
|
| 312 |
+
|
| 313 |
+
#
|
| 314 |
+
#meanResults = forwardResults.mean(axis=(1, 2, 3, 4)) # Shape becomes (3, 6)
|
| 315 |
+
#meanResultsT = meanResults.T
|
| 316 |
+
|
| 317 |
+
# print latex table
|
| 318 |
+
method_labels = methodDirs
|
| 319 |
+
|
| 320 |
+
# results 0508: scoreLPIPS, scoreDISTS, scoreSSIM, scoreDE2000, scorePSNR, scorePatchPSNR, scorePatchSSIM, scorePatchLPIPS, scorePatchDISTS, scoreFID, scoreFVD
|
| 321 |
+
# metric_labels = ["FID $\downarrow$","FVD $\downarrow$","LPIPS $\downarrow$", "DISTS $\downarrow$", "SSIM $\downarrow$", "DE2000 $\downarrow$", "PSNR $\downarrow$", "Patch PSNR $\downarrow$", "Patch SSIM $\downarrow$", "Patch LPIPS $\downarrow$", "Patch DISTS $\downarrow$", "Patch DE2000 $\downarrow$"]
|
| 322 |
+
# results 0517:
|
| 323 |
+
# metric_labels = ["FID $\downarrow$","FVD $\downarrow$","PWPSNR $\downarrow$","EPE $\downarrow$","BlurryPSNR $\downarrow$", "LPIPS $\downarrow$", "DISTS $\downarrow$", "SSIM $\downarrow$", "DE2000 $\downarrow$", "PSNR $\downarrow$", "Patch PSNR $\downarrow$", "Patch SSIM $\downarrow$", "Patch LPIPS $\downarrow$", "Patch DISTS $\downarrow$", "Patch DE2000 $\downarrow$"]
|
| 324 |
+
|
| 325 |
+
# results 0518:
|
| 326 |
+
metric_labels = ["FVD $\downarrow$","PWPSNR $\downarrow$","EPE $\downarrow$","BlurryPSNR $\downarrow$","Patch SSIM $\downarrow$","Patch LPIPS $\downarrow$", "PSNR $\downarrow$"]
|
| 327 |
+
|
| 328 |
+
# appropriate for results 0507
|
| 329 |
+
#metric_labels = ["FID $\downarrow$", "FVD $\downarrow$", "LPIPS $\downarrow$", "DISTS $\downarrow$", "SSIM $\downarrow$", "DE2000 $\downarrow$", "PSNR $\downarrow$"]
|
| 330 |
+
|
| 331 |
+
latex_table = "\\begin{tabular}{l" + "c" * len(method_labels) + "}\n"
|
| 332 |
+
latex_table += "Metric & " + " & ".join(method_labels) + " \\\\\n"
|
| 333 |
+
latex_table += "\\hline\n"
|
| 334 |
+
|
| 335 |
+
for metric, row in zip(metric_labels, meanResultsT):
|
| 336 |
+
row_values = " & ".join(f"{v:.4f}" for v in row)
|
| 337 |
+
latex_table += f"{metric} & {row_values} \\\\\n"
|
| 338 |
+
|
| 339 |
+
latex_table += "\\end{tabular}"
|
| 340 |
+
print(latex_table)
|
| 341 |
+
|
| 342 |
+
if __name__ == '__main__':
|
| 343 |
+
main()
|
gradio/app.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uuid
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
# -----------------------
|
| 9 |
+
# 1. Load your model here
|
| 10 |
+
# -----------------------
|
| 11 |
+
# Example:
|
| 12 |
+
# from my_model_lib import MyVideoModel
|
| 13 |
+
# model = MyVideoModel.from_pretrained("your/model/hub/id")
|
| 14 |
+
|
| 15 |
+
OUTPUT_DIR = Path("/tmp/generated_videos")
|
| 16 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def generate_video_from_image(image: Image.Image) -> str:
|
| 20 |
+
video_id = uuid.uuid4().hex
|
| 21 |
+
output_path = OUTPUT_DIR / f"{video_id}.mp4"
|
| 22 |
+
|
| 23 |
+
# 1. Preprocess image
|
| 24 |
+
# img_tensor = preprocess(image) # your code
|
| 25 |
+
|
| 26 |
+
# 2. Run model
|
| 27 |
+
# frames = model(img_tensor) # e.g. np.ndarray of shape (T, H, W, 3), dtype=uint8
|
| 28 |
+
|
| 29 |
+
# 3. Save frames to video
|
| 30 |
+
# iio.imwrite(
|
| 31 |
+
# uri=output_path,
|
| 32 |
+
# image=frames,
|
| 33 |
+
# fps=16,
|
| 34 |
+
# codec="h264",
|
| 35 |
+
# )
|
| 36 |
+
|
| 37 |
+
return str(output_path)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def demo_predict(image: Image.Image) -> str:
|
| 41 |
+
"""
|
| 42 |
+
Wrapper for Gradio. Takes an image and returns a video path.
|
| 43 |
+
"""
|
| 44 |
+
if image is None:
|
| 45 |
+
raise gr.Error("Please upload an image first.")
|
| 46 |
+
|
| 47 |
+
video_path = generate_video_from_image(image)
|
| 48 |
+
if not os.path.exists(video_path):
|
| 49 |
+
raise gr.Error("Video generation failed: output file not found.")
|
| 50 |
+
return video_path
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
| 54 |
+
gr.Markdown(
|
| 55 |
+
"""
|
| 56 |
+
# 🖼️ ➜ 🎬 Recover motion from a blurry image!
|
| 57 |
+
|
| 58 |
+
Upload an image and the model will generate a short video.
|
| 59 |
+
"""
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
with gr.Row():
|
| 63 |
+
with gr.Column():
|
| 64 |
+
image_in = gr.Image(
|
| 65 |
+
type="pil",
|
| 66 |
+
label="Input image",
|
| 67 |
+
interactive=True,
|
| 68 |
+
)
|
| 69 |
+
generate_btn = gr.Button("Generate video", variant="primary")
|
| 70 |
+
with gr.Column():
|
| 71 |
+
video_out = gr.Video(
|
| 72 |
+
label="Generated video",
|
| 73 |
+
format="mp4", # ensures browser-friendly output
|
| 74 |
+
autoplay=True,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
generate_btn.click(
|
| 78 |
+
fn=demo_predict,
|
| 79 |
+
inputs=image_in,
|
| 80 |
+
outputs=video_out,
|
| 81 |
+
api_name="predict",
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
demo.launch()
|
inference.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import os
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
import io
|
| 20 |
+
import yaml
|
| 21 |
+
|
| 22 |
+
from PIL import Image, ImageCms
|
| 23 |
+
import torch
|
| 24 |
+
import numpy as np
|
| 25 |
+
from transformers import T5Tokenizer, T5EncoderModel
|
| 26 |
+
from safetensors.torch import load_file
|
| 27 |
+
import diffusers
|
| 28 |
+
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler
|
| 29 |
+
from diffusers.utils import check_min_version, export_to_video
|
| 30 |
+
|
| 31 |
+
from controlnet_pipeline import ControlnetCogVideoXPipeline
|
| 32 |
+
from cogvideo_transformer import CogVideoXTransformer3DModel
|
| 33 |
+
|
| 34 |
+
from training.utils import save_frames_as_pngs
|
| 35 |
+
from training.helpers import get_conditioning
|
| 36 |
+
|
| 37 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 38 |
+
check_min_version("0.31.0.dev0")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def convert_to_srgb(img: Image):
|
| 42 |
+
if 'icc_profile' in img.info:
|
| 43 |
+
icc = img.info['icc_profile']
|
| 44 |
+
src_profile = ImageCms.ImageCmsProfile(io.BytesIO(icc))
|
| 45 |
+
dst_profile = ImageCms.createProfile("sRGB")
|
| 46 |
+
img = ImageCms.profileToProfile(img, src_profile, dst_profile, outputMode='RGB')
|
| 47 |
+
else:
|
| 48 |
+
img = img.convert("RGB") # Assume sRGB
|
| 49 |
+
return img
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
INTERVALS = {
|
| 53 |
+
"present": {
|
| 54 |
+
"in_start": 0,
|
| 55 |
+
"in_end": 16,
|
| 56 |
+
"out_start": 0,
|
| 57 |
+
"out_end": 16,
|
| 58 |
+
"center": 8,
|
| 59 |
+
"window_size": 16,
|
| 60 |
+
"mode": "1x",
|
| 61 |
+
"fps": 240
|
| 62 |
+
},
|
| 63 |
+
"past_present_and_future": {
|
| 64 |
+
"in_start": 4,
|
| 65 |
+
"in_end": 12,
|
| 66 |
+
"out_start": 0,
|
| 67 |
+
"out_end": 16,
|
| 68 |
+
"center": 8,
|
| 69 |
+
"window_size": 16,
|
| 70 |
+
"mode": "2x",
|
| 71 |
+
"fps": 240,
|
| 72 |
+
},
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def convert_to_batch(
|
| 77 |
+
image,
|
| 78 |
+
interval_key="present",
|
| 79 |
+
image_size=(720, 1280),
|
| 80 |
+
):
|
| 81 |
+
interval = INTERVALS[interval_key]
|
| 82 |
+
|
| 83 |
+
inp_int, out_int, num_frames = get_conditioning(
|
| 84 |
+
in_start=interval['in_start'],
|
| 85 |
+
in_end=interval['in_end'],
|
| 86 |
+
out_start=interval['out_start'],
|
| 87 |
+
out_end=interval['out_end'],
|
| 88 |
+
mode=interval['mode'],
|
| 89 |
+
fps=interval['fps'],
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
blur_img_original = convert_to_srgb(image)
|
| 93 |
+
H, W = blur_img_original.size
|
| 94 |
+
|
| 95 |
+
blur_img = blur_img_original.resize((image_size[1], image_size[0])) # pil is width, height
|
| 96 |
+
blur_img = torch.from_numpy(np.array(blur_img)[None]).permute(0, 3, 1, 2).contiguous().float()
|
| 97 |
+
blur_img = blur_img / 127.5 - 1.0
|
| 98 |
+
|
| 99 |
+
data = {
|
| 100 |
+
"original_size": (H, W),
|
| 101 |
+
'blur_img': blur_img,
|
| 102 |
+
'caption': "",
|
| 103 |
+
'input_interval': inp_int,
|
| 104 |
+
'output_interval': out_int,
|
| 105 |
+
'height': image_size[0],
|
| 106 |
+
'width': image_size[1],
|
| 107 |
+
'num_frames': num_frames,
|
| 108 |
+
}
|
| 109 |
+
return data
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def load_model(args):
|
| 113 |
+
with open(args.model_config_path) as f:
|
| 114 |
+
model_config = yaml.safe_load(f)
|
| 115 |
+
|
| 116 |
+
load_dtype = torch.float16
|
| 117 |
+
transformer = CogVideoXTransformer3DModel.from_pretrained(
|
| 118 |
+
args.pretrained_model_path,
|
| 119 |
+
subfolder="transformer",
|
| 120 |
+
torch_dtype=load_dtype,
|
| 121 |
+
revision=model_config["revision"],
|
| 122 |
+
variant=model_config["variant"],
|
| 123 |
+
low_cpu_mem_usage=False,
|
| 124 |
+
)
|
| 125 |
+
transformer.load_state_dict(load_file(args.weight_path))
|
| 126 |
+
|
| 127 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
| 128 |
+
args.pretrained_model_path,
|
| 129 |
+
subfolder="text_encoder",
|
| 130 |
+
revision=model_config["revision"],
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
tokenizer = T5Tokenizer.from_pretrained(
|
| 134 |
+
args.pretrained_model_path,
|
| 135 |
+
subfolder="tokenizer",
|
| 136 |
+
revision=model_config["revision"],
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
vae = AutoencoderKLCogVideoX.from_pretrained(
|
| 140 |
+
args.pretrained_model_path,
|
| 141 |
+
subfolder="vae",
|
| 142 |
+
revision=model_config["revision"],
|
| 143 |
+
variant=model_config["variant"],
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
scheduler = CogVideoXDPMScheduler.from_pretrained(
|
| 147 |
+
args.pretrained_model_path,
|
| 148 |
+
subfolder="scheduler"
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Enable slicing or tiling if VRAM is low!
|
| 152 |
+
vae.enable_slicing()
|
| 153 |
+
vae.enable_tiling()
|
| 154 |
+
|
| 155 |
+
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
| 156 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
| 157 |
+
weight_dtype = torch.bfloat16
|
| 158 |
+
|
| 159 |
+
text_encoder.to(args.device, dtype=weight_dtype)
|
| 160 |
+
transformer.to(args.device, dtype=weight_dtype)
|
| 161 |
+
vae.to(args.device, dtype=weight_dtype)
|
| 162 |
+
|
| 163 |
+
pipe = ControlnetCogVideoXPipeline.from_pretrained(
|
| 164 |
+
args.pretrained_model_path,
|
| 165 |
+
tokenizer=tokenizer,
|
| 166 |
+
transformer=transformer,
|
| 167 |
+
text_encoder=text_encoder,
|
| 168 |
+
vae=vae,
|
| 169 |
+
scheduler=scheduler,
|
| 170 |
+
torch_dtype=weight_dtype,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
scheduler_args = {}
|
| 174 |
+
|
| 175 |
+
if "variance_type" in pipe.scheduler.config:
|
| 176 |
+
variance_type = pipe.scheduler.config.variance_type
|
| 177 |
+
|
| 178 |
+
if variance_type in ["learned", "learned_range"]:
|
| 179 |
+
variance_type = "fixed_small"
|
| 180 |
+
|
| 181 |
+
scheduler_args["variance_type"] = variance_type
|
| 182 |
+
|
| 183 |
+
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)
|
| 184 |
+
pipe = pipe.to(args.device)
|
| 185 |
+
|
| 186 |
+
return pipe, model_config
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def inference_on_image(pipe, image, interval_key, model_config, args):
|
| 190 |
+
# If passed along, set the training seed now.
|
| 191 |
+
if args.seed is not None:
|
| 192 |
+
np.random.seed(args.seed)
|
| 193 |
+
torch.manual_seed(args.seed)
|
| 194 |
+
|
| 195 |
+
# run inference
|
| 196 |
+
generator = torch.Generator(device=args.device).manual_seed(args.seed) if args.seed else None
|
| 197 |
+
|
| 198 |
+
with torch.autocast(args.device, enabled=True):
|
| 199 |
+
batch = convert_to_batch(image, interval_key, (args.video_height, args.video_width))
|
| 200 |
+
|
| 201 |
+
frame = batch["blur_img"].permute(0, 2, 3, 1).cpu().numpy()
|
| 202 |
+
frame = (frame + 1.0) * 127.5
|
| 203 |
+
frame = frame.astype(np.uint8)
|
| 204 |
+
pipeline_args = {
|
| 205 |
+
"prompt": "",
|
| 206 |
+
"negative_prompt": "",
|
| 207 |
+
"image": frame,
|
| 208 |
+
"input_intervals": torch.stack([batch["input_interval"]]),
|
| 209 |
+
"output_intervals": torch.stack([batch["output_interval"]]),
|
| 210 |
+
"guidance_scale": model_config["guidance_scale"],
|
| 211 |
+
"use_dynamic_cfg": model_config["use_dynamic_cfg"],
|
| 212 |
+
"height": batch["height"],
|
| 213 |
+
"width": batch["width"],
|
| 214 |
+
"num_frames": torch.tensor([[model_config["max_num_frames"]]]), # torch.tensor([[batch["num_frames"]]]),
|
| 215 |
+
"num_inference_steps": model_config["num_inference_steps"],
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
input_image = frame
|
| 219 |
+
|
| 220 |
+
num_frames = batch["num_frames"] # this is the actual number of frames, the video generation is padded by one frame
|
| 221 |
+
|
| 222 |
+
print(f"Running inference for interval {interval_key}...")
|
| 223 |
+
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
|
| 224 |
+
|
| 225 |
+
video = video[0:num_frames]
|
| 226 |
+
|
| 227 |
+
return input_image, video
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def main(args):
|
| 231 |
+
output_path = Path(args.output_path)
|
| 232 |
+
output_path.mkdir(exist_ok=True)
|
| 233 |
+
|
| 234 |
+
image_path = Path(args.image_path)
|
| 235 |
+
|
| 236 |
+
is_dir = image_path.is_dir()
|
| 237 |
+
|
| 238 |
+
if is_dir:
|
| 239 |
+
image_paths = sorted(list(image_path.glob("*.*")))
|
| 240 |
+
else:
|
| 241 |
+
image_paths = [image_path]
|
| 242 |
+
|
| 243 |
+
pipe, model_config = load_model(args)
|
| 244 |
+
|
| 245 |
+
for image_path in image_paths:
|
| 246 |
+
image = Image.open(image_path)
|
| 247 |
+
|
| 248 |
+
processed_image, video = inference_on_image(pipe, image, "past_present_and_future", model_config, args)
|
| 249 |
+
|
| 250 |
+
vid_output_path = output_path / f"{image_path.stem}.mp4"
|
| 251 |
+
export_to_video(video, vid_output_path, fps=20)
|
| 252 |
+
|
| 253 |
+
# save input image as well
|
| 254 |
+
inpug_image_output_path = output_path / f"{image_path.stem}_input.png"
|
| 255 |
+
Image.fromarray(processed_image[0]).save(inpug_image_output_path)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
if __name__ == "__main__":
|
| 259 |
+
parser = argparse.ArgumentParser()
|
| 260 |
+
parser.add_argument(
|
| 261 |
+
"--image_path",
|
| 262 |
+
type=str,
|
| 263 |
+
required=True,
|
| 264 |
+
help="Path to image input or directory containing input images",
|
| 265 |
+
)
|
| 266 |
+
parser.add_argument(
|
| 267 |
+
"--weight_path",
|
| 268 |
+
type=str,
|
| 269 |
+
default="training/cogvideox-outsidephotos/checkpoint/model.safetensors",
|
| 270 |
+
help="directory containing weight files",
|
| 271 |
+
)
|
| 272 |
+
parser.add_argument(
|
| 273 |
+
"--pretrained_model_path",
|
| 274 |
+
type=str,
|
| 275 |
+
default="THUDM/CogVideoX-2b",
|
| 276 |
+
help="repo id or path for pretrained CogVideoX model",
|
| 277 |
+
)
|
| 278 |
+
parser.add_argument(
|
| 279 |
+
"--model_config_path",
|
| 280 |
+
type=str,
|
| 281 |
+
default="training/configs/outsidephotos.yaml",
|
| 282 |
+
help="path to model config yaml",
|
| 283 |
+
)
|
| 284 |
+
parser.add_argument(
|
| 285 |
+
"--output_path",
|
| 286 |
+
type=str,
|
| 287 |
+
required=True,
|
| 288 |
+
help="path to output",
|
| 289 |
+
)
|
| 290 |
+
parser.add_argument(
|
| 291 |
+
"--video_width",
|
| 292 |
+
type=int,
|
| 293 |
+
default=1280,
|
| 294 |
+
help="video resolution width",
|
| 295 |
+
)
|
| 296 |
+
parser.add_argument(
|
| 297 |
+
"--video_height",
|
| 298 |
+
type=int,
|
| 299 |
+
default=720,
|
| 300 |
+
help="video resolution height",
|
| 301 |
+
)
|
| 302 |
+
parser.add_argument(
|
| 303 |
+
"--seed",
|
| 304 |
+
type=int,
|
| 305 |
+
default=None,
|
| 306 |
+
help="random generator seed",
|
| 307 |
+
)
|
| 308 |
+
parser.add_argument(
|
| 309 |
+
"--device",
|
| 310 |
+
type=str,
|
| 311 |
+
default="cuda",
|
| 312 |
+
help="inference device",
|
| 313 |
+
)
|
| 314 |
+
args = parser.parse_args()
|
| 315 |
+
main(args)
|
| 316 |
+
|
| 317 |
+
# python inference.py --image_path assets/dummy_image.png --output_path output/
|
requirements.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spaces>=0.29.3
|
| 2 |
+
safetensors>=0.4.5
|
| 3 |
+
spandrel>=0.4.0
|
| 4 |
+
tqdm>=4.66.5
|
| 5 |
+
scikit-video>=1.1.11
|
| 6 |
+
git+https://github.com/huggingface/diffusers.git@main
|
| 7 |
+
transformers>=4.44.0
|
| 8 |
+
accelerate>=0.34.2
|
| 9 |
+
opencv-python>=4.10.0.84
|
| 10 |
+
sentencepiece>=0.2.0
|
| 11 |
+
numpy==1.26.0
|
| 12 |
+
torch>=2.4.0
|
| 13 |
+
torchvision>=0.19.0
|
| 14 |
+
gradio>=4.44.0
|
| 15 |
+
imageio>=2.34.2
|
| 16 |
+
imageio-ffmpeg>=0.5.1
|
| 17 |
+
openai>=1.45.0
|
| 18 |
+
moviepy>=1.0.3
|
| 19 |
+
pillow==9.5.0
|
| 20 |
+
denku==0.0.51
|
| 21 |
+
controlnet-aux==0.0.9
|
| 22 |
+
gradio>=4.44.0
|
setup/download_checkpoints.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import snapshot_download
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
# Make sure HF_TOKEN is set in your env beforehand:
|
| 5 |
+
# export HF_TOKEN=your_hf_token
|
| 6 |
+
#get first command line argument
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
mode = sys.argv[1] if len(sys.argv) > 1 else "outsidephotos"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
REPO_ID = "tedlasai/blur2vid"
|
| 13 |
+
REPO_TYPE = "model"
|
| 14 |
+
|
| 15 |
+
# These are the subfolders you previously used as path_in_repo
|
| 16 |
+
if mode == "outsidephotos":
|
| 17 |
+
checkpoints = [
|
| 18 |
+
"cogvideox-outsidephotos",
|
| 19 |
+
]
|
| 20 |
+
elif mode == "gopro":
|
| 21 |
+
checkpoints = [
|
| 22 |
+
"cogvideox-gopro-test",
|
| 23 |
+
"cogvideox-gopro-2x-test",
|
| 24 |
+
]
|
| 25 |
+
elif mode == "baist":
|
| 26 |
+
checkpoints = [
|
| 27 |
+
"cogvideox-baist-test",
|
| 28 |
+
]
|
| 29 |
+
elif mode == "full":
|
| 30 |
+
checkpoints = [
|
| 31 |
+
"cogvideox-baist-test",
|
| 32 |
+
"cogvideox-gopro-test",
|
| 33 |
+
"cogvideox-gopro-2x-test",
|
| 34 |
+
"cogvideox-full-test",
|
| 35 |
+
"cogvideox-outsidephotos",
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
# This is the root local directory where you want everything saved
|
| 39 |
+
#get path of this file
|
| 40 |
+
LOCAL_TRAINING_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "training")
|
| 41 |
+
os.makedirs(LOCAL_TRAINING_ROOT, exist_ok=True)
|
| 42 |
+
|
| 43 |
+
# Download only those folders from the repo and place them under LOCAL_TRAINING_ROOT
|
| 44 |
+
snapshot_download(
|
| 45 |
+
repo_id=REPO_ID,
|
| 46 |
+
repo_type=REPO_TYPE,
|
| 47 |
+
local_dir=LOCAL_TRAINING_ROOT,
|
| 48 |
+
local_dir_use_symlinks=False,
|
| 49 |
+
allow_patterns=[f"{name}/*" for name in checkpoints],
|
| 50 |
+
token=os.getenv("HF_TOKEN"),
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
print(f"Done! Checkpoints downloaded under: {LOCAL_TRAINING_ROOT}")
|
setup/download_cogvideo_weights.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import snapshot_download
|
| 2 |
+
|
| 3 |
+
# Download the entire model repository and store it locally
|
| 4 |
+
model_path = snapshot_download(repo_id="THUDM/CogVideoX-2b", cache_dir="./CogVideoX-2b")
|
| 5 |
+
|
| 6 |
+
print(f"Model downloaded to: {model_path}")
|
setup/environment.yaml
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: blur2vid
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- _libgcc_mutex=0.1=main
|
| 7 |
+
- _openmp_mutex=5.1=1_gnu
|
| 8 |
+
- asttokens=3.0.0=pyhd8ed1ab_1
|
| 9 |
+
- bzip2=1.0.8=h5eee18b_6
|
| 10 |
+
- ca-certificates=2025.4.26=hbd8a1cb_0
|
| 11 |
+
- comm=0.2.2=pyhd8ed1ab_1
|
| 12 |
+
- debugpy=1.6.0=py310hd8f1fbe_0
|
| 13 |
+
- entrypoints=0.4=pyhd8ed1ab_1
|
| 14 |
+
- exceptiongroup=1.2.2=pyhd8ed1ab_1
|
| 15 |
+
- executing=2.2.0=pyhd8ed1ab_0
|
| 16 |
+
- ffmpeg=4.3.2=hca11adc_0
|
| 17 |
+
- freetype=2.10.4=h0708190_1
|
| 18 |
+
- gmp=6.2.1=h58526e2_0
|
| 19 |
+
- gnutls=3.6.13=h85f3911_1
|
| 20 |
+
- ipykernel=6.20.2=pyh210e3f2_0
|
| 21 |
+
- ipython=8.36.0=pyh907856f_0
|
| 22 |
+
- jedi=0.19.2=pyhd8ed1ab_1
|
| 23 |
+
- jupyter_client=7.3.4=pyhd8ed1ab_0
|
| 24 |
+
- jupyter_core=5.7.2=pyh31011fe_1
|
| 25 |
+
- lame=3.100=h7f98852_1001
|
| 26 |
+
- ld_impl_linux-64=2.40=h12ee557_0
|
| 27 |
+
- libevent=2.1.12=hdbd6064_1
|
| 28 |
+
- libffi=3.4.4=h6a678d5_1
|
| 29 |
+
- libgcc-ng=11.2.0=h1234567_1
|
| 30 |
+
- libgomp=11.2.0=h1234567_1
|
| 31 |
+
- libpng=1.6.37=h21135ba_2
|
| 32 |
+
- libsodium=1.0.18=h36c2ea0_1
|
| 33 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
| 34 |
+
- libuuid=1.41.5=h5eee18b_0
|
| 35 |
+
- matplotlib-inline=0.1.7=pyhd8ed1ab_1
|
| 36 |
+
- ncurses=6.4=h6a678d5_0
|
| 37 |
+
- nest-asyncio=1.6.0=pyhd8ed1ab_1
|
| 38 |
+
- nettle=3.6=he412f7d_0
|
| 39 |
+
- openh264=2.1.1=h780b84a_0
|
| 40 |
+
- openssl=3.0.16=h5eee18b_0
|
| 41 |
+
- parso=0.8.4=pyhd8ed1ab_1
|
| 42 |
+
- pexpect=4.9.0=pyhd8ed1ab_1
|
| 43 |
+
- pickleshare=0.7.5=pyhd8ed1ab_1004
|
| 44 |
+
- pip=25.0=py310h06a4308_0
|
| 45 |
+
- platformdirs=4.3.7=pyh29332c3_0
|
| 46 |
+
- prompt-toolkit=3.0.51=pyha770c72_0
|
| 47 |
+
- ptyprocess=0.7.0=pyhd8ed1ab_1
|
| 48 |
+
- pure_eval=0.2.3=pyhd8ed1ab_1
|
| 49 |
+
- pygments=2.19.1=pyhd8ed1ab_0
|
| 50 |
+
- python=3.10.16=he870216_1
|
| 51 |
+
- python-dateutil=2.9.0.post0=pyhff2d567_1
|
| 52 |
+
- python_abi=3.10=2_cp310
|
| 53 |
+
- pyzmq=23.0.0=py310h330234f_0
|
| 54 |
+
- readline=8.2=h5eee18b_0
|
| 55 |
+
- setuptools=75.8.0=py310h06a4308_0
|
| 56 |
+
- six=1.17.0=pyhd8ed1ab_0
|
| 57 |
+
- sqlite=3.45.3=h5eee18b_0
|
| 58 |
+
- stack_data=0.6.3=pyhd8ed1ab_1
|
| 59 |
+
- tk=8.6.14=h39e8969_0
|
| 60 |
+
- tmux=3.3a=h5eee18b_1
|
| 61 |
+
- tornado=6.1=py310h5764c6d_3
|
| 62 |
+
- traitlets=5.14.3=pyhd8ed1ab_1
|
| 63 |
+
- typing_extensions=4.13.2=pyh29332c3_0
|
| 64 |
+
- wcwidth=0.2.13=pyhd8ed1ab_1
|
| 65 |
+
- wheel=0.45.1=py310h06a4308_0
|
| 66 |
+
- x264=1!161.3030=h7f98852_1
|
| 67 |
+
- xz=5.6.4=h5eee18b_1
|
| 68 |
+
- zeromq=4.3.4=h9c3ff4c_1
|
| 69 |
+
- zlib=1.2.13=h5eee18b_1
|
| 70 |
+
- pip:
|
| 71 |
+
- absl-py==2.2.0
|
| 72 |
+
- accelerate==1.5.2
|
| 73 |
+
- aiofiles==23.2.1
|
| 74 |
+
- aiohappyeyeballs==2.6.1
|
| 75 |
+
- aiohttp==3.12.14
|
| 76 |
+
- aiosignal==1.4.0
|
| 77 |
+
- annotated-types==0.7.0
|
| 78 |
+
- anyio==4.9.0
|
| 79 |
+
- async-timeout==5.0.1
|
| 80 |
+
- atomicwrites==1.4.1
|
| 81 |
+
- attrs==25.3.0
|
| 82 |
+
- beautifulsoup4==4.13.4
|
| 83 |
+
- certifi==2025.1.31
|
| 84 |
+
- cffi==1.17.1
|
| 85 |
+
- charset-normalizer==3.4.1
|
| 86 |
+
- click==8.1.8
|
| 87 |
+
- colour-science==0.4.6
|
| 88 |
+
- contourpy==1.3.1
|
| 89 |
+
- controlnet-aux==0.0.9
|
| 90 |
+
- cycler==0.12.1
|
| 91 |
+
- decorator==4.4.2
|
| 92 |
+
- decord==0.6.0
|
| 93 |
+
- denku==0.0.51
|
| 94 |
+
- diffusers==0.32.0
|
| 95 |
+
- distro==1.9.0
|
| 96 |
+
- docker-pycreds==0.4.0
|
| 97 |
+
- einops==0.8.1
|
| 98 |
+
- einops-exts==0.0.4
|
| 99 |
+
- fastapi==0.115.11
|
| 100 |
+
- ffmpeg-python==0.2.0
|
| 101 |
+
- ffmpy==0.5.0
|
| 102 |
+
- filelock==3.18.0
|
| 103 |
+
- flatbuffers==25.2.10
|
| 104 |
+
- fonttools==4.56.0
|
| 105 |
+
- frozenlist==1.7.0
|
| 106 |
+
- fsspec==2025.3.0
|
| 107 |
+
- future==1.0.0
|
| 108 |
+
- gdown==5.2.0
|
| 109 |
+
- gitdb==4.0.12
|
| 110 |
+
- gitpython==3.1.44
|
| 111 |
+
- gradio==5.22.0
|
| 112 |
+
- gradio-client==1.8.0
|
| 113 |
+
- groovy==0.1.2
|
| 114 |
+
- h11==0.14.0
|
| 115 |
+
- hf-transfer==0.1.9
|
| 116 |
+
- httpcore==1.0.7
|
| 117 |
+
- httpx==0.28.1
|
| 118 |
+
- huggingface-hub==0.29.3
|
| 119 |
+
- idna==3.10
|
| 120 |
+
- imageio==2.37.0
|
| 121 |
+
- imageio-ffmpeg==0.6.0
|
| 122 |
+
- importlib-metadata==8.6.1
|
| 123 |
+
- jax==0.5.3
|
| 124 |
+
- jaxlib==0.5.3
|
| 125 |
+
- jinja2==3.1.6
|
| 126 |
+
- jiter==0.9.0
|
| 127 |
+
- kiwisolver==1.4.8
|
| 128 |
+
- lazy-loader==0.4
|
| 129 |
+
- lightning==2.5.2
|
| 130 |
+
- lightning-utilities==0.14.3
|
| 131 |
+
- markdown-it-py==3.0.0
|
| 132 |
+
- markupsafe==3.0.2
|
| 133 |
+
- matplotlib==3.10.1
|
| 134 |
+
- mdurl==0.1.2
|
| 135 |
+
- mediapipe==0.10.21
|
| 136 |
+
- ml-dtypes==0.5.1
|
| 137 |
+
- moviepy==1.0.3
|
| 138 |
+
- mpmath==1.3.0
|
| 139 |
+
- multidict==6.6.3
|
| 140 |
+
- networkx==3.4.2
|
| 141 |
+
- numpy==1.26.0
|
| 142 |
+
- nvidia-cublas-cu12==12.4.5.8
|
| 143 |
+
- nvidia-cuda-cupti-cu12==12.4.127
|
| 144 |
+
- nvidia-cuda-nvrtc-cu12==12.4.127
|
| 145 |
+
- nvidia-cuda-runtime-cu12==12.4.127
|
| 146 |
+
- nvidia-cudnn-cu12==9.1.0.70
|
| 147 |
+
- nvidia-cufft-cu12==11.2.1.3
|
| 148 |
+
- nvidia-curand-cu12==10.3.5.147
|
| 149 |
+
- nvidia-cusolver-cu12==11.6.1.9
|
| 150 |
+
- nvidia-cusparse-cu12==12.3.1.170
|
| 151 |
+
- nvidia-cusparselt-cu12==0.6.2
|
| 152 |
+
- nvidia-ml-py==12.570.86
|
| 153 |
+
- nvidia-nccl-cu12==2.21.5
|
| 154 |
+
- nvidia-nvjitlink-cu12==12.4.127
|
| 155 |
+
- nvidia-nvtx-cu12==12.4.127
|
| 156 |
+
- nvitop==1.4.2
|
| 157 |
+
- openai==1.68.2
|
| 158 |
+
- opencv-contrib-python==4.11.0.86
|
| 159 |
+
- opencv-python==4.11.0.86
|
| 160 |
+
- opencv-python-headless==4.11.0.86
|
| 161 |
+
- opt-einsum==3.4.0
|
| 162 |
+
- orjson==3.10.15
|
| 163 |
+
- packaging==24.2
|
| 164 |
+
- pandas==2.2.3
|
| 165 |
+
- peft==0.15.0
|
| 166 |
+
- pillow==9.5.0
|
| 167 |
+
- proglog==0.1.10
|
| 168 |
+
- propcache==0.3.2
|
| 169 |
+
- protobuf==4.25.6
|
| 170 |
+
- psutil==5.9.8
|
| 171 |
+
- ptflops==0.7.4
|
| 172 |
+
- pycparser==2.22
|
| 173 |
+
- pydantic==2.10.6
|
| 174 |
+
- pydantic-core==2.27.2
|
| 175 |
+
- pydub==0.25.1
|
| 176 |
+
- pyparsing==3.2.1
|
| 177 |
+
- pysocks==1.7.1
|
| 178 |
+
- python-dotenv==1.0.1
|
| 179 |
+
- python-multipart==0.0.20
|
| 180 |
+
- pytorch-lightning==2.5.2
|
| 181 |
+
- pytz==2025.1
|
| 182 |
+
- pyyaml==6.0.2
|
| 183 |
+
- regex==2024.11.6
|
| 184 |
+
- requests==2.32.3
|
| 185 |
+
- rich==13.9.4
|
| 186 |
+
- ruff==0.11.2
|
| 187 |
+
- safehttpx==0.1.6
|
| 188 |
+
- safetensors==0.5.3
|
| 189 |
+
- scikit-image==0.24.0
|
| 190 |
+
- scikit-video==1.1.11
|
| 191 |
+
- scipy==1.15.2
|
| 192 |
+
- semantic-version==2.10.0
|
| 193 |
+
- sentencepiece==0.2.0
|
| 194 |
+
- sentry-sdk==2.24.0
|
| 195 |
+
- setproctitle==1.3.5
|
| 196 |
+
- shellingham==1.5.4
|
| 197 |
+
- smmap==5.0.2
|
| 198 |
+
- sniffio==1.3.1
|
| 199 |
+
- sounddevice==0.5.1
|
| 200 |
+
- soupsieve==2.7
|
| 201 |
+
- spaces==0.32.0
|
| 202 |
+
- spandrel==0.4.1
|
| 203 |
+
- starlette==0.46.1
|
| 204 |
+
- sympy==1.13.1
|
| 205 |
+
- tifffile==2025.3.13
|
| 206 |
+
- timm==0.6.7
|
| 207 |
+
- tokenizers==0.21.1
|
| 208 |
+
- tomlkit==0.13.2
|
| 209 |
+
- torch==2.6.0
|
| 210 |
+
- torch-fidelity==0.3.0
|
| 211 |
+
- torchmetrics==1.7.4
|
| 212 |
+
- torchvision==0.21.0
|
| 213 |
+
- tqdm==4.67.1
|
| 214 |
+
- transformers==4.50.0
|
| 215 |
+
- triton==3.2.0
|
| 216 |
+
- typer==0.15.2
|
| 217 |
+
- typing-extensions==4.12.2
|
| 218 |
+
- tzdata==2025.1
|
| 219 |
+
- urllib3==2.3.0
|
| 220 |
+
- uvicorn==0.34.0
|
| 221 |
+
- videoio==0.3.0
|
| 222 |
+
- wandb==0.19.8
|
| 223 |
+
- websockets==15.0.1
|
| 224 |
+
- yarl==1.20.1
|
| 225 |
+
- zipp==3.21.0
|
training/accelerator_configs/accelerate_test.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# accelerate_test.py
|
| 2 |
+
from accelerate import Accelerator
|
| 3 |
+
import os
|
| 4 |
+
print("MADE IT HERE")
|
| 5 |
+
# Force unbuffered printing
|
| 6 |
+
import sys; sys.stdout.reconfigure(line_buffering=True)
|
| 7 |
+
|
| 8 |
+
acc = Accelerator()
|
| 9 |
+
print(acc.num_processes )
|
| 10 |
+
print(
|
| 11 |
+
f"[host {os.uname().nodename}] "
|
| 12 |
+
f"global rank {acc.process_index}/{acc.num_processes}, "
|
| 13 |
+
f"local rank {acc.local_process_index}"
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
# Print out assigned CUDA device
|
| 17 |
+
print(f"Device: {acc.device}")
|
training/accelerator_configs/accelerator_multigpu.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Specify distributed_type as `MULTI_GPU` for DDP
|
| 2 |
+
distributed_type: "MULTI_GPU"
|
| 3 |
+
# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`)
|
| 4 |
+
mixed_precision: "bf16"
|
| 5 |
+
# Specify the number of GPUs to use
|
| 6 |
+
num_processes: 4
|
training/accelerator_configs/accelerator_multinode.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
distributed_type: "MULTI_GPU"
|
| 2 |
+
mixed_precision: "bf16"
|
| 3 |
+
num_processes: 16
|
| 4 |
+
num_machines: 4
|
training/accelerator_configs/accelerator_singlegpu.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
main_process_port: 29501
|
| 3 |
+
debug: false
|
| 4 |
+
deepspeed_config:
|
| 5 |
+
gradient_accumulation_steps: 1
|
| 6 |
+
gradient_clipping: 1.0
|
| 7 |
+
offload_optimizer_device: none
|
| 8 |
+
offload_param_device: none
|
| 9 |
+
zero3_init_flag: false
|
| 10 |
+
zero_stage: 2
|
| 11 |
+
distributed_type: DEEPSPEED
|
| 12 |
+
downcast_bf16: 'no'
|
| 13 |
+
enable_cpu_affinity: false
|
| 14 |
+
machine_rank: 0
|
| 15 |
+
main_training_function: main
|
| 16 |
+
dynamo_backend: 'no'
|
| 17 |
+
mixed_precision: 'no'
|
| 18 |
+
num_machines: 1
|
| 19 |
+
num_processes: 1
|
| 20 |
+
rdzv_backend: static
|
| 21 |
+
same_network: true
|
| 22 |
+
tpu_env: []
|
| 23 |
+
tpu_use_cluster: false
|
| 24 |
+
tpu_use_sudo: false
|
| 25 |
+
use_cpu: false
|
training/accelerator_configs/accelerator_val_config.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
main_process_port: 29501
|
| 3 |
+
debug: false
|
| 4 |
+
deepspeed_config:
|
| 5 |
+
gradient_accumulation_steps: 1
|
| 6 |
+
gradient_clipping: 1.0
|
| 7 |
+
offload_optimizer_device: none
|
| 8 |
+
offload_param_device: none
|
| 9 |
+
zero3_init_flag: false
|
| 10 |
+
zero_stage: 2
|
| 11 |
+
distributed_type: DEEPSPEED
|
| 12 |
+
downcast_bf16: 'no'
|
| 13 |
+
enable_cpu_affinity: false
|
| 14 |
+
machine_rank: 0
|
| 15 |
+
main_training_function: main
|
| 16 |
+
dynamo_backend: 'no'
|
| 17 |
+
mixed_precision: 'no'
|
| 18 |
+
num_machines: 1
|
| 19 |
+
num_processes: 4
|
| 20 |
+
rdzv_backend: static
|
| 21 |
+
same_network: true
|
| 22 |
+
tpu_env: []
|
| 23 |
+
tpu_use_cluster: false
|
| 24 |
+
tpu_use_sudo: false
|
| 25 |
+
use_cpu: false
|
training/available-qos.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Name Priority GraceTime Preempt PreemptExemptTime PreemptMode Flags UsageThres UsageFactor GrpTRES GrpTRESMins GrpTRESRunMin GrpJobs GrpSubmit GrpWall MaxTRES MaxTRESPerNode MaxTRESMins MaxWall MaxTRESPU MaxJobsPU MaxSubmitPU MaxTRESPA MaxJobsPA MaxSubmitPA MinTRES
|
| 2 |
+
---------- ---------- ---------- ---------- ------------------- ----------- ---------------------------------------- ---------- ----------- ------------- ------------- ------------- ------- --------- ----------- ------------- -------------- ------------- ----------- ------------- --------- ----------- ------------- --------- ----------- -------------
|
| 3 |
+
normal 0 00:00:00 cluster 1.000000
|
| 4 |
+
gpu1-32h 10000 00:00:00 scavenger cluster DenyOnLimit 1.000000 1-08:00:00 cpu=28,gres/+
|
| 5 |
+
gpu2-16h 10000 00:00:00 scavenger cluster DenyOnLimit 1.000000 16:00:00 cpu=56,gres/+
|
| 6 |
+
gpu4-8h 10000 00:00:00 scavenger cluster DenyOnLimit 1.000000 08:00:00 cpu=112,gres+
|
| 7 |
+
gpu8-4h 10000 00:00:00 scavenger cluster DenyOnLimit 1.000000 04:00:00 cpu=224,gres+
|
| 8 |
+
gpu16-2h 10000 00:00:00 scavenger cluster DenyOnLimit 1.000000 02:00:00 cpu=448,gres+
|
| 9 |
+
gpu32-1h 10000 00:00:00 scavenger cluster DenyOnLimit 1.000000 01:00:00 cpu=896,gres+
|
| 10 |
+
scavenger 0 00:00:00 01:00:00 cluster 0.250000
|
training/configs/baist_test.yaml
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# === Required or overridden ===
|
| 2 |
+
base_dir: "/datasets/sai/gencam/blur2vid"
|
| 3 |
+
pretrained_model_name_or_path: "CogVideoX-2b/models--THUDM--CogVideoX-2b/snapshots/1137dacfc2c9c012bed6a0793f4ecf2ca8e7ba01" # Replace with actual path or env var expansion
|
| 4 |
+
video_root_dir: "datasets/b-aist"
|
| 5 |
+
csv_path: "set-path-to-csv-file" # Replace with actual CSV path
|
| 6 |
+
output_dir: "cogvideox-baist-test"
|
| 7 |
+
tracker_name: "cogvideox-baist-test"
|
| 8 |
+
|
| 9 |
+
# === Data-related ===
|
| 10 |
+
stride_min: 1
|
| 11 |
+
stride_max: 3
|
| 12 |
+
hflip_p: 0.5
|
| 13 |
+
downscale_coef: 8
|
| 14 |
+
init_from_transformer: true
|
| 15 |
+
dataloader_num_workers: 32
|
| 16 |
+
val_split: "test"
|
| 17 |
+
dataset: "baist"
|
| 18 |
+
|
| 19 |
+
# === Validation ===
|
| 20 |
+
num_inference_steps: 50
|
| 21 |
+
validation_prompt: ""
|
| 22 |
+
validation_video: "../resources/car.mp4:::../resources/ship.mp4"
|
| 23 |
+
validation_prompt_separator: ":::"
|
| 24 |
+
num_validation_videos: 1
|
| 25 |
+
validation_steps: 400
|
| 26 |
+
guidance_scale: 1.1
|
| 27 |
+
use_dynamic_cfg: false
|
| 28 |
+
just_validate: true
|
| 29 |
+
special_info: "just_one"
|
| 30 |
+
|
| 31 |
+
# === Training ===
|
| 32 |
+
seed: 42
|
| 33 |
+
mixed_precision: "bf16"
|
| 34 |
+
height: 720
|
| 35 |
+
width: 1280
|
| 36 |
+
fps: 8
|
| 37 |
+
max_num_frames: 17
|
| 38 |
+
train_batch_size: 2
|
| 39 |
+
num_train_epochs: 100
|
| 40 |
+
max_train_steps: null
|
| 41 |
+
checkpointing_steps: 200
|
| 42 |
+
checkpoints_total_limit: null
|
| 43 |
+
gradient_accumulation_steps: 1
|
| 44 |
+
gradient_checkpointing: true
|
| 45 |
+
learning_rate: 0.0001
|
| 46 |
+
scale_lr: false
|
| 47 |
+
lr_scheduler: "constant"
|
| 48 |
+
lr_warmup_steps: 250
|
| 49 |
+
lr_num_cycles: 1
|
| 50 |
+
lr_power: 1.0
|
| 51 |
+
enable_slicing: true
|
| 52 |
+
enable_tiling: true
|
| 53 |
+
|
| 54 |
+
# === Optimizer ===
|
| 55 |
+
optimizer: "adamw"
|
| 56 |
+
use_8bit_adam: false
|
| 57 |
+
adam_beta1: 0.9
|
| 58 |
+
adam_beta2: 0.95
|
| 59 |
+
prodigy_beta3: null
|
| 60 |
+
prodigy_decouple: false
|
| 61 |
+
adam_weight_decay: 0.0001
|
| 62 |
+
adam_epsilon: 0.0000001
|
| 63 |
+
max_grad_norm: 1.0
|
| 64 |
+
prodigy_use_bias_correction: false
|
| 65 |
+
prodigy_safeguard_warmup: false
|
| 66 |
+
|
| 67 |
+
# === Logging & Reporting ===
|
| 68 |
+
push_to_hub: false
|
| 69 |
+
hub_token: null
|
| 70 |
+
hub_model_id: null
|
| 71 |
+
logging_dir: "logs"
|
| 72 |
+
allow_tf32: true
|
| 73 |
+
report_to: null
|
| 74 |
+
|
| 75 |
+
# === Optional HuggingFace model variant ===
|
| 76 |
+
revision: null
|
| 77 |
+
variant: null
|
training/configs/baist_train.yaml
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# === Required or overridden ===
|
| 2 |
+
base_dir: "/datasets/sai/gencam/blur2vid"
|
| 3 |
+
pretrained_model_name_or_path: "CogVideoX-2b/models--THUDM--CogVideoX-2b/snapshots/1137dacfc2c9c012bed6a0793f4ecf2ca8e7ba01" # Replace with actual path or env var expansion
|
| 4 |
+
video_root_dir: "datasets/b-aist"
|
| 5 |
+
csv_path: "set-path-to-csv-file" # Replace with actual CSV path
|
| 6 |
+
output_dir: "cogvideox-baist-train"
|
| 7 |
+
tracker_name: "cogvideox-baist-train"
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# === Data-related ===
|
| 11 |
+
stride_min: 1
|
| 12 |
+
stride_max: 3
|
| 13 |
+
hflip_p: 0.5
|
| 14 |
+
downscale_coef: 8
|
| 15 |
+
init_from_transformer: true
|
| 16 |
+
dataloader_num_workers: 32
|
| 17 |
+
val_split: "val"
|
| 18 |
+
dataset: "baist"
|
| 19 |
+
|
| 20 |
+
# === Validation ===
|
| 21 |
+
num_inference_steps: 50
|
| 22 |
+
validation_prompt: ""
|
| 23 |
+
validation_video: "../resources/car.mp4:::../resources/ship.mp4"
|
| 24 |
+
validation_prompt_separator: ":::"
|
| 25 |
+
num_validation_videos: 1
|
| 26 |
+
validation_steps: 400
|
| 27 |
+
guidance_scale: 1.1
|
| 28 |
+
use_dynamic_cfg: false
|
| 29 |
+
just_validate: false
|
| 30 |
+
special_info: "just_one"
|
| 31 |
+
|
| 32 |
+
# === Training ===
|
| 33 |
+
seed: 42
|
| 34 |
+
mixed_precision: "bf16"
|
| 35 |
+
height: 720
|
| 36 |
+
width: 1280
|
| 37 |
+
fps: 8
|
| 38 |
+
max_num_frames: 17
|
| 39 |
+
train_batch_size: 2
|
| 40 |
+
num_train_epochs: 100
|
| 41 |
+
max_train_steps: null
|
| 42 |
+
checkpointing_steps: 200
|
| 43 |
+
checkpoints_total_limit: null
|
| 44 |
+
gradient_accumulation_steps: 1
|
| 45 |
+
gradient_checkpointing: true
|
| 46 |
+
learning_rate: 0.0001
|
| 47 |
+
scale_lr: false
|
| 48 |
+
lr_scheduler: "constant"
|
| 49 |
+
lr_warmup_steps: 250
|
| 50 |
+
lr_num_cycles: 1
|
| 51 |
+
lr_power: 1.0
|
| 52 |
+
enable_slicing: true
|
| 53 |
+
enable_tiling: true
|
| 54 |
+
|
| 55 |
+
# === Optimizer ===
|
| 56 |
+
optimizer: "adamw"
|
| 57 |
+
use_8bit_adam: false
|
| 58 |
+
adam_beta1: 0.9
|
| 59 |
+
adam_beta2: 0.95
|
| 60 |
+
prodigy_beta3: null
|
| 61 |
+
prodigy_decouple: false
|
| 62 |
+
adam_weight_decay: 0.0001
|
| 63 |
+
adam_epsilon: 0.0000001
|
| 64 |
+
max_grad_norm: 1.0
|
| 65 |
+
prodigy_use_bias_correction: false
|
| 66 |
+
prodigy_safeguard_warmup: false
|
| 67 |
+
|
| 68 |
+
# === Logging & Reporting ===
|
| 69 |
+
push_to_hub: false
|
| 70 |
+
hub_token: null
|
| 71 |
+
hub_model_id: null
|
| 72 |
+
logging_dir: "logs"
|
| 73 |
+
allow_tf32: true
|
| 74 |
+
report_to: null
|
| 75 |
+
|
| 76 |
+
# === Optional HuggingFace model variant ===
|
| 77 |
+
revision: null
|
| 78 |
+
variant: null
|
training/configs/full_test.yaml
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# === Required or overridden ===
|
| 2 |
+
base_dir: "/datasets/sai/gencam/blur2vid"
|
| 3 |
+
pretrained_model_name_or_path: "CogVideoX-2b/models--THUDM--CogVideoX-2b/snapshots/1137dacfc2c9c012bed6a0793f4ecf2ca8e7ba01" # Replace with actual path or env var expansion
|
| 4 |
+
video_root_dir: "datasets/FullDataset"
|
| 5 |
+
csv_path: "set-path-to-csv-file" # Replace with actual CSV path
|
| 6 |
+
output_dir: "cogvideox-full-test"
|
| 7 |
+
tracker_name: "cogvideox-full-test"
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# === Data-related ===
|
| 11 |
+
stride_min: 1
|
| 12 |
+
stride_max: 3
|
| 13 |
+
hflip_p: 0.5
|
| 14 |
+
downscale_coef: 8
|
| 15 |
+
init_from_transformer: true
|
| 16 |
+
dataloader_num_workers: 32
|
| 17 |
+
val_split: "test"
|
| 18 |
+
dataset: "full"
|
| 19 |
+
|
| 20 |
+
# === Validation ===
|
| 21 |
+
num_inference_steps: 50
|
| 22 |
+
validation_prompt: ""
|
| 23 |
+
validation_video: "../resources/car.mp4:::../resources/ship.mp4"
|
| 24 |
+
validation_prompt_separator: ":::"
|
| 25 |
+
num_validation_videos: 1
|
| 26 |
+
validation_steps: 400
|
| 27 |
+
guidance_scale: 1.1
|
| 28 |
+
use_dynamic_cfg: false
|
| 29 |
+
just_validate: true
|
| 30 |
+
special_info: "just_one"
|
| 31 |
+
|
| 32 |
+
# === Training ===
|
| 33 |
+
seed: 42
|
| 34 |
+
mixed_precision: "bf16"
|
| 35 |
+
height: 720
|
| 36 |
+
width: 1280
|
| 37 |
+
fps: 8
|
| 38 |
+
max_num_frames: 17
|
| 39 |
+
train_batch_size: 2
|
| 40 |
+
num_train_epochs: 200
|
| 41 |
+
max_train_steps: null
|
| 42 |
+
checkpointing_steps: 200
|
| 43 |
+
checkpoints_total_limit: null
|
| 44 |
+
gradient_accumulation_steps: 2
|
| 45 |
+
gradient_checkpointing: true
|
| 46 |
+
learning_rate: 0.0001
|
| 47 |
+
scale_lr: false
|
| 48 |
+
lr_scheduler: "constant"
|
| 49 |
+
lr_warmup_steps: 250
|
| 50 |
+
lr_num_cycles: 1
|
| 51 |
+
lr_power: 1.0
|
| 52 |
+
enable_slicing: true
|
| 53 |
+
enable_tiling: true
|
| 54 |
+
|
| 55 |
+
# === Optimizer ===
|
| 56 |
+
optimizer: "adamw"
|
| 57 |
+
use_8bit_adam: false
|
| 58 |
+
adam_beta1: 0.9
|
| 59 |
+
adam_beta2: 0.95
|
| 60 |
+
prodigy_beta3: null
|
| 61 |
+
prodigy_decouple: false
|
| 62 |
+
adam_weight_decay: 0.0001
|
| 63 |
+
adam_epsilon: 0.0000001
|
| 64 |
+
max_grad_norm: 1.0
|
| 65 |
+
prodigy_use_bias_correction: false
|
| 66 |
+
prodigy_safeguard_warmup: false
|
| 67 |
+
|
| 68 |
+
# === Logging & Reporting ===
|
| 69 |
+
push_to_hub: false
|
| 70 |
+
hub_token: null
|
| 71 |
+
hub_model_id: null
|
| 72 |
+
logging_dir: "logs"
|
| 73 |
+
allow_tf32: true
|
| 74 |
+
report_to: null
|
| 75 |
+
|
| 76 |
+
# === Optional HuggingFace model variant ===
|
| 77 |
+
revision: null
|
| 78 |
+
variant: null
|
training/configs/full_train.yaml
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# === Required or overridden ===
|
| 2 |
+
base_dir: "/datasets/sai/gencam/blur2vid"
|
| 3 |
+
pretrained_model_name_or_path: "CogVideoX-2b/models--THUDM--CogVideoX-2b/snapshots/1137dacfc2c9c012bed6a0793f4ecf2ca8e7ba01" # Replace with actual path or env var expansion
|
| 4 |
+
video_root_dir: "datasets/FullDataset"
|
| 5 |
+
csv_path: "set-path-to-csv-file" # Replace with actual CSV path
|
| 6 |
+
output_dir: "cogvideox-full-train"
|
| 7 |
+
tracker_name: "cogvideox-full-train"
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# === Data-related ===
|
| 11 |
+
stride_min: 1
|
| 12 |
+
stride_max: 3
|
| 13 |
+
hflip_p: 0.5
|
| 14 |
+
downscale_coef: 8
|
| 15 |
+
init_from_transformer: true
|
| 16 |
+
dataloader_num_workers: 2
|
| 17 |
+
val_split: "val"
|
| 18 |
+
dataset: "full"
|
| 19 |
+
|
| 20 |
+
# === Validation ===
|
| 21 |
+
num_inference_steps: 50
|
| 22 |
+
validation_prompt: ""
|
| 23 |
+
validation_video: "../resources/car.mp4:::../resources/ship.mp4"
|
| 24 |
+
validation_prompt_separator: ":::"
|
| 25 |
+
num_validation_videos: 1
|
| 26 |
+
validation_steps: 400
|
| 27 |
+
guidance_scale: 1.0
|
| 28 |
+
use_dynamic_cfg: false
|
| 29 |
+
just_validate: false
|
| 30 |
+
special_info: "just_one"
|
| 31 |
+
|
| 32 |
+
# === Training ===
|
| 33 |
+
seed: 42
|
| 34 |
+
mixed_precision: "bf16"
|
| 35 |
+
height: 720
|
| 36 |
+
width: 1280
|
| 37 |
+
fps: 8
|
| 38 |
+
max_num_frames: 17
|
| 39 |
+
train_batch_size: 2
|
| 40 |
+
num_train_epochs: 200
|
| 41 |
+
max_train_steps: null
|
| 42 |
+
checkpointing_steps: 200
|
| 43 |
+
checkpoints_total_limit: null
|
| 44 |
+
gradient_accumulation_steps: 2
|
| 45 |
+
gradient_checkpointing: true
|
| 46 |
+
learning_rate: 0.0001
|
| 47 |
+
scale_lr: false
|
| 48 |
+
lr_scheduler: "constant"
|
| 49 |
+
lr_warmup_steps: 250
|
| 50 |
+
lr_num_cycles: 1
|
| 51 |
+
lr_power: 1.0
|
| 52 |
+
enable_slicing: true
|
| 53 |
+
enable_tiling: true
|
| 54 |
+
|
| 55 |
+
# === Optimizer ===
|
| 56 |
+
optimizer: "adamw"
|
| 57 |
+
use_8bit_adam: false
|
| 58 |
+
adam_beta1: 0.9
|
| 59 |
+
adam_beta2: 0.95
|
| 60 |
+
prodigy_beta3: null
|
| 61 |
+
prodigy_decouple: false
|
| 62 |
+
adam_weight_decay: 0.0001
|
| 63 |
+
adam_epsilon: 0.0000001
|
| 64 |
+
max_grad_norm: 1.0
|
| 65 |
+
prodigy_use_bias_correction: false
|
| 66 |
+
prodigy_safeguard_warmup: false
|
| 67 |
+
|
| 68 |
+
# === Logging & Reporting ===
|
| 69 |
+
push_to_hub: false
|
| 70 |
+
hub_token: null
|
| 71 |
+
hub_model_id: null
|
| 72 |
+
logging_dir: "logs"
|
| 73 |
+
allow_tf32: true
|
| 74 |
+
report_to: null
|
| 75 |
+
|
| 76 |
+
# === Optional HuggingFace model variant ===
|
| 77 |
+
revision: null
|
| 78 |
+
variant: null
|
training/configs/gopro_2x_test.yaml
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# === Required or overridden ===
|
| 2 |
+
base_dir: "/datasets/sai/gencam/blur2vid"
|
| 3 |
+
pretrained_model_name_or_path: "CogVideoX-2b/models--THUDM--CogVideoX-2b/snapshots/1137dacfc2c9c012bed6a0793f4ecf2ca8e7ba01" # Replace with actual path or env var expansion
|
| 4 |
+
video_root_dir: "datasets/GOPRO_7"
|
| 5 |
+
csv_path: "set-path-to-csv-file" # Replace with actual CSV path
|
| 6 |
+
output_dir: "cogvideox-gopro-2x-test"
|
| 7 |
+
tracker_name: "cogvideox-gopro-2x-test"
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# === Data-related ===
|
| 11 |
+
stride_min: 1
|
| 12 |
+
stride_max: 3
|
| 13 |
+
hflip_p: 0.5
|
| 14 |
+
downscale_coef: 8
|
| 15 |
+
init_from_transformer: true
|
| 16 |
+
dataloader_num_workers: 32
|
| 17 |
+
val_split: "test"
|
| 18 |
+
dataset: "gopro2x"
|
| 19 |
+
|
| 20 |
+
# === Validation ===
|
| 21 |
+
num_inference_steps: 50
|
| 22 |
+
validation_prompt: ""
|
| 23 |
+
validation_video: "../resources/car.mp4:::../resources/ship.mp4"
|
| 24 |
+
validation_prompt_separator: ":::"
|
| 25 |
+
num_validation_videos: 1
|
| 26 |
+
validation_steps: 400
|
| 27 |
+
guidance_scale: 1.1
|
| 28 |
+
use_dynamic_cfg: false
|
| 29 |
+
just_validate: true
|
| 30 |
+
special_info: "just_one"
|
| 31 |
+
|
| 32 |
+
# === Training ===
|
| 33 |
+
seed: 42
|
| 34 |
+
mixed_precision: "bf16"
|
| 35 |
+
height: 720
|
| 36 |
+
width: 1280
|
| 37 |
+
fps: 8
|
| 38 |
+
max_num_frames: 17
|
| 39 |
+
train_batch_size: 4
|
| 40 |
+
num_train_epochs: 100
|
| 41 |
+
max_train_steps: null
|
| 42 |
+
checkpointing_steps: 400
|
| 43 |
+
checkpoints_total_limit: null
|
| 44 |
+
gradient_accumulation_steps: 1
|
| 45 |
+
gradient_checkpointing: true
|
| 46 |
+
learning_rate: 0.0001
|
| 47 |
+
scale_lr: false
|
| 48 |
+
lr_scheduler: "constant"
|
| 49 |
+
lr_warmup_steps: 250
|
| 50 |
+
lr_num_cycles: 1
|
| 51 |
+
lr_power: 1.0
|
| 52 |
+
enable_slicing: true
|
| 53 |
+
enable_tiling: true
|
| 54 |
+
|
| 55 |
+
# === Optimizer ===
|
| 56 |
+
optimizer: "adamw"
|
| 57 |
+
use_8bit_adam: false
|
| 58 |
+
adam_beta1: 0.9
|
| 59 |
+
adam_beta2: 0.95
|
| 60 |
+
prodigy_beta3: null
|
| 61 |
+
prodigy_decouple: false
|
| 62 |
+
adam_weight_decay: 0.0001
|
| 63 |
+
adam_epsilon: 0.0000001
|
| 64 |
+
max_grad_norm: 1.0
|
| 65 |
+
prodigy_use_bias_correction: false
|
| 66 |
+
prodigy_safeguard_warmup: false
|
| 67 |
+
|
| 68 |
+
# === Logging & Reporting ===
|
| 69 |
+
push_to_hub: false
|
| 70 |
+
hub_token: null
|
| 71 |
+
hub_model_id: null
|
| 72 |
+
logging_dir: "logs"
|
| 73 |
+
allow_tf32: true
|
| 74 |
+
report_to: null
|
| 75 |
+
|
| 76 |
+
# === Optional HuggingFace model variant ===
|
| 77 |
+
revision: null
|
| 78 |
+
variant: null
|
training/configs/gopro_test.yaml
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# === Required or overridden ===
|
| 2 |
+
base_dir: "/datasets/sai/gencam/blur2vid"
|
| 3 |
+
pretrained_model_name_or_path: "CogVideoX-2b/models--THUDM--CogVideoX-2b/snapshots/1137dacfc2c9c012bed6a0793f4ecf2ca8e7ba01" # Replace with actual path or env var expansion
|
| 4 |
+
video_root_dir: "datasets/GOPRO_7"
|
| 5 |
+
csv_path: "set-path-to-csv-file" # Replace with actual CSV path
|
| 6 |
+
output_dir: "cogvideox-gopro-test"
|
| 7 |
+
tracker_name: "cogvideox-gopro-test"
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# === Data-related ===
|
| 11 |
+
stride_min: 1
|
| 12 |
+
stride_max: 3
|
| 13 |
+
hflip_p: 0.5
|
| 14 |
+
downscale_coef: 8
|
| 15 |
+
init_from_transformer: true
|
| 16 |
+
dataloader_num_workers: 32
|
| 17 |
+
val_split: "test"
|
| 18 |
+
dataset: "gopro"
|
| 19 |
+
|
| 20 |
+
# === Validation ===
|
| 21 |
+
num_inference_steps: 50
|
| 22 |
+
validation_prompt: ""
|
| 23 |
+
validation_video: "../resources/car.mp4:::../resources/ship.mp4"
|
| 24 |
+
validation_prompt_separator: ":::"
|
| 25 |
+
num_validation_videos: 1
|
| 26 |
+
validation_steps: 400
|
| 27 |
+
guidance_scale: 1.1
|
| 28 |
+
use_dynamic_cfg: false
|
| 29 |
+
just_validate: true
|
| 30 |
+
special_info: "just_one"
|
| 31 |
+
|
| 32 |
+
# === Training ===
|
| 33 |
+
seed: 42
|
| 34 |
+
mixed_precision: "bf16"
|
| 35 |
+
height: 720
|
| 36 |
+
width: 1280
|
| 37 |
+
fps: 8
|
| 38 |
+
max_num_frames: 9
|
| 39 |
+
train_batch_size: 4
|
| 40 |
+
num_train_epochs: 500
|
| 41 |
+
max_train_steps: null
|
| 42 |
+
checkpointing_steps: 100
|
| 43 |
+
checkpoints_total_limit: null
|
| 44 |
+
gradient_accumulation_steps: 1
|
| 45 |
+
gradient_checkpointing: true
|
| 46 |
+
learning_rate: 0.0001
|
| 47 |
+
scale_lr: false
|
| 48 |
+
lr_scheduler: "constant"
|
| 49 |
+
lr_warmup_steps: 250
|
| 50 |
+
lr_num_cycles: 1
|
| 51 |
+
lr_power: 1.0
|
| 52 |
+
enable_slicing: true
|
| 53 |
+
enable_tiling: true
|
| 54 |
+
|
| 55 |
+
# === Optimizer ===
|
| 56 |
+
optimizer: "adamw"
|
| 57 |
+
use_8bit_adam: false
|
| 58 |
+
adam_beta1: 0.9
|
| 59 |
+
adam_beta2: 0.95
|
| 60 |
+
prodigy_beta3: null
|
| 61 |
+
prodigy_decouple: false
|
| 62 |
+
adam_weight_decay: 0.0001
|
| 63 |
+
adam_epsilon: 0.0000001
|
| 64 |
+
max_grad_norm: 1.0
|
| 65 |
+
prodigy_use_bias_correction: false
|
| 66 |
+
prodigy_safeguard_warmup: false
|
| 67 |
+
|
| 68 |
+
# === Logging & Reporting ===
|
| 69 |
+
push_to_hub: false
|
| 70 |
+
hub_token: null
|
| 71 |
+
hub_model_id: null
|
| 72 |
+
logging_dir: "logs"
|
| 73 |
+
allow_tf32: true
|
| 74 |
+
report_to: null
|
| 75 |
+
|
| 76 |
+
# === Optional HuggingFace model variant ===
|
| 77 |
+
revision: null
|
| 78 |
+
variant: null
|
training/configs/gopro_train.yaml
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# === Required or overridden ===
|
| 2 |
+
base_dir: "/datasets/sai/gencam/blur2vid"
|
| 3 |
+
pretrained_model_name_or_path: "CogVideoX-2b/models--THUDM--CogVideoX-2b/snapshots/1137dacfc2c9c012bed6a0793f4ecf2ca8e7ba01" # Replace with actual path or env var expansion
|
| 4 |
+
video_root_dir: "datasets/GOPRO_7"
|
| 5 |
+
csv_path: "set-path-to-csv-file" # Replace with actual CSV path
|
| 6 |
+
output_dir: "cogvideox-gopro-train"
|
| 7 |
+
tracker_name: "cogvideox-gopro-train"
|
| 8 |
+
|
| 9 |
+
# === Data-related ===
|
| 10 |
+
stride_min: 1
|
| 11 |
+
stride_max: 3
|
| 12 |
+
hflip_p: 0.5
|
| 13 |
+
downscale_coef: 8
|
| 14 |
+
init_from_transformer: true
|
| 15 |
+
dataloader_num_workers: 2
|
| 16 |
+
val_split: "val"
|
| 17 |
+
dataset: "gopro"
|
| 18 |
+
|
| 19 |
+
# === Validation ===
|
| 20 |
+
num_inference_steps: 50
|
| 21 |
+
validation_prompt: ""
|
| 22 |
+
validation_video: "../resources/car.mp4:::../resources/ship.mp4"
|
| 23 |
+
validation_prompt_separator: ":::"
|
| 24 |
+
num_validation_videos: 1
|
| 25 |
+
validation_steps: 100
|
| 26 |
+
guidance_scale: 1.0
|
| 27 |
+
use_dynamic_cfg: false
|
| 28 |
+
just_validate: false
|
| 29 |
+
special_info: "just_one"
|
| 30 |
+
|
| 31 |
+
# === Training ===
|
| 32 |
+
seed: 42
|
| 33 |
+
mixed_precision: "bf16"
|
| 34 |
+
height: 720
|
| 35 |
+
width: 1280
|
| 36 |
+
fps: 8
|
| 37 |
+
max_num_frames: 9
|
| 38 |
+
train_batch_size: 4
|
| 39 |
+
num_train_epochs: 500
|
| 40 |
+
max_train_steps: null
|
| 41 |
+
checkpointing_steps: 100
|
| 42 |
+
checkpoints_total_limit: null
|
| 43 |
+
gradient_accumulation_steps: 1
|
| 44 |
+
gradient_checkpointing: true
|
| 45 |
+
learning_rate: 0.0001
|
| 46 |
+
scale_lr: false
|
| 47 |
+
lr_scheduler: "constant"
|
| 48 |
+
lr_warmup_steps: 250
|
| 49 |
+
lr_num_cycles: 1
|
| 50 |
+
lr_power: 1.0
|
| 51 |
+
enable_slicing: true
|
| 52 |
+
enable_tiling: true
|
| 53 |
+
|
| 54 |
+
# === Optimizer ===
|
| 55 |
+
optimizer: "adamw"
|
| 56 |
+
use_8bit_adam: false
|
| 57 |
+
adam_beta1: 0.9
|
| 58 |
+
adam_beta2: 0.95
|
| 59 |
+
prodigy_beta3: null
|
| 60 |
+
prodigy_decouple: false
|
| 61 |
+
adam_weight_decay: 0.0001
|
| 62 |
+
adam_epsilon: 0.0000001
|
| 63 |
+
max_grad_norm: 1.0
|
| 64 |
+
prodigy_use_bias_correction: false
|
| 65 |
+
prodigy_safeguard_warmup: false
|
| 66 |
+
|
| 67 |
+
# === Logging & Reporting ===
|
| 68 |
+
push_to_hub: false
|
| 69 |
+
hub_token: null
|
| 70 |
+
hub_model_id: null
|
| 71 |
+
logging_dir: "logs"
|
| 72 |
+
allow_tf32: true
|
| 73 |
+
report_to: null
|
| 74 |
+
|
| 75 |
+
# === Optional HuggingFace model variant ===
|
| 76 |
+
revision: null
|
| 77 |
+
variant: null
|
training/configs/outsidephotos.yaml
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# === Required or overridden ===
|
| 2 |
+
base_dir: "/datasets/sai/gencam/blur2vid"
|
| 3 |
+
pretrained_model_name_or_path: "cogvideox/CogVideoX-2b/models--THUDM--CogVideoX-2b/snapshots/1137dacfc2c9c012bed6a0793f4ecf2ca8e7ba01" # Replace with actual path or env var expansion
|
| 4 |
+
video_root_dir: "datasets/my_motion_blurred_images"
|
| 5 |
+
csv_path: "set-path-to-csv-file" # Replace with actual CSV path
|
| 6 |
+
output_dir: "cogvideox-outsidephotos"
|
| 7 |
+
tracker_name: "cogvideox-outsidephotos"
|
| 8 |
+
|
| 9 |
+
# === Data-related ===
|
| 10 |
+
stride_min: 1
|
| 11 |
+
stride_max: 3
|
| 12 |
+
hflip_p: 0.5
|
| 13 |
+
downscale_coef: 8
|
| 14 |
+
init_from_transformer: true
|
| 15 |
+
dataloader_num_workers: 0
|
| 16 |
+
val_split: "test"
|
| 17 |
+
dataset: "outsidephotos"
|
| 18 |
+
|
| 19 |
+
# === Validation ===
|
| 20 |
+
num_inference_steps: 50
|
| 21 |
+
just_validate: true
|
| 22 |
+
validation_prompt: ""
|
| 23 |
+
validation_video: "../resources/car.mp4:::../resources/ship.mp4"
|
| 24 |
+
validation_prompt_separator: ":::"
|
| 25 |
+
num_validation_videos: 1
|
| 26 |
+
validation_steps: 100
|
| 27 |
+
guidance_scale: 1.1
|
| 28 |
+
use_dynamic_cfg: false
|
| 29 |
+
|
| 30 |
+
# === Training ===
|
| 31 |
+
seed: 42
|
| 32 |
+
mixed_precision: "bf16"
|
| 33 |
+
height: 720
|
| 34 |
+
width: 1280
|
| 35 |
+
fps: 8
|
| 36 |
+
max_num_frames: 17
|
| 37 |
+
train_batch_size: 1
|
| 38 |
+
num_train_epochs: 100
|
| 39 |
+
max_train_steps: null
|
| 40 |
+
checkpointing_steps: 100
|
| 41 |
+
checkpoints_total_limit: null
|
| 42 |
+
gradient_accumulation_steps: 1
|
| 43 |
+
gradient_checkpointing: true
|
| 44 |
+
learning_rate: 0.0001
|
| 45 |
+
scale_lr: false
|
| 46 |
+
lr_scheduler: "constant"
|
| 47 |
+
lr_warmup_steps: 250
|
| 48 |
+
lr_num_cycles: 1
|
| 49 |
+
lr_power: 1.0
|
| 50 |
+
enable_slicing: true
|
| 51 |
+
enable_tiling: true
|
| 52 |
+
|
| 53 |
+
# === Optimizer ===
|
| 54 |
+
optimizer: "adamw"
|
| 55 |
+
use_8bit_adam: false
|
| 56 |
+
adam_beta1: 0.9
|
| 57 |
+
adam_beta2: 0.95
|
| 58 |
+
prodigy_beta3: null
|
| 59 |
+
prodigy_decouple: false
|
| 60 |
+
adam_weight_decay: 0.0001
|
| 61 |
+
adam_epsilon: 0.0000001
|
| 62 |
+
max_grad_norm: 1.0
|
| 63 |
+
prodigy_use_bias_correction: false
|
| 64 |
+
prodigy_safeguard_warmup: false
|
| 65 |
+
|
| 66 |
+
# === Logging & Reporting ===
|
| 67 |
+
push_to_hub: false
|
| 68 |
+
hub_token: null
|
| 69 |
+
hub_model_id: null
|
| 70 |
+
logging_dir: "logs"
|
| 71 |
+
allow_tf32: true
|
| 72 |
+
report_to: null
|
| 73 |
+
|
| 74 |
+
# === Optional HuggingFace model variant ===
|
| 75 |
+
revision: null
|
| 76 |
+
variant: null
|
training/controlnet_datasets.py
ADDED
|
@@ -0,0 +1,735 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import os
|
| 3 |
+
import glob
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import pickle
|
| 6 |
+
import random
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
import cv2
|
| 11 |
+
import torch
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pandas as pd
|
| 14 |
+
import torchvision.transforms as transforms
|
| 15 |
+
from PIL import Image, ImageOps, ImageCms
|
| 16 |
+
from decord import VideoReader
|
| 17 |
+
from torch.utils.data.dataset import Dataset
|
| 18 |
+
from controlnet_aux import CannyDetector, HEDdetector
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from helpers import generate_1x_sequence, generate_2x_sequence, generate_large_blur_sequence, generate_test_case
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def unpack_mm_params(p):
|
| 24 |
+
if isinstance(p, (tuple, list)):
|
| 25 |
+
return p[0], p[1]
|
| 26 |
+
elif isinstance(p, (int, float)):
|
| 27 |
+
return p, p
|
| 28 |
+
raise Exception(f'Unknown input parameter type.\nParameter: {p}.\nType: {type(p)}')
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def resize_for_crop(image, min_h, min_w):
|
| 32 |
+
img_h, img_w = image.shape[-2:]
|
| 33 |
+
|
| 34 |
+
if img_h >= min_h and img_w >= min_w:
|
| 35 |
+
coef = min(min_h / img_h, min_w / img_w)
|
| 36 |
+
elif img_h <= min_h and img_w <=min_w:
|
| 37 |
+
coef = max(min_h / img_h, min_w / img_w)
|
| 38 |
+
else:
|
| 39 |
+
coef = min_h / img_h if min_h > img_h else min_w / img_w
|
| 40 |
+
|
| 41 |
+
out_h, out_w = int(img_h * coef), int(img_w * coef)
|
| 42 |
+
resized_image = transforms.functional.resize(image, (out_h, out_w), antialias=True)
|
| 43 |
+
return resized_image
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class BaseClass(Dataset):
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
data_dir,
|
| 51 |
+
output_dir,
|
| 52 |
+
image_size=(320, 512),
|
| 53 |
+
hflip_p=0.5,
|
| 54 |
+
controlnet_type='canny',
|
| 55 |
+
split='train',
|
| 56 |
+
*args,
|
| 57 |
+
**kwargs
|
| 58 |
+
):
|
| 59 |
+
self.split = split
|
| 60 |
+
self.height, self.width = unpack_mm_params(image_size)
|
| 61 |
+
self.data_dir = data_dir
|
| 62 |
+
self.output_dir = output_dir
|
| 63 |
+
self.hflip_p = hflip_p
|
| 64 |
+
self.image_size = image_size
|
| 65 |
+
self.length = 0
|
| 66 |
+
|
| 67 |
+
def __len__(self):
|
| 68 |
+
return self.length
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def load_frames(self, frames):
|
| 72 |
+
# frames: numpy array of shape (N, H, W, C), 0–255
|
| 73 |
+
# → tensor of shape (N, C, H, W) as float
|
| 74 |
+
pixel_values = torch.from_numpy(frames).permute(0, 3, 1, 2).contiguous().float()
|
| 75 |
+
# normalize to [-1, 1]
|
| 76 |
+
pixel_values = pixel_values / 127.5 - 1.0
|
| 77 |
+
# resize to (self.height, self.width)
|
| 78 |
+
pixel_values = F.interpolate(
|
| 79 |
+
pixel_values,
|
| 80 |
+
size=(self.height, self.width),
|
| 81 |
+
mode="bilinear",
|
| 82 |
+
align_corners=False
|
| 83 |
+
)
|
| 84 |
+
return pixel_values
|
| 85 |
+
|
| 86 |
+
def get_batch(self, idx):
|
| 87 |
+
raise Exception('Get batch method is not realized.')
|
| 88 |
+
|
| 89 |
+
def __getitem__(self, idx):
|
| 90 |
+
while True:
|
| 91 |
+
try:
|
| 92 |
+
video, caption, motion_blur = self.get_batch(idx)
|
| 93 |
+
break
|
| 94 |
+
except Exception as e:
|
| 95 |
+
print(e)
|
| 96 |
+
idx = random.randint(0, self.length - 1)
|
| 97 |
+
|
| 98 |
+
video, = [
|
| 99 |
+
resize_for_crop(x, self.height, self.width) for x in [video]
|
| 100 |
+
]
|
| 101 |
+
video, = [
|
| 102 |
+
transforms.functional.center_crop(x, (self.height, self.width)) for x in [video]
|
| 103 |
+
]
|
| 104 |
+
data = {
|
| 105 |
+
'video': video,
|
| 106 |
+
'caption': caption,
|
| 107 |
+
}
|
| 108 |
+
return data
|
| 109 |
+
|
| 110 |
+
def load_as_srgb(path):
|
| 111 |
+
img = Image.open(path)
|
| 112 |
+
img = ImageOps.exif_transpose(img)
|
| 113 |
+
|
| 114 |
+
if 'icc_profile' in img.info:
|
| 115 |
+
icc = img.info['icc_profile']
|
| 116 |
+
src_profile = ImageCms.ImageCmsProfile(io.BytesIO(icc))
|
| 117 |
+
dst_profile = ImageCms.createProfile("sRGB")
|
| 118 |
+
img = ImageCms.profileToProfile(img, src_profile, dst_profile, outputMode='RGB')
|
| 119 |
+
else:
|
| 120 |
+
img = img.convert("RGB") # Assume sRGB
|
| 121 |
+
return img
|
| 122 |
+
|
| 123 |
+
class GoProMotionBlurDataset(BaseClass): #7 frame go pro dataset
|
| 124 |
+
def __init__(self,
|
| 125 |
+
*args, **kwargs):
|
| 126 |
+
super().__init__(*args, **kwargs)
|
| 127 |
+
# Set blur and sharp directories based on split
|
| 128 |
+
if self.split == 'train':
|
| 129 |
+
self.blur_root = os.path.join(self.data_dir, 'train', 'blur')
|
| 130 |
+
self.sharp_root = os.path.join(self.data_dir, 'train', 'sharp')
|
| 131 |
+
elif self.split in ['val', 'test']:
|
| 132 |
+
self.blur_root = os.path.join(self.data_dir, 'test', 'blur')
|
| 133 |
+
self.sharp_root = os.path.join(self.data_dir, 'test', 'sharp')
|
| 134 |
+
else:
|
| 135 |
+
raise ValueError(f"Unsupported split: {self.split}")
|
| 136 |
+
|
| 137 |
+
# Collect all blurred image paths
|
| 138 |
+
pattern = os.path.join(self.blur_root, '*', '*.png')
|
| 139 |
+
|
| 140 |
+
self.blur_paths = sorted(glob.glob(pattern))
|
| 141 |
+
|
| 142 |
+
if self.split == 'val':
|
| 143 |
+
# Optional: limit validation subset
|
| 144 |
+
self.blur_paths = self.blur_paths[:5]
|
| 145 |
+
|
| 146 |
+
filtered_blur_paths = []
|
| 147 |
+
for path in self.blur_paths:
|
| 148 |
+
output_deblurred_dir = os.path.join(self.output_dir, "deblurred")
|
| 149 |
+
full_output_path = Path(output_deblurred_dir, *path.split('/')[-2:]).with_suffix(".mp4")
|
| 150 |
+
if not os.path.exists(full_output_path):
|
| 151 |
+
filtered_blur_paths.append(path)
|
| 152 |
+
|
| 153 |
+
# Window and padding parameters
|
| 154 |
+
self.window_size = 7 # original number of sharp frames
|
| 155 |
+
self.pad = 2 # number of times to repeat last frame
|
| 156 |
+
self.output_length = self.window_size + self.pad
|
| 157 |
+
self.half_window = self.window_size // 2
|
| 158 |
+
self.length = len(self.blur_paths)
|
| 159 |
+
|
| 160 |
+
# Normalized input interval: always [-0.5, 0.5]
|
| 161 |
+
self.input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float)
|
| 162 |
+
|
| 163 |
+
# Precompute normalized output intervals: first for window_size frames, then pad duplicates
|
| 164 |
+
step = 1.0 / (self.window_size - 1)
|
| 165 |
+
# intervals for the original 7 frames
|
| 166 |
+
window_intervals = []
|
| 167 |
+
for i in range(self.window_size):
|
| 168 |
+
start = -0.5 + i * step
|
| 169 |
+
if i < self.window_size - 1:
|
| 170 |
+
end = -0.5 + (i + 1) * step
|
| 171 |
+
else:
|
| 172 |
+
end = 0.5
|
| 173 |
+
window_intervals.append([start, end])
|
| 174 |
+
# append the last interval pad times
|
| 175 |
+
intervals = window_intervals + [window_intervals[-1]] * self.pad
|
| 176 |
+
self.output_interval = torch.tensor(intervals, dtype=torch.float)
|
| 177 |
+
|
| 178 |
+
def __len__(self):
|
| 179 |
+
return self.length
|
| 180 |
+
|
| 181 |
+
def __getitem__(self, idx):
|
| 182 |
+
# Path to the blurred (center) frame
|
| 183 |
+
blur_path = self.blur_paths[idx]
|
| 184 |
+
seq_name = os.path.basename(os.path.dirname(blur_path))
|
| 185 |
+
frame_name = os.path.basename(blur_path)
|
| 186 |
+
center_idx = int(os.path.splitext(frame_name)[0])
|
| 187 |
+
|
| 188 |
+
# Compute sharp frame range [center-half, center+half]
|
| 189 |
+
start_idx = center_idx - self.half_window
|
| 190 |
+
end_idx = center_idx + self.half_window
|
| 191 |
+
|
| 192 |
+
# Load sharp frames
|
| 193 |
+
sharp_dir = os.path.join(self.sharp_root, seq_name)
|
| 194 |
+
frames = []
|
| 195 |
+
for i in range(start_idx, end_idx + 1):
|
| 196 |
+
sharp_filename = f"{i:06d}.png"
|
| 197 |
+
sharp_path = os.path.join(sharp_dir, sharp_filename)
|
| 198 |
+
img = Image.open(sharp_path).convert('RGB')
|
| 199 |
+
frames.append(img)
|
| 200 |
+
|
| 201 |
+
# Repeat last sharp frame so total frames == output_length
|
| 202 |
+
while len(frames) < self.output_length:
|
| 203 |
+
frames.append(frames[-1])
|
| 204 |
+
|
| 205 |
+
# Load blurred image
|
| 206 |
+
blur_img = Image.open(blur_path).convert('RGB')
|
| 207 |
+
|
| 208 |
+
# Convert to pixel values via BaseClass loader
|
| 209 |
+
video = self.load_frames(np.array(frames)) # shape: (output_length, H, W, C)
|
| 210 |
+
blur_input = self.load_frames(np.expand_dims(np.array(blur_img), 0)) # shape: (1, H, W, C)
|
| 211 |
+
end_time = time.time()
|
| 212 |
+
data = {
|
| 213 |
+
'file_name': os.path.join(seq_name, frame_name),
|
| 214 |
+
'blur_img': blur_input,
|
| 215 |
+
'video': video,
|
| 216 |
+
"caption": "",
|
| 217 |
+
'motion_blur_amount': torch.tensor(self.half_window, dtype=torch.long),
|
| 218 |
+
'input_interval': self.input_interval,
|
| 219 |
+
'output_interval': self.output_interval,
|
| 220 |
+
"num_frames": self.window_size,
|
| 221 |
+
"mode": "1x",
|
| 222 |
+
}
|
| 223 |
+
return data
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class OutsidePhotosDataset(BaseClass):
|
| 227 |
+
def __init__(self, *args, **kwargs):
|
| 228 |
+
super().__init__(*args, **kwargs)
|
| 229 |
+
self.image_paths = sorted(glob.glob(os.path.join(self.data_dir, '**', '*.*'), recursive=True))
|
| 230 |
+
|
| 231 |
+
INTERVALS = [
|
| 232 |
+
{"in_start": 0, "in_end": 16, "out_start": 0, "out_end": 16, "center": 8, "window_size": 16, "mode": "1x", "fps": 240},
|
| 233 |
+
{"in_start": 4, "in_end": 12, "out_start": 0, "out_end": 16, "center": 8, "window_size": 16, "mode": "2x", "fps": 240},]
|
| 234 |
+
#other modes commented out for faster processing
|
| 235 |
+
#{"in_start": 0, "in_end": 4, "out_start": 0, "out_end": 4, "center": 2, "window_size": 4, "mode": "1x", "fps": 240},
|
| 236 |
+
#{"in_start": 0, "in_end": 8, "out_start": 0, "out_end": 8, "center": 4, "window_size": 8, "mode": "1x", "fps": 240},
|
| 237 |
+
#{"in_start": 0, "in_end": 12, "out_start": 0, "out_end": 12, "center": 6, "window_size": 12, "mode": "1x", "fps": 240},
|
| 238 |
+
#{"in_start": 0, "in_end": 32, "out_start": 0, "out_end": 32, "center": 12, "window_size": 32, "mode": "lb", "fps": 120}
|
| 239 |
+
#{"in_start": 0, "in_end": 48, "out_start": 0, "out_end": 48, "center": 24, "window_size": 48, "mode": "lb", "fps": 80}
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
self.cleaned_intervals = []
|
| 243 |
+
for image_path in self.image_paths:
|
| 244 |
+
for interval in INTERVALS:
|
| 245 |
+
#create a copy of the interval dictionary
|
| 246 |
+
i = interval.copy()
|
| 247 |
+
#add the image path to the interval dictionary
|
| 248 |
+
i['video_name'] = image_path
|
| 249 |
+
video_name = i['video_name']
|
| 250 |
+
mode = i['mode']
|
| 251 |
+
|
| 252 |
+
vid_name_w_extension = os.path.relpath(video_name, self.data_dir).split('.')[0] # "frame_00000"
|
| 253 |
+
output_name = (
|
| 254 |
+
f"{vid_name_w_extension}_{mode}.mp4"
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
full_output_path = os.path.join("/datasets/sai/gencam/cogvideox/training/cogvideox-outsidephotos/deblurred", output_name) #THIS IS A HACK - YOU NEED TO UPDATE THIS TO YOUR OUTPUT DIRECTORY
|
| 258 |
+
|
| 259 |
+
# Keep only if output doesn't exist
|
| 260 |
+
if not os.path.exists(full_output_path):
|
| 261 |
+
self.cleaned_intervals.append(i)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
self.length = len(self.cleaned_intervals)
|
| 265 |
+
|
| 266 |
+
def __len__(self):
|
| 267 |
+
return self.length
|
| 268 |
+
|
| 269 |
+
def __getitem__(self, idx):
|
| 270 |
+
|
| 271 |
+
interval = self.cleaned_intervals[idx]
|
| 272 |
+
|
| 273 |
+
in_start = interval['in_start']
|
| 274 |
+
in_end = interval['in_end']
|
| 275 |
+
out_start = interval['out_start']
|
| 276 |
+
out_end = interval['out_end']
|
| 277 |
+
center = interval['center']
|
| 278 |
+
window = interval['window_size']
|
| 279 |
+
mode = interval['mode']
|
| 280 |
+
fps = interval['fps']
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
image_path = interval['video_name']
|
| 284 |
+
blur_img_original = load_as_srgb(image_path)
|
| 285 |
+
H,W = blur_img_original.size
|
| 286 |
+
|
| 287 |
+
frame_paths = []
|
| 288 |
+
frame_paths = ["../assets/dummy_image.png" for _ in range(window)] #any random path replicated
|
| 289 |
+
|
| 290 |
+
# generate test case
|
| 291 |
+
_, seq_frames, inp_int, out_int, high_fps_video, num_frames = generate_test_case(
|
| 292 |
+
frame_paths=frame_paths, window_max=window, in_start=in_start, in_end=in_end, out_start=out_start,out_end=out_end, center=center, mode=mode, fps=fps
|
| 293 |
+
)
|
| 294 |
+
file_name = image_path
|
| 295 |
+
|
| 296 |
+
# Get base directory and frame prefix
|
| 297 |
+
relative_file_name = os.path.relpath(file_name, self.data_dir)
|
| 298 |
+
base_dir = os.path.dirname(relative_file_name)
|
| 299 |
+
frame_stem = os.path.splitext(os.path.basename(file_name))[0] # "frame_00000"
|
| 300 |
+
# Build new filename
|
| 301 |
+
new_filename = (
|
| 302 |
+
f"{frame_stem}_{mode}.png"
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
blur_img =blur_img_original.resize((self.image_size[1], self.image_size[0])) #cause pil is width, height
|
| 306 |
+
|
| 307 |
+
# Final path
|
| 308 |
+
relative_file_name = os.path.join(base_dir, new_filename)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
blur_input = self.load_frames(np.expand_dims(blur_img, 0).copy())
|
| 312 |
+
# seq_frames is list of frames; stack along time dim
|
| 313 |
+
video = self.load_frames(np.stack(seq_frames, axis=0))
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
data = {
|
| 317 |
+
'file_name': relative_file_name,
|
| 318 |
+
"original_size": (H, W),
|
| 319 |
+
'blur_img': blur_input,
|
| 320 |
+
'video': video,
|
| 321 |
+
'caption': "",
|
| 322 |
+
'input_interval': inp_int,
|
| 323 |
+
'output_interval': out_int,
|
| 324 |
+
"num_frames": num_frames,
|
| 325 |
+
}
|
| 326 |
+
return data
|
| 327 |
+
|
| 328 |
+
class FullMotionBlurDataset(BaseClass):
|
| 329 |
+
"""
|
| 330 |
+
A dataset that randomly selects among 1×, 2×, or large-blur modes per sample.
|
| 331 |
+
Uses category-specific <split>_list.txt files under each subfolder of FullDataset to assemble sequences.
|
| 332 |
+
In 'test' split, it instead loads precomputed intervals from intervals_test.pkl and uses generate_test_case.
|
| 333 |
+
"""
|
| 334 |
+
def __init__(self, *args, **kwargs):
|
| 335 |
+
super().__init__(*args, **kwargs)
|
| 336 |
+
self.seq_dirs = []
|
| 337 |
+
|
| 338 |
+
# TEST split: load fixed intervals early
|
| 339 |
+
if self.split == 'test':
|
| 340 |
+
pkl_path = os.path.join(self.data_dir, 'intervals_test.pkl')
|
| 341 |
+
with open(pkl_path, 'rb') as f:
|
| 342 |
+
self.test_intervals = pickle.load(f)
|
| 343 |
+
assert self.test_intervals, f"No test intervals found in {pkl_path}"
|
| 344 |
+
|
| 345 |
+
cleaned_intervals = []
|
| 346 |
+
for interval in self.test_intervals:
|
| 347 |
+
# Extract interval values
|
| 348 |
+
in_start = interval['in_start']
|
| 349 |
+
in_end = interval['in_end']
|
| 350 |
+
out_start = interval['out_start']
|
| 351 |
+
out_end = interval['out_end']
|
| 352 |
+
center = interval['center']
|
| 353 |
+
window = interval['window_size']
|
| 354 |
+
mode = interval['mode']
|
| 355 |
+
fps = interval['fps'] # e.g. "lower_fps_frames/720p_240fps_1/frame_00247.png"
|
| 356 |
+
category, seq = interval['video_name'].split('/')#.split('/')
|
| 357 |
+
seq_dir = os.path.join(self.data_dir, category, 'lower_fps_frames', seq)
|
| 358 |
+
frame_paths = sorted(glob.glob(os.path.join(seq_dir, '*.png')))
|
| 359 |
+
rel_path = os.path.relpath(frame_paths[center], self.data_dir)
|
| 360 |
+
rel_path = os.path.splitext(rel_path)[0] # remove the file extension
|
| 361 |
+
|
| 362 |
+
output_name = (
|
| 363 |
+
f"{rel_path}_"
|
| 364 |
+
f"in{in_start:04d}_ie{in_end:04d}_"
|
| 365 |
+
f"os{out_start:04d}_oe{out_end:04d}_"
|
| 366 |
+
f"ctr{center:04d}_win{window:04d}_"
|
| 367 |
+
f"fps{fps:04d}_{mode}.mp4"
|
| 368 |
+
)
|
| 369 |
+
output_deblurred_dir = os.path.join(self.output_dir, "deblurred")
|
| 370 |
+
full_output_path = os.path.join(output_deblurred_dir, output_name)
|
| 371 |
+
|
| 372 |
+
# Keep only if output doesn't exist
|
| 373 |
+
if not os.path.exists(full_output_path):
|
| 374 |
+
cleaned_intervals.append(interval)
|
| 375 |
+
print("Len of test intervals after cleaning: ", len(cleaned_intervals))
|
| 376 |
+
print("Len of test intervals before cleaning: ", len(self.test_intervals))
|
| 377 |
+
self.test_intervals = cleaned_intervals
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
# TRAIN/VAL: build seq_dirs from each category's list or fallback
|
| 381 |
+
list_file = 'train_list.txt' if self.split == 'train' else 'test_list.txt'
|
| 382 |
+
for category in sorted(os.listdir(self.data_dir)):
|
| 383 |
+
cat_dir = os.path.join(self.data_dir, category)
|
| 384 |
+
if not os.path.isdir(cat_dir):
|
| 385 |
+
continue
|
| 386 |
+
list_path = os.path.join(cat_dir, list_file)
|
| 387 |
+
if os.path.isfile(list_path):
|
| 388 |
+
with open(list_path, 'r') as f:
|
| 389 |
+
for line in f:
|
| 390 |
+
rel = line.strip()
|
| 391 |
+
if not rel:
|
| 392 |
+
continue
|
| 393 |
+
seq_dir = os.path.join(self.data_dir, rel)
|
| 394 |
+
if os.path.isdir(seq_dir):
|
| 395 |
+
self.seq_dirs.append(seq_dir)
|
| 396 |
+
else:
|
| 397 |
+
fps_root = os.path.join(cat_dir, 'lower_fps_frames')
|
| 398 |
+
if os.path.isdir(fps_root):
|
| 399 |
+
for seq in sorted(os.listdir(fps_root)):
|
| 400 |
+
seq_path = os.path.join(fps_root, seq)
|
| 401 |
+
if os.path.isdir(seq_path):
|
| 402 |
+
self.seq_dirs.append(seq_path)
|
| 403 |
+
|
| 404 |
+
if self.split == 'val':
|
| 405 |
+
self.seq_dirs = self.seq_dirs[:5]
|
| 406 |
+
if self.split == 'train':
|
| 407 |
+
self.seq_dirs *= 10
|
| 408 |
+
|
| 409 |
+
assert self.seq_dirs, \
|
| 410 |
+
f"No sequences found for split '{self.split}' in {self.data_dir}"
|
| 411 |
+
|
| 412 |
+
def __len__(self):
|
| 413 |
+
return len(self.test_intervals) if self.split == 'test' else len(self.seq_dirs)
|
| 414 |
+
|
| 415 |
+
def __getitem__(self, idx):
|
| 416 |
+
# Prepare base items
|
| 417 |
+
if self.split == 'test':
|
| 418 |
+
interval = self.test_intervals[idx]
|
| 419 |
+
category, seq = interval['video_name'].split('/')
|
| 420 |
+
seq_dir = os.path.join(self.data_dir, category, 'lower_fps_frames', seq)
|
| 421 |
+
frame_paths = sorted(glob.glob(os.path.join(seq_dir, '*.png')))
|
| 422 |
+
|
| 423 |
+
in_start = interval['in_start']
|
| 424 |
+
in_end = interval['in_end']
|
| 425 |
+
out_start = interval['out_start']
|
| 426 |
+
out_end = interval['out_end']
|
| 427 |
+
center = interval['center']
|
| 428 |
+
window = interval['window_size']
|
| 429 |
+
mode = interval['mode']
|
| 430 |
+
fps = interval['fps']
|
| 431 |
+
|
| 432 |
+
# generate test case
|
| 433 |
+
blur_img, seq_frames, inp_int, out_int, high_fps_video, num_frames = generate_test_case(
|
| 434 |
+
frame_paths=frame_paths, window_max=window, in_start=in_start, in_end=in_end, out_start=out_start,out_end=out_end, center=center, mode=mode, fps=fps
|
| 435 |
+
)
|
| 436 |
+
file_name = frame_paths[center]
|
| 437 |
+
|
| 438 |
+
else:
|
| 439 |
+
seq_dir = self.seq_dirs[idx]
|
| 440 |
+
frame_paths = sorted(glob.glob(os.path.join(seq_dir, '*.png')))
|
| 441 |
+
mode = random.choice(['1x', '2x', 'large_blur'])
|
| 442 |
+
|
| 443 |
+
if mode == '1x' or len(frame_paths) < 50:
|
| 444 |
+
base_rate = random.choice([1, 2])
|
| 445 |
+
blur_img, seq_frames, inp_int, out_int, _ = generate_1x_sequence(
|
| 446 |
+
frame_paths, window_max=16, output_len=17, base_rate=base_rate
|
| 447 |
+
)
|
| 448 |
+
elif mode == '2x':
|
| 449 |
+
base_rate = random.choice([1, 2])
|
| 450 |
+
blur_img, seq_frames, inp_int, out_int, _ = generate_2x_sequence(
|
| 451 |
+
frame_paths, window_max=16, output_len=17, base_rate=base_rate
|
| 452 |
+
)
|
| 453 |
+
else:
|
| 454 |
+
max_base = min((len(frame_paths) - 1) // 17, 3)
|
| 455 |
+
base_rate = random.randint(1, max_base)
|
| 456 |
+
blur_img, seq_frames, inp_int, out_int, _ = generate_large_blur_sequence(
|
| 457 |
+
frame_paths, window_max=16, output_len=17, base_rate=base_rate
|
| 458 |
+
)
|
| 459 |
+
file_name = frame_paths[0]
|
| 460 |
+
num_frames = 16
|
| 461 |
+
|
| 462 |
+
# blur_img is a single frame; wrap in batch dim
|
| 463 |
+
blur_input = self.load_frames(np.expand_dims(blur_img, 0))
|
| 464 |
+
# seq_frames is list of frames; stack along time dim
|
| 465 |
+
video = self.load_frames(np.stack(seq_frames, axis=0))
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
relative_file_name = os.path.relpath(file_name, self.data_dir)
|
| 469 |
+
|
| 470 |
+
if self.split == 'test':
|
| 471 |
+
# Get base directory and frame prefix
|
| 472 |
+
base_dir = os.path.dirname(relative_file_name)
|
| 473 |
+
frame_stem = os.path.splitext(os.path.basename(relative_file_name))[0] # "frame_00000"
|
| 474 |
+
|
| 475 |
+
# Build new filename
|
| 476 |
+
new_filename = (
|
| 477 |
+
f"{frame_stem}_"
|
| 478 |
+
f"in{in_start:04d}_ie{in_end:04d}_"
|
| 479 |
+
f"os{out_start:04d}_oe{out_end:04d}_"
|
| 480 |
+
f"ctr{center:04d}_win{window:04d}_"
|
| 481 |
+
f"fps{fps:04d}_{mode}.png"
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
# Final path
|
| 485 |
+
relative_file_name = os.path.join(base_dir, new_filename)
|
| 486 |
+
|
| 487 |
+
data = {
|
| 488 |
+
'file_name': relative_file_name,
|
| 489 |
+
'blur_img': blur_input,
|
| 490 |
+
'num_frames': num_frames,
|
| 491 |
+
'video': video,
|
| 492 |
+
'caption': "",
|
| 493 |
+
'mode': mode,
|
| 494 |
+
'input_interval': inp_int,
|
| 495 |
+
'output_interval': out_int,
|
| 496 |
+
}
|
| 497 |
+
if self.split == 'test':
|
| 498 |
+
high_fps_video = self.load_frames(np.stack(high_fps_video, axis=0))
|
| 499 |
+
data['high_fps_video'] = high_fps_video
|
| 500 |
+
return data
|
| 501 |
+
|
| 502 |
+
class GoPro2xMotionBlurDataset(BaseClass):
|
| 503 |
+
def __init__(self,
|
| 504 |
+
*args, **kwargs):
|
| 505 |
+
super().__init__(*args, **kwargs)
|
| 506 |
+
# Set blur and sharp directories based on split
|
| 507 |
+
if self.split == 'train':
|
| 508 |
+
self.blur_root = os.path.join(self.data_dir, 'train', 'blur')
|
| 509 |
+
self.sharp_root = os.path.join(self.data_dir, 'train', 'sharp')
|
| 510 |
+
elif self.split in ['val', 'test']:
|
| 511 |
+
self.blur_root = os.path.join(self.data_dir, 'test', 'blur')
|
| 512 |
+
self.sharp_root = os.path.join(self.data_dir, 'test', 'sharp')
|
| 513 |
+
else:
|
| 514 |
+
raise ValueError(f"Unsupported split: {self.split}")
|
| 515 |
+
|
| 516 |
+
# Collect all blurred image paths
|
| 517 |
+
pattern = os.path.join(self.blur_root, '*', '*.png')
|
| 518 |
+
|
| 519 |
+
def get_sharp_paths(blur_paths):
|
| 520 |
+
sharp_paths = []
|
| 521 |
+
for blur_path in blur_paths:
|
| 522 |
+
base_dir = blur_path.replace('/blur/', '/sharp/')
|
| 523 |
+
frame_num = int(os.path.basename(blur_path).split('.')[0])
|
| 524 |
+
dir_path = os.path.dirname(base_dir)
|
| 525 |
+
sequence = [
|
| 526 |
+
os.path.join(dir_path, f"{frame_num + offset:06d}.png")
|
| 527 |
+
for offset in range(-6, 7)
|
| 528 |
+
]
|
| 529 |
+
if all(os.path.exists(path) for path in sequence):
|
| 530 |
+
sharp_paths.append(sequence)
|
| 531 |
+
return sharp_paths
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
self.blur_paths = sorted(glob.glob(pattern))
|
| 537 |
+
filtered_blur_paths = []
|
| 538 |
+
for path in self.blur_paths:
|
| 539 |
+
output_deblurred_dir = os.path.join(self.output_dir, "deblurred")
|
| 540 |
+
full_output_path = Path(output_deblurred_dir, *path.split('/')[-2:]).with_suffix(".mp4")
|
| 541 |
+
if not os.path.exists(full_output_path):
|
| 542 |
+
filtered_blur_paths.append(path)
|
| 543 |
+
self.blur_paths = filtered_blur_paths
|
| 544 |
+
|
| 545 |
+
self.sharp_paths = get_sharp_paths(self.blur_paths)
|
| 546 |
+
if self.split == 'val':
|
| 547 |
+
# Optional: limit validation subset
|
| 548 |
+
self.sharp_paths = self.sharp_paths[:5]
|
| 549 |
+
self.length = len(self.sharp_paths)
|
| 550 |
+
|
| 551 |
+
def __len__(self):
|
| 552 |
+
return self.length
|
| 553 |
+
|
| 554 |
+
def __getitem__(self, idx):
|
| 555 |
+
# Path to the blurred (center) frame
|
| 556 |
+
sharp_path = self.sharp_paths[idx]
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
# Load sharp frames
|
| 560 |
+
blur_img, seq_frames, inp_int, out_int, high_fps_video, num_frames = generate_test_case(
|
| 561 |
+
frame_paths=sharp_path, window_max=13, in_start=3, in_end=10, out_start=0,out_end=13, center=6, mode="2x", fps=240
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
# Convert to pixel values via BaseClass loader
|
| 565 |
+
video = self.load_frames(np.array(seq_frames)) # shape: (output_length, H, W, C)
|
| 566 |
+
blur_input = self.load_frames(np.expand_dims(np.array(blur_img), 0)) # shape: (1, H, W, C)
|
| 567 |
+
last_two_parts_of_path = os.path.join(*sharp_path[6].split(os.sep)[-2:])
|
| 568 |
+
#print(f"Time taken to load and process data: {end_time - start_time:.2f} seconds")
|
| 569 |
+
data = {
|
| 570 |
+
'file_name': last_two_parts_of_path,
|
| 571 |
+
'blur_img': blur_input,
|
| 572 |
+
'video': video,
|
| 573 |
+
"caption": "",
|
| 574 |
+
'input_interval': inp_int,
|
| 575 |
+
'output_interval': out_int,
|
| 576 |
+
"num_frames": num_frames,
|
| 577 |
+
"mode": "2x",
|
| 578 |
+
}
|
| 579 |
+
return data
|
| 580 |
+
|
| 581 |
+
class BAISTDataset(BaseClass):
|
| 582 |
+
def __init__(self, *args, **kwargs):
|
| 583 |
+
super().__init__(*args, **kwargs)
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
test_folders = {
|
| 587 |
+
"gWA_sBM_c01_d26_mWA0_ch06_cropped_32X": None,
|
| 588 |
+
"gBR_sBM_c01_d05_mBR0_ch01_cropped_32X": None,
|
| 589 |
+
"gMH_sBM_c01_d22_mMH0_ch04_cropped_32X": None,
|
| 590 |
+
"gHO_sBM_c01_d20_mHO0_ch05_cropped_32X": None,
|
| 591 |
+
"gMH_sBM_c01_d22_mMH0_ch08_cropped_32X": None,
|
| 592 |
+
"gWA_sBM_c01_d26_mWA0_ch02_cropped_32X": None,
|
| 593 |
+
"gJS_sBM_c01_d02_mJS0_ch08_cropped_32X": None,
|
| 594 |
+
"gHO_sBM_c01_d20_mHO0_ch07_cropped_32X": None,
|
| 595 |
+
"gHO_sBM_c01_d20_mHO0_ch06_cropped_32X": None,
|
| 596 |
+
"gBR_sBM_c01_d05_mBR0_ch03_cropped_32X": None,
|
| 597 |
+
"gBR_sBM_c01_d05_mBR0_ch05_cropped_32X": None,
|
| 598 |
+
"gHO_sBM_c01_d20_mHO0_ch02_cropped_32X": None,
|
| 599 |
+
"gHO_sBM_c01_d20_mHO0_ch03_cropped_32X": None,
|
| 600 |
+
"gHO_sBM_c01_d20_mHO0_ch09_cropped_32X": None,
|
| 601 |
+
"gMH_sBM_c01_d22_mMH0_ch10_cropped_32X": None,
|
| 602 |
+
"gWA_sBM_c01_d26_mWA0_ch10_cropped_32X": None,
|
| 603 |
+
"gBR_sBM_c01_d05_mBR0_ch06_cropped_32X": None,
|
| 604 |
+
"gHO_sBM_c01_d20_mHO0_ch08_cropped_32X": None,
|
| 605 |
+
"gMH_sBM_c01_d22_mMH0_ch06_cropped_32X": None,
|
| 606 |
+
"gHO_sBM_c01_d20_mHO0_ch10_cropped_32X": None,
|
| 607 |
+
"gMH_sBM_c01_d22_mMH0_ch09_cropped_32X": None,
|
| 608 |
+
"gMH_sBM_c01_d22_mMH0_ch02_cropped_32X": None,
|
| 609 |
+
"gBR_sBM_c01_d05_mBR0_ch04_cropped_32X": None,
|
| 610 |
+
"gPO_sBM_c01_d10_mPO0_ch09_cropped_32X": None,
|
| 611 |
+
"gMH_sBM_c01_d22_mMH0_ch01_cropped_32X": None,
|
| 612 |
+
"gMH_sBM_c01_d22_mMH0_ch07_cropped_32X": None,
|
| 613 |
+
"gMH_sBM_c01_d22_mMH0_ch03_cropped_32X": None,
|
| 614 |
+
"gHO_sBM_c01_d20_mHO0_ch04_cropped_32X": None,
|
| 615 |
+
"gBR_sBM_c01_d05_mBR0_ch02_cropped_32X": None,
|
| 616 |
+
"gHO_sBM_c01_d20_mHO0_ch01_cropped_32X": None,
|
| 617 |
+
"gMH_sBM_c01_d22_mMH0_ch05_cropped_32X": None,
|
| 618 |
+
"gPO_sBM_c01_d10_mPO0_ch10_cropped_32X": None,
|
| 619 |
+
}
|
| 620 |
+
|
| 621 |
+
def collect_blur_images(root_dir, allowed_folders, skip_start=40, skip_end=40):
|
| 622 |
+
blur_image_paths = []
|
| 623 |
+
|
| 624 |
+
for dirpath, dirnames, filenames in os.walk(root_dir):
|
| 625 |
+
if os.path.basename(dirpath) == "blur":
|
| 626 |
+
parent_folder = os.path.basename(os.path.dirname(dirpath))
|
| 627 |
+
if (self.split in ["test", "val"] and parent_folder in test_folders) or (self.split in "train" and parent_folder not in test_folders):
|
| 628 |
+
# Filter and sort valid image filenames
|
| 629 |
+
valid_files = [
|
| 630 |
+
f for f in filenames
|
| 631 |
+
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')) and os.path.splitext(f)[0].isdigit()
|
| 632 |
+
]
|
| 633 |
+
valid_files.sort(key=lambda x: int(os.path.splitext(x)[0]))
|
| 634 |
+
|
| 635 |
+
# Skip first and last N files
|
| 636 |
+
middle_files = valid_files[skip_start:len(valid_files) - skip_end]
|
| 637 |
+
|
| 638 |
+
for f in middle_files:
|
| 639 |
+
from pathlib import Path
|
| 640 |
+
full_path = Path(os.path.join(dirpath, f))
|
| 641 |
+
output_deblurred_dir = os.path.join(self.output_dir, "deblurred")
|
| 642 |
+
full_output_path = Path(output_deblurred_dir, *full_path.parts[-3:]).with_suffix(".mp4")
|
| 643 |
+
if not os.path.exists(full_output_path) or self.split in ["train", "val"]:
|
| 644 |
+
blur_image_paths.append(os.path.join(dirpath, f))
|
| 645 |
+
|
| 646 |
+
return blur_image_paths
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
self.image_paths = collect_blur_images(self.data_dir, test_folders)
|
| 651 |
+
#if bbx path does not exist, remove the image path
|
| 652 |
+
self.image_paths = [path for path in self.image_paths if os.path.exists(path.replace("blur", "blur_anno").replace(".png", ".pkl"))]
|
| 653 |
+
|
| 654 |
+
filtered_image_paths = []
|
| 655 |
+
for blur_path in self.image_paths:
|
| 656 |
+
base_dir = blur_path.replace('/blur/', '/sharp/').replace('.png', '')
|
| 657 |
+
sharp_paths = [f"{base_dir}_{i:03d}.png" for i in range(7)]
|
| 658 |
+
if all(os.path.exists(p) for p in sharp_paths):
|
| 659 |
+
filtered_image_paths.append(blur_path)
|
| 660 |
+
|
| 661 |
+
self.image_paths = filtered_image_paths
|
| 662 |
+
|
| 663 |
+
if self.split == 'val':
|
| 664 |
+
# Optional: limit validation subset
|
| 665 |
+
self.image_paths = self.image_paths[:4]
|
| 666 |
+
self.length = len(self.image_paths)
|
| 667 |
+
|
| 668 |
+
def __len__(self):
|
| 669 |
+
return self.length
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
def __getitem__(self, idx):
|
| 673 |
+
image_path = self.image_paths[idx]
|
| 674 |
+
blur_img_original = load_as_srgb(image_path)
|
| 675 |
+
|
| 676 |
+
bbx_path = image_path.replace("blur", "blur_anno").replace(".png", ".pkl")
|
| 677 |
+
|
| 678 |
+
#load the bbx path
|
| 679 |
+
bbx = np.load(bbx_path, allow_pickle=True)['bbox'][0:4]
|
| 680 |
+
# Final crop box
|
| 681 |
+
#turn crop_box into tupel
|
| 682 |
+
W,H = blur_img_original.size
|
| 683 |
+
blur_img = blur_img_original.resize((self.image_size[1], self.image_size[0]), resample=Image.BILINEAR)
|
| 684 |
+
|
| 685 |
+
#cause pil is width, height
|
| 686 |
+
blur_np = np.array([blur_img])
|
| 687 |
+
|
| 688 |
+
base_dir = os.path.dirname(os.path.dirname(image_path)) # strip /blur
|
| 689 |
+
filename = os.path.splitext(os.path.basename(image_path))[0] # '00000000'
|
| 690 |
+
sharp_dir = os.path.join(base_dir, "sharp")
|
| 691 |
+
|
| 692 |
+
frame_paths = [
|
| 693 |
+
os.path.join(sharp_dir, f"{filename}_{i:03d}.png")
|
| 694 |
+
for i in range(7)
|
| 695 |
+
]
|
| 696 |
+
|
| 697 |
+
_, seq_frames, inp_int, out_int, high_fps_video, num_frames = generate_test_case(
|
| 698 |
+
frame_paths=frame_paths, window_max=7, in_start=0, in_end=7, out_start=0,out_end=7, center=3, mode="1x", fps=240
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
pixel_values = self.load_frames(np.stack(seq_frames, axis=0))
|
| 702 |
+
blur_pixel_values = self.load_frames(blur_np)
|
| 703 |
+
|
| 704 |
+
relative_file_name = os.path.relpath(image_path, self.data_dir)
|
| 705 |
+
|
| 706 |
+
out_bbx = bbx.copy()
|
| 707 |
+
|
| 708 |
+
scale_x = blur_pixel_values.shape[3]/W
|
| 709 |
+
scale_y = blur_pixel_values.shape[2]/H
|
| 710 |
+
#scale the bbx
|
| 711 |
+
out_bbx[0] = int(out_bbx[0] * scale_x)
|
| 712 |
+
out_bbx[1] = int(out_bbx[1] * scale_y)
|
| 713 |
+
out_bbx[2] = int(out_bbx[2] * scale_x)
|
| 714 |
+
out_bbx[3] = int(out_bbx[3] * scale_y)
|
| 715 |
+
|
| 716 |
+
out_bbx = torch.tensor(out_bbx, dtype=torch.uint32)
|
| 717 |
+
|
| 718 |
+
#crop image using the bbx
|
| 719 |
+
blur_img_npy = np.array(blur_img)
|
| 720 |
+
out_bbx_npy = out_bbx.numpy().astype(np.uint32)
|
| 721 |
+
blur_img_npy = blur_img_npy[out_bbx_npy[1]:out_bbx_npy[3], out_bbx_npy[0]:out_bbx_npy[2], :]
|
| 722 |
+
|
| 723 |
+
data = {
|
| 724 |
+
'file_name': relative_file_name,
|
| 725 |
+
'blur_img': blur_pixel_values,
|
| 726 |
+
'video': pixel_values,
|
| 727 |
+
'bbx': out_bbx,
|
| 728 |
+
'caption': "",
|
| 729 |
+
'input_interval': inp_int,
|
| 730 |
+
'output_interval': out_int,
|
| 731 |
+
"num_frames": num_frames,
|
| 732 |
+
'mode': "1x",
|
| 733 |
+
}
|
| 734 |
+
return data
|
| 735 |
+
|
training/helpers.py
ADDED
|
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
def random_insert_latent_frame(
|
| 8 |
+
image_latent: torch.Tensor,
|
| 9 |
+
noisy_model_input: torch.Tensor,
|
| 10 |
+
target_latents: torch.Tensor,
|
| 11 |
+
input_intervals: torch.Tensor,
|
| 12 |
+
output_intervals: torch.Tensor,
|
| 13 |
+
special_info
|
| 14 |
+
):
|
| 15 |
+
"""
|
| 16 |
+
Inserts latent frames into noisy input, pads targets, and builds flattened intervals with flags.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
image_latent: [B, latent_count, C, H, W]
|
| 20 |
+
noisy_model_input:[B, F, C, H, W]
|
| 21 |
+
target_latents: [B, F, C, H, W]
|
| 22 |
+
input_intervals: [B, N, frames_per_latent, L]
|
| 23 |
+
output_intervals: [B, M, frames_per_latent, L]
|
| 24 |
+
|
| 25 |
+
For each sample randomly choose:
|
| 26 |
+
Mode A (50%):
|
| 27 |
+
- Insert two image_latent frames at start of noisy input and targets.
|
| 28 |
+
- Pad target_latents by prepending two zero-frames.
|
| 29 |
+
- Pad input_intervals by repeating its last group once.
|
| 30 |
+
Mode B (50%):
|
| 31 |
+
- Insert one image_latent frame at start and repeat last noisy frame at end.
|
| 32 |
+
- Pad target_latents by prepending one one-frame and appending last target frame.
|
| 33 |
+
- Pad output_intervals by repeating its last group once.
|
| 34 |
+
|
| 35 |
+
After padding intervals, flatten each group from [frames_per_latent, L] to [frames_per_latent * L],
|
| 36 |
+
then append a 4-element flag (1 for input groups, 0 for output groups).
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
outputs: Tensor [B, F+2, C, H, W]
|
| 40 |
+
new_targets: Tensor [B, F+2, C, H, W]
|
| 41 |
+
masks: Tensor [B, F+2] bool mask of latent inserts
|
| 42 |
+
intervals: Tensor [B, N+M+1, fpl * L + 4]
|
| 43 |
+
"""
|
| 44 |
+
B, F, C, H, W = noisy_model_input.shape
|
| 45 |
+
_, N, fpl, L = input_intervals.shape
|
| 46 |
+
_, M, _, _ = output_intervals.shape
|
| 47 |
+
device = noisy_model_input.device
|
| 48 |
+
|
| 49 |
+
new_F = F + 1 if special_info == "just_one" else F + 2
|
| 50 |
+
outputs = torch.empty((B, new_F, C, H, W), device=device)
|
| 51 |
+
masks = torch.zeros((B, new_F), dtype=torch.bool, device=device)
|
| 52 |
+
combined_groups = N + M #+ 1
|
| 53 |
+
feature_len = fpl * L
|
| 54 |
+
# intervals = torch.empty((B, combined_groups, feature_len + 4), device=device,
|
| 55 |
+
# dtype=input_intervals.dtype)
|
| 56 |
+
intervals = torch.empty((B, combined_groups, feature_len), device=device,
|
| 57 |
+
dtype=input_intervals.dtype)
|
| 58 |
+
new_targets = torch.empty((B, new_F, C, H, W), device=device,
|
| 59 |
+
dtype=target_latents.dtype)
|
| 60 |
+
|
| 61 |
+
for b in range(B):
|
| 62 |
+
latent = image_latent[b, 0]
|
| 63 |
+
frames = noisy_model_input[b]
|
| 64 |
+
tgt = target_latents[b]
|
| 65 |
+
|
| 66 |
+
limit = 10 if special_info == "use_a" else 0.5
|
| 67 |
+
if special_info == "just_one": #ALWAYS_MODE_A
|
| 68 |
+
# Mode A: two latent inserts, zero-prefixed targets
|
| 69 |
+
outputs[b, 0] = latent
|
| 70 |
+
masks[b, :1] = True
|
| 71 |
+
outputs[b, 1:] = frames
|
| 72 |
+
|
| 73 |
+
# pad targets: two large-numbers - these should be ignored
|
| 74 |
+
large_number = torch.ones_like(tgt[0])*10000
|
| 75 |
+
new_targets[b, 0] = large_number
|
| 76 |
+
new_targets[b, 1:] = tgt
|
| 77 |
+
|
| 78 |
+
# pad intervals: input + replicated last input group
|
| 79 |
+
#pad_group = input_intervals[b, -1:].clone()
|
| 80 |
+
in_groups = input_intervals[b] #torch.cat([input_intervals[b], pad_group], dim=0)
|
| 81 |
+
out_groups = output_intervals[b]
|
| 82 |
+
elif random.random() < limit: #ALWAYS_MODE_A
|
| 83 |
+
# Mode A: two latent inserts, zero-prefixed targets
|
| 84 |
+
outputs[b, 0] = latent
|
| 85 |
+
outputs[b, 1] = latent
|
| 86 |
+
masks[b, :2] = True
|
| 87 |
+
outputs[b, 2:] = frames
|
| 88 |
+
|
| 89 |
+
# pad targets: two large-numbers - these should be ignored
|
| 90 |
+
large_number = torch.ones_like(tgt[0])*10000
|
| 91 |
+
new_targets[b, 0] = large_number
|
| 92 |
+
new_targets[b, 1] = large_number
|
| 93 |
+
new_targets[b, 2:] = tgt
|
| 94 |
+
|
| 95 |
+
# pad intervals: input + replicated last input group
|
| 96 |
+
pad_group = input_intervals[b, -1:].clone()
|
| 97 |
+
in_groups = torch.cat([input_intervals[b], pad_group], dim=0)
|
| 98 |
+
out_groups = output_intervals[b]
|
| 99 |
+
else:
|
| 100 |
+
# Mode B: one latent insert & last-frame repeat, one-prefixed/appended targets
|
| 101 |
+
outputs[b, 0] = latent
|
| 102 |
+
masks[b, 0] = True
|
| 103 |
+
outputs[b, 1:new_F-1] = frames
|
| 104 |
+
outputs[b, new_F-1] = frames[-1]
|
| 105 |
+
|
| 106 |
+
# pad targets: one one-frame then original then last frame
|
| 107 |
+
zero = torch.zeros_like(tgt[0])
|
| 108 |
+
new_targets[b, 0] = zero
|
| 109 |
+
new_targets[b, 1:new_F-1] = tgt
|
| 110 |
+
new_targets[b, new_F-1] = tgt[-1]
|
| 111 |
+
|
| 112 |
+
# pad intervals: output + replicated last output group
|
| 113 |
+
in_groups = input_intervals[b]
|
| 114 |
+
pad_group = output_intervals[b, -1:].clone()
|
| 115 |
+
out_groups = torch.cat([output_intervals[b], pad_group], dim=0)
|
| 116 |
+
|
| 117 |
+
# flatten & flag groups
|
| 118 |
+
flat_in = in_groups.reshape(-1, feature_len)
|
| 119 |
+
proc_in = torch.cat([flat_in], dim=1)
|
| 120 |
+
|
| 121 |
+
flat_out = out_groups.reshape(-1, feature_len)
|
| 122 |
+
proc_out = torch.cat([flat_out], dim=1)
|
| 123 |
+
|
| 124 |
+
intervals[b] = torch.cat([proc_in, proc_out], dim=0)
|
| 125 |
+
|
| 126 |
+
return outputs, new_targets, masks, intervals
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def transform_intervals(
|
| 132 |
+
intervals: torch.Tensor,
|
| 133 |
+
frames_per_latent: int = 4,
|
| 134 |
+
repeat_first: bool = True
|
| 135 |
+
) -> torch.Tensor:
|
| 136 |
+
"""
|
| 137 |
+
Pad and reshape intervals into [B, num_latent_frames, frames_per_latent, L].
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
intervals: Tensor of shape [B, N, L]
|
| 141 |
+
frames_per_latent: number of frames per latent group (e.g., 4)
|
| 142 |
+
repeat_first: if True, pad at the beginning by repeating the first row; otherwise pad at the end by repeating the last row.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
Tensor of shape [B, num_latent_frames, frames_per_latent, L]
|
| 146 |
+
"""
|
| 147 |
+
B, N, L = intervals.shape
|
| 148 |
+
num_latent = math.ceil(N / frames_per_latent)
|
| 149 |
+
target_N = num_latent * frames_per_latent
|
| 150 |
+
pad_count = target_N - N
|
| 151 |
+
|
| 152 |
+
if pad_count > 0:
|
| 153 |
+
# choose row to repeat
|
| 154 |
+
pad_row = intervals[:, :1, :] if repeat_first else intervals[:, -1:, :]
|
| 155 |
+
# replicate pad_row pad_count times
|
| 156 |
+
pad = pad_row.repeat(1, pad_count, 1)
|
| 157 |
+
# pad at beginning or end
|
| 158 |
+
if repeat_first:
|
| 159 |
+
expanded = torch.cat([pad, intervals], dim=1)
|
| 160 |
+
else:
|
| 161 |
+
expanded = torch.cat([intervals, pad], dim=1)
|
| 162 |
+
else:
|
| 163 |
+
expanded = intervals[:, :target_N, :]
|
| 164 |
+
|
| 165 |
+
# reshape into latent-frame groups
|
| 166 |
+
return expanded.view(B, num_latent, frames_per_latent, L)
|
| 167 |
+
|
| 168 |
+
import random
|
| 169 |
+
import numpy as np
|
| 170 |
+
import torch
|
| 171 |
+
from PIL import Image
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
import random
|
| 175 |
+
import numpy as np
|
| 176 |
+
import torch
|
| 177 |
+
from PIL import Image
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def build_blur(frame_paths, gamma=2.2):
|
| 181 |
+
"""
|
| 182 |
+
Simulate motion blur using inverse-gamma (linear-light) summation:
|
| 183 |
+
- Load each image, convert to float32 sRGB [0,255]
|
| 184 |
+
- Linearize via inverse gamma: linear = (img/255)^gamma
|
| 185 |
+
- Sum linear values, average, then re-encode via gamma: (linear_avg)^(1/gamma)*255
|
| 186 |
+
Returns a uint8 numpy array.
|
| 187 |
+
"""
|
| 188 |
+
acc_lin = None
|
| 189 |
+
for p in frame_paths:
|
| 190 |
+
img = np.array(Image.open(p).convert('RGB'), dtype=np.float32)
|
| 191 |
+
# normalize to [0,1] then linearize
|
| 192 |
+
lin = np.power(img / 255.0, gamma)
|
| 193 |
+
acc_lin = lin if acc_lin is None else acc_lin + lin
|
| 194 |
+
# average in linear domain
|
| 195 |
+
avg_lin = acc_lin / len(frame_paths)
|
| 196 |
+
# gamma-encode back to sRGB domain
|
| 197 |
+
srgb = np.power(avg_lin, 1.0 / gamma) * 255.0
|
| 198 |
+
return np.clip(srgb, 0, 255).astype(np.uint8)
|
| 199 |
+
|
| 200 |
+
def generate_1x_sequence(frame_paths, window_max =16, output_len=17, base_rate=1, start = None):
|
| 201 |
+
"""
|
| 202 |
+
1× mode at arbitrary base_rate (units of 1/240s):
|
| 203 |
+
- Treat each output step as the sum of `base_rate` consecutive raw frames.
|
| 204 |
+
- Pick window size W ∈ [1, output_len]
|
| 205 |
+
- Randomly choose start index so W*base_rate frames fit
|
| 206 |
+
- Group raw frames into W groups of length base_rate
|
| 207 |
+
- Build blur image over all W*base_rate frames for input
|
| 208 |
+
- For each group, build a blurred output frame by summing its base_rate frames
|
| 209 |
+
- Pad sequence of W blurred frames to output_len by repeating last blurred frame
|
| 210 |
+
- Input interval always [-0.5, 0.5]
|
| 211 |
+
- Output intervals reflect each group’s coverage within [-0.5,0.5]
|
| 212 |
+
"""
|
| 213 |
+
N = len(frame_paths)
|
| 214 |
+
max_w = min(output_len, N // base_rate)
|
| 215 |
+
max_w = min(max_w, window_max)
|
| 216 |
+
W = random.randint(1, max_w)
|
| 217 |
+
if start is not None:
|
| 218 |
+
# choose start so that W*base_rate frames fit
|
| 219 |
+
assert N >= W * base_rate, f"Not enough frames for base_rate={base_rate}, need {W * base_rate}, got {N}"
|
| 220 |
+
else:
|
| 221 |
+
start = random.randint(0, N - W * base_rate)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# group start indices
|
| 225 |
+
group_starts = [start + i * base_rate for i in range(W)]
|
| 226 |
+
# flatten raw frame paths for blur input
|
| 227 |
+
blur_paths = []
|
| 228 |
+
for gs in group_starts:
|
| 229 |
+
blur_paths.extend(frame_paths[gs:gs + base_rate])
|
| 230 |
+
blur_img = build_blur(blur_paths)
|
| 231 |
+
|
| 232 |
+
# build blurred output frames per group
|
| 233 |
+
seq = []
|
| 234 |
+
for gs in group_starts:
|
| 235 |
+
group = frame_paths[gs:gs + base_rate]
|
| 236 |
+
seq.append(build_blur(group))
|
| 237 |
+
# pad with last blurred frame
|
| 238 |
+
seq += [seq[-1]] * (output_len - len(seq))
|
| 239 |
+
|
| 240 |
+
input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float)
|
| 241 |
+
# each group covers interval of length 1/W
|
| 242 |
+
step = 1.0 / W
|
| 243 |
+
intervals = [[-0.5 + i * step, -0.5 + (i + 1) * step] for i in range(W)]
|
| 244 |
+
num_frames = len(intervals)
|
| 245 |
+
intervals += [intervals[-1]] * (output_len - W)
|
| 246 |
+
output_intervals = torch.tensor(intervals, dtype=torch.float)
|
| 247 |
+
|
| 248 |
+
return blur_img, seq, input_interval, output_intervals, num_frames
|
| 249 |
+
|
| 250 |
+
def generate_2x_sequence(frame_paths, window_max =16, output_len=17, base_rate=1):
|
| 251 |
+
"""
|
| 252 |
+
2× mode:
|
| 253 |
+
- Logical window of W output-steps so that 2*W ≤ output_len
|
| 254 |
+
- Raw window spans W*base_rate frames
|
| 255 |
+
- Build blur only over that raw window (flattened) for input
|
| 256 |
+
- before_count = W//2, after_count = W - before_count
|
| 257 |
+
- Define groups for before, during, and after each of length base_rate
|
| 258 |
+
- Build blurred frames for each group
|
| 259 |
+
- Pad sequence of 2*W blurred frames to output_len by repeating last
|
| 260 |
+
- Input interval always [-0.5,0.5]
|
| 261 |
+
- Output intervals relative to window: each group’s center
|
| 262 |
+
"""
|
| 263 |
+
N = len(frame_paths)
|
| 264 |
+
max_w = min(output_len // 2, N // base_rate)
|
| 265 |
+
max_w = min(max_w, window_max)
|
| 266 |
+
W = random.randint(1, max_w)
|
| 267 |
+
before_count = W // 2
|
| 268 |
+
after_count = W - before_count
|
| 269 |
+
# choose start so that before and after stay within bounds
|
| 270 |
+
min_start = before_count * base_rate
|
| 271 |
+
max_start = N - (W + after_count) * base_rate
|
| 272 |
+
# ensure we can pick a valid start, else fail
|
| 273 |
+
assert max_start >= min_start, f"Cannot satisfy before/after window for W={W}, base_rate={base_rate}, N={N}"
|
| 274 |
+
start = random.randint(min_start, max_start)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
# window group starts
|
| 278 |
+
window_starts = [start + i * base_rate for i in range(W)]
|
| 279 |
+
# flatten for blur input
|
| 280 |
+
blur_paths = []
|
| 281 |
+
for gs in window_starts:
|
| 282 |
+
blur_paths.extend(frame_paths[gs:gs + base_rate])
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
blur_img = build_blur(blur_paths)
|
| 286 |
+
|
| 287 |
+
# define before/after group starts
|
| 288 |
+
before_count = W // 2
|
| 289 |
+
after_count = W - before_count
|
| 290 |
+
before_starts = [max(0, start - (i + 1) * base_rate) for i in range(before_count)][::-1]
|
| 291 |
+
after_starts = [min(N - base_rate, start + W * base_rate + i * base_rate) for i in range(after_count)]
|
| 292 |
+
|
| 293 |
+
# all group starts in sequence
|
| 294 |
+
group_starts = before_starts + window_starts + after_starts
|
| 295 |
+
# build blurred frames per group
|
| 296 |
+
seq = []
|
| 297 |
+
for gs in group_starts:
|
| 298 |
+
group = frame_paths[gs:gs + base_rate]
|
| 299 |
+
seq.append(build_blur(group))
|
| 300 |
+
# pad blurred frames to output_len
|
| 301 |
+
seq += [seq[-1]] * (output_len - len(seq))
|
| 302 |
+
|
| 303 |
+
input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float)
|
| 304 |
+
# each group covers 1/(2W) around its center within [-0.5,0.5]
|
| 305 |
+
half = 0.5 / W
|
| 306 |
+
centers = [((gs - start) / (W * base_rate)) - 0.5 + half
|
| 307 |
+
for gs in group_starts]
|
| 308 |
+
intervals = [[c - half, c + half] for c in centers]
|
| 309 |
+
num_frames = len(intervals)
|
| 310 |
+
intervals += [intervals[-1]] * (output_len - len(intervals))
|
| 311 |
+
output_intervals = torch.tensor(intervals, dtype=torch.float)
|
| 312 |
+
|
| 313 |
+
return blur_img, seq, input_interval, output_intervals, num_frames
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def generate_large_blur_sequence(frame_paths, window_max=16, output_len=17, base_rate=1):
|
| 317 |
+
"""
|
| 318 |
+
Large blur mode (fixed output_len=25) with instantaneous outputs:
|
| 319 |
+
- Raw window spans 25 * base_rate consecutive frames
|
| 320 |
+
- Build blur over that full raw window for input
|
| 321 |
+
- For output sequence:
|
| 322 |
+
• Pick 1 raw frame every `base_rate` (group_starts)
|
| 323 |
+
• Each output frame is the instantaneous frame at that raw index
|
| 324 |
+
- Input interval always [-0.5, 0.5]
|
| 325 |
+
- Output intervals reflect each 1-frame slice’s coverage within the blur window,
|
| 326 |
+
leaving gaps between.
|
| 327 |
+
"""
|
| 328 |
+
N = len(frame_paths)
|
| 329 |
+
total_raw = window_max * base_rate
|
| 330 |
+
assert N >= total_raw, f"Not enough frames for base_rate={base_rate}, need {total_raw}, got {N}"
|
| 331 |
+
start = random.randint(0, N - total_raw)
|
| 332 |
+
|
| 333 |
+
# build blur input over the full raw block
|
| 334 |
+
raw_block = frame_paths[start:start + total_raw]
|
| 335 |
+
blur_img = build_blur(raw_block)
|
| 336 |
+
|
| 337 |
+
# output sequence: instantaneous frames at each group_start
|
| 338 |
+
seq = []
|
| 339 |
+
group_starts = [start + i * base_rate for i in range(window_max)]
|
| 340 |
+
for gs in group_starts:
|
| 341 |
+
img = np.array(Image.open(frame_paths[gs]).convert('RGB'), dtype=np.uint8)
|
| 342 |
+
seq.append(img)
|
| 343 |
+
# pad blurred frames to output_len
|
| 344 |
+
seq += [seq[-1]] * (output_len - len(seq))
|
| 345 |
+
|
| 346 |
+
# compute intervals for each instantaneous frame:
|
| 347 |
+
# each covers [gs, gs+1) over total_raw, normalized to [-0.5, 0.5]
|
| 348 |
+
intervals = []
|
| 349 |
+
for gs in group_starts:
|
| 350 |
+
t0 = (gs - start) / total_raw - 0.5
|
| 351 |
+
t1 = (gs + 1 - start) / total_raw - 0.5
|
| 352 |
+
intervals.append([t0, t1])
|
| 353 |
+
num_frames = len(intervals)
|
| 354 |
+
intervals += [intervals[-1]] * (output_len - len(intervals))
|
| 355 |
+
output_intervals = torch.tensor(intervals, dtype=torch.float)
|
| 356 |
+
|
| 357 |
+
# input interval
|
| 358 |
+
input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float)
|
| 359 |
+
return blur_img, seq, input_interval, output_intervals, num_frames
|
| 360 |
+
|
| 361 |
+
def generate_test_case(frame_paths,
|
| 362 |
+
window_max=16,
|
| 363 |
+
output_len=17,
|
| 364 |
+
in_start=None,
|
| 365 |
+
in_end=None,
|
| 366 |
+
out_start=None,
|
| 367 |
+
out_end = None,
|
| 368 |
+
center=None,
|
| 369 |
+
mode="1x",
|
| 370 |
+
fps=240):
|
| 371 |
+
"""
|
| 372 |
+
Generate blurred input + a target sequence + normalized intervals.
|
| 373 |
+
|
| 374 |
+
Args:
|
| 375 |
+
frame_paths: list of all frame filepaths
|
| 376 |
+
window_max: number of groups/bins W
|
| 377 |
+
output_len: desired length of the output sequence
|
| 378 |
+
in_start, in_end: integer indices defining the raw window [in_start, in_end)
|
| 379 |
+
mode: one of "1x", "2x", or "lb"
|
| 380 |
+
fps: frames-per-second (only used to override mode=="2x" if fps==120)
|
| 381 |
+
|
| 382 |
+
Returns:
|
| 383 |
+
blur_img: np.ndarray of the global blur over the window
|
| 384 |
+
seq: list of np.ndarray, length = output_len (blured groups or raw frames)
|
| 385 |
+
input_interval: torch.Tensor [[-0.5, 0.5]]
|
| 386 |
+
output_intervals: torch.Tensor shape [output_len, 2], normalized in [-0.5,0.5]
|
| 387 |
+
"""
|
| 388 |
+
# 1) slice and blur
|
| 389 |
+
raw_paths = frame_paths[in_start:in_end]
|
| 390 |
+
|
| 391 |
+
blur_img = build_blur(raw_paths)
|
| 392 |
+
|
| 393 |
+
# 2) build the sequence
|
| 394 |
+
# one target per frame
|
| 395 |
+
seq = [
|
| 396 |
+
np.array(Image.open(p).convert("RGB"), dtype=np.uint8)
|
| 397 |
+
for p in frame_paths[out_start:out_end]
|
| 398 |
+
]
|
| 399 |
+
|
| 400 |
+
# 3) compute normalized intervals
|
| 401 |
+
input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float)
|
| 402 |
+
|
| 403 |
+
# 2) define the normalizer
|
| 404 |
+
def normalize(x, in_start, in_end):
|
| 405 |
+
return (x - in_start) / (in_end - in_start) - 0.5
|
| 406 |
+
|
| 407 |
+
base_rate = 240 // fps
|
| 408 |
+
|
| 409 |
+
# 3) define the raw intervals in absolute frame‐indices
|
| 410 |
+
base_rate = 240 // fps
|
| 411 |
+
if mode == "1x":
|
| 412 |
+
assert in_start == out_start and in_end == out_end
|
| 413 |
+
#assert fps == 240, "haven't implemented 120fps in 1x yet"
|
| 414 |
+
W = (out_end - out_start) // base_rate
|
| 415 |
+
# one frame per window
|
| 416 |
+
group_starts = [out_start + i * base_rate for i in range(W)]
|
| 417 |
+
group_ends = [out_start + (i + 1) * base_rate for i in range(W)]
|
| 418 |
+
|
| 419 |
+
elif mode == "2x":
|
| 420 |
+
W = (out_end - out_start) // base_rate
|
| 421 |
+
# every base_rate frames, starting at out_start
|
| 422 |
+
group_starts = [out_start + i * base_rate for i in range(W)]
|
| 423 |
+
group_ends = [out_start + (i + 1) * base_rate for i in range(W)]
|
| 424 |
+
|
| 425 |
+
elif mode == "lb":
|
| 426 |
+
W = (out_end - out_start) // base_rate
|
| 427 |
+
# sparse “key‐frame” windows from the raw input range
|
| 428 |
+
group_starts = [in_start + i * base_rate for i in range(W)]
|
| 429 |
+
group_ends = [s + 1 for s in group_starts]
|
| 430 |
+
|
| 431 |
+
else:
|
| 432 |
+
raise ValueError(f"Unsupported mode: {mode}")
|
| 433 |
+
|
| 434 |
+
# --- after mode‐switch, once you have raw group_starts & group_ends ---
|
| 435 |
+
# 4) build a summed video sequence by blurring each interval
|
| 436 |
+
summed_seq = []
|
| 437 |
+
for s, e in zip(group_starts, group_ends):
|
| 438 |
+
# make sure indices lie in [in_start, in_end)
|
| 439 |
+
s_clamped = max(in_start, min(s, in_end-1))
|
| 440 |
+
e_clamped = max(s_clamped+1, min(e, in_end))
|
| 441 |
+
# sum/blur the frames in [s_clamped:e_clamped)
|
| 442 |
+
summed = build_blur(frame_paths[s_clamped:e_clamped])
|
| 443 |
+
summed_seq.append(summed)
|
| 444 |
+
|
| 445 |
+
# pad to output_len
|
| 446 |
+
if len(summed_seq) < output_len:
|
| 447 |
+
summed_seq += [summed_seq[-1]] * (output_len - len(summed_seq))
|
| 448 |
+
|
| 449 |
+
# 5) now normalize your intervals as before
|
| 450 |
+
def normalize(x):
|
| 451 |
+
return (x - in_start) / (in_end - in_start) - 0.5
|
| 452 |
+
|
| 453 |
+
intervals = [[normalize(s), normalize(e)] for s, e in zip(group_starts, group_ends)]
|
| 454 |
+
num_frames = len(intervals)
|
| 455 |
+
if len(intervals) < output_len:
|
| 456 |
+
intervals += [intervals[-1]] * (output_len - len(intervals))
|
| 457 |
+
|
| 458 |
+
output_intervals = torch.tensor(intervals, dtype=torch.float)
|
| 459 |
+
|
| 460 |
+
# final return now also includes summed_seq
|
| 461 |
+
return blur_img, summed_seq, input_interval, output_intervals, seq, num_frames
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def get_conditioning(
|
| 465 |
+
output_len=17,
|
| 466 |
+
in_start=None,
|
| 467 |
+
in_end=None,
|
| 468 |
+
out_start=None,
|
| 469 |
+
out_end=None,
|
| 470 |
+
mode="1x",
|
| 471 |
+
fps=240,
|
| 472 |
+
):
|
| 473 |
+
"""
|
| 474 |
+
Generate normalized intervals conditioning singals. Just like the above function but without
|
| 475 |
+
loading any images (for inference only).
|
| 476 |
+
|
| 477 |
+
Args:
|
| 478 |
+
output_len: desired length of the output sequence
|
| 479 |
+
in_start, in_end: integer indices defining the raw window [in_start, in_end)
|
| 480 |
+
mode: one of "1x", "2x", or "lb"
|
| 481 |
+
fps: frames-per-second (only used to override mode=="2x" if fps==120)
|
| 482 |
+
|
| 483 |
+
Returns:
|
| 484 |
+
input_interval: torch.Tensor [[-0.5, 0.5]]
|
| 485 |
+
output_intervals: torch.Tensor shape [output_len, 2], normalized in [-0.5,0.5]
|
| 486 |
+
"""
|
| 487 |
+
|
| 488 |
+
# 3) compute normalized intervals
|
| 489 |
+
input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float)
|
| 490 |
+
|
| 491 |
+
# 2) define the normalizer
|
| 492 |
+
def normalize(x, in_start, in_end):
|
| 493 |
+
return (x - in_start) / (in_end - in_start) - 0.5
|
| 494 |
+
|
| 495 |
+
base_rate = 240 // fps
|
| 496 |
+
|
| 497 |
+
# 3) define the raw intervals in absolute frame‐indices
|
| 498 |
+
base_rate = 240 // fps
|
| 499 |
+
if mode == "1x":
|
| 500 |
+
assert in_start == out_start and in_end == out_end
|
| 501 |
+
#assert fps == 240, "haven't implemented 120fps in 1x yet"
|
| 502 |
+
W = (out_end - out_start) // base_rate
|
| 503 |
+
# one frame per window
|
| 504 |
+
group_starts = [out_start + i * base_rate for i in range(W)]
|
| 505 |
+
group_ends = [out_start + (i + 1) * base_rate for i in range(W)]
|
| 506 |
+
|
| 507 |
+
elif mode == "2x":
|
| 508 |
+
W = (out_end - out_start) // base_rate
|
| 509 |
+
# every base_rate frames, starting at out_start
|
| 510 |
+
group_starts = [out_start + i * base_rate for i in range(W)]
|
| 511 |
+
group_ends = [out_start + (i + 1) * base_rate for i in range(W)]
|
| 512 |
+
|
| 513 |
+
elif mode == "lb":
|
| 514 |
+
W = (out_end - out_start) // base_rate
|
| 515 |
+
# sparse “key‐frame” windows from the raw input range
|
| 516 |
+
group_starts = [in_start + i * base_rate for i in range(W)]
|
| 517 |
+
group_ends = [s + 1 for s in group_starts]
|
| 518 |
+
|
| 519 |
+
else:
|
| 520 |
+
raise ValueError(f"Unsupported mode: {mode}")
|
| 521 |
+
|
| 522 |
+
# 5) now normalize your intervals as before
|
| 523 |
+
def normalize(x):
|
| 524 |
+
return (x - in_start) / (in_end - in_start) - 0.5
|
| 525 |
+
|
| 526 |
+
intervals = [[normalize(s), normalize(e)] for s, e in zip(group_starts, group_ends)]
|
| 527 |
+
num_frames = len(intervals)
|
| 528 |
+
if len(intervals) < output_len:
|
| 529 |
+
intervals += [intervals[-1]] * (output_len - len(intervals))
|
| 530 |
+
|
| 531 |
+
output_intervals = torch.tensor(intervals, dtype=torch.float)
|
| 532 |
+
|
| 533 |
+
return input_interval, output_intervals, num_frames
|
training/slurm_scripts/simple_multinode.sbatch
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=XYZ
|
| 3 |
+
#SBATCH --nodes=4
|
| 4 |
+
#SBATCH --mem=256gb
|
| 5 |
+
#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
|
| 6 |
+
#SBATCH --cpus-per-task=28
|
| 7 |
+
#SBATCH --gpus-per-node=4
|
| 8 |
+
#SBATCH --exclusive
|
| 9 |
+
#SBATCH --output=output/slurm-%j-%N.out
|
| 10 |
+
#SBATCH --error=error/slurm-%j-%N.err
|
| 11 |
+
#SBATCH --qos=scavenger
|
| 12 |
+
#SBATCH --signal=B:USR1@300
|
| 13 |
+
#SBATCH --nodelist=lse-hpcnode[1,3,4,5,10-12]
|
| 14 |
+
|
| 15 |
+
#6 and 9 are messed up
|
| 16 |
+
#7 is sketchy as well
|
| 17 |
+
|
| 18 |
+
set -x -e
|
| 19 |
+
|
| 20 |
+
if [ -z "$1" ]
|
| 21 |
+
then
|
| 22 |
+
#quit if no job number is passed
|
| 23 |
+
echo "No config file passed, quitting"
|
| 24 |
+
exit 1
|
| 25 |
+
else
|
| 26 |
+
config_file=$1
|
| 27 |
+
fi
|
| 28 |
+
|
| 29 |
+
source ~/.bashrc
|
| 30 |
+
conda activate gencam
|
| 31 |
+
cd /datasets/sai/gencam/cogvideox/training
|
| 32 |
+
|
| 33 |
+
echo "START TIME: $(date)"
|
| 34 |
+
|
| 35 |
+
# needed until we fix IB issues
|
| 36 |
+
export NCCL_IB_DISABLE=1
|
| 37 |
+
export NCCL_SOCKET_IFNAME=ens
|
| 38 |
+
|
| 39 |
+
# Training setup
|
| 40 |
+
GPUS_PER_NODE=4
|
| 41 |
+
# so processes know who to talk to
|
| 42 |
+
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
|
| 43 |
+
MASTER_PORT=6000
|
| 44 |
+
NNODES=$SLURM_NNODES
|
| 45 |
+
NODE_RANK=$SLURM_PROCID
|
| 46 |
+
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
#CMD="accelerate_test.py"
|
| 50 |
+
CMD="train_controlnet.py --config $config_file"
|
| 51 |
+
LAUNCHER="accelerate launch \
|
| 52 |
+
--multi_gpu \
|
| 53 |
+
--gpu_ids 0,1,2,3 \
|
| 54 |
+
--num_processes $WORLD_SIZE \
|
| 55 |
+
--num_machines $NNODES \
|
| 56 |
+
--main_process_ip $MASTER_ADDR \
|
| 57 |
+
--main_process_port $MASTER_PORT \
|
| 58 |
+
--rdzv_backend=c10d \
|
| 59 |
+
--max_restarts 0 \
|
| 60 |
+
--tee 3 \
|
| 61 |
+
"
|
| 62 |
+
|
| 63 |
+
# # NOT SURE THE FOLLOWING ENV VARS IS STRICTLY NEEDED (PROBABLY NOT)
|
| 64 |
+
# export CUDA_HOME=/usr/local/cuda-11.6
|
| 65 |
+
# export LD_PRELOAD=$CUDA_HOME/lib/libnccl.so
|
| 66 |
+
# export LD_LIBRARY_PATH=$CUDA_HOME/efa/lib:$CUDA_HOME/lib:$CUDA_HOME/lib64:$LD_LIBRARY_PATH
|
| 67 |
+
|
| 68 |
+
SRUN_ARGS=" \
|
| 69 |
+
--wait=60 \
|
| 70 |
+
--kill-on-bad-exit=1 \
|
| 71 |
+
"
|
| 72 |
+
|
| 73 |
+
handler()
|
| 74 |
+
{
|
| 75 |
+
echo "Signal handler triggered at $(date)"
|
| 76 |
+
|
| 77 |
+
sleep 120 # Let training save
|
| 78 |
+
sbatch ${BASH_SOURCE[0]} $config_file
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
# register signal handler
|
| 82 |
+
trap handler SIGUSR1
|
| 83 |
+
|
| 84 |
+
clear; srun --cpu-bind=none --jobid $SLURM_JOB_ID $LAUNCHER $CMD & srun_pid=$!
|
| 85 |
+
|
| 86 |
+
wait
|
| 87 |
+
|
| 88 |
+
echo "END TIME: $(date)"
|
training/slurm_scripts/slurm-bash.sh
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
srun --nodes=1 --gpus=4 --qos=gpu4-8h --pty bash
|
training/slurm_scripts/train.sbatch
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=train_deblur
|
| 3 |
+
#SBATCH --nodes=1
|
| 4 |
+
#SBATCH --gpus-per-node=4
|
| 5 |
+
#SBATCH --qos=gpu4-8h
|
| 6 |
+
#SBATCH --signal=B:USR1@600
|
| 7 |
+
#SBATCH --cpus-per-task=24
|
| 8 |
+
#SBATCH --output=output/slurm-%j.out
|
| 9 |
+
#SBATCH --error=error/slurm-%j.err
|
| 10 |
+
#SBATCH --nodelist=lse-hpcnode[8]
|
| 11 |
+
|
| 12 |
+
#the signal time needs to be larger than the sleep in the handler function
|
| 13 |
+
|
| 14 |
+
# prepare your environment here
|
| 15 |
+
source ~/.bashrc
|
| 16 |
+
conda activate gencam
|
| 17 |
+
cd /datasets/sai/gencam/cogvideox/training
|
| 18 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
| 19 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 20 |
+
|
| 21 |
+
if [ -z "$1" ]
|
| 22 |
+
then
|
| 23 |
+
#quit if no job number is passed
|
| 24 |
+
echo "No config file passed, quitting"
|
| 25 |
+
exit 1
|
| 26 |
+
else
|
| 27 |
+
config_file=$1
|
| 28 |
+
fi
|
| 29 |
+
|
| 30 |
+
handler()
|
| 31 |
+
{
|
| 32 |
+
echo "function handler called at $(date)"
|
| 33 |
+
# Send SIGUSR1 to the captured PID of the accelerate job
|
| 34 |
+
if [ -n "$accelerate_pid" ]; then
|
| 35 |
+
echo "Sending SIGUSR1 to accelerate PID: $accelerate_pid"
|
| 36 |
+
python_id=$(ps --ppid $accelerate_pid -o pid=)
|
| 37 |
+
kill -USR1 $python_id # Send SIGUSR1 to the accelerate job
|
| 38 |
+
sleep 300 # Wait for 5 minutes
|
| 39 |
+
else
|
| 40 |
+
echo "No accelerate PID found"
|
| 41 |
+
fi
|
| 42 |
+
echo "Resubmitting job with config file: $config_file"
|
| 43 |
+
sbatch ${BASH_SOURCE[0]} $config_file
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
# register signal handler
|
| 47 |
+
trap handler SIGUSR1
|
| 48 |
+
|
| 49 |
+
echo "Starting job at $(date)"
|
| 50 |
+
#python train_controlnet.py #--config $config_file #& wait
|
| 51 |
+
accelerate launch --config_file accelerator_configs/accelerator_train_config.yaml --multi_gpu train_controlnet.py --config $config_file &
|
| 52 |
+
accelerate_pid=$!
|
| 53 |
+
|
| 54 |
+
wait
|
training/slurm_scripts/val.sbatch
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=train_deblur
|
| 3 |
+
#SBATCH --nodes=1
|
| 4 |
+
#SBATCH --gpus-per-node=4
|
| 5 |
+
#SBATCH --qos=scavenger
|
| 6 |
+
#SBATCH --signal=B:USR1@600
|
| 7 |
+
#SBATCH --cpus-per-task=24
|
| 8 |
+
#SBATCH --output=output/slurm-%j.out
|
| 9 |
+
#SBATCH --error=error/slurm-%j.err
|
| 10 |
+
#SBATCH --exclude=lse-hpcnode9
|
| 11 |
+
# prepare your environment here
|
| 12 |
+
source ~/.bashrc
|
| 13 |
+
conda activate gencam
|
| 14 |
+
cd /datasets/sai/gencam/cogvideox/training
|
| 15 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
| 16 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 17 |
+
|
| 18 |
+
if [ -z "$1" ]
|
| 19 |
+
then
|
| 20 |
+
#quit if no job number is passed
|
| 21 |
+
echo "No config file passed, quitting"
|
| 22 |
+
exit 1
|
| 23 |
+
else
|
| 24 |
+
config_file=$1
|
| 25 |
+
fi
|
| 26 |
+
|
| 27 |
+
handler()
|
| 28 |
+
{
|
| 29 |
+
echo "function handler called at $(date)"
|
| 30 |
+
# Send SIGUSR1 to the captured PID of the accelerate job
|
| 31 |
+
if [ -n "$accelerate_pid" ]; then
|
| 32 |
+
echo "Sending SIGUSR1 to accelerate PID: $accelerate_pid"
|
| 33 |
+
python_id=$(ps --ppid $accelerate_pid -o pid=)
|
| 34 |
+
kill -USR1 $python_id # Send SIGUSR1 to the accelerate job
|
| 35 |
+
sleep 300 # Wait for 5 minutes
|
| 36 |
+
else
|
| 37 |
+
echo "No accelerate PID found"
|
| 38 |
+
fi
|
| 39 |
+
sbatch ${BASH_SOURCE[0]} $config_file
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
# register signal handler
|
| 43 |
+
trap handler SIGUSR1
|
| 44 |
+
|
| 45 |
+
echo "Starting job at $(date)"
|
| 46 |
+
#python train_controlnet.py #--config $config_file #& wait
|
| 47 |
+
accelerate launch --config_file accelerator_configs/accelerator_val_config.yaml --multi_gpu train_controlnet.py --config $config_file &
|
| 48 |
+
accelerate_pid=$!
|
| 49 |
+
|
| 50 |
+
wait
|
training/test_dataset.py
ADDED
|
File without changes
|
training/train_controlnet.py
ADDED
|
@@ -0,0 +1,724 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import signal
|
| 17 |
+
import sys
|
| 18 |
+
import threading
|
| 19 |
+
import time
|
| 20 |
+
import cv2
|
| 21 |
+
sys.path.append('..')
|
| 22 |
+
from PIL import Image
|
| 23 |
+
import logging
|
| 24 |
+
import math
|
| 25 |
+
import os
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
import transformers
|
| 30 |
+
from accelerate import Accelerator
|
| 31 |
+
from accelerate.logging import get_logger
|
| 32 |
+
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
| 33 |
+
from huggingface_hub import create_repo
|
| 34 |
+
from torch.utils.data import DataLoader
|
| 35 |
+
from tqdm.auto import tqdm
|
| 36 |
+
import numpy as np
|
| 37 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
| 38 |
+
|
| 39 |
+
import diffusers
|
| 40 |
+
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler
|
| 41 |
+
from diffusers.optimization import get_scheduler
|
| 42 |
+
from diffusers.training_utils import (
|
| 43 |
+
cast_training_params,
|
| 44 |
+
free_memory,
|
| 45 |
+
)
|
| 46 |
+
from diffusers.utils import check_min_version, export_to_video, is_wandb_available
|
| 47 |
+
from diffusers.utils.torch_utils import is_compiled_module
|
| 48 |
+
|
| 49 |
+
from controlnet_datasets import FullMotionBlurDataset, GoPro2xMotionBlurDataset, OutsidePhotosDataset, GoProMotionBlurDataset, BAISTDataset
|
| 50 |
+
from controlnet_pipeline import ControlnetCogVideoXPipeline
|
| 51 |
+
from cogvideo_transformer import CogVideoXTransformer3DModel
|
| 52 |
+
from helpers import random_insert_latent_frame, transform_intervals
|
| 53 |
+
import os
|
| 54 |
+
from utils import save_frames_as_pngs, compute_prompt_embeddings, prepare_rotary_positional_embeddings, encode_prompt, get_optimizer, atomic_save, get_args
|
| 55 |
+
if is_wandb_available():
|
| 56 |
+
import wandb
|
| 57 |
+
|
| 58 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 59 |
+
check_min_version("0.31.0.dev0")
|
| 60 |
+
|
| 61 |
+
logger = get_logger(__name__)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def log_validation(
|
| 65 |
+
pipe,
|
| 66 |
+
args,
|
| 67 |
+
accelerator,
|
| 68 |
+
pipeline_args,
|
| 69 |
+
):
|
| 70 |
+
logger.info(
|
| 71 |
+
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
|
| 72 |
+
)
|
| 73 |
+
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
| 74 |
+
scheduler_args = {}
|
| 75 |
+
|
| 76 |
+
if "variance_type" in pipe.scheduler.config:
|
| 77 |
+
variance_type = pipe.scheduler.config.variance_type
|
| 78 |
+
|
| 79 |
+
if variance_type in ["learned", "learned_range"]:
|
| 80 |
+
variance_type = "fixed_small"
|
| 81 |
+
|
| 82 |
+
scheduler_args["variance_type"] = variance_type
|
| 83 |
+
|
| 84 |
+
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)
|
| 85 |
+
pipe = pipe.to(accelerator.device)
|
| 86 |
+
|
| 87 |
+
# run inference
|
| 88 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
| 89 |
+
|
| 90 |
+
videos = []
|
| 91 |
+
for _ in range(args.num_validation_videos):
|
| 92 |
+
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
|
| 93 |
+
videos.append(video)
|
| 94 |
+
|
| 95 |
+
free_memory() #delete the pipeline to free up memory
|
| 96 |
+
|
| 97 |
+
return videos
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def main(args):
|
| 102 |
+
global signal_recieved_time
|
| 103 |
+
if args.report_to == "wandb" and args.hub_token is not None:
|
| 104 |
+
raise ValueError(
|
| 105 |
+
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
| 106 |
+
" Please use `huggingface-cli login` to authenticate with the Hub."
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
| 110 |
+
# due to pytorch#99272, MPS does not yet support bfloat16.
|
| 111 |
+
raise ValueError(
|
| 112 |
+
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 116 |
+
|
| 117 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
| 118 |
+
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
| 119 |
+
accelerator = Accelerator(
|
| 120 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 121 |
+
mixed_precision=args.mixed_precision,
|
| 122 |
+
log_with=args.report_to,
|
| 123 |
+
project_config=accelerator_project_config,
|
| 124 |
+
kwargs_handlers=[kwargs],
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Disable AMP for MPS.
|
| 128 |
+
if torch.backends.mps.is_available():
|
| 129 |
+
accelerator.native_amp = False
|
| 130 |
+
|
| 131 |
+
if args.report_to == "wandb":
|
| 132 |
+
if not is_wandb_available():
|
| 133 |
+
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
| 134 |
+
|
| 135 |
+
# Make one log on every process with the configuration for debugging.
|
| 136 |
+
logging.basicConfig(
|
| 137 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 138 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 139 |
+
level=logging.INFO,
|
| 140 |
+
)
|
| 141 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 142 |
+
if accelerator.is_local_main_process:
|
| 143 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 144 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 145 |
+
else:
|
| 146 |
+
transformers.utils.logging.set_verbosity_error()
|
| 147 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 148 |
+
|
| 149 |
+
# If passed along, set the training seed now.
|
| 150 |
+
if args.seed is not None:
|
| 151 |
+
set_seed(args.seed)
|
| 152 |
+
|
| 153 |
+
# Handle the repository creation
|
| 154 |
+
if accelerator.is_main_process:
|
| 155 |
+
if args.output_dir is not None:
|
| 156 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 157 |
+
|
| 158 |
+
if args.push_to_hub:
|
| 159 |
+
repo_id = create_repo(
|
| 160 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
| 161 |
+
exist_ok=True,
|
| 162 |
+
).repo_id
|
| 163 |
+
|
| 164 |
+
# Prepare models and scheduler
|
| 165 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 166 |
+
os.path.join(args.base_dir, args.pretrained_model_name_or_path), subfolder="tokenizer", revision=args.revision
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
| 170 |
+
os.path.join(args.base_dir, args.pretrained_model_name_or_path), subfolder="text_encoder", revision=args.revision
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# CogVideoX-2b weights are stored in float16
|
| 174 |
+
config = CogVideoXTransformer3DModel.load_config(
|
| 175 |
+
os.path.join(args.base_dir, args.pretrained_model_name_or_path),
|
| 176 |
+
subfolder="transformer",
|
| 177 |
+
revision=args.revision,
|
| 178 |
+
variant=args.variant,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
load_dtype = torch.bfloat16 if "5b" in os.path.join(args.base_dir, args.pretrained_model_name_or_path).lower() else torch.float16
|
| 182 |
+
transformer = CogVideoXTransformer3DModel.from_pretrained(
|
| 183 |
+
os.path.join(args.base_dir, args.pretrained_model_name_or_path),
|
| 184 |
+
subfolder="transformer",
|
| 185 |
+
torch_dtype=load_dtype,
|
| 186 |
+
revision=args.revision,
|
| 187 |
+
variant=args.variant,
|
| 188 |
+
low_cpu_mem_usage=False,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
vae = AutoencoderKLCogVideoX.from_pretrained(
|
| 192 |
+
os.path.join(args.base_dir, args.pretrained_model_name_or_path), subfolder="vae", revision=args.revision, variant=args.variant
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
scheduler = CogVideoXDPMScheduler.from_pretrained(os.path.join(args.base_dir, args.pretrained_model_name_or_path), subfolder="scheduler")
|
| 196 |
+
|
| 197 |
+
if args.enable_slicing:
|
| 198 |
+
vae.enable_slicing()
|
| 199 |
+
if args.enable_tiling:
|
| 200 |
+
vae.enable_tiling()
|
| 201 |
+
|
| 202 |
+
# We only train the additional adapter controlnet layers
|
| 203 |
+
text_encoder.requires_grad_(False)
|
| 204 |
+
transformer.requires_grad_(True)
|
| 205 |
+
vae.requires_grad_(False)
|
| 206 |
+
|
| 207 |
+
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
| 208 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
| 209 |
+
weight_dtype = torch.float32
|
| 210 |
+
if accelerator.state.deepspeed_plugin:
|
| 211 |
+
# DeepSpeed is handling precision, use what's in the DeepSpeed config
|
| 212 |
+
if (
|
| 213 |
+
"fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
| 214 |
+
and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
|
| 215 |
+
):
|
| 216 |
+
weight_dtype = torch.float16
|
| 217 |
+
if (
|
| 218 |
+
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
| 219 |
+
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
|
| 220 |
+
):
|
| 221 |
+
weight_dtype = torch.float16
|
| 222 |
+
else:
|
| 223 |
+
if accelerator.mixed_precision == "fp16":
|
| 224 |
+
weight_dtype = torch.float16
|
| 225 |
+
elif accelerator.mixed_precision == "bf16":
|
| 226 |
+
weight_dtype = torch.bfloat16
|
| 227 |
+
|
| 228 |
+
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
|
| 229 |
+
# due to pytorch#99272, MPS does not yet support bfloat16.
|
| 230 |
+
raise ValueError(
|
| 231 |
+
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 235 |
+
transformer.to(accelerator.device, dtype=weight_dtype)
|
| 236 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 237 |
+
|
| 238 |
+
if args.gradient_checkpointing:
|
| 239 |
+
transformer.enable_gradient_checkpointing()
|
| 240 |
+
|
| 241 |
+
def unwrap_model(model):
|
| 242 |
+
model = accelerator.unwrap_model(model)
|
| 243 |
+
model = model._orig_mod if is_compiled_module(model) else model
|
| 244 |
+
return model
|
| 245 |
+
|
| 246 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
| 247 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
| 248 |
+
if args.allow_tf32 and torch.cuda.is_available():
|
| 249 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 250 |
+
|
| 251 |
+
if args.scale_lr:
|
| 252 |
+
args.learning_rate = (
|
| 253 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# Make sure the trainable params are in float32.
|
| 257 |
+
if args.mixed_precision == "fp16":
|
| 258 |
+
# only upcast trainable parameters into fp32
|
| 259 |
+
cast_training_params([transformer], dtype=torch.float32)
|
| 260 |
+
|
| 261 |
+
trainable_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
| 262 |
+
|
| 263 |
+
# Optimization parameters
|
| 264 |
+
trainable_parameters_with_lr = {"params": trainable_parameters, "lr": args.learning_rate}
|
| 265 |
+
params_to_optimize = [trainable_parameters_with_lr]
|
| 266 |
+
|
| 267 |
+
use_deepspeed_optimizer = (
|
| 268 |
+
accelerator.state.deepspeed_plugin is not None
|
| 269 |
+
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
|
| 270 |
+
)
|
| 271 |
+
use_deepspeed_scheduler = (
|
| 272 |
+
accelerator.state.deepspeed_plugin is not None
|
| 273 |
+
and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
|
| 277 |
+
|
| 278 |
+
# Dataset and DataLoader
|
| 279 |
+
DATASET_REGISTRY = {
|
| 280 |
+
"gopro": GoProMotionBlurDataset,
|
| 281 |
+
"gopro2x": GoPro2xMotionBlurDataset,
|
| 282 |
+
"full": FullMotionBlurDataset,
|
| 283 |
+
"baist": BAISTDataset,
|
| 284 |
+
"outsidephotos": OutsidePhotosDataset, # val-only special (no split)
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
if args.dataset not in DATASET_REGISTRY:
|
| 288 |
+
raise ValueError(f"Unknown dataset: {args.dataset}")
|
| 289 |
+
|
| 290 |
+
train_dataset_class = DATASET_REGISTRY[args.dataset]
|
| 291 |
+
val_dataset_class = train_dataset_class
|
| 292 |
+
|
| 293 |
+
common_kwargs = dict(
|
| 294 |
+
data_dir=os.path.join(args.base_dir, args.video_root_dir),
|
| 295 |
+
output_dir = args.output_dir,
|
| 296 |
+
image_size=(args.height, args.width),
|
| 297 |
+
stride=(args.stride_min, args.stride_max),
|
| 298 |
+
sample_n_frames=args.max_num_frames,
|
| 299 |
+
hflip_p=args.hflip_p,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
def build_kwargs(is_train: bool):
|
| 303 |
+
"""Return constructor kwargs, adding split"""
|
| 304 |
+
kw = dict(common_kwargs)
|
| 305 |
+
kw["split"] = "train" if is_train else args.val_split
|
| 306 |
+
return kw
|
| 307 |
+
|
| 308 |
+
train_dataset = train_dataset_class(**build_kwargs(is_train=True))
|
| 309 |
+
val_dataset = val_dataset_class(**build_kwargs(is_train=False))
|
| 310 |
+
|
| 311 |
+
def encode_video(video):
|
| 312 |
+
video = video.to(accelerator.device, dtype=vae.dtype)
|
| 313 |
+
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
| 314 |
+
latent_dist = vae.encode(video).latent_dist.sample() * vae.config.scaling_factor
|
| 315 |
+
return latent_dist.permute(0, 2, 1, 3, 4).to(memory_format=torch.contiguous_format)
|
| 316 |
+
|
| 317 |
+
def collate_fn(examples):
|
| 318 |
+
blur_img = [example["blur_img"] for example in examples]
|
| 319 |
+
videos = [example["video"] for example in examples]
|
| 320 |
+
if "high_fps_video" in examples[0]:
|
| 321 |
+
high_fps_videos = [example["high_fps_video"] for example in examples]
|
| 322 |
+
high_fps_videos = torch.stack(high_fps_videos)
|
| 323 |
+
high_fps_videos = high_fps_videos.to(memory_format=torch.contiguous_format).float()
|
| 324 |
+
if "bbx" in examples[0]:
|
| 325 |
+
bbx = [example["bbx"] for example in examples]
|
| 326 |
+
bbx = torch.stack(bbx)
|
| 327 |
+
bbx = bbx.to(memory_format=torch.contiguous_format).float()
|
| 328 |
+
prompts = [example["caption"] for example in examples]
|
| 329 |
+
file_names = [example["file_name"] for example in examples]
|
| 330 |
+
num_frames = [example["num_frames"] for example in examples]
|
| 331 |
+
input_intervals = [example["input_interval"] for example in examples]
|
| 332 |
+
output_intervals = [example["output_interval"] for example in examples]
|
| 333 |
+
|
| 334 |
+
videos = torch.stack(videos)
|
| 335 |
+
videos = videos.to(memory_format=torch.contiguous_format).float()
|
| 336 |
+
|
| 337 |
+
blur_img = torch.stack(blur_img)
|
| 338 |
+
blur_img = blur_img.to(memory_format=torch.contiguous_format).float()
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
input_intervals = torch.stack(input_intervals)
|
| 342 |
+
input_intervals = input_intervals.to(memory_format=torch.contiguous_format).float()
|
| 343 |
+
|
| 344 |
+
output_intervals = torch.stack(output_intervals)
|
| 345 |
+
output_intervals = output_intervals.to(memory_format=torch.contiguous_format).float()
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
out_dict = {
|
| 349 |
+
"file_names": file_names,
|
| 350 |
+
"blur_img": blur_img,
|
| 351 |
+
"videos": videos,
|
| 352 |
+
"num_frames": num_frames,
|
| 353 |
+
"prompts": prompts,
|
| 354 |
+
"input_intervals": input_intervals,
|
| 355 |
+
"output_intervals": output_intervals,
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
if "high_fps_video" in examples[0]:
|
| 359 |
+
out_dict["high_fps_video"] = high_fps_videos
|
| 360 |
+
if "bbx" in examples[0]:
|
| 361 |
+
out_dict["bbx"] = bbx
|
| 362 |
+
return out_dict
|
| 363 |
+
|
| 364 |
+
train_dataloader = DataLoader(
|
| 365 |
+
train_dataset,
|
| 366 |
+
batch_size=args.train_batch_size,
|
| 367 |
+
shuffle=True,
|
| 368 |
+
collate_fn=collate_fn,
|
| 369 |
+
num_workers=args.dataloader_num_workers,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
val_dataloader = DataLoader(
|
| 373 |
+
val_dataset,
|
| 374 |
+
batch_size=1,
|
| 375 |
+
shuffle=False,
|
| 376 |
+
collate_fn=collate_fn,
|
| 377 |
+
num_workers=args.dataloader_num_workers,
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# Scheduler and math around the number of training steps.
|
| 381 |
+
overrode_max_train_steps = False
|
| 382 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 383 |
+
if args.max_train_steps is None:
|
| 384 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 385 |
+
overrode_max_train_steps = True
|
| 386 |
+
|
| 387 |
+
if use_deepspeed_scheduler:
|
| 388 |
+
from accelerate.utils import DummyScheduler
|
| 389 |
+
|
| 390 |
+
lr_scheduler = DummyScheduler(
|
| 391 |
+
name=args.lr_scheduler,
|
| 392 |
+
optimizer=optimizer,
|
| 393 |
+
total_num_steps=args.max_train_steps * accelerator.num_processes,
|
| 394 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
| 395 |
+
)
|
| 396 |
+
else:
|
| 397 |
+
lr_scheduler = get_scheduler(
|
| 398 |
+
args.lr_scheduler,
|
| 399 |
+
optimizer=optimizer,
|
| 400 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
| 401 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
| 402 |
+
num_cycles=args.lr_num_cycles,
|
| 403 |
+
power=args.lr_power,
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
# Prepare everything with our `accelerator`.
|
| 407 |
+
transformer, optimizer, train_dataloader, lr_scheduler, val_dataloader = accelerator.prepare(
|
| 408 |
+
transformer, optimizer, train_dataloader, lr_scheduler, val_dataloader
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 412 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 413 |
+
if overrode_max_train_steps:
|
| 414 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 415 |
+
# Afterwards we recalculate our number of training epochs
|
| 416 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 417 |
+
|
| 418 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
| 419 |
+
# The trackers initializes automatically on the main process.
|
| 420 |
+
if accelerator.is_main_process:
|
| 421 |
+
tracker_name = args.tracker_name or "cogvideox-controlnet"
|
| 422 |
+
accelerator.init_trackers(tracker_name, config=vars(args))
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
accelerator.register_for_checkpointing(transformer, optimizer, lr_scheduler)
|
| 426 |
+
save_path = os.path.join(args.output_dir, f"checkpoint")
|
| 427 |
+
|
| 428 |
+
#check if the checkpoint already exists
|
| 429 |
+
if os.path.exists(save_path):
|
| 430 |
+
accelerator.load_state(save_path)
|
| 431 |
+
logger.info(f"Loaded state from {save_path}")
|
| 432 |
+
|
| 433 |
+
# Train!
|
| 434 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 435 |
+
num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
|
| 436 |
+
|
| 437 |
+
logger.info("***** Running training *****")
|
| 438 |
+
logger.info(f" Num trainable parameters = {num_trainable_parameters}")
|
| 439 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
| 440 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
| 441 |
+
logger.info(f" Num epochs = {args.num_train_epochs}")
|
| 442 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 443 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 444 |
+
logger.info(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
|
| 445 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 446 |
+
global_step = 0
|
| 447 |
+
first_epoch = 0
|
| 448 |
+
initial_global_step = 0
|
| 449 |
+
|
| 450 |
+
progress_bar = tqdm(
|
| 451 |
+
range(0, args.max_train_steps),
|
| 452 |
+
initial=initial_global_step,
|
| 453 |
+
desc="Steps",
|
| 454 |
+
# Only show the progress bar once on each machine.
|
| 455 |
+
disable=not accelerator.is_local_main_process,
|
| 456 |
+
)
|
| 457 |
+
vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1)
|
| 458 |
+
|
| 459 |
+
# For DeepSpeed training
|
| 460 |
+
model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
|
| 461 |
+
|
| 462 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
| 463 |
+
transformer.train()
|
| 464 |
+
for step, batch in enumerate(train_dataloader):
|
| 465 |
+
if not args.just_validate:
|
| 466 |
+
models_to_accumulate = [transformer]
|
| 467 |
+
with accelerator.accumulate(models_to_accumulate):
|
| 468 |
+
model_input = encode_video(batch["videos"]).to(dtype=weight_dtype) # [B, F, C, H, W]
|
| 469 |
+
prompts = batch["prompts"]
|
| 470 |
+
image_latent = encode_video(batch["blur_img"]).to(dtype=weight_dtype) # [B, F, C, H, W]
|
| 471 |
+
input_intervals = batch["input_intervals"]
|
| 472 |
+
output_intervals = batch["output_intervals"]
|
| 473 |
+
|
| 474 |
+
batch_size = len(prompts)
|
| 475 |
+
# True = use real prompt (conditional); False = drop to empty (unconditional)
|
| 476 |
+
guidance_mask = torch.rand(batch_size, device=accelerator.device) >= 0.2
|
| 477 |
+
|
| 478 |
+
# build a new prompts list: keep the original where mask True, else blank
|
| 479 |
+
per_sample_prompts = [
|
| 480 |
+
prompts[i] if guidance_mask[i] else ""
|
| 481 |
+
for i in range(batch_size)
|
| 482 |
+
]
|
| 483 |
+
prompts = per_sample_prompts
|
| 484 |
+
|
| 485 |
+
# encode prompts
|
| 486 |
+
prompt_embeds = compute_prompt_embeddings(
|
| 487 |
+
tokenizer,
|
| 488 |
+
text_encoder,
|
| 489 |
+
prompts,
|
| 490 |
+
model_config.max_text_seq_length,
|
| 491 |
+
accelerator.device,
|
| 492 |
+
weight_dtype,
|
| 493 |
+
requires_grad=False,
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
# Sample noise that will be added to the latents
|
| 497 |
+
noise = torch.randn_like(model_input)
|
| 498 |
+
batch_size, num_frames, num_channels, height, width = model_input.shape
|
| 499 |
+
|
| 500 |
+
# Sample a random timestep for each image
|
| 501 |
+
timesteps = torch.randint(
|
| 502 |
+
0, scheduler.config.num_train_timesteps, (batch_size,), device=model_input.device
|
| 503 |
+
)
|
| 504 |
+
timesteps = timesteps.long()
|
| 505 |
+
|
| 506 |
+
# Prepare rotary embeds
|
| 507 |
+
image_rotary_emb = (
|
| 508 |
+
prepare_rotary_positional_embeddings(
|
| 509 |
+
height=args.height,
|
| 510 |
+
width=args.width,
|
| 511 |
+
num_frames=num_frames,
|
| 512 |
+
vae_scale_factor_spatial=vae_scale_factor_spatial,
|
| 513 |
+
patch_size=model_config.patch_size,
|
| 514 |
+
attention_head_dim=model_config.attention_head_dim,
|
| 515 |
+
device=accelerator.device,
|
| 516 |
+
)
|
| 517 |
+
if model_config.use_rotary_positional_embeddings
|
| 518 |
+
else None
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
# Add noise to the model input according to the noise magnitude at each timestep (this is the forward diffusion process)
|
| 522 |
+
noisy_model_input = scheduler.add_noise(model_input, noise, timesteps)
|
| 523 |
+
|
| 524 |
+
input_intervals = transform_intervals(input_intervals, frames_per_latent=4)
|
| 525 |
+
output_intervals = transform_intervals(output_intervals, frames_per_latent=4)
|
| 526 |
+
|
| 527 |
+
#first interval is always rep
|
| 528 |
+
noisy_model_input, target, condition_mask, intervals = random_insert_latent_frame(image_latent, noisy_model_input, model_input, input_intervals, output_intervals, special_info=args.special_info)
|
| 529 |
+
|
| 530 |
+
for i in range(batch_size):
|
| 531 |
+
if not guidance_mask[i]:
|
| 532 |
+
noisy_model_input[i][condition_mask[i]] = 0
|
| 533 |
+
|
| 534 |
+
# Predict the noise residual
|
| 535 |
+
model_output = transformer(
|
| 536 |
+
hidden_states=noisy_model_input,
|
| 537 |
+
encoder_hidden_states=prompt_embeds,
|
| 538 |
+
intervals=intervals,
|
| 539 |
+
condition_mask=condition_mask,
|
| 540 |
+
timestep=timesteps,
|
| 541 |
+
image_rotary_emb=image_rotary_emb,
|
| 542 |
+
return_dict=False,
|
| 543 |
+
)[0]
|
| 544 |
+
|
| 545 |
+
#this line below is also scaling the input which is bad - so the model is also learning to scale this input latent somehow
|
| 546 |
+
#thus, we need to replace the first frame with the original frame later
|
| 547 |
+
model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps)
|
| 548 |
+
|
| 549 |
+
alphas_cumprod = scheduler.alphas_cumprod[timesteps]
|
| 550 |
+
weights = 1 / (1 - alphas_cumprod)
|
| 551 |
+
while len(weights.shape) < len(model_pred.shape):
|
| 552 |
+
weights = weights.unsqueeze(-1)
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
loss = torch.mean((weights * (model_pred[~condition_mask] - target[~condition_mask]) ** 2).reshape(batch_size, -1), dim=1)
|
| 557 |
+
loss = loss.mean()
|
| 558 |
+
accelerator.backward(loss)
|
| 559 |
+
|
| 560 |
+
if accelerator.state.deepspeed_plugin is None:
|
| 561 |
+
if not args.just_validate:
|
| 562 |
+
optimizer.step()
|
| 563 |
+
optimizer.zero_grad()
|
| 564 |
+
lr_scheduler.step()
|
| 565 |
+
|
| 566 |
+
#wait for all processes to finish
|
| 567 |
+
accelerator.wait_for_everyone()
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 571 |
+
if accelerator.sync_gradients:
|
| 572 |
+
progress_bar.update(1)
|
| 573 |
+
global_step += 1
|
| 574 |
+
|
| 575 |
+
if signal_recieved_time != 0:
|
| 576 |
+
if time.time() - signal_recieved_time > 60:
|
| 577 |
+
print("Signal received, saving state and exiting")
|
| 578 |
+
atomic_save(save_path, accelerator)
|
| 579 |
+
signal_recieved_time = 0
|
| 580 |
+
exit(0)
|
| 581 |
+
else:
|
| 582 |
+
exit(0)
|
| 583 |
+
|
| 584 |
+
if accelerator.is_main_process:
|
| 585 |
+
if global_step % args.checkpointing_steps == 0:
|
| 586 |
+
atomic_save(save_path, accelerator)
|
| 587 |
+
logger.info(f"Saved state to {save_path}")
|
| 588 |
+
|
| 589 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 590 |
+
progress_bar.set_postfix(**logs)
|
| 591 |
+
accelerator.log(logs, step=global_step)
|
| 592 |
+
|
| 593 |
+
if global_step >= args.max_train_steps:
|
| 594 |
+
break
|
| 595 |
+
|
| 596 |
+
print("Step", step)
|
| 597 |
+
accelerator.wait_for_everyone()
|
| 598 |
+
|
| 599 |
+
if step == 0 or args.validation_prompt is not None and (step + 1) % args.validation_steps == 0:
|
| 600 |
+
# Create pipeline
|
| 601 |
+
pipe = ControlnetCogVideoXPipeline.from_pretrained(
|
| 602 |
+
os.path.join(args.base_dir, args.pretrained_model_name_or_path),
|
| 603 |
+
transformer=unwrap_model(transformer),
|
| 604 |
+
text_encoder=unwrap_model(text_encoder),
|
| 605 |
+
vae=unwrap_model(vae),
|
| 606 |
+
scheduler=scheduler,
|
| 607 |
+
torch_dtype=weight_dtype,
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
print("Length of validation dataset: ", len(val_dataloader))
|
| 611 |
+
#create a pipeline per accelerator device (for faster inference)
|
| 612 |
+
with torch.autocast(str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"):
|
| 613 |
+
for batch in val_dataloader:
|
| 614 |
+
frame = ((batch["blur_img"][0].permute(0,2,3,1).cpu().numpy() + 1)*127.5).astype(np.uint8)
|
| 615 |
+
pipeline_args = {
|
| 616 |
+
"prompt": "",
|
| 617 |
+
"negative_prompt": "",
|
| 618 |
+
"image": frame,
|
| 619 |
+
"input_intervals": batch["input_intervals"][0:1],
|
| 620 |
+
"output_intervals": batch["output_intervals"][0:1],
|
| 621 |
+
"guidance_scale": args.guidance_scale,
|
| 622 |
+
"use_dynamic_cfg": args.use_dynamic_cfg,
|
| 623 |
+
"height": args.height,
|
| 624 |
+
"width": args.width,
|
| 625 |
+
"num_frames": args.max_num_frames,
|
| 626 |
+
"num_inference_steps": args.num_inference_steps,
|
| 627 |
+
}
|
| 628 |
+
|
| 629 |
+
modified_filenames = []
|
| 630 |
+
filenames = batch['file_names']
|
| 631 |
+
for file in filenames:
|
| 632 |
+
modified_filenames.append(os.path.splitext(file)[0] + ".mp4")
|
| 633 |
+
|
| 634 |
+
num_frames = batch["num_frames"][0]
|
| 635 |
+
#save the gt_video output
|
| 636 |
+
if args.dataset not in ["outsidephotos"]:
|
| 637 |
+
gt_video = batch["videos"][0].permute(0,2,3,1).cpu().numpy()
|
| 638 |
+
gt_video = ((gt_video + 1) * 127.5)/255
|
| 639 |
+
gt_video = gt_video[0:num_frames]
|
| 640 |
+
|
| 641 |
+
for file in modified_filenames:
|
| 642 |
+
gt_file_name = os.path.join(args.output_dir, "gt", modified_filenames[0])
|
| 643 |
+
os.makedirs(os.path.dirname(gt_file_name), exist_ok=True)
|
| 644 |
+
if args.dataset == "baist":
|
| 645 |
+
bbox = batch["bbx"][0].cpu().numpy().astype(np.int32)
|
| 646 |
+
gt_video = gt_video[:, bbox[1]:bbox[3], bbox[0]:bbox[2], :]
|
| 647 |
+
gt_video = np.array([cv2.resize(frame, (160, 192)) for frame in gt_video]) #resize to 192x160
|
| 648 |
+
|
| 649 |
+
save_frames_as_pngs((gt_video*255).astype(np.uint8), gt_file_name.replace(".mp4", "").replace("gt", "gt_frames"))
|
| 650 |
+
export_to_video(gt_video, gt_file_name, fps=20)
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
if "high_fps_video" in batch:
|
| 654 |
+
high_fps_video = batch["high_fps_video"][0].permute(0,2,3,1).cpu().numpy()
|
| 655 |
+
high_fps_video = ((high_fps_video + 1) * 127.5)/255
|
| 656 |
+
gt_file_name = os.path.join(args.output_dir, "gt_highfps", modified_filenames[0])
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
#save the blurred image
|
| 660 |
+
if args.dataset in ["full", "outsidephotos", "gopro2x", "baist"]:
|
| 661 |
+
for file in modified_filenames:
|
| 662 |
+
blurry_file_name = os.path.join(args.output_dir, "blurry", modified_filenames[0].replace(".mp4", ".png"))
|
| 663 |
+
os.makedirs(os.path.dirname(blurry_file_name), exist_ok=True)
|
| 664 |
+
if args.dataset == "baist":
|
| 665 |
+
bbox = batch["bbx"][0].cpu().numpy().astype(np.int32)
|
| 666 |
+
frame0 = frame[0][bbox[1]:bbox[3], bbox[0]:bbox[2], :]
|
| 667 |
+
frame0 = cv2.resize(frame0, (160, 192)) #resize to 192x160
|
| 668 |
+
Image.fromarray(frame0).save(blurry_file_name)
|
| 669 |
+
else:
|
| 670 |
+
Image.fromarray(frame[0]).save(blurry_file_name)
|
| 671 |
+
|
| 672 |
+
videos = log_validation(
|
| 673 |
+
pipe=pipe,
|
| 674 |
+
args=args,
|
| 675 |
+
accelerator=accelerator,
|
| 676 |
+
pipeline_args=pipeline_args
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
#save the output video frames as pngs (uncompressed results) and mp4 (compressed results easily viewable)
|
| 680 |
+
for i, video in enumerate(videos):
|
| 681 |
+
video = video[0:num_frames]
|
| 682 |
+
filename = os.path.join(args.output_dir, "deblurred", modified_filenames[0])
|
| 683 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
| 684 |
+
if args.dataset == "baist":
|
| 685 |
+
bbox = batch["bbx"][0].cpu().numpy().astype(np.int32)
|
| 686 |
+
video = video[:, bbox[1]:bbox[3], bbox[0]:bbox[2], :]
|
| 687 |
+
video = np.array([cv2.resize(frame, (160, 192)) for frame in video]) #resize to 192x160
|
| 688 |
+
save_frames_as_pngs((video*255).astype(np.uint8), filename.replace(".mp4", "").replace("deblurred", "deblurred_frames"))
|
| 689 |
+
export_to_video(video, filename, fps=20)
|
| 690 |
+
accelerator.wait_for_everyone()
|
| 691 |
+
|
| 692 |
+
if args.just_validate:
|
| 693 |
+
exit(0)
|
| 694 |
+
|
| 695 |
+
accelerator.wait_for_everyone()
|
| 696 |
+
accelerator.end_training()
|
| 697 |
+
|
| 698 |
+
signal_recieved_time = 0
|
| 699 |
+
|
| 700 |
+
def handle_signal(signum, frame):
|
| 701 |
+
global signal_recieved_time
|
| 702 |
+
signal_recieved_time = time.time()
|
| 703 |
+
|
| 704 |
+
print(f"Signal {signum} received at {time.ctime()}")
|
| 705 |
+
|
| 706 |
+
with open("/datasets/sai/gencam/cogvideox/interrupted.txt", "w") as f:
|
| 707 |
+
f.write(f"Training was interrupted at {time.ctime()}")
|
| 708 |
+
|
| 709 |
+
if __name__ == "__main__":
|
| 710 |
+
|
| 711 |
+
args = get_args()
|
| 712 |
+
|
| 713 |
+
print("Registering signal handler")
|
| 714 |
+
#Register the signal handler (catch SIGUSR1)
|
| 715 |
+
signal.signal(signal.SIGUSR1, handle_signal)
|
| 716 |
+
|
| 717 |
+
main_thread = threading.Thread(target=main, args=(args,))
|
| 718 |
+
main_thread.start()
|
| 719 |
+
|
| 720 |
+
while signal_recieved_time!= 0:
|
| 721 |
+
time.sleep(1)
|
| 722 |
+
|
| 723 |
+
#call main with args as a thread
|
| 724 |
+
|
training/train_controlnet_backup.py
ADDED
|
@@ -0,0 +1,1235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import random
|
| 17 |
+
import signal
|
| 18 |
+
import sys
|
| 19 |
+
import threading
|
| 20 |
+
import time
|
| 21 |
+
|
| 22 |
+
import cv2
|
| 23 |
+
import yaml
|
| 24 |
+
|
| 25 |
+
sys.path.append('..')
|
| 26 |
+
import argparse
|
| 27 |
+
from PIL import Image
|
| 28 |
+
import logging
|
| 29 |
+
import math
|
| 30 |
+
import os
|
| 31 |
+
import shutil
|
| 32 |
+
from pathlib import Path
|
| 33 |
+
from typing import List, Optional, Tuple, Union
|
| 34 |
+
|
| 35 |
+
import torch
|
| 36 |
+
import transformers
|
| 37 |
+
from accelerate import Accelerator
|
| 38 |
+
from accelerate.logging import get_logger
|
| 39 |
+
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
| 40 |
+
from huggingface_hub import create_repo, upload_folder
|
| 41 |
+
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
|
| 42 |
+
from torch.utils.data import DataLoader, Dataset
|
| 43 |
+
from torchvision import transforms
|
| 44 |
+
from tqdm.auto import tqdm
|
| 45 |
+
import numpy as np
|
| 46 |
+
from decord import VideoReader
|
| 47 |
+
from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
|
| 48 |
+
|
| 49 |
+
import diffusers
|
| 50 |
+
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler
|
| 51 |
+
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
| 52 |
+
from diffusers.optimization import get_scheduler
|
| 53 |
+
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
|
| 54 |
+
from diffusers.training_utils import (
|
| 55 |
+
cast_training_params,
|
| 56 |
+
free_memory,
|
| 57 |
+
)
|
| 58 |
+
from diffusers.utils import check_min_version, export_to_video, is_wandb_available
|
| 59 |
+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
| 60 |
+
from diffusers.utils.torch_utils import is_compiled_module
|
| 61 |
+
|
| 62 |
+
from controlnet_datasets import AblationFullMotionBlurDataset, AdobeMotionBlurDataset, FullMotionBlurDataset, GoPro2xMotionBlurDataset, GoProLargeMotionBlurDataset, OutsidePhotosDataset, GoProMotionBlurDataset, BAISTDataset, SimpleBAISTDataset
|
| 63 |
+
from controlnet_pipeline import ControlnetCogVideoXPipeline
|
| 64 |
+
from cogvideo_transformer import CogVideoXTransformer3DModel
|
| 65 |
+
from helpers import random_insert_latent_frame, transform_intervals
|
| 66 |
+
import os
|
| 67 |
+
import tempfile
|
| 68 |
+
from atomicwrites import atomic_write
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
if is_wandb_available():
|
| 73 |
+
import wandb
|
| 74 |
+
|
| 75 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 76 |
+
check_min_version("0.31.0.dev0")
|
| 77 |
+
|
| 78 |
+
logger = get_logger(__name__)
|
| 79 |
+
|
| 80 |
+
def save_frames_as_pngs(video_array,output_dir,
|
| 81 |
+
downsample_spatial=1, # e.g. 2 to halve width & height
|
| 82 |
+
downsample_temporal=1): # e.g. 2 to keep every 2nd frame
|
| 83 |
+
"""
|
| 84 |
+
Save each frame of a (T, H, W, C) numpy array as a PNG with no compression.
|
| 85 |
+
"""
|
| 86 |
+
assert video_array.ndim == 4 and video_array.shape[-1] == 3, \
|
| 87 |
+
"Expected (T, H, W, C=3) array"
|
| 88 |
+
assert video_array.dtype == np.uint8, "Expected uint8 array"
|
| 89 |
+
|
| 90 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 91 |
+
|
| 92 |
+
# temporal downsample
|
| 93 |
+
frames = video_array[::downsample_temporal]
|
| 94 |
+
|
| 95 |
+
# compute spatially downsampled size
|
| 96 |
+
T, H, W, _ = frames.shape
|
| 97 |
+
new_size = (W // downsample_spatial, H // downsample_spatial)
|
| 98 |
+
|
| 99 |
+
# PNG compression param: 0 = no compression
|
| 100 |
+
png_params = [cv2.IMWRITE_PNG_COMPRESSION, 0]
|
| 101 |
+
|
| 102 |
+
for idx, frame in enumerate(frames):
|
| 103 |
+
# frame is RGB; convert to BGR for OpenCV
|
| 104 |
+
bgr = frame[..., ::-1]
|
| 105 |
+
if downsample_spatial > 1:
|
| 106 |
+
bgr = cv2.resize(bgr, new_size, interpolation=cv2.INTER_NEAREST)
|
| 107 |
+
|
| 108 |
+
filename = os.path.join(output_dir, "frame_{:05d}.png".format(idx))
|
| 109 |
+
success = cv2.imwrite(filename, bgr, png_params)
|
| 110 |
+
if not success:
|
| 111 |
+
raise RuntimeError("Failed to write frame ")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def get_args():
|
| 115 |
+
parser = argparse.ArgumentParser(description="Training script for CogVideoX using config file.")
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"--config",
|
| 118 |
+
type=str,
|
| 119 |
+
required=True,
|
| 120 |
+
help="Path to the YAML config file."
|
| 121 |
+
)
|
| 122 |
+
args = parser.parse_args()
|
| 123 |
+
|
| 124 |
+
with open(args.config, "r") as f:
|
| 125 |
+
config = yaml.safe_load(f)
|
| 126 |
+
|
| 127 |
+
args = argparse.Namespace(**config)
|
| 128 |
+
|
| 129 |
+
# Convert nested config dict to an argparse.Namespace for easier downstream usage
|
| 130 |
+
return args
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# def read_video(video_path, start_index=0, frames_count=49, stride=1):
|
| 134 |
+
# video_reader = VideoReader(video_path)
|
| 135 |
+
# end_index = min(start_index + frames_count * stride, len(video_reader)) - 1
|
| 136 |
+
# batch_index = np.linspace(start_index, end_index, frames_count, dtype=int)
|
| 137 |
+
# numpy_video = video_reader.get_batch(batch_index).asnumpy()
|
| 138 |
+
# return numpy_video
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def log_validation(
|
| 142 |
+
pipe,
|
| 143 |
+
args,
|
| 144 |
+
accelerator,
|
| 145 |
+
pipeline_args,
|
| 146 |
+
epoch,
|
| 147 |
+
is_final_validation: bool = False,
|
| 148 |
+
):
|
| 149 |
+
logger.info(
|
| 150 |
+
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
|
| 151 |
+
)
|
| 152 |
+
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
| 153 |
+
scheduler_args = {}
|
| 154 |
+
|
| 155 |
+
if "variance_type" in pipe.scheduler.config:
|
| 156 |
+
variance_type = pipe.scheduler.config.variance_type
|
| 157 |
+
|
| 158 |
+
if variance_type in ["learned", "learned_range"]:
|
| 159 |
+
variance_type = "fixed_small"
|
| 160 |
+
|
| 161 |
+
scheduler_args["variance_type"] = variance_type
|
| 162 |
+
|
| 163 |
+
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)
|
| 164 |
+
pipe = pipe.to(accelerator.device)
|
| 165 |
+
# pipe.set_progress_bar_config(disable=True)
|
| 166 |
+
|
| 167 |
+
# run inference
|
| 168 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
| 169 |
+
|
| 170 |
+
videos = []
|
| 171 |
+
for _ in range(args.num_validation_videos):
|
| 172 |
+
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
|
| 173 |
+
videos.append(video)
|
| 174 |
+
|
| 175 |
+
free_memory()
|
| 176 |
+
|
| 177 |
+
return videos
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _get_t5_prompt_embeds(
|
| 181 |
+
tokenizer: T5Tokenizer,
|
| 182 |
+
text_encoder: T5EncoderModel,
|
| 183 |
+
prompt: Union[str, List[str]],
|
| 184 |
+
num_videos_per_prompt: int = 1,
|
| 185 |
+
max_sequence_length: int = 226,
|
| 186 |
+
device: Optional[torch.device] = None,
|
| 187 |
+
dtype: Optional[torch.dtype] = None,
|
| 188 |
+
text_input_ids=None,
|
| 189 |
+
):
|
| 190 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 191 |
+
batch_size = len(prompt)
|
| 192 |
+
|
| 193 |
+
if tokenizer is not None:
|
| 194 |
+
text_inputs = tokenizer(
|
| 195 |
+
prompt,
|
| 196 |
+
padding="max_length",
|
| 197 |
+
max_length=max_sequence_length,
|
| 198 |
+
truncation=True,
|
| 199 |
+
add_special_tokens=True,
|
| 200 |
+
return_tensors="pt",
|
| 201 |
+
)
|
| 202 |
+
text_input_ids = text_inputs.input_ids
|
| 203 |
+
else:
|
| 204 |
+
if text_input_ids is None:
|
| 205 |
+
raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
|
| 206 |
+
|
| 207 |
+
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
| 208 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 209 |
+
|
| 210 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 211 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 212 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 213 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 214 |
+
|
| 215 |
+
return prompt_embeds
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def encode_prompt(
|
| 219 |
+
tokenizer: T5Tokenizer,
|
| 220 |
+
text_encoder: T5EncoderModel,
|
| 221 |
+
prompt: Union[str, List[str]],
|
| 222 |
+
num_videos_per_prompt: int = 1,
|
| 223 |
+
max_sequence_length: int = 226,
|
| 224 |
+
device: Optional[torch.device] = None,
|
| 225 |
+
dtype: Optional[torch.dtype] = None,
|
| 226 |
+
text_input_ids=None,
|
| 227 |
+
):
|
| 228 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 229 |
+
prompt_embeds = _get_t5_prompt_embeds(
|
| 230 |
+
tokenizer,
|
| 231 |
+
text_encoder,
|
| 232 |
+
prompt=prompt,
|
| 233 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 234 |
+
max_sequence_length=max_sequence_length,
|
| 235 |
+
device=device,
|
| 236 |
+
dtype=dtype,
|
| 237 |
+
text_input_ids=text_input_ids,
|
| 238 |
+
)
|
| 239 |
+
return prompt_embeds
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def compute_prompt_embeddings(
|
| 243 |
+
tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False
|
| 244 |
+
):
|
| 245 |
+
if requires_grad:
|
| 246 |
+
prompt_embeds = encode_prompt(
|
| 247 |
+
tokenizer,
|
| 248 |
+
text_encoder,
|
| 249 |
+
prompt,
|
| 250 |
+
num_videos_per_prompt=1,
|
| 251 |
+
max_sequence_length=max_sequence_length,
|
| 252 |
+
device=device,
|
| 253 |
+
dtype=dtype,
|
| 254 |
+
)
|
| 255 |
+
else:
|
| 256 |
+
with torch.no_grad():
|
| 257 |
+
prompt_embeds = encode_prompt(
|
| 258 |
+
tokenizer,
|
| 259 |
+
text_encoder,
|
| 260 |
+
prompt,
|
| 261 |
+
num_videos_per_prompt=1,
|
| 262 |
+
max_sequence_length=max_sequence_length,
|
| 263 |
+
device=device,
|
| 264 |
+
dtype=dtype,
|
| 265 |
+
)
|
| 266 |
+
return prompt_embeds
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def prepare_rotary_positional_embeddings(
|
| 270 |
+
height: int,
|
| 271 |
+
width: int,
|
| 272 |
+
num_frames: int,
|
| 273 |
+
vae_scale_factor_spatial: int = 8,
|
| 274 |
+
patch_size: int = 2,
|
| 275 |
+
attention_head_dim: int = 64,
|
| 276 |
+
device: Optional[torch.device] = None,
|
| 277 |
+
base_height: int = 480,
|
| 278 |
+
base_width: int = 720,
|
| 279 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 280 |
+
grid_height = height // (vae_scale_factor_spatial * patch_size)
|
| 281 |
+
grid_width = width // (vae_scale_factor_spatial * patch_size)
|
| 282 |
+
base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
|
| 283 |
+
base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
|
| 284 |
+
|
| 285 |
+
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
|
| 286 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 287 |
+
embed_dim=attention_head_dim,
|
| 288 |
+
crops_coords=grid_crops_coords,
|
| 289 |
+
grid_size=(grid_height, grid_width),
|
| 290 |
+
temporal_size=num_frames,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
freqs_cos = freqs_cos.to(device=device)
|
| 294 |
+
freqs_sin = freqs_sin.to(device=device)
|
| 295 |
+
return freqs_cos, freqs_sin
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
|
| 299 |
+
# Use DeepSpeed optimzer
|
| 300 |
+
if use_deepspeed:
|
| 301 |
+
from accelerate.utils import DummyOptim
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
return DummyOptim(
|
| 305 |
+
params_to_optimize,
|
| 306 |
+
lr=args.learning_rate,
|
| 307 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 308 |
+
eps=args.adam_epsilon,
|
| 309 |
+
weight_decay=args.adam_weight_decay,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Optimizer creation
|
| 313 |
+
supported_optimizers = ["adam", "adamw", "prodigy"]
|
| 314 |
+
if args.optimizer not in supported_optimizers:
|
| 315 |
+
logger.warning(
|
| 316 |
+
f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW"
|
| 317 |
+
)
|
| 318 |
+
args.optimizer = "adamw"
|
| 319 |
+
|
| 320 |
+
if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]):
|
| 321 |
+
logger.warning(
|
| 322 |
+
f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was "
|
| 323 |
+
f"set to {args.optimizer.lower()}"
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
if args.use_8bit_adam:
|
| 327 |
+
try:
|
| 328 |
+
import bitsandbytes as bnb
|
| 329 |
+
except ImportError:
|
| 330 |
+
raise ImportError(
|
| 331 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
if args.optimizer.lower() == "adamw":
|
| 335 |
+
optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW
|
| 336 |
+
|
| 337 |
+
optimizer = optimizer_class(
|
| 338 |
+
params_to_optimize,
|
| 339 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 340 |
+
eps=args.adam_epsilon,
|
| 341 |
+
weight_decay=args.adam_weight_decay,
|
| 342 |
+
)
|
| 343 |
+
elif args.optimizer.lower() == "adam":
|
| 344 |
+
optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
optimizer = optimizer_class(
|
| 348 |
+
params_to_optimize,
|
| 349 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 350 |
+
eps=args.adam_epsilon,
|
| 351 |
+
weight_decay=args.adam_weight_decay,
|
| 352 |
+
)
|
| 353 |
+
elif args.optimizer.lower() == "prodigy":
|
| 354 |
+
try:
|
| 355 |
+
import prodigyopt
|
| 356 |
+
except ImportError:
|
| 357 |
+
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
|
| 358 |
+
|
| 359 |
+
optimizer_class = prodigyopt.Prodigy
|
| 360 |
+
|
| 361 |
+
if args.learning_rate <= 0.1:
|
| 362 |
+
logger.warning(
|
| 363 |
+
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
optimizer = optimizer_class(
|
| 367 |
+
params_to_optimize,
|
| 368 |
+
lr=args.learning_rate,
|
| 369 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 370 |
+
beta3=args.prodigy_beta3,
|
| 371 |
+
weight_decay=args.adam_weight_decay,
|
| 372 |
+
eps=args.adam_epsilon,
|
| 373 |
+
decouple=args.prodigy_decouple,
|
| 374 |
+
use_bias_correction=args.prodigy_use_bias_correction,
|
| 375 |
+
safeguard_warmup=args.prodigy_safeguard_warmup,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
return optimizer
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def main(args):
|
| 382 |
+
global signal_recieved_time
|
| 383 |
+
if args.report_to == "wandb" and args.hub_token is not None:
|
| 384 |
+
raise ValueError(
|
| 385 |
+
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
| 386 |
+
" Please use `huggingface-cli login` to authenticate with the Hub."
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
| 390 |
+
# due to pytorch#99272, MPS does not yet support bfloat16.
|
| 391 |
+
raise ValueError(
|
| 392 |
+
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 396 |
+
|
| 397 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
| 398 |
+
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
| 399 |
+
accelerator = Accelerator(
|
| 400 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 401 |
+
mixed_precision=args.mixed_precision,
|
| 402 |
+
log_with=args.report_to,
|
| 403 |
+
project_config=accelerator_project_config,
|
| 404 |
+
kwargs_handlers=[kwargs],
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
# Disable AMP for MPS.
|
| 408 |
+
if torch.backends.mps.is_available():
|
| 409 |
+
accelerator.native_amp = False
|
| 410 |
+
|
| 411 |
+
if args.report_to == "wandb":
|
| 412 |
+
if not is_wandb_available():
|
| 413 |
+
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
| 414 |
+
|
| 415 |
+
# Make one log on every process with the configuration for debugging.
|
| 416 |
+
logging.basicConfig(
|
| 417 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 418 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 419 |
+
level=logging.INFO,
|
| 420 |
+
)
|
| 421 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 422 |
+
if accelerator.is_local_main_process:
|
| 423 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 424 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 425 |
+
else:
|
| 426 |
+
transformers.utils.logging.set_verbosity_error()
|
| 427 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 428 |
+
|
| 429 |
+
# If passed along, set the training seed now.
|
| 430 |
+
if args.seed is not None:
|
| 431 |
+
set_seed(args.seed)
|
| 432 |
+
|
| 433 |
+
# Handle the repository creation
|
| 434 |
+
if accelerator.is_main_process:
|
| 435 |
+
if args.output_dir is not None:
|
| 436 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 437 |
+
|
| 438 |
+
if args.push_to_hub:
|
| 439 |
+
repo_id = create_repo(
|
| 440 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
| 441 |
+
exist_ok=True,
|
| 442 |
+
).repo_id
|
| 443 |
+
|
| 444 |
+
# Prepare models and scheduler
|
| 445 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 446 |
+
os.path.join(args.base_dir, args.pretrained_model_name_or_path), subfolder="tokenizer", revision=args.revision
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
| 450 |
+
os.path.join(args.base_dir, args.pretrained_model_name_or_path), subfolder="text_encoder", revision=args.revision
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
# CogVideoX-2b weights are stored in float16
|
| 454 |
+
# CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
|
| 455 |
+
|
| 456 |
+
## TRYING NEW CONFIG LOADING
|
| 457 |
+
config = CogVideoXTransformer3DModel.load_config(
|
| 458 |
+
os.path.join(args.base_dir, args.pretrained_model_name_or_path),
|
| 459 |
+
subfolder="transformer",
|
| 460 |
+
revision=args.revision,
|
| 461 |
+
variant=args.variant,
|
| 462 |
+
)
|
| 463 |
+
config["ablation_mode"] = args.ablation_mode if hasattr(args, "ablation_mode") else None
|
| 464 |
+
|
| 465 |
+
##FINISH TRYING NEW CONFIG LOADING
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
load_dtype = torch.bfloat16 if "5b" in os.path.join(args.base_dir, args.pretrained_model_name_or_path).lower() else torch.float16
|
| 470 |
+
transformer = CogVideoXTransformer3DModel.from_pretrained(
|
| 471 |
+
os.path.join(args.base_dir, args.pretrained_model_name_or_path),
|
| 472 |
+
subfolder="transformer",
|
| 473 |
+
torch_dtype=load_dtype,
|
| 474 |
+
ablation_mode=args.ablation_mode if hasattr(args, "ablation_mode") else None,
|
| 475 |
+
revision=args.revision,
|
| 476 |
+
variant=args.variant,
|
| 477 |
+
low_cpu_mem_usage=False,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
vae = AutoencoderKLCogVideoX.from_pretrained(
|
| 481 |
+
os.path.join(args.base_dir, args.pretrained_model_name_or_path), subfolder="vae", revision=args.revision, variant=args.variant
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
scheduler = CogVideoXDPMScheduler.from_pretrained(os.path.join(args.base_dir, args.pretrained_model_name_or_path), subfolder="scheduler")
|
| 488 |
+
|
| 489 |
+
if args.enable_slicing:
|
| 490 |
+
vae.enable_slicing()
|
| 491 |
+
if args.enable_tiling:
|
| 492 |
+
vae.enable_tiling()
|
| 493 |
+
|
| 494 |
+
# We only train the additional adapter controlnet layers
|
| 495 |
+
text_encoder.requires_grad_(False)
|
| 496 |
+
transformer.requires_grad_(True)
|
| 497 |
+
vae.requires_grad_(False)
|
| 498 |
+
|
| 499 |
+
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
| 500 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
| 501 |
+
weight_dtype = torch.float32
|
| 502 |
+
if accelerator.state.deepspeed_plugin:
|
| 503 |
+
# DeepSpeed is handling precision, use what's in the DeepSpeed config
|
| 504 |
+
if (
|
| 505 |
+
"fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
| 506 |
+
and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
|
| 507 |
+
):
|
| 508 |
+
weight_dtype = torch.float16
|
| 509 |
+
if (
|
| 510 |
+
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
| 511 |
+
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
|
| 512 |
+
):
|
| 513 |
+
weight_dtype = torch.float16
|
| 514 |
+
else:
|
| 515 |
+
if accelerator.mixed_precision == "fp16":
|
| 516 |
+
weight_dtype = torch.float16
|
| 517 |
+
elif accelerator.mixed_precision == "bf16":
|
| 518 |
+
weight_dtype = torch.bfloat16
|
| 519 |
+
|
| 520 |
+
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
|
| 521 |
+
# due to pytorch#99272, MPS does not yet support bfloat16.
|
| 522 |
+
raise ValueError(
|
| 523 |
+
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 527 |
+
transformer.to(accelerator.device, dtype=weight_dtype)
|
| 528 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 529 |
+
|
| 530 |
+
if args.gradient_checkpointing:
|
| 531 |
+
transformer.enable_gradient_checkpointing()
|
| 532 |
+
|
| 533 |
+
def unwrap_model(model):
|
| 534 |
+
model = accelerator.unwrap_model(model)
|
| 535 |
+
model = model._orig_mod if is_compiled_module(model) else model
|
| 536 |
+
return model
|
| 537 |
+
|
| 538 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
| 539 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
| 540 |
+
if args.allow_tf32 and torch.cuda.is_available():
|
| 541 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 542 |
+
|
| 543 |
+
if args.scale_lr:
|
| 544 |
+
args.learning_rate = (
|
| 545 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
# Make sure the trainable params are in float32.
|
| 549 |
+
if args.mixed_precision == "fp16":
|
| 550 |
+
# only upcast trainable parameters into fp32
|
| 551 |
+
cast_training_params([transformer], dtype=torch.float32)
|
| 552 |
+
|
| 553 |
+
trainable_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
| 554 |
+
|
| 555 |
+
# Optimization parameters
|
| 556 |
+
trainable_parameters_with_lr = {"params": trainable_parameters, "lr": args.learning_rate}
|
| 557 |
+
params_to_optimize = [trainable_parameters_with_lr]
|
| 558 |
+
|
| 559 |
+
use_deepspeed_optimizer = (
|
| 560 |
+
accelerator.state.deepspeed_plugin is not None
|
| 561 |
+
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
|
| 562 |
+
)
|
| 563 |
+
use_deepspeed_scheduler = (
|
| 564 |
+
accelerator.state.deepspeed_plugin is not None
|
| 565 |
+
and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
|
| 569 |
+
|
| 570 |
+
# Dataset and DataLoader
|
| 571 |
+
if args.dataset == "adobe":
|
| 572 |
+
train_dataset = AdobeMotionBlurDataset(
|
| 573 |
+
data_dir=os.path.join(args.base_dir, args.video_root_dir),
|
| 574 |
+
split = "train",
|
| 575 |
+
image_size=(args.height, args.width),
|
| 576 |
+
stride=(args.stride_min, args.stride_max),
|
| 577 |
+
sample_n_frames=args.max_num_frames,
|
| 578 |
+
hflip_p=args.hflip_p,
|
| 579 |
+
)
|
| 580 |
+
elif args.dataset == "gopro":
|
| 581 |
+
train_dataset = GoProMotionBlurDataset(
|
| 582 |
+
data_dir=os.path.join(args.base_dir, args.video_root_dir),
|
| 583 |
+
split = "train",
|
| 584 |
+
image_size=(args.height, args.width),
|
| 585 |
+
stride=(args.stride_min, args.stride_max),
|
| 586 |
+
sample_n_frames=args.max_num_frames,
|
| 587 |
+
hflip_p=args.hflip_p,
|
| 588 |
+
)
|
| 589 |
+
elif args.dataset == "gopro2x":
|
| 590 |
+
train_dataset = GoPro2xMotionBlurDataset(
|
| 591 |
+
data_dir=os.path.join(args.base_dir, args.video_root_dir),
|
| 592 |
+
split = "train",
|
| 593 |
+
image_size=(args.height, args.width),
|
| 594 |
+
stride=(args.stride_min, args.stride_max),
|
| 595 |
+
sample_n_frames=args.max_num_frames,
|
| 596 |
+
hflip_p=args.hflip_p,
|
| 597 |
+
)
|
| 598 |
+
elif args.dataset == "goprolarge":
|
| 599 |
+
train_dataset = GoProLargeMotionBlurDataset(
|
| 600 |
+
data_dir=os.path.join(args.base_dir, args.video_root_dir),
|
| 601 |
+
split = "train",
|
| 602 |
+
image_size=(args.height, args.width),
|
| 603 |
+
stride=(args.stride_min, args.stride_max),
|
| 604 |
+
sample_n_frames=args.max_num_frames,
|
| 605 |
+
hflip_p=args.hflip_p,
|
| 606 |
+
)
|
| 607 |
+
elif args.dataset == "full":
|
| 608 |
+
train_dataset = FullMotionBlurDataset(
|
| 609 |
+
data_dir=os.path.join(args.base_dir, args.video_root_dir),
|
| 610 |
+
split = "train",
|
| 611 |
+
image_size=(args.height, args.width),
|
| 612 |
+
stride=(args.stride_min, args.stride_max),
|
| 613 |
+
sample_n_frames=args.max_num_frames,
|
| 614 |
+
hflip_p=args.hflip_p,
|
| 615 |
+
)
|
| 616 |
+
elif args.dataset == "fullablation":
|
| 617 |
+
train_dataset = AblationFullMotionBlurDataset(
|
| 618 |
+
data_dir=os.path.join(args.base_dir, args.video_root_dir),
|
| 619 |
+
split = "train",
|
| 620 |
+
image_size=(args.height, args.width),
|
| 621 |
+
stride=(args.stride_min, args.stride_max),
|
| 622 |
+
sample_n_frames=args.max_num_frames,
|
| 623 |
+
hflip_p=args.hflip_p,
|
| 624 |
+
ablation_mode = args.ablation_mode, #this is not called for now
|
| 625 |
+
)
|
| 626 |
+
elif args.dataset == "baist":
|
| 627 |
+
train_dataset = BAISTDataset(
|
| 628 |
+
data_dir=os.path.join(args.base_dir, args.video_root_dir),
|
| 629 |
+
split = "train",
|
| 630 |
+
image_size=(args.height, args.width),
|
| 631 |
+
stride=(args.stride_min, args.stride_max),
|
| 632 |
+
sample_n_frames=args.max_num_frames,
|
| 633 |
+
hflip_p=args.hflip_p,
|
| 634 |
+
) #this is not called for now
|
| 635 |
+
elif args.dataset == "simplebaist":
|
| 636 |
+
train_dataset = SimpleBAISTDataset(
|
| 637 |
+
data_dir=os.path.join(args.base_dir, args.video_root_dir),
|
| 638 |
+
split = "train",
|
| 639 |
+
image_size=(args.height, args.width),
|
| 640 |
+
stride=(args.stride_min, args.stride_max),
|
| 641 |
+
sample_n_frames=args.max_num_frames,
|
| 642 |
+
hflip_p=args.hflip_p,
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
if args.dataset == "adobe":
|
| 647 |
+
val_dataset = AdobeMotionBlurDataset(
|
| 648 |
+
data_dir=os.path.join(args.base_dir, args.video_root_dir),
|
| 649 |
+
split = args.val_split,
|
| 650 |
+
image_size=(args.height, args.width),
|
| 651 |
+
stride=(args.stride_min, args.stride_max),
|
| 652 |
+
sample_n_frames=args.max_num_frames,
|
| 653 |
+
hflip_p=args.hflip_p,
|
| 654 |
+
)
|
| 655 |
+
elif args.dataset == "outsidephotos":
|
| 656 |
+
|
| 657 |
+
val_dataset = OutsidePhotosDataset(
|
| 658 |
+
data_dir=os.path.join(args.base_dir, args.video_root_dir),
|
| 659 |
+
image_size=(args.height, args.width),
|
| 660 |
+
stride=(args.stride_min, args.stride_max),
|
| 661 |
+
sample_n_frames=args.max_num_frames,
|
| 662 |
+
hflip_p=args.hflip_p,
|
| 663 |
+
)
|
| 664 |
+
train_dataset = val_dataset #dummy dataset
|
| 665 |
+
elif args.dataset == "gopro":
|
| 666 |
+
val_dataset = GoProMotionBlurDataset(
|
| 667 |
+
data_dir=os.path.join(args.base_dir, args.video_root_dir),
|
| 668 |
+
split = args.val_split,
|
| 669 |
+
image_size=(args.height, args.width),
|
| 670 |
+
stride=(args.stride_min, args.stride_max),
|
| 671 |
+
sample_n_frames=args.max_num_frames,
|
| 672 |
+
hflip_p=args.hflip_p,
|
| 673 |
+
)
|
| 674 |
+
elif args.dataset == "gopro2x":
|
| 675 |
+
val_dataset = GoPro2xMotionBlurDataset(
|
| 676 |
+
data_dir=os.path.join(args.base_dir, args.video_root_dir),
|
| 677 |
+
split = args.val_split,
|
| 678 |
+
image_size=(args.height, args.width),
|
| 679 |
+
stride=(args.stride_min, args.stride_max),
|
| 680 |
+
sample_n_frames=args.max_num_frames,
|
| 681 |
+
hflip_p=args.hflip_p,
|
| 682 |
+
)
|
| 683 |
+
elif args.dataset == "goprolarge":
|
| 684 |
+
val_dataset = GoProLargeMotionBlurDataset(
|
| 685 |
+
data_dir=os.path.join(args.base_dir, args.video_root_dir),
|
| 686 |
+
split = args.val_split,
|
| 687 |
+
image_size=(args.height, args.width),
|
| 688 |
+
stride=(args.stride_min, args.stride_max),
|
| 689 |
+
sample_n_frames=args.max_num_frames,
|
| 690 |
+
hflip_p=args.hflip_p,
|
| 691 |
+
)
|
| 692 |
+
elif args.dataset == "full":
|
| 693 |
+
val_dataset = FullMotionBlurDataset(
|
| 694 |
+
data_dir=os.path.join(args.base_dir, args.video_root_dir),
|
| 695 |
+
split = args.val_split,
|
| 696 |
+
image_size=(args.height, args.width),
|
| 697 |
+
stride=(args.stride_min, args.stride_max),
|
| 698 |
+
sample_n_frames=args.max_num_frames,
|
| 699 |
+
hflip_p=args.hflip_p,
|
| 700 |
+
)
|
| 701 |
+
elif args.dataset == "fullablation":
|
| 702 |
+
val_dataset = AblationFullMotionBlurDataset(
|
| 703 |
+
data_dir=os.path.join(args.base_dir, args.video_root_dir),
|
| 704 |
+
split = args.val_split,
|
| 705 |
+
image_size=(args.height, args.width),
|
| 706 |
+
stride=(args.stride_min, args.stride_max),
|
| 707 |
+
sample_n_frames=args.max_num_frames,
|
| 708 |
+
hflip_p=args.hflip_p,
|
| 709 |
+
ablation_mode = args.ablation_mode, #this is not called for now
|
| 710 |
+
)
|
| 711 |
+
elif args.dataset == "baist":
|
| 712 |
+
val_dataset = BAISTDataset(
|
| 713 |
+
data_dir=os.path.join(args.base_dir, args.video_root_dir),
|
| 714 |
+
split = args.val_split,
|
| 715 |
+
image_size=(args.height, args.width),
|
| 716 |
+
stride=(args.stride_min, args.stride_max),
|
| 717 |
+
sample_n_frames=args.max_num_frames,
|
| 718 |
+
hflip_p=args.hflip_p,
|
| 719 |
+
)
|
| 720 |
+
elif args.dataset == "simplebaist":
|
| 721 |
+
val_dataset = SimpleBAISTDataset(
|
| 722 |
+
data_dir=os.path.join(args.base_dir, args.video_root_dir),
|
| 723 |
+
split = args.val_split,
|
| 724 |
+
image_size=(args.height, args.width),
|
| 725 |
+
stride=(args.stride_min, args.stride_max),
|
| 726 |
+
sample_n_frames=args.max_num_frames,
|
| 727 |
+
hflip_p=args.hflip_p,
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
def encode_video(video):
|
| 731 |
+
video = video.to(accelerator.device, dtype=vae.dtype)
|
| 732 |
+
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
| 733 |
+
latent_dist = vae.encode(video).latent_dist.sample() * vae.config.scaling_factor
|
| 734 |
+
return latent_dist.permute(0, 2, 1, 3, 4).to(memory_format=torch.contiguous_format)
|
| 735 |
+
|
| 736 |
+
# def atomic_save(save_path, accelerator):
|
| 737 |
+
|
| 738 |
+
# dir_name = os.path.dirname(save_path)
|
| 739 |
+
# with tempfile.NamedTemporaryFile(delete=False, dir=dir_name) as tmp_file:
|
| 740 |
+
# tmp_path = tmp_file.name
|
| 741 |
+
# # Close the file so that it can be moved later
|
| 742 |
+
# #delete anything at the tmp_path
|
| 743 |
+
# if accelerator.is_main_process:
|
| 744 |
+
# accelerator.save_state(tmp_path) #just a backup incase things go crazy
|
| 745 |
+
# accelerator.save_state(save_path)
|
| 746 |
+
# os.remove(tmp_path)
|
| 747 |
+
# accelerator.wait_for_everyone()
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
|
| 751 |
+
def atomic_save(save_path, accelerator):
|
| 752 |
+
parent = os.path.dirname(save_path)
|
| 753 |
+
tmp_dir = tempfile.mkdtemp(dir=parent)
|
| 754 |
+
backup_dir = save_path + "_backup"
|
| 755 |
+
|
| 756 |
+
try:
|
| 757 |
+
# Save state into the temp directory
|
| 758 |
+
accelerator.save_state(tmp_dir)
|
| 759 |
+
|
| 760 |
+
# Backup existing save_path if it exists
|
| 761 |
+
if os.path.exists(save_path):
|
| 762 |
+
os.rename(save_path, backup_dir)
|
| 763 |
+
|
| 764 |
+
# Atomically move temp directory into place
|
| 765 |
+
os.rename(tmp_dir, save_path)
|
| 766 |
+
|
| 767 |
+
# Clean up the backup directory
|
| 768 |
+
if os.path.exists(backup_dir):
|
| 769 |
+
shutil.rmtree(backup_dir)
|
| 770 |
+
|
| 771 |
+
except Exception as e:
|
| 772 |
+
# Clean up temp directory on failure
|
| 773 |
+
if os.path.exists(tmp_dir):
|
| 774 |
+
shutil.rmtree(tmp_dir)
|
| 775 |
+
|
| 776 |
+
# Restore from backup if replacement failed
|
| 777 |
+
if os.path.exists(backup_dir):
|
| 778 |
+
if os.path.exists(save_path):
|
| 779 |
+
shutil.rmtree(save_path)
|
| 780 |
+
os.rename(backup_dir, save_path)
|
| 781 |
+
|
| 782 |
+
raise e
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
def collate_fn(examples):
|
| 787 |
+
blur_img = [example["blur_img"] for example in examples]
|
| 788 |
+
videos = [example["video"] for example in examples]
|
| 789 |
+
if "high_fps_video" in examples[0]:
|
| 790 |
+
high_fps_videos = [example["high_fps_video"] for example in examples]
|
| 791 |
+
high_fps_videos = torch.stack(high_fps_videos)
|
| 792 |
+
high_fps_videos = high_fps_videos.to(memory_format=torch.contiguous_format).float()
|
| 793 |
+
if "bbx" in examples[0]:
|
| 794 |
+
bbx = [example["bbx"] for example in examples]
|
| 795 |
+
bbx = torch.stack(bbx)
|
| 796 |
+
bbx = bbx.to(memory_format=torch.contiguous_format).float()
|
| 797 |
+
prompts = [example["caption"] for example in examples]
|
| 798 |
+
file_names = [example["file_name"] for example in examples]
|
| 799 |
+
num_frames = [example["num_frames"] for example in examples]
|
| 800 |
+
# if full_file_names in examples[0]:
|
| 801 |
+
# full_file_names = [example["full_file_name"] for example in examples]
|
| 802 |
+
input_intervals = [example["input_interval"] for example in examples]
|
| 803 |
+
output_intervals = [example["output_interval"] for example in examples]
|
| 804 |
+
ablation_condition = [example["ablation_condition"] for example in examples] if "ablation_condition" in examples[0] else None
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
videos = torch.stack(videos)
|
| 808 |
+
videos = videos.to(memory_format=torch.contiguous_format).float()
|
| 809 |
+
|
| 810 |
+
blur_img = torch.stack(blur_img)
|
| 811 |
+
blur_img = blur_img.to(memory_format=torch.contiguous_format).float()
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
input_intervals = torch.stack(input_intervals)
|
| 815 |
+
if args.dataset == "gopro":
|
| 816 |
+
input_intervals = input_intervals.to(memory_format=torch.contiguous_format).long() #this is a bug, but I trained it like this on GOPRO (sets intervals all to 0), model doesn't need intervals for this dataset cause its always 7 frames in the same spacing
|
| 817 |
+
else:
|
| 818 |
+
input_intervals = input_intervals.to(memory_format=torch.contiguous_format).float()
|
| 819 |
+
|
| 820 |
+
output_intervals = torch.stack(output_intervals)
|
| 821 |
+
if args.dataset == "gopro":
|
| 822 |
+
output_intervals = output_intervals.to(memory_format=torch.contiguous_format).long() #this is a bug, but I trained it like this on GOPRO (sets intervals all to 0), model doesn't need intervals for this dataset cause its always 7 frames in the same spacing
|
| 823 |
+
else:
|
| 824 |
+
output_intervals = output_intervals.to(memory_format=torch.contiguous_format).float()
|
| 825 |
+
|
| 826 |
+
#just used for ablation studies
|
| 827 |
+
ablation_condition = torch.stack(ablation_condition) if ablation_condition is not None else None
|
| 828 |
+
if ablation_condition is not None:
|
| 829 |
+
ablation_condition = ablation_condition.to(memory_format=torch.contiguous_format).float()
|
| 830 |
+
|
| 831 |
+
out_dict = {
|
| 832 |
+
"file_names": file_names,
|
| 833 |
+
"blur_img": blur_img,
|
| 834 |
+
"videos": videos,
|
| 835 |
+
"num_frames": num_frames,
|
| 836 |
+
"prompts": prompts,
|
| 837 |
+
"input_intervals": input_intervals,
|
| 838 |
+
"output_intervals": output_intervals,
|
| 839 |
+
}
|
| 840 |
+
|
| 841 |
+
if "high_fps_video" in examples[0]:
|
| 842 |
+
out_dict["high_fps_video"] = high_fps_videos
|
| 843 |
+
if "bbx" in examples[0]:
|
| 844 |
+
out_dict["bbx"] = bbx
|
| 845 |
+
if ablation_condition is not None:
|
| 846 |
+
out_dict["ablation_condition"] = ablation_condition
|
| 847 |
+
return out_dict
|
| 848 |
+
|
| 849 |
+
train_dataloader = DataLoader(
|
| 850 |
+
train_dataset,
|
| 851 |
+
batch_size=args.train_batch_size,
|
| 852 |
+
shuffle=True,
|
| 853 |
+
collate_fn=collate_fn,
|
| 854 |
+
num_workers=args.dataloader_num_workers,
|
| 855 |
+
)
|
| 856 |
+
|
| 857 |
+
val_dataloader = DataLoader(
|
| 858 |
+
val_dataset,
|
| 859 |
+
batch_size=1,
|
| 860 |
+
shuffle=False,
|
| 861 |
+
collate_fn=collate_fn,
|
| 862 |
+
num_workers=args.dataloader_num_workers,
|
| 863 |
+
)
|
| 864 |
+
|
| 865 |
+
# Scheduler and math around the number of training steps.
|
| 866 |
+
overrode_max_train_steps = False
|
| 867 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 868 |
+
if args.max_train_steps is None:
|
| 869 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 870 |
+
overrode_max_train_steps = True
|
| 871 |
+
|
| 872 |
+
if use_deepspeed_scheduler:
|
| 873 |
+
from accelerate.utils import DummyScheduler
|
| 874 |
+
|
| 875 |
+
lr_scheduler = DummyScheduler(
|
| 876 |
+
name=args.lr_scheduler,
|
| 877 |
+
optimizer=optimizer,
|
| 878 |
+
total_num_steps=args.max_train_steps * accelerator.num_processes,
|
| 879 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
| 880 |
+
)
|
| 881 |
+
else:
|
| 882 |
+
lr_scheduler = get_scheduler(
|
| 883 |
+
args.lr_scheduler,
|
| 884 |
+
optimizer=optimizer,
|
| 885 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
| 886 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
| 887 |
+
num_cycles=args.lr_num_cycles,
|
| 888 |
+
power=args.lr_power,
|
| 889 |
+
)
|
| 890 |
+
|
| 891 |
+
# Prepare everything with our `accelerator`.
|
| 892 |
+
transformer, optimizer, train_dataloader, lr_scheduler, val_dataloader = accelerator.prepare(
|
| 893 |
+
transformer, optimizer, train_dataloader, lr_scheduler, val_dataloader
|
| 894 |
+
)
|
| 895 |
+
|
| 896 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 897 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 898 |
+
if overrode_max_train_steps:
|
| 899 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 900 |
+
# Afterwards we recalculate our number of training epochs
|
| 901 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 902 |
+
|
| 903 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
| 904 |
+
# The trackers initializes automatically on the main process.
|
| 905 |
+
if accelerator.is_main_process:
|
| 906 |
+
tracker_name = args.tracker_name or "cogvideox-controlnet"
|
| 907 |
+
accelerator.init_trackers(tracker_name, config=vars(args))
|
| 908 |
+
|
| 909 |
+
|
| 910 |
+
accelerator.register_for_checkpointing(transformer, optimizer, lr_scheduler)
|
| 911 |
+
save_path = os.path.join(args.output_dir, f"checkpoint")
|
| 912 |
+
|
| 913 |
+
#check if the checkpoint already exists
|
| 914 |
+
if os.path.exists(save_path):
|
| 915 |
+
accelerator.load_state(save_path)
|
| 916 |
+
logger.info(f"Loaded state from {save_path}")
|
| 917 |
+
|
| 918 |
+
|
| 919 |
+
|
| 920 |
+
# Train!
|
| 921 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 922 |
+
num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
|
| 923 |
+
|
| 924 |
+
logger.info("***** Running training *****")
|
| 925 |
+
logger.info(f" Num trainable parameters = {num_trainable_parameters}")
|
| 926 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
| 927 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
| 928 |
+
logger.info(f" Num epochs = {args.num_train_epochs}")
|
| 929 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 930 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 931 |
+
logger.info(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
|
| 932 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 933 |
+
global_step = 0
|
| 934 |
+
first_epoch = 0
|
| 935 |
+
initial_global_step = 0
|
| 936 |
+
|
| 937 |
+
progress_bar = tqdm(
|
| 938 |
+
range(0, args.max_train_steps),
|
| 939 |
+
initial=initial_global_step,
|
| 940 |
+
desc="Steps",
|
| 941 |
+
# Only show the progress bar once on each machine.
|
| 942 |
+
disable=not accelerator.is_local_main_process,
|
| 943 |
+
)
|
| 944 |
+
vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1)
|
| 945 |
+
|
| 946 |
+
# For DeepSpeed training
|
| 947 |
+
model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
|
| 948 |
+
|
| 949 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
| 950 |
+
transformer.train()
|
| 951 |
+
for step, batch in enumerate(train_dataloader):
|
| 952 |
+
if not args.just_validate:
|
| 953 |
+
models_to_accumulate = [transformer]
|
| 954 |
+
with accelerator.accumulate(models_to_accumulate):
|
| 955 |
+
model_input = encode_video(batch["videos"]).to(dtype=weight_dtype) # [B, F, C, H, W]
|
| 956 |
+
prompts = batch["prompts"]
|
| 957 |
+
image_latent = encode_video(batch["blur_img"]).to(dtype=weight_dtype) # [B, F, C, H, W]
|
| 958 |
+
input_intervals = batch["input_intervals"]
|
| 959 |
+
output_intervals = batch["output_intervals"]
|
| 960 |
+
ablation_condition = batch["ablation_condition"] if "ablation_condition" in batch else None
|
| 961 |
+
|
| 962 |
+
batch_size = len(prompts)
|
| 963 |
+
# True = use real prompt (conditional); False = drop to empty (unconditional)
|
| 964 |
+
guidance_mask = torch.rand(batch_size, device=accelerator.device) >= 0.2
|
| 965 |
+
|
| 966 |
+
# build a new prompts list: keep the original where mask True, else blank
|
| 967 |
+
per_sample_prompts = [
|
| 968 |
+
prompts[i] if guidance_mask[i] else ""
|
| 969 |
+
for i in range(batch_size)
|
| 970 |
+
]
|
| 971 |
+
prompts = per_sample_prompts
|
| 972 |
+
|
| 973 |
+
# encode prompts
|
| 974 |
+
prompt_embeds = compute_prompt_embeddings(
|
| 975 |
+
tokenizer,
|
| 976 |
+
text_encoder,
|
| 977 |
+
prompts,
|
| 978 |
+
model_config.max_text_seq_length,
|
| 979 |
+
accelerator.device,
|
| 980 |
+
weight_dtype,
|
| 981 |
+
requires_grad=False,
|
| 982 |
+
)
|
| 983 |
+
|
| 984 |
+
# Sample noise that will be added to the latents
|
| 985 |
+
noise = torch.randn_like(model_input)
|
| 986 |
+
batch_size, num_frames, num_channels, height, width = model_input.shape
|
| 987 |
+
|
| 988 |
+
# Sample a random timestep for each image
|
| 989 |
+
timesteps = torch.randint(
|
| 990 |
+
0, scheduler.config.num_train_timesteps, (batch_size,), device=model_input.device
|
| 991 |
+
)
|
| 992 |
+
timesteps = timesteps.long()
|
| 993 |
+
|
| 994 |
+
# Prepare rotary embeds
|
| 995 |
+
image_rotary_emb = (
|
| 996 |
+
prepare_rotary_positional_embeddings(
|
| 997 |
+
height=args.height,
|
| 998 |
+
width=args.width,
|
| 999 |
+
num_frames=num_frames,
|
| 1000 |
+
vae_scale_factor_spatial=vae_scale_factor_spatial,
|
| 1001 |
+
patch_size=model_config.patch_size,
|
| 1002 |
+
attention_head_dim=model_config.attention_head_dim,
|
| 1003 |
+
device=accelerator.device,
|
| 1004 |
+
)
|
| 1005 |
+
if model_config.use_rotary_positional_embeddings
|
| 1006 |
+
else None
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
# Add noise to the model input according to the noise magnitude at each timestep (this is the forward diffusion process)
|
| 1010 |
+
noisy_model_input = scheduler.add_noise(model_input, noise, timesteps)
|
| 1011 |
+
|
| 1012 |
+
input_intervals = transform_intervals(input_intervals, frames_per_latent=4)
|
| 1013 |
+
output_intervals = transform_intervals(output_intervals, frames_per_latent=4)
|
| 1014 |
+
|
| 1015 |
+
#first interval is always rep
|
| 1016 |
+
noisy_model_input, target, condition_mask, intervals = random_insert_latent_frame(image_latent, noisy_model_input, model_input, input_intervals, output_intervals, special_info=args.special_info)
|
| 1017 |
+
|
| 1018 |
+
for i in range(batch_size):
|
| 1019 |
+
if not guidance_mask[i]:
|
| 1020 |
+
noisy_model_input[i][condition_mask[i]] = 0
|
| 1021 |
+
|
| 1022 |
+
# Predict the noise residual
|
| 1023 |
+
model_output = transformer(
|
| 1024 |
+
hidden_states=noisy_model_input,
|
| 1025 |
+
encoder_hidden_states=prompt_embeds,
|
| 1026 |
+
intervals=intervals,
|
| 1027 |
+
condition_mask=condition_mask,
|
| 1028 |
+
timestep=timesteps,
|
| 1029 |
+
image_rotary_emb=image_rotary_emb,
|
| 1030 |
+
return_dict=False,
|
| 1031 |
+
ablation_condition = ablation_condition
|
| 1032 |
+
)[0]
|
| 1033 |
+
|
| 1034 |
+
#this line below is also scaling the input which is bad - so the model is also learning to scale this input latent somehow
|
| 1035 |
+
#thus, we need to replace the first frame with the original frame later
|
| 1036 |
+
model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps)
|
| 1037 |
+
|
| 1038 |
+
|
| 1039 |
+
|
| 1040 |
+
alphas_cumprod = scheduler.alphas_cumprod[timesteps]
|
| 1041 |
+
weights = 1 / (1 - alphas_cumprod)
|
| 1042 |
+
while len(weights.shape) < len(model_pred.shape):
|
| 1043 |
+
weights = weights.unsqueeze(-1)
|
| 1044 |
+
|
| 1045 |
+
|
| 1046 |
+
|
| 1047 |
+
loss = torch.mean((weights * (model_pred[~condition_mask] - target[~condition_mask]) ** 2).reshape(batch_size, -1), dim=1)
|
| 1048 |
+
loss = loss.mean()
|
| 1049 |
+
accelerator.backward(loss)
|
| 1050 |
+
|
| 1051 |
+
if accelerator.state.deepspeed_plugin is None:
|
| 1052 |
+
if not args.just_validate:
|
| 1053 |
+
optimizer.step()
|
| 1054 |
+
optimizer.zero_grad()
|
| 1055 |
+
|
| 1056 |
+
lr_scheduler.step()
|
| 1057 |
+
|
| 1058 |
+
|
| 1059 |
+
#wait for all processes to finish
|
| 1060 |
+
accelerator.wait_for_everyone()
|
| 1061 |
+
|
| 1062 |
+
|
| 1063 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 1064 |
+
if accelerator.sync_gradients:
|
| 1065 |
+
progress_bar.update(1)
|
| 1066 |
+
global_step += 1
|
| 1067 |
+
|
| 1068 |
+
if signal_recieved_time != 0:
|
| 1069 |
+
if time.time() - signal_recieved_time > 60:
|
| 1070 |
+
print("Signal received, saving state and exiting")
|
| 1071 |
+
#accelerator.save_state(save_path)
|
| 1072 |
+
atomic_save(save_path, accelerator)
|
| 1073 |
+
signal_recieved_time = 0
|
| 1074 |
+
exit(0)
|
| 1075 |
+
else:
|
| 1076 |
+
exit(0)
|
| 1077 |
+
|
| 1078 |
+
if accelerator.is_main_process:
|
| 1079 |
+
if global_step % args.checkpointing_steps == 0:
|
| 1080 |
+
#accelerator.save_state(save_path)
|
| 1081 |
+
atomic_save(save_path, accelerator)
|
| 1082 |
+
logger.info(f"Saved state to {save_path}")
|
| 1083 |
+
|
| 1084 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 1085 |
+
progress_bar.set_postfix(**logs)
|
| 1086 |
+
accelerator.log(logs, step=global_step)
|
| 1087 |
+
|
| 1088 |
+
if global_step >= args.max_train_steps:
|
| 1089 |
+
break
|
| 1090 |
+
|
| 1091 |
+
print("Step", step)
|
| 1092 |
+
accelerator.wait_for_everyone()
|
| 1093 |
+
|
| 1094 |
+
if step == 0 or args.validation_prompt is not None and (step + 1) % args.validation_steps == 0:
|
| 1095 |
+
# Create pipeline
|
| 1096 |
+
pipe = ControlnetCogVideoXPipeline.from_pretrained(
|
| 1097 |
+
os.path.join(args.base_dir, args.pretrained_model_name_or_path),
|
| 1098 |
+
transformer=unwrap_model(transformer),
|
| 1099 |
+
text_encoder=unwrap_model(text_encoder),
|
| 1100 |
+
vae=unwrap_model(vae),
|
| 1101 |
+
scheduler=scheduler,
|
| 1102 |
+
torch_dtype=weight_dtype,
|
| 1103 |
+
)
|
| 1104 |
+
|
| 1105 |
+
print("Length of validation dataset: ", len(val_dataloader))
|
| 1106 |
+
#create a pipeline per accelerator device (for faster inference)
|
| 1107 |
+
with torch.autocast(str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"):
|
| 1108 |
+
for batch in val_dataloader:
|
| 1109 |
+
frame = ((batch["blur_img"][0].permute(0,2,3,1).cpu().numpy() + 1)*127.5).astype(np.uint8)
|
| 1110 |
+
pipeline_args = {
|
| 1111 |
+
"prompt": "",
|
| 1112 |
+
"negative_prompt": "",
|
| 1113 |
+
"image": frame,
|
| 1114 |
+
"input_intervals": batch["input_intervals"][0:1],
|
| 1115 |
+
"output_intervals": batch["output_intervals"][0:1],
|
| 1116 |
+
"ablation_condition": batch["ablation_condition"][0:1] if "ablation_condition" in batch else None,
|
| 1117 |
+
"guidance_scale": args.guidance_scale,
|
| 1118 |
+
"use_dynamic_cfg": args.use_dynamic_cfg,
|
| 1119 |
+
"height": args.height,
|
| 1120 |
+
"width": args.width,
|
| 1121 |
+
"num_frames": args.max_num_frames,
|
| 1122 |
+
"num_inference_steps": args.num_inference_steps,
|
| 1123 |
+
}
|
| 1124 |
+
|
| 1125 |
+
modified_filenames = []
|
| 1126 |
+
filenames = batch['file_names']
|
| 1127 |
+
for file in filenames:
|
| 1128 |
+
modified_filenames.append(os.path.splitext(file)[0] + ".mp4")
|
| 1129 |
+
|
| 1130 |
+
num_frames = batch["num_frames"][0]
|
| 1131 |
+
#save the gt_video output
|
| 1132 |
+
if args.dataset not in ["outsidephotos"]:
|
| 1133 |
+
gt_video = batch["videos"][0].permute(0,2,3,1).cpu().numpy()
|
| 1134 |
+
gt_video = ((gt_video + 1) * 127.5)/255
|
| 1135 |
+
gt_video = gt_video[0:num_frames]
|
| 1136 |
+
|
| 1137 |
+
for file in modified_filenames:
|
| 1138 |
+
#create the directory if it does not exist
|
| 1139 |
+
gt_file_name = os.path.join(args.output_dir, "gt", modified_filenames[0])
|
| 1140 |
+
os.makedirs(os.path.dirname(gt_file_name), exist_ok=True)
|
| 1141 |
+
if args.dataset in ["baist", "simplebaist"]:
|
| 1142 |
+
bbox = batch["bbx"][0].cpu().numpy().astype(np.int32)
|
| 1143 |
+
gt_video = gt_video[:, bbox[1]:bbox[3], bbox[0]:bbox[2], :]
|
| 1144 |
+
gt_video = np.array([cv2.resize(frame, (160, 192)) for frame in gt_video])
|
| 1145 |
+
|
| 1146 |
+
save_frames_as_pngs((gt_video*255).astype(np.uint8), gt_file_name.replace(".mp4", "").replace("gt", "gt_frames"))
|
| 1147 |
+
export_to_video(gt_video, gt_file_name, fps=20)
|
| 1148 |
+
|
| 1149 |
+
|
| 1150 |
+
if "high_fps_video" in batch:
|
| 1151 |
+
high_fps_video = batch["high_fps_video"][0].permute(0,2,3,1).cpu().numpy()
|
| 1152 |
+
high_fps_video = ((high_fps_video + 1) * 127.5)/255
|
| 1153 |
+
gt_file_name = os.path.join(args.output_dir, "gt_highfps", modified_filenames[0])
|
| 1154 |
+
|
| 1155 |
+
|
| 1156 |
+
if args.dataset in ["adobe", "full", "baist", "outsidephotos", "gopro2x", "goprolarge", "simplebaist"]:
|
| 1157 |
+
for file in modified_filenames:
|
| 1158 |
+
#create the directory if it does not exist
|
| 1159 |
+
blurry_file_name = os.path.join(args.output_dir, "blurry", modified_filenames[0].replace(".mp4", ".png"))
|
| 1160 |
+
#save the blurry image
|
| 1161 |
+
os.makedirs(os.path.dirname(blurry_file_name), exist_ok=True)
|
| 1162 |
+
if args.dataset in ["baist", "simplebaist"]:
|
| 1163 |
+
bbox = batch["bbx"][0].cpu().numpy().astype(np.int32)
|
| 1164 |
+
frame0 = frame[0][bbox[1]:bbox[3], bbox[0]:bbox[2], :]
|
| 1165 |
+
#resize to 192x160
|
| 1166 |
+
frame0 = cv2.resize(frame0, (160, 192))
|
| 1167 |
+
Image.fromarray(frame0).save(blurry_file_name)
|
| 1168 |
+
else:
|
| 1169 |
+
Image.fromarray(frame[0]).save(blurry_file_name)
|
| 1170 |
+
|
| 1171 |
+
videos = log_validation(
|
| 1172 |
+
pipe=pipe,
|
| 1173 |
+
args=args,
|
| 1174 |
+
accelerator=accelerator,
|
| 1175 |
+
pipeline_args=pipeline_args,
|
| 1176 |
+
epoch=epoch,
|
| 1177 |
+
)
|
| 1178 |
+
|
| 1179 |
+
for i, video in enumerate(videos):
|
| 1180 |
+
prompt = (
|
| 1181 |
+
pipeline_args["prompt"][:25]
|
| 1182 |
+
.replace(" ", "_")
|
| 1183 |
+
.replace(" ", "_")
|
| 1184 |
+
.replace("'", "_")
|
| 1185 |
+
.replace('"', "_")
|
| 1186 |
+
.replace("/", "_")
|
| 1187 |
+
)
|
| 1188 |
+
video = video[0:num_frames]
|
| 1189 |
+
filename = os.path.join(args.output_dir, "deblurred", modified_filenames[0])
|
| 1190 |
+
print("Deblurred file name", filename)
|
| 1191 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
| 1192 |
+
if args.dataset in ["baist", "simplebaist"]:
|
| 1193 |
+
bbox = batch["bbx"][0].cpu().numpy().astype(np.int32)
|
| 1194 |
+
video = video[:, bbox[1]:bbox[3], bbox[0]:bbox[2], :]
|
| 1195 |
+
#resize to 192x160
|
| 1196 |
+
video = np.array([cv2.resize(frame, (160, 192)) for frame in video])
|
| 1197 |
+
save_frames_as_pngs((video*255).astype(np.uint8), filename.replace(".mp4", "").replace("deblurred", "deblurred_frames"))
|
| 1198 |
+
export_to_video(video, filename, fps=20)
|
| 1199 |
+
|
| 1200 |
+
accelerator.wait_for_everyone()
|
| 1201 |
+
|
| 1202 |
+
if args.just_validate:
|
| 1203 |
+
exit(0)
|
| 1204 |
+
|
| 1205 |
+
accelerator.wait_for_everyone()
|
| 1206 |
+
accelerator.end_training()
|
| 1207 |
+
|
| 1208 |
+
signal_recieved_time = 0
|
| 1209 |
+
|
| 1210 |
+
def handle_signal(signum, frame):
|
| 1211 |
+
global signal_recieved_time
|
| 1212 |
+
signal_recieved_time = time.time()
|
| 1213 |
+
|
| 1214 |
+
print(f"Signal {signum} received at {time.ctime()}")
|
| 1215 |
+
|
| 1216 |
+
with open("/datasets/sai/gencam/cogvideox/interrupted.txt", "w") as f:
|
| 1217 |
+
f.write(f"Training was interrupted at {time.ctime()}")
|
| 1218 |
+
|
| 1219 |
+
if __name__ == "__main__":
|
| 1220 |
+
|
| 1221 |
+
args = get_args()
|
| 1222 |
+
|
| 1223 |
+
print("Registering signal handler")
|
| 1224 |
+
#Register the signal handler (catch SIGUSR1)
|
| 1225 |
+
signal.signal(signal.SIGUSR1, handle_signal)
|
| 1226 |
+
|
| 1227 |
+
main_thread = threading.Thread(target=main, args=(args,))
|
| 1228 |
+
main_thread.start()
|
| 1229 |
+
|
| 1230 |
+
print("SIGNAL RECIEVED TIME", signal_recieved_time)
|
| 1231 |
+
while signal_recieved_time!= 0:
|
| 1232 |
+
time.sleep(1)
|
| 1233 |
+
|
| 1234 |
+
#call main with args as a thread
|
| 1235 |
+
|
training/utils.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List, Optional, Union, Tuple
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
| 5 |
+
import numpy as np
|
| 6 |
+
import cv2
|
| 7 |
+
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
| 8 |
+
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
|
| 9 |
+
from accelerate.logging import get_logger
|
| 10 |
+
import tempfile
|
| 11 |
+
import argparse
|
| 12 |
+
import yaml
|
| 13 |
+
import shutil
|
| 14 |
+
|
| 15 |
+
logger = get_logger(__name__)
|
| 16 |
+
|
| 17 |
+
def get_args():
|
| 18 |
+
parser = argparse.ArgumentParser(description="Training script for CogVideoX using config file.")
|
| 19 |
+
parser.add_argument(
|
| 20 |
+
"--config",
|
| 21 |
+
type=str,
|
| 22 |
+
required=True,
|
| 23 |
+
help="Path to the YAML config file."
|
| 24 |
+
)
|
| 25 |
+
args = parser.parse_args()
|
| 26 |
+
with open(args.config, "r") as f:
|
| 27 |
+
config = yaml.safe_load(f)
|
| 28 |
+
args = argparse.Namespace(**config)
|
| 29 |
+
# Convert nested config dict to an argparse.Namespace for easier downstream usage
|
| 30 |
+
return args
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def atomic_save(save_path, accelerator):
|
| 35 |
+
parent = os.path.dirname(save_path)
|
| 36 |
+
tmp_dir = tempfile.mkdtemp(dir=parent)
|
| 37 |
+
backup_dir = save_path + "_backup"
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
# Save state into the temp directory
|
| 41 |
+
accelerator.save_state(tmp_dir)
|
| 42 |
+
|
| 43 |
+
# Backup existing save_path if it exists
|
| 44 |
+
if os.path.exists(save_path):
|
| 45 |
+
os.rename(save_path, backup_dir)
|
| 46 |
+
|
| 47 |
+
# Atomically move temp directory into place
|
| 48 |
+
os.rename(tmp_dir, save_path)
|
| 49 |
+
|
| 50 |
+
# Clean up the backup directory
|
| 51 |
+
if os.path.exists(backup_dir):
|
| 52 |
+
shutil.rmtree(backup_dir)
|
| 53 |
+
|
| 54 |
+
except Exception as e:
|
| 55 |
+
# Clean up temp directory on failure
|
| 56 |
+
if os.path.exists(tmp_dir):
|
| 57 |
+
shutil.rmtree(tmp_dir)
|
| 58 |
+
|
| 59 |
+
# Restore from backup if replacement failed
|
| 60 |
+
if os.path.exists(backup_dir):
|
| 61 |
+
if os.path.exists(save_path):
|
| 62 |
+
shutil.rmtree(save_path)
|
| 63 |
+
os.rename(backup_dir, save_path)
|
| 64 |
+
|
| 65 |
+
raise e
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
|
| 69 |
+
# Use DeepSpeed optimzer
|
| 70 |
+
if use_deepspeed:
|
| 71 |
+
from accelerate.utils import DummyOptim
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
return DummyOptim(
|
| 75 |
+
params_to_optimize,
|
| 76 |
+
lr=args.learning_rate,
|
| 77 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 78 |
+
eps=args.adam_epsilon,
|
| 79 |
+
weight_decay=args.adam_weight_decay,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Optimizer creation
|
| 83 |
+
supported_optimizers = ["adam", "adamw", "prodigy"]
|
| 84 |
+
if args.optimizer not in supported_optimizers:
|
| 85 |
+
logger.warning(
|
| 86 |
+
f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW"
|
| 87 |
+
)
|
| 88 |
+
args.optimizer = "adamw"
|
| 89 |
+
|
| 90 |
+
if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]):
|
| 91 |
+
logger.warning(
|
| 92 |
+
f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was "
|
| 93 |
+
f"set to {args.optimizer.lower()}"
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
if args.use_8bit_adam:
|
| 97 |
+
try:
|
| 98 |
+
import bitsandbytes as bnb
|
| 99 |
+
except ImportError:
|
| 100 |
+
raise ImportError(
|
| 101 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
if args.optimizer.lower() == "adamw":
|
| 105 |
+
optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW
|
| 106 |
+
|
| 107 |
+
optimizer = optimizer_class(
|
| 108 |
+
params_to_optimize,
|
| 109 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 110 |
+
eps=args.adam_epsilon,
|
| 111 |
+
weight_decay=args.adam_weight_decay,
|
| 112 |
+
)
|
| 113 |
+
elif args.optimizer.lower() == "adam":
|
| 114 |
+
optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
optimizer = optimizer_class(
|
| 118 |
+
params_to_optimize,
|
| 119 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 120 |
+
eps=args.adam_epsilon,
|
| 121 |
+
weight_decay=args.adam_weight_decay,
|
| 122 |
+
)
|
| 123 |
+
elif args.optimizer.lower() == "prodigy":
|
| 124 |
+
try:
|
| 125 |
+
import prodigyopt
|
| 126 |
+
except ImportError:
|
| 127 |
+
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
|
| 128 |
+
|
| 129 |
+
optimizer_class = prodigyopt.Prodigy
|
| 130 |
+
|
| 131 |
+
if args.learning_rate <= 0.1:
|
| 132 |
+
logger.warning(
|
| 133 |
+
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
optimizer = optimizer_class(
|
| 137 |
+
params_to_optimize,
|
| 138 |
+
lr=args.learning_rate,
|
| 139 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 140 |
+
beta3=args.prodigy_beta3,
|
| 141 |
+
weight_decay=args.adam_weight_decay,
|
| 142 |
+
eps=args.adam_epsilon,
|
| 143 |
+
decouple=args.prodigy_decouple,
|
| 144 |
+
use_bias_correction=args.prodigy_use_bias_correction,
|
| 145 |
+
safeguard_warmup=args.prodigy_safeguard_warmup,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
return optimizer
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def prepare_rotary_positional_embeddings(
|
| 152 |
+
height: int,
|
| 153 |
+
width: int,
|
| 154 |
+
num_frames: int,
|
| 155 |
+
vae_scale_factor_spatial: int = 8,
|
| 156 |
+
patch_size: int = 2,
|
| 157 |
+
attention_head_dim: int = 64,
|
| 158 |
+
device: Optional[torch.device] = None,
|
| 159 |
+
base_height: int = 480,
|
| 160 |
+
base_width: int = 720,
|
| 161 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 162 |
+
grid_height = height // (vae_scale_factor_spatial * patch_size)
|
| 163 |
+
grid_width = width // (vae_scale_factor_spatial * patch_size)
|
| 164 |
+
base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
|
| 165 |
+
base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
|
| 166 |
+
|
| 167 |
+
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
|
| 168 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 169 |
+
embed_dim=attention_head_dim,
|
| 170 |
+
crops_coords=grid_crops_coords,
|
| 171 |
+
grid_size=(grid_height, grid_width),
|
| 172 |
+
temporal_size=num_frames,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
freqs_cos = freqs_cos.to(device=device)
|
| 176 |
+
freqs_sin = freqs_sin.to(device=device)
|
| 177 |
+
return freqs_cos, freqs_sin
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _get_t5_prompt_embeds(
|
| 181 |
+
tokenizer: T5Tokenizer,
|
| 182 |
+
text_encoder: T5EncoderModel,
|
| 183 |
+
prompt: Union[str, List[str]],
|
| 184 |
+
num_videos_per_prompt: int = 1,
|
| 185 |
+
max_sequence_length: int = 226,
|
| 186 |
+
device: Optional[torch.device] = None,
|
| 187 |
+
dtype: Optional[torch.dtype] = None,
|
| 188 |
+
text_input_ids=None,
|
| 189 |
+
):
|
| 190 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 191 |
+
batch_size = len(prompt)
|
| 192 |
+
|
| 193 |
+
if tokenizer is not None:
|
| 194 |
+
text_inputs = tokenizer(
|
| 195 |
+
prompt,
|
| 196 |
+
padding="max_length",
|
| 197 |
+
max_length=max_sequence_length,
|
| 198 |
+
truncation=True,
|
| 199 |
+
add_special_tokens=True,
|
| 200 |
+
return_tensors="pt",
|
| 201 |
+
)
|
| 202 |
+
text_input_ids = text_inputs.input_ids
|
| 203 |
+
else:
|
| 204 |
+
if text_input_ids is None:
|
| 205 |
+
raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
|
| 206 |
+
|
| 207 |
+
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
| 208 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 209 |
+
|
| 210 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 211 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 212 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 213 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 214 |
+
|
| 215 |
+
return prompt_embeds
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def encode_prompt(
|
| 219 |
+
tokenizer: T5Tokenizer,
|
| 220 |
+
text_encoder: T5EncoderModel,
|
| 221 |
+
prompt: Union[str, List[str]],
|
| 222 |
+
num_videos_per_prompt: int = 1,
|
| 223 |
+
max_sequence_length: int = 226,
|
| 224 |
+
device: Optional[torch.device] = None,
|
| 225 |
+
dtype: Optional[torch.dtype] = None,
|
| 226 |
+
text_input_ids=None,
|
| 227 |
+
):
|
| 228 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 229 |
+
prompt_embeds = _get_t5_prompt_embeds(
|
| 230 |
+
tokenizer,
|
| 231 |
+
text_encoder,
|
| 232 |
+
prompt=prompt,
|
| 233 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 234 |
+
max_sequence_length=max_sequence_length,
|
| 235 |
+
device=device,
|
| 236 |
+
dtype=dtype,
|
| 237 |
+
text_input_ids=text_input_ids,
|
| 238 |
+
)
|
| 239 |
+
return prompt_embeds
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def compute_prompt_embeddings(
|
| 243 |
+
tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False
|
| 244 |
+
):
|
| 245 |
+
if requires_grad:
|
| 246 |
+
prompt_embeds = encode_prompt(
|
| 247 |
+
tokenizer,
|
| 248 |
+
text_encoder,
|
| 249 |
+
prompt,
|
| 250 |
+
num_videos_per_prompt=1,
|
| 251 |
+
max_sequence_length=max_sequence_length,
|
| 252 |
+
device=device,
|
| 253 |
+
dtype=dtype,
|
| 254 |
+
)
|
| 255 |
+
else:
|
| 256 |
+
with torch.no_grad():
|
| 257 |
+
prompt_embeds = encode_prompt(
|
| 258 |
+
tokenizer,
|
| 259 |
+
text_encoder,
|
| 260 |
+
prompt,
|
| 261 |
+
num_videos_per_prompt=1,
|
| 262 |
+
max_sequence_length=max_sequence_length,
|
| 263 |
+
device=device,
|
| 264 |
+
dtype=dtype,
|
| 265 |
+
)
|
| 266 |
+
return prompt_embeds
|
| 267 |
+
|
| 268 |
+
def save_frames_as_pngs(video_array,output_dir,
|
| 269 |
+
downsample_spatial=1, # e.g. 2 to halve width & height
|
| 270 |
+
downsample_temporal=1): # e.g. 2 to keep every 2nd frame
|
| 271 |
+
"""
|
| 272 |
+
Save each frame of a (T, H, W, C) numpy array as a PNG with no compression.
|
| 273 |
+
"""
|
| 274 |
+
assert video_array.ndim == 4 and video_array.shape[-1] == 3, \
|
| 275 |
+
"Expected (T, H, W, C=3) array"
|
| 276 |
+
assert video_array.dtype == np.uint8, "Expected uint8 array"
|
| 277 |
+
|
| 278 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 279 |
+
|
| 280 |
+
# temporal downsample
|
| 281 |
+
frames = video_array[::downsample_temporal]
|
| 282 |
+
|
| 283 |
+
# compute spatially downsampled size
|
| 284 |
+
T, H, W, _ = frames.shape
|
| 285 |
+
new_size = (W // downsample_spatial, H // downsample_spatial)
|
| 286 |
+
|
| 287 |
+
# PNG compression param: 0 = no compression
|
| 288 |
+
png_params = [cv2.IMWRITE_PNG_COMPRESSION, 0]
|
| 289 |
+
|
| 290 |
+
for idx, frame in enumerate(frames):
|
| 291 |
+
# frame is RGB; convert to BGR for OpenCV
|
| 292 |
+
bgr = frame[..., ::-1]
|
| 293 |
+
if downsample_spatial > 1:
|
| 294 |
+
bgr = cv2.resize(bgr, new_size, interpolation=cv2.INTER_NEAREST)
|
| 295 |
+
|
| 296 |
+
filename = os.path.join(output_dir, "frame_{:05d}.png".format(idx))
|
| 297 |
+
success = cv2.imwrite(filename, bgr, png_params)
|
| 298 |
+
if not success:
|
| 299 |
+
raise RuntimeError("Failed to write frame ")
|