Lisandro commited on
Commit
67cee0c
·
1 Parent(s): d8cb9cc

feat: Add Gradio app for image generation with LoRA support and history management

Browse files

- Updated requirements.txt to include necessary packages: torch, gradio, spaces, huggingface_hub, Pillow, and numpy.
- Created a new Gradio app (app copy.py) that integrates image generation using a diffusion pipeline with LoRA configurations.
- Implemented functions to load LoRA configurations from a JSON file and manage generation history.
- Added UI components for selecting LoRAs, adjusting generation parameters, and displaying generated images.
- Introduced functionality to clear and update the generation history.
- Added runtime.txt to specify Python and PyTorch versions for compatibility.

Files changed (6) hide show
  1. .gitattributes +13 -0
  2. .gitignore +7 -0
  3. app copy.py +613 -0
  4. app.py +132 -504
  5. requirements.txt +7 -1
  6. runtime.txt +5 -0
.gitattributes CHANGED
@@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+
37
+ # No modificar finales de línea en binarios
38
+ *.safetensors binary
39
+ *.pt binary
40
+ *.pth binary
41
+ *.pt2 binary
42
+ *.zip binary
43
+ aoti_artifacts/** binary
44
+
45
+ # Forzar texto LF en archivos de código
46
+ *.py text eol=lf
47
+ *.json text eol=lf
48
+ *.txt text eol=lf
.gitignore CHANGED
@@ -6,3 +6,10 @@ env/
6
  # Python cache
7
  __pycache__/
8
  *.pyc
 
 
 
 
 
 
 
 
6
  # Python cache
7
  __pycache__/
8
  *.pyc
9
+
10
+ # AOT
11
+ *.pt
12
+ *.pth
13
+ *.safetensors
14
+ *.zip
15
+ aoti_artifacts/
app copy.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import json
4
+ import logging
5
+ import torch
6
+ from PIL import Image
7
+ import spaces
8
+ from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
9
+ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
10
+ import copy
11
+ import random
12
+ import time
13
+ import re
14
+ import math
15
+ import numpy as np
16
+ import traceback
17
+
18
+ # Load LoRAs from JSON file
19
+ def load_loras_from_file():
20
+ """Load LoRA configurations from external JSON file."""
21
+ try:
22
+ with open('loras.json', 'r', encoding='utf-8') as f:
23
+ return json.load(f)
24
+ except FileNotFoundError:
25
+ print("Warning: loras.json file not found. Using empty list.")
26
+ return []
27
+ except json.JSONDecodeError as e:
28
+ print(f"Error parsing loras.json: {e}")
29
+ return []
30
+
31
+ # Load the LoRAs
32
+ loras = load_loras_from_file()
33
+
34
+ # Initialize the base model
35
+ dtype = torch.bfloat16
36
+ device = "cuda" if torch.cuda.is_available() else "cpu"
37
+ base_model = "Qwen/Qwen-Image"
38
+
39
+ # Scheduler configuration from the Qwen-Image-Lightning repository
40
+ scheduler_config = {
41
+ "base_image_seq_len": 256,
42
+ "base_shift": math.log(3),
43
+ "invert_sigmas": False,
44
+ "max_image_seq_len": 8192,
45
+ "max_shift": math.log(3),
46
+ "num_train_timesteps": 1000,
47
+ "shift": 1.0,
48
+ "shift_terminal": None,
49
+ "stochastic_sampling": False,
50
+ "time_shift_type": "exponential",
51
+ "use_beta_sigmas": False,
52
+ "use_dynamic_shifting": True,
53
+ "use_exponential_sigmas": False,
54
+ "use_karras_sigmas": False,
55
+ }
56
+
57
+ scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
58
+ pipe = DiffusionPipeline.from_pretrained(
59
+ base_model, scheduler=scheduler, torch_dtype=dtype
60
+ ).to(device)
61
+
62
+ # Lightning LoRA info (no global state)
63
+ LIGHTNING_LORA_REPO = "lightx2v/Qwen-Image-Lightning"
64
+ LIGHTNING_LORA_WEIGHT = "Qwen-Image-Lightning-4steps-V2.0-bf16.safetensors"
65
+ LIGHTNING8_LORA_WEIGHT = "Qwen-Image-Lightning-8steps-V2.0-bf16.safetensors"
66
+ LIGHTNING_FP8_4STEPS_LORA_WEIGHT = "Qwen-Image-fp8-e4m3fn-Lightning-4steps-V1.0-bf16.safetensors"
67
+
68
+
69
+ MAX_SEED = np.iinfo(np.int32).max
70
+
71
+ ### MODIFICACIÓN 1: AÑADIR FUNCIONES PARA GESTIONAR EL HISTORIAL ###
72
+ def update_history(new_images, history):
73
+ """Añade las nuevas imágenes generadas al principio de la lista del historial."""
74
+ # Gradio pasa el valor actual de la galería de historial como una lista
75
+ if history is None:
76
+ history = []
77
+ if new_images is not None and len(new_images) > 0:
78
+ # Añade las nuevas imágenes al principio de la lista existente
79
+ updated_history = new_images + history
80
+ # Limita el historial a un tamaño razonable (ej. 24 imágenes) para no usar demasiada memoria
81
+ return updated_history[:24]
82
+ return history
83
+
84
+ def clear_history():
85
+ """Devuelve una lista vacía para limpiar la galería de historial."""
86
+ return []
87
+ ### FIN DE LA MODIFICACIÓN 1 ###
88
+
89
+
90
+ class calculateDuration:
91
+ def __init__(self, activity_name=""):
92
+ self.activity_name = activity_name
93
+
94
+ def __enter__(self):
95
+ self.start_time = time.time()
96
+ return self
97
+
98
+ def __exit__(self, exc_type, exc_value, traceback):
99
+ self.end_time = time.time()
100
+ self.elapsed_time = self.end_time - self.start_time
101
+ if self.activity_name:
102
+ print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
103
+ else:
104
+ print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
105
+
106
+ def get_image_size(aspect_ratio):
107
+ """Converts aspect ratio string to width, height tuple."""
108
+ if aspect_ratio == "1:1":
109
+ return 1024, 1024
110
+ elif aspect_ratio == "16:9":
111
+ return 1152, 640
112
+ elif aspect_ratio == "9:16":
113
+ return 640, 1152
114
+ elif aspect_ratio == "4:3":
115
+ return 1024, 768
116
+ elif aspect_ratio == "3:4":
117
+ return 768, 1024
118
+ elif aspect_ratio == "3:2":
119
+ return 1024, 688
120
+ elif aspect_ratio == "2:3":
121
+ return 688, 1024
122
+ elif aspect_ratio == "4:1":
123
+ return 2560, 640
124
+ elif aspect_ratio == "3:1":
125
+ return 1920, 640
126
+ elif aspect_ratio == "2:1":
127
+ return 1280, 640
128
+ else:
129
+ return 1024, 1024
130
+
131
+ def update_selection(evt: gr.SelectData, aspect_ratio):
132
+ selected_lora = loras[evt.index]
133
+ new_placeholder = f"Type a prompt for {selected_lora['title']}"
134
+ lora_repo = selected_lora["repo"]
135
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
136
+
137
+ examples_list = []
138
+ try:
139
+ model_card = ModelCard.load(lora_repo)
140
+ widget_data = model_card.data.get("widget", [])
141
+ if widget_data and len(widget_data) > 0:
142
+ for example in widget_data[:4]:
143
+ if "output" in example and "url" in example["output"]:
144
+ image_url = f"https://huggingface.co/{lora_repo}/resolve/main/{example['output']['url']}"
145
+ prompt_text = example.get("text", "")
146
+ examples_list.append([prompt_text])
147
+ except Exception as e:
148
+ print(f"Could not load model card for {lora_repo}: {e}")
149
+
150
+ return (
151
+ gr.update(placeholder=new_placeholder),
152
+ updated_text,
153
+ evt.index,
154
+ aspect_ratio,
155
+ gr.update(interactive=True)
156
+ )
157
+
158
+ def handle_speed_mode(speed_mode):
159
+ """Update UI based on speed/quality toggle."""
160
+ if speed_mode == "light 4":
161
+ return gr.update(value="Light mode (4 steps) selected"), 4, 1.0
162
+ elif speed_mode == "light 4 fp8":
163
+ return gr.update(value="Light mode (4 steps fp8) selected"), 4, 1.0
164
+ elif speed_mode == "light 8":
165
+ return gr.update(value="Light mode (8 steps) selected"), 8, 1.0
166
+ else:
167
+ return gr.update(value="Normal quality (45 steps) selected"), 45, 3.5
168
+
169
+ @spaces.GPU(duration=70)
170
+ def generate_image(
171
+ prompt_mash,
172
+ steps,
173
+ seed,
174
+ cfg_scale,
175
+ width,
176
+ height,
177
+ lora_scale,
178
+ negative_prompt="",
179
+ num_images=1,
180
+ ):
181
+ pipe.to("cuda")
182
+
183
+ seeds = [seed + (i * 100) for i in range(num_images)]
184
+ generators = [torch.Generator(device="cuda").manual_seed(s) for s in seeds]
185
+
186
+ with calculateDuration("Generating image"):
187
+ result = pipe(
188
+ prompt=prompt_mash,
189
+ negative_prompt=negative_prompt,
190
+ num_inference_steps=steps,
191
+ true_cfg_scale=cfg_scale,
192
+ width=width,
193
+ height=height,
194
+ num_images_per_prompt=num_images,
195
+ generator=generators,
196
+ )
197
+
198
+ images = [(img, s) for img, s in zip(result.images, seeds)]
199
+ return images
200
+
201
+ @spaces.GPU(duration=70)
202
+ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, aspect_ratio, lora_scale, speed_mode, quality_multiplier, quantity, progress=gr.Progress(track_tqdm=True)):
203
+ if selected_index is None:
204
+ raise gr.Error("You must select a LoRA before proceeding.")
205
+
206
+ selected_lora = loras[selected_index]
207
+ lora_path = selected_lora["repo"]
208
+ trigger_word = selected_lora["trigger_word"]
209
+
210
+ if trigger_word:
211
+ if "trigger_position" in selected_lora:
212
+ if selected_lora["trigger_position"] == "prepend":
213
+ prompt_mash = f"{trigger_word} {prompt}"
214
+ else:
215
+ prompt_mash = f"{prompt} {trigger_word}"
216
+ else:
217
+ prompt_mash = f"{trigger_word} {prompt}"
218
+ else:
219
+ prompt_mash = prompt
220
+
221
+ with calculateDuration("Unloading existing LoRAs"):
222
+ pipe.unload_lora_weights()
223
+
224
+ if speed_mode == "light 4":
225
+ with calculateDuration("Loading Lightning LoRA and style LoRA"):
226
+ pipe.load_lora_weights(
227
+ LIGHTNING_LORA_REPO,
228
+ weight_name=LIGHTNING_LORA_WEIGHT,
229
+ adapter_name="lightning"
230
+ )
231
+ weight_name = selected_lora.get("weights", None)
232
+ pipe.load_lora_weights(
233
+ lora_path,
234
+ weight_name=weight_name,
235
+ low_cpu_mem_usage=True,
236
+ adapter_name="style"
237
+ )
238
+ pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
239
+ elif speed_mode == "light 4 fp8":
240
+ with calculateDuration("Loading Lightning LoRA and style LoRA"):
241
+ pipe.load_lora_weights(
242
+ LIGHTNING_LORA_REPO,
243
+ weight_name=LIGHTNING_FP8_4STEPS_LORA_WEIGHT,
244
+ adapter_name="lightning"
245
+ )
246
+ weight_name = selected_lora.get("weights", None)
247
+ pipe.load_lora_weights(
248
+ lora_path,
249
+ weight_name=weight_name,
250
+ low_cpu_mem_usage=True,
251
+ adapter_name="style"
252
+ )
253
+ pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
254
+ elif speed_mode == "light 8":
255
+ with calculateDuration("Loading Lightning LoRA and style LoRA"):
256
+ pipe.load_lora_weights(
257
+ LIGHTNING_LORA_REPO,
258
+ weight_name=LIGHTNING8_LORA_WEIGHT,
259
+ adapter_name="lightning"
260
+ )
261
+ weight_name = selected_lora.get("weights", None)
262
+ pipe.load_lora_weights(
263
+ lora_path,
264
+ weight_name=weight_name,
265
+ low_cpu_mem_usage=True,
266
+ adapter_name="style"
267
+ )
268
+ pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
269
+ else:
270
+ with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
271
+ weight_name = selected_lora.get("weights", None)
272
+ pipe.load_lora_weights(
273
+ lora_path,
274
+ weight_name=weight_name,
275
+ low_cpu_mem_usage=True,
276
+ adapter_name="style"
277
+ )
278
+ pipe.set_adapters(["style"], adapter_weights=[lora_scale])
279
+
280
+ with calculateDuration("Randomizing seed"):
281
+ if randomize_seed:
282
+ seed = random.randint(0, MAX_SEED)
283
+
284
+ width, height = get_image_size(aspect_ratio)
285
+ multiplier = float(quality_multiplier.replace('x', ''))
286
+ width = int(width * multiplier)
287
+ height = int(height * multiplier)
288
+ num_images = int(quantity) + 1
289
+
290
+ pairs = generate_image(
291
+ prompt_mash,
292
+ steps,
293
+ seed,
294
+ cfg_scale,
295
+ width,
296
+ height,
297
+ lora_scale,
298
+ negative_prompt="",
299
+ num_images=num_images,
300
+ )
301
+
302
+ images_for_gallery = [
303
+ (img, str(s))
304
+ for (img, s) in pairs
305
+ ]
306
+
307
+ return images_for_gallery, seed
308
+
309
+ # ... (El resto de las funciones como get_huggingface_safetensors, check_custom_model, etc., permanecen sin cambios) ...
310
+ def get_huggingface_safetensors(link):
311
+ split_link = link.split("/")
312
+ if len(split_link) != 2:
313
+ raise Exception("Invalid Hugging Face repository link format.")
314
+ print(f"Repository attempted: {split_link}")
315
+ model_card = ModelCard.load(link)
316
+ base_model = model_card.data.get("base_model")
317
+ print(f"Base model: {base_model}")
318
+ acceptable_models = {"Qwen/Qwen-Image"}
319
+ models_to_check = base_model if isinstance(base_model, list) else [base_model]
320
+ if not any(model in acceptable_models for model in models_to_check):
321
+ raise Exception("Not a Qwen-Image LoRA!")
322
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
323
+ trigger_word = model_card.data.get("instance_prompt", "")
324
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
325
+ fs = HfFileSystem()
326
+ try:
327
+ list_of_files = fs.ls(link, detail=False)
328
+ safetensors_name = None
329
+ for file in list_of_files:
330
+ filename = file.split("/")[-1]
331
+ if filename.endswith(".safetensors"):
332
+ safetensors_name = filename
333
+ break
334
+ if not safetensors_name:
335
+ raise Exception("No valid *.safetensors file found in the repository.")
336
+ except Exception as e:
337
+ print(e)
338
+ raise Exception("You didn't include a valid Hugging Face repository with a *.safetensors LoRA")
339
+ return split_link[1], link, safetensors_name, trigger_word, image_url
340
+
341
+ def check_custom_model(link):
342
+ print(f"Checking a custom model on: {link}")
343
+ if link.endswith('.safetensors'):
344
+ if 'huggingface.co' in link:
345
+ parts = link.split('/')
346
+ try:
347
+ hf_index = parts.index('huggingface.co')
348
+ username = parts[hf_index + 1]
349
+ repo_name = parts[hf_index + 2]
350
+ repo = f"{username}/{repo_name}"
351
+ safetensors_name = parts[-1]
352
+ try:
353
+ model_card = ModelCard.load(repo)
354
+ trigger_word = model_card.data.get("instance_prompt", "")
355
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
356
+ image_url = f"https://huggingface.co/{repo}/resolve/main/{image_path}" if image_path else None
357
+ except:
358
+ trigger_word = ""
359
+ image_url = None
360
+ return repo_name, repo, safetensors_name, trigger_word, image_url
361
+ except:
362
+ raise Exception("Invalid safetensors URL format")
363
+ if link.startswith("https://"):
364
+ if link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co"):
365
+ link_split = link.split("huggingface.co/")
366
+ return get_huggingface_safetensors(link_split[1])
367
+ else:
368
+ return get_huggingface_safetensors(link)
369
+
370
+ def add_custom_lora(custom_lora):
371
+ global loras
372
+ if custom_lora:
373
+ try:
374
+ title, repo, path, trigger_word, image = check_custom_model(custom_lora)
375
+ print(f"Loaded custom LoRA: {repo}")
376
+ model_card_examples = ""
377
+ try:
378
+ model_card = ModelCard.load(repo)
379
+ widget_data = model_card.data.get("widget", [])
380
+ if widget_data and len(widget_data) > 0:
381
+ examples_html = '<div style="margin-top: 10px;">'
382
+ examples_html += '<h4 style="margin-bottom: 8px; font-size: 0.9em;">Sample Images:</h4>'
383
+ examples_html += '<div style="display: grid; grid-template-columns: repeat(4, 1fr); gap: 8px;">'
384
+ for i, example in enumerate(widget_data[:4]):
385
+ if "output" in example and "url" in example["output"]:
386
+ image_url = f"https://huggingface.co/{repo}/resolve/main/{example['output']['url']}"
387
+ caption = example.get("text", f"Example {i+1}")
388
+ examples_html += f'''
389
+ <div style="text-align: center;">
390
+ <img src="{image_url}" style="width: 100%; height: auto; border-radius: 4px;" />
391
+ <p style="font-size: 0.7em; margin: 2px 0;">{caption[:30]}{'...' if len(caption) > 30 else ''}</p>
392
+ </div>
393
+ '''
394
+ examples_html += '</div></div>'
395
+ model_card_examples = examples_html
396
+ except Exception as e:
397
+ print(f"Could not load model card examples for custom LoRA: {e}")
398
+ card = f'''
399
+ <div class="custom_lora_card">
400
+ <span>Loaded custom LoRA:</span>
401
+ <div class="card_internal">
402
+ <img src="{image}" />
403
+ <div>
404
+ <h3>{title}</h3>
405
+ <small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br></small>
406
+ </div>
407
+ </div>
408
+ {model_card_examples}
409
+ </div>
410
+ '''
411
+ existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
412
+ if existing_item_index is None:
413
+ new_item = {"image": image, "title": title, "repo": repo, "weights": path, "trigger_word": trigger_word}
414
+ print(new_item)
415
+ loras.append(new_item)
416
+ existing_item_index = len(loras) - 1
417
+ return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word, gr.update(interactive=True)
418
+ except Exception as e:
419
+ full_traceback = traceback.format_exc()
420
+ print(f"Full traceback:\n{full_traceback}")
421
+ gr.Warning(f"Invalid LoRA: either you entered an invalid link, or a non-Qwen-Image LoRA, this was the issue: {e}")
422
+ return gr.update(visible=True, value=f"Invalid LoRA: either you entered an invalid link, a non-Qwen-Image LoRA"), gr.update(visible=True), gr.update(), "", None, "", gr.update(interactive=False)
423
+ else:
424
+ return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, "", gr.update(interactive=False)
425
+
426
+ def remove_custom_lora():
427
+ return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, "", gr.update(interactive=False)
428
+
429
+
430
+ run_lora.zerogpu = True
431
+
432
+ css = '''
433
+ #gen_btn{height: 100%}
434
+ #gen_column{align-self: stretch}
435
+ #title{text-align: center}
436
+ #title h1{font-size: 3em; display:inline-flex; align-items:center}
437
+ #title img{width: 100px; margin-right: 0.5em}
438
+ #gallery .grid-wrap{height: 10vh}
439
+ #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
440
+ .card_internal{display: flex;height: 100px;margin-top: .5em}
441
+ .card_internal img{margin-right: 1em}
442
+ .styler{--form-gap-width: 0px !important}
443
+ #speed_status{padding: .5em; border-radius: 5px; margin: 1em 0}
444
+ '''
445
+
446
+ with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 60)) as app:
447
+ title = gr.HTML(
448
+ """<img src=\"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/qwen_image_logo.png\" alt=\"Qwen-Image\" style=\"width: 280px; margin: 0 auto\">
449
+ <h3 style=\"margin-top: -10px\">LoRA🦜 ChoquinLabs Explorer</h3>""",
450
+ elem_id="title",
451
+ )
452
+
453
+ selected_index = gr.State(None)
454
+
455
+ with gr.Row():
456
+ with gr.Column(scale=3):
457
+ prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
458
+ with gr.Column(scale=1, elem_id="gen_column"):
459
+ generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn", interactive=False)
460
+
461
+ with gr.Row():
462
+ with gr.Column():
463
+ selected_info = gr.Markdown("")
464
+ examples_component = gr.Examples(examples=[], inputs=[prompt], label="Sample Prompts", visible=False)
465
+ gallery = gr.Gallery(
466
+ [(item["image"], item["title"]) for item in loras],
467
+ label="LoRA Gallery",
468
+ allow_preview=False,
469
+ columns=3,
470
+ elem_id="gallery",
471
+ show_share_button=False
472
+ )
473
+ with gr.Group():
474
+ custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path", placeholder="username/qwen-image-custom-lora")
475
+ gr.Markdown("[Check Qwen-Image LoRAs](https://huggingface.co/models?other=base_model:adapter:Qwen/Qwen-Image)", elem_id="lora_list")
476
+ custom_lora_info = gr.HTML(visible=False)
477
+ custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
478
+
479
+ with gr.Column():
480
+ result = gr.Gallery(label="Generated Images", show_label=True, elem_id="result_gallery")
481
+
482
+ ### MODIFICACIÓN 2: AÑADIR LOS COMPONENTES DE LA UI DEL HISTORIAL ###
483
+ with gr.Group():
484
+ with gr.Row():
485
+ gr.Markdown("### 📜 History")
486
+ clear_history_button = gr.Button("🗑️ Clear History", size="sm")
487
+
488
+ history_gallery = gr.Gallery(
489
+ label="Generation History",
490
+ show_label=False,
491
+ columns=4,
492
+ object_fit="contain",
493
+ height="auto",
494
+ interactive=False
495
+ )
496
+ ### FIN DE LA MODIFICACIÓN 2 ###
497
+
498
+ with gr.Row():
499
+ with gr.Column():
500
+ speed_mode = gr.Radio(
501
+ label="Generation Mode",
502
+ choices=["light 4", "light 4 fp8", "light 8", "normal"],
503
+ value="light 4",
504
+ info="'light' modes use Lightning LoRA for faster generation"
505
+ )
506
+ with gr.Column():
507
+ quantity = gr.Radio(
508
+ label="Quantity",
509
+ choices=["1", "2", "3", "4"],
510
+ value="1",
511
+ type="index"
512
+ )
513
+
514
+ speed_status = gr.Markdown("Quality mode active", elem_id="speed_status")
515
+
516
+ with gr.Row():
517
+ aspect_ratio = gr.Radio(
518
+ label="Aspect Ratio",
519
+ choices=["1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3", "4:1", "3:1", "2:1"],
520
+ value="16:9"
521
+ )
522
+
523
+ with gr.Row():
524
+ quality_multiplier = gr.Radio(
525
+ label="Quality (Size Multiplier)",
526
+ choices=["0.5x", "1x", "1.5x"],
527
+ value="1x"
528
+ )
529
+
530
+ with gr.Row():
531
+ with gr.Accordion("Advanced Settings", open=False):
532
+ with gr.Column():
533
+ with gr.Row():
534
+ cfg_scale = gr.Slider(
535
+ label="Guidance Scale (True CFG)",
536
+ minimum=1.0,
537
+ maximum=5.0,
538
+ step=0.1,
539
+ value=3.5,
540
+ info="Lower for speed mode, higher for quality"
541
+ )
542
+ steps = gr.Slider(
543
+ label="Steps",
544
+ minimum=4,
545
+ maximum=50,
546
+ step=1,
547
+ value=45,
548
+ info="Automatically set by speed mode"
549
+ )
550
+
551
+ with gr.Row():
552
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
553
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
554
+ lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=3, step=0.01, value=1.0)
555
+
556
+ # Event handlers
557
+ gallery.select(
558
+ update_selection,
559
+ inputs=[aspect_ratio],
560
+ outputs=[prompt, selected_info, selected_index, aspect_ratio, generate_button]
561
+ )
562
+
563
+ speed_mode.change(
564
+ handle_speed_mode,
565
+ inputs=[speed_mode],
566
+ outputs=[speed_status, steps, cfg_scale]
567
+ )
568
+
569
+ custom_lora.input(
570
+ add_custom_lora,
571
+ inputs=[custom_lora],
572
+ outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt, generate_button]
573
+ )
574
+
575
+ custom_lora_button.click(
576
+ remove_custom_lora,
577
+ outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora, generate_button]
578
+ )
579
+
580
+ ### MODIFICACIÓN 3: CONECTAR LOS EVENTOS DEL HISTORIAL ###
581
+ # Evento principal de generación
582
+ generate_event = gr.on(
583
+ triggers=[generate_button.click, prompt.submit],
584
+ fn=run_lora,
585
+ inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, aspect_ratio, lora_scale, speed_mode, quality_multiplier, quantity],
586
+ outputs=[result, seed]
587
+ )
588
+
589
+ # Encadenar la actualización del historial para que se ejecute DESPUÉS de la generación
590
+ generate_event.then(
591
+ fn=update_history,
592
+ inputs=[result, history_gallery],
593
+ outputs=history_gallery,
594
+ show_api=False # No es necesario mostrar esto en la API
595
+ )
596
+
597
+ # Evento para el botón de limpiar historial
598
+ clear_history_button.click(
599
+ fn=clear_history,
600
+ inputs=None,
601
+ outputs=history_gallery,
602
+ show_api=False
603
+ )
604
+ ### FIN DE LA MODIFICACIÓN 3 ###
605
+
606
+ app.load(
607
+ fn=handle_speed_mode,
608
+ inputs=[gr.State("light 4")],
609
+ outputs=[speed_status, steps, cfg_scale]
610
+ )
611
+
612
+ app.queue()
613
+ app.launch()
app.py CHANGED
@@ -14,29 +14,31 @@ import re
14
  import math
15
  import numpy as np
16
  import traceback
 
17
 
18
- # Load LoRAs from JSON file
 
 
19
  def load_loras_from_file():
20
- """Load LoRA configurations from external JSON file."""
21
  try:
22
  with open('loras.json', 'r', encoding='utf-8') as f:
23
  return json.load(f)
24
  except FileNotFoundError:
25
- print("Warning: loras.json file not found. Using empty list.")
26
  return []
27
  except json.JSONDecodeError as e:
28
- print(f"Error parsing loras.json: {e}")
29
  return []
30
 
31
- # Load the LoRAs
32
  loras = load_loras_from_file()
33
 
34
- # Initialize the base model
 
 
35
  dtype = torch.bfloat16
36
  device = "cuda" if torch.cuda.is_available() else "cpu"
37
  base_model = "Qwen/Qwen-Image"
38
 
39
- # Scheduler configuration from the Qwen-Image-Lightning repository
40
  scheduler_config = {
41
  "base_image_seq_len": 256,
42
  "base_shift": math.log(3),
@@ -55,559 +57,185 @@ scheduler_config = {
55
  }
56
 
57
  scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
58
- pipe = DiffusionPipeline.from_pretrained(
59
- base_model, scheduler=scheduler, torch_dtype=dtype
60
- ).to(device)
61
 
62
- # Lightning LoRA info (no global state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  LIGHTNING_LORA_REPO = "lightx2v/Qwen-Image-Lightning"
64
  LIGHTNING_LORA_WEIGHT = "Qwen-Image-Lightning-4steps-V2.0-bf16.safetensors"
65
  LIGHTNING8_LORA_WEIGHT = "Qwen-Image-Lightning-8steps-V2.0-bf16.safetensors"
66
  LIGHTNING_FP8_4STEPS_LORA_WEIGHT = "Qwen-Image-fp8-e4m3fn-Lightning-4steps-V1.0-bf16.safetensors"
67
-
68
-
69
  MAX_SEED = np.iinfo(np.int32).max
70
 
71
- ### MODIFICACIÓN 1: AÑADIR FUNCIONES PARA GESTIONAR EL HISTORIAL ###
 
 
72
  def update_history(new_images, history):
73
- """Añade las nuevas imágenes generadas al principio de la lista del historial."""
74
- # Gradio pasa el valor actual de la galería de historial como una lista
75
  if history is None:
76
  history = []
77
- if new_images is not None and len(new_images) > 0:
78
- # Añade las nuevas imágenes al principio de la lista existente
79
- updated_history = new_images + history
80
- # Limita el historial a un tamaño razonable (ej. 24 imágenes) para no usar demasiada memoria
81
- return updated_history[:24]
82
  return history
83
 
84
  def clear_history():
85
- """Devuelve una lista vacía para limpiar la galería de historial."""
86
  return []
87
- ### FIN DE LA MODIFICACIÓN 1 ###
88
-
89
 
90
  class calculateDuration:
91
- def __init__(self, activity_name=""):
92
- self.activity_name = activity_name
93
-
94
  def __enter__(self):
95
- self.start_time = time.time()
96
  return self
97
-
98
- def __exit__(self, exc_type, exc_value, traceback):
99
- self.end_time = time.time()
100
- self.elapsed_time = self.end_time - self.start_time
101
- if self.activity_name:
102
- print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
103
- else:
104
- print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
105
-
106
- def get_image_size(aspect_ratio):
107
- """Converts aspect ratio string to width, height tuple."""
108
- if aspect_ratio == "1:1":
109
- return 1024, 1024
110
- elif aspect_ratio == "16:9":
111
- return 1152, 640
112
- elif aspect_ratio == "9:16":
113
- return 640, 1152
114
- elif aspect_ratio == "4:3":
115
- return 1024, 768
116
- elif aspect_ratio == "3:4":
117
- return 768, 1024
118
- elif aspect_ratio == "3:2":
119
- return 1024, 688
120
- elif aspect_ratio == "2:3":
121
- return 688, 1024
122
- elif aspect_ratio == "4:1":
123
- return 2560, 640
124
- elif aspect_ratio == "3:1":
125
- return 1920, 640
126
- elif aspect_ratio == "2:1":
127
- return 1280, 640
128
- else:
129
- return 1024, 1024
130
-
131
- def update_selection(evt: gr.SelectData, aspect_ratio):
132
- selected_lora = loras[evt.index]
133
- new_placeholder = f"Type a prompt for {selected_lora['title']}"
134
- lora_repo = selected_lora["repo"]
135
- updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
136
-
137
- examples_list = []
138
- try:
139
- model_card = ModelCard.load(lora_repo)
140
- widget_data = model_card.data.get("widget", [])
141
- if widget_data and len(widget_data) > 0:
142
- for example in widget_data[:4]:
143
- if "output" in example and "url" in example["output"]:
144
- image_url = f"https://huggingface.co/{lora_repo}/resolve/main/{example['output']['url']}"
145
- prompt_text = example.get("text", "")
146
- examples_list.append([prompt_text])
147
- except Exception as e:
148
- print(f"Could not load model card for {lora_repo}: {e}")
149
-
150
- return (
151
- gr.update(placeholder=new_placeholder),
152
- updated_text,
153
- evt.index,
154
- aspect_ratio,
155
- gr.update(interactive=True)
156
- )
157
-
158
- def handle_speed_mode(speed_mode):
159
- """Update UI based on speed/quality toggle."""
160
- if speed_mode == "light 4":
161
  return gr.update(value="Light mode (4 steps) selected"), 4, 1.0
162
- elif speed_mode == "light 4 fp8":
163
  return gr.update(value="Light mode (4 steps fp8) selected"), 4, 1.0
164
- elif speed_mode == "light 8":
165
  return gr.update(value="Light mode (8 steps) selected"), 8, 1.0
166
- else:
167
  return gr.update(value="Normal quality (45 steps) selected"), 45, 3.5
168
 
 
 
 
169
  @spaces.GPU(duration=70)
170
- def generate_image(
171
- prompt_mash,
172
- steps,
173
- seed,
174
- cfg_scale,
175
- width,
176
- height,
177
- lora_scale,
178
- negative_prompt="",
179
- num_images=1,
180
- ):
181
  pipe.to("cuda")
182
-
183
  seeds = [seed + (i * 100) for i in range(num_images)]
184
  generators = [torch.Generator(device="cuda").manual_seed(s) for s in seeds]
185
-
186
- with calculateDuration("Generating image"):
187
- result = pipe(
188
- prompt=prompt_mash,
189
- negative_prompt=negative_prompt,
190
- num_inference_steps=steps,
191
- true_cfg_scale=cfg_scale,
192
- width=width,
193
- height=height,
194
- num_images_per_prompt=num_images,
195
- generator=generators,
196
- )
197
-
198
- images = [(img, s) for img, s in zip(result.images, seeds)]
199
- return images
200
 
201
  @spaces.GPU(duration=70)
202
- def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, aspect_ratio, lora_scale, speed_mode, quality_multiplier, quantity, progress=gr.Progress(track_tqdm=True)):
 
203
  if selected_index is None:
204
- raise gr.Error("You must select a LoRA before proceeding.")
205
-
206
  selected_lora = loras[selected_index]
207
  lora_path = selected_lora["repo"]
208
  trigger_word = selected_lora["trigger_word"]
209
-
210
- if trigger_word:
211
- if "trigger_position" in selected_lora:
212
- if selected_lora["trigger_position"] == "prepend":
213
- prompt_mash = f"{trigger_word} {prompt}"
214
- else:
215
- prompt_mash = f"{prompt} {trigger_word}"
216
- else:
217
- prompt_mash = f"{trigger_word} {prompt}"
218
- else:
219
- prompt_mash = prompt
220
 
221
  with calculateDuration("Unloading existing LoRAs"):
222
  pipe.unload_lora_weights()
223
 
 
224
  if speed_mode == "light 4":
225
- with calculateDuration("Loading Lightning LoRA and style LoRA"):
226
- pipe.load_lora_weights(
227
- LIGHTNING_LORA_REPO,
228
- weight_name=LIGHTNING_LORA_WEIGHT,
229
- adapter_name="lightning"
230
- )
231
- weight_name = selected_lora.get("weights", None)
232
- pipe.load_lora_weights(
233
- lora_path,
234
- weight_name=weight_name,
235
- low_cpu_mem_usage=True,
236
- adapter_name="style"
237
- )
238
- pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
239
- elif speed_mode == "light 4 fp8":
240
- with calculateDuration("Loading Lightning LoRA and style LoRA"):
241
- pipe.load_lora_weights(
242
- LIGHTNING_LORA_REPO,
243
- weight_name=LIGHTNING_FP8_4STEPS_LORA_WEIGHT,
244
- adapter_name="lightning"
245
- )
246
- weight_name = selected_lora.get("weights", None)
247
- pipe.load_lora_weights(
248
- lora_path,
249
- weight_name=weight_name,
250
- low_cpu_mem_usage=True,
251
- adapter_name="style"
252
- )
253
- pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
254
  elif speed_mode == "light 8":
255
- with calculateDuration("Loading Lightning LoRA and style LoRA"):
256
- pipe.load_lora_weights(
257
- LIGHTNING_LORA_REPO,
258
- weight_name=LIGHTNING8_LORA_WEIGHT,
259
- adapter_name="lightning"
260
- )
261
- weight_name = selected_lora.get("weights", None)
262
- pipe.load_lora_weights(
263
- lora_path,
264
- weight_name=weight_name,
265
- low_cpu_mem_usage=True,
266
- adapter_name="style"
267
- )
268
- pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
269
  else:
270
- with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
271
- weight_name = selected_lora.get("weights", None)
272
- pipe.load_lora_weights(
273
- lora_path,
274
- weight_name=weight_name,
275
- low_cpu_mem_usage=True,
276
- adapter_name="style"
277
- )
 
278
  pipe.set_adapters(["style"], adapter_weights=[lora_scale])
279
-
280
- with calculateDuration("Randomizing seed"):
281
- if randomize_seed:
282
- seed = random.randint(0, MAX_SEED)
283
-
284
  width, height = get_image_size(aspect_ratio)
285
  multiplier = float(quality_multiplier.replace('x', ''))
286
- width = int(width * multiplier)
287
- height = int(height * multiplier)
288
  num_images = int(quantity) + 1
289
 
290
- pairs = generate_image(
291
- prompt_mash,
292
- steps,
293
- seed,
294
- cfg_scale,
295
- width,
296
- height,
297
- lora_scale,
298
- negative_prompt="",
299
- num_images=num_images,
300
- )
301
-
302
- images_for_gallery = [
303
- (img, str(s))
304
- for (img, s) in pairs
305
- ]
306
-
307
- return images_for_gallery, seed
308
-
309
- # ... (El resto de las funciones como get_huggingface_safetensors, check_custom_model, etc., permanecen sin cambios) ...
310
- def get_huggingface_safetensors(link):
311
- split_link = link.split("/")
312
- if len(split_link) != 2:
313
- raise Exception("Invalid Hugging Face repository link format.")
314
- print(f"Repository attempted: {split_link}")
315
- model_card = ModelCard.load(link)
316
- base_model = model_card.data.get("base_model")
317
- print(f"Base model: {base_model}")
318
- acceptable_models = {"Qwen/Qwen-Image"}
319
- models_to_check = base_model if isinstance(base_model, list) else [base_model]
320
- if not any(model in acceptable_models for model in models_to_check):
321
- raise Exception("Not a Qwen-Image LoRA!")
322
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
323
- trigger_word = model_card.data.get("instance_prompt", "")
324
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
325
- fs = HfFileSystem()
326
- try:
327
- list_of_files = fs.ls(link, detail=False)
328
- safetensors_name = None
329
- for file in list_of_files:
330
- filename = file.split("/")[-1]
331
- if filename.endswith(".safetensors"):
332
- safetensors_name = filename
333
- break
334
- if not safetensors_name:
335
- raise Exception("No valid *.safetensors file found in the repository.")
336
- except Exception as e:
337
- print(e)
338
- raise Exception("You didn't include a valid Hugging Face repository with a *.safetensors LoRA")
339
- return split_link[1], link, safetensors_name, trigger_word, image_url
340
-
341
- def check_custom_model(link):
342
- print(f"Checking a custom model on: {link}")
343
- if link.endswith('.safetensors'):
344
- if 'huggingface.co' in link:
345
- parts = link.split('/')
346
- try:
347
- hf_index = parts.index('huggingface.co')
348
- username = parts[hf_index + 1]
349
- repo_name = parts[hf_index + 2]
350
- repo = f"{username}/{repo_name}"
351
- safetensors_name = parts[-1]
352
- try:
353
- model_card = ModelCard.load(repo)
354
- trigger_word = model_card.data.get("instance_prompt", "")
355
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
356
- image_url = f"https://huggingface.co/{repo}/resolve/main/{image_path}" if image_path else None
357
- except:
358
- trigger_word = ""
359
- image_url = None
360
- return repo_name, repo, safetensors_name, trigger_word, image_url
361
- except:
362
- raise Exception("Invalid safetensors URL format")
363
- if link.startswith("https://"):
364
- if link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co"):
365
- link_split = link.split("huggingface.co/")
366
- return get_huggingface_safetensors(link_split[1])
367
- else:
368
- return get_huggingface_safetensors(link)
369
-
370
- def add_custom_lora(custom_lora):
371
- global loras
372
- if custom_lora:
373
- try:
374
- title, repo, path, trigger_word, image = check_custom_model(custom_lora)
375
- print(f"Loaded custom LoRA: {repo}")
376
- model_card_examples = ""
377
- try:
378
- model_card = ModelCard.load(repo)
379
- widget_data = model_card.data.get("widget", [])
380
- if widget_data and len(widget_data) > 0:
381
- examples_html = '<div style="margin-top: 10px;">'
382
- examples_html += '<h4 style="margin-bottom: 8px; font-size: 0.9em;">Sample Images:</h4>'
383
- examples_html += '<div style="display: grid; grid-template-columns: repeat(4, 1fr); gap: 8px;">'
384
- for i, example in enumerate(widget_data[:4]):
385
- if "output" in example and "url" in example["output"]:
386
- image_url = f"https://huggingface.co/{repo}/resolve/main/{example['output']['url']}"
387
- caption = example.get("text", f"Example {i+1}")
388
- examples_html += f'''
389
- <div style="text-align: center;">
390
- <img src="{image_url}" style="width: 100%; height: auto; border-radius: 4px;" />
391
- <p style="font-size: 0.7em; margin: 2px 0;">{caption[:30]}{'...' if len(caption) > 30 else ''}</p>
392
- </div>
393
- '''
394
- examples_html += '</div></div>'
395
- model_card_examples = examples_html
396
- except Exception as e:
397
- print(f"Could not load model card examples for custom LoRA: {e}")
398
- card = f'''
399
- <div class="custom_lora_card">
400
- <span>Loaded custom LoRA:</span>
401
- <div class="card_internal">
402
- <img src="{image}" />
403
- <div>
404
- <h3>{title}</h3>
405
- <small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br></small>
406
- </div>
407
- </div>
408
- {model_card_examples}
409
- </div>
410
- '''
411
- existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
412
- if existing_item_index is None:
413
- new_item = {"image": image, "title": title, "repo": repo, "weights": path, "trigger_word": trigger_word}
414
- print(new_item)
415
- loras.append(new_item)
416
- existing_item_index = len(loras) - 1
417
- return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word, gr.update(interactive=True)
418
- except Exception as e:
419
- full_traceback = traceback.format_exc()
420
- print(f"Full traceback:\n{full_traceback}")
421
- gr.Warning(f"Invalid LoRA: either you entered an invalid link, or a non-Qwen-Image LoRA, this was the issue: {e}")
422
- return gr.update(visible=True, value=f"Invalid LoRA: either you entered an invalid link, a non-Qwen-Image LoRA"), gr.update(visible=True), gr.update(), "", None, "", gr.update(interactive=False)
423
- else:
424
- return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, "", gr.update(interactive=False)
425
-
426
- def remove_custom_lora():
427
- return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, "", gr.update(interactive=False)
428
-
429
 
430
  run_lora.zerogpu = True
431
 
 
 
 
432
  css = '''
433
- #gen_btn{height: 100%}
434
- #gen_column{align-self: stretch}
435
- #title{text-align: center}
436
- #title h1{font-size: 3em; display:inline-flex; align-items:center}
437
- #title img{width: 100px; margin-right: 0.5em}
438
- #gallery .grid-wrap{height: 10vh}
439
- #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
440
- .card_internal{display: flex;height: 100px;margin-top: .5em}
441
- .card_internal img{margin-right: 1em}
442
- .styler{--form-gap-width: 0px !important}
443
- #speed_status{padding: .5em; border-radius: 5px; margin: 1em 0}
444
  '''
445
 
446
- with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 60)) as app:
447
- title = gr.HTML(
448
- """<img src=\"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/qwen_image_logo.png\" alt=\"Qwen-Image\" style=\"width: 280px; margin: 0 auto\">
449
- <h3 style=\"margin-top: -10px\">LoRA🦜 ChoquinLabs Explorer</h3>""",
450
- elem_id="title",
451
- )
452
-
453
  selected_index = gr.State(None)
454
-
455
- with gr.Row():
456
- with gr.Column(scale=3):
457
- prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
458
- with gr.Column(scale=1, elem_id="gen_column"):
459
- generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn", interactive=False)
460
-
461
- with gr.Row():
462
- with gr.Column():
463
- selected_info = gr.Markdown("")
464
- examples_component = gr.Examples(examples=[], inputs=[prompt], label="Sample Prompts", visible=False)
465
- gallery = gr.Gallery(
466
- [(item["image"], item["title"]) for item in loras],
467
- label="LoRA Gallery",
468
- allow_preview=False,
469
- columns=3,
470
- elem_id="gallery",
471
- show_share_button=False
472
- )
473
- with gr.Group():
474
- custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path", placeholder="username/qwen-image-custom-lora")
475
- gr.Markdown("[Check Qwen-Image LoRAs](https://huggingface.co/models?other=base_model:adapter:Qwen/Qwen-Image)", elem_id="lora_list")
476
- custom_lora_info = gr.HTML(visible=False)
477
- custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
478
-
479
- with gr.Column():
480
- result = gr.Gallery(label="Generated Images", show_label=True, elem_id="result_gallery")
481
-
482
- ### MODIFICACIÓN 2: AÑADIR LOS COMPONENTES DE LA UI DEL HISTORIAL ###
483
- with gr.Group():
484
- with gr.Row():
485
- gr.Markdown("### 📜 History")
486
- clear_history_button = gr.Button("🗑️ Clear History", size="sm")
487
-
488
- history_gallery = gr.Gallery(
489
- label="Generation History",
490
- show_label=False,
491
- columns=4,
492
- object_fit="contain",
493
- height="auto",
494
- interactive=False
495
- )
496
- ### FIN DE LA MODIFICACIÓN 2 ###
497
-
498
- with gr.Row():
499
- with gr.Column():
500
- speed_mode = gr.Radio(
501
- label="Generation Mode",
502
- choices=["light 4", "light 4 fp8", "light 8", "normal"],
503
- value="light 4",
504
- info="'light' modes use Lightning LoRA for faster generation"
505
- )
506
- with gr.Column():
507
- quantity = gr.Radio(
508
- label="Quantity",
509
- choices=["1", "2", "3", "4"],
510
- value="1",
511
- type="index"
512
- )
513
-
514
- speed_status = gr.Markdown("Quality mode active", elem_id="speed_status")
515
-
516
- with gr.Row():
517
- aspect_ratio = gr.Radio(
518
- label="Aspect Ratio",
519
- choices=["1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3", "4:1", "3:1", "2:1"],
520
- value="16:9"
521
- )
522
-
523
- with gr.Row():
524
- quality_multiplier = gr.Radio(
525
- label="Quality (Size Multiplier)",
526
- choices=["0.5x", "1x", "1.5x"],
527
- value="1x"
528
- )
529
-
530
- with gr.Row():
531
- with gr.Accordion("Advanced Settings", open=False):
532
- with gr.Column():
533
- with gr.Row():
534
- cfg_scale = gr.Slider(
535
- label="Guidance Scale (True CFG)",
536
- minimum=1.0,
537
- maximum=5.0,
538
- step=0.1,
539
- value=3.5,
540
- info="Lower for speed mode, higher for quality"
541
- )
542
- steps = gr.Slider(
543
- label="Steps",
544
- minimum=4,
545
- maximum=50,
546
- step=1,
547
- value=45,
548
- info="Automatically set by speed mode"
549
- )
550
-
551
- with gr.Row():
552
- randomize_seed = gr.Checkbox(True, label="Randomize seed")
553
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
554
- lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=3, step=0.01, value=1.0)
555
 
556
- # Event handlers
557
- gallery.select(
558
- update_selection,
559
- inputs=[aspect_ratio],
560
- outputs=[prompt, selected_info, selected_index, aspect_ratio, generate_button]
561
- )
562
-
563
- speed_mode.change(
564
- handle_speed_mode,
565
- inputs=[speed_mode],
566
- outputs=[speed_status, steps, cfg_scale]
567
- )
568
-
569
- custom_lora.input(
570
- add_custom_lora,
571
- inputs=[custom_lora],
572
- outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt, generate_button]
573
- )
574
-
575
- custom_lora_button.click(
576
- remove_custom_lora,
577
- outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora, generate_button]
578
- )
579
-
580
- ### MODIFICACIÓN 3: CONECTAR LOS EVENTOS DEL HISTORIAL ###
581
- # Evento principal de generación
582
  generate_event = gr.on(
583
  triggers=[generate_button.click, prompt.submit],
584
  fn=run_lora,
585
- inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, aspect_ratio, lora_scale, speed_mode, quality_multiplier, quantity],
586
- outputs=[result, seed]
587
- )
588
-
589
- # Encadenar la actualización del historial para que se ejecute DESPUÉS de la generación
590
- generate_event.then(
591
- fn=update_history,
592
- inputs=[result, history_gallery],
593
- outputs=history_gallery,
594
- show_api=False # No es necesario mostrar esto en la API
595
- )
596
-
597
- # Evento para el botón de limpiar historial
598
- clear_history_button.click(
599
- fn=clear_history,
600
- inputs=None,
601
- outputs=history_gallery,
602
- show_api=False
603
- )
604
- ### FIN DE LA MODIFICACIÓN 3 ###
605
-
606
- app.load(
607
- fn=handle_speed_mode,
608
- inputs=[gr.State("light 4")],
609
- outputs=[speed_status, steps, cfg_scale]
610
  )
 
 
611
 
612
  app.queue()
613
- app.launch()
 
14
  import math
15
  import numpy as np
16
  import traceback
17
+ from spaces import aoti_capture, aoti_compile, aoti_apply # ✅ ZeroGPU AOT helpers
18
 
19
+ # =========================================================
20
+ # Load LoRAs
21
+ # =========================================================
22
  def load_loras_from_file():
 
23
  try:
24
  with open('loras.json', 'r', encoding='utf-8') as f:
25
  return json.load(f)
26
  except FileNotFoundError:
27
+ print("⚠️ Warning: loras.json not found. Using empty list.")
28
  return []
29
  except json.JSONDecodeError as e:
30
+ print(f"Error parsing loras.json: {e}")
31
  return []
32
 
 
33
  loras = load_loras_from_file()
34
 
35
+ # =========================================================
36
+ # Base model + scheduler
37
+ # =========================================================
38
  dtype = torch.bfloat16
39
  device = "cuda" if torch.cuda.is_available() else "cpu"
40
  base_model = "Qwen/Qwen-Image"
41
 
 
42
  scheduler_config = {
43
  "base_image_seq_len": 256,
44
  "base_shift": math.log(3),
 
57
  }
58
 
59
  scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
60
+ pipe = DiffusionPipeline.from_pretrained(base_model, scheduler=scheduler, torch_dtype=dtype).to(device)
 
 
61
 
62
+ # =========================================================
63
+ # ZeroGPU AOT Compilation Logic
64
+ # =========================================================
65
+ AOT_DIR = "./aoti_artifacts"
66
+ EXAMPLE_PROMPT = "a cute cat in a spacesuit"
67
+
68
+ if torch.cuda.is_available():
69
+ try:
70
+ if os.environ.get("SPACE_BUILD") == "1": # ✅ Solo durante build
71
+ print("🏗️ Space build detected. Compiling AOT artifacts...")
72
+ example_inputs = aoti_capture(
73
+ pipe,
74
+ inputs={
75
+ "prompt": EXAMPLE_PROMPT,
76
+ "num_inference_steps": 4,
77
+ "true_cfg_scale": 3.5,
78
+ "width": 1024,
79
+ "height": 1024,
80
+ "num_images_per_prompt": 1,
81
+ }
82
+ )
83
+ aoti_compile(pipe, example_inputs, output_dir=AOT_DIR, dynamic=False)
84
+ print("✅ AOT compilation completed successfully.")
85
+ else:
86
+ # En runtime normal: aplicar los artefactos ya compilados
87
+ if os.path.exists(AOT_DIR):
88
+ pipe = aoti_apply(pipe, AOT_DIR)
89
+ print("✅ Loaded precompiled AOT model.")
90
+ else:
91
+ print("⚠️ No AOT artifacts found. Running in JIT mode.")
92
+ except Exception as e:
93
+ print(f"⚠️ Skipping AOT setup: {e}")
94
+
95
+ # =========================================================
96
+ # Lightning LoRA configuration
97
+ # =========================================================
98
  LIGHTNING_LORA_REPO = "lightx2v/Qwen-Image-Lightning"
99
  LIGHTNING_LORA_WEIGHT = "Qwen-Image-Lightning-4steps-V2.0-bf16.safetensors"
100
  LIGHTNING8_LORA_WEIGHT = "Qwen-Image-Lightning-8steps-V2.0-bf16.safetensors"
101
  LIGHTNING_FP8_4STEPS_LORA_WEIGHT = "Qwen-Image-fp8-e4m3fn-Lightning-4steps-V1.0-bf16.safetensors"
 
 
102
  MAX_SEED = np.iinfo(np.int32).max
103
 
104
+ # =========================================================
105
+ # Utility functions
106
+ # =========================================================
107
  def update_history(new_images, history):
 
 
108
  if history is None:
109
  history = []
110
+ if new_images:
111
+ return (new_images + history)[:24]
 
 
 
112
  return history
113
 
114
  def clear_history():
 
115
  return []
 
 
116
 
117
  class calculateDuration:
118
+ def __init__(self, name=""):
119
+ self.name = name
 
120
  def __enter__(self):
121
+ self.start = time.time()
122
  return self
123
+ def __exit__(self, *args):
124
+ elapsed = time.time() - self.start
125
+ print(f"⏱️ {self.name}: {elapsed:.3f}s")
126
+
127
+ def get_image_size(ratio):
128
+ sizes = {
129
+ "1:1": (1024, 1024),
130
+ "16:9": (1152, 640),
131
+ "9:16": (640, 1152),
132
+ "4:3": (1024, 768),
133
+ "3:4": (768, 1024),
134
+ "3:2": (1024, 688),
135
+ "2:3": (688, 1024),
136
+ "4:1": (2560, 640),
137
+ "3:1": (1920, 640),
138
+ "2:1": (1280, 640),
139
+ }
140
+ return sizes.get(ratio, (1024, 1024))
141
+
142
+ def handle_speed_mode(mode):
143
+ if mode == "light 4":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  return gr.update(value="Light mode (4 steps) selected"), 4, 1.0
145
+ elif mode == "light 4 fp8":
146
  return gr.update(value="Light mode (4 steps fp8) selected"), 4, 1.0
147
+ elif mode == "light 8":
148
  return gr.update(value="Light mode (8 steps) selected"), 8, 1.0
149
+ else:
150
  return gr.update(value="Normal quality (45 steps) selected"), 45, 3.5
151
 
152
+ # =========================================================
153
+ # Core generation functions
154
+ # =========================================================
155
  @spaces.GPU(duration=70)
156
+ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, negative_prompt="", num_images=1):
 
 
 
 
 
 
 
 
 
 
157
  pipe.to("cuda")
 
158
  seeds = [seed + (i * 100) for i in range(num_images)]
159
  generators = [torch.Generator(device="cuda").manual_seed(s) for s in seeds]
160
+ with calculateDuration("Image generation"):
161
+ result = pipe(prompt=prompt_mash, negative_prompt=negative_prompt,
162
+ num_inference_steps=steps, true_cfg_scale=cfg_scale,
163
+ width=width, height=height, num_images_per_prompt=num_images,
164
+ generator=generators)
165
+ return [(img, s) for img, s in zip(result.images, seeds)]
 
 
 
 
 
 
 
 
 
166
 
167
  @spaces.GPU(duration=70)
168
+ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed,
169
+ aspect_ratio, lora_scale, speed_mode, quality_multiplier, quantity, progress=gr.Progress(track_tqdm=True)):
170
  if selected_index is None:
171
+ raise gr.Error("Select a LoRA first.")
 
172
  selected_lora = loras[selected_index]
173
  lora_path = selected_lora["repo"]
174
  trigger_word = selected_lora["trigger_word"]
175
+ prompt_mash = f"{trigger_word} {prompt}" if trigger_word else prompt
 
 
 
 
 
 
 
 
 
 
176
 
177
  with calculateDuration("Unloading existing LoRAs"):
178
  pipe.unload_lora_weights()
179
 
180
+ # Load LoRAs
181
  if speed_mode == "light 4":
182
+ weights = LIGHTNING_LORA_WEIGHT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  elif speed_mode == "light 8":
184
+ weights = LIGHTNING8_LORA_WEIGHT
185
+ elif speed_mode == "light 4 fp8":
186
+ weights = LIGHTNING_FP8_4STEPS_LORA_WEIGHT
 
 
 
 
 
 
 
 
 
 
 
187
  else:
188
+ weights = None
189
+
190
+ with calculateDuration("Loading LoRA weights"):
191
+ if weights:
192
+ pipe.load_lora_weights(LIGHTNING_LORA_REPO, weight_name=weights, adapter_name="lightning")
193
+ pipe.load_lora_weights(lora_path, adapter_name="style")
194
+ pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
195
+ else:
196
+ pipe.load_lora_weights(lora_path, adapter_name="style")
197
  pipe.set_adapters(["style"], adapter_weights=[lora_scale])
198
+
199
+ if randomize_seed:
200
+ seed = random.randint(0, MAX_SEED)
201
+
 
202
  width, height = get_image_size(aspect_ratio)
203
  multiplier = float(quality_multiplier.replace('x', ''))
204
+ width, height = int(width * multiplier), int(height * multiplier)
 
205
  num_images = int(quantity) + 1
206
 
207
+ pairs = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, "", num_images)
208
+ return [(img, str(s)) for img, s in pairs], seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  run_lora.zerogpu = True
211
 
212
+ # =========================================================
213
+ # Gradio UI
214
+ # =========================================================
215
  css = '''
216
+ #gen_btn{height:100%}
217
+ #gen_column{align-self:stretch}
 
 
 
 
 
 
 
 
 
218
  '''
219
 
220
+ with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60,60)) as app:
221
+ title = gr.HTML("<h3>LoRA🦜 ChoquinLabs Explorer</h3>")
 
 
 
 
 
222
  selected_index = gr.State(None)
223
+ prompt = gr.Textbox(label="Prompt", placeholder="Type a prompt after selecting a LoRA")
224
+ generate_button = gr.Button("Generate", variant="primary", interactive=False)
225
+ result = gr.Gallery(label="Generated Images")
226
+ history_gallery = gr.Gallery(label="History", columns=4)
227
+ clear_history_button = gr.Button("🗑️ Clear History")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  generate_event = gr.on(
230
  triggers=[generate_button.click, prompt.submit],
231
  fn=run_lora,
232
+ inputs=[prompt, gr.State(3.5), gr.State(45), selected_index,
233
+ gr.State(True), gr.State(0), gr.State("16:9"), gr.State(1.0),
234
+ gr.State("light 4"), gr.State("1x"), gr.State("1")],
235
+ outputs=[result, gr.State(0)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  )
237
+ generate_event.then(fn=update_history, inputs=[result, history_gallery], outputs=history_gallery)
238
+ clear_history_button.click(fn=clear_history, outputs=history_gallery)
239
 
240
  app.queue()
241
+ app.launch()
requirements.txt CHANGED
@@ -3,4 +3,10 @@ transformers
3
  accelerate
4
  safetensors
5
  peft
6
- sentencepiece
 
 
 
 
 
 
 
3
  accelerate
4
  safetensors
5
  peft
6
+ sentencepiece
7
+ torch>=2.3
8
+ gradio>=4.38
9
+ spaces
10
+ huggingface_hub
11
+ Pillow
12
+ numpy
runtime.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Python runtime version
2
+ python-3.11
3
+
4
+ # PyTorch version compatible con AOTInductor
5
+ torch-2.3.1+cu121