| |
|
|
| import json |
| import os |
|
|
| import torch |
| import torch.distributed.checkpoint as dist_cp |
| from peft import get_peft_model_state_dict |
| from safetensors.torch import load_file, save_file |
| from torch.distributed.checkpoint.default_planner import (DefaultLoadPlanner, |
| DefaultSavePlanner) |
| from torch.distributed.checkpoint.optimizer import \ |
| load_sharded_optimizer_state_dict |
| from torch.distributed.fsdp import (FullOptimStateDictConfig, |
| FullStateDictConfig) |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| from torch.distributed.fsdp import StateDictType |
|
|
| from fastvideo.utils.logging_ import main_print |
|
|
|
|
| def save_checkpoint_optimizer(model, |
| optimizer, |
| rank, |
| output_dir, |
| step, |
| discriminator=False): |
| with FSDP.state_dict_type( |
| model, |
| StateDictType.FULL_STATE_DICT, |
| FullStateDictConfig(offload_to_cpu=True, rank0_only=True), |
| FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), |
| ): |
| cpu_state = model.state_dict() |
| optim_state = FSDP.optim_state_dict( |
| model, |
| optimizer, |
| ) |
|
|
| |
| save_dir = os.path.join(output_dir, f"checkpoint-{step}") |
| os.makedirs(save_dir, exist_ok=True) |
| |
| if rank <= 0 and not discriminator: |
| weight_path = os.path.join(save_dir, |
| "diffusion_pytorch_model.safetensors") |
| save_file(cpu_state, weight_path) |
| config_dict = dict(model.config) |
| config_dict.pop('dtype') |
| config_path = os.path.join(save_dir, "config.json") |
| |
| with open(config_path, "w") as f: |
| json.dump(config_dict, f, indent=4) |
| optimizer_path = os.path.join(save_dir, "optimizer.pt") |
| torch.save(optim_state, optimizer_path) |
| else: |
| weight_path = os.path.join(save_dir, |
| "discriminator_pytorch_model.safetensors") |
| save_file(cpu_state, weight_path) |
| optimizer_path = os.path.join(save_dir, "discriminator_optimizer.pt") |
| torch.save(optim_state, optimizer_path) |
| main_print(f"--> checkpoint saved at step {step}") |
|
|
|
|
| def save_checkpoint(transformer, rank, output_dir, step, epoch): |
| main_print(f"--> saving checkpoint at step {step}") |
| with FSDP.state_dict_type( |
| transformer, |
| StateDictType.FULL_STATE_DICT, |
| FullStateDictConfig(offload_to_cpu=True, rank0_only=True), |
| ): |
| cpu_state = transformer.state_dict() |
| |
| if rank <= 0: |
| save_dir = os.path.join(output_dir, f"checkpoint-{step}-{epoch}") |
| os.makedirs(save_dir, exist_ok=True) |
| |
| weight_path = os.path.join(save_dir, |
| "diffusion_pytorch_model.safetensors") |
| save_file(cpu_state, weight_path) |
| config_dict = dict(transformer.config) |
| if "dtype" in config_dict: |
| del config_dict["dtype"] |
| config_path = os.path.join(save_dir, "config.json") |
| |
| with open(config_path, "w") as f: |
| json.dump(config_dict, f, indent=4) |
| main_print(f"--> checkpoint saved at step {step}") |
|
|
|
|
| def save_checkpoint_generator_discriminator( |
| model, |
| optimizer, |
| discriminator, |
| discriminator_optimizer, |
| rank, |
| output_dir, |
| step, |
| ): |
| with FSDP.state_dict_type( |
| model, |
| StateDictType.FULL_STATE_DICT, |
| FullStateDictConfig(offload_to_cpu=True, rank0_only=True), |
| ): |
| cpu_state = model.state_dict() |
|
|
| |
| save_dir = os.path.join(output_dir, f"checkpoint-{step}") |
| os.makedirs(save_dir, exist_ok=True) |
| hf_weight_dir = os.path.join(save_dir, "hf_weights") |
| os.makedirs(hf_weight_dir, exist_ok=True) |
| |
| if rank <= 0: |
| config_dict = dict(model.config) |
| config_path = os.path.join(hf_weight_dir, "config.json") |
| |
| with open(config_path, "w") as f: |
| json.dump(config_dict, f, indent=4) |
| weight_path = os.path.join(hf_weight_dir, |
| "diffusion_pytorch_model.safetensors") |
| save_file(cpu_state, weight_path) |
|
|
| main_print(f"--> saved HF weight checkpoint at path {hf_weight_dir}") |
| model_weight_dir = os.path.join(save_dir, "model_weights_state") |
| os.makedirs(model_weight_dir, exist_ok=True) |
| model_optimizer_dir = os.path.join(save_dir, "model_optimizer_state") |
| os.makedirs(model_optimizer_dir, exist_ok=True) |
| with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): |
| optim_state = FSDP.optim_state_dict(model, optimizer) |
| model_state = model.state_dict() |
| weight_state_dict = {"model": model_state} |
| dist_cp.save_state_dict( |
| state_dict=weight_state_dict, |
| storage_writer=dist_cp.FileSystemWriter(model_weight_dir), |
| planner=DefaultSavePlanner(), |
| ) |
| optimizer_state_dict = {"optimizer": optim_state} |
| dist_cp.save_state_dict( |
| state_dict=optimizer_state_dict, |
| storage_writer=dist_cp.FileSystemWriter(model_optimizer_dir), |
| planner=DefaultSavePlanner(), |
| ) |
|
|
| discriminator_fsdp_state_dir = os.path.join(save_dir, |
| "discriminator_fsdp_state") |
| os.makedirs(discriminator_fsdp_state_dir, exist_ok=True) |
| with FSDP.state_dict_type( |
| discriminator, |
| StateDictType.FULL_STATE_DICT, |
| FullStateDictConfig(offload_to_cpu=True, rank0_only=True), |
| FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), |
| ): |
| optim_state = FSDP.optim_state_dict(discriminator, |
| discriminator_optimizer) |
| model_state = discriminator.state_dict() |
| state_dict = {"optimizer": optim_state, "model": model_state} |
| if rank <= 0: |
| discriminator_fsdp_state_fil = os.path.join( |
| discriminator_fsdp_state_dir, "discriminator_state.pt") |
| torch.save(state_dict, discriminator_fsdp_state_fil) |
|
|
| main_print("--> saved FSDP state checkpoint") |
|
|
|
|
| def load_sharded_model(model, optimizer, model_dir, optimizer_dir): |
| with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): |
| weight_state_dict = {"model": model.state_dict()} |
|
|
| optim_state = load_sharded_optimizer_state_dict( |
| model_state_dict=weight_state_dict["model"], |
| optimizer_key="optimizer", |
| storage_reader=dist_cp.FileSystemReader(optimizer_dir), |
| ) |
| optim_state = optim_state["optimizer"] |
| flattened_osd = FSDP.optim_state_dict_to_load( |
| model=model, optim=optimizer, optim_state_dict=optim_state) |
| optimizer.load_state_dict(flattened_osd) |
| dist_cp.load_state_dict( |
| state_dict=weight_state_dict, |
| storage_reader=dist_cp.FileSystemReader(model_dir), |
| planner=DefaultLoadPlanner(), |
| ) |
| model_state = weight_state_dict["model"] |
| model.load_state_dict(model_state) |
| main_print(f"--> loaded model and optimizer from path {model_dir}") |
| return model, optimizer |
|
|
|
|
| def load_full_state_model(model, optimizer, checkpoint_file, rank): |
| with FSDP.state_dict_type( |
| model, |
| StateDictType.FULL_STATE_DICT, |
| FullStateDictConfig(offload_to_cpu=True, rank0_only=True), |
| FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), |
| ): |
| discriminator_state = torch.load(checkpoint_file) |
| model_state = discriminator_state["model"] |
| if rank <= 0: |
| optim_state = discriminator_state["optimizer"] |
| else: |
| optim_state = None |
| model.load_state_dict(model_state) |
| discriminator_optim_state = FSDP.optim_state_dict_to_load( |
| model=model, optim=optimizer, optim_state_dict=optim_state) |
| optimizer.load_state_dict(discriminator_optim_state) |
| main_print( |
| f"--> loaded discriminator and discriminator optimizer from path {checkpoint_file}" |
| ) |
| return model, optimizer |
|
|
|
|
| def resume_training_generator_discriminator(model, optimizer, discriminator, |
| discriminator_optimizer, |
| checkpoint_dir, rank): |
| step = int(checkpoint_dir.split("-")[-1]) |
| model_weight_dir = os.path.join(checkpoint_dir, "model_weights_state") |
| model_optimizer_dir = os.path.join(checkpoint_dir, "model_optimizer_state") |
| model, optimizer = load_sharded_model(model, optimizer, model_weight_dir, |
| model_optimizer_dir) |
| discriminator_ckpt_file = os.path.join(checkpoint_dir, |
| "discriminator_fsdp_state", |
| "discriminator_state.pt") |
| discriminator, discriminator_optimizer = load_full_state_model( |
| discriminator, discriminator_optimizer, discriminator_ckpt_file, rank) |
| return model, optimizer, discriminator, discriminator_optimizer, step |
|
|
|
|
| def resume_training(model, optimizer, checkpoint_dir, discriminator=False): |
| weight_path = os.path.join(checkpoint_dir, |
| "diffusion_pytorch_model.safetensors") |
| if discriminator: |
| weight_path = os.path.join(checkpoint_dir, |
| "discriminator_pytorch_model.safetensors") |
| model_weights = load_file(weight_path) |
|
|
| with FSDP.state_dict_type( |
| model, |
| StateDictType.FULL_STATE_DICT, |
| FullStateDictConfig(offload_to_cpu=True, rank0_only=True), |
| FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), |
| ): |
| current_state = model.state_dict() |
| current_state.update(model_weights) |
| model.load_state_dict(current_state, strict=False) |
| if discriminator: |
| optim_path = os.path.join(checkpoint_dir, "discriminator_optimizer.pt") |
| else: |
| optim_path = os.path.join(checkpoint_dir, "optimizer.pt") |
| optimizer_state_dict = torch.load(optim_path, weights_only=False) |
| optim_state = FSDP.optim_state_dict_to_load( |
| model=model, optim=optimizer, optim_state_dict=optimizer_state_dict) |
| optimizer.load_state_dict(optim_state) |
| step = int(checkpoint_dir.split("-")[-1]) |
| return model, optimizer, step |
|
|
|
|
| def save_lora_checkpoint(transformer, optimizer, rank, output_dir, step, |
| pipeline, epoch): |
| with FSDP.state_dict_type( |
| transformer, |
| StateDictType.FULL_STATE_DICT, |
| FullStateDictConfig(offload_to_cpu=True, rank0_only=True), |
| ): |
| full_state_dict = transformer.state_dict() |
| lora_optim_state = FSDP.optim_state_dict( |
| transformer, |
| optimizer, |
| ) |
|
|
| if rank <= 0: |
| save_dir = os.path.join(output_dir, f"lora-checkpoint-{step}-{epoch}") |
| os.makedirs(save_dir, exist_ok=True) |
|
|
| |
| optim_path = os.path.join(save_dir, "lora_optimizer.pt") |
| torch.save(lora_optim_state, optim_path) |
| |
| main_print(f"--> saving LoRA checkpoint at step {step}") |
| transformer_lora_layers = get_peft_model_state_dict( |
| model=transformer, state_dict=full_state_dict) |
| pipeline.save_lora_weights( |
| save_directory=save_dir, |
| transformer_lora_layers=transformer_lora_layers, |
| is_main_process=True, |
| ) |
| |
| lora_config = { |
| "step": step, |
| "lora_params": { |
| "lora_rank": transformer.config.lora_rank, |
| "lora_alpha": transformer.config.lora_alpha, |
| "target_modules": transformer.config.lora_target_modules, |
| }, |
| } |
| config_path = os.path.join(save_dir, "lora_config.json") |
| with open(config_path, "w") as f: |
| json.dump(lora_config, f, indent=4) |
| main_print(f"--> LoRA checkpoint saved at step {step}") |
|
|
|
|
| def resume_lora_optimizer(transformer, checkpoint_dir, optimizer): |
| config_path = os.path.join(checkpoint_dir, "lora_config.json") |
| with open(config_path, "r") as f: |
| config_dict = json.load(f) |
| optim_path = os.path.join(checkpoint_dir, "lora_optimizer.pt") |
| optimizer_state_dict = torch.load(optim_path, weights_only=False) |
| optim_state = FSDP.optim_state_dict_to_load( |
| model=transformer, |
| optim=optimizer, |
| optim_state_dict=optimizer_state_dict) |
| optimizer.load_state_dict(optim_state) |
| step = config_dict["step"] |
| main_print(f"--> Successfully resuming LoRA optimizer from step {step}") |
| return transformer, optimizer, step |
|
|