Text-to-Image
Diffusers
Safetensors
recoilme commited on
Commit
8a0d7b4
·
1 Parent(s): d6894ad
samples/unet_384x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 63379e019e6e33e6d5235d44ea208f91f6658476136c4c7343b45156f085f8b3
  • Pointer size: 131 Bytes
  • Size of remote file: 127 kB

Git LFS Details

  • SHA256: caae396c9bc1ec3af1432f6d334a3d2663d54f45d1e8e63eae1031286a5edc05
  • Pointer size: 131 Bytes
  • Size of remote file: 469 kB
samples/unet_416x704_0.jpg CHANGED

Git LFS Details

  • SHA256: a8229f1817d97e30f219542c43b2adc432bba04dcc11d8b9f7ea8f30c9b70577
  • Pointer size: 131 Bytes
  • Size of remote file: 296 kB

Git LFS Details

  • SHA256: 504636ccecf603abf95d348c1f7416e5fe9483263209964b1fa77946187afa2f
  • Pointer size: 131 Bytes
  • Size of remote file: 298 kB
samples/unet_448x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 354a6d5120b9bb44dd4cca836ca19c517b47629a6c16cb556964660b824f4cdf
  • Pointer size: 131 Bytes
  • Size of remote file: 295 kB

Git LFS Details

  • SHA256: 3c691e6b385a903f0306223e77aa03728c9f18865ec000c4ff492f434a0101ff
  • Pointer size: 131 Bytes
  • Size of remote file: 321 kB
samples/unet_480x704_0.jpg CHANGED

Git LFS Details

  • SHA256: c7b75eaeaecbb71870f92892350b665ab3276249ed5e9d5c4bb2e39283ab2696
  • Pointer size: 131 Bytes
  • Size of remote file: 487 kB

Git LFS Details

  • SHA256: ffd42aae34d3b3fbe6b53978de3594d356636295c1aa34ba1fd715d6ae9b450d
  • Pointer size: 131 Bytes
  • Size of remote file: 353 kB
samples/unet_512x704_0.jpg CHANGED

Git LFS Details

  • SHA256: f367172867b1112fa95969611a836f926f5d0e8e4f955dadd74fc196c8336862
  • Pointer size: 131 Bytes
  • Size of remote file: 683 kB

Git LFS Details

  • SHA256: 2150be572c33614fb9eda3844191f48a311e87eb30e9b90d20706481e47a8251
  • Pointer size: 131 Bytes
  • Size of remote file: 448 kB
samples/unet_544x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 307e54aa72c63ae06ce79f8ce0fefdd0ba6414920cbdec5463c1dfe521c6de1a
  • Pointer size: 131 Bytes
  • Size of remote file: 676 kB

Git LFS Details

  • SHA256: 3ff5d385c9643d2891f8cc433766476131e10a6618a5249184107a19a41243b8
  • Pointer size: 131 Bytes
  • Size of remote file: 340 kB
samples/unet_576x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 60f72963a286988cc09d582226fc60d94d01af4bc7f9719f3d7b39f426b75aed
  • Pointer size: 131 Bytes
  • Size of remote file: 642 kB

Git LFS Details

  • SHA256: b7ab0172172d9d65fe0622d70f67c0d934899577dffdc8f7d2409057269489c5
  • Pointer size: 131 Bytes
  • Size of remote file: 287 kB
samples/unet_608x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 97605de5e6a092bf6ee180e5ac4c79cf4146dee0f3c967d76265ffdeef5decfd
  • Pointer size: 131 Bytes
  • Size of remote file: 520 kB

Git LFS Details

  • SHA256: 3bfe307b6a0a923eca577889e10c553b22654e8b8ac23133c3962abc705c1a71
  • Pointer size: 131 Bytes
  • Size of remote file: 455 kB
samples/unet_640x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 9db887c6d0c89819192f2645419a057a6a487e166b5ab3ac758209183a4877d9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.26 MB

Git LFS Details

  • SHA256: 1df544279aa31c3478a7b1df1de8857123d863262f76374ddc72a6e0f71fec7f
  • Pointer size: 131 Bytes
  • Size of remote file: 519 kB
samples/unet_672x704_0.jpg CHANGED

Git LFS Details

  • SHA256: d4716ebdfc9e5a007256c3a176f5408e446de42dd599ac5b3bfcefe712a441cb
  • Pointer size: 131 Bytes
  • Size of remote file: 437 kB

Git LFS Details

  • SHA256: a8e80c3d4bee53ed565c194631dec1f17168ab9a489ff1eabd8a4168bfea54e8
  • Pointer size: 131 Bytes
  • Size of remote file: 776 kB
samples/unet_704x384_0.jpg CHANGED

Git LFS Details

  • SHA256: cfd89a1413d58ac37f454732f9958a7486b60fc1d38946752ac6390540910eaa
  • Pointer size: 131 Bytes
  • Size of remote file: 283 kB

Git LFS Details

  • SHA256: e1aa6f9430f01761fb5b3b89394e45687416f84a2a1539c02312ee7d3abd74f0
  • Pointer size: 131 Bytes
  • Size of remote file: 400 kB
samples/unet_704x416_0.jpg CHANGED

Git LFS Details

  • SHA256: 9bc7bc3a45da931477ea5e8264604a8d3a15e73b26f697dc3815e3bd808b7871
  • Pointer size: 131 Bytes
  • Size of remote file: 440 kB

Git LFS Details

  • SHA256: 7e0afebbebd313860e8ba231ae7767b0808b6c9ba1e6abb0b00362e76a41bc18
  • Pointer size: 131 Bytes
  • Size of remote file: 293 kB
samples/unet_704x448_0.jpg CHANGED

Git LFS Details

  • SHA256: 4e2c3c50ddbb9af350628f7734063bf0b980af1cac01946ea85cc28ea7202521
  • Pointer size: 131 Bytes
  • Size of remote file: 467 kB

Git LFS Details

  • SHA256: 9ded5a5e1517a21d7074177dbb8900824d63df67c6c3bc693f09726bb42b4e76
  • Pointer size: 131 Bytes
  • Size of remote file: 405 kB
samples/unet_704x480_0.jpg CHANGED

Git LFS Details

  • SHA256: 94d995dfd430bb8da79168361b9b7939ed82bfa3f82a065d00d67182fd5176ca
  • Pointer size: 131 Bytes
  • Size of remote file: 761 kB

Git LFS Details

  • SHA256: aa1d06bbfd20235456118f9430fee152baccd67f883e9188beabbd0a334409a6
  • Pointer size: 131 Bytes
  • Size of remote file: 292 kB
samples/unet_704x512_0.jpg CHANGED

Git LFS Details

  • SHA256: b53f44cf44b5cc1e5368d6ec6625b802ca985c8bd1073a9c2253c0635efce308
  • Pointer size: 131 Bytes
  • Size of remote file: 270 kB

Git LFS Details

  • SHA256: 6effaa1686df9131e6f0e6a2ee3ad45861dadde49c27e36aae99ee5a080f2772
  • Pointer size: 131 Bytes
  • Size of remote file: 273 kB
samples/unet_704x544_0.jpg CHANGED

Git LFS Details

  • SHA256: c95803edbadbc5982021746331877a43797f3162e5fb3fe9a9bcbe20f4b7aec5
  • Pointer size: 131 Bytes
  • Size of remote file: 569 kB

Git LFS Details

  • SHA256: 3e2b29989618f28c5f5c7e7a7fb76dfa78bb366f5c4f3951aa0e1dbadfc90dcd
  • Pointer size: 131 Bytes
  • Size of remote file: 219 kB
samples/unet_704x576_0.jpg CHANGED

Git LFS Details

  • SHA256: 672731f5d046e92d9450ef8f83a14b8440590d22dce079d31fa58cd4a646ea23
  • Pointer size: 131 Bytes
  • Size of remote file: 502 kB

Git LFS Details

  • SHA256: bd6ee1d03971d835fdefd5325b2e531797038de037ab3ce461c6ec41e6bef628
  • Pointer size: 131 Bytes
  • Size of remote file: 389 kB
samples/unet_704x608_0.jpg CHANGED

Git LFS Details

  • SHA256: 59fcb66948e50a0ae7d2c044b3f2fc5bff869318a9bce8ba2d4a6dafe05e6c62
  • Pointer size: 131 Bytes
  • Size of remote file: 585 kB

Git LFS Details

  • SHA256: e01b2d2d83b0e85a422e80115171c52772887f154778597d110b21b9864ad2c4
  • Pointer size: 131 Bytes
  • Size of remote file: 447 kB
samples/unet_704x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 50f580ca216fefdf2cdd946b5cced1b9623124b24cd57a7fe2259bfb03083321
  • Pointer size: 131 Bytes
  • Size of remote file: 638 kB

Git LFS Details

  • SHA256: 502e1a49ae41f42c56c1fa9dae9fe319f61135e0f03ae183a3feba99325321f1
  • Pointer size: 131 Bytes
  • Size of remote file: 560 kB
samples/unet_704x672_0.jpg CHANGED

Git LFS Details

  • SHA256: 0a3b85332ffc913914dfe61e29e64f9e28a055c31811df61bbe5fbae43c2fb3d
  • Pointer size: 131 Bytes
  • Size of remote file: 391 kB

Git LFS Details

  • SHA256: 8f8cdca1bffc596700e70da0bf0a73aba49198337eee07cffbcca8212a9b8a9b
  • Pointer size: 131 Bytes
  • Size of remote file: 195 kB
samples/unet_704x704_0.jpg CHANGED

Git LFS Details

  • SHA256: e40ed73cdb6af4f7d2ada2f5781756044595df42db218deba35e20536019258e
  • Pointer size: 131 Bytes
  • Size of remote file: 938 kB

Git LFS Details

  • SHA256: 4bc2c41e94a651395b131b65b58e62c99ca0997fed4a9a378ee1f8baf0d9842a
  • Pointer size: 131 Bytes
  • Size of remote file: 592 kB
tea_debug.log ADDED
File without changes
train.old.py ADDED
@@ -0,0 +1,837 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import torch
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import wandb,comet_ml
7
+ import random,time
8
+ import gc
9
+ import bitsandbytes as bnb
10
+ import torch.nn.functional as F
11
+ import argparse
12
+
13
+ from datetime import datetime
14
+ from diffusers import UNet2DConditionModel, AsymmetricAutoencoderKL, FlowMatchEulerDiscreteScheduler
15
+ from transformers import Qwen3_5Tokenizer, Qwen3_5ForConditionalGeneration
16
+ from torch.utils.data import DataLoader, Sampler
17
+ from torch.optim.lr_scheduler import LambdaLR
18
+ from collections import defaultdict
19
+ from accelerate import Accelerator
20
+ from datasets import load_from_disk
21
+ from tqdm import tqdm
22
+ from PIL import Image, ImageOps
23
+ from torch.utils.checkpoint import checkpoint
24
+ from diffusers.models.attention_processor import AttnProcessor2_0
25
+ from contextlib import nullcontext
26
+
27
+ # Muon not tested! pip install git+https://github.com/recoilme/muon_adamw8bit.git
28
+ from muon_adamw8bit import MuonAdamW8bit
29
+
30
+ os.environ["NCCL_P2P_DISABLE"] = "1"
31
+ os.environ["NCCL_IB_DISABLE"] = "1" # comment this on H100!
32
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
33
+
34
+ # --------------------------- Параметры ---------------------------
35
+ ds_path = "datasets/ds1234_noanime_704_vae8x16x"
36
+ project = "unet"
37
+
38
+ gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
39
+ local_bs = max(1, int((gpu_mem_gb / 32) * 7))
40
+ num_gpus = torch.cuda.device_count()
41
+ batch_size = local_bs * num_gpus
42
+
43
+ base_learning_rate = 4e-5
44
+ min_learning_rate = 4e-6
45
+
46
+ # 0.5 - pretrain (base forms)
47
+ # 1 - base train (composition)
48
+ # 3 - finetuning (anatomy)
49
+ # 5 - small details (faces)
50
+ learning_rate_scale = 2
51
+ base_learning_rate = base_learning_rate / learning_rate_scale
52
+ min_learning_rate = min_learning_rate / learning_rate_scale
53
+ print(f"Calculated params max-lr:{base_learning_rate} min-lr:{min_learning_rate} GPUs: {num_gpus}, Global BS: {batch_size}")
54
+
55
+ num_epochs = num_gpus
56
+ sink_interval_share = 20
57
+ cfg_dropout = 0.10
58
+ max_length = 248
59
+ use_precomputed_embeddings = False
60
+ use_wandb = True
61
+ use_comet_ml = False
62
+ save_model = True
63
+ use_decay = True
64
+ fbp = False
65
+ torch_compile = False
66
+ unet_gradient = True
67
+ loss_normalize = False
68
+ fixed_seed = False
69
+ shuffle = True
70
+ optimizer_type = "adam8bit"
71
+ if optimizer_type == "muon_adam8bit":
72
+ batch_size = num_gpus * max(1, int((gpu_mem_gb / 32) * 3))
73
+ muon_lr_scale = 500
74
+ comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r" # hardcoded for blind run, i don't care about key
75
+ comet_ml_workspace = "recoilme"
76
+ torch.backends.cuda.matmul.allow_tf32 = True
77
+ torch.backends.cudnn.allow_tf32 = True
78
+ # MAX_JOBS=4 pip install flash-attn --no-build-isolation
79
+ torch.backends.cuda.enable_flash_sdp(True)
80
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
81
+ torch.backends.cuda.enable_math_sdp(False) # Отключаем медленный вариант
82
+ save_barrier = 1.25
83
+ warmup_percent = 0.0025
84
+ betta2 = 0.997
85
+ eps = 1e-7
86
+ clip_grad_norm = 1.0
87
+ limit = 0
88
+ checkpoints_folder = ""
89
+ gradient_accumulation_steps = 1
90
+ dtype = torch.float32
91
+ mixed_precision = "no"
92
+
93
+ # Параметры для диффузии
94
+ n_diffusion_steps = 40
95
+ samples_to_generate = 12
96
+ guidance_scale = 4
97
+
98
+ # Папки для сохранения результатов
99
+ generated_folder = "samples"
100
+ os.makedirs(generated_folder, exist_ok=True)
101
+
102
+ # Настройка seed
103
+ current_date = datetime.now()
104
+ seed = int(current_date.strftime("%Y%m%d")) + 42
105
+ if fixed_seed:
106
+ torch.manual_seed(seed)
107
+ np.random.seed(seed)
108
+ random.seed(seed)
109
+ if torch.cuda.is_available():
110
+ torch.cuda.manual_seed_all(seed)
111
+
112
+ accelerator = Accelerator(
113
+ mixed_precision=mixed_precision,
114
+ gradient_accumulation_steps=gradient_accumulation_steps
115
+ )
116
+ device = accelerator.device
117
+
118
+ print("init")
119
+ # Создаём объект ArgumentParser с рассчитанными значениями по умолчанию
120
+ parser = argparse.ArgumentParser(description='Train a model on a dataset.')
121
+ parser.add_argument('--ds-path', type=str, default=ds_path, help='Path to the dataset')
122
+ parser.add_argument('--ep', type=int, default=num_epochs, help='Number of epochs to train the model')
123
+ parser.add_argument('--batch', type=int, default=batch_size, help='Total batch size')
124
+ parser.add_argument('--min-lr', type=float, default=min_learning_rate, help='Minimum learning rate')
125
+ parser.add_argument('--max-lr', type=float, default=base_learning_rate, help='Maximum learning rate')
126
+ parser.add_argument('--dry-run', action='store_true',default=False, help='Dry run train without saving/sampling')
127
+ parser.add_argument('--lvl', type=float, default=0.0, help='Train level, from 0.5 to 5')
128
+
129
+ # Парсим аргументы командной строки
130
+ args = parser.parse_args()
131
+
132
+ # Используем значения из аргументов
133
+ batch_size = args.batch
134
+ ds_path = args.ds_path
135
+ base_learning_rate = args.max_lr
136
+ min_learning_rate = args.min_lr
137
+ num_epochs = args.ep
138
+ lvl = args.lvl
139
+ if args.dry_run:
140
+ save_model = False
141
+ if lvl >= 0.1:
142
+ base_learning_rate = base_learning_rate / lvl
143
+ min_learning_rate = min_learning_rate / lvl
144
+ print(f"max-lr:{base_learning_rate} min-lr:{min_learning_rate}")
145
+
146
+
147
+ # --------------------------- Инициализация WandB ---------------------------
148
+ if accelerator.is_main_process:
149
+ if use_wandb:
150
+ wandb.init(project=project, config={
151
+ "batch_size": batch_size,
152
+ "base_learning_rate": base_learning_rate,
153
+ "num_epochs": num_epochs,
154
+ "optimizer_type": optimizer_type,
155
+ })
156
+ if use_comet_ml:
157
+ from comet_ml import Experiment
158
+ comet_experiment = Experiment(
159
+ api_key=comet_ml_api_key,
160
+ project_name=project,
161
+ workspace=comet_ml_workspace
162
+ )
163
+ hyper_params = {
164
+ "batch_size": batch_size,
165
+ "base_learning_rate": base_learning_rate,
166
+ "num_epochs": num_epochs,
167
+ }
168
+ comet_experiment.log_parameters(hyper_params)
169
+
170
+ # --------------------------- Загрузка моделей ---------------------------
171
+ vae = AsymmetricAutoencoderKL.from_pretrained("vae", torch_dtype=dtype).to(device).eval()
172
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("scheduler")
173
+ tokenizer = None
174
+ text_encoder = None
175
+
176
+ def load_text_encoder():
177
+ global tokenizer, text_encoder
178
+ if tokenizer is None:
179
+ tokenizer = Qwen3_5Tokenizer.from_pretrained("tokenizer")
180
+ if text_encoder is None:
181
+ text_encoder = Qwen3_5ForConditionalGeneration.from_pretrained(
182
+ "text_encoder",
183
+ torch_dtype=torch.float16
184
+ ).to(device).eval()
185
+
186
+ load_text_encoder()
187
+
188
+ @torch.no_grad()
189
+ def encode_texts(text, max_length=max_length):
190
+ if text is None: text = ""
191
+ if isinstance(text, str): text = [text]
192
+
193
+ formatted_prompts = []
194
+ for t in text:
195
+ messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
196
+ formatted_prompts.append(tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False))
197
+
198
+ toks = tokenizer(formatted_prompts, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt").to(device)
199
+ outputs = text_encoder(input_ids=toks.input_ids, attention_mask=toks.attention_mask, output_hidden_states=True)
200
+
201
+ last_hidden = outputs.hidden_states[-2]
202
+
203
+ return last_hidden.to(dtype=dtype), toks.attention_mask.to(dtype=torch.int64)
204
+
205
+ shift_factor = getattr(vae.config, "shift_factor", 0.0)
206
+ if shift_factor is None:
207
+ shift_factor = 0.0
208
+
209
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
210
+ if scaling_factor is None:
211
+ scaling_factor = 1.0
212
+
213
+ mean = getattr(vae.config, "latents_mean", None)
214
+ std = getattr(vae.config, "latents_std", None)
215
+ if mean is not None and std is not None:
216
+ latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1)
217
+ latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1)
218
+
219
+ import numpy as np
220
+ from torch.utils.data import Sampler
221
+
222
+
223
+ class DistributedResolutionBatchSampler(Sampler):
224
+ def __init__(self, dataset, batch_size, num_replicas, rank, drop_last=True, shuffle=True):
225
+ self.dataset = dataset
226
+ self.num_replicas = num_replicas
227
+ self.rank = rank
228
+ self.shuffle = shuffle
229
+ self.drop_last = drop_last
230
+ self.epoch = 0
231
+
232
+ # batch на одну GPU
233
+ self.batch_size = max(1, batch_size // num_replicas)
234
+ self.global_batch = self.batch_size * num_replicas
235
+
236
+ try:
237
+ widths = np.asarray(dataset["width"])
238
+ heights = np.asarray(dataset["height"])
239
+ except KeyError:
240
+ widths = np.zeros(len(dataset))
241
+ heights = np.zeros(len(dataset))
242
+
243
+ # --- группировка индексов ---
244
+ groups = {}
245
+ for i, (w, h) in enumerate(zip(widths, heights)):
246
+ groups.setdefault((w, h), []).append(i)
247
+
248
+ # --- создаём список всех глобальных батчей ---
249
+ all_batches = []
250
+
251
+ for indices in groups.values():
252
+
253
+ idx = np.asarray(indices, dtype=np.int64)
254
+
255
+ num_batches = len(idx) // self.global_batch
256
+ if num_batches == 0:
257
+ continue
258
+
259
+ idx = idx[: num_batches * self.global_batch]
260
+
261
+ batches = idx.reshape(num_batches, self.global_batch)
262
+
263
+ all_batches.append(batches)
264
+
265
+ if len(all_batches) > 0:
266
+ self.global_batches = np.concatenate(all_batches, axis=0)
267
+ else:
268
+ self.global_batches = np.empty((0, self.global_batch), dtype=np.int64)
269
+
270
+ self.num_batches = len(self.global_batches)
271
+
272
+ def __iter__(self):
273
+
274
+ rng = np.random.RandomState(self.epoch)
275
+
276
+ order = np.arange(self.num_batches)
277
+
278
+ if self.shuffle:
279
+ rng.shuffle(order)
280
+
281
+ start = self.rank * self.batch_size
282
+ end = start + self.batch_size
283
+
284
+ for i in order:
285
+ yield self.global_batches[i][start:end]
286
+
287
+ def __len__(self):
288
+ return self.num_batches
289
+
290
+ def set_epoch(self, epoch):
291
+ self.epoch = epoch
292
+
293
+
294
+
295
+ # --- [UPDATED] Функция для фиксированных семплов ---
296
+ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
297
+ size_groups = defaultdict(list)
298
+ try:
299
+ widths = dataset["width"]
300
+ heights = dataset["height"]
301
+ except KeyError:
302
+ widths = [0] * len(dataset)
303
+ heights = [0] * len(dataset)
304
+ for i, (w, h) in enumerate(zip(widths, heights)):
305
+ size = (w, h)
306
+ size_groups[size].append(i)
307
+
308
+ fixed_samples = {}
309
+ for size, indices in size_groups.items():
310
+ n_samples = min(samples_per_group, len(indices))
311
+ if len(size_groups)==1:
312
+ n_samples = samples_to_generate
313
+ if n_samples == 0:
314
+ continue
315
+ sample_indices = random.sample(indices, n_samples)
316
+ samples_data = [dataset[idx] for idx in sample_indices]
317
+
318
+ latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device, dtype=dtype)
319
+ texts = [item["text"] for item in samples_data]
320
+
321
+ # Кодируем тексты на лету, чтобы получить маски и пулинг
322
+ #embeddings, masks = encode_texts(texts)
323
+ if use_precomputed_embeddings:
324
+ embeddings = torch.tensor(
325
+ np.array([item["embeddings"] for item in samples_data]),
326
+ device=device,
327
+ dtype=dtype
328
+ )
329
+ masks = torch.tensor(
330
+ np.array([item["attention_mask"] for item in samples_data]),
331
+ device=device,
332
+ dtype=torch.int64
333
+ )
334
+ else:
335
+ embeddings, masks = encode_texts(texts)
336
+
337
+ fixed_samples[size] = (latents, embeddings, masks, texts)
338
+
339
+ print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
340
+ return fixed_samples
341
+
342
+ if limit > 0:
343
+ dataset = load_from_disk(ds_path).select(range(limit))
344
+ else:
345
+ dataset = load_from_disk(ds_path)
346
+
347
+
348
+ print(f"images: {len(dataset)}")
349
+
350
+ def collate_fn_simple(batch):
351
+
352
+ latents = torch.from_numpy(
353
+ np.array([item["vae"] for item in batch], dtype=np.float16)
354
+ ).to(device, dtype=dtype)
355
+
356
+ if use_precomputed_embeddings:
357
+ embeddings = torch.from_numpy(
358
+ np.array([item["embeddings"] for item in batch], dtype=np.float16)
359
+ ).to(device, dtype=dtype)
360
+
361
+ attention_mask = torch.from_numpy(
362
+ np.array([item["attention_mask"] for item in batch], dtype=np.int64)
363
+ ).to(device)
364
+
365
+ return latents, embeddings, attention_mask
366
+
367
+ raw_texts = [item["text"] for item in batch]
368
+
369
+ texts = [
370
+ "" if t.lower().startswith("zero")
371
+ else "" if random.random() < cfg_dropout
372
+ else t[1:].lstrip() if t.startswith(".")
373
+ else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip()
374
+ for t in raw_texts
375
+ ]
376
+
377
+ embeddings, attention_mask = encode_texts(texts)
378
+ attention_mask = attention_mask.to(dtype=torch.int64)
379
+
380
+ return latents, embeddings, attention_mask
381
+
382
+ batch_sampler = DistributedResolutionBatchSampler(
383
+ dataset=dataset,
384
+ batch_size=batch_size,
385
+ num_replicas=accelerator.num_processes,
386
+ rank=accelerator.process_index,
387
+ shuffle = shuffle
388
+ )
389
+
390
+ dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
391
+
392
+ if accelerator.is_main_process:
393
+ print("Total samples", len(dataloader))
394
+ dataloader = accelerator.prepare(dataloader)
395
+
396
+ start_epoch = 0
397
+ global_step = 0
398
+ total_training_steps = (len(dataloader) * num_epochs)
399
+ world_size = accelerator.state.num_processes
400
+
401
+ # Загрузка UNet
402
+ latest_checkpoint = os.path.join(checkpoints_folder, project)
403
+ if os.path.isdir(latest_checkpoint):
404
+ print("Загружаем UNet из чекпоинта:", latest_checkpoint)
405
+ unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype)
406
+ if unet_gradient:
407
+ unet.enable_gradient_checkpointing()
408
+ unet.set_use_memory_efficient_attention_xformers(False)
409
+ try:
410
+ unet.set_attn_processor(AttnProcessor2_0())
411
+ except Exception as e:
412
+ print(f"Ошибка при включении SDPA: {e}")
413
+ unet.set_use_memory_efficient_attention_xformers(True)
414
+ else:
415
+ raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}")
416
+
417
+ def create_optimizer(name, params):
418
+ if name == "adam8bit":
419
+ return bnb.optim.AdamW8bit(
420
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01
421
+ )
422
+ elif name == "adam":
423
+ return torch.optim.AdamW(
424
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01
425
+ )
426
+ elif name == "muon_adam8bit":
427
+ return MuonAdamW8bit(
428
+ params,
429
+ lr=base_learning_rate,
430
+ betas=(0.9, betta2),
431
+ eps=eps,
432
+ weight_decay=0.01,
433
+ muon_lr_mult=muon_lr_scale,
434
+ )
435
+ else:
436
+ raise ValueError(f"Unknown optimizer: {name}")
437
+
438
+ if fbp:
439
+ trainable_params = list(unet.parameters())
440
+ optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
441
+ def optimizer_hook(param):
442
+ optimizer_dict[param].step()
443
+ optimizer_dict[param].zero_grad(set_to_none=True)
444
+ for param in trainable_params:
445
+ param.register_post_accumulate_grad_hook(optimizer_hook)
446
+ unet, optimizer = accelerator.prepare(unet, optimizer_dict)
447
+ else:
448
+ unet.requires_grad_(True)
449
+ optimizer = create_optimizer(optimizer_type, unet.parameters())
450
+ # 1. Сначала замораживаем ВСЕ параметры UNet
451
+ #unet.requires_grad_(False)
452
+
453
+ # 2. Размораживаем только нужные
454
+ #trainable_params_names = ["conv_in.weight", "conv_in.bias", "conv_out.weight", "conv_out.bias"]
455
+ #train_params = []
456
+
457
+ #for name, param in unet.named_parameters():
458
+ # if any(target in name for target in trainable_params_names):
459
+ # param.requires_grad = True
460
+ # train_params.append(param)
461
+ # print(f"Обучаемый слой: {name}")
462
+
463
+ def lr_schedule(step):
464
+ x = step / (total_training_steps * world_size)
465
+ warmup = warmup_percent
466
+ if not use_decay:
467
+ return base_learning_rate
468
+ if x < warmup:
469
+ return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
470
+ decay_ratio = (x - warmup) / (1 - warmup)
471
+ return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
472
+ (1 + math.cos(math.pi * decay_ratio))
473
+ lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
474
+
475
+ if torch_compile:
476
+ print("Compiling UNet... Это займет несколько минут, не прерывайте!")
477
+ unet = torch.compile(unet)
478
+ print("Compiling - ok")
479
+
480
+ if not fbp:
481
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
482
+
483
+ # Фиксированные семплы
484
+ fixed_samples = get_fixed_samples_by_resolution(dataset)
485
+
486
+ # --- [UPDATED] Функция для негативного эмбеддинга (возвращает 3 элемента) ---
487
+ def get_negative_embedding(neg_prompt="", batch_size=1):
488
+ if not neg_prompt:
489
+ hidden_dim = 2048
490
+ seq_len = max_length
491
+ empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
492
+ empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
493
+ return empty_emb, empty_mask
494
+
495
+ uncond_emb, uncond_mask = encode_texts([neg_prompt])
496
+ uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1)
497
+ uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1)
498
+
499
+ return uncond_emb, uncond_mask
500
+
501
+ # Получаем негативные (пустые) условия для валидации
502
+ if use_precomputed_embeddings:
503
+ # 1. грузим encoder ВРЕМЕННО
504
+ load_text_encoder()
505
+
506
+ # 2. считаем negative
507
+ uncond_emb, uncond_mask = get_negative_embedding("low quality")
508
+
509
+ # 3. уносим на CPU (очень важно)
510
+ uncond_emb = uncond_emb.to("cpu")
511
+ uncond_mask = uncond_mask.to("cpu")
512
+
513
+ # 4. выгружаем encoder с GPU
514
+ del text_encoder
515
+ torch.cuda.empty_cache()
516
+ gc.collect()
517
+
518
+ text_encoder = None
519
+
520
+ else:
521
+ uncond_emb, uncond_mask = get_negative_embedding("low quality")
522
+
523
+ # --- Функция генерации семплов ---
524
+ @torch.compiler.disable()
525
+ @torch.no_grad()
526
+ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
527
+ uncond_emb, uncond_mask = uncond_data
528
+ uncond_emb = uncond_emb.to(device)
529
+ uncond_mask = uncond_mask.to(device)
530
+
531
+ original_model = None
532
+ try:
533
+ if not torch_compile:
534
+ original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
535
+ else:
536
+ original_model = unet.eval()
537
+
538
+ vae.to(device=device).eval()
539
+
540
+ all_generated_images = []
541
+ all_captions = []
542
+
543
+ # Распаковываем 5 элементов (добавились mask)
544
+ for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text) in fixed_samples_cpu.items():
545
+ width, height = size
546
+ sample_latents = sample_latents.to(dtype=dtype, device=device)
547
+ sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
548
+ sample_mask = sample_mask.to(device=device)
549
+
550
+ latents = torch.randn(
551
+ sample_latents.shape,
552
+ device=device,
553
+ dtype=sample_latents.dtype,
554
+ generator=torch.Generator(device=device).manual_seed(seed)
555
+ )
556
+
557
+ scheduler.set_timesteps(n_diffusion_steps, device=device)
558
+
559
+ for t in scheduler.timesteps:
560
+ if guidance_scale != 1:
561
+ latent_model_input = torch.cat([latents, latents], dim=0)
562
+
563
+ curr_batch_size = sample_text_embeddings.shape[0]
564
+ seq_len = sample_text_embeddings.shape[1]
565
+ hidden_dim = sample_text_embeddings.shape[2]
566
+
567
+ neg_emb_batch = uncond_emb[0:1].expand(curr_batch_size, -1, -1)
568
+ text_embeddings_batch = torch.cat([neg_emb_batch, sample_text_embeddings], dim=0)
569
+
570
+ neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1)
571
+ attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0)
572
+
573
+ else:
574
+ latent_model_input = latents
575
+ text_embeddings_batch = sample_text_embeddings
576
+ attention_mask_batch = sample_mask
577
+
578
+ # Теперь всё имеет одинаковый batch size
579
+ model_out = original_model(
580
+ latent_model_input,
581
+ t,
582
+ encoder_hidden_states=text_embeddings_batch,
583
+ encoder_attention_mask=attention_mask_batch,
584
+ )
585
+
586
+ flow = getattr(model_out, "sample", model_out)
587
+
588
+ if guidance_scale != 1:
589
+ flow_uncond, flow_cond = flow.chunk(2)
590
+ flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
591
+
592
+ latents = scheduler.step(flow, t, latents).prev_sample
593
+
594
+ current_latents = latents
595
+ if step==0:
596
+ current_latents = sample_latents
597
+
598
+ if latents_mean is not None and latents_std is not None:
599
+ latents = current_latents * latents_std + latents_mean
600
+
601
+ decoded = vae.decode(latents.to(torch.float32)).sample
602
+ decoded_fp32 = decoded.to(torch.float32)
603
+
604
+ for img_idx, img_tensor in enumerate(decoded_fp32):
605
+ img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
606
+ img = img.transpose(1, 2, 0)
607
+
608
+ if np.isnan(img).any():
609
+ print("NaNs found, saving stopped! Step:", step)
610
+ pil_img = Image.fromarray((img * 255).astype("uint8"))
611
+
612
+ max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
613
+ max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
614
+ max_w_overall = max(255, max_w_overall)
615
+ max_h_overall = max(255, max_h_overall)
616
+
617
+ padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
618
+ all_generated_images.append(padded_img)
619
+
620
+ caption_text = sample_text[img_idx][:300] if img_idx < len(sample_text) else ""
621
+ all_captions.append(caption_text)
622
+
623
+ sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
624
+ pil_img.save(sample_path, "JPEG", quality=95)
625
+
626
+ if use_wandb and accelerator.is_main_process:
627
+ wandb_images = [
628
+ wandb.Image(img, caption=f"{all_captions[i]}")
629
+ for i, img in enumerate(all_generated_images)
630
+ ]
631
+ wandb.log({"generated_images": wandb_images})
632
+ if use_comet_ml and accelerator.is_main_process:
633
+ for i, img in enumerate(all_generated_images):
634
+ comet_experiment.log_image(
635
+ image_data=img,
636
+ name=f"step_{step}_img_{i}",
637
+ step=step,
638
+ metadata={"caption": all_captions[i]}
639
+ )
640
+ finally:
641
+ vae.to("cpu")
642
+ uncond_emb = uncond_emb.to("cpu")
643
+ uncond_mask = uncond_mask.to("cpu")
644
+ try:
645
+ all_generated_images.clear()
646
+ all_captions.clear()
647
+ del all_generated_images, all_captions
648
+ del latents, current_latents, latent_model_input, flow
649
+ del decoded, decoded_fp32
650
+ del sample_latents, sample_text_embeddings, sample_mask # Копии на GPU
651
+ del model_out
652
+ except UnboundLocalError:
653
+ pass
654
+
655
+ # 3. Синхронизируем CUDA перед очисткой
656
+ torch.cuda.synchronize()
657
+ # 4. Теперь чистим кэш аллокатора и вызываем GC
658
+ torch.cuda.empty_cache()
659
+ gc.collect()
660
+
661
+ # --------------------------- Генерация сэмплов перед обучением ---------------------------
662
+ if accelerator.is_main_process:
663
+ if save_model:
664
+ print("Генерация сэмплов до старта обучения...")
665
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), 0)
666
+ accelerator.wait_for_everyone()
667
+
668
+ def save_checkpoint(unet, variant=""):
669
+ if accelerator.is_main_process:
670
+ model_to_save = None
671
+ if not torch_compile:
672
+ model_to_save = accelerator.unwrap_model(unet)
673
+ else:
674
+ model_to_save = unet
675
+
676
+ if variant != "":
677
+ model_to_save.to(dtype=torch.float16).save_pretrained(
678
+ os.path.join(checkpoints_folder, f"{project}"), variant=variant
679
+ )
680
+ else:
681
+ model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
682
+
683
+ torch.cuda.synchronize()
684
+ torch.cuda.empty_cache()
685
+ gc.collect()
686
+
687
+ # --------------------------- Тренировочный цикл ---------------------------
688
+ if accelerator.is_main_process:
689
+ print(f"Total steps per GPU: {total_training_steps}")
690
+
691
+ epoch_loss_points = []
692
+ progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
693
+
694
+ steps_per_epoch = len(dataloader)
695
+ sink_interval = max(1, steps_per_epoch // sink_interval_share)
696
+ min_loss = 4.
697
+ last_sample_time = time.time()
698
+ sample_interval_seconds = 60 * 60 # 60 минут
699
+
700
+ for epoch in range(start_epoch, start_epoch + num_epochs):
701
+ batch_losses = []
702
+ batch_grads = []
703
+ batch_sampler.set_epoch(epoch)
704
+ accelerator.wait_for_everyone()
705
+ unet.train()
706
+
707
+ for step, (latents, embeddings, attention_mask) in enumerate(dataloader):
708
+
709
+ if save_model == False and epoch == 0 and step == 5 :
710
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
711
+ print(f"Шаг {step}: {used_gb:.2f} GB")
712
+
713
+ amp_context = accelerator.autocast() if torch_compile else nullcontext()
714
+ with accelerator.accumulate(unet):
715
+ with amp_context:
716
+ # шум
717
+ noise = torch.randn_like(latents, dtype=latents.dtype)
718
+
719
+ # 3. Время t, bias = -0.5 (Фокус на Деталях ~300) bias = 0.5 (Фокус на структуре) bias = 0 (колокол/ равномерно)
720
+ bias = 0.25
721
+ t = torch.sigmoid(torch.randn(latents.shape[0], device=latents.device, dtype=latents.dtype) + bias)
722
+
723
+ # интерполяция между x0 и шумом
724
+ noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
725
+ # делаем integer timesteps для UNet
726
+ timesteps = t.to(torch.float32).mul(999.0)
727
+ timesteps = timesteps.clamp(0, scheduler.config.num_train_timesteps - 1)
728
+
729
+ # --- Вызов UNet с маской ---
730
+ model_pred = unet(
731
+ noisy_latents,
732
+ timesteps,
733
+ encoder_hidden_states=embeddings,
734
+ encoder_attention_mask=attention_mask,
735
+ ).sample
736
+
737
+ target = noise - latents
738
+
739
+ mse_loss = F.mse_loss(model_pred.float(), target.float())
740
+ batch_losses.append(mse_loss.detach().item())
741
+
742
+ if (global_step % 100 == 0) or (global_step % sink_interval == 0):
743
+ accelerator.wait_for_everyone()
744
+
745
+ losses_dict = {}
746
+ losses_dict["mse"] = mse_loss
747
+
748
+ if (global_step % 100 == 0) or (global_step % sink_interval == 0):
749
+ accelerator.wait_for_everyone()
750
+
751
+ accelerator.backward(mse_loss)
752
+
753
+ if (global_step % 100 == 0) or (global_step % sink_interval == 0):
754
+ accelerator.wait_for_everyone()
755
+
756
+ grad = 0.0
757
+ if not fbp:
758
+ if accelerator.sync_gradients:
759
+ grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
760
+ grad = grad_val.float().item() if torch.is_tensor(grad_val) else float(grad_val)
761
+ optimizer.step()
762
+ lr_scheduler.step()
763
+ optimizer.zero_grad(set_to_none=True)
764
+
765
+ if accelerator.sync_gradients:
766
+ global_step += 1
767
+ progress_bar.update(1)
768
+ if accelerator.is_main_process:
769
+ if fbp:
770
+ current_lr = base_learning_rate
771
+ else:
772
+ current_lr = lr_scheduler.get_last_lr()[0]
773
+ batch_grads.append(grad)
774
+
775
+ log_data = {}
776
+ log_data["loss_mse"] = mse_loss.detach().item()
777
+ log_data["lr"] = current_lr
778
+ log_data["grad"] = grad
779
+ if accelerator.sync_gradients:
780
+ if use_wandb:
781
+ wandb.log(log_data, step=global_step)
782
+ if use_comet_ml:
783
+ comet_experiment.log_metrics(log_data, step=global_step)
784
+
785
+ current_time = time.time()
786
+ is_time_to_sample = (current_time - last_sample_time) >= sample_interval_seconds
787
+ if is_time_to_sample or global_step == 50:
788
+ # Передаем tuple (emb, mask) для негатива
789
+ if save_model:
790
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
791
+ elif epoch % 10 == 0:
792
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
793
+ last_n = sink_interval
794
+
795
+ if save_model:
796
+ has_losses = len(batch_losses) > 0
797
+ avg_sample_loss = np.mean(batch_losses[-sink_interval:]) if has_losses else 0.0
798
+ last_loss = batch_losses[-1] if has_losses else 0.0
799
+ max_loss = max(avg_sample_loss, last_loss)
800
+ should_save = max_loss < min_loss * save_barrier
801
+ print(
802
+ f"Saving: {should_save} | Max: {max_loss:.4f} | "
803
+ f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}"
804
+ )
805
+ # 6. Сохранение и обновление
806
+ if should_save:
807
+ min_loss = max_loss
808
+ save_checkpoint(unet)
809
+ last_sample_time = current_time
810
+ unet.train()
811
+
812
+ if accelerator.is_main_process:
813
+ avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
814
+ avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
815
+
816
+ print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
817
+ log_data_ep = {
818
+ "epoch_loss": avg_epoch_loss,
819
+ "epoch_grad": avg_epoch_grad,
820
+ "epoch": epoch + 1,
821
+ }
822
+ if use_wandb:
823
+ wandb.log(log_data_ep)
824
+ if use_comet_ml:
825
+ comet_experiment.log_metrics(log_data_ep)
826
+
827
+ if accelerator.is_main_process:
828
+ print("Обучение завершено! Сохраняем финальную модель...")
829
+ #if save_model:
830
+ save_checkpoint(unet,"fp16")
831
+ if use_comet_ml:
832
+ comet_experiment.end()
833
+ accelerator.free_memory()
834
+ if torch.distributed.is_initialized():
835
+ torch.distributed.destroy_process_group()
836
+
837
+ print("Готово!")
train.py CHANGED
@@ -207,10 +207,17 @@ def encode_texts(text, max_length=max_length):
207
 
208
  toks = tokenizer(
209
  formatted_prompts,
210
- padding=True, # 🔥 динамический padding
 
211
  truncation=True,
212
  return_tensors="pt"
213
  ).to(device)
 
 
 
 
 
 
214
 
215
  outputs = text_encoder(
216
  input_ids=toks.input_ids,
 
207
 
208
  toks = tokenizer(
209
  formatted_prompts,
210
+ padding="max_length",
211
+ max_length=max_length,
212
  truncation=True,
213
  return_tensors="pt"
214
  ).to(device)
215
+ #toks = tokenizer(
216
+ # formatted_prompts,
217
+ # padding=True, # 🔥 динамический padding
218
+ # truncation=True,
219
+ # return_tensors="pt"
220
+ #).to(device)
221
 
222
  outputs = text_encoder(
223
  input_ids=toks.input_ids,
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:28b10c6e834d7e64f1dd96efd40a5fc91a2046e6c0fa5b4cdb3482f2aa3d0e18
3
  size 6420443856
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8eb45e13a22d9a0d00549f0c204606a9bd6f9b439d572207d9dc7b732cfd300d
3
  size 6420443856