# optimization.py from typing import Any, Callable, ParamSpec import spaces import torch from torch.utils._pytree import tree_map P = ParamSpec("P") TEXT_SEQ_LENGTH = 12 IMAGE_SEQ_LENGTH = 4096 INDUCTOR_CONFIGS = { "conv_1x1_as_mm": True, "epilogue_fusion": False, "coordinate_descent_tuning": True, "coordinate_descent_check_all_directions": True, "max_autotune": True, "triton.cudagraphs": True, } def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs): """ Versión estable y comprobada para tu Space con Qwen-Image. Corrige completamente la estructura de img_shapes (lista de listas) y evita todos los UserError de dynamic_shapes. """ if not torch.cuda.is_available(): print("⚠️ CUDA no disponible. Se omite AOT.") return pipeline try: @spaces.GPU(duration=1500) def compile_transformer(): print("🏗️ Capturando modelo para AOT...") with spaces.aoti_capture(pipeline.transformer) as call: pipeline(*args, **kwargs) # Construimos dynamic_shapes desde los kwargs originales dynamic_shapes = tree_map(lambda t: None, call.kwargs) # Definimos shapes estáticos fijos y seguros static_shapes = { "hidden_states": {1: IMAGE_SEQ_LENGTH}, "encoder_hidden_states": {1: TEXT_SEQ_LENGTH}, "encoder_hidden_states_mask": {1: TEXT_SEQ_LENGTH}, # 👇 clave final: lista de listas "img_shapes": [[None, None]], } # Aplicamos solo las claves válidas for k, v in static_shapes.items(): if k in call.kwargs: dynamic_shapes[k] = v print("🚀 Exportando modelo con torch.export...") exported = torch.export.export( mod=pipeline.transformer, args=call.args, kwargs=call.kwargs, dynamic_shapes=dynamic_shapes, ) print("⚙️ Compilando con AOTInductor...") return spaces.aoti_compile(exported, INDUCTOR_CONFIGS) print("🧠 Aplicando AOT al transformer...") spaces.aoti_apply(compile_transformer(), pipeline.transformer) print("✅ AOT aplicado correctamente al transformer de Qwen-Image.") except Exception as e: print(f"⚠️ Error al aplicar AOT: {e}") return pipeline