File size: 6,749 Bytes
feb33a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import torch, os, json
from diffsynth import load_state_dict
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth.pipelines.flux_image_new import ControlNetInput
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task
from diffsynth.trainers.unified_dataset import UnifiedDataset
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class QwenImageTrainingModule(DiffusionTrainingModule):
def __init__(
self,
model_paths=None, model_id_with_origin_paths=None,
tokenizer_path=None, processor_path=None,
trainable_models=None,
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
extra_inputs=None,
enable_fp8_training=False,
task="sft",
):
super().__init__()
# Load models
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=enable_fp8_training)
tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path)
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config)
# Training mode
self.switch_pipe_to_training_mode(
self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
enable_fp8_training=enable_fp8_training,
)
# Store other configs
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
self.task = task
def forward_preprocess(self, data):
# CFG-sensitive parameters
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {"negative_prompt": ""}
# CFG-unsensitive parameters
inputs_shared = {
# Assume you are using this pipeline for inference,
# please fill in the input parameters.
"input_image": data["image"],
"height": data["image"].size[1],
"width": data["image"].size[0],
# Please do not modify the following parameters
# unless you clearly know what this will cause.
"cfg_scale": 1,
"rand_device": self.pipe.device,
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
"edit_image_auto_resize": True,
}
# Extra inputs
controlnet_input, blockwise_controlnet_input = {}, {}
for extra_input in self.extra_inputs:
if extra_input.startswith("blockwise_controlnet_"):
blockwise_controlnet_input[extra_input.replace("blockwise_controlnet_", "")] = data[extra_input]
elif extra_input.startswith("controlnet_"):
controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input]
else:
inputs_shared[extra_input] = data[extra_input]
if len(controlnet_input) > 0:
inputs_shared["controlnet_inputs"] = [ControlNetInput(**controlnet_input)]
if len(blockwise_controlnet_input) > 0:
inputs_shared["blockwise_controlnet_inputs"] = [ControlNetInput(**blockwise_controlnet_input)]
# Pipeline units will automatically process the input parameters.
for unit in self.pipe.units:
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
return {**inputs_shared, **inputs_posi}
def forward(self, data, inputs=None, return_inputs=False):
# Inputs
if inputs is None:
inputs = self.forward_preprocess(data)
else:
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
if return_inputs: return inputs
# Loss
if self.task == "sft":
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
loss = self.pipe.training_loss(**models, **inputs)
elif self.task == "data_process":
loss = inputs
elif self.task == "direct_distill":
loss = self.pipe.direct_distill_loss(**inputs)
else:
raise NotImplementedError(f"Unsupported task: {self.task}.")
return loss
if __name__ == "__main__":
parser = qwen_image_parser()
args = parser.parse_args()
dataset = UnifiedDataset(
base_path=args.dataset_base_path,
metadata_path=args.dataset_metadata_path,
repeat=args.dataset_repeat,
data_file_keys=args.data_file_keys.split(","),
main_data_operator=UnifiedDataset.default_image_operator(
base_path=args.dataset_base_path,
max_pixels=args.max_pixels,
height=args.height,
width=args.width,
height_division_factor=16,
width_division_factor=16,
)
)
model = QwenImageTrainingModule(
model_paths=args.model_paths,
model_id_with_origin_paths=args.model_id_with_origin_paths,
tokenizer_path=args.tokenizer_path,
processor_path=args.processor_path,
trainable_models=args.trainable_models,
lora_base_model=args.lora_base_model,
lora_target_modules=args.lora_target_modules,
lora_rank=args.lora_rank,
lora_checkpoint=args.lora_checkpoint,
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
enable_fp8_training=args.enable_fp8_training,
task=args.task,
)
model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt)
launcher_map = {
"sft": launch_training_task,
"data_process": launch_data_process_task,
"direct_distill": launch_training_task,
}
launcher_map[args.task](dataset, model, model_logger, args=args)
|