aryadomain commited on
Commit
533920b
·
verified ·
1 Parent(s): ef8f3ad

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Reward_sana_idealized/README.md +41 -0
  2. Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_1/evaluation_results.txt +4 -0
  3. Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_1/log.log +203 -0
  4. Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_2/evaluation_results.txt +4 -0
  5. Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_2/log.log +258 -0
  6. Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_2/lr_curve.png +0 -0
  7. Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_3/evaluation_results.txt +4 -0
  8. Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_3/log.log +218 -0
  9. Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_3/lr_curve.png +0 -0
  10. Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/evaluation_results.txt +4 -0
  11. Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/log.log +218 -0
  12. Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/lr_curve.png +0 -0
  13. Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/rewards_curve.png +0 -0
  14. Reward_sana_idealized/__pycache__/eval.cpython-311.pyc +0 -0
  15. Reward_sana_idealized/__pycache__/gradient_ascent_utils.cpython-311.pyc +0 -0
  16. Reward_sana_idealized/blip/__init__.py +1 -0
  17. Reward_sana_idealized/blip/__pycache__/__init__.cpython-311.pyc +0 -0
  18. Reward_sana_idealized/blip/__pycache__/blip.cpython-311.pyc +0 -0
  19. Reward_sana_idealized/blip/__pycache__/blip_pretrain.cpython-311.pyc +0 -0
  20. Reward_sana_idealized/blip/__pycache__/med.cpython-311.pyc +0 -0
  21. Reward_sana_idealized/blip/blip.py +70 -0
  22. Reward_sana_idealized/blip/blip_pretrain.py +43 -0
  23. Reward_sana_idealized/config_analysis_tuning.ipynb +218 -0
  24. Reward_sana_idealized/eval.py +1447 -0
  25. Reward_sana_idealized/examples.sh +162 -0
  26. Reward_sana_idealized/grad_ascent_configs.py +67 -0
  27. Reward_sana_idealized/gradient_ascent_utils.py +391 -0
  28. Reward_sana_idealized/hpsv2_score.py +110 -0
  29. Reward_sana_idealized/imagereward_score.py +221 -0
  30. Reward_sana_idealized/lr_scheduler.py +233 -0
  31. Reward_sana_idealized/models/__pycache__/__init__.cpython-311.pyc +0 -0
  32. Reward_sana_idealized/open_clip/__pycache__/coca_model.cpython-311.pyc +0 -0
  33. Reward_sana_idealized/open_clip/__pycache__/factory.cpython-311.pyc +0 -0
  34. Reward_sana_idealized/open_clip/__pycache__/model.cpython-311.pyc +0 -0
  35. Reward_sana_idealized/open_clip/__pycache__/modified_resnet.cpython-311.pyc +0 -0
  36. Reward_sana_idealized/open_clip/__pycache__/pretrained.cpython-311.pyc +0 -0
  37. Reward_sana_idealized/open_clip/__pycache__/push_to_hf_hub.cpython-311.pyc +0 -0
  38. Reward_sana_idealized/open_clip/__pycache__/timm_model.cpython-311.pyc +0 -0
  39. Reward_sana_idealized/open_clip/__pycache__/tokenizer.cpython-311.pyc +0 -0
  40. Reward_sana_idealized/open_clip/__pycache__/transformer.cpython-311.pyc +0 -0
  41. Reward_sana_idealized/open_clip/model_configs/convnext_xlarge.json +19 -0
  42. Reward_sana_idealized/pick_score.py +141 -0
  43. Reward_sana_idealized/test.ipynb +47 -0
  44. Reward_sana_idealized/tune_hyperparams.py +514 -0
  45. Reward_sana_idealized/tune_parallel.sh +253 -0
  46. Reward_sdxl_idealized/models/__pycache__/__init__.cpython-310.pyc +0 -0
  47. Reward_sdxl_idealized/models/__pycache__/__init__.cpython-313.pyc +0 -0
  48. Reward_sdxl_idealized/models/__pycache__/__init__.cpython-39.pyc +0 -0
  49. Reward_sdxl_idealized/models/__pycache__/reward_model.cpython-39.pyc +0 -0
  50. Reward_sdxl_idealized/models/__pycache__/reward_model_sdxl.cpython-310.pyc +0 -0
Reward_sana_idealized/README.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reward SANA Idealized
2
+
3
+ This folder is a SANA-only reward-guided inference package.
4
+
5
+ ## What is inside
6
+
7
+ - `models/reward_model.py`
8
+ - Local SANA reward wrapper (no trainer import from other directories).
9
+ - Loads base SANA diffusers modules and local reward checkpoint weights.
10
+ - `pipelines/sana_reward_pipeline.py`
11
+ - SANA pipeline with per-step reward tracking.
12
+ - `pipelines/sana_gradient_ascent_pipeline.py`
13
+ - SANA pipeline with gradient-ascent latent updates.
14
+ - `eval.py`
15
+ - End-to-end evaluation script.
16
+ - `examples.sh`
17
+ - Cluster entrypoint for prefetch and evaluation.
18
+
19
+ ## Default checkpoint
20
+
21
+ `examples.sh` defaults to:
22
+
23
+ `/g/data/rr81/LPO/lrm/lrm_sana/logs/v8/reward_model/step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951/checkpoint-gstep76000`
24
+
25
+ Override with:
26
+
27
+ ```bash
28
+ LRM_MODEL_PATH=/path/to/checkpoint-dir-or-model.safetensors
29
+ ```
30
+
31
+ ## Run (10-sample smoke test)
32
+
33
+ ```bash
34
+ cd /g/data/rr81/LPO/Reward_sana_idealized
35
+ OFFLINE_MODE=1 MAX_SAMPLES=10 MODE=gradient_ascent MODEL_PROFILE=sana_600m_512 ./examples.sh
36
+ ```
37
+
38
+ ## Notes
39
+
40
+ - Uses existing Python env: `/g/data/rr81/aev/bin/python`.
41
+ - GPU nodes should run with offline HF cache.
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_1/evaluation_results.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ mode: baseline
2
+ metrics: ['clip', 'aesthetic', 'pickscore', 'hpsv2', 'hpsv21', 'imagereward']
3
+ config: {'num_samples': 500, 'num_steps': 20, 'cfg_scale': 4.5, 'grad_range': [0, 700], 'grad_steps': 5, 'grad_step_size': 0.1}
4
+ baseline: {'avg_reward': np.float64(0.6854833755493164), 'clip_score': np.float64(26.60960610508919), 'aesthetic_score': np.float64(5.930574191093445), 'pickscore': np.float64(21.89574451446533), 'hpsv2_score': np.float16(0.2805), 'hpsv21_score': np.float16(0.292), 'imagereward_score': np.float64(1.001599932681769)}
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_1/log.log ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ======================================================================
2
+ FID EVALUATION: BASELINE vs GRADIENT ASCENT
3
+ ======================================================================
4
+
5
+ Logging to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_1/log.log
6
+
7
+ Device: cuda:0
8
+ Dataset: PICKAPIC
9
+ Data directory: ./data
10
+ Base model: Efficient-Large-Model/Sana_600M_512px_diffusers
11
+ Model variant: sana_600m_512
12
+ LRM model: /g/data/rr81/LPO/lrm/lrm_sana/logs/v8/reward_model/step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951/checkpoint-gstep76000
13
+ HF cache dir: /scratch/rr81/ma5430/.cache/huggingface/hub
14
+ HF offline mode: True
15
+ Inference steps: 20
16
+ CFG scale: 4.5
17
+ Batch size: 1
18
+ Max samples: All
19
+ Output directory: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_1
20
+ Save images: False
21
+ Evaluation mode: baseline
22
+ Metrics to evaluate: CLIP, AESTHETIC, PICKSCORE, HPSV2, HPSV21, IMAGEREWARD
23
+ Gradient ascent config: one_step_rectification_config
24
+
25
+ ======================================================================
26
+ 1. LOADING VALIDATION DATA
27
+ ======================================================================
28
+ Loading Pick-a-Pic validation prompts...
29
+ Loading cached Pick-a-Pic split 'validation_unique' from 1 parquet shards
30
+ cache=/scratch/rr81/ma5430/.cache/huggingface/hub/datasets--pickapic-anonymous--pickapic_v1
31
+ Loaded 500 Pick-a-Pic validation samples
32
+
33
+ ======================================================================
34
+ 2. LOADING REWARD MODEL
35
+ ======================================================================
36
+ Loading SANA base reward backbone from Efficient-Large-Model/Sana_600M_512px_diffusers...
37
+ Loading SANA reward checkpoint from /g/data/rr81/LPO/lrm/lrm_sana/logs/v8/reward_model/step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951/checkpoint-gstep76000/model.safetensors...
38
+ ✓ Loaded checkpoint keys: 1214
39
+ ✓ Missing keys: 0 | Unexpected keys: 0
40
+ ✓ SANA LRM Reward Model initialized successfully!
41
+ ✓ Reward model loaded
42
+
43
+ ======================================================================
44
+ 3. LOADING PIPELINE
45
+ ======================================================================
46
+ ✓ Loaded SANA base model: Efficient-Large-Model/Sana_600M_512px_diffusers
47
+ ✓ Reward model attached to SANA pipeline
48
+ ✓ Pipeline loaded
49
+ GPU memory before scorer load: 126.00 GB free / 140.06 GB total
50
+ Scorer device: cuda:0
51
+
52
+ ======================================================================
53
+ 3.5. LOADING CLIP AND AESTHETIC SCORERS
54
+ ======================================================================
55
+ ✓ CLIP scorer loaded
56
+ ✓ Aesthetic scorer loaded
57
+ ✓ PickScore scorer loaded
58
+ ✓ HPSv2 scorer loaded
59
+ ✓ HPSv2.1 scorer loaded
60
+ load checkpoint from /scratch/rr81/ma5430/.cache/huggingface/hub/models--THUDM--ImageReward/snapshots/5736be03b2652728fb87788c9797b0570450ab72/ImageReward.pt
61
+ checkpoint loaded
62
+ ✓ ImageReward scorer loaded
63
+
64
+ ======================================================================
65
+ 4. CONFIGURING GRADIENT ASCENT
66
+ ======================================================================
67
+ Loading gradient ascent config: one_step_rectification_config
68
+ Config loaded: {'grad_timestep_range': (200, 800), 'num_grad_steps': 1, 'grad_step_size': 1.0, 'grad_scale': 1.0, 'lr_scheduler_type': 'constant', 'use_momentum': False, 'use_nesterov': False, 'use_iso_projection': False}
69
+ Gradient timestep range: (200, 800)
70
+ Gradient steps: 1
71
+ Gradient step size (initial LR): 1.0
72
+ LR Scheduler: constant
73
+ ✓ Gradient ascent enabled for timesteps (200, 800)
74
+
75
+ ======================================================================
76
+ 5. EVALUATING BASELINE
77
+ ======================================================================
78
+
79
+ Generating images with baseline mode...
80
+
81
+ [baseline] Batch 10/500 | Samples: 10/500 | Reward (t=136.0): 0.9946 | Reward (Avg): 0.7891 | CLIP: 25.9197 | Aesthetic: 6.1225 | PickScore: 21.9931 | HPSv2: 0.2854 | HPSv2.1: 0.3103 | ImageReward: 1.2588
82
+
83
+ [baseline] Batch 20/500 | Samples: 20/500 | Reward (t=136.0): 0.1036 | Reward (Avg): 0.7397 | CLIP: 26.1207 | Aesthetic: 6.0740 | PickScore: 22.1701 | HPSv2: 0.2849 | HPSv2.1: 0.3103 | ImageReward: 1.1364
84
+
85
+ [baseline] Batch 30/500 | Samples: 30/500 | Reward (t=136.0): 0.9990 | Reward (Avg): 0.7536 | CLIP: 26.2919 | Aesthetic: 6.0006 | PickScore: 22.3621 | HPSv2: 0.2859 | HPSv2.1: 0.3066 | ImageReward: 1.1024
86
+
87
+ [baseline] Batch 40/500 | Samples: 40/500 | Reward (t=136.0): 0.4502 | Reward (Avg): 0.7308 | CLIP: 26.6222 | Aesthetic: 6.0754 | PickScore: 22.3074 | HPSv2: 0.2844 | HPSv2.1: 0.3040 | ImageReward: 0.9845
88
+
89
+ [baseline] Batch 50/500 | Samples: 50/500 | Reward (t=136.0): 0.2013 | Reward (Avg): 0.7104 | CLIP: 26.4461 | Aesthetic: 6.0134 | PickScore: 22.1421 | HPSv2: 0.2832 | HPSv2.1: 0.3013 | ImageReward: 1.0774
90
+
91
+ [baseline] Batch 60/500 | Samples: 60/500 | Reward (t=136.0): 0.8906 | Reward (Avg): 0.7145 | CLIP: 26.3397 | Aesthetic: 6.0189 | PickScore: 22.1258 | HPSv2: 0.2837 | HPSv2.1: 0.3013 | ImageReward: 1.0667
92
+
93
+ [baseline] Batch 70/500 | Samples: 70/500 | Reward (t=136.0): 0.9961 | Reward (Avg): 0.6977 | CLIP: 26.5825 | Aesthetic: 5.9899 | PickScore: 22.1627 | HPSv2: 0.2839 | HPSv2.1: 0.3013 | ImageReward: 1.0167
94
+
95
+ [baseline] Batch 80/500 | Samples: 80/500 | Reward (t=136.0): 0.9014 | Reward (Avg): 0.6814 | CLIP: 26.4912 | Aesthetic: 5.9676 | PickScore: 22.0616 | HPSv2: 0.2832 | HPSv2.1: 0.2974 | ImageReward: 0.9867
96
+
97
+ [baseline] Batch 90/500 | Samples: 90/500 | Reward (t=136.0): 0.9990 | Reward (Avg): 0.6827 | CLIP: 26.6833 | Aesthetic: 5.9604 | PickScore: 22.0380 | HPSv2: 0.2830 | HPSv2.1: 0.2976 | ImageReward: 1.0014
98
+
99
+ [baseline] Batch 100/500 | Samples: 100/500 | Reward (t=136.0): 0.0362 | Reward (Avg): 0.6914 | CLIP: 26.9976 | Aesthetic: 5.9972 | PickScore: 21.9799 | HPSv2: 0.2822 | HPSv2.1: 0.2957 | ImageReward: 1.0027
100
+
101
+ [baseline] Batch 110/500 | Samples: 110/500 | Reward (t=136.0): 0.5342 | Reward (Avg): 0.6817 | CLIP: 27.1451 | Aesthetic: 5.9820 | PickScore: 21.9669 | HPSv2: 0.2825 | HPSv2.1: 0.2959 | ImageReward: 1.0009
102
+
103
+ [baseline] Batch 120/500 | Samples: 120/500 | Reward (t=136.0): 0.0418 | Reward (Avg): 0.6660 | CLIP: 27.0372 | Aesthetic: 5.9733 | PickScore: 21.9549 | HPSv2: 0.2825 | HPSv2.1: 0.2964 | ImageReward: 1.0380
104
+
105
+ [baseline] Batch 130/500 | Samples: 130/500 | Reward (t=136.0): 0.9771 | Reward (Avg): 0.6797 | CLIP: 27.0679 | Aesthetic: 5.9902 | PickScore: 22.0210 | HPSv2: 0.2830 | HPSv2.1: 0.2974 | ImageReward: 1.0360
106
+
107
+ [baseline] Batch 140/500 | Samples: 140/500 | Reward (t=136.0): 0.9722 | Reward (Avg): 0.6796 | CLIP: 27.2356 | Aesthetic: 5.9616 | PickScore: 22.0095 | HPSv2: 0.2830 | HPSv2.1: 0.2961 | ImageReward: 1.0353
108
+
109
+ [baseline] Batch 150/500 | Samples: 150/500 | Reward (t=136.0): 0.9639 | Reward (Avg): 0.6779 | CLIP: 27.0927 | Aesthetic: 5.9419 | PickScore: 21.9896 | HPSv2: 0.2825 | HPSv2.1: 0.2952 | ImageReward: 1.0313
110
+
111
+ [baseline] Batch 160/500 | Samples: 160/500 | Reward (t=136.0): 0.8735 | Reward (Avg): 0.6787 | CLIP: 27.1935 | Aesthetic: 5.9386 | PickScore: 22.0361 | HPSv2: 0.2827 | HPSv2.1: 0.2954 | ImageReward: 1.0422
112
+
113
+ [baseline] Batch 170/500 | Samples: 170/500 | Reward (t=136.0): 0.8418 | Reward (Avg): 0.6797 | CLIP: 26.9886 | Aesthetic: 5.9230 | PickScore: 21.9763 | HPSv2: 0.2820 | HPSv2.1: 0.2939 | ImageReward: 1.0346
114
+
115
+ [baseline] Batch 180/500 | Samples: 180/500 | Reward (t=136.0): 0.3572 | Reward (Avg): 0.6742 | CLIP: 27.0903 | Aesthetic: 5.9289 | PickScore: 21.9842 | HPSv2: 0.2825 | HPSv2.1: 0.2947 | ImageReward: 1.0590
116
+
117
+ [baseline] Batch 190/500 | Samples: 190/500 | Reward (t=136.0): 0.3916 | Reward (Avg): 0.6795 | CLIP: 27.0595 | Aesthetic: 5.9303 | PickScore: 21.9667 | HPSv2: 0.2817 | HPSv2.1: 0.2937 | ImageReward: 1.0438
118
+
119
+ [baseline] Batch 200/500 | Samples: 200/500 | Reward (t=136.0): 0.3701 | Reward (Avg): 0.6735 | CLIP: 27.0636 | Aesthetic: 5.9264 | PickScore: 21.9664 | HPSv2: 0.2820 | HPSv2.1: 0.2944 | ImageReward: 1.0526
120
+
121
+ [baseline] Batch 210/500 | Samples: 210/500 | Reward (t=136.0): 0.7476 | Reward (Avg): 0.6796 | CLIP: 27.0620 | Aesthetic: 5.9352 | PickScore: 21.9646 | HPSv2: 0.2820 | HPSv2.1: 0.2949 | ImageReward: 1.0607
122
+
123
+ [baseline] Batch 220/500 | Samples: 220/500 | Reward (t=136.0): 0.9932 | Reward (Avg): 0.6813 | CLIP: 27.1314 | Aesthetic: 5.9390 | PickScore: 21.9501 | HPSv2: 0.2820 | HPSv2.1: 0.2947 | ImageReward: 1.0481
124
+
125
+ [baseline] Batch 230/500 | Samples: 230/500 | Reward (t=136.0): 0.4731 | Reward (Avg): 0.6819 | CLIP: 27.1906 | Aesthetic: 5.9441 | PickScore: 21.9595 | HPSv2: 0.2820 | HPSv2.1: 0.2944 | ImageReward: 1.0323
126
+
127
+ [baseline] Batch 240/500 | Samples: 240/500 | Reward (t=136.0): 0.2905 | Reward (Avg): 0.6844 | CLIP: 27.0801 | Aesthetic: 5.9538 | PickScore: 21.9540 | HPSv2: 0.2815 | HPSv2.1: 0.2937 | ImageReward: 1.0183
128
+
129
+ [baseline] Batch 250/500 | Samples: 250/500 | Reward (t=136.0): 0.9868 | Reward (Avg): 0.6836 | CLIP: 27.0973 | Aesthetic: 5.9652 | PickScore: 21.9579 | HPSv2: 0.2817 | HPSv2.1: 0.2939 | ImageReward: 1.0174
130
+
131
+ [baseline] Batch 260/500 | Samples: 260/500 | Reward (t=136.0): 0.6987 | Reward (Avg): 0.6825 | CLIP: 27.0369 | Aesthetic: 5.9730 | PickScore: 21.9534 | HPSv2: 0.2817 | HPSv2.1: 0.2939 | ImageReward: 1.0270
132
+
133
+ [baseline] Batch 270/500 | Samples: 270/500 | Reward (t=136.0): 1.0000 | Reward (Avg): 0.6849 | CLIP: 27.0198 | Aesthetic: 5.9743 | PickScore: 21.9475 | HPSv2: 0.2820 | HPSv2.1: 0.2939 | ImageReward: 1.0323
134
+
135
+ [baseline] Batch 280/500 | Samples: 280/500 | Reward (t=136.0): 0.9316 | Reward (Avg): 0.6886 | CLIP: 27.0667 | Aesthetic: 5.9771 | PickScore: 21.9809 | HPSv2: 0.2822 | HPSv2.1: 0.2949 | ImageReward: 1.0492
136
+
137
+ [baseline] Batch 290/500 | Samples: 290/500 | Reward (t=136.0): 0.8652 | Reward (Avg): 0.6863 | CLIP: 26.9701 | Aesthetic: 5.9660 | PickScore: 21.9452 | HPSv2: 0.2820 | HPSv2.1: 0.2939 | ImageReward: 1.0336
138
+
139
+ [baseline] Batch 300/500 | Samples: 300/500 | Reward (t=136.0): 0.9995 | Reward (Avg): 0.6888 | CLIP: 26.9522 | Aesthetic: 5.9680 | PickScore: 21.9559 | HPSv2: 0.2822 | HPSv2.1: 0.2947 | ImageReward: 1.0375
140
+
141
+ [baseline] Batch 310/500 | Samples: 310/500 | Reward (t=136.0): 0.9971 | Reward (Avg): 0.6894 | CLIP: 27.0262 | Aesthetic: 5.9725 | PickScore: 21.9700 | HPSv2: 0.2825 | HPSv2.1: 0.2949 | ImageReward: 1.0583
142
+
143
+ [baseline] Batch 320/500 | Samples: 320/500 | Reward (t=136.0): 0.8667 | Reward (Avg): 0.6890 | CLIP: 27.0493 | Aesthetic: 5.9674 | PickScore: 21.9830 | HPSv2: 0.2825 | HPSv2.1: 0.2947 | ImageReward: 1.0534
144
+
145
+ [baseline] Batch 330/500 | Samples: 330/500 | Reward (t=136.0): 0.9683 | Reward (Avg): 0.6922 | CLIP: 27.0304 | Aesthetic: 5.9733 | PickScore: 21.9784 | HPSv2: 0.2822 | HPSv2.1: 0.2947 | ImageReward: 1.0603
146
+
147
+ [baseline] Batch 340/500 | Samples: 340/500 | Reward (t=136.0): 0.8975 | Reward (Avg): 0.6945 | CLIP: 27.0049 | Aesthetic: 5.9703 | PickScore: 21.9790 | HPSv2: 0.2822 | HPSv2.1: 0.2949 | ImageReward: 1.0708
148
+
149
+ [baseline] Batch 350/500 | Samples: 350/500 | Reward (t=136.0): 0.0694 | Reward (Avg): 0.6900 | CLIP: 27.0073 | Aesthetic: 5.9688 | PickScore: 21.9897 | HPSv2: 0.2822 | HPSv2.1: 0.2949 | ImageReward: 1.0626
150
+
151
+ [baseline] Batch 360/500 | Samples: 360/500 | Reward (t=136.0): 0.9307 | Reward (Avg): 0.6921 | CLIP: 27.0431 | Aesthetic: 5.9667 | PickScore: 21.9860 | HPSv2: 0.2825 | HPSv2.1: 0.2952 | ImageReward: 1.0531
152
+
153
+ [baseline] Batch 370/500 | Samples: 370/500 | Reward (t=136.0): 0.9175 | Reward (Avg): 0.6917 | CLIP: 26.9788 | Aesthetic: 5.9615 | PickScore: 21.9723 | HPSv2: 0.2822 | HPSv2.1: 0.2949 | ImageReward: 1.0486
154
+
155
+ [baseline] Batch 380/500 | Samples: 380/500 | Reward (t=136.0): 0.3616 | Reward (Avg): 0.6916 | CLIP: 27.0571 | Aesthetic: 5.9690 | PickScore: 21.9754 | HPSv2: 0.2822 | HPSv2.1: 0.2952 | ImageReward: 1.0540
156
+
157
+ [baseline] Batch 390/500 | Samples: 390/500 | Reward (t=136.0): 0.9912 | Reward (Avg): 0.6914 | CLIP: 26.9386 | Aesthetic: 5.9658 | PickScore: 21.9601 | HPSv2: 0.2820 | HPSv2.1: 0.2944 | ImageReward: 1.0402
158
+
159
+ [baseline] Batch 400/500 | Samples: 400/500 | Reward (t=136.0): 0.0252 | Reward (Avg): 0.6910 | CLIP: 26.8978 | Aesthetic: 5.9574 | PickScore: 21.9578 | HPSv2: 0.2820 | HPSv2.1: 0.2942 | ImageReward: 1.0424
160
+
161
+ [baseline] Batch 410/500 | Samples: 410/500 | Reward (t=136.0): 0.2007 | Reward (Avg): 0.6909 | CLIP: 26.8640 | Aesthetic: 5.9528 | PickScore: 21.9488 | HPSv2: 0.2815 | HPSv2.1: 0.2937 | ImageReward: 1.0236
162
+
163
+ [baseline] Batch 420/500 | Samples: 420/500 | Reward (t=136.0): 0.9917 | Reward (Avg): 0.6889 | CLIP: 26.8175 | Aesthetic: 5.9515 | PickScore: 21.9370 | HPSv2: 0.2815 | HPSv2.1: 0.2935 | ImageReward: 1.0200
164
+
165
+ [baseline] Batch 430/500 | Samples: 430/500 | Reward (t=136.0): 0.2085 | Reward (Avg): 0.6878 | CLIP: 26.8390 | Aesthetic: 5.9464 | PickScore: 21.9401 | HPSv2: 0.2815 | HPSv2.1: 0.2935 | ImageReward: 1.0237
166
+
167
+ [baseline] Batch 440/500 | Samples: 440/500 | Reward (t=136.0): 0.7144 | Reward (Avg): 0.6867 | CLIP: 26.7663 | Aesthetic: 5.9488 | PickScore: 21.9390 | HPSv2: 0.2815 | HPSv2.1: 0.2935 | ImageReward: 1.0194
168
+
169
+ [baseline] Batch 450/500 | Samples: 450/500 | Reward (t=136.0): 0.8086 | Reward (Avg): 0.6877 | CLIP: 26.7831 | Aesthetic: 5.9422 | PickScore: 21.9401 | HPSv2: 0.2812 | HPSv2.1: 0.2932 | ImageReward: 1.0113
170
+
171
+ [baseline] Batch 460/500 | Samples: 460/500 | Reward (t=136.0): 0.6558 | Reward (Avg): 0.6879 | CLIP: 26.7559 | Aesthetic: 5.9414 | PickScore: 21.9357 | HPSv2: 0.2810 | HPSv2.1: 0.2927 | ImageReward: 1.0081
172
+
173
+ [baseline] Batch 470/500 | Samples: 470/500 | Reward (t=136.0): 0.6001 | Reward (Avg): 0.6873 | CLIP: 26.6761 | Aesthetic: 5.9320 | PickScore: 21.9162 | HPSv2: 0.2808 | HPSv2.1: 0.2920 | ImageReward: 0.9957
174
+
175
+ [baseline] Batch 480/500 | Samples: 480/500 | Reward (t=136.0): 0.7988 | Reward (Avg): 0.6846 | CLIP: 26.6003 | Aesthetic: 5.9285 | PickScore: 21.9059 | HPSv2: 0.2805 | HPSv2.1: 0.2920 | ImageReward: 0.9969
176
+
177
+ [baseline] Batch 490/500 | Samples: 490/500 | Reward (t=136.0): 0.8433 | Reward (Avg): 0.6850 | CLIP: 26.6023 | Aesthetic: 5.9283 | PickScore: 21.8971 | HPSv2: 0.2805 | HPSv2.1: 0.2915 | ImageReward: 0.9980
178
+
179
+ [baseline] Batch 500/500 | Samples: 500/500 | Reward (t=136.0): 0.9902 | Reward (Avg): 0.6855 | CLIP: 26.6096 | Aesthetic: 5.9306 | PickScore: 21.8957 | HPSv2: 0.2805 | HPSv2.1: 0.2920 | ImageReward: 1.0016
180
+ ✓ Baseline Avg Reward: 0.6855
181
+ ✓ Baseline Avg CLIP Score: 26.6096
182
+ ✓ Baseline Avg Aesthetic Score: 5.9306
183
+ ✓ Baseline Avg PickScore: 21.8957
184
+ ✓ Baseline Avg HPSv2 Score: 0.2805
185
+ ✓ Baseline Avg HPSv2.1 Score: 0.2920
186
+ ✓ Baseline Avg ImageReward: 1.0016
187
+
188
+ ======================================================================
189
+ FINAL RESULTS
190
+ ======================================================================
191
+
192
+ Baseline:
193
+ Avg Reward: 0.6855
194
+ Avg CLIP Score: 26.6096
195
+ Avg Aesthetic: 5.9306
196
+ Avg PickScore: 21.8957
197
+ Avg HPSv2: 0.2805
198
+ Avg HPSv2.1: 0.2920
199
+ Avg ImageReward: 1.0016
200
+
201
+ ✓ Results saved to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_1/evaluation_results.txt
202
+
203
+ ======================================================================
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_2/evaluation_results.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ mode: gradient_ascent
2
+ metrics: ['clip', 'aesthetic', 'pickscore', 'hpsv2', 'hpsv21', 'imagereward']
3
+ config: {'num_samples': 500, 'num_steps': 20, 'cfg_scale': 4.5, 'grad_range': [0, 700], 'grad_steps': 5, 'grad_step_size': 0.1}
4
+ gradient_ascent: {'avg_reward': np.float64(0.910064338684082), 'clip_score': np.float64(26.665978475391864), 'aesthetic_score': np.float64(5.9369088306427), 'pickscore': np.float64(21.87682523727417), 'hpsv2_score': np.float16(0.28), 'hpsv21_score': np.float16(0.2903), 'imagereward_score': np.float64(0.9915356585062655), 'stats': {'num_applications': 10, 'total_reward_improvement': 0.00537109375, 'avg_reward_improvement': 0.000537109375, 'avg_grad_norm': 0.016338474582880735, 'max_grad_norm': 0.022640923038125038, 'detailed_stats': [{'timestep': 785, 'initial_reward': 0.986328125, 'final_reward': 0.9873046875, 'reward_improvement': 0.0009765625, 'grad_norms': [0.022640923038125038], 'reward_history': [0.986328125, 0.986328125], 'lr_history': [1.0], 'latent_change': 0.9999995827674866}, {'timestep': 749, 'initial_reward': 0.98779296875, 'final_reward': 0.98828125, 'reward_improvement': 0.00048828125, 'grad_norms': [0.021537061780691147], 'reward_history': [0.98779296875, 0.98779296875], 'lr_history': [1.0], 'latent_change': 0.9999995827674866}, {'timestep': 710, 'initial_reward': 0.98828125, 'final_reward': 0.98876953125, 'reward_improvement': 0.00048828125, 'grad_norms': [0.018791837617754936], 'reward_history': [0.98828125, 0.98828125], 'lr_history': [1.0], 'latent_change': 0.9999994039535522}, {'timestep': 666, 'initial_reward': 0.98876953125, 'final_reward': 0.990234375, 'reward_improvement': 0.00146484375, 'grad_norms': [0.020797280594706535], 'reward_history': [0.98876953125, 0.98876953125], 'lr_history': [1.0], 'latent_change': 0.9999995231628418}, {'timestep': 617, 'initial_reward': 0.98974609375, 'final_reward': 0.98974609375, 'reward_improvement': 0.0, 'grad_norms': [0.021422632038593292], 'reward_history': [0.98974609375, 0.98974609375], 'lr_history': [1.0], 'latent_change': 0.9999995231628418}, {'timestep': 562, 'initial_reward': 0.98974609375, 'final_reward': 0.990234375, 'reward_improvement': 0.00048828125, 'grad_norms': [0.014982366934418678], 'reward_history': [0.98974609375, 0.98974609375], 'lr_history': [1.0], 'latent_change': 0.9999992251396179}, {'timestep': 499, 'initial_reward': 0.990234375, 'final_reward': 0.99072265625, 'reward_improvement': 0.00048828125, 'grad_norms': [0.012734953314065933], 'reward_history': [0.990234375, 0.990234375], 'lr_history': [1.0], 'latent_change': 0.9999991655349731}, {'timestep': 428, 'initial_reward': 0.99072265625, 'final_reward': 0.9912109375, 'reward_improvement': 0.00048828125, 'grad_norms': [0.011904634535312653], 'reward_history': [0.99072265625, 0.99072265625], 'lr_history': [1.0], 'latent_change': 0.9999989867210388}, {'timestep': 345, 'initial_reward': 0.99169921875, 'final_reward': 0.99169921875, 'reward_improvement': 0.0, 'grad_norms': [0.009561818093061447], 'reward_history': [0.99169921875, 0.99169921875], 'lr_history': [1.0], 'latent_change': 0.9999989867210388}, {'timestep': 249, 'initial_reward': 0.99169921875, 'final_reward': 0.9921875, 'reward_improvement': 0.00048828125, 'grad_norms': [0.009011237882077694], 'reward_history': [0.99169921875, 0.99169921875], 'lr_history': [1.0], 'latent_change': 0.9999988675117493}]}}
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_2/log.log ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ======================================================================
2
+ FID EVALUATION: BASELINE vs GRADIENT ASCENT
3
+ ======================================================================
4
+
5
+ Logging to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_2/log.log
6
+
7
+ Device: cuda:0
8
+ Dataset: PICKAPIC
9
+ Data directory: ./data
10
+ Base model: Efficient-Large-Model/Sana_600M_512px_diffusers
11
+ Model variant: sana_600m_512
12
+ LRM model: /g/data/rr81/LPO/lrm/lrm_sana/logs/v8/reward_model/step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951/checkpoint-gstep76000
13
+ HF cache dir: /scratch/rr81/ma5430/.cache/huggingface/hub
14
+ HF offline mode: True
15
+ Inference steps: 20
16
+ CFG scale: 4.5
17
+ Batch size: 1
18
+ Max samples: All
19
+ Output directory: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_2
20
+ Save images: False
21
+ Evaluation mode: gradient_ascent
22
+ Metrics to evaluate: CLIP, AESTHETIC, PICKSCORE, HPSV2, HPSV21, IMAGEREWARD
23
+ Gradient ascent config: one_step_rectification_config
24
+
25
+ ======================================================================
26
+ 1. LOADING VALIDATION DATA
27
+ ======================================================================
28
+ Loading Pick-a-Pic validation prompts...
29
+ Loading cached Pick-a-Pic split 'validation_unique' from 1 parquet shards
30
+ cache=/scratch/rr81/ma5430/.cache/huggingface/hub/datasets--pickapic-anonymous--pickapic_v1
31
+ Loaded 500 Pick-a-Pic validation samples
32
+
33
+ ======================================================================
34
+ 2. LOADING REWARD MODEL
35
+ ======================================================================
36
+ Loading SANA base reward backbone from Efficient-Large-Model/Sana_600M_512px_diffusers...
37
+ Loading SANA reward checkpoint from /g/data/rr81/LPO/lrm/lrm_sana/logs/v8/reward_model/step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951/checkpoint-gstep76000/model.safetensors...
38
+ ✓ Loaded checkpoint keys: 1214
39
+ ✓ Missing keys: 0 | Unexpected keys: 0
40
+ ✓ SANA LRM Reward Model initialized successfully!
41
+ ✓ Reward model loaded
42
+
43
+ ======================================================================
44
+ 3. LOADING PIPELINE
45
+ ======================================================================
46
+ ✓ Loaded SANA base model: Efficient-Large-Model/Sana_600M_512px_diffusers
47
+ ✓ Reward model attached to SANA pipeline
48
+ ✓ Pipeline loaded
49
+ GPU memory before scorer load: 93.09 GB free / 140.06 GB total
50
+ Scorer device: cuda:0
51
+
52
+ ======================================================================
53
+ 3.5. LOADING CLIP AND AESTHETIC SCORERS
54
+ ======================================================================
55
+ ✓ CLIP scorer loaded
56
+ ✓ Aesthetic scorer loaded
57
+ ✓ PickScore scorer loaded
58
+ ✓ HPSv2 scorer loaded
59
+ ✓ HPSv2.1 scorer loaded
60
+ load checkpoint from /scratch/rr81/ma5430/.cache/huggingface/hub/models--THUDM--ImageReward/snapshots/5736be03b2652728fb87788c9797b0570450ab72/ImageReward.pt
61
+ checkpoint loaded
62
+ ✓ ImageReward scorer loaded
63
+
64
+ ======================================================================
65
+ 4. CONFIGURING GRADIENT ASCENT
66
+ ======================================================================
67
+ Loading gradient ascent config: one_step_rectification_config
68
+ Config loaded: {'grad_timestep_range': (200, 800), 'num_grad_steps': 1, 'grad_step_size': 1.0, 'grad_scale': 1.0, 'lr_scheduler_type': 'constant', 'use_momentum': False, 'use_nesterov': False, 'use_iso_projection': False}
69
+ Gradient timestep range: (200, 800)
70
+ Gradient steps: 1
71
+ Gradient step size (initial LR): 1.0
72
+ LR Scheduler: constant
73
+ ✓ Gradient ascent enabled for timesteps (200, 800)
74
+
75
+ ======================================================================
76
+ 6. EVALUATING GRADIENT ASCENT
77
+ ======================================================================
78
+
79
+ Generating images with gradient_ascent mode...
80
+
81
+ [gradient_ascent] Batch 10/500 | Samples: 10/500 | Reward (t=136.0): 0.9961 | Reward (Avg): 0.9693 | CLIP: 26.4394 | Aesthetic: 6.0884 | PickScore: 21.9325 | HPSv2: 0.2861 | HPSv2.1: 0.3071 | ImageReward: 1.2097
82
+
83
+ [gradient_ascent] Batch 20/500 | Samples: 20/500 | Reward (t=136.0): 0.4675 | Reward (Avg): 0.9257 | CLIP: 26.3258 | Aesthetic: 6.0671 | PickScore: 22.1625 | HPSv2: 0.2861 | HPSv2.1: 0.3086 | ImageReward: 1.1082
84
+
85
+ [gradient_ascent] Batch 30/500 | Samples: 30/500 | Reward (t=136.0): 0.9980 | Reward (Avg): 0.9233 | CLIP: 26.4155 | Aesthetic: 5.9805 | PickScore: 22.3315 | HPSv2: 0.2861 | HPSv2.1: 0.3057 | ImageReward: 1.0953
86
+ ?? WARNING: Gradient exists but max value is 0.0
87
+ ?? WARNING: Gradient exists but max value is 0.0
88
+ ?? WARNING: Gradient exists but max value is 0.0
89
+ ?? WARNING: Gradient exists but max value is 0.0
90
+ ?? WARNING: Gradient exists but max value is 0.0
91
+ ?? WARNING: Gradient exists but max value is 0.0
92
+ ?? WARNING: Gradient exists but max value is 0.0
93
+ ?? WARNING: Gradient exists but max value is 0.0
94
+ ?? WARNING: Gradient exists but max value is 0.0
95
+ ?? WARNING: Gradient exists but max value is 0.0
96
+
97
+ [gradient_ascent] Batch 40/500 | Samples: 40/500 | Reward (t=136.0): 0.9722 | Reward (Avg): 0.9284 | CLIP: 26.8101 | Aesthetic: 6.0611 | PickScore: 22.3091 | HPSv2: 0.2849 | HPSv2.1: 0.3037 | ImageReward: 0.9902
98
+
99
+ [gradient_ascent] Batch 50/500 | Samples: 50/500 | Reward (t=136.0): 0.6968 | Reward (Avg): 0.9302 | CLIP: 26.6150 | Aesthetic: 5.9952 | PickScore: 22.1516 | HPSv2: 0.2832 | HPSv2.1: 0.3005 | ImageReward: 1.0679
100
+
101
+ [gradient_ascent] Batch 60/500 | Samples: 60/500 | Reward (t=136.0): 0.9868 | Reward (Avg): 0.9303 | CLIP: 26.4753 | Aesthetic: 6.0003 | PickScore: 22.1196 | HPSv2: 0.2837 | HPSv2.1: 0.3000 | ImageReward: 1.0506
102
+
103
+ [gradient_ascent] Batch 70/500 | Samples: 70/500 | Reward (t=136.0): 0.9966 | Reward (Avg): 0.9293 | CLIP: 26.7216 | Aesthetic: 5.9899 | PickScore: 22.1551 | HPSv2: 0.2842 | HPSv2.1: 0.3005 | ImageReward: 0.9863
104
+
105
+ [gradient_ascent] Batch 80/500 | Samples: 80/500 | Reward (t=136.0): 0.9819 | Reward (Avg): 0.9147 | CLIP: 26.6154 | Aesthetic: 5.9622 | PickScore: 22.0316 | HPSv2: 0.2830 | HPSv2.1: 0.2959 | ImageReward: 0.9452
106
+ ?? WARNING: Gradient exists but max value is 0.0
107
+ ?? WARNING: Gradient exists but max value is 0.0
108
+ ?? WARNING: Gradient exists but max value is 0.0
109
+ ?? WARNING: Gradient exists but max value is 0.0
110
+ ?? WARNING: Gradient exists but max value is 0.0
111
+ ?? WARNING: Gradient exists but max value is 0.0
112
+ ?? WARNING: Gradient exists but max value is 0.0
113
+ ?? WARNING: Gradient exists but max value is 0.0
114
+ ?? WARNING: Gradient exists but max value is 0.0
115
+ ?? WARNING: Gradient exists but max value is 0.0
116
+
117
+ [gradient_ascent] Batch 90/500 | Samples: 90/500 | Reward (t=136.0): 0.9990 | Reward (Avg): 0.9137 | CLIP: 26.7968 | Aesthetic: 5.9589 | PickScore: 22.0089 | HPSv2: 0.2825 | HPSv2.1: 0.2959 | ImageReward: 0.9702
118
+ ?? WARNING: Gradient exists but max value is 0.0
119
+ ?? WARNING: Gradient exists but max value is 0.0
120
+ ?? WARNING: Gradient exists but max value is 0.0
121
+ ?? WARNING: Gradient exists but max value is 0.0
122
+ ?? WARNING: Gradient exists but max value is 0.0
123
+ ?? WARNING: Gradient exists but max value is 0.0
124
+ ?? WARNING: Gradient exists but max value is 0.0
125
+ ?? WARNING: Gradient exists but max value is 0.0
126
+ ?? WARNING: Gradient exists but max value is 0.0
127
+ ?? WARNING: Gradient exists but max value is 0.0
128
+
129
+ [gradient_ascent] Batch 100/500 | Samples: 100/500 | Reward (t=136.0): 0.7866 | Reward (Avg): 0.9190 | CLIP: 27.0602 | Aesthetic: 6.0023 | PickScore: 21.9608 | HPSv2: 0.2817 | HPSv2.1: 0.2942 | ImageReward: 0.9681
130
+
131
+ [gradient_ascent] Batch 110/500 | Samples: 110/500 | Reward (t=136.0): 0.9692 | Reward (Avg): 0.9170 | CLIP: 27.2418 | Aesthetic: 5.9877 | PickScore: 21.9353 | HPSv2: 0.2820 | HPSv2.1: 0.2939 | ImageReward: 0.9614
132
+
133
+ [gradient_ascent] Batch 120/500 | Samples: 120/500 | Reward (t=136.0): 0.8379 | Reward (Avg): 0.9146 | CLIP: 27.1626 | Aesthetic: 5.9762 | PickScore: 21.9197 | HPSv2: 0.2820 | HPSv2.1: 0.2942 | ImageReward: 0.9982
134
+
135
+ [gradient_ascent] Batch 130/500 | Samples: 130/500 | Reward (t=136.0): 0.9858 | Reward (Avg): 0.9196 | CLIP: 27.2047 | Aesthetic: 5.9938 | PickScore: 21.9912 | HPSv2: 0.2825 | HPSv2.1: 0.2952 | ImageReward: 0.9948
136
+
137
+ [gradient_ascent] Batch 140/500 | Samples: 140/500 | Reward (t=136.0): 0.9912 | Reward (Avg): 0.9157 | CLIP: 27.3311 | Aesthetic: 5.9660 | PickScore: 21.9761 | HPSv2: 0.2822 | HPSv2.1: 0.2944 | ImageReward: 0.9935
138
+
139
+ [gradient_ascent] Batch 150/500 | Samples: 150/500 | Reward (t=136.0): 0.9932 | Reward (Avg): 0.9114 | CLIP: 27.1722 | Aesthetic: 5.9463 | PickScore: 21.9579 | HPSv2: 0.2817 | HPSv2.1: 0.2930 | ImageReward: 0.9898
140
+
141
+ [gradient_ascent] Batch 160/500 | Samples: 160/500 | Reward (t=136.0): 0.9814 | Reward (Avg): 0.9088 | CLIP: 27.2961 | Aesthetic: 5.9422 | PickScore: 21.9990 | HPSv2: 0.2820 | HPSv2.1: 0.2932 | ImageReward: 1.0014
142
+
143
+ [gradient_ascent] Batch 170/500 | Samples: 170/500 | Reward (t=136.0): 0.9531 | Reward (Avg): 0.9071 | CLIP: 27.0807 | Aesthetic: 5.9318 | PickScore: 21.9435 | HPSv2: 0.2815 | HPSv2.1: 0.2917 | ImageReward: 0.9945
144
+
145
+ [gradient_ascent] Batch 180/500 | Samples: 180/500 | Reward (t=136.0): 0.8799 | Reward (Avg): 0.9030 | CLIP: 27.1614 | Aesthetic: 5.9334 | PickScore: 21.9554 | HPSv2: 0.2817 | HPSv2.1: 0.2927 | ImageReward: 1.0219
146
+
147
+ [gradient_ascent] Batch 190/500 | Samples: 190/500 | Reward (t=136.0): 0.8799 | Reward (Avg): 0.9062 | CLIP: 27.1369 | Aesthetic: 5.9368 | PickScore: 21.9370 | HPSv2: 0.2812 | HPSv2.1: 0.2920 | ImageReward: 1.0060
148
+
149
+ [gradient_ascent] Batch 200/500 | Samples: 200/500 | Reward (t=136.0): 0.9653 | Reward (Avg): 0.9037 | CLIP: 27.1264 | Aesthetic: 5.9329 | PickScore: 21.9405 | HPSv2: 0.2815 | HPSv2.1: 0.2925 | ImageReward: 1.0175
150
+
151
+ [gradient_ascent] Batch 210/500 | Samples: 210/500 | Reward (t=136.0): 0.9897 | Reward (Avg): 0.9063 | CLIP: 27.1528 | Aesthetic: 5.9437 | PickScore: 21.9455 | HPSv2: 0.2815 | HPSv2.1: 0.2932 | ImageReward: 1.0271
152
+
153
+ [gradient_ascent] Batch 220/500 | Samples: 220/500 | Reward (t=136.0): 0.9961 | Reward (Avg): 0.9069 | CLIP: 27.2116 | Aesthetic: 5.9480 | PickScore: 21.9343 | HPSv2: 0.2815 | HPSv2.1: 0.2932 | ImageReward: 1.0189
154
+
155
+ [gradient_ascent] Batch 230/500 | Samples: 230/500 | Reward (t=136.0): 0.9409 | Reward (Avg): 0.9084 | CLIP: 27.2825 | Aesthetic: 5.9535 | PickScore: 21.9417 | HPSv2: 0.2815 | HPSv2.1: 0.2927 | ImageReward: 1.0007
156
+
157
+ [gradient_ascent] Batch 240/500 | Samples: 240/500 | Reward (t=136.0): 0.8716 | Reward (Avg): 0.9101 | CLIP: 27.1896 | Aesthetic: 5.9654 | PickScore: 21.9351 | HPSv2: 0.2810 | HPSv2.1: 0.2920 | ImageReward: 0.9896
158
+
159
+ [gradient_ascent] Batch 250/500 | Samples: 250/500 | Reward (t=136.0): 0.9966 | Reward (Avg): 0.9120 | CLIP: 27.2003 | Aesthetic: 5.9738 | PickScore: 21.9403 | HPSv2: 0.2812 | HPSv2.1: 0.2922 | ImageReward: 0.9901
160
+
161
+ [gradient_ascent] Batch 260/500 | Samples: 260/500 | Reward (t=136.0): 0.8730 | Reward (Avg): 0.9100 | CLIP: 27.1515 | Aesthetic: 5.9809 | PickScore: 21.9393 | HPSv2: 0.2812 | HPSv2.1: 0.2922 | ImageReward: 1.0035
162
+ ?? WARNING: Gradient exists but max value is 0.0
163
+ ?? WARNING: Gradient exists but max value is 0.0
164
+ ?? WARNING: Gradient exists but max value is 0.0
165
+ ?? WARNING: Gradient exists but max value is 0.0
166
+ ?? WARNING: Gradient exists but max value is 0.0
167
+ ?? WARNING: Gradient exists but max value is 0.0
168
+ ?? WARNING: Gradient exists but max value is 0.0
169
+ ?? WARNING: Gradient exists but max value is 0.0
170
+ ?? WARNING: Gradient exists but max value is 0.0
171
+ ?? WARNING: Gradient exists but max value is 0.0
172
+
173
+ [gradient_ascent] Batch 270/500 | Samples: 270/500 | Reward (t=136.0): 1.0000 | Reward (Avg): 0.9094 | CLIP: 27.1001 | Aesthetic: 5.9807 | PickScore: 21.9312 | HPSv2: 0.2812 | HPSv2.1: 0.2922 | ImageReward: 1.0090
174
+
175
+ [gradient_ascent] Batch 280/500 | Samples: 280/500 | Reward (t=136.0): 0.9683 | Reward (Avg): 0.9115 | CLIP: 27.1456 | Aesthetic: 5.9814 | PickScore: 21.9610 | HPSv2: 0.2817 | HPSv2.1: 0.2932 | ImageReward: 1.0251
176
+
177
+ [gradient_ascent] Batch 290/500 | Samples: 290/500 | Reward (t=136.0): 0.9956 | Reward (Avg): 0.9111 | CLIP: 27.0581 | Aesthetic: 5.9709 | PickScore: 21.9298 | HPSv2: 0.2812 | HPSv2.1: 0.2925 | ImageReward: 1.0123
178
+ ?? WARNING: Gradient exists but max value is 0.0
179
+
180
+ [gradient_ascent] Batch 300/500 | Samples: 300/500 | Reward (t=136.0): 0.9995 | Reward (Avg): 0.9110 | CLIP: 27.0435 | Aesthetic: 5.9736 | PickScore: 21.9395 | HPSv2: 0.2817 | HPSv2.1: 0.2930 | ImageReward: 1.0170
181
+
182
+ [gradient_ascent] Batch 310/500 | Samples: 310/500 | Reward (t=136.0): 0.9980 | Reward (Avg): 0.9117 | CLIP: 27.0912 | Aesthetic: 5.9763 | PickScore: 21.9508 | HPSv2: 0.2817 | HPSv2.1: 0.2932 | ImageReward: 1.0371
183
+
184
+ [gradient_ascent] Batch 320/500 | Samples: 320/500 | Reward (t=136.0): 0.9775 | Reward (Avg): 0.9116 | CLIP: 27.1126 | Aesthetic: 5.9707 | PickScore: 21.9669 | HPSv2: 0.2817 | HPSv2.1: 0.2930 | ImageReward: 1.0342
185
+
186
+ [gradient_ascent] Batch 330/500 | Samples: 330/500 | Reward (t=136.0): 0.9941 | Reward (Avg): 0.9121 | CLIP: 27.1159 | Aesthetic: 5.9762 | PickScore: 21.9632 | HPSv2: 0.2817 | HPSv2.1: 0.2932 | ImageReward: 1.0420
187
+
188
+ [gradient_ascent] Batch 340/500 | Samples: 340/500 | Reward (t=136.0): 0.9785 | Reward (Avg): 0.9129 | CLIP: 27.1023 | Aesthetic: 5.9739 | PickScore: 21.9643 | HPSv2: 0.2817 | HPSv2.1: 0.2935 | ImageReward: 1.0503
189
+
190
+ [gradient_ascent] Batch 350/500 | Samples: 350/500 | Reward (t=136.0): 0.5396 | Reward (Avg): 0.9110 | CLIP: 27.1031 | Aesthetic: 5.9720 | PickScore: 21.9747 | HPSv2: 0.2817 | HPSv2.1: 0.2935 | ImageReward: 1.0437
191
+
192
+ [gradient_ascent] Batch 360/500 | Samples: 360/500 | Reward (t=136.0): 0.9805 | Reward (Avg): 0.9119 | CLIP: 27.1421 | Aesthetic: 5.9696 | PickScore: 21.9725 | HPSv2: 0.2817 | HPSv2.1: 0.2937 | ImageReward: 1.0372
193
+
194
+ [gradient_ascent] Batch 370/500 | Samples: 370/500 | Reward (t=136.0): 0.9692 | Reward (Avg): 0.9117 | CLIP: 27.0698 | Aesthetic: 5.9642 | PickScore: 21.9583 | HPSv2: 0.2817 | HPSv2.1: 0.2935 | ImageReward: 1.0346
195
+
196
+ [gradient_ascent] Batch 380/500 | Samples: 380/500 | Reward (t=136.0): 0.7930 | Reward (Avg): 0.9110 | CLIP: 27.1452 | Aesthetic: 5.9731 | PickScore: 21.9609 | HPSv2: 0.2817 | HPSv2.1: 0.2937 | ImageReward: 1.0401
197
+
198
+ [gradient_ascent] Batch 390/500 | Samples: 390/500 | Reward (t=136.0): 0.9932 | Reward (Avg): 0.9114 | CLIP: 27.0228 | Aesthetic: 5.9694 | PickScore: 21.9430 | HPSv2: 0.2812 | HPSv2.1: 0.2930 | ImageReward: 1.0260
199
+
200
+ [gradient_ascent] Batch 400/500 | Samples: 400/500 | Reward (t=136.0): 0.4897 | Reward (Avg): 0.9117 | CLIP: 26.9789 | Aesthetic: 5.9603 | PickScore: 21.9398 | HPSv2: 0.2812 | HPSv2.1: 0.2927 | ImageReward: 1.0275
201
+
202
+ [gradient_ascent] Batch 410/500 | Samples: 410/500 | Reward (t=136.0): 0.8018 | Reward (Avg): 0.9123 | CLIP: 26.9305 | Aesthetic: 5.9569 | PickScore: 21.9279 | HPSv2: 0.2810 | HPSv2.1: 0.2922 | ImageReward: 1.0088
203
+
204
+ [gradient_ascent] Batch 420/500 | Samples: 420/500 | Reward (t=136.0): 0.9966 | Reward (Avg): 0.9113 | CLIP: 26.8775 | Aesthetic: 5.9550 | PickScore: 21.9134 | HPSv2: 0.2808 | HPSv2.1: 0.2920 | ImageReward: 1.0062
205
+
206
+ [gradient_ascent] Batch 430/500 | Samples: 430/500 | Reward (t=136.0): 0.7876 | Reward (Avg): 0.9102 | CLIP: 26.8827 | Aesthetic: 5.9510 | PickScore: 21.9192 | HPSv2: 0.2808 | HPSv2.1: 0.2920 | ImageReward: 1.0116
207
+
208
+ [gradient_ascent] Batch 440/500 | Samples: 440/500 | Reward (t=136.0): 0.9561 | Reward (Avg): 0.9096 | CLIP: 26.8088 | Aesthetic: 5.9552 | PickScore: 21.9176 | HPSv2: 0.2808 | HPSv2.1: 0.2920 | ImageReward: 1.0110
209
+
210
+ [gradient_ascent] Batch 450/500 | Samples: 450/500 | Reward (t=136.0): 0.9443 | Reward (Avg): 0.9086 | CLIP: 26.8014 | Aesthetic: 5.9476 | PickScore: 21.9168 | HPSv2: 0.2805 | HPSv2.1: 0.2915 | ImageReward: 1.0002
211
+
212
+ [gradient_ascent] Batch 460/500 | Samples: 460/500 | Reward (t=136.0): 0.9888 | Reward (Avg): 0.9095 | CLIP: 26.7852 | Aesthetic: 5.9451 | PickScore: 21.9099 | HPSv2: 0.2805 | HPSv2.1: 0.2913 | ImageReward: 0.9956
213
+
214
+ [gradient_ascent] Batch 470/500 | Samples: 470/500 | Reward (t=136.0): 0.9800 | Reward (Avg): 0.9107 | CLIP: 26.7099 | Aesthetic: 5.9361 | PickScore: 21.8909 | HPSv2: 0.2800 | HPSv2.1: 0.2905 | ImageReward: 0.9830
215
+
216
+ [gradient_ascent] Batch 480/500 | Samples: 480/500 | Reward (t=136.0): 0.9692 | Reward (Avg): 0.9096 | CLIP: 26.6490 | Aesthetic: 5.9346 | PickScore: 21.8834 | HPSv2: 0.2800 | HPSv2.1: 0.2903 | ImageReward: 0.9868
217
+
218
+ [gradient_ascent] Batch 490/500 | Samples: 490/500 | Reward (t=136.0): 0.9629 | Reward (Avg): 0.9101 | CLIP: 26.6586 | Aesthetic: 5.9349 | PickScore: 21.8766 | HPSv2: 0.2800 | HPSv2.1: 0.2900 | ImageReward: 0.9869
219
+
220
+ [gradient_ascent] Batch 500/500 | Samples: 500/500 | Reward (t=136.0): 0.9927 | Reward (Avg): 0.9101 | CLIP: 26.6660 | Aesthetic: 5.9369 | PickScore: 21.8768 | HPSv2: 0.2800 | HPSv2.1: 0.2903 | ImageReward: 0.9915
221
+ ✓ Gradient Ascent Avg Reward: 0.9101
222
+ ✓ Gradient Ascent Avg CLIP Score: 26.6660
223
+ ✓ Gradient Ascent Avg Aesthetic Score: 5.9369
224
+ ✓ Gradient Ascent Avg PickScore: 21.8768
225
+ ✓ Gradient Ascent Avg HPSv2 Score: 0.2800
226
+ ✓ Gradient Ascent Avg HPSv2.1 Score: 0.2903
227
+ ✓ Gradient Ascent Avg ImageReward: 0.9915
228
+
229
+ Gradient Ascent Statistics:
230
+ Applications: 10
231
+ Total reward improvement: +0.0054
232
+ Avg reward improvement: +0.0005
233
+
234
+ ✓ Saved LR curve plot to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_2/lr_curve.png
235
+ Total gradient steps: 10
236
+ LR range: 1.000000 → 1.000000
237
+
238
+ ✓ Saved Rewards curve plot to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_2/rewards_curve.png
239
+ Total gradient steps: 20
240
+ Reward range: 0.9971 → 0.9990
241
+ Total improvement: +0.0020
242
+
243
+ ======================================================================
244
+ FINAL RESULTS
245
+ ======================================================================
246
+
247
+ Gradient Ascent:
248
+ Avg Reward: 0.9101
249
+ Avg CLIP Score: 26.6660
250
+ Avg Aesthetic: 5.9369
251
+ Avg PickScore: 21.8768
252
+ Avg HPSv2: 0.2800
253
+ Avg HPSv2.1: 0.2903
254
+ Avg ImageReward: 0.9915
255
+
256
+ ✓ Results saved to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_2/evaluation_results.txt
257
+
258
+ ======================================================================
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_2/lr_curve.png ADDED
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_3/evaluation_results.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ mode: gradient_ascent
2
+ metrics: ['clip', 'aesthetic', 'pickscore', 'hpsv2', 'hpsv21', 'imagereward']
3
+ config: {'num_samples': 500, 'num_steps': 20, 'cfg_scale': 4.5, 'grad_range': [0, 700], 'grad_steps': 5, 'grad_step_size': 0.1}
4
+ gradient_ascent: {'avg_reward': np.float64(0.92294921875), 'clip_score': np.float64(26.616577001571656), 'aesthetic_score': np.float64(5.955972215652466), 'pickscore': np.float64(21.881759201049803), 'hpsv2_score': np.float16(0.2798), 'hpsv21_score': np.float16(0.2905), 'imagereward_score': np.float64(0.9889077876545489), 'stats': {'num_applications': 11, 'total_reward_improvement': 0.078125, 'avg_reward_improvement': 0.007102272727272727, 'avg_grad_norm': 0.21211126311258835, 'max_grad_norm': 0.277950644493103, 'detailed_stats': [{'timestep': 785, 'initial_reward': 0.9140625, 'final_reward': 0.93359375, 'reward_improvement': 0.01953125, 'grad_norms': [0.26926133036613464], 'reward_history': [0.9140625, 0.9140625], 'lr_history': [1.0], 'latent_change': 1.0}, {'timestep': 749, 'initial_reward': 0.9375, 'final_reward': 0.94921875, 'reward_improvement': 0.01171875, 'grad_norms': [0.277950644493103], 'reward_history': [0.9375, 0.9375], 'lr_history': [1.0], 'latent_change': 0.9999999403953552}, {'timestep': 710, 'initial_reward': 0.94921875, 'final_reward': 0.9609375, 'reward_improvement': 0.01171875, 'grad_norms': [0.27435681223869324], 'reward_history': [0.94921875, 0.94921875], 'lr_history': [1.0], 'latent_change': 1.0000001192092896}, {'timestep': 666, 'initial_reward': 0.9609375, 'final_reward': 0.96875, 'reward_improvement': 0.0078125, 'grad_norms': [0.24942916631698608], 'reward_history': [0.9609375, 0.9609375], 'lr_history': [1.0], 'latent_change': 0.9999998807907104}, {'timestep': 617, 'initial_reward': 0.96875, 'final_reward': 0.97265625, 'reward_improvement': 0.00390625, 'grad_norms': [0.21902038156986237], 'reward_history': [0.96875, 0.96875], 'lr_history': [1.0], 'latent_change': 0.9999999403953552}, {'timestep': 562, 'initial_reward': 0.96875, 'final_reward': 0.9765625, 'reward_improvement': 0.0078125, 'grad_norms': [0.19418643414974213], 'reward_history': [0.96875, 0.96875], 'lr_history': [1.0], 'latent_change': 1.0}, {'timestep': 499, 'initial_reward': 0.97265625, 'final_reward': 0.9765625, 'reward_improvement': 0.00390625, 'grad_norms': [0.18358778953552246], 'reward_history': [0.97265625, 0.97265625], 'lr_history': [1.0], 'latent_change': 0.9999998807907104}, {'timestep': 428, 'initial_reward': 0.9765625, 'final_reward': 0.98046875, 'reward_improvement': 0.00390625, 'grad_norms': [0.1757386326789856], 'reward_history': [0.9765625, 0.9765625], 'lr_history': [1.0], 'latent_change': 0.9999999403953552}, {'timestep': 345, 'initial_reward': 0.9765625, 'final_reward': 0.98046875, 'reward_improvement': 0.00390625, 'grad_norms': [0.16876810789108276], 'reward_history': [0.9765625, 0.9765625], 'lr_history': [1.0], 'latent_change': 0.9999998807907104}, {'timestep': 249, 'initial_reward': 0.98046875, 'final_reward': 0.98046875, 'reward_improvement': 0.0, 'grad_norms': [0.16214041411876678], 'reward_history': [0.98046875, 0.98046875], 'lr_history': [1.0], 'latent_change': 0.9999998807907104}, {'timestep': 136, 'initial_reward': 0.9765625, 'final_reward': 0.98046875, 'reward_improvement': 0.00390625, 'grad_norms': [0.1587841808795929], 'reward_history': [0.9765625, 0.9765625], 'lr_history': [1.0], 'latent_change': 0.9999998807907104}]}}
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_3/log.log ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ======================================================================
2
+ FID EVALUATION: BASELINE vs GRADIENT ASCENT
3
+ ======================================================================
4
+
5
+ Logging to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_3/log.log
6
+
7
+ Device: cuda:0
8
+ Dataset: PICKAPIC
9
+ Data directory: ./data
10
+ Base model: Efficient-Large-Model/Sana_600M_512px_diffusers
11
+ Model variant: sana_600m_512
12
+ LRM model: /g/data/rr81/LPO/lrm/lrm_sana/logs/v8/reward_model/step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951/checkpoint-gstep33000
13
+ HF cache dir: /scratch/rr81/ma5430/.cache/huggingface/hub
14
+ HF offline mode: True
15
+ Inference steps: 20
16
+ CFG scale: 4.5
17
+ Batch size: 1
18
+ Max samples: All
19
+ Generation dtype: bf16
20
+ Output directory: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_3
21
+ Save images: False
22
+ Evaluation mode: gradient_ascent
23
+ Metrics to evaluate: CLIP, AESTHETIC, PICKSCORE, HPSV2, HPSV21, IMAGEREWARD
24
+ Gradient ascent config: one_step_rectification_config
25
+
26
+ ======================================================================
27
+ 1. LOADING VALIDATION DATA
28
+ ======================================================================
29
+ Loading Pick-a-Pic validation prompts...
30
+ Loading cached Pick-a-Pic split 'validation_unique' from 1 parquet shards
31
+ cache=/scratch/rr81/ma5430/.cache/huggingface/hub/datasets--pickapic-anonymous--pickapic_v1
32
+ Loaded 500 Pick-a-Pic validation samples
33
+
34
+ ======================================================================
35
+ 2. LOADING REWARD MODEL
36
+ ======================================================================
37
+ Loading SANA base reward backbone from Efficient-Large-Model/Sana_600M_512px_diffusers...
38
+ Loading SANA reward checkpoint from /g/data/rr81/LPO/lrm/lrm_sana/logs/v8/reward_model/step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951/checkpoint-gstep33000/model.safetensors...
39
+ ✓ Loaded checkpoint keys: 1214
40
+ ✓ Missing keys: 0 | Unexpected keys: 0
41
+ ✓ SANA LRM Reward Model initialized successfully!
42
+ ✓ Reward model loaded
43
+
44
+ ======================================================================
45
+ 3. LOADING PIPELINE
46
+ ======================================================================
47
+ ✓ Loaded SANA base model: Efficient-Large-Model/Sana_600M_512px_diffusers
48
+ ✓ Reward model attached to SANA pipeline
49
+ ✓ Pipeline loaded
50
+ GPU memory before scorer load: 125.86 GB free / 140.06 GB total
51
+ Scorer device: cuda:0
52
+
53
+ ======================================================================
54
+ 3.5. LOADING CLIP AND AESTHETIC SCORERS
55
+ ======================================================================
56
+ ✓ CLIP scorer loaded
57
+ ✓ Aesthetic scorer loaded
58
+ ✓ PickScore scorer loaded
59
+ ✓ HPSv2 scorer loaded
60
+ ✓ HPSv2.1 scorer loaded
61
+ load checkpoint from /scratch/rr81/ma5430/.cache/huggingface/hub/models--THUDM--ImageReward/snapshots/5736be03b2652728fb87788c9797b0570450ab72/ImageReward.pt
62
+ checkpoint loaded
63
+ ✓ ImageReward scorer loaded
64
+
65
+ ======================================================================
66
+ 4. CONFIGURING GRADIENT ASCENT
67
+ ======================================================================
68
+ Loading gradient ascent config: one_step_rectification_config
69
+ Config loaded: {'grad_timestep_range': (100, 800), 'num_grad_steps': 1, 'grad_step_size': 1.0, 'grad_scale': 1.0, 'lr_scheduler_type': 'constant', 'use_momentum': False, 'use_nesterov': False, 'use_iso_projection': False}
70
+ Gradient timestep range: (100, 800)
71
+ Gradient steps: 1
72
+ Gradient step size (initial LR): 1.0
73
+ LR Scheduler: constant
74
+ ✓ Gradient ascent enabled for timesteps (100, 800)
75
+
76
+ ======================================================================
77
+ 6. EVALUATING GRADIENT ASCENT
78
+ ======================================================================
79
+
80
+ Generating images with gradient_ascent mode...
81
+
82
+ [gradient_ascent] Batch 10/500 | Samples: 10/500 | Reward (t=136.0): 1.0000 | Reward (Avg): 0.9625 | CLIP: 26.4663 | Aesthetic: 6.1413 | PickScore: 22.1120 | HPSv2: 0.2876 | HPSv2.1: 0.3157 | ImageReward: 1.2690
83
+
84
+ [gradient_ascent] Batch 20/500 | Samples: 20/500 | Reward (t=136.0): 0.7266 | Reward (Avg): 0.9354 | CLIP: 26.3375 | Aesthetic: 6.1225 | PickScore: 22.1887 | HPSv2: 0.2869 | HPSv2.1: 0.3118 | ImageReward: 1.0984
85
+
86
+ [gradient_ascent] Batch 30/500 | Samples: 30/500 | Reward (t=136.0): 1.0000 | Reward (Avg): 0.9434 | CLIP: 26.4766 | Aesthetic: 6.0533 | PickScore: 22.3414 | HPSv2: 0.2864 | HPSv2.1: 0.3079 | ImageReward: 1.0883
87
+
88
+ [gradient_ascent] Batch 40/500 | Samples: 40/500 | Reward (t=136.0): 0.9609 | Reward (Avg): 0.9415 | CLIP: 26.6829 | Aesthetic: 6.0924 | PickScore: 22.2873 | HPSv2: 0.2844 | HPSv2.1: 0.3047 | ImageReward: 0.9498
89
+
90
+ [gradient_ascent] Batch 50/500 | Samples: 50/500 | Reward (t=136.0): 0.8711 | Reward (Avg): 0.9413 | CLIP: 26.3142 | Aesthetic: 6.0252 | PickScore: 22.1119 | HPSv2: 0.2830 | HPSv2.1: 0.3010 | ImageReward: 1.0441
91
+
92
+ [gradient_ascent] Batch 60/500 | Samples: 60/500 | Reward (t=136.0): 0.9961 | Reward (Avg): 0.9441 | CLIP: 26.1909 | Aesthetic: 6.0327 | PickScore: 22.1030 | HPSv2: 0.2837 | HPSv2.1: 0.3013 | ImageReward: 1.0347
93
+
94
+ [gradient_ascent] Batch 70/500 | Samples: 70/500 | Reward (t=136.0): 1.0000 | Reward (Avg): 0.9421 | CLIP: 26.4568 | Aesthetic: 6.0299 | PickScore: 22.1226 | HPSv2: 0.2839 | HPSv2.1: 0.3013 | ImageReward: 1.0243
95
+
96
+ [gradient_ascent] Batch 80/500 | Samples: 80/500 | Reward (t=136.0): 0.9922 | Reward (Avg): 0.9356 | CLIP: 26.2856 | Aesthetic: 6.0066 | PickScore: 22.0058 | HPSv2: 0.2830 | HPSv2.1: 0.2964 | ImageReward: 0.9735
97
+
98
+ [gradient_ascent] Batch 90/500 | Samples: 90/500 | Reward (t=136.0): 0.9961 | Reward (Avg): 0.9348 | CLIP: 26.5472 | Aesthetic: 5.9998 | PickScore: 21.9813 | HPSv2: 0.2825 | HPSv2.1: 0.2964 | ImageReward: 0.9939
99
+
100
+ [gradient_ascent] Batch 100/500 | Samples: 100/500 | Reward (t=136.0): 0.9844 | Reward (Avg): 0.9393 | CLIP: 26.8104 | Aesthetic: 6.0321 | PickScore: 21.9362 | HPSv2: 0.2817 | HPSv2.1: 0.2947 | ImageReward: 0.9892
101
+
102
+ [gradient_ascent] Batch 110/500 | Samples: 110/500 | Reward (t=136.0): 0.9844 | Reward (Avg): 0.9385 | CLIP: 26.9949 | Aesthetic: 6.0235 | PickScore: 21.9147 | HPSv2: 0.2817 | HPSv2.1: 0.2942 | ImageReward: 0.9906
103
+
104
+ [gradient_ascent] Batch 120/500 | Samples: 120/500 | Reward (t=136.0): 0.7188 | Reward (Avg): 0.9285 | CLIP: 26.8847 | Aesthetic: 6.0040 | PickScore: 21.8963 | HPSv2: 0.2815 | HPSv2.1: 0.2942 | ImageReward: 1.0242
105
+
106
+ [gradient_ascent] Batch 130/500 | Samples: 130/500 | Reward (t=136.0): 0.9883 | Reward (Avg): 0.9322 | CLIP: 26.9108 | Aesthetic: 6.0127 | PickScore: 21.9706 | HPSv2: 0.2822 | HPSv2.1: 0.2949 | ImageReward: 1.0272
107
+
108
+ [gradient_ascent] Batch 140/500 | Samples: 140/500 | Reward (t=136.0): 0.9727 | Reward (Avg): 0.9334 | CLIP: 27.1204 | Aesthetic: 5.9798 | PickScore: 21.9565 | HPSv2: 0.2820 | HPSv2.1: 0.2939 | ImageReward: 1.0218
109
+
110
+ [gradient_ascent] Batch 150/500 | Samples: 150/500 | Reward (t=136.0): 0.9961 | Reward (Avg): 0.9266 | CLIP: 27.0076 | Aesthetic: 5.9507 | PickScore: 21.9421 | HPSv2: 0.2815 | HPSv2.1: 0.2927 | ImageReward: 1.0071
111
+
112
+ [gradient_ascent] Batch 160/500 | Samples: 160/500 | Reward (t=136.0): 0.9375 | Reward (Avg): 0.9261 | CLIP: 27.1794 | Aesthetic: 5.9462 | PickScore: 21.9868 | HPSv2: 0.2815 | HPSv2.1: 0.2927 | ImageReward: 1.0162
113
+
114
+ [gradient_ascent] Batch 170/500 | Samples: 170/500 | Reward (t=136.0): 0.8789 | Reward (Avg): 0.9254 | CLIP: 27.0033 | Aesthetic: 5.9378 | PickScore: 21.9386 | HPSv2: 0.2810 | HPSv2.1: 0.2915 | ImageReward: 1.0123
115
+
116
+ [gradient_ascent] Batch 180/500 | Samples: 180/500 | Reward (t=136.0): 0.9570 | Reward (Avg): 0.9212 | CLIP: 27.1150 | Aesthetic: 5.9410 | PickScore: 21.9510 | HPSv2: 0.2815 | HPSv2.1: 0.2922 | ImageReward: 1.0373
117
+
118
+ [gradient_ascent] Batch 190/500 | Samples: 190/500 | Reward (t=136.0): 0.7188 | Reward (Avg): 0.9218 | CLIP: 27.0857 | Aesthetic: 5.9454 | PickScore: 21.9378 | HPSv2: 0.2810 | HPSv2.1: 0.2915 | ImageReward: 1.0217
119
+
120
+ [gradient_ascent] Batch 200/500 | Samples: 200/500 | Reward (t=136.0): 0.9609 | Reward (Avg): 0.9220 | CLIP: 27.0541 | Aesthetic: 5.9446 | PickScore: 21.9444 | HPSv2: 0.2812 | HPSv2.1: 0.2922 | ImageReward: 1.0335
121
+
122
+ [gradient_ascent] Batch 210/500 | Samples: 210/500 | Reward (t=136.0): 0.9922 | Reward (Avg): 0.9241 | CLIP: 27.0499 | Aesthetic: 5.9568 | PickScore: 21.9480 | HPSv2: 0.2812 | HPSv2.1: 0.2930 | ImageReward: 1.0426
123
+
124
+ [gradient_ascent] Batch 220/500 | Samples: 220/500 | Reward (t=136.0): 0.9922 | Reward (Avg): 0.9244 | CLIP: 27.1198 | Aesthetic: 5.9586 | PickScore: 21.9354 | HPSv2: 0.2812 | HPSv2.1: 0.2932 | ImageReward: 1.0333
125
+
126
+ [gradient_ascent] Batch 230/500 | Samples: 230/500 | Reward (t=136.0): 0.8281 | Reward (Avg): 0.9250 | CLIP: 27.1823 | Aesthetic: 5.9633 | PickScore: 21.9462 | HPSv2: 0.2812 | HPSv2.1: 0.2927 | ImageReward: 1.0163
127
+
128
+ [gradient_ascent] Batch 240/500 | Samples: 240/500 | Reward (t=136.0): 0.9453 | Reward (Avg): 0.9262 | CLIP: 27.0908 | Aesthetic: 5.9750 | PickScore: 21.9335 | HPSv2: 0.2808 | HPSv2.1: 0.2920 | ImageReward: 1.0059
129
+
130
+ [gradient_ascent] Batch 250/500 | Samples: 250/500 | Reward (t=136.0): 0.9961 | Reward (Avg): 0.9277 | CLIP: 27.0939 | Aesthetic: 5.9814 | PickScore: 21.9393 | HPSv2: 0.2810 | HPSv2.1: 0.2922 | ImageReward: 1.0046
131
+
132
+ [gradient_ascent] Batch 260/500 | Samples: 260/500 | Reward (t=136.0): 0.9648 | Reward (Avg): 0.9273 | CLIP: 27.0651 | Aesthetic: 5.9886 | PickScore: 21.9368 | HPSv2: 0.2810 | HPSv2.1: 0.2922 | ImageReward: 1.0170
133
+
134
+ [gradient_ascent] Batch 270/500 | Samples: 270/500 | Reward (t=136.0): 1.0000 | Reward (Avg): 0.9268 | CLIP: 26.9957 | Aesthetic: 5.9890 | PickScore: 21.9255 | HPSv2: 0.2810 | HPSv2.1: 0.2922 | ImageReward: 1.0180
135
+
136
+ [gradient_ascent] Batch 280/500 | Samples: 280/500 | Reward (t=136.0): 0.9688 | Reward (Avg): 0.9280 | CLIP: 27.0482 | Aesthetic: 5.9911 | PickScore: 21.9605 | HPSv2: 0.2812 | HPSv2.1: 0.2930 | ImageReward: 1.0319
137
+
138
+ [gradient_ascent] Batch 290/500 | Samples: 290/500 | Reward (t=136.0): 0.9688 | Reward (Avg): 0.9285 | CLIP: 26.9579 | Aesthetic: 5.9804 | PickScore: 21.9239 | HPSv2: 0.2810 | HPSv2.1: 0.2922 | ImageReward: 1.0166
139
+
140
+ [gradient_ascent] Batch 300/500 | Samples: 300/500 | Reward (t=136.0): 1.0000 | Reward (Avg): 0.9295 | CLIP: 26.9421 | Aesthetic: 5.9830 | PickScore: 21.9336 | HPSv2: 0.2812 | HPSv2.1: 0.2927 | ImageReward: 1.0217
141
+
142
+ [gradient_ascent] Batch 310/500 | Samples: 310/500 | Reward (t=136.0): 0.9961 | Reward (Avg): 0.9297 | CLIP: 27.0328 | Aesthetic: 5.9872 | PickScore: 21.9477 | HPSv2: 0.2815 | HPSv2.1: 0.2932 | ImageReward: 1.0419
143
+
144
+ [gradient_ascent] Batch 320/500 | Samples: 320/500 | Reward (t=136.0): 0.9492 | Reward (Avg): 0.9282 | CLIP: 27.0486 | Aesthetic: 5.9830 | PickScore: 21.9641 | HPSv2: 0.2812 | HPSv2.1: 0.2930 | ImageReward: 1.0378
145
+
146
+ [gradient_ascent] Batch 330/500 | Samples: 330/500 | Reward (t=136.0): 0.9961 | Reward (Avg): 0.9291 | CLIP: 27.0375 | Aesthetic: 5.9877 | PickScore: 21.9561 | HPSv2: 0.2812 | HPSv2.1: 0.2930 | ImageReward: 1.0415
147
+
148
+ [gradient_ascent] Batch 340/500 | Samples: 340/500 | Reward (t=136.0): 0.9688 | Reward (Avg): 0.9295 | CLIP: 27.0379 | Aesthetic: 5.9848 | PickScore: 21.9602 | HPSv2: 0.2815 | HPSv2.1: 0.2935 | ImageReward: 1.0553
149
+
150
+ [gradient_ascent] Batch 350/500 | Samples: 350/500 | Reward (t=136.0): 0.7539 | Reward (Avg): 0.9292 | CLIP: 27.0467 | Aesthetic: 5.9841 | PickScore: 21.9707 | HPSv2: 0.2812 | HPSv2.1: 0.2935 | ImageReward: 1.0468
151
+
152
+ [gradient_ascent] Batch 360/500 | Samples: 360/500 | Reward (t=136.0): 0.9609 | Reward (Avg): 0.9277 | CLIP: 27.0976 | Aesthetic: 5.9827 | PickScore: 21.9736 | HPSv2: 0.2815 | HPSv2.1: 0.2939 | ImageReward: 1.0415
153
+
154
+ [gradient_ascent] Batch 370/500 | Samples: 370/500 | Reward (t=136.0): 0.9453 | Reward (Avg): 0.9267 | CLIP: 27.0320 | Aesthetic: 5.9800 | PickScore: 21.9612 | HPSv2: 0.2815 | HPSv2.1: 0.2935 | ImageReward: 1.0372
155
+
156
+ [gradient_ascent] Batch 380/500 | Samples: 380/500 | Reward (t=136.0): 0.7891 | Reward (Avg): 0.9272 | CLIP: 27.1124 | Aesthetic: 5.9873 | PickScore: 21.9642 | HPSv2: 0.2815 | HPSv2.1: 0.2939 | ImageReward: 1.0436
157
+
158
+ [gradient_ascent] Batch 390/500 | Samples: 390/500 | Reward (t=136.0): 0.9922 | Reward (Avg): 0.9274 | CLIP: 26.9883 | Aesthetic: 5.9833 | PickScore: 21.9461 | HPSv2: 0.2810 | HPSv2.1: 0.2930 | ImageReward: 1.0313
159
+
160
+ [gradient_ascent] Batch 400/500 | Samples: 400/500 | Reward (t=136.0): 0.8711 | Reward (Avg): 0.9275 | CLIP: 26.9427 | Aesthetic: 5.9776 | PickScore: 21.9448 | HPSv2: 0.2810 | HPSv2.1: 0.2930 | ImageReward: 1.0291
161
+
162
+ [gradient_ascent] Batch 410/500 | Samples: 410/500 | Reward (t=136.0): 0.3535 | Reward (Avg): 0.9265 | CLIP: 26.8892 | Aesthetic: 5.9751 | PickScore: 21.9345 | HPSv2: 0.2808 | HPSv2.1: 0.2925 | ImageReward: 1.0099
163
+
164
+ [gradient_ascent] Batch 420/500 | Samples: 420/500 | Reward (t=136.0): 0.9844 | Reward (Avg): 0.9249 | CLIP: 26.8305 | Aesthetic: 5.9748 | PickScore: 21.9225 | HPSv2: 0.2805 | HPSv2.1: 0.2922 | ImageReward: 1.0087
165
+
166
+ [gradient_ascent] Batch 430/500 | Samples: 430/500 | Reward (t=136.0): 0.9609 | Reward (Avg): 0.9245 | CLIP: 26.8481 | Aesthetic: 5.9711 | PickScore: 21.9269 | HPSv2: 0.2805 | HPSv2.1: 0.2922 | ImageReward: 1.0131
167
+
168
+ [gradient_ascent] Batch 440/500 | Samples: 440/500 | Reward (t=136.0): 0.8984 | Reward (Avg): 0.9250 | CLIP: 26.7605 | Aesthetic: 5.9752 | PickScore: 21.9258 | HPSv2: 0.2805 | HPSv2.1: 0.2922 | ImageReward: 1.0092
169
+
170
+ [gradient_ascent] Batch 450/500 | Samples: 450/500 | Reward (t=136.0): 0.9531 | Reward (Avg): 0.9243 | CLIP: 26.7746 | Aesthetic: 5.9678 | PickScore: 21.9262 | HPSv2: 0.2803 | HPSv2.1: 0.2917 | ImageReward: 1.0003
171
+
172
+ [gradient_ascent] Batch 460/500 | Samples: 460/500 | Reward (t=136.0): 0.9805 | Reward (Avg): 0.9246 | CLIP: 26.7613 | Aesthetic: 5.9665 | PickScore: 21.9229 | HPSv2: 0.2803 | HPSv2.1: 0.2915 | ImageReward: 0.9966
173
+
174
+ [gradient_ascent] Batch 470/500 | Samples: 470/500 | Reward (t=136.0): 0.9766 | Reward (Avg): 0.9251 | CLIP: 26.6828 | Aesthetic: 5.9578 | PickScore: 21.9037 | HPSv2: 0.2800 | HPSv2.1: 0.2908 | ImageReward: 0.9823
175
+
176
+ [gradient_ascent] Batch 480/500 | Samples: 480/500 | Reward (t=136.0): 0.9688 | Reward (Avg): 0.9234 | CLIP: 26.6169 | Aesthetic: 5.9549 | PickScore: 21.8937 | HPSv2: 0.2798 | HPSv2.1: 0.2905 | ImageReward: 0.9846
177
+
178
+ [gradient_ascent] Batch 490/500 | Samples: 490/500 | Reward (t=136.0): 0.9688 | Reward (Avg): 0.9227 | CLIP: 26.6094 | Aesthetic: 5.9548 | PickScore: 21.8840 | HPSv2: 0.2798 | HPSv2.1: 0.2903 | ImageReward: 0.9847
179
+
180
+ [gradient_ascent] Batch 500/500 | Samples: 500/500 | Reward (t=136.0): 0.9805 | Reward (Avg): 0.9229 | CLIP: 26.6166 | Aesthetic: 5.9560 | PickScore: 21.8818 | HPSv2: 0.2798 | HPSv2.1: 0.2905 | ImageReward: 0.9889
181
+ ✓ Gradient Ascent Avg Reward: 0.9229
182
+ ✓ Gradient Ascent Avg CLIP Score: 26.6166
183
+ ✓ Gradient Ascent Avg Aesthetic Score: 5.9560
184
+ ✓ Gradient Ascent Avg PickScore: 21.8818
185
+ ✓ Gradient Ascent Avg HPSv2 Score: 0.2798
186
+ ✓ Gradient Ascent Avg HPSv2.1 Score: 0.2905
187
+ ✓ Gradient Ascent Avg ImageReward: 0.9889
188
+
189
+ Gradient Ascent Statistics:
190
+ Applications: 11
191
+ Total reward improvement: +0.0781
192
+ Avg reward improvement: +0.0071
193
+
194
+ ✓ Saved LR curve plot to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_3/lr_curve.png
195
+ Total gradient steps: 11
196
+ LR range: 1.000000 → 1.000000
197
+
198
+ ✓ Saved Rewards curve plot to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_3/rewards_curve.png
199
+ Total gradient steps: 22
200
+ Reward range: 0.9922 → 1.0000
201
+ Total improvement: +0.0078
202
+
203
+ ======================================================================
204
+ FINAL RESULTS
205
+ ======================================================================
206
+
207
+ Gradient Ascent:
208
+ Avg Reward: 0.9229
209
+ Avg CLIP Score: 26.6166
210
+ Avg Aesthetic: 5.9560
211
+ Avg PickScore: 21.8818
212
+ Avg HPSv2: 0.2798
213
+ Avg HPSv2.1: 0.2905
214
+ Avg ImageReward: 0.9889
215
+
216
+ ✓ Results saved to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_3/evaluation_results.txt
217
+
218
+ ======================================================================
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_3/lr_curve.png ADDED
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/evaluation_results.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ mode: gradient_ascent
2
+ metrics: ['clip', 'aesthetic', 'pickscore', 'hpsv2', 'hpsv21', 'imagereward']
3
+ config: {'num_samples': 500, 'num_steps': 20, 'cfg_scale': 4.5, 'grad_range': [0, 700], 'grad_steps': 5, 'grad_step_size': 0.1}
4
+ gradient_ascent: {'avg_reward': np.float64(0.86123828125), 'clip_score': np.float64(26.688936717987062), 'aesthetic_score': np.float64(5.964459408760071), 'pickscore': np.float64(21.888245372772218), 'hpsv2_score': np.float16(0.2798), 'hpsv21_score': np.float16(0.2896), 'imagereward_score': np.float64(0.9574673203025014), 'stats': {'num_applications': 11, 'total_reward_improvement': 0.3046875, 'avg_reward_improvement': 0.027698863636363636, 'avg_grad_norm': 0.20624772933396426, 'max_grad_norm': 0.24269965291023254, 'detailed_stats': [{'timestep': 785, 'initial_reward': 0.65234375, 'final_reward': 0.703125, 'reward_improvement': 0.05078125, 'grad_norms': [0.237853541970253], 'reward_history': [0.65234375, 0.65234375], 'lr_history': [1.0], 'latent_change': 0.9999999403953552}, {'timestep': 749, 'initial_reward': 0.71484375, 'final_reward': 0.76171875, 'reward_improvement': 0.046875, 'grad_norms': [0.24269965291023254], 'reward_history': [0.71484375, 0.71484375], 'lr_history': [1.0], 'latent_change': 0.9999999403953552}, {'timestep': 710, 'initial_reward': 0.76171875, 'final_reward': 0.80078125, 'reward_improvement': 0.0390625, 'grad_norms': [0.23808617889881134], 'reward_history': [0.76171875, 0.76171875], 'lr_history': [1.0], 'latent_change': 1.0}, {'timestep': 666, 'initial_reward': 0.80078125, 'final_reward': 0.83203125, 'reward_improvement': 0.03125, 'grad_norms': [0.2161739617586136], 'reward_history': [0.80078125, 0.80078125], 'lr_history': [1.0], 'latent_change': 0.9999998807907104}, {'timestep': 617, 'initial_reward': 0.828125, 'final_reward': 0.85546875, 'reward_improvement': 0.02734375, 'grad_norms': [0.20163142681121826], 'reward_history': [0.828125, 0.828125], 'lr_history': [1.0], 'latent_change': 1.0}, {'timestep': 562, 'initial_reward': 0.8515625, 'final_reward': 0.87109375, 'reward_improvement': 0.01953125, 'grad_norms': [0.1897481232881546], 'reward_history': [0.8515625, 0.8515625], 'lr_history': [1.0], 'latent_change': 0.9999999403953552}, {'timestep': 499, 'initial_reward': 0.8671875, 'final_reward': 0.88671875, 'reward_improvement': 0.01953125, 'grad_norms': [0.181509330868721], 'reward_history': [0.8671875, 0.8671875], 'lr_history': [1.0], 'latent_change': 0.9999998807907104}, {'timestep': 428, 'initial_reward': 0.875, 'final_reward': 0.89453125, 'reward_improvement': 0.01953125, 'grad_norms': [0.18168200552463531], 'reward_history': [0.875, 0.875], 'lr_history': [1.0], 'latent_change': 0.9999998807907104}, {'timestep': 345, 'initial_reward': 0.88671875, 'final_reward': 0.90234375, 'reward_improvement': 0.015625, 'grad_norms': [0.185561865568161], 'reward_history': [0.88671875, 0.88671875], 'lr_history': [1.0], 'latent_change': 0.9999999403953552}, {'timestep': 249, 'initial_reward': 0.890625, 'final_reward': 0.90625, 'reward_improvement': 0.015625, 'grad_norms': [0.19334888458251953], 'reward_history': [0.890625, 0.890625], 'lr_history': [1.0], 'latent_change': 0.9999998807907104}, {'timestep': 136, 'initial_reward': 0.890625, 'final_reward': 0.91015625, 'reward_improvement': 0.01953125, 'grad_norms': [0.20043005049228668], 'reward_history': [0.890625, 0.890625], 'lr_history': [1.0], 'latent_change': 0.9999999403953552}]}}
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/log.log ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ======================================================================
2
+ FID EVALUATION: BASELINE vs GRADIENT ASCENT
3
+ ======================================================================
4
+
5
+ Logging to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/log.log
6
+
7
+ Device: cuda:0
8
+ Dataset: PICKAPIC
9
+ Data directory: ./data
10
+ Base model: Efficient-Large-Model/Sana_600M_512px_diffusers
11
+ Model variant: sana_600m_512
12
+ LRM model: /g/data/rr81/LPO/lrm/lrm_sana/logs/v7/reward_model/step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951/checkpoint-gstep32000
13
+ HF cache dir: /scratch/rr81/ma5430/.cache/huggingface/hub
14
+ HF offline mode: True
15
+ Inference steps: 20
16
+ CFG scale: 4.5
17
+ Batch size: 1
18
+ Max samples: All
19
+ Generation dtype: bf16
20
+ Output directory: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4
21
+ Save images: False
22
+ Evaluation mode: gradient_ascent
23
+ Metrics to evaluate: CLIP, AESTHETIC, PICKSCORE, HPSV2, HPSV21, IMAGEREWARD
24
+ Gradient ascent config: one_step_rectification_config
25
+
26
+ ======================================================================
27
+ 1. LOADING VALIDATION DATA
28
+ ======================================================================
29
+ Loading Pick-a-Pic validation prompts...
30
+ Loading cached Pick-a-Pic split 'validation_unique' from 1 parquet shards
31
+ cache=/scratch/rr81/ma5430/.cache/huggingface/hub/datasets--pickapic-anonymous--pickapic_v1
32
+ Loaded 500 Pick-a-Pic validation samples
33
+
34
+ ======================================================================
35
+ 2. LOADING REWARD MODEL
36
+ ======================================================================
37
+ Loading SANA base reward backbone from Efficient-Large-Model/Sana_600M_512px_diffusers...
38
+ Loading SANA reward checkpoint from /g/data/rr81/LPO/lrm/lrm_sana/logs/v7/reward_model/step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951/checkpoint-gstep32000/model.safetensors...
39
+ ✓ Loaded checkpoint keys: 1214
40
+ ✓ Missing keys: 0 | Unexpected keys: 0
41
+ ✓ SANA LRM Reward Model initialized successfully!
42
+ ✓ Reward model loaded
43
+
44
+ ======================================================================
45
+ 3. LOADING PIPELINE
46
+ ======================================================================
47
+ ✓ Loaded SANA base model: Efficient-Large-Model/Sana_600M_512px_diffusers
48
+ ✓ Reward model attached to SANA pipeline
49
+ ✓ Pipeline loaded
50
+ GPU memory before scorer load: 85.91 GB free / 140.06 GB total
51
+ Scorer device: cuda:0
52
+
53
+ ======================================================================
54
+ 3.5. LOADING CLIP AND AESTHETIC SCORERS
55
+ ======================================================================
56
+ ✓ CLIP scorer loaded
57
+ ✓ Aesthetic scorer loaded
58
+ ✓ PickScore scorer loaded
59
+ ✓ HPSv2 scorer loaded
60
+ ✓ HPSv2.1 scorer loaded
61
+ load checkpoint from /scratch/rr81/ma5430/.cache/huggingface/hub/models--THUDM--ImageReward/snapshots/5736be03b2652728fb87788c9797b0570450ab72/ImageReward.pt
62
+ checkpoint loaded
63
+ ✓ ImageReward scorer loaded
64
+
65
+ ======================================================================
66
+ 4. CONFIGURING GRADIENT ASCENT
67
+ ======================================================================
68
+ Loading gradient ascent config: one_step_rectification_config
69
+ Config loaded: {'grad_timestep_range': (100, 800), 'num_grad_steps': 1, 'grad_step_size': 1.0, 'grad_scale': 1.0, 'lr_scheduler_type': 'constant', 'use_momentum': False, 'use_nesterov': False, 'use_iso_projection': False}
70
+ Gradient timestep range: (100, 800)
71
+ Gradient steps: 1
72
+ Gradient step size (initial LR): 1.0
73
+ LR Scheduler: constant
74
+ ✓ Gradient ascent enabled for timesteps (100, 800)
75
+
76
+ ======================================================================
77
+ 6. EVALUATING GRADIENT ASCENT
78
+ ======================================================================
79
+
80
+ Generating images with gradient_ascent mode...
81
+
82
+ [gradient_ascent] Batch 10/500 | Samples: 10/500 | Reward (t=136.0): 0.9883 | Reward (Avg): 0.9230 | CLIP: 27.2524 | Aesthetic: 6.1769 | PickScore: 22.0294 | HPSv2: 0.2856 | HPSv2.1: 0.3088 | ImageReward: 1.2917
83
+
84
+ [gradient_ascent] Batch 20/500 | Samples: 20/500 | Reward (t=136.0): 0.8672 | Reward (Avg): 0.9189 | CLIP: 26.8319 | Aesthetic: 6.1308 | PickScore: 22.1532 | HPSv2: 0.2854 | HPSv2.1: 0.3091 | ImageReward: 1.1375
85
+
86
+ [gradient_ascent] Batch 30/500 | Samples: 30/500 | Reward (t=136.0): 0.9922 | Reward (Avg): 0.9047 | CLIP: 26.7728 | Aesthetic: 6.0477 | PickScore: 22.3591 | HPSv2: 0.2854 | HPSv2.1: 0.3052 | ImageReward: 1.1096
87
+
88
+ [gradient_ascent] Batch 40/500 | Samples: 40/500 | Reward (t=136.0): 0.8672 | Reward (Avg): 0.9026 | CLIP: 26.8790 | Aesthetic: 6.1062 | PickScore: 22.3062 | HPSv2: 0.2839 | HPSv2.1: 0.3013 | ImageReward: 0.9804
89
+
90
+ [gradient_ascent] Batch 50/500 | Samples: 50/500 | Reward (t=136.0): 0.6523 | Reward (Avg): 0.8916 | CLIP: 26.6755 | Aesthetic: 6.0488 | PickScore: 22.1421 | HPSv2: 0.2827 | HPSv2.1: 0.2991 | ImageReward: 1.0652
91
+
92
+ [gradient_ascent] Batch 60/500 | Samples: 60/500 | Reward (t=136.0): 0.9648 | Reward (Avg): 0.8891 | CLIP: 26.5230 | Aesthetic: 6.0566 | PickScore: 22.1336 | HPSv2: 0.2832 | HPSv2.1: 0.2991 | ImageReward: 1.0473
93
+
94
+ [gradient_ascent] Batch 70/500 | Samples: 70/500 | Reward (t=136.0): 0.9805 | Reward (Avg): 0.8851 | CLIP: 26.6211 | Aesthetic: 6.0387 | PickScore: 22.1567 | HPSv2: 0.2834 | HPSv2.1: 0.2996 | ImageReward: 1.0103
95
+
96
+ [gradient_ascent] Batch 80/500 | Samples: 80/500 | Reward (t=136.0): 0.9141 | Reward (Avg): 0.8869 | CLIP: 26.5639 | Aesthetic: 6.0182 | PickScore: 22.0309 | HPSv2: 0.2825 | HPSv2.1: 0.2947 | ImageReward: 0.9646
97
+
98
+ [gradient_ascent] Batch 90/500 | Samples: 90/500 | Reward (t=136.0): 0.6836 | Reward (Avg): 0.8837 | CLIP: 26.7986 | Aesthetic: 6.0132 | PickScore: 22.0176 | HPSv2: 0.2825 | HPSv2.1: 0.2952 | ImageReward: 0.9975
99
+
100
+ [gradient_ascent] Batch 100/500 | Samples: 100/500 | Reward (t=136.0): 0.5234 | Reward (Avg): 0.8824 | CLIP: 27.0476 | Aesthetic: 6.0459 | PickScore: 21.9703 | HPSv2: 0.2815 | HPSv2.1: 0.2932 | ImageReward: 0.9856
101
+
102
+ [gradient_ascent] Batch 110/500 | Samples: 110/500 | Reward (t=136.0): 0.8086 | Reward (Avg): 0.8786 | CLIP: 27.2317 | Aesthetic: 6.0274 | PickScore: 21.9682 | HPSv2: 0.2820 | HPSv2.1: 0.2935 | ImageReward: 0.9922
103
+
104
+ [gradient_ascent] Batch 120/500 | Samples: 120/500 | Reward (t=136.0): 0.5312 | Reward (Avg): 0.8745 | CLIP: 27.0959 | Aesthetic: 6.0128 | PickScore: 21.9585 | HPSv2: 0.2817 | HPSv2.1: 0.2935 | ImageReward: 1.0282
105
+
106
+ [gradient_ascent] Batch 130/500 | Samples: 130/500 | Reward (t=136.0): 0.8984 | Reward (Avg): 0.8748 | CLIP: 27.1302 | Aesthetic: 6.0250 | PickScore: 22.0331 | HPSv2: 0.2825 | HPSv2.1: 0.2944 | ImageReward: 1.0238
107
+
108
+ [gradient_ascent] Batch 140/500 | Samples: 140/500 | Reward (t=136.0): 0.8438 | Reward (Avg): 0.8711 | CLIP: 27.3238 | Aesthetic: 5.9973 | PickScore: 22.0136 | HPSv2: 0.2822 | HPSv2.1: 0.2935 | ImageReward: 1.0193
109
+
110
+ [gradient_ascent] Batch 150/500 | Samples: 150/500 | Reward (t=136.0): 0.9844 | Reward (Avg): 0.8711 | CLIP: 27.2450 | Aesthetic: 5.9757 | PickScore: 21.9954 | HPSv2: 0.2817 | HPSv2.1: 0.2922 | ImageReward: 1.0077
111
+
112
+ [gradient_ascent] Batch 160/500 | Samples: 160/500 | Reward (t=136.0): 0.8828 | Reward (Avg): 0.8731 | CLIP: 27.3797 | Aesthetic: 5.9691 | PickScore: 22.0329 | HPSv2: 0.2817 | HPSv2.1: 0.2922 | ImageReward: 1.0169
113
+
114
+ [gradient_ascent] Batch 170/500 | Samples: 170/500 | Reward (t=136.0): 0.8477 | Reward (Avg): 0.8723 | CLIP: 27.1880 | Aesthetic: 5.9651 | PickScore: 21.9791 | HPSv2: 0.2812 | HPSv2.1: 0.2910 | ImageReward: 1.0030
115
+
116
+ [gradient_ascent] Batch 180/500 | Samples: 180/500 | Reward (t=136.0): 0.8906 | Reward (Avg): 0.8697 | CLIP: 27.2816 | Aesthetic: 5.9708 | PickScore: 21.9840 | HPSv2: 0.2815 | HPSv2.1: 0.2917 | ImageReward: 1.0276
117
+
118
+ [gradient_ascent] Batch 190/500 | Samples: 190/500 | Reward (t=136.0): 0.8672 | Reward (Avg): 0.8692 | CLIP: 27.2303 | Aesthetic: 5.9694 | PickScore: 21.9639 | HPSv2: 0.2810 | HPSv2.1: 0.2910 | ImageReward: 1.0080
119
+
120
+ [gradient_ascent] Batch 200/500 | Samples: 200/500 | Reward (t=136.0): 0.8281 | Reward (Avg): 0.8684 | CLIP: 27.2174 | Aesthetic: 5.9641 | PickScore: 21.9716 | HPSv2: 0.2815 | HPSv2.1: 0.2917 | ImageReward: 1.0208
121
+
122
+ [gradient_ascent] Batch 210/500 | Samples: 210/500 | Reward (t=136.0): 0.8789 | Reward (Avg): 0.8703 | CLIP: 27.1992 | Aesthetic: 5.9754 | PickScore: 21.9724 | HPSv2: 0.2812 | HPSv2.1: 0.2925 | ImageReward: 1.0306
123
+
124
+ [gradient_ascent] Batch 220/500 | Samples: 220/500 | Reward (t=136.0): 0.9375 | Reward (Avg): 0.8702 | CLIP: 27.2434 | Aesthetic: 5.9779 | PickScore: 21.9538 | HPSv2: 0.2815 | HPSv2.1: 0.2922 | ImageReward: 1.0132
125
+
126
+ [gradient_ascent] Batch 230/500 | Samples: 230/500 | Reward (t=136.0): 0.6250 | Reward (Avg): 0.8715 | CLIP: 27.2932 | Aesthetic: 5.9839 | PickScore: 21.9659 | HPSv2: 0.2812 | HPSv2.1: 0.2920 | ImageReward: 0.9930
127
+
128
+ [gradient_ascent] Batch 240/500 | Samples: 240/500 | Reward (t=136.0): 0.8594 | Reward (Avg): 0.8721 | CLIP: 27.1944 | Aesthetic: 5.9938 | PickScore: 21.9538 | HPSv2: 0.2808 | HPSv2.1: 0.2913 | ImageReward: 0.9821
129
+
130
+ [gradient_ascent] Batch 250/500 | Samples: 250/500 | Reward (t=136.0): 0.8047 | Reward (Avg): 0.8723 | CLIP: 27.1940 | Aesthetic: 5.9992 | PickScore: 21.9601 | HPSv2: 0.2810 | HPSv2.1: 0.2913 | ImageReward: 0.9797
131
+
132
+ [gradient_ascent] Batch 260/500 | Samples: 260/500 | Reward (t=136.0): 0.9297 | Reward (Avg): 0.8711 | CLIP: 27.1640 | Aesthetic: 6.0070 | PickScore: 21.9547 | HPSv2: 0.2810 | HPSv2.1: 0.2915 | ImageReward: 0.9918
133
+
134
+ [gradient_ascent] Batch 270/500 | Samples: 270/500 | Reward (t=136.0): 0.9961 | Reward (Avg): 0.8714 | CLIP: 27.1096 | Aesthetic: 6.0077 | PickScore: 21.9464 | HPSv2: 0.2810 | HPSv2.1: 0.2913 | ImageReward: 0.9900
135
+
136
+ [gradient_ascent] Batch 280/500 | Samples: 280/500 | Reward (t=136.0): 0.9258 | Reward (Avg): 0.8715 | CLIP: 27.1485 | Aesthetic: 6.0107 | PickScore: 21.9772 | HPSv2: 0.2812 | HPSv2.1: 0.2920 | ImageReward: 1.0051
137
+
138
+ [gradient_ascent] Batch 290/500 | Samples: 290/500 | Reward (t=136.0): 0.9062 | Reward (Avg): 0.8722 | CLIP: 27.0509 | Aesthetic: 5.9991 | PickScore: 21.9401 | HPSv2: 0.2810 | HPSv2.1: 0.2913 | ImageReward: 0.9888
139
+
140
+ [gradient_ascent] Batch 300/500 | Samples: 300/500 | Reward (t=136.0): 0.9844 | Reward (Avg): 0.8718 | CLIP: 27.0241 | Aesthetic: 6.0017 | PickScore: 21.9507 | HPSv2: 0.2812 | HPSv2.1: 0.2920 | ImageReward: 0.9941
141
+
142
+ [gradient_ascent] Batch 310/500 | Samples: 310/500 | Reward (t=136.0): 0.9375 | Reward (Avg): 0.8709 | CLIP: 27.0985 | Aesthetic: 6.0055 | PickScore: 21.9663 | HPSv2: 0.2815 | HPSv2.1: 0.2925 | ImageReward: 1.0157
143
+
144
+ [gradient_ascent] Batch 320/500 | Samples: 320/500 | Reward (t=136.0): 0.9492 | Reward (Avg): 0.8703 | CLIP: 27.1293 | Aesthetic: 6.0013 | PickScore: 21.9831 | HPSv2: 0.2815 | HPSv2.1: 0.2922 | ImageReward: 1.0116
145
+
146
+ [gradient_ascent] Batch 330/500 | Samples: 330/500 | Reward (t=136.0): 0.9609 | Reward (Avg): 0.8704 | CLIP: 27.1128 | Aesthetic: 6.0068 | PickScore: 21.9762 | HPSv2: 0.2815 | HPSv2.1: 0.2925 | ImageReward: 1.0155
147
+
148
+ [gradient_ascent] Batch 340/500 | Samples: 340/500 | Reward (t=136.0): 0.9648 | Reward (Avg): 0.8709 | CLIP: 27.1100 | Aesthetic: 6.0010 | PickScore: 21.9779 | HPSv2: 0.2812 | HPSv2.1: 0.2927 | ImageReward: 1.0264
149
+
150
+ [gradient_ascent] Batch 350/500 | Samples: 350/500 | Reward (t=136.0): 0.7969 | Reward (Avg): 0.8694 | CLIP: 27.1092 | Aesthetic: 5.9996 | PickScore: 21.9887 | HPSv2: 0.2812 | HPSv2.1: 0.2927 | ImageReward: 1.0177
151
+
152
+ [gradient_ascent] Batch 360/500 | Samples: 360/500 | Reward (t=136.0): 0.7461 | Reward (Avg): 0.8684 | CLIP: 27.1495 | Aesthetic: 5.9991 | PickScore: 21.9878 | HPSv2: 0.2815 | HPSv2.1: 0.2930 | ImageReward: 1.0115
153
+
154
+ [gradient_ascent] Batch 370/500 | Samples: 370/500 | Reward (t=136.0): 0.9336 | Reward (Avg): 0.8669 | CLIP: 27.0887 | Aesthetic: 5.9964 | PickScore: 21.9767 | HPSv2: 0.2815 | HPSv2.1: 0.2925 | ImageReward: 1.0073
155
+
156
+ [gradient_ascent] Batch 380/500 | Samples: 380/500 | Reward (t=136.0): 0.6836 | Reward (Avg): 0.8673 | CLIP: 27.1722 | Aesthetic: 6.0026 | PickScore: 21.9784 | HPSv2: 0.2812 | HPSv2.1: 0.2927 | ImageReward: 1.0130
157
+
158
+ [gradient_ascent] Batch 390/500 | Samples: 390/500 | Reward (t=136.0): 0.9688 | Reward (Avg): 0.8669 | CLIP: 27.0503 | Aesthetic: 5.9985 | PickScore: 21.9617 | HPSv2: 0.2810 | HPSv2.1: 0.2920 | ImageReward: 1.0016
159
+
160
+ [gradient_ascent] Batch 400/500 | Samples: 400/500 | Reward (t=136.0): 0.6484 | Reward (Avg): 0.8664 | CLIP: 27.0083 | Aesthetic: 5.9924 | PickScore: 21.9558 | HPSv2: 0.2810 | HPSv2.1: 0.2920 | ImageReward: 0.9993
161
+
162
+ [gradient_ascent] Batch 410/500 | Samples: 410/500 | Reward (t=136.0): 0.4082 | Reward (Avg): 0.8651 | CLIP: 26.9581 | Aesthetic: 5.9885 | PickScore: 21.9445 | HPSv2: 0.2808 | HPSv2.1: 0.2913 | ImageReward: 0.9801
163
+
164
+ [gradient_ascent] Batch 420/500 | Samples: 420/500 | Reward (t=136.0): 0.6602 | Reward (Avg): 0.8630 | CLIP: 26.9090 | Aesthetic: 5.9864 | PickScore: 21.9316 | HPSv2: 0.2805 | HPSv2.1: 0.2910 | ImageReward: 0.9757
165
+
166
+ [gradient_ascent] Batch 430/500 | Samples: 430/500 | Reward (t=136.0): 0.8164 | Reward (Avg): 0.8628 | CLIP: 26.9218 | Aesthetic: 5.9813 | PickScore: 21.9366 | HPSv2: 0.2805 | HPSv2.1: 0.2910 | ImageReward: 0.9807
167
+
168
+ [gradient_ascent] Batch 440/500 | Samples: 440/500 | Reward (t=136.0): 0.8516 | Reward (Avg): 0.8632 | CLIP: 26.8428 | Aesthetic: 5.9854 | PickScore: 21.9359 | HPSv2: 0.2805 | HPSv2.1: 0.2910 | ImageReward: 0.9783
169
+
170
+ [gradient_ascent] Batch 450/500 | Samples: 450/500 | Reward (t=136.0): 0.7812 | Reward (Avg): 0.8623 | CLIP: 26.8449 | Aesthetic: 5.9793 | PickScore: 21.9350 | HPSv2: 0.2803 | HPSv2.1: 0.2908 | ImageReward: 0.9707
171
+
172
+ [gradient_ascent] Batch 460/500 | Samples: 460/500 | Reward (t=136.0): 0.8086 | Reward (Avg): 0.8627 | CLIP: 26.8162 | Aesthetic: 5.9783 | PickScore: 21.9268 | HPSv2: 0.2803 | HPSv2.1: 0.2903 | ImageReward: 0.9679
173
+
174
+ [gradient_ascent] Batch 470/500 | Samples: 470/500 | Reward (t=136.0): 0.6836 | Reward (Avg): 0.8619 | CLIP: 26.7554 | Aesthetic: 5.9671 | PickScore: 21.9094 | HPSv2: 0.2798 | HPSv2.1: 0.2898 | ImageReward: 0.9538
175
+
176
+ [gradient_ascent] Batch 480/500 | Samples: 480/500 | Reward (t=136.0): 0.9141 | Reward (Avg): 0.8607 | CLIP: 26.6819 | Aesthetic: 5.9631 | PickScore: 21.8992 | HPSv2: 0.2798 | HPSv2.1: 0.2896 | ImageReward: 0.9561
177
+
178
+ [gradient_ascent] Batch 490/500 | Samples: 490/500 | Reward (t=136.0): 0.8906 | Reward (Avg): 0.8607 | CLIP: 26.6874 | Aesthetic: 5.9629 | PickScore: 21.8907 | HPSv2: 0.2795 | HPSv2.1: 0.2893 | ImageReward: 0.9554
179
+
180
+ [gradient_ascent] Batch 500/500 | Samples: 500/500 | Reward (t=136.0): 0.9102 | Reward (Avg): 0.8612 | CLIP: 26.6889 | Aesthetic: 5.9645 | PickScore: 21.8882 | HPSv2: 0.2798 | HPSv2.1: 0.2896 | ImageReward: 0.9575
181
+ ✓ Gradient Ascent Avg Reward: 0.8612
182
+ ✓ Gradient Ascent Avg CLIP Score: 26.6889
183
+ ✓ Gradient Ascent Avg Aesthetic Score: 5.9645
184
+ ✓ Gradient Ascent Avg PickScore: 21.8882
185
+ ✓ Gradient Ascent Avg HPSv2 Score: 0.2798
186
+ ✓ Gradient Ascent Avg HPSv2.1 Score: 0.2896
187
+ ✓ Gradient Ascent Avg ImageReward: 0.9575
188
+
189
+ Gradient Ascent Statistics:
190
+ Applications: 11
191
+ Total reward improvement: +0.3047
192
+ Avg reward improvement: +0.0277
193
+
194
+ ✓ Saved LR curve plot to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/lr_curve.png
195
+ Total gradient steps: 11
196
+ LR range: 1.000000 → 1.000000
197
+
198
+ ✓ Saved Rewards curve plot to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/rewards_curve.png
199
+ Total gradient steps: 22
200
+ Reward range: 0.9727 → 0.9961
201
+ Total improvement: +0.0234
202
+
203
+ ======================================================================
204
+ FINAL RESULTS
205
+ ======================================================================
206
+
207
+ Gradient Ascent:
208
+ Avg Reward: 0.8612
209
+ Avg CLIP Score: 26.6889
210
+ Avg Aesthetic: 5.9645
211
+ Avg PickScore: 21.8882
212
+ Avg HPSv2: 0.2798
213
+ Avg HPSv2.1: 0.2896
214
+ Avg ImageReward: 0.9575
215
+
216
+ ✓ Results saved to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/evaluation_results.txt
217
+
218
+ ======================================================================
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/lr_curve.png ADDED
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/rewards_curve.png ADDED
Reward_sana_idealized/__pycache__/eval.cpython-311.pyc ADDED
Binary file (75.8 kB). View file
 
Reward_sana_idealized/__pycache__/gradient_ascent_utils.cpython-311.pyc ADDED
Binary file (17.1 kB). View file
 
Reward_sana_idealized/blip/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .blip_pretrain import *
Reward_sana_idealized/blip/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (202 Bytes). View file
 
Reward_sana_idealized/blip/__pycache__/blip.cpython-311.pyc ADDED
Binary file (4.03 kB). View file
 
Reward_sana_idealized/blip/__pycache__/blip_pretrain.cpython-311.pyc ADDED
Binary file (2.35 kB). View file
 
Reward_sana_idealized/blip/__pycache__/med.cpython-311.pyc ADDED
Binary file (46.9 kB). View file
 
Reward_sana_idealized/blip/blip.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
+ '''
4
+
5
+ import warnings
6
+ warnings.filterwarnings("ignore")
7
+
8
+ import torch
9
+ import os
10
+ from urllib.parse import urlparse
11
+ from timm.models.hub import download_cached_file
12
+ from transformers import BertTokenizer
13
+ from .vit import VisionTransformer, interpolate_pos_embed
14
+
15
+
16
+ def init_tokenizer():
17
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
18
+ tokenizer.add_special_tokens({'bos_token':'[DEC]'})
19
+ tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
20
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
21
+ return tokenizer
22
+
23
+
24
+ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
25
+
26
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
27
+ if vit=='base':
28
+ vision_width = 768
29
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
30
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
31
+ drop_path_rate=0 or drop_path_rate
32
+ )
33
+ elif vit=='large':
34
+ vision_width = 1024
35
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
36
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
37
+ drop_path_rate=0.1 or drop_path_rate
38
+ )
39
+ return visual_encoder, vision_width
40
+
41
+
42
+ def is_url(url_or_filename):
43
+ parsed = urlparse(url_or_filename)
44
+ return parsed.scheme in ("http", "https")
45
+
46
+ def load_checkpoint(model,url_or_filename):
47
+ if is_url(url_or_filename):
48
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
49
+ checkpoint = torch.load(cached_file, map_location='cpu')
50
+ elif os.path.isfile(url_or_filename):
51
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
52
+ else:
53
+ raise RuntimeError('checkpoint url or path is invalid')
54
+
55
+ state_dict = checkpoint['model']
56
+
57
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
58
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
59
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
60
+ model.visual_encoder_m)
61
+ for key in model.state_dict().keys():
62
+ if key in state_dict.keys():
63
+ if state_dict[key].shape!=model.state_dict()[key].shape:
64
+ print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape)
65
+ del state_dict[key]
66
+
67
+ msg = model.load_state_dict(state_dict,strict=False)
68
+ print('load checkpoint from %s'%url_or_filename)
69
+ return model,msg
70
+
Reward_sana_idealized/blip/blip_pretrain.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
+ '''
4
+
5
+ import transformers
6
+ transformers.logging.set_verbosity_error()
7
+
8
+ from torch import nn
9
+ import os
10
+ from .med import BertConfig, BertModel
11
+ from .blip import create_vit, init_tokenizer
12
+
13
+ class BLIP_Pretrain(nn.Module):
14
+ def __init__(self,
15
+ med_config = "med_config.json",
16
+ image_size = 224,
17
+ vit = 'base',
18
+ vit_grad_ckpt = False,
19
+ vit_ckpt_layer = 0,
20
+ embed_dim = 256,
21
+ queue_size = 57600,
22
+ momentum = 0.995,
23
+ ):
24
+ """
25
+ Args:
26
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
27
+ image_size (int): input image size
28
+ vit (str): model size of vision transformer
29
+ """
30
+ super().__init__()
31
+
32
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
33
+
34
+ self.tokenizer = init_tokenizer()
35
+ encoder_config = BertConfig.from_json_file(med_config)
36
+ encoder_config.encoder_width = vision_width
37
+ self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
38
+
39
+ text_width = self.text_encoder.config.hidden_size
40
+
41
+ self.vision_proj = nn.Linear(vision_width, embed_dim)
42
+ self.text_proj = nn.Linear(text_width, embed_dim)
43
+
Reward_sana_idealized/config_analysis_tuning.ipynb ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "a24d02a2",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import json\n",
11
+ "import pandas as pd\n",
12
+ "import numpy as np\n",
13
+ "from pathlib import Path\n",
14
+ "from datetime import datetime\n",
15
+ "import warnings\n",
16
+ "warnings.filterwarnings('ignore')\n",
17
+ "\n",
18
+ "# ============================================================================\n",
19
+ "# SECTION 1: Load and Parse Results from GPU Tuning Runs\n",
20
+ "# ==========================-==================================================\n",
21
+ "print(\"=\" * 80)\n",
22
+ "print(\"LOADING TUNING RESULTS FROM GPU RUNS\")\n",
23
+ "print(\"=\" * 80)\n",
24
+ "\n",
25
+ "results_dir = Path(\"RESULTS_TURNING/run_2\")\n",
26
+ "all_experiments = []\n",
27
+ "baseline_metrics = None\n",
28
+ "\n",
29
+ "# Collect results from all GPU runs\n",
30
+ "for gpu_id in range(8):\n",
31
+ " gpu_dir = results_dir / f\"gpu_{gpu_id}\"\n",
32
+ " results_file = gpu_dir / \"tuning_results.json\"\n",
33
+ " \n",
34
+ " if results_file.exists():\n",
35
+ " with open(results_file, 'r') as f:\n",
36
+ " data = json.load(f)\n",
37
+ " \n",
38
+ " # Extract baseline (same across all GPUs)\n",
39
+ " if baseline_metrics is None and \"baseline\" in data:\n",
40
+ " baseline_metrics = data[\"baseline\"][\"metrics\"]\n",
41
+ " print(f\"\\n📊 Baseline Metrics (cfg_scale=5.0):\")\n",
42
+ " for metric, value in baseline_metrics.items():\n",
43
+ " print(f\" {metric:15s}: {value:.6f}\")\n",
44
+ " \n",
45
+ " # Collect all experiments\n",
46
+ " if \"experiments\" in data:\n",
47
+ " all_experiments.extend(data[\"experiments\"])\n",
48
+ " print(f\"✓ GPU {gpu_id}: {len(data['experiments'])} results loaded\")\n",
49
+ "\n",
50
+ "print(f\"\\n✓ Total experiments loaded: {len(all_experiments)}\")\n",
51
+ "\n",
52
+ "# ============================================================================\n",
53
+ "# SECTION 2: Filter Top Configs with Improvements Across All Metrics\n",
54
+ "# ============================================================================\n",
55
+ "print(\"\\n\" + \"=\" * 80)\n",
56
+ "print(\"FILTERING CONFIGURATIONS WITH IMPROVEMENTS IN ALL METRICS\")\n",
57
+ "print(\"=\" * 80)\n",
58
+ "\n",
59
+ "# Define improvement metrics to track (using ImageReward instead of Reward)\n",
60
+ "improvement_metrics = [\n",
61
+ " \"aesthetic_improvement\", \n",
62
+ " \"imagereward_improvement\", \n",
63
+ " \"clip_improvement\", \n",
64
+ " \"pickscore_improvement\", \n",
65
+ " \"hpsv2_improvement\"\n",
66
+ " ]\n",
67
+ "\n",
68
+ "# Filter experiments with improvements in ALL metrics\n",
69
+ "top_configs = []\n",
70
+ "\n",
71
+ "for exp in all_experiments:\n",
72
+ " if \"improvements\" not in exp or \"config\" not in exp or \"metrics\" not in exp:\n",
73
+ " continue\n",
74
+ " \n",
75
+ " improvements = exp[\"improvements\"]\n",
76
+ " config = exp[\"config\"]\n",
77
+ " metrics = exp[\"metrics\"]\n",
78
+ " \n",
79
+ " # Check if ALL improvements are positive (>0)\n",
80
+ " all_positive = all(improvements.get(metric, -1) > 0 for metric in improvement_metrics)\n",
81
+ " \n",
82
+ " if all_positive:\n",
83
+ " # Calculate aggregate improvement score\n",
84
+ " avg_improvement = np.mean([improvements.get(metric, 0) for metric in improvement_metrics])\n",
85
+ " \n",
86
+ " top_configs.append({\n",
87
+ " \"config\": config,\n",
88
+ " \"metrics\": metrics,\n",
89
+ " \"improvements\": improvements,\n",
90
+ " \"avg_improvement\": avg_improvement\n",
91
+ " })\n",
92
+ "\n",
93
+ "print(f\"✓ Found {len(top_configs)} configurations with improvements in ALL metrics\")\n",
94
+ "\n",
95
+ "# Sort by average improvement\n",
96
+ "top_configs.sort(key=lambda x: x[\"avg_improvement\"], reverse=True)\n",
97
+ "\n",
98
+ "# Get top 10\n",
99
+ "top_10 = top_configs[:10]\n",
100
+ "print(f\"✓ Extracted top 10 best performing configurations\")\n",
101
+ "\n",
102
+ "# ============================================================================\n",
103
+ "# SECTION 3: Create Comprehensive Results Table\n",
104
+ "# ============================================================================\n",
105
+ "print(\"\\n\" + \"=\" * 80)\n",
106
+ "print(\"CREATING COMPREHENSIVE RESULTS TABLE\")\n",
107
+ "print(\"=\" * 80)\n",
108
+ "\n",
109
+ "# Build detailed table data\n",
110
+ "table_data = []\n",
111
+ "\n",
112
+ "for rank, result in enumerate(top_10, 1):\n",
113
+ " cfg = result[\"config\"]\n",
114
+ " metrics = result[\"metrics\"]\n",
115
+ " improvements = result[\"improvements\"]\n",
116
+ " \n",
117
+ " row = {\n",
118
+ " \"Rank\": rank,\n",
119
+ " \"CFG Scale\": cfg.get(\"cfg_scale\", \"N/A\"),\n",
120
+ " \"Grad Config\": cfg.get(\"grad_config\", \"N/A\"),\n",
121
+ " \"Steps\": cfg.get(\"num_grad_steps\", \"N/A\"),\n",
122
+ " \"LR\": cfg.get(\"grad_step_size\", \"N/A\"),\n",
123
+ " \"Momentum\": cfg.get(\"momentum\", \"N/A\"),\n",
124
+ " \"ImageReward\": f\"{metrics.get('imagereward', 0):.6f}\",\n",
125
+ " \"ImageReward ↑\": f\"{improvements.get('imagereward_improvement', 0):+.2f}%\",\n",
126
+ " \"CLIP\": f\"{metrics.get('clip', 0):.4f}\",\n",
127
+ " \"CLIP ↑\": f\"{improvements.get('clip_improvement', 0):+.2f}%\",\n",
128
+ " \"Aesthetic\": f\"{metrics.get('aesthetic', 0):.4f}\",\n",
129
+ " \"Aesthetic ↑\": f\"{improvements.get('aesthetic_improvement', 0):+.2f}%\",\n",
130
+ " \"PickScore\": f\"{metrics.get('pickscore', 0):.4f}\",\n",
131
+ " \"PickScore ↑\": f\"{improvements.get('pickscore_improvement', 0):+.2f}%\",\n",
132
+ " \"HPSv2\": f\"{metrics.get('hpsv2', 0):.4f}\",\n",
133
+ " \"HPSv2 ↑\": f\"{improvements.get('hpsv2_improvement', 0):+.2f}%\",\n",
134
+ " \"Avg Improvement\": f\"{result['avg_improvement']:+.2f}%\",\n",
135
+ " }\n",
136
+ " \n",
137
+ " table_data.append(row)\n",
138
+ "\n",
139
+ "df_top_10 = pd.DataFrame(table_data)\n",
140
+ "\n",
141
+ "print(\"\\n📋 TOP 10 CONFIGURATIONS WITH IMPROVEMENTS IN ALL METRICS:\")\n",
142
+ "print(\"=\" * 180)\n",
143
+ "print(df_top_10.to_string(index=False))\n",
144
+ "print(\"=\" * 180)\n",
145
+ "\n",
146
+ "# ============================================================================\n",
147
+ "# SECTION 4: Visualize and Summary Statistics\n",
148
+ "# ============================================================================\n",
149
+ "print(\"\\n\" + \"=\" * 80)\n",
150
+ "print(\"SUMMARY STATISTICS\")\n",
151
+ "print(\"=\" * 80)\n",
152
+ "\n",
153
+ "# Extract numeric improvement values for analysis\n",
154
+ "improvement_summary = []\n",
155
+ "for result in top_10:\n",
156
+ " improvements = result[\"improvements\"]\n",
157
+ " for metric in [\"imagereward_improvement\", \"clip_improvement\", \"aesthetic_improvement\", \n",
158
+ " \"pickscore_improvement\", \"hpsv2_improvement\"]:\n",
159
+ " metric_name = metric.replace(\"_improvement\", \"\").upper()\n",
160
+ " improvement_summary.append({\n",
161
+ " \"Metric\": metric_name,\n",
162
+ " \"Improvement %\": improvements.get(metric, 0)\n",
163
+ " })\n",
164
+ "\n",
165
+ "df_summary = pd.DataFrame(improvement_summary)\n",
166
+ "\n",
167
+ "print(\"\\n📊 Average Improvements by Metric (Top 10):\")\n",
168
+ "metric_stats = df_summary.groupby(\"Metric\")[\"Improvement %\"].agg([\"mean\", \"std\", \"min\", \"max\"])\n",
169
+ "print(metric_stats.round(2))\n",
170
+ "\n",
171
+ "print(\"\\n📈 Best Configuration Details:\")\n",
172
+ "best = top_10[0]\n",
173
+ "best_cfg = best[\"config\"]\n",
174
+ "best_metrics = best[\"metrics\"]\n",
175
+ "best_improvements = best[\"improvements\"]\n",
176
+ "\n",
177
+ "print(f\"\\n✓ RANK #1 - Best Performing Configuration:\")\n",
178
+ "print(f\" Configuration:\")\n",
179
+ "print(f\" • CFG Scale: {best_cfg.get('cfg_scale')}\")\n",
180
+ "print(f\" • Gradient Config: {best_cfg.get('grad_config')}\")\n",
181
+ "print(f\" • Gradient Steps: {best_cfg.get('num_grad_steps')}\")\n",
182
+ "print(f\" • Step Size: {best_cfg.get('grad_step_size')}\")\n",
183
+ "print(f\" • Momentum: {best_cfg.get('momentum')}\")\n",
184
+ "print(f\"\\n Metrics:\")\n",
185
+ "for metric in [\"imagereward\", \"clip\", \"aesthetic\", \"pickscore\", \"hpsv2\"]:\n",
186
+ " baseline_val = baseline_metrics.get(metric, 0)\n",
187
+ " current_val = best_metrics.get(metric, 0)\n",
188
+ " improvement = best_improvements.get(f\"{metric}_improvement\", 0)\n",
189
+ " print(f\" • {metric:12s}: {current_val:8.6f} (baseline: {baseline_val:8.6f}) ↑ {improvement:+6.2f}%\")\n",
190
+ "\n",
191
+ "print(\"\\n\" + \"=\" * 80)\n",
192
+ "print(\"✓ ANALYSIS COMPLETE - TOP 10 CONFIGURATIONS IDENTIFIED\")\n",
193
+ "print(\"=\" * 80)"
194
+ ]
195
+ }
196
+ ],
197
+ "metadata": {
198
+ "kernelspec": {
199
+ "display_name": "Python 3",
200
+ "language": "python",
201
+ "name": "python3"
202
+ },
203
+ "language_info": {
204
+ "codemirror_mode": {
205
+ "name": "ipython",
206
+ "version": 3
207
+ },
208
+ "file_extension": ".py",
209
+ "mimetype": "text/x-python",
210
+ "name": "python",
211
+ "nbconvert_exporter": "python",
212
+ "pygments_lexer": "ipython3",
213
+ "version": "3.10.18"
214
+ }
215
+ },
216
+ "nbformat": 4,
217
+ "nbformat_minor": 5
218
+ }
Reward_sana_idealized/eval.py ADDED
@@ -0,0 +1,1447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation script for comparing baseline and gradient ascent pipelines using multiple metrics.
3
+
4
+ This script evaluates both pipelines on COCO or Pick-a-Pic validation sets and computes
5
+ various preference and quality metrics.
6
+ """
7
+ import warnings
8
+ warnings.filterwarnings("ignore")
9
+ import torch
10
+ import torch.nn as nn
11
+ import json
12
+ import os
13
+ import sys
14
+ import logging
15
+ from glob import glob
16
+ from pathlib import Path
17
+ from PIL import Image
18
+ from diffusers import SanaPipeline
19
+ from models import LRMRewardModel
20
+ from pipelines.sana_gradient_ascent_pipeline import SanaGradientAscentPipeline
21
+ from torchmetrics.image.fid import FrechetInceptionDistance
22
+ from torchmetrics.multimodal import CLIPScore
23
+ from transformers import CLIPModel, CLIPProcessor
24
+ from tqdm import tqdm
25
+ import numpy as np
26
+ import argparse
27
+ from datasets import load_dataset
28
+ from grad_ascent_configs import get_config, list_configs
29
+ import matplotlib.pyplot as plt
30
+ import matplotlib
31
+ matplotlib.use('Agg') # Use non-interactive backend
32
+
33
+ from huggingface_hub import hf_hub_download
34
+
35
+ import random
36
+
37
+
38
+ SANA_PROFILE_TO_MODEL_ID = {
39
+ "sana_600m_512": "Efficient-Large-Model/Sana_600M_512px_diffusers",
40
+ "sana_1600m_512": "Efficient-Large-Model/Sana_1600M_512px_diffusers",
41
+ "sana_sprint_0_6b_1024": "Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers",
42
+ "sana_sprint_1_6b_1024": "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
43
+ }
44
+
45
+
46
+ def configure_hf_runtime(hf_cache_dir=None, force_offline=False):
47
+ """Set Hugging Face cache/offline environment for cluster-safe execution."""
48
+ cache_dir = hf_cache_dir or os.getenv("HF_HUB_CACHE") or os.getenv("HUGGINGFACE_HUB_CACHE")
49
+ if cache_dir:
50
+ os.environ["HF_HUB_CACHE"] = cache_dir
51
+ os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
52
+ os.environ["HF_HOME"] = os.path.dirname(cache_dir)
53
+
54
+ env_offline = os.getenv("HF_HUB_OFFLINE", "0").strip().lower() in {"1", "true", "yes", "on"}
55
+ offline_enabled = bool(force_offline or env_offline)
56
+ if offline_enabled:
57
+ os.environ["HF_DATASETS_OFFLINE"] = "1"
58
+ os.environ["HF_METRICS_OFFLINE"] = "1"
59
+ os.environ["HF_MODULES_OFFLINE"] = "1"
60
+ os.environ["TRANSFORMERS_OFFLINE"] = "1"
61
+ os.environ["DIFFUSERS_OFFLINE"] = "1"
62
+ os.environ["HF_HUB_OFFLINE"] = "1"
63
+
64
+ return cache_dir, offline_enabled
65
+
66
+
67
+ def resolve_default_lrm_model():
68
+ """Prefer the local SANA reward checkpoint when available."""
69
+ project_root = Path(__file__).resolve().parents[1]
70
+ default_ckpt_dir = (
71
+ project_root
72
+ / "lrm"
73
+ / "lrm_sana"
74
+ / "logs"
75
+ / "v8"
76
+ / "reward_model"
77
+ / "step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951"
78
+ / "checkpoint-gstep76000"
79
+ )
80
+ if default_ckpt_dir.exists():
81
+ return str(default_ckpt_dir)
82
+ return ""
83
+
84
+
85
+ def load_pickapic_prompts(max_samples=None, cache_dir=None, offline=False):
86
+ """Load Pick-a-Pic prompts with robust offline fallback to cached parquet shards."""
87
+ split = "validation_unique"
88
+
89
+ if not offline:
90
+ try:
91
+ ds = load_dataset("pickapic-anonymous/pickapic_v1", split=split, streaming=True)
92
+ prompts = []
93
+ for i, sample in enumerate(ds):
94
+ prompts.append(sample["caption"])
95
+ if max_samples and i + 1 >= max_samples:
96
+ break
97
+ return prompts
98
+ except Exception as e:
99
+ print(f"Warning: online streaming load failed ({e}). Trying cached offline parquet shards.")
100
+
101
+ cache_candidates = []
102
+ for p in [
103
+ cache_dir,
104
+ os.getenv("HF_HUB_CACHE"),
105
+ os.getenv("HUGGINGFACE_HUB_CACHE"),
106
+ (os.path.join(os.getenv("HF_HOME"), "hub") if os.getenv("HF_HOME") else None),
107
+ os.path.expanduser("~/.cache/huggingface/hub"),
108
+ "/scratch/rr81/ma5430/.cache/huggingface/hub",
109
+ ]:
110
+ if p and p not in cache_candidates:
111
+ cache_candidates.append(p)
112
+
113
+ for cache_root in cache_candidates:
114
+ repo_cache = os.path.join(cache_root, "datasets--pickapic-anonymous--pickapic_v1")
115
+ if not os.path.isdir(repo_cache):
116
+ continue
117
+
118
+ snapshot_dir = None
119
+ ref_main = os.path.join(repo_cache, "refs", "main")
120
+ if os.path.isfile(ref_main):
121
+ revision = open(ref_main, "r", encoding="utf-8").read().strip()
122
+ candidate = os.path.join(repo_cache, "snapshots", revision)
123
+ if os.path.isdir(candidate):
124
+ snapshot_dir = candidate
125
+
126
+ if snapshot_dir is None:
127
+ snapshots = sorted(glob(os.path.join(repo_cache, "snapshots", "*")))
128
+ if snapshots:
129
+ snapshot_dir = snapshots[-1]
130
+
131
+ if snapshot_dir is None:
132
+ continue
133
+
134
+ data_dir = os.path.join(snapshot_dir, "data")
135
+ if not os.path.isdir(data_dir):
136
+ continue
137
+
138
+ selected_split = split
139
+ parquet_files = sorted(glob(os.path.join(data_dir, f"{selected_split}-*.parquet")))
140
+ if not parquet_files:
141
+ for alt_split in ("test_unique", "test"):
142
+ alt_files = sorted(glob(os.path.join(data_dir, f"{alt_split}-*.parquet")))
143
+ if alt_files:
144
+ selected_split = alt_split
145
+ parquet_files = alt_files
146
+ print(f"Offline cache missing split '{split}', falling back to '{selected_split}'.")
147
+ break
148
+
149
+ if not parquet_files:
150
+ continue
151
+
152
+ print(
153
+ f"Loading cached Pick-a-Pic split '{selected_split}' from {len(parquet_files)} parquet shards\n"
154
+ f"cache={repo_cache}"
155
+ )
156
+ ds = load_dataset("parquet", data_files=parquet_files, split="train")
157
+ prompts = ds["caption"]
158
+ if max_samples:
159
+ prompts = prompts[:max_samples]
160
+ return list(prompts)
161
+
162
+ raise RuntimeError(
163
+ "Could not load pickapic prompts in offline mode. "
164
+ "Set --hf_cache_dir to a cache that contains datasets--pickapic-anonymous--pickapic_v1."
165
+ )
166
+
167
+
168
+ def resolve_scorer_device(requested_device, generation_device, min_free_gb_for_gpu=14.0):
169
+ """Choose where metric scorers should run to avoid GPU OOM/cudnn init failures."""
170
+ if requested_device == "cpu":
171
+ return "cpu"
172
+
173
+ if not torch.cuda.is_available() or not str(generation_device).startswith("cuda"):
174
+ return "cpu"
175
+
176
+ if requested_device == "cuda":
177
+ return generation_device
178
+
179
+ # Auto mode: only keep scorers on GPU if enough headroom remains after loading generation models.
180
+ try:
181
+ free_bytes, total_bytes = torch.cuda.mem_get_info(torch.device(generation_device))
182
+ free_gb = free_bytes / (1024 ** 3)
183
+ total_gb = total_bytes / (1024 ** 3)
184
+ print(f"GPU memory before scorer load: {free_gb:.2f} GB free / {total_gb:.2f} GB total")
185
+ if free_gb >= min_free_gb_for_gpu:
186
+ return generation_device
187
+ print(
188
+ f"⚠ Low free VRAM ({free_gb:.2f} GB). Running scorers on CPU to keep diffusion stable. "
189
+ f"Use --scorer_device cuda to force GPU scorers."
190
+ )
191
+ return "cpu"
192
+ except Exception as e:
193
+ print(f"Warning: could not inspect CUDA free memory ({e}). Falling back to CPU scorers.")
194
+ return "cpu"
195
+
196
+
197
+ def configure_cudnn_safely(device):
198
+ """Disable cuDNN when the current GPU or runtime cannot initialize it safely."""
199
+ if not torch.cuda.is_available() or not str(device).startswith("cuda"):
200
+ return
201
+
202
+ try:
203
+ major, minor = torch.cuda.get_device_capability(torch.device(device))
204
+ if (major, minor) < (7, 5):
205
+ print(
206
+ f"⚠ Detected compute capability sm_{major}{minor} (< 75). "
207
+ "Disabling cuDNN to prevent runtime initialization failures."
208
+ )
209
+ torch.backends.cudnn.enabled = False
210
+ return
211
+
212
+ # Force a cuDNN init probe early so failures are handled once at startup.
213
+ _ = torch.backends.cudnn.version()
214
+ except Exception as e:
215
+ print(f"⚠ cuDNN init probe failed ({e}). Disabling cuDNN for this run.")
216
+ torch.backends.cudnn.enabled = False
217
+
218
+
219
+ def resolve_generation_dtype(requested_dtype, device):
220
+ """Pick a safe generation dtype for the current device."""
221
+ req = str(requested_dtype).strip().lower()
222
+
223
+ if req == "fp32":
224
+ return torch.float32
225
+
226
+ if not str(device).startswith("cuda"):
227
+ if req in {"fp16", "bf16", "auto"}:
228
+ print("⚠ Non-CUDA device detected. Falling back to fp32.")
229
+ return torch.float32
230
+
231
+ if req == "fp16":
232
+ return torch.float16
233
+
234
+ if req == "bf16":
235
+ if torch.cuda.is_bf16_supported():
236
+ return torch.bfloat16
237
+ print("⚠ bf16 requested but not supported on this GPU. Falling back to fp16.")
238
+ return torch.float16
239
+
240
+ # auto: prefer bf16 on supported GPUs to avoid fp16 underflow in tiny gradients.
241
+ if torch.cuda.is_bf16_supported():
242
+ return torch.bfloat16
243
+ return torch.float16
244
+
245
+
246
+ def dtype_to_name(dtype: torch.dtype) -> str:
247
+ if dtype == torch.float16:
248
+ return "fp16"
249
+ if dtype == torch.bfloat16:
250
+ return "bf16"
251
+ return "fp32"
252
+
253
+
254
+ def seed_everything(seed: int):
255
+ """Locks down all random number generators for absolute reproducibility."""
256
+ # 1. Python & Numpy
257
+ random.seed(seed)
258
+ np.random.seed(seed)
259
+
260
+ # 2. PyTorch Base
261
+ torch.manual_seed(seed)
262
+ if torch.cuda.is_available():
263
+ torch.cuda.manual_seed(seed)
264
+ torch.cuda.manual_seed_all(seed) # For multi-GPU
265
+
266
+ # 3. cuDNN Determinism (Crucial for consistent gradients)
267
+ torch.backends.cudnn.deterministic = True
268
+ torch.backends.cudnn.benchmark = False
269
+
270
+ # 4. Optional: Force deterministic algorithms for PyTorch 2.0+
271
+ # Uncomment if variance persists, but it may slow down generation slightly
272
+ # torch.use_deterministic_algorithms(True)
273
+
274
+
275
+ class MLP(nn.Module):
276
+ """MLP for aesthetic scoring."""
277
+ def __init__(self):
278
+ super().__init__()
279
+ self.layers = nn.Sequential(
280
+ nn.Linear(768, 1024),
281
+ nn.Dropout(0.2),
282
+ nn.Linear(1024, 128),
283
+ nn.Dropout(0.2),
284
+ nn.Linear(128, 64),
285
+ nn.Dropout(0.1),
286
+ nn.Linear(64, 16),
287
+ nn.Linear(16, 1),
288
+ )
289
+
290
+ @torch.no_grad()
291
+ def forward(self, embed):
292
+ return self.layers(embed)
293
+
294
+
295
+ class AestheticScorer(torch.nn.Module):
296
+ """Aesthetic scorer using CLIP and MLP."""
297
+ def __init__(self, dtype, device, clip_name_or_path="openai/clip-vit-large-patch14",
298
+ aesthetic_path="./sac+logos+ava1-l14-linearMSE.pth"):
299
+ super().__init__()
300
+ self.clip = CLIPModel.from_pretrained(clip_name_or_path)
301
+ self.processor = CLIPProcessor.from_pretrained(clip_name_or_path)
302
+ self.mlp = MLP()
303
+
304
+ # Load aesthetic weights
305
+ if os.path.exists(aesthetic_path):
306
+ state_dict = torch.load(aesthetic_path, map_location='cpu')
307
+ self.mlp.load_state_dict(state_dict)
308
+ else:
309
+ print(f"Warning: Aesthetic weights not found at {aesthetic_path}")
310
+
311
+ self.dtype = dtype
312
+ self.to(device)
313
+ self.eval()
314
+
315
+ @torch.no_grad()
316
+ def __call__(self, images):
317
+ device = next(self.parameters()).device
318
+ inputs = self.processor(images=images, return_tensors="pt")
319
+ inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()}
320
+ embed = self.clip.get_image_features(**inputs)
321
+ # normalize embedding
322
+ embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
323
+ return self.mlp(embed).squeeze(1)
324
+
325
+
326
+ class TeeLogger:
327
+ """Logger that writes to both console and file."""
328
+ def __init__(self, log_file):
329
+ self.terminal = sys.stdout
330
+ self.log = open(log_file, 'w')
331
+
332
+ def write(self, message):
333
+ self.terminal.write(message)
334
+ self.log.write(message)
335
+ self.log.flush()
336
+
337
+ def flush(self):
338
+ self.terminal.flush()
339
+ self.log.flush()
340
+
341
+ def close(self):
342
+ self.log.close()
343
+
344
+
345
+ def setup_logging(output_dir):
346
+ """Setup logging to both console and file."""
347
+ output_path = Path(output_dir)
348
+ output_path.mkdir(parents=True, exist_ok=True)
349
+ log_file = output_path / "log.log"
350
+
351
+ # Redirect stdout to both console and file
352
+ tee = TeeLogger(log_file)
353
+ sys.stdout = tee
354
+
355
+ return tee, log_file
356
+
357
+
358
+ def load_validation_data(data_dir, max_samples=None, dataset_type="coco", hf_cache_dir=None, offline=False):
359
+ """Load validation prompts and image paths.
360
+
361
+ Args:
362
+ data_dir: Path to data directory
363
+ max_samples: Maximum number of samples to load
364
+ dataset_type: Type of dataset ("coco" or "pickapic")
365
+
366
+ Returns:
367
+ prompts: List of text prompts
368
+ image_paths: List of image paths (None for pickapic streaming dataset)
369
+ """
370
+ if dataset_type == "coco":
371
+ data_dir = Path(data_dir)
372
+ val_json = data_dir / "coco" / "caption_val.json"
373
+
374
+ if not val_json.exists():
375
+ raise FileNotFoundError(f"Validation JSON not found: {val_json}")
376
+
377
+ with open(val_json, 'r') as f:
378
+ data = json.load(f)
379
+
380
+ # Validate that image folder exists
381
+ val_img_dir = data_dir / "coco" / "images" / "val"
382
+ if not val_img_dir.exists():
383
+ raise FileNotFoundError(f"Validation image directory not found: {val_img_dir}")
384
+
385
+ # Parse data
386
+ prompts = []
387
+ image_paths = []
388
+ for img_path, caption in data.items():
389
+ full_path = data_dir / "coco" / img_path
390
+ if full_path.exists():
391
+ prompts.append(caption)
392
+ image_paths.append(str(full_path))
393
+ else:
394
+ print(f"Warning: Image not found: {full_path}")
395
+
396
+ if max_samples:
397
+ prompts = prompts[:max_samples]
398
+ image_paths = image_paths[:max_samples]
399
+
400
+ print(f"Loaded {len(prompts)} COCO validation samples")
401
+ return prompts, image_paths
402
+
403
+ elif dataset_type == "pickapic":
404
+ print("Loading Pick-a-Pic validation prompts...")
405
+ prompts = load_pickapic_prompts(max_samples=max_samples, cache_dir=hf_cache_dir, offline=offline)
406
+
407
+ print(f"Loaded {len(prompts)} Pick-a-Pic validation samples")
408
+ return prompts, None # No reference images for Pick-a-Pic
409
+
410
+ else:
411
+ raise ValueError(f"Unknown dataset type: {dataset_type}. Choose 'coco' or 'pickapic'.")
412
+
413
+
414
+ def generate_and_evaluate(
415
+ pipeline,
416
+ prompts,
417
+ image_paths,
418
+ device,
419
+ dtype,
420
+ num_inference_steps=20,
421
+ guidance_scale=7.5,
422
+ seed=42,
423
+ batch_size=1,
424
+ apply_gradient_ascent=False,
425
+ mode_name="baseline",
426
+ log_interval=10,
427
+ output_dir=None,
428
+ save_images=False,
429
+ clip_scorer=None,
430
+ aesthetic_scorer=None,
431
+ pick_scorer=None,
432
+ hpsv2_scorer=None,
433
+ hpsv21_scorer=None,
434
+ imagereward_scorer=None,
435
+ compute_fid=True,
436
+ capture_trajectory=False
437
+ ):
438
+ """Generate images and update FID metric."""
439
+ pipeline.to(device)
440
+
441
+ print(f"\nGenerating images with {mode_name} mode...")
442
+
443
+
444
+ all_rewards = []
445
+ all_clip_scores = []
446
+ all_aesthetic_scores = []
447
+ all_pick_scores = []
448
+ all_hpsv2_scores = []
449
+ all_hpsv21_scores = []
450
+ all_imagereward_scores = []
451
+ lr_history_first_image = None # Store LR history for first image
452
+ trajectory_first_image = []
453
+ num_batches = (len(prompts) + batch_size - 1) // batch_size
454
+
455
+ # Create output directory if saving images
456
+ if save_images and output_dir:
457
+ mode_output_dir = Path(output_dir) / mode_name
458
+ mode_output_dir.mkdir(parents=True, exist_ok=True)
459
+
460
+ # Disable internal progress bars
461
+ pipeline.set_progress_bar_config(disable=True)
462
+
463
+ for idx, i in enumerate(tqdm(range(0, len(prompts), batch_size), desc=f"Generating {mode_name}")):
464
+ batch_prompts = prompts[i:i+batch_size]
465
+ batch_real_paths = image_paths[i:i+batch_size] if image_paths is not None else None
466
+ batch_num = idx + 1
467
+
468
+ # Initialize FID metric if needed
469
+ fid_metric = None
470
+ real_images_tensor = None
471
+
472
+ if compute_fid and batch_real_paths is not None:
473
+ fid_metric = FrechetInceptionDistance().to(device)
474
+
475
+ # Load and update FID with real images for this batch
476
+ real_images = []
477
+ for path in batch_real_paths:
478
+ img = Image.open(path).convert("RGB")
479
+ img = img.resize((512, 512)) # Inception v3 input size
480
+ img_array = np.array(img)
481
+ real_images.append(img_array)
482
+
483
+ # Convert to tensor [B, H, W, C] -> [B, C, H, W]
484
+ real_images_tensor = torch.from_numpy(np.stack(real_images)).permute(0, 3, 1, 2).float()
485
+ real_images_tensor = real_images_tensor.to(device)
486
+
487
+ # Generate images
488
+ generator = torch.Generator(device=device).manual_seed(seed + i)
489
+
490
+ # Only capture trajectory for the very first batch to save RAM
491
+ def trajectory_callback(step, timestep, latents):
492
+ if idx == 0 and capture_trajectory:
493
+ # Detach and move to CPU immediately to prevent VRAM OOM
494
+ trajectory_first_image.append(latents.detach().cpu().clone())
495
+
496
+ with torch.no_grad():
497
+ result = pipeline(
498
+ prompt=batch_prompts,
499
+ num_inference_steps=num_inference_steps,
500
+ guidance_scale=guidance_scale,
501
+ generator=generator,
502
+ track_rewards=True,
503
+ print_rewards=False,
504
+ apply_gradient_ascent=apply_gradient_ascent,
505
+ verbose_grad=False,
506
+ callback=trajectory_callback if capture_trajectory else None,
507
+ callback_steps=1
508
+ )
509
+
510
+ # Process generated images
511
+ images = result.images
512
+
513
+ # Update FID metric if computing it
514
+ if compute_fid and fid_metric is not None:
515
+ image_tensors = []
516
+
517
+ for img in images:
518
+ img_resized = img.resize((512, 512)) # Inception v3 input size
519
+ img_array = np.array(img_resized)
520
+ image_tensors.append(img_array)
521
+
522
+ # Convert to tensor and update FID
523
+ images_tensor = torch.from_numpy(np.stack(image_tensors)).permute(0, 3, 1, 2).float()
524
+ images_tensor = images_tensor.to(device)
525
+
526
+ if batch_size == 1:
527
+ real_images_tensor = torch.cat([real_images_tensor, real_images_tensor], dim=0).to(dtype=torch.uint8)
528
+ images_tensor = torch.cat([images_tensor, images_tensor], dim=0).to(dtype=torch.uint8)
529
+ fid_metric.update(real_images_tensor, real=True)
530
+ fid_metric.update(images_tensor, real=False)
531
+
532
+ # Track rewards - get the final timestep reward (t=0)
533
+ current_batch_final_reward = None
534
+ current_batch_final_timestep = None
535
+ if hasattr(pipeline, 'reward_history') and pipeline.reward_history:
536
+ # For each image, get the reward from the last denoising step (t=0 or closest to 0)
537
+ num_steps_per_image = num_inference_steps
538
+
539
+ # Get the last entry which corresponds to the final timestep of the last image in batch
540
+ final_entry = pipeline.reward_history[-1]
541
+ current_batch_final_reward = final_entry['reward_score']
542
+ current_batch_final_timestep = final_entry['timestep']
543
+ all_rewards.append(current_batch_final_reward)
544
+
545
+ # Capture LR history from first image if gradient ascent is enabled
546
+ if apply_gradient_ascent and idx == 0 and lr_history_first_image is None:
547
+ if hasattr(pipeline, 'grad_guidance') and pipeline.grad_guidance:
548
+ grad_stats = pipeline.grad_guidance.get_statistics()
549
+ if grad_stats and 'detailed_stats' in grad_stats:
550
+ # Extract LR history from the gradient ascent statistics
551
+ lr_history_first_image = {
552
+ 'prompt': batch_prompts[0],
553
+ 'timesteps': [],
554
+ 'learning_rates': [], # All LR values from all gradient steps
555
+ 'rewards': []
556
+ }
557
+ for stat in grad_stats['detailed_stats']:
558
+ lr_history_first_image['timesteps'].append(stat['timestep'])
559
+ if 'lr_history' in stat:
560
+ # Extend with all LR values from this timestep's gradient steps
561
+ lr_history_first_image['learning_rates'].extend(stat['lr_history'])
562
+ # Collect all rewards from reward_history for each gradient step
563
+ if 'reward_history' in stat:
564
+ lr_history_first_image['rewards'].extend(stat['reward_history'])
565
+
566
+ # Compute CLIP score
567
+ if clip_scorer is not None:
568
+ clip_device = next(clip_scorer.parameters()).device
569
+ # Convert PIL images to tensor format for CLIP score [C, H, W] in range [0, 1]
570
+ for img, prompt in zip(images, batch_prompts):
571
+ img_array = np.array(img).astype(np.float32)
572
+ img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).to(clip_device)
573
+ clip_score = clip_scorer(img_tensor, [prompt]).item()
574
+ all_clip_scores.append(clip_score)
575
+
576
+ # Compute aesthetic score
577
+ if aesthetic_scorer is not None:
578
+ aesthetic_scores = aesthetic_scorer(images)
579
+ if isinstance(aesthetic_scores, torch.Tensor):
580
+ aesthetic_scores = aesthetic_scores.cpu().numpy()
581
+ if aesthetic_scores.ndim == 0:
582
+ aesthetic_scores = [aesthetic_scores.item()]
583
+ all_aesthetic_scores.extend(aesthetic_scores.tolist() if hasattr(aesthetic_scores, 'tolist') else [aesthetic_scores])
584
+
585
+ # Compute PickScore
586
+ if pick_scorer is not None:
587
+ for img, prompt in zip(images, batch_prompts):
588
+ pick_score = pick_scorer(prompt, [img])[0]
589
+ all_pick_scores.append(pick_score)
590
+
591
+ # Compute HPSv2 score
592
+ if hpsv2_scorer is not None:
593
+ for img, prompt in zip(images, batch_prompts):
594
+ hpsv2_score = hpsv2_scorer.score(img, prompt)[0]
595
+ all_hpsv2_scores.append(hpsv2_score)
596
+
597
+ # Compute HPSv2.1 score
598
+ if hpsv21_scorer is not None:
599
+ for img, prompt in zip(images, batch_prompts):
600
+ hpsv21_score = hpsv21_scorer.score(img, prompt)[0]
601
+ all_hpsv21_scores.append(hpsv21_score)
602
+
603
+ # Compute ImageReward score
604
+ if imagereward_scorer is not None:
605
+ for img, prompt in zip(images, batch_prompts):
606
+ imagereward_score = imagereward_scorer.score(prompt, img)
607
+ all_imagereward_scores.append(imagereward_score)
608
+
609
+ # Save generated images if requested
610
+ if save_images and output_dir:
611
+ for img_idx, img in enumerate(images):
612
+ global_idx = i + img_idx
613
+ img_path = mode_output_dir / f"sample_{global_idx:05d}.png"
614
+ img.save(img_path)
615
+
616
+ # Log intermediate FID and metrics every log_interval batches
617
+ if batch_num % log_interval == 0 or batch_num == num_batches:
618
+ num_samples_processed = min(i + batch_size, len(prompts))
619
+ log_msg = f"\n[{mode_name}] Batch {batch_num}/{num_batches} | Samples: {num_samples_processed}/{len(prompts)}"
620
+
621
+ # Add FID if computing
622
+ if compute_fid and fid_metric is not None:
623
+ try:
624
+ current_fid = fid_metric.compute().item()
625
+ log_msg += f" | FID: {current_fid:.4f}"
626
+ except Exception as e:
627
+ log_msg += f" | FID: Computing..."
628
+
629
+ # Add reward - show both final timestep reward and average
630
+ if all_rewards:
631
+ avg_reward = np.mean(all_rewards)
632
+ if current_batch_final_reward is not None:
633
+ log_msg += f" | Reward (t={current_batch_final_timestep}): {current_batch_final_reward:.4f}"
634
+ log_msg += f" | Reward (Avg): {avg_reward:.4f}"
635
+ else:
636
+ log_msg += f" | Reward (Avg): {avg_reward:.4f}"
637
+
638
+ # Add CLIP if computing
639
+ if clip_scorer is not None and all_clip_scores:
640
+ log_msg += f" | CLIP: {np.mean(all_clip_scores):.4f}"
641
+
642
+ # Add aesthetic if computing
643
+ if aesthetic_scorer is not None and all_aesthetic_scores:
644
+ log_msg += f" | Aesthetic: {np.mean(all_aesthetic_scores):.4f}"
645
+
646
+ # Add PickScore
647
+ if pick_scorer is not None and all_pick_scores:
648
+ log_msg += f" | PickScore: {np.mean(all_pick_scores):.4f}"
649
+
650
+ # Add HPSv2
651
+ if hpsv2_scorer is not None and all_hpsv2_scores:
652
+ log_msg += f" | HPSv2: {np.mean(all_hpsv2_scores):.4f}"
653
+
654
+ # Add HPSv2.1
655
+ if hpsv21_scorer is not None and all_hpsv21_scores:
656
+ log_msg += f" | HPSv2.1: {np.mean(all_hpsv21_scores):.4f}"
657
+
658
+ # Add ImageReward
659
+ if imagereward_scorer is not None and all_imagereward_scores:
660
+ log_msg += f" | ImageReward: {np.mean(all_imagereward_scores):.4f}"
661
+
662
+ print(log_msg)
663
+
664
+ # Re-enable progress bars
665
+ pipeline.set_progress_bar_config(disable=False)
666
+
667
+ avg_reward = np.mean(all_rewards) if all_rewards else 0.0
668
+ avg_clip_score = np.mean(all_clip_scores) if all_clip_scores else 0.0
669
+ avg_aesthetic_score = np.mean(all_aesthetic_scores) if all_aesthetic_scores else 0.0
670
+ avg_pick_score = np.mean(all_pick_scores) if all_pick_scores else 0.0
671
+ avg_hpsv2_score = np.mean(all_hpsv2_scores) if all_hpsv2_scores else 0.0
672
+ avg_hpsv21_score = np.mean(all_hpsv21_scores) if all_hpsv21_scores else 0.0
673
+ avg_imagereward_score = np.mean(all_imagereward_scores) if all_imagereward_scores else 0.0
674
+
675
+ return avg_reward, fid_metric, avg_clip_score, avg_aesthetic_score, avg_pick_score, avg_hpsv2_score, avg_hpsv21_score, avg_imagereward_score, lr_history_first_image, trajectory_first_image
676
+
677
+
678
+ def auto_increment_path(base_path):
679
+ """
680
+ Create an auto-incrementing run folder inside base_path.
681
+ Returns: base_path/run_1, base_path/run_2, etc.
682
+ """
683
+ base_path = Path(base_path)
684
+ base_path.mkdir(parents=True, exist_ok=True) # Ensure base directory exists
685
+
686
+ i = 1
687
+ while True:
688
+ new_path = base_path / f"run_{i}"
689
+ if not new_path.exists():
690
+ return new_path
691
+ i += 1
692
+
693
+
694
+ def main():
695
+ parser = argparse.ArgumentParser(description="Evaluate baseline and gradient ascent pipelines")
696
+ parser.add_argument("--data_dir", type=str, default="./data", help="Path to data directory")
697
+ parser.add_argument("--dataset_type", type=str, default="coco", choices=["coco", "pickapic"],
698
+ help="Dataset to use for evaluation: coco or pickapic (default: coco)")
699
+ parser.add_argument("--base_model", type=str, default=None, help="Override SANA base model repo id")
700
+ parser.add_argument(
701
+ "--model_variant",
702
+ type=str,
703
+ default="sana_600m_512",
704
+ choices=list(SANA_PROFILE_TO_MODEL_ID.keys()),
705
+ help="SANA model profile to use",
706
+ )
707
+ parser.add_argument("--lrm_model", type=str, default=None, help="SANA reward checkpoint path (directory or model.safetensors).")
708
+ parser.add_argument("--hf_cache_dir", type=str, default="/scratch/rr81/ma5430/.cache/huggingface/hub", help="Shared HF cache directory")
709
+ parser.add_argument("--offline", action="store_true", help="Force fully offline mode (recommended on GPU nodes)")
710
+ parser.add_argument("--num_steps", type=int, default=50, help="Number of inference steps")
711
+ parser.add_argument("--cfg_scale", type=float, default=4.5, help="Classifier-free guidance scale")
712
+ parser.add_argument("--dtype", type=str, default="bf16", choices=["auto", "bf16", "fp16", "fp32"],
713
+ help="Generation/reward dtype. bf16 is recommended for tiny gradient stability.")
714
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
715
+ parser.add_argument("--max_samples", type=int, default=None, help="Max samples to evaluate (None for all)")
716
+ parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generation (use 1 for reward model compatibility)")
717
+ parser.add_argument("--fid_batch_size", type=int, default=32, help="Batch size for FID computation")
718
+ parser.add_argument("--log_interval", type=int, default=10, help="Log FID and metrics every N batches")
719
+ parser.add_argument("--output_dir", type=str, default="eval_outputs", help="Directory to save generated images and results")
720
+ parser.add_argument("--save_images", action="store_true", help="Save all generated images to output directory")
721
+ parser.add_argument("--mode", type=str, default="both", choices=["baseline", "gradient_ascent", "both"],
722
+ help="Which evaluation to run: baseline, gradient_ascent, or both (default: both)")
723
+
724
+ # Metrics selection
725
+ parser.add_argument("--metrics", type=str, nargs="+", default=["clip", "aesthetic"],
726
+ choices=["fid", "clip", "aesthetic", "pickscore", "hpsv2", "hpsv21", "imagereward"],
727
+ help="Which metrics to evaluate (default: clip aesthetic)")
728
+ parser.add_argument("--scorer_device", type=str, default="auto", choices=["auto", "cpu", "cuda"],
729
+ help="Device for metric scorers. auto keeps scorers on GPU only when enough VRAM is free.")
730
+
731
+ # Gradient ascent config
732
+ parser.add_argument("--grad_config", type=str, default=None,
733
+ help=f"Gradient ascent config preset (available: {', '.join(list_configs())}). "
734
+ "If provided, overrides individual grad_* arguments.")
735
+ parser.add_argument("--grad_range_start", type=int, default=0, help="Gradient timestep range start")
736
+ parser.add_argument("--grad_range_end", type=int, default=700, help="Gradient timestep range end")
737
+ parser.add_argument("--grad_steps", type=int, default=5, help="Number of gradient steps per timestep (use 5 for better reward improvement)")
738
+ parser.add_argument("--grad_step_size", type=float, default=0.1, help="Gradient step size (initial LR)")
739
+
740
+ # Config overrides (these override values from grad_config if specified)
741
+ parser.add_argument("--override_momentum", type=float, default=None, help="Override momentum value from grad_config")
742
+ parser.add_argument("--override_num_grad_steps", type=int, default=None, help="Override num_grad_steps from grad_config")
743
+ parser.add_argument("--override_grad_step_size", type=float, default=None, help="Override grad_step_size from grad_config")
744
+
745
+ # Cuda
746
+ parser.add_argument("--cuda", type=int, default=0, help="Use CUDA device id")
747
+
748
+ args = parser.parse_args()
749
+
750
+ hf_cache_dir, offline_enabled = configure_hf_runtime(args.hf_cache_dir, force_offline=args.offline)
751
+ if args.base_model is None:
752
+ args.base_model = SANA_PROFILE_TO_MODEL_ID[args.model_variant]
753
+ if args.lrm_model is None:
754
+ args.lrm_model = resolve_default_lrm_model()
755
+ if not args.lrm_model:
756
+ raise ValueError(
757
+ "No SANA reward checkpoint found. Provide --lrm_model pointing to checkpoint dir or model.safetensors"
758
+ )
759
+
760
+ seed_everything(args.seed)
761
+
762
+ # Configuration
763
+ device = f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu"
764
+ dtype = resolve_generation_dtype(args.dtype, device)
765
+ os.environ["ACCELERATE_MIXED_PRECISION"] = dtype_to_name(dtype)
766
+ configure_cudnn_safely(device)
767
+
768
+ # Create auto-incremented output directory
769
+ args.output_dir = auto_increment_path(args.output_dir)
770
+
771
+ # Setup logging to file
772
+ tee_logger, log_file = setup_logging(args.output_dir)
773
+
774
+ print("="*70)
775
+ print("FID EVALUATION: BASELINE vs GRADIENT ASCENT")
776
+ print("="*70)
777
+ print(f"\nLogging to: {log_file}")
778
+ print(f"\nDevice: {device}")
779
+ print(f"Dataset: {args.dataset_type.upper()}")
780
+ print(f"Data directory: {args.data_dir}")
781
+ print(f"Base model: {args.base_model}")
782
+ print(f"Model variant: {args.model_variant}")
783
+ print(f"LRM model: {args.lrm_model}")
784
+ print(f"HF cache dir: {hf_cache_dir or 'default'}")
785
+ print(f"HF offline mode: {offline_enabled}")
786
+ print(f"Inference steps: {args.num_steps}")
787
+ print(f"CFG scale: {args.cfg_scale}")
788
+ print(f"Batch size: {args.batch_size}")
789
+ print(f"Max samples: {args.max_samples or 'All'}")
790
+ print(f"Generation dtype: {dtype_to_name(dtype)}")
791
+ print(f"Output directory: {args.output_dir}")
792
+ print(f"Save images: {args.save_images}")
793
+ print(f"Evaluation mode: {args.mode}")
794
+ print(f"Metrics to evaluate: {', '.join(args.metrics).upper()}")
795
+ if args.grad_config:
796
+ print(f"Gradient ascent config: {args.grad_config}")
797
+
798
+ # Load validation data
799
+ print("\n" + "="*70)
800
+ print("1. LOADING VALIDATION DATA")
801
+ print("="*70)
802
+ prompts, image_paths = load_validation_data(
803
+ args.data_dir,
804
+ args.max_samples,
805
+ args.dataset_type,
806
+ hf_cache_dir=hf_cache_dir,
807
+ offline=offline_enabled,
808
+ )
809
+
810
+ # Automatically disable FID if no reference images available (e.g., Pick-a-Pic dataset)
811
+ can_compute_fid = image_paths is not None
812
+ if not can_compute_fid and "fid" in args.metrics:
813
+ print("\n⚠ Warning: FID metric requested but no reference images available. FID will be skipped.")
814
+ args.metrics = [m for m in args.metrics if m != "fid"]
815
+
816
+ # Load reward model
817
+ print("\n" + "="*70)
818
+ print("2. LOADING REWARD MODEL")
819
+ print("="*70)
820
+ reward_model = LRMRewardModel(
821
+ pretrained_model_name_or_path=args.base_model,
822
+ lrm_model_path=args.lrm_model,
823
+ model_profile=args.model_variant,
824
+ guidance_scale=args.cfg_scale,
825
+ device=device
826
+ )
827
+ if dtype == torch.float16:
828
+ reward_model = reward_model.half()
829
+ elif dtype == torch.bfloat16:
830
+ reward_model = reward_model.to(dtype=torch.bfloat16)
831
+ else:
832
+ reward_model = reward_model.to(dtype=torch.float32)
833
+ reward_model.eval()
834
+ print("✓ Reward model loaded")
835
+
836
+ # Load pipeline
837
+ print("\n" + "="*70)
838
+ print("3. LOADING PIPELINE")
839
+ print("="*70)
840
+
841
+ pretrained_kwargs = {"local_files_only": offline_enabled}
842
+ if hf_cache_dir:
843
+ pretrained_kwargs["cache_dir"] = hf_cache_dir
844
+
845
+ base_pipeline = SanaPipeline.from_pretrained(
846
+ args.base_model,
847
+ torch_dtype=dtype,
848
+ **pretrained_kwargs,
849
+ )
850
+ print(f"✓ Loaded SANA base model: {args.base_model}")
851
+
852
+ pipeline = SanaGradientAscentPipeline(**base_pipeline.components)
853
+ pipeline = pipeline.to(device)
854
+ pipeline.set_reward_model(reward_model)
855
+ print("✓ Pipeline loaded")
856
+
857
+ scorer_device = resolve_scorer_device(args.scorer_device, device)
858
+ scorer_dtype = dtype if str(scorer_device).startswith("cuda") else torch.float32
859
+ print(f"Scorer device: {scorer_device}")
860
+
861
+ if torch.cuda.is_available():
862
+ torch.cuda.empty_cache()
863
+
864
+ # Load CLIP scorer
865
+ print("\n" + "="*70)
866
+ print("3.5. LOADING CLIP AND AESTHETIC SCORERS")
867
+ print("="*70)
868
+
869
+ # Only load scorers for requested metrics
870
+ clip_scorer = None
871
+ aesthetic_scorer = None
872
+ pick_scorer = None
873
+ hpsv2_scorer = None
874
+ hpsv21_scorer = None
875
+ imagereward_scorer = None
876
+
877
+ if "clip" in args.metrics:
878
+ try:
879
+ clip_scorer = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14").to(scorer_device)
880
+ print("✓ CLIP scorer loaded")
881
+ except Exception as e:
882
+ print(f"Warning: Could not load CLIP scorer: {e}")
883
+ clip_scorer = None
884
+ else:
885
+ print("⊘ CLIP scorer skipped (not in selected metrics)")
886
+
887
+ if "aesthetic" in args.metrics:
888
+ try:
889
+ aesthetic_scorer = AestheticScorer(dtype=scorer_dtype, device=scorer_device)
890
+ print("✓ Aesthetic scorer loaded")
891
+ except Exception as e:
892
+ print(f"Warning: Could not load Aesthetic scorer: {e}")
893
+ aesthetic_scorer = None
894
+ else:
895
+ print("⊘ Aesthetic scorer skipped (not in selected metrics)")
896
+
897
+ if "pickscore" in args.metrics:
898
+ try:
899
+ from pick_score import PickScorer
900
+ pick_scorer = PickScorer(
901
+ processor_name_or_path="laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
902
+ model_pretrained_name_or_path="yuvalkirstain/PickScore_v1",
903
+ device=scorer_device
904
+ )
905
+ print("✓ PickScore scorer loaded")
906
+ except Exception as e:
907
+ print(f"Warning: Could not load PickScore scorer: {e}")
908
+ pick_scorer = None
909
+ else:
910
+ print("⊘ PickScore scorer skipped (not in selected metrics)")
911
+
912
+ if "hpsv2" in args.metrics:
913
+ try:
914
+ from hpsv2_score import HPSv2Scorer
915
+ hf_dl_kwargs = {"local_files_only": offline_enabled}
916
+ if hf_cache_dir:
917
+ hf_dl_kwargs["cache_dir"] = hf_cache_dir
918
+ hpsv2_scorer = HPSv2Scorer(
919
+ clip_pretrained_name_or_path=hf_hub_download(
920
+ repo_id="laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
921
+ filename="open_clip_pytorch_model.bin",
922
+ **hf_dl_kwargs,
923
+ ),
924
+ model_pretrained_name_or_path=hf_hub_download(
925
+ repo_id="xswu/HPSv2",
926
+ filename="HPS_v2_compressed.pt",
927
+ **hf_dl_kwargs,
928
+ ),
929
+ device=scorer_device
930
+ )
931
+ print("✓ HPSv2 scorer loaded")
932
+ except Exception as e:
933
+ print(f"Warning: Could not load HPSv2 scorer: {e}")
934
+ hpsv2_scorer = None
935
+ else:
936
+ print("⊘ HPSv2 scorer skipped (not in selected metrics)")
937
+
938
+ if "hpsv21" in args.metrics:
939
+ try:
940
+ from hpsv2_score import HPSv2Scorer
941
+ hf_dl_kwargs = {"local_files_only": offline_enabled}
942
+ if hf_cache_dir:
943
+ hf_dl_kwargs["cache_dir"] = hf_cache_dir
944
+ hpsv21_scorer = HPSv2Scorer(
945
+ clip_pretrained_name_or_path=hf_hub_download(
946
+ repo_id="laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
947
+ filename="open_clip_pytorch_model.bin",
948
+ **hf_dl_kwargs,
949
+ ),
950
+ model_pretrained_name_or_path=hf_hub_download(
951
+ repo_id="xswu/HPSv2",
952
+ filename="HPS_v2.1_compressed.pt",
953
+ **hf_dl_kwargs,
954
+ ),
955
+ device=scorer_device
956
+ )
957
+ print("✓ HPSv2.1 scorer loaded")
958
+ except Exception as e:
959
+ print(f"Warning: Could not load HPSv2.1 scorer: {e}")
960
+ hpsv21_scorer = None
961
+ else:
962
+ print("⊘ HPSv2.1 scorer skipped (not in selected metrics)")
963
+
964
+ if "imagereward" in args.metrics:
965
+ try:
966
+ from imagereward_score import load_imagereward
967
+ hf_dl_kwargs = {"local_files_only": offline_enabled}
968
+ if hf_cache_dir:
969
+ hf_dl_kwargs["cache_dir"] = hf_cache_dir
970
+ imagereward_scorer = load_imagereward(
971
+ model_path=hf_hub_download(repo_id="THUDM/ImageReward", filename="ImageReward.pt", **hf_dl_kwargs),
972
+ med_config=hf_hub_download(repo_id="THUDM/ImageReward", filename="med_config.json", **hf_dl_kwargs),
973
+ device=scorer_device
974
+ )
975
+ print("✓ ImageReward scorer loaded")
976
+ except Exception as e:
977
+ print(f"Warning: Could not load ImageReward scorer: {e}")
978
+ imagereward_scorer = None
979
+ else:
980
+ print("⊘ ImageReward scorer skipped (not in selected metrics)")
981
+
982
+ # Configure gradient ascent
983
+ print("\n" + "="*70)
984
+ print("4. CONFIGURING GRADIENT ASCENT")
985
+ print("="*70)
986
+
987
+ # Use config preset if provided, otherwise use individual args
988
+ if args.grad_config:
989
+ print(f"Loading gradient ascent config: {args.grad_config}")
990
+ grad_config = get_config(args.grad_config)
991
+ print(f"Config loaded: {grad_config}")
992
+
993
+ # Apply overrides if specified
994
+ if args.override_momentum is not None:
995
+ grad_config['momentum'] = args.override_momentum
996
+ print(f" Overriding momentum: {args.override_momentum}")
997
+ if args.override_num_grad_steps is not None:
998
+ grad_config['num_grad_steps'] = args.override_num_grad_steps
999
+ print(f" Overriding num_grad_steps: {args.override_num_grad_steps}")
1000
+ if args.override_grad_step_size is not None:
1001
+ grad_config['grad_step_size'] = args.override_grad_step_size
1002
+ print(f" Overriding grad_step_size: {args.override_grad_step_size}")
1003
+ else:
1004
+ grad_config = {
1005
+ "grad_timestep_range": (args.grad_range_start, args.grad_range_end),
1006
+ "num_grad_steps": args.grad_steps,
1007
+ "grad_step_size": args.grad_step_size,
1008
+ }
1009
+ print(f"Using manual gradient ascent configuration")
1010
+
1011
+ print(f"Gradient timestep range: {grad_config.get('grad_timestep_range', (args.grad_range_start, args.grad_range_end))}")
1012
+ print(f"Gradient steps: {grad_config.get('num_grad_steps', args.grad_steps)}")
1013
+ print(f"Gradient step size (initial LR): {grad_config.get('grad_step_size', args.grad_step_size)}")
1014
+ if grad_config.get('lr_scheduler_type'):
1015
+ print(f"LR Scheduler: {grad_config['lr_scheduler_type']}")
1016
+ if grad_config.get('use_momentum'):
1017
+ print(f"Momentum: {grad_config.get('momentum', 0.9)} (Nesterov: {grad_config.get('use_nesterov', False)})")
1018
+
1019
+ pipeline.enable_gradient_ascent(**grad_config)
1020
+
1021
+ # Initialize result variables
1022
+ fid_score_baseline = None
1023
+ avg_reward_baseline = None
1024
+ clip_score_baseline = None
1025
+ aesthetic_score_baseline = None
1026
+ pick_score_baseline = None
1027
+ hpsv2_score_baseline = None
1028
+ hpsv21_score_baseline = None
1029
+ imagereward_score_baseline = None
1030
+ fid_score_grad = None
1031
+ avg_reward_grad = None
1032
+ clip_score_grad = None
1033
+ aesthetic_score_grad = None
1034
+ pick_score_grad = None
1035
+ hpsv2_score_grad = None
1036
+ hpsv21_score_grad = None
1037
+ imagereward_score_grad = None
1038
+ grad_stats = None
1039
+
1040
+ # ========== BASELINE EVALUATION ==========
1041
+ if args.mode in ["baseline", "both"]:
1042
+ print("\n" + "="*70)
1043
+ print("5. EVALUATING BASELINE")
1044
+ print("="*70)
1045
+
1046
+ # Generate and evaluate baseline
1047
+ avg_reward_baseline, fid_baseline, clip_score_baseline, aesthetic_score_baseline, pick_score_baseline, hpsv2_score_baseline, hpsv21_score_baseline, imagereward_score_baseline, _, baseline_trajectory = generate_and_evaluate(
1048
+ pipeline=pipeline,
1049
+ prompts=prompts,
1050
+ image_paths=image_paths,
1051
+ device=device,
1052
+ dtype=dtype,
1053
+ num_inference_steps=args.num_steps,
1054
+ guidance_scale=args.cfg_scale,
1055
+ seed=args.seed,
1056
+ batch_size=args.batch_size,
1057
+ apply_gradient_ascent=False,
1058
+ mode_name="baseline",
1059
+ log_interval=args.log_interval,
1060
+ output_dir=args.output_dir,
1061
+ save_images=args.save_images,
1062
+ clip_scorer=clip_scorer,
1063
+ aesthetic_scorer=aesthetic_scorer,
1064
+ pick_scorer=pick_scorer,
1065
+ hpsv2_scorer=hpsv2_scorer,
1066
+ hpsv21_scorer=hpsv21_scorer,
1067
+ imagereward_scorer=imagereward_scorer,
1068
+ compute_fid=("fid" in args.metrics and can_compute_fid),
1069
+ capture_trajectory=True
1070
+ )
1071
+
1072
+ # Compute FID for baseline if requested
1073
+ if "fid" in args.metrics and fid_baseline is not None:
1074
+ fid_score_baseline = fid_baseline.compute().item()
1075
+ print(f"\n✓ Baseline FID: {fid_score_baseline:.4f}")
1076
+ print(f"✓ Baseline Avg Reward: {avg_reward_baseline:.4f}")
1077
+ if "clip" in args.metrics:
1078
+ print(f"✓ Baseline Avg CLIP Score: {clip_score_baseline:.4f}")
1079
+ if "aesthetic" in args.metrics:
1080
+ print(f"✓ Baseline Avg Aesthetic Score: {aesthetic_score_baseline:.4f}")
1081
+ if "pickscore" in args.metrics and pick_score_baseline is not None:
1082
+ print(f"✓ Baseline Avg PickScore: {pick_score_baseline:.4f}")
1083
+ if "hpsv2" in args.metrics and hpsv2_score_baseline is not None:
1084
+ print(f"✓ Baseline Avg HPSv2 Score: {hpsv2_score_baseline:.4f}")
1085
+ if "hpsv21" in args.metrics and hpsv21_score_baseline is not None:
1086
+ print(f"✓ Baseline Avg HPSv2.1 Score: {hpsv21_score_baseline:.4f}")
1087
+ if "imagereward" in args.metrics and imagereward_score_baseline is not None:
1088
+ print(f"✓ Baseline Avg ImageReward: {imagereward_score_baseline:.4f}")
1089
+
1090
+ # ========== GRADIENT ASCENT EVALUATION ==========
1091
+ if args.mode in ["gradient_ascent", "both"]:
1092
+ print("\n" + "="*70)
1093
+ print("6. EVALUATING GRADIENT ASCENT")
1094
+ print("="*70)
1095
+
1096
+ # Generate and evaluate with gradient ascent
1097
+ avg_reward_grad, fid_grad, clip_score_grad, aesthetic_score_grad, pick_score_grad, hpsv2_score_grad, hpsv21_score_grad, imagereward_score_grad, lr_history, guided_trajectory = generate_and_evaluate(
1098
+ pipeline=pipeline,
1099
+ prompts=prompts,
1100
+ image_paths=image_paths,
1101
+ device=device,
1102
+ dtype=dtype,
1103
+ num_inference_steps=args.num_steps,
1104
+ guidance_scale=args.cfg_scale,
1105
+ seed=args.seed,
1106
+ batch_size=args.batch_size,
1107
+ apply_gradient_ascent=True,
1108
+ mode_name="gradient_ascent",
1109
+ log_interval=args.log_interval,
1110
+ output_dir=args.output_dir,
1111
+ save_images=args.save_images,
1112
+ clip_scorer=clip_scorer,
1113
+ aesthetic_scorer=aesthetic_scorer,
1114
+ pick_scorer=pick_scorer,
1115
+ hpsv2_scorer=hpsv2_scorer,
1116
+ hpsv21_scorer=hpsv21_scorer,
1117
+ imagereward_scorer=imagereward_scorer,
1118
+ compute_fid=("fid" in args.metrics and can_compute_fid),
1119
+ capture_trajectory=True
1120
+ )
1121
+
1122
+ # Compute FID for gradient ascent if requested
1123
+ if "fid" in args.metrics and fid_grad is not None:
1124
+ fid_score_grad = fid_grad.compute().item()
1125
+ print(f"\n✓ Gradient Ascent FID: {fid_score_grad:.4f}")
1126
+ print(f"✓ Gradient Ascent Avg Reward: {avg_reward_grad:.4f}")
1127
+ if "clip" in args.metrics:
1128
+ print(f"✓ Gradient Ascent Avg CLIP Score: {clip_score_grad:.4f}")
1129
+ if "aesthetic" in args.metrics:
1130
+ print(f"✓ Gradient Ascent Avg Aesthetic Score: {aesthetic_score_grad:.4f}")
1131
+ if "pickscore" in args.metrics and pick_score_grad is not None:
1132
+ print(f"✓ Gradient Ascent Avg PickScore: {pick_score_grad:.4f}")
1133
+ if "hpsv2" in args.metrics and hpsv2_score_grad is not None:
1134
+ print(f"✓ Gradient Ascent Avg HPSv2 Score: {hpsv2_score_grad:.4f}")
1135
+ if "hpsv21" in args.metrics and hpsv21_score_grad is not None:
1136
+ print(f"✓ Gradient Ascent Avg HPSv2.1 Score: {hpsv21_score_grad:.4f}")
1137
+ if "imagereward" in args.metrics and imagereward_score_grad is not None:
1138
+ print(f"✓ Gradient Ascent Avg ImageReward: {imagereward_score_grad:.4f}")
1139
+
1140
+ # Get gradient stats
1141
+ grad_stats = pipeline.grad_guidance.get_statistics()
1142
+ if grad_stats:
1143
+ print(f"\nGradient Ascent Statistics:")
1144
+ print(f" Applications: {grad_stats['num_applications']}")
1145
+ print(f" Total reward improvement: {grad_stats['total_reward_improvement']:+.4f}")
1146
+ print(f" Avg reward improvement: {grad_stats['avg_reward_improvement']:+.4f}")
1147
+
1148
+ # Plot LR curve if we captured it
1149
+ if lr_history is not None and lr_history['learning_rates']:
1150
+ plot_path = Path(args.output_dir) / "lr_curve.png"
1151
+
1152
+ # LR values are now continuous across all gradient steps
1153
+ lrs = lr_history['learning_rates']
1154
+ steps = list(range(len(lrs))) # Step indices (0 to total_steps-1)
1155
+
1156
+ plt.figure(figsize=(12, 6))
1157
+ plt.plot(steps, lrs, linewidth=2, color='blue', alpha=0.8)
1158
+
1159
+ # Mark the first step with a star
1160
+ plt.plot(steps[0], lrs[0], marker='*', markersize=20, color='gold',
1161
+ markeredgecolor='darkgoldenrod', markeredgewidth=2, zorder=5)
1162
+
1163
+ # Mark timestep boundaries
1164
+ num_timesteps = len(lr_history['timesteps'])
1165
+ num_grad_steps_per_timestep = len(lrs) // num_timesteps if num_timesteps > 0 else 0
1166
+ if num_grad_steps_per_timestep > 0:
1167
+ for i in range(num_timesteps + 1):
1168
+ step_idx = i * num_grad_steps_per_timestep
1169
+ if step_idx <= len(lrs):
1170
+ plt.axvline(x=step_idx, color='red', linestyle='--', alpha=0.3, linewidth=1)
1171
+ if i < num_timesteps:
1172
+ plt.text(step_idx, plt.ylim()[1] * 0.95, f't={lr_history["timesteps"][i]}',
1173
+ fontsize=8, color='red', alpha=0.7, ha='left')
1174
+
1175
+ plt.xlabel('Global Gradient Step', fontsize=12)
1176
+ plt.ylabel('Learning Rate', fontsize=12)
1177
+ plt.title(f'Learning Rate Evolution Across All Gradient Steps\\nPrompt: "{lr_history["prompt"][:60]}..."',
1178
+ fontsize=12, fontweight='bold')
1179
+ plt.grid(True, alpha=0.3)
1180
+
1181
+ # Add info text
1182
+ num_timesteps = len(lr_history['timesteps'])
1183
+ num_grad_steps_per_timestep = len(lrs) // num_timesteps if num_timesteps > 0 else 0
1184
+ plt.text(0.02, 0.98,
1185
+ f'Total timesteps: {num_timesteps}\\nGrad steps/timestep: {num_grad_steps_per_timestep}\\nTotal grad steps: {len(lrs)}',
1186
+ transform=plt.gca().transAxes, fontsize=10, verticalalignment='top',
1187
+ bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
1188
+
1189
+ plt.tight_layout()
1190
+ plt.savefig(plot_path, dpi=150, bbox_inches='tight')
1191
+ plt.close()
1192
+ print(f"\n✓ Saved LR curve plot to: {plot_path}")
1193
+ print(f" Total gradient steps: {len(lrs)}")
1194
+ print(f" LR range: {min(lrs):.6f} → {max(lrs):.6f}")
1195
+
1196
+ # Plot Rewards curve if we captured it
1197
+ if lr_history is not None and lr_history['rewards']:
1198
+ plot_path = Path(args.output_dir) / "rewards_curve.png"
1199
+
1200
+ # Reward values are now continuous across all gradient steps
1201
+ rewards = lr_history['rewards']
1202
+ steps = list(range(len(rewards))) # Step indices (0 to total_steps-1)
1203
+
1204
+ plt.figure(figsize=(12, 6))
1205
+ plt.plot(steps, rewards, linewidth=2, color='green', alpha=0.8)
1206
+
1207
+ # Mark the first step with a star
1208
+ plt.plot(steps[0], rewards[0], marker='*', markersize=20, color='gold',
1209
+ markeredgecolor='darkgoldenrod', markeredgewidth=2, zorder=5)
1210
+
1211
+ # Mark timestep boundaries
1212
+ num_timesteps = len(lr_history['timesteps'])
1213
+ # rewards has one extra value at the start (initial) compared to gradient steps
1214
+ num_grad_steps_per_timestep = (len(rewards) - num_timesteps) // num_timesteps if num_timesteps > 0 else 0
1215
+ if num_grad_steps_per_timestep > 0:
1216
+ for i in range(num_timesteps + 1):
1217
+ step_idx = i * (num_grad_steps_per_timestep + 1) # +1 because reward_history includes initial
1218
+ if step_idx <= len(rewards):
1219
+ plt.axvline(x=step_idx, color='red', linestyle='--', alpha=0.3, linewidth=1)
1220
+ if i < num_timesteps:
1221
+ plt.text(step_idx, plt.ylim()[1] * 0.95, f't={lr_history["timesteps"][i]}',
1222
+ fontsize=8, color='red', alpha=0.7, ha='left')
1223
+
1224
+ plt.xlabel('Global Gradient Step', fontsize=12)
1225
+ plt.ylabel('Reward Score', fontsize=12)
1226
+ plt.title(f'Reward Evolution Across All Gradient Steps\nPrompt: "{lr_history["prompt"][:60]}..."',
1227
+ fontsize=12, fontweight='bold')
1228
+ plt.grid(True, alpha=0.3)
1229
+
1230
+ # Add info text
1231
+ num_timesteps = len(lr_history['timesteps'])
1232
+ reward_improvement = rewards[-1] - rewards[0] if len(rewards) > 1 else 0
1233
+ plt.text(0.02, 0.98,
1234
+ f'Total timesteps: {num_timesteps}\nTotal grad steps: {len(rewards)}\n'
1235
+ f'Initial reward: {rewards[0]:.4f}\nFinal reward: {rewards[-1]:.4f}\n'
1236
+ f'Improvement: {reward_improvement:+.4f}',
1237
+ transform=plt.gca().transAxes, fontsize=10, verticalalignment='top',
1238
+ bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.5))
1239
+
1240
+ plt.tight_layout()
1241
+ plt.savefig(plot_path, dpi=150, bbox_inches='tight')
1242
+ plt.close()
1243
+ print(f"\n✓ Saved Rewards curve plot to: {plot_path}")
1244
+ print(f" Total gradient steps: {len(rewards)}")
1245
+ print(f" Reward range: {min(rewards):.4f} → {max(rewards):.4f}")
1246
+ print(f" Total improvement: {reward_improvement:+.4f}")
1247
+
1248
+ # ---> NEW: PLOT TRAJECTORY DIVERGENCE (MANIFOLD DRIFT) <---
1249
+ if args.mode == "both" and 'baseline_trajectory' in locals() and 'guided_trajectory' in locals():
1250
+ if len(baseline_trajectory) == len(guided_trajectory) and len(baseline_trajectory) > 0:
1251
+ print("\n" + "="*70)
1252
+ print("7. CALCULATING TRAJECTORY DIVERGENCE (THEOREM 1 & 2)")
1253
+ print("="*70)
1254
+
1255
+ drift_path = Path(args.output_dir) / "trajectory_drift.png"
1256
+
1257
+ l2_distances = []
1258
+ # Calculate L2 norm ||z_t_guided - z_t_base||_2 for each step
1259
+ for b_lat, g_lat in zip(baseline_trajectory, guided_trajectory):
1260
+ dist = torch.norm(g_lat.float() - b_lat.float(), p=2).item()
1261
+ l2_distances.append(dist)
1262
+
1263
+ steps = list(range(len(l2_distances)))
1264
+
1265
+ plt.figure(figsize=(10, 6))
1266
+ plt.plot(steps, l2_distances, linewidth=2.5, color='purple', marker='o', markersize=4)
1267
+
1268
+ plt.xlabel('Denoising Step', fontsize=12)
1269
+ plt.ylabel('L2 Distance: ||z_guided - z_base||_2', fontsize=12)
1270
+ plt.title('Latent Trajectory Divergence (Manifold Drift)', fontsize=14, fontweight='bold')
1271
+ plt.grid(True, alpha=0.3)
1272
+
1273
+ # Add interpretation text based on your theory
1274
+ max_drift = max(l2_distances)
1275
+ plt.text(0.02, 0.98,
1276
+ f'Max Drift: {max_drift:.4f}\n'
1277
+ f'Final Drift: {l2_distances[-1]:.4f}\n'
1278
+ f'(Matches bounded drift from Thm 1\n'
1279
+ f'or ODE stiffness collapse from Thm 2)',
1280
+ transform=plt.gca().transAxes, fontsize=10, verticalalignment='top',
1281
+ bbox=dict(boxstyle='round', facecolor='thistle', alpha=0.5))
1282
+
1283
+ plt.tight_layout()
1284
+ plt.savefig(drift_path, dpi=150, bbox_inches='tight')
1285
+ plt.close()
1286
+ print(f"? Saved Manifold Drift curve to: {drift_path}")
1287
+ print(f" Max L2 Distance from baseline: {max_drift:.4f}")
1288
+
1289
+ # ========== FINAL RESULTS ==========
1290
+ print("\n" + "="*70)
1291
+ print("FINAL RESULTS")
1292
+ print("="*70)
1293
+
1294
+ if avg_reward_baseline is not None:
1295
+ print(f"\nBaseline:")
1296
+ if fid_score_baseline is not None:
1297
+ print(f" FID Score: {fid_score_baseline:.4f}")
1298
+ print(f" Avg Reward: {avg_reward_baseline:.4f}")
1299
+ if "clip" in args.metrics and clip_score_baseline is not None:
1300
+ print(f" Avg CLIP Score: {clip_score_baseline:.4f}")
1301
+ if "aesthetic" in args.metrics and aesthetic_score_baseline is not None:
1302
+ print(f" Avg Aesthetic: {aesthetic_score_baseline:.4f}")
1303
+ if "pickscore" in args.metrics and pick_score_baseline is not None:
1304
+ print(f" Avg PickScore: {pick_score_baseline:.4f}")
1305
+ if "hpsv2" in args.metrics and hpsv2_score_baseline is not None:
1306
+ print(f" Avg HPSv2: {hpsv2_score_baseline:.4f}")
1307
+ if "hpsv21" in args.metrics and hpsv21_score_baseline is not None:
1308
+ print(f" Avg HPSv2.1: {hpsv21_score_baseline:.4f}")
1309
+ if "imagereward" in args.metrics and imagereward_score_baseline is not None:
1310
+ print(f" Avg ImageReward: {imagereward_score_baseline:.4f}")
1311
+
1312
+ if avg_reward_grad is not None:
1313
+ print(f"\nGradient Ascent:")
1314
+ if fid_score_grad is not None:
1315
+ print(f" FID Score: {fid_score_grad:.4f}")
1316
+ print(f" Avg Reward: {avg_reward_grad:.4f}")
1317
+ if "clip" in args.metrics and clip_score_grad is not None:
1318
+ print(f" Avg CLIP Score: {clip_score_grad:.4f}")
1319
+ if "aesthetic" in args.metrics and aesthetic_score_grad is not None:
1320
+ print(f" Avg Aesthetic: {aesthetic_score_grad:.4f}")
1321
+ if "pickscore" in args.metrics and pick_score_grad is not None:
1322
+ print(f" Avg PickScore: {pick_score_grad:.4f}")
1323
+ if "hpsv2" in args.metrics and hpsv2_score_grad is not None:
1324
+ print(f" Avg HPSv2: {hpsv2_score_grad:.4f}")
1325
+ if "hpsv21" in args.metrics and hpsv21_score_grad is not None:
1326
+ print(f" Avg HPSv2.1: {hpsv21_score_grad:.4f}")
1327
+ if "imagereward" in args.metrics and imagereward_score_grad is not None:
1328
+ print(f" Avg ImageReward: {imagereward_score_grad:.4f}")
1329
+
1330
+ if avg_reward_baseline is not None and avg_reward_grad is not None:
1331
+ print(f"\nComparison:")
1332
+ if fid_score_baseline is not None and fid_score_grad is not None:
1333
+ fid_diff = fid_score_grad - fid_score_baseline
1334
+ print(f" FID Change: {fid_diff:+.4f} ({'worse' if fid_diff > 0 else 'better'}, lower is better)")
1335
+ reward_diff = avg_reward_grad - avg_reward_baseline
1336
+ print(f" Reward Change: {reward_diff:+.4f} ({'better' if reward_diff > 0 else 'worse'}, higher is better)")
1337
+ if "clip" in args.metrics and clip_score_baseline is not None and clip_score_grad is not None:
1338
+ clip_diff = clip_score_grad - clip_score_baseline
1339
+ print(f" CLIP Change: {clip_diff:+.4f} ({'better' if clip_diff > 0 else 'worse'}, higher is better)")
1340
+ if "aesthetic" in args.metrics and aesthetic_score_baseline is not None and aesthetic_score_grad is not None:
1341
+ aesthetic_diff = aesthetic_score_grad - aesthetic_score_baseline
1342
+ print(f" Aesthetic Change: {aesthetic_diff:+.4f} ({'better' if aesthetic_diff > 0 else 'worse'}, higher is better)")
1343
+ if "pickscore" in args.metrics and pick_score_baseline is not None and pick_score_grad is not None:
1344
+ pick_diff = pick_score_grad - pick_score_baseline
1345
+ print(f" PickScore Change: {pick_diff:+.4f} ({'better' if pick_diff > 0 else 'worse'}, higher is better)")
1346
+ if "hpsv2" in args.metrics and hpsv2_score_baseline is not None and hpsv2_score_grad is not None:
1347
+ hpsv2_diff = hpsv2_score_grad - hpsv2_score_baseline
1348
+ print(f" HPSv2 Change: {hpsv2_diff:+.4f} ({'better' if hpsv2_diff > 0 else 'worse'}, higher is better)")
1349
+ if "hpsv21" in args.metrics and hpsv21_score_baseline is not None and hpsv21_score_grad is not None:
1350
+ hpsv21_diff = hpsv21_score_grad - hpsv21_score_baseline
1351
+ print(f" HPSv2.1 Change: {hpsv21_diff:+.4f} ({'better' if hpsv21_diff > 0 else 'worse'}, higher is better)")
1352
+ if "imagereward" in args.metrics and imagereward_score_baseline is not None and imagereward_score_grad is not None:
1353
+ imagereward_diff = imagereward_score_grad - imagereward_score_baseline
1354
+ print(f" ImageReward Chg: {imagereward_diff:+.4f} ({'better' if imagereward_diff > 0 else 'worse'}, higher is better)")
1355
+
1356
+ # Save results to file
1357
+ results = {
1358
+ "mode": args.mode,
1359
+ "metrics": args.metrics,
1360
+ "config": {
1361
+ "num_samples": len(prompts),
1362
+ "num_steps": args.num_steps,
1363
+ "cfg_scale": args.cfg_scale,
1364
+ "grad_range": [args.grad_range_start, args.grad_range_end],
1365
+ "grad_steps": args.grad_steps,
1366
+ "grad_step_size": args.grad_step_size
1367
+ }
1368
+ }
1369
+
1370
+ if avg_reward_baseline is not None:
1371
+ results["baseline"] = {"avg_reward": avg_reward_baseline}
1372
+ if fid_score_baseline is not None:
1373
+ results["baseline"]["fid"] = fid_score_baseline
1374
+ if "clip" in args.metrics and clip_score_baseline is not None:
1375
+ results["baseline"]["clip_score"] = clip_score_baseline
1376
+ if "aesthetic" in args.metrics and aesthetic_score_baseline is not None:
1377
+ results["baseline"]["aesthetic_score"] = aesthetic_score_baseline
1378
+ if "pickscore" in args.metrics and pick_score_baseline is not None:
1379
+ results["baseline"]["pickscore"] = pick_score_baseline
1380
+ if "hpsv2" in args.metrics and hpsv2_score_baseline is not None:
1381
+ results["baseline"]["hpsv2_score"] = hpsv2_score_baseline
1382
+ if "hpsv21" in args.metrics and hpsv21_score_baseline is not None:
1383
+ results["baseline"]["hpsv21_score"] = hpsv21_score_baseline
1384
+ if "imagereward" in args.metrics and imagereward_score_baseline is not None:
1385
+ results["baseline"]["imagereward_score"] = imagereward_score_baseline
1386
+
1387
+ if avg_reward_grad is not None:
1388
+ results["gradient_ascent"] = {"avg_reward": avg_reward_grad}
1389
+ if fid_score_grad is not None:
1390
+ results["gradient_ascent"]["fid"] = fid_score_grad
1391
+ if "clip" in args.metrics and clip_score_grad is not None:
1392
+ results["gradient_ascent"]["clip_score"] = clip_score_grad
1393
+ if "aesthetic" in args.metrics and aesthetic_score_grad is not None:
1394
+ results["gradient_ascent"]["aesthetic_score"] = aesthetic_score_grad
1395
+ if "pickscore" in args.metrics and pick_score_grad is not None:
1396
+ results["gradient_ascent"]["pickscore"] = pick_score_grad
1397
+ if "hpsv2" in args.metrics and hpsv2_score_grad is not None:
1398
+ results["gradient_ascent"]["hpsv2_score"] = hpsv2_score_grad
1399
+ if "hpsv21" in args.metrics and hpsv21_score_grad is not None:
1400
+ results["gradient_ascent"]["hpsv21_score"] = hpsv21_score_grad
1401
+ if "imagereward" in args.metrics and imagereward_score_grad is not None:
1402
+ results["gradient_ascent"]["imagereward_score"] = imagereward_score_grad
1403
+ if grad_stats:
1404
+ results["gradient_ascent"]["stats"] = grad_stats
1405
+
1406
+ if avg_reward_baseline is not None and avg_reward_grad is not None:
1407
+ results["comparison"] = {
1408
+ "reward_difference": avg_reward_grad - avg_reward_baseline
1409
+ }
1410
+ if fid_score_baseline is not None and fid_score_grad is not None:
1411
+ results["comparison"]["fid_difference"] = fid_score_grad - fid_score_baseline
1412
+ if "clip" in args.metrics and clip_score_baseline is not None and clip_score_grad is not None:
1413
+ results["comparison"]["clip_difference"] = clip_score_grad - clip_score_baseline
1414
+ if "aesthetic" in args.metrics and aesthetic_score_baseline is not None and aesthetic_score_grad is not None:
1415
+ results["comparison"]["aesthetic_difference"] = aesthetic_score_grad - aesthetic_score_baseline
1416
+ if "pickscore" in args.metrics and pick_score_baseline is not None and pick_score_grad is not None:
1417
+ results["comparison"]["pickscore_difference"] = pick_score_grad - pick_score_baseline
1418
+ if "hpsv2" in args.metrics and hpsv2_score_baseline is not None and hpsv2_score_grad is not None:
1419
+ results["comparison"]["hpsv2_difference"] = hpsv2_score_grad - hpsv2_score_baseline
1420
+ if "hpsv21" in args.metrics and hpsv21_score_baseline is not None and hpsv21_score_grad is not None:
1421
+ results["comparison"]["hpsv21_difference"] = hpsv21_score_grad - hpsv21_score_baseline
1422
+ if "imagereward" in args.metrics and imagereward_score_baseline is not None and imagereward_score_grad is not None:
1423
+ results["comparison"]["imagereward_difference"] = imagereward_score_grad - imagereward_score_baseline
1424
+
1425
+ # Save results to output directory
1426
+ output_path = Path(args.output_dir)
1427
+ output_path.mkdir(parents=True, exist_ok=True)
1428
+ results_path = output_path / "evaluation_results.txt"
1429
+
1430
+ with open(results_path, "w") as f:
1431
+ for k, v in results.items():
1432
+ f.write(f"{k}: {v}\n")
1433
+
1434
+
1435
+ print(f"\n✓ Results saved to: {results_path}")
1436
+ if args.save_images:
1437
+ print(f"✓ Generated images saved to: {output_path}/baseline/ and {output_path}/gradient_ascent/")
1438
+ print("\n" + "="*70)
1439
+
1440
+ # Close logger
1441
+ tee_logger.close()
1442
+ sys.stdout = tee_logger.terminal
1443
+
1444
+
1445
+ if __name__ == "__main__":
1446
+ main()
1447
+
Reward_sana_idealized/examples.sh ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+ # bash examples.sh
4
+ if [[ -n "${TERM:-}" ]]; then
5
+ clear
6
+ fi
7
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
8
+ cd "$SCRIPT_DIR"
9
+
10
+ # Shared HF cache used on this cluster.
11
+ HF_HUB_CACHE_DIR="${HF_HUB_CACHE_DIR:-/scratch/rr81/ma5430/.cache/huggingface/hub}"
12
+ export HF_HUB_CACHE="$HF_HUB_CACHE_DIR"
13
+ export HUGGINGFACE_HUB_CACHE="$HF_HUB_CACHE_DIR"
14
+ export HF_HOME="$(dirname "$HF_HUB_CACHE_DIR")"
15
+
16
+ # GPU nodes have no internet, while login nodes do.
17
+ # Auto default: offline on GPU nodes, online on login nodes.
18
+ DEFAULT_OFFLINE_MODE="1"
19
+ if ! (command -v nvidia-smi >/dev/null 2>&1 && nvidia-smi -L >/dev/null 2>&1); then
20
+ DEFAULT_OFFLINE_MODE="0"
21
+ fi
22
+ OFFLINE_MODE="${OFFLINE_MODE:-$DEFAULT_OFFLINE_MODE}"
23
+
24
+ if [[ "$OFFLINE_MODE" == "1" ]]; then
25
+ export HF_DATASETS_OFFLINE="1"
26
+ export HF_METRICS_OFFLINE="1"
27
+ export HF_MODULES_OFFLINE="1"
28
+ export TRANSFORMERS_OFFLINE="1"
29
+ export DIFFUSERS_OFFLINE="1"
30
+ export HF_HUB_OFFLINE="1"
31
+ else
32
+ export HF_DATASETS_OFFLINE="0"
33
+ export HF_METRICS_OFFLINE="0"
34
+ export HF_MODULES_OFFLINE="0"
35
+ export TRANSFORMERS_OFFLINE="0"
36
+ export DIFFUSERS_OFFLINE="0"
37
+ export HF_HUB_OFFLINE="0"
38
+ fi
39
+
40
+ # Existing environment requested by user.
41
+ PYTHON_BIN="${PYTHON_BIN:-/g/data/rr81/aev/bin/python}"
42
+ if [[ ! -x "$PYTHON_BIN" ]]; then
43
+ echo "[examples.sh] Missing Python executable: $PYTHON_BIN" >&2
44
+ exit 1
45
+ fi
46
+
47
+ DATASET_NAME="${DATASET_NAME:-pickapic}" # coco | pickapic
48
+ GRAD_CONFIG="${GRAD_CONFIG:-one_step_rectification_config}"
49
+ MODEL_PROFILE="${MODEL_PROFILE:-sana_600m_512}" # sana_600m_512 | sana_1600m_512 | sana_sprint_0_6b_1024 | sana_sprint_1_6b_1024
50
+ MODE="${MODE:-gradient_ascent}" # gradient_ascent | baseline | both
51
+ # Empty MAX_SAMPLES means evaluate all available samples.
52
+ MAX_SAMPLES="${MAX_SAMPLES:-}"
53
+ NUM_STEPS="${NUM_STEPS:-20}"
54
+ CFG_SCALE="${CFG_SCALE:-4.5}"
55
+ DTYPE="${DTYPE:-bf16}" # auto | bf16 | fp16 | fp32
56
+ METRICS="${METRICS:-clip aesthetic pickscore hpsv2 hpsv21 imagereward}"
57
+ PREFETCH_ONLY="${PREFETCH_ONLY:-0}"
58
+
59
+ # Override this path whenever you want to swap reward weights.
60
+ # LRM_MODEL_PATH="${LRM_MODEL_PATH:-/g/data/rr81/LPO/lrm/lrm_sana/logs/v8/reward_model/step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951/checkpoint-gstep33000}"
61
+ LRM_MODEL_PATH="${LRM_MODEL_PATH:-/g/data/rr81/LPO/lrm/lrm_sana/logs/v7/reward_model/step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951/checkpoint-gstep32000}"
62
+
63
+ if [[ -z "${GPU_ID:-}" ]]; then
64
+ if command -v nvidia-smi >/dev/null 2>&1 && nvidia-smi -L >/dev/null 2>&1; then
65
+ GPU_ID="$(nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits | sort -k2 -n | head -n1 | cut -d',' -f1 | tr -d ' ')"
66
+ GPU_ID="${GPU_ID:-0}"
67
+ else
68
+ GPU_ID="0"
69
+ echo "[examples.sh] No visible NVIDIA GPU on this node. Defaulting GPU_ID=0."
70
+ echo "[examples.sh] eval.py will run on CPU if CUDA is unavailable."
71
+ fi
72
+ fi
73
+
74
+ echo "Using GPU ID: $GPU_ID"
75
+ echo "Using LRM weights: $LRM_MODEL_PATH"
76
+ echo "HF offline mode: $OFFLINE_MODE"
77
+ echo "Generation dtype: $DTYPE"
78
+
79
+ if [[ "$PREFETCH_ONLY" == "1" ]]; then
80
+ echo "[examples.sh] PREFETCH_ONLY=1 -> downloading required model files to shared cache and exiting."
81
+ export MODEL_PROFILE
82
+ export METRICS
83
+ "$PYTHON_BIN" - <<'PY'
84
+ import os
85
+ from huggingface_hub import hf_hub_download, snapshot_download
86
+
87
+ cache_dir = os.environ["HF_HUB_CACHE"]
88
+ model_profile = os.environ.get("MODEL_PROFILE", "sana_600m_512")
89
+ metrics = set(os.environ.get("METRICS", "clip aesthetic").split())
90
+
91
+ profile_to_repo = {
92
+ "sana_600m_512": "Efficient-Large-Model/Sana_600M_512px_diffusers",
93
+ "sana_1600m_512": "Efficient-Large-Model/Sana_1600M_512px_diffusers",
94
+ "sana_sprint_0_6b_1024": "Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers",
95
+ "sana_sprint_1_6b_1024": "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
96
+ }
97
+
98
+ def snap(repo_id):
99
+ print(f"[prefetch] snapshot_download: {repo_id}")
100
+ snapshot_download(repo_id=repo_id, cache_dir=cache_dir, local_files_only=False)
101
+
102
+ def one(repo_id, filename):
103
+ print(f"[prefetch] hf_hub_download: {repo_id}/{filename}")
104
+ hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir, local_files_only=False)
105
+
106
+ if model_profile not in profile_to_repo:
107
+ raise ValueError(f"Unknown MODEL_PROFILE={model_profile}")
108
+
109
+ # Base SANA model used for generation + reward backbone
110
+ snap(profile_to_repo[model_profile])
111
+
112
+ # Required for CLIP-based metrics and LRM text projection init fallback
113
+ if "clip" in metrics or "aesthetic" in metrics:
114
+ snap("openai/clip-vit-large-patch14")
115
+
116
+ if "pickscore" in metrics:
117
+ snap("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
118
+ snap("yuvalkirstain/PickScore_v1")
119
+
120
+ if "hpsv2" in metrics or "hpsv21" in metrics:
121
+ one("laion/CLIP-ViT-H-14-laion2B-s32B-b79K", "open_clip_pytorch_model.bin")
122
+ if "hpsv2" in metrics:
123
+ one("xswu/HPSv2", "HPS_v2_compressed.pt")
124
+ if "hpsv21" in metrics:
125
+ one("xswu/HPSv2", "HPS_v2.1_compressed.pt")
126
+
127
+ if "imagereward" in metrics:
128
+ one("THUDM/ImageReward", "ImageReward.pt")
129
+ one("THUDM/ImageReward", "med_config.json")
130
+
131
+ print("[prefetch] done")
132
+ PY
133
+ exit 0
134
+ fi
135
+
136
+ read -r -a METRICS_ARR <<< "$METRICS"
137
+
138
+ CMD=(
139
+ "$PYTHON_BIN" eval.py
140
+ --model_variant "$MODEL_PROFILE"
141
+ --dataset_type "$DATASET_NAME"
142
+ --lrm_model "$LRM_MODEL_PATH"
143
+ --grad_config "$GRAD_CONFIG"
144
+ --metrics "${METRICS_ARR[@]}"
145
+ --num_steps "$NUM_STEPS"
146
+ --cfg_scale "$CFG_SCALE"
147
+ --dtype "$DTYPE"
148
+ --hf_cache_dir "$HF_HUB_CACHE_DIR"
149
+ --output_dir "RESULTS/$DATASET_NAME/${GRAD_CONFIG}_${MODEL_PROFILE}"
150
+ --cuda "$GPU_ID"
151
+ --mode "$MODE"
152
+ )
153
+
154
+ if [[ -n "$MAX_SAMPLES" ]]; then
155
+ CMD+=(--max_samples "$MAX_SAMPLES")
156
+ fi
157
+
158
+ if [[ "$OFFLINE_MODE" == "1" ]]; then
159
+ CMD+=(--offline)
160
+ fi
161
+
162
+ "${CMD[@]}"
Reward_sana_idealized/grad_ascent_configs.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration presets for gradient ascent optimization.
3
+
4
+ Provides pre-configured settings for various optimization strategies
5
+ including learning rate scheduling and momentum configurations.
6
+ """
7
+
8
+ from typing import Dict, Any
9
+
10
+
11
+ ONE_STEP_RECTIFICATION_CONFIG = {
12
+ "grad_timestep_range": (100, 800), # Match SDXL one-step rectification window
13
+ "num_grad_steps": 1,
14
+ "grad_step_size": 1.0,
15
+ "grad_scale": 1.0,
16
+ "lr_scheduler_type": "constant",
17
+ "use_momentum": False,
18
+ "use_nesterov": False,
19
+ "use_iso_projection": False
20
+ }
21
+
22
+ # ============================================================================
23
+ # Config Dictionary (for easy access)
24
+ # ============================================================================
25
+
26
+ CONFIGS = {
27
+ "one_step_rectification_config": ONE_STEP_RECTIFICATION_CONFIG,
28
+ }
29
+
30
+
31
+ def get_config(config_name: str) -> Dict[str, Any]:
32
+ """
33
+ Get a gradient ascent configuration by name.
34
+
35
+ Args:
36
+ config_name: Name of the configuration
37
+
38
+ Returns:
39
+ Configuration dictionary
40
+
41
+ Raises:
42
+ ValueError: If config_name is not found
43
+
44
+ Example:
45
+ config = get_config("cosine_nesterov")
46
+ pipeline.enable_gradient_ascent(**config)
47
+ """
48
+ if config_name not in CONFIGS:
49
+ available = ", ".join(sorted(CONFIGS.keys()))
50
+ raise ValueError(f"Unknown config: {config_name}. Available: {available}")
51
+
52
+ return CONFIGS[config_name].copy()
53
+
54
+
55
+ def list_configs() -> list:
56
+ """List all available configuration names."""
57
+ return sorted(CONFIGS.keys())
58
+
59
+
60
+ def print_config(config_name: str):
61
+ """Print a configuration in a readable format."""
62
+ config = get_config(config_name)
63
+ print(f"\nConfiguration: {config_name}")
64
+ print("=" * 60)
65
+ for key, value in config.items():
66
+ print(f" {key}: {value}")
67
+ print("=" * 60)
Reward_sana_idealized/gradient_ascent_utils.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradient Ascent utilities for reward-guided diffusion generation.
3
+
4
+ This module implements gradient ascent on the LRM reward score to guide
5
+ the diffusion process toward higher preference scores.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from typing import Optional, Tuple, List, Literal
11
+ from tqdm import tqdm
12
+ from lr_scheduler import create_lr_scheduler, LRScheduler
13
+
14
+
15
+ class RewardGuidedDiffusion:
16
+ """
17
+ Implements reward-guided generation using gradient ascent.
18
+
19
+ During denoising, at specified timesteps, we:
20
+ 1. Compute the reward score for current latents
21
+ 2. Calculate gradients of reward w.r.t. latents
22
+ 3. Update latents in the direction that increases reward
23
+
24
+ This guides generation toward higher preference scores.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ reward_model,
30
+ grad_scale: float = 1.0,
31
+ grad_timestep_range: Optional[Tuple[int, int]] = None,
32
+ num_grad_steps: int = 5,
33
+ grad_step_size: float = 0.1,
34
+ gradient_checkpoint: bool = False,
35
+ # LR Scheduling
36
+ lr_scheduler_type: Literal["constant", "linear", "cosine", "exponential", "step"] = "constant",
37
+ lr_scheduler_kwargs: Optional[dict] = None,
38
+ # Momentum
39
+ use_momentum: bool = False,
40
+ momentum: float = 0.9,
41
+ use_nesterov: bool = False,
42
+ use_iso_projection: bool = False
43
+ ):
44
+ """
45
+ Initialize reward-guided diffusion.
46
+
47
+ Args:
48
+ reward_model: LRM reward model for computing preference scores
49
+ grad_scale: Scale factor for gradient updates (default: 1.0)
50
+ grad_timestep_range: Tuple of (min_t, max_t) for gradient ascent.
51
+ If None, applies to all timesteps.
52
+ num_grad_steps: Number of gradient ascent steps per timestep
53
+ grad_step_size: Step size for each gradient update (initial LR)
54
+ gradient_checkpoint: Whether to use gradient checkpointing
55
+ lr_scheduler_type: Type of LR scheduler ("constant", "linear", "cosine", "exponential", "step")
56
+ lr_scheduler_kwargs: Additional kwargs for LR scheduler (e.g., end_lr, min_lr, warmup_steps)
57
+ use_momentum: Whether to use momentum in gradient updates
58
+ momentum: Momentum coefficient (typically 0.9)
59
+ use_nesterov: Whether to use Nesterov momentum
60
+ use_iso_projection: Whether to use Iso Projection
61
+ """
62
+ self.reward_model = reward_model
63
+ self.grad_scale = grad_scale
64
+ self.grad_timestep_range = grad_timestep_range
65
+ self.num_grad_steps = num_grad_steps
66
+ self.grad_step_size = grad_step_size
67
+ self.gradient_checkpoint = gradient_checkpoint
68
+
69
+ # LR Scheduler
70
+ self.lr_scheduler_type = lr_scheduler_type
71
+ self.lr_scheduler_kwargs = lr_scheduler_kwargs or {}
72
+ self.lr_scheduler: Optional[LRScheduler] = None
73
+ self.global_lr_scheduler: Optional[LRScheduler] = None # Scheduler across denoising timesteps
74
+
75
+ # Momentum
76
+ self.use_momentum = use_momentum
77
+ self.momentum = momentum
78
+ self.use_nesterov = use_nesterov
79
+ self.velocity = None # Will be initialized per optimization
80
+
81
+ self.use_iso_projection = use_iso_projection
82
+
83
+ # Statistics
84
+ self.grad_stats = []
85
+ self.timestep_counter = 0 # Track which timestep we're on
86
+
87
+ def should_apply_gradient(self, timestep: int) -> bool:
88
+ """Check if gradient ascent should be applied at this timestep."""
89
+
90
+ if self.grad_timestep_range is None:
91
+ return False
92
+
93
+ min_t, max_t = self.grad_timestep_range
94
+ return min_t <= timestep <= max_t
95
+
96
+ @torch.enable_grad()
97
+ def compute_reward_gradient(
98
+ self,
99
+ latents: torch.Tensor,
100
+ prompt,
101
+ timestep: int,
102
+ ) -> Tuple[torch.Tensor, float]:
103
+ """
104
+ Compute gradient of reward score w.r.t. latents in FP32 to prevent underflow.
105
+ """
106
+ # 1. Cast to FP32 and ensure we are detached from previous iterations
107
+ latents_fp32 = latents.detach().to(torch.float32).clone()
108
+ latents_fp32.requires_grad_(True)
109
+
110
+ # 2. Compute reward score
111
+ # Note: Even if the model internally uses fp16/bf16, autograd will
112
+ # safely accumulate the gradient in fp32 for our leaf node.
113
+ reward_score = self.reward_model.get_reward_score(
114
+ latents_fp32,
115
+ prompt,
116
+ timestep,
117
+ enable_grad=True,
118
+ return_logits=True,
119
+ )
120
+
121
+ reward_score_mean = reward_score.mean()
122
+ if not torch.isfinite(reward_score_mean):
123
+ return torch.zeros_like(latents), 0.0
124
+
125
+ # 3. Extract gradient
126
+ # CRITICAL: retain_graph=True prevents the graph from dying across multiple
127
+ # gradient steps if your reward model relies on cached text embeddings.
128
+ grad = torch.autograd.grad(
129
+ outputs=reward_score_mean,
130
+ inputs=latents_fp32,
131
+ create_graph=False,
132
+ retain_graph=True, # Keeps the graph alive for the next step!
133
+ allow_unused=True,
134
+ )[0]
135
+
136
+ # 4. Handle None gradients and cast back to the pipeline's original dtype
137
+ if grad is None:
138
+ grad = torch.zeros_like(latents)
139
+ else:
140
+ grad = torch.nan_to_num(grad, nan=0.0, posinf=0.0, neginf=0.0)
141
+ grad = grad.to(latents.dtype)
142
+
143
+ return grad, reward_score_mean.item()
144
+
145
+ def apply_gradient_ascent(
146
+ self,
147
+ latents: torch.Tensor,
148
+ prompt,
149
+ timestep: int,
150
+ base_noise: Optional[torch.Tensor] = None, # Required for Iso-Marginal projection
151
+ verbose: bool = True,
152
+ total_denoising_steps: Optional[int] = None,
153
+ ) -> Tuple[torch.Tensor, dict]:
154
+
155
+ # 1. UPCAST TO FP32 AND SETUP OPTIMIZER (Targeting Latents)
156
+ original_latents = latents.detach().clone().to(torch.float32)
157
+ current_latents = torch.nn.Parameter(original_latents.clone())
158
+
159
+ # Initial reward tracking
160
+ with torch.no_grad():
161
+ initial_reward = self.reward_model.get_reward_score(
162
+ latents,
163
+ prompt,
164
+ timestep
165
+ )
166
+ initial_reward_val = initial_reward.item() if initial_reward.numel() == 1 else initial_reward.mean().item()
167
+
168
+ # Initialize tracking lists
169
+ grad_norms = []
170
+ reward_history = [initial_reward_val]
171
+ lr_history = []
172
+
173
+ # 2. FORWARD PASS (model precision follows eval dtype; latents stay fp32 here)
174
+ reward = self.reward_model.get_reward_score(
175
+ current_latents.to(latents.dtype),
176
+ prompt,
177
+ timestep,
178
+ enable_grad=True,
179
+ return_logits=True,
180
+ )
181
+
182
+ reward_mean = reward.mean()
183
+ if not torch.isfinite(reward_mean):
184
+ if verbose:
185
+ print("?? WARNING: Non-finite reward encountered; skipping gradient step.")
186
+ rectified_latents = original_latents.clone()
187
+ final_latents = rectified_latents.detach().to(latents.dtype)
188
+ stats = {
189
+ 'timestep': timestep,
190
+ 'initial_reward': initial_reward_val,
191
+ 'final_reward': initial_reward_val,
192
+ 'reward_improvement': 0.0,
193
+ 'grad_norms': [0.0],
194
+ 'reward_history': reward_history,
195
+ 'lr_history': [0.0],
196
+ 'latent_change': 0.0,
197
+ }
198
+ self.grad_stats.append(stats)
199
+ return final_latents, stats
200
+
201
+ loss = -reward_mean
202
+ loss.backward()
203
+
204
+ # Extract latent gradient
205
+ raw_grad = current_latents.grad
206
+ if raw_grad is not None:
207
+ raw_grad = torch.nan_to_num(raw_grad, nan=0.0, posinf=0.0, neginf=0.0)
208
+ reward_history.append(torch.sigmoid(reward_mean).item())
209
+
210
+ # 3. ISO-MARGINAL PROJECTION WITH ASYMMETRIC INCLUSION
211
+ if raw_grad is not None and base_noise is not None and self.use_iso_projection:
212
+ gamma = 1e-8
213
+ B = raw_grad.shape[0]
214
+
215
+ grad_flat = raw_grad.view(B, -1)
216
+ noise_flat = base_noise.view(B, -1).to(torch.float32)
217
+
218
+ # Compute projection scalar for raw_grad (which is -?R)
219
+ dot_product = (grad_flat * noise_flat).sum(dim=1, keepdim=True)
220
+ noise_norm_sq = (noise_flat * noise_flat).sum(dim=1, keepdim=True)
221
+
222
+ proj_scalar = dot_product / (noise_norm_sq + gamma)
223
+ proj_scalar = proj_scalar.view(B, 1, 1, 1)
224
+
225
+ # 1. Decompose
226
+ grad_parallel = proj_scalar * base_noise.to(torch.float32)
227
+ grad_perp = raw_grad - grad_parallel
228
+
229
+ # 2. Asymmetric Inclusion
230
+ # proj_scalar > 0 means the applied step (+?R) points toward -epsilon (Denoising. GOOD.)
231
+ # proj_scalar < 0 means the applied step (+?R) points toward +epsilon (Noising. BAD.)
232
+ safe_proj_scalar = torch.clamp(proj_scalar, min=0.0)
233
+
234
+ beta = 1.0 # Retention factor for the safe parallel gradient
235
+ safe_grad_parallel = beta * (safe_proj_scalar * base_noise.to(torch.float32))
236
+
237
+ # 3. Recombine
238
+ grad_perp = grad_perp + safe_grad_parallel
239
+ else:
240
+ grad_perp = raw_grad
241
+ if base_noise is None and self.use_iso_projection:
242
+ print("?? WARNING: base_noise missing. Skipping Iso-Marginal projection.")
243
+
244
+ # 4. KINETIC RECTIFICATION (Applied to the projected latent gradient)
245
+ if grad_perp is not None:
246
+ grad_norm = grad_perp.float().norm().item()
247
+ max_abs_grad = grad_perp.float().abs().max().item()
248
+
249
+ recovered_with_fallback = False
250
+ if grad_norm <= 0 or max_abs_grad <= 0:
251
+ fallback_grad, _ = self.compute_reward_gradient(
252
+ original_latents,
253
+ prompt,
254
+ timestep,
255
+ )
256
+ fallback_grad = torch.nan_to_num(fallback_grad, nan=0.0, posinf=0.0, neginf=0.0)
257
+ fallback_grad = fallback_grad.to(dtype=original_latents.dtype)
258
+ fallback_norm = fallback_grad.float().norm().item()
259
+ fallback_max_abs = fallback_grad.float().abs().max().item()
260
+
261
+ if fallback_norm > 0 and fallback_max_abs > 0:
262
+ grad_perp = fallback_grad
263
+ grad_norm = fallback_norm
264
+ max_abs_grad = fallback_max_abs
265
+ recovered_with_fallback = True
266
+
267
+ if grad_norm > 0 and max_abs_grad > 0:
268
+ kinetic_direction = grad_perp / (grad_norm + 1e-8)
269
+
270
+ # Because the max element is 1.0, alpha is the EXACT float32 change applied.
271
+ alpha = self.grad_step_size
272
+
273
+ with torch.no_grad():
274
+ rectified_latents = original_latents - (alpha * kinetic_direction)
275
+ if recovered_with_fallback:
276
+ print(
277
+ "✓ Recovered collapsed gradient using fp32 fallback "
278
+ f"(norm={grad_norm:.3e}, max_abs={max_abs_grad:.3e})"
279
+ )
280
+ else:
281
+ print(
282
+ "?? WARNING: Gradient tensor exists but magnitude collapsed to zero "
283
+ f"(norm={grad_norm:.3e}, max_abs={max_abs_grad:.3e}, dtype={grad_perp.dtype})"
284
+ )
285
+ rectified_latents = original_latents.clone()
286
+ alpha = 0.0
287
+ max_grad = grad_norm
288
+ else:
289
+ print("?? FATAL: PyTorch completely dropped the latent gradient!")
290
+ rectified_latents = original_latents.clone()
291
+ max_grad = 0.0
292
+ alpha = 0.0
293
+
294
+ if verbose:
295
+ print(f" Grad step | LR: {alpha:.6f} | Reward: {reward.mean().item():.4f} | Max Grad: {max_grad:.4f}")
296
+
297
+ # 5. DOWNCAST AND RETURN
298
+ final_latents = rectified_latents.detach().to(latents.dtype)
299
+
300
+ with torch.no_grad():
301
+ final_reward = self.reward_model.get_reward_score(
302
+ final_latents, prompt, timestep
303
+ )
304
+ final_reward_val = final_reward.item() if final_reward.numel() == 1 else final_reward.mean().item()
305
+
306
+ stats = {
307
+ 'timestep': timestep,
308
+ 'initial_reward': initial_reward_val,
309
+ 'final_reward': final_reward_val,
310
+ 'reward_improvement': final_reward_val - initial_reward_val,
311
+ 'grad_norms': [max_grad],
312
+ 'reward_history': reward_history,
313
+ 'lr_history': [alpha], # Kept for plotting logic
314
+ 'latent_change': (final_latents - original_latents.to(latents.dtype)).norm().item(),
315
+ }
316
+
317
+ self.grad_stats.append(stats)
318
+
319
+ return final_latents, stats
320
+
321
+ def get_statistics(self) -> dict:
322
+ """Get aggregated statistics across all gradient ascent applications."""
323
+ if not self.grad_stats:
324
+ return {}
325
+
326
+ total_improvement = sum(s['reward_improvement'] for s in self.grad_stats)
327
+ avg_improvement = total_improvement / len(self.grad_stats)
328
+
329
+ all_grad_norms = [n for s in self.grad_stats for n in s['grad_norms']]
330
+
331
+ return {
332
+ 'num_applications': len(self.grad_stats),
333
+ 'total_reward_improvement': total_improvement,
334
+ 'avg_reward_improvement': avg_improvement,
335
+ 'avg_grad_norm': sum(all_grad_norms) / len(all_grad_norms) if all_grad_norms else 0,
336
+ 'max_grad_norm': max(all_grad_norms) if all_grad_norms else 0,
337
+ 'detailed_stats': self.grad_stats,
338
+ }
339
+
340
+ def reset_statistics(self):
341
+ """Reset statistics and global scheduler."""
342
+ self.grad_stats = []
343
+ self.global_lr_scheduler = None
344
+ self.timestep_counter = 0
345
+
346
+
347
+ def create_reward_guided_generator(
348
+ reward_model,
349
+ grad_timestep_range: Tuple[int, int] = (500, 700),
350
+ grad_scale: float = 1.0,
351
+ num_grad_steps: int = 5,
352
+ grad_step_size: float = 0.1,
353
+ lr_scheduler_type: str = "constant",
354
+ lr_scheduler_kwargs: Optional[dict] = None,
355
+ use_momentum: bool = False,
356
+ momentum: float = 0.9,
357
+ use_nesterov: bool = False,
358
+ use_iso_projection: bool = False
359
+ ) -> RewardGuidedDiffusion:
360
+ """
361
+ Convenience function to create a reward-guided diffusion generator.
362
+
363
+ Args:
364
+ reward_model: LRM reward model
365
+ grad_timestep_range: Tuple of (min_t, max_t) for applying gradients
366
+ grad_scale: Scale factor for gradient magnitude
367
+ num_grad_steps: Number of gradient ascent iterations per timestep
368
+ grad_step_size: Step size for each gradient update (initial LR)
369
+ lr_scheduler_type: Type of LR scheduler
370
+ lr_scheduler_kwargs: Additional kwargs for LR scheduler
371
+ use_momentum: Whether to use momentum
372
+ momentum: Momentum coefficient
373
+ use_nesterov: Whether to use Nesterov momentum
374
+ use_iso_projection: Whether to use Iso Projection
375
+
376
+ Returns:
377
+ RewardGuidedDiffusion instance
378
+ """
379
+ return RewardGuidedDiffusion(
380
+ reward_model=reward_model,
381
+ grad_scale=grad_scale,
382
+ grad_timestep_range=grad_timestep_range,
383
+ num_grad_steps=num_grad_steps,
384
+ grad_step_size=grad_step_size,
385
+ lr_scheduler_type=lr_scheduler_type,
386
+ lr_scheduler_kwargs=lr_scheduler_kwargs,
387
+ use_momentum=use_momentum,
388
+ momentum=momentum,
389
+ use_nesterov=use_nesterov,
390
+ use_iso_projection= False
391
+ )
Reward_sana_idealized/hpsv2_score.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from https://github.com/tgxs002/HPSv2. Originally Apache License, Version 2.0, January 2004.
3
+ """
4
+
5
+ import torch
6
+ from open_clip import create_model_and_transforms, get_tokenizer
7
+ from PIL import Image
8
+
9
+
10
+ class HPSv2Scorer():
11
+ def __init__(self, clip_pretrained_name_or_path, model_pretrained_name_or_path, device='cuda'):
12
+ self.model, _, self.preprocess_val = create_model_and_transforms(
13
+ 'ViT-H-14',
14
+ # 'laion2B-s32B-b79K',
15
+ clip_pretrained_name_or_path,
16
+ precision='amp',
17
+ device=device,
18
+ jit=False,
19
+ force_quick_gelu=False,
20
+ force_custom_text=False,
21
+ force_patch_dropout=False,
22
+ force_image_size=None,
23
+ pretrained_image=False,
24
+ image_mean=None,
25
+ image_std=None,
26
+ light_augmentation=True,
27
+ aug_cfg={},
28
+ output_dict=True,
29
+ with_score_predictor=False,
30
+ with_region_predictor=False
31
+ )
32
+ self.device = device
33
+ checkpoint = torch.load(model_pretrained_name_or_path, map_location=device)
34
+ self.model.load_state_dict(checkpoint['state_dict'])
35
+ self.tokenizer = get_tokenizer('ViT-H-14')
36
+ self.model = self.model.to(device)
37
+
38
+
39
+ def score(self, img_path, prompt):
40
+
41
+ if isinstance(img_path, list):
42
+ result = []
43
+ for one_img_path in img_path:
44
+ # Load your image and prompt
45
+ with torch.no_grad():
46
+ # Process the image
47
+ if isinstance(one_img_path, str):
48
+ image = self.preprocess_val(Image.open(one_img_path)).unsqueeze(0).to(device=self.device, non_blocking=True)
49
+ elif isinstance(one_img_path, Image.Image):
50
+ image = self.preprocess_val(one_img_path).unsqueeze(0).to(device=self.device, non_blocking=True)
51
+ else:
52
+ raise TypeError('The type of parameter img_path is illegal.')
53
+ # Process the prompt
54
+ text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
55
+ # Calculate the HPS
56
+ with torch.cuda.amp.autocast():
57
+ outputs = self.model(image, text)
58
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
59
+ logits_per_image = image_features @ text_features.T
60
+
61
+ hps_score = torch.diagonal(logits_per_image).cpu().numpy()
62
+ result.append(hps_score[0])
63
+ return result
64
+ elif isinstance(img_path, str):
65
+ # Load your image and prompt
66
+ with torch.no_grad():
67
+ # Process the image
68
+ image = self.preprocess_val(Image.open(img_path)).unsqueeze(0).to(device=self.device, non_blocking=True)
69
+ # Process the prompt
70
+ text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
71
+ # Calculate the HPS
72
+ with torch.cuda.amp.autocast():
73
+ outputs = self.model(image, text)
74
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
75
+ logits_per_image = image_features @ text_features.T
76
+
77
+ hps_score = torch.diagonal(logits_per_image).cpu().numpy()
78
+ return [hps_score[0]]
79
+ elif isinstance(img_path, Image.Image):
80
+ # Load your image and prompt
81
+ with torch.no_grad():
82
+ # Process the image
83
+ image = self.preprocess_val(img_path).unsqueeze(0).to(device=self.device, non_blocking=True)
84
+ # Process the prompt
85
+ text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
86
+ # Calculate the HPS
87
+ with torch.cuda.amp.autocast():
88
+ outputs = self.model(image, text)
89
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
90
+ logits_per_image = image_features @ text_features.T
91
+
92
+ hps_score = torch.diagonal(logits_per_image).cpu().numpy()
93
+ return [hps_score[0]]
94
+ else:
95
+ raise TypeError('The type of parameter img_path is illegal.')
96
+
97
+
98
+ if __name__ == "__main__":
99
+ from huggingface_hub import hf_hub_download
100
+
101
+ clip_model_path = hf_hub_download(repo_id="laion/CLIP-ViT-H-14-laion2B-s32B-b79K", filename="open_clip_pytorch_model.bin")
102
+ hps_model_path = hf_hub_download(repo_id="xswu/HPSv2", filename="HPS_v2_compressed.pt")
103
+
104
+ hpsv2_scorer = HPSv2Scorer(clip_pretrained_name_or_path=clip_model_path,
105
+ model_pretrained_name_or_path=hps_model_path)
106
+ score = hpsv2_scorer.score(img_path=['./image0.png', './image1.png'],
107
+ prompt='photorealistic image of a lone painter standing in a gallery, watching an exhibition of paintings made entirely with AI. In the foreground of the image a robot looks proudly at his art')
108
+
109
+ print(score)
110
+
Reward_sana_idealized/imagereward_score.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from https://github.com/THUDM/ImageReward. Originally Apache License, Version 2.0, January 2004.
3
+ """
4
+
5
+ import os
6
+ import torch
7
+ import torch.nn as nn
8
+ from io import BytesIO
9
+ from PIL import Image
10
+ from blip.blip_pretrain import BLIP_Pretrain
11
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
12
+ from typing import Any, Union, List
13
+
14
+ try:
15
+ from torchvision.transforms import InterpolationMode
16
+ BICUBIC = InterpolationMode.BICUBIC
17
+ except ImportError:
18
+ BICUBIC = Image.BICUBIC
19
+
20
+
21
+ def open_image(image):
22
+ if isinstance(image, bytes):
23
+ image = Image.open(BytesIO(image))
24
+ elif isinstance(image, str):
25
+ image = Image.open(image)
26
+ image = image.convert("RGB")
27
+ return image
28
+
29
+
30
+ def _convert_image_to_rgb(image):
31
+ return image.convert("RGB")
32
+
33
+
34
+ def _transform(n_px):
35
+ return Compose([
36
+ Resize(n_px, interpolation=BICUBIC),
37
+ CenterCrop(n_px),
38
+ _convert_image_to_rgb,
39
+ ToTensor(),
40
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
41
+ ])
42
+
43
+
44
+ class MLP(nn.Module):
45
+ def __init__(self, input_size):
46
+ super().__init__()
47
+ self.input_size = input_size
48
+
49
+ self.layers = nn.Sequential(
50
+ nn.Linear(self.input_size, 1024),
51
+ #nn.ReLU(),
52
+ nn.Dropout(0.2),
53
+ nn.Linear(1024, 128),
54
+ #nn.ReLU(),
55
+ nn.Dropout(0.2),
56
+ nn.Linear(128, 64),
57
+ #nn.ReLU(),
58
+ nn.Dropout(0.1),
59
+ nn.Linear(64, 16),
60
+ #nn.ReLU(),
61
+ nn.Linear(16, 1)
62
+ )
63
+
64
+ # initial MLP param
65
+ for name, param in self.layers.named_parameters():
66
+ if 'weight' in name:
67
+ nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1))
68
+ if 'bias' in name:
69
+ nn.init.constant_(param, val=0)
70
+
71
+ def forward(self, input):
72
+ return self.layers(input)
73
+
74
+
75
+ class ImageReward(nn.Module):
76
+ def __init__(self, med_config, device='cpu'):
77
+ super().__init__()
78
+ self.device = device
79
+
80
+ self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config)
81
+ self.preprocess = _transform(224)
82
+ self.mlp = MLP(768)
83
+
84
+ self.mean = 0.16717362830052426
85
+ self.std = 1.0333394966054072
86
+
87
+
88
+ def score_gard(self, prompt_ids, prompt_attention_mask, image):
89
+
90
+ image_embeds = self.blip.visual_encoder(image)
91
+ # text encode cross attention with image
92
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(self.device)
93
+ text_output = self.blip.text_encoder(prompt_ids,
94
+ attention_mask = prompt_attention_mask,
95
+ encoder_hidden_states = image_embeds,
96
+ encoder_attention_mask = image_atts,
97
+ return_dict = True,
98
+ )
99
+
100
+ txt_features = text_output.last_hidden_state[:,0,:] # (feature_dim)
101
+ rewards = self.mlp(txt_features)
102
+ rewards = (rewards - self.mean) / self.std
103
+
104
+ return rewards
105
+
106
+
107
+ def score(self, prompt, image):
108
+
109
+ if (type(image).__name__=='list'):
110
+ _, rewards = self.inference_rank(prompt, image)
111
+ return rewards
112
+
113
+ # text encode
114
+ text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
115
+
116
+ # image encode
117
+ if isinstance(image, Image.Image):
118
+ pil_image = image
119
+ elif isinstance(image, str):
120
+ if os.path.isfile(image):
121
+ pil_image = Image.open(image)
122
+ else:
123
+ raise TypeError(r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.')
124
+
125
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
126
+ image_embeds = self.blip.visual_encoder(image)
127
+
128
+ # text encode cross attention with image
129
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(self.device)
130
+ text_output = self.blip.text_encoder(text_input.input_ids,
131
+ attention_mask = text_input.attention_mask,
132
+ encoder_hidden_states = image_embeds,
133
+ encoder_attention_mask = image_atts,
134
+ return_dict = True,
135
+ )
136
+
137
+ txt_features = text_output.last_hidden_state[:,0,:].float() # (feature_dim)
138
+ rewards = self.mlp(txt_features)
139
+ rewards = (rewards - self.mean) / self.std
140
+
141
+ return rewards.detach().cpu().numpy().item()
142
+
143
+
144
+ def inference_rank(self, prompt, generations_list):
145
+
146
+ text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
147
+
148
+ txt_set = []
149
+ for generation in generations_list:
150
+ # image encode
151
+ if isinstance(generation, Image.Image):
152
+ pil_image = generation
153
+ elif isinstance(generation, str):
154
+ if os.path.isfile(generation):
155
+ pil_image = Image.open(generation)
156
+ else:
157
+ raise TypeError(r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.')
158
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
159
+ image_embeds = self.blip.visual_encoder(image)
160
+
161
+ # text encode cross attention with image
162
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(self.device)
163
+ text_output = self.blip.text_encoder(text_input.input_ids,
164
+ attention_mask = text_input.attention_mask,
165
+ encoder_hidden_states = image_embeds,
166
+ encoder_attention_mask = image_atts,
167
+ return_dict = True,
168
+ )
169
+ txt_set.append(text_output.last_hidden_state[:,0,:])
170
+
171
+ txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim]
172
+ rewards = self.mlp(txt_features) # [image_num, 1]
173
+ rewards = (rewards - self.mean) / self.std
174
+ rewards = torch.squeeze(rewards)
175
+ _, rank = torch.sort(rewards, dim=0, descending=True)
176
+ _, indices = torch.sort(rank, dim=0)
177
+ indices = indices + 1
178
+
179
+ return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
180
+
181
+
182
+ def load_imagereward(model_path: str, med_config: str = None, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu"):
183
+ """Load a ImageReward model
184
+
185
+ Parameters
186
+ ----------
187
+ name : str
188
+ A model name listed by `ImageReward.available_models()`, or the path to a model checkpoint containing the state_dict
189
+
190
+ device : Union[str, torch.device]
191
+ The device to put the loaded model
192
+
193
+
194
+ Returns
195
+ -------
196
+ model : torch.nn.Module
197
+ The ImageReward model
198
+ """
199
+ print('load checkpoint from %s'%model_path)
200
+ state_dict = torch.load(model_path, map_location='cpu')
201
+
202
+ model = ImageReward(device=device, med_config=med_config).to(device)
203
+ msg = model.load_state_dict(state_dict, strict=False)
204
+ print("checkpoint loaded")
205
+ model.eval()
206
+
207
+ return model
208
+
209
+
210
+ if __name__ == '__main__':
211
+ from huggingface_hub import hf_hub_download
212
+
213
+ model_path = hf_hub_download(repo_id="THUDM/ImageReward", filename="ImageReward.pt")
214
+ config_path = hf_hub_download(repo_id="THUDM/ImageReward", filename="med_config.json")
215
+
216
+ image0 = open_image('./image0.png')
217
+ image1 = open_image('./image1.png')
218
+ prompt = "photorealistic image of a lone painter standing in a gallery, watching an exhibition of paintings made entirely with AI. In the foreground of the image a robot looks proudly at his art"
219
+ model = load_imagereward(model_path=model_path, med_config=config_path, device='cuda')
220
+
221
+ print(model.score(prompt, [image0, image1]))
Reward_sana_idealized/lr_scheduler.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Learning rate schedulers for gradient ascent optimization.
3
+
4
+ Provides various LR scheduling strategies for reward-guided gradient ascent,
5
+ including cosine annealing, linear decay, and custom schedules.
6
+ """
7
+
8
+ import math
9
+ from typing import Optional, Literal
10
+
11
+
12
+ class LRScheduler:
13
+ """Base class for learning rate schedulers."""
14
+
15
+ def __init__(self, initial_lr: float, num_steps: int):
16
+ """
17
+ Initialize LR scheduler.
18
+
19
+ Args:
20
+ initial_lr: Initial learning rate
21
+ num_steps: Total number of optimization steps
22
+ """
23
+ self.initial_lr = initial_lr
24
+ self.num_steps = num_steps
25
+ self.current_step = 0
26
+
27
+ def get_lr(self) -> float:
28
+ """Get current learning rate."""
29
+ raise NotImplementedError
30
+
31
+ def step(self):
32
+ """Update scheduler state after a step."""
33
+ self.current_step += 1
34
+
35
+ def reset(self):
36
+ """Reset scheduler state."""
37
+ self.current_step = 0
38
+
39
+
40
+ class ConstantLR(LRScheduler):
41
+ """Constant learning rate (no scheduling)."""
42
+
43
+ def get_lr(self) -> float:
44
+ return self.initial_lr
45
+
46
+
47
+ class LinearLR(LRScheduler):
48
+ """Linear learning rate decay."""
49
+
50
+ def __init__(
51
+ self,
52
+ initial_lr: float,
53
+ num_steps: int,
54
+ end_lr: float = 0.0,
55
+ start_step: int = 0,
56
+ ):
57
+ """
58
+ Initialize linear LR scheduler.
59
+
60
+ Args:
61
+ initial_lr: Starting learning rate
62
+ num_steps: Total number of steps
63
+ end_lr: Ending learning rate (default: 0.0)
64
+ start_step: Step to begin decay (default: 0)
65
+ """
66
+ super().__init__(initial_lr, num_steps)
67
+ self.end_lr = end_lr
68
+ self.start_step = start_step
69
+
70
+ def get_lr(self) -> float:
71
+ if self.current_step < self.start_step:
72
+ return self.initial_lr
73
+
74
+ progress = (self.current_step - self.start_step) / (self.num_steps - self.start_step)
75
+ progress = min(1.0, progress)
76
+
77
+ return self.initial_lr + (self.end_lr - self.initial_lr) * progress
78
+
79
+
80
+ class CosineLR(LRScheduler):
81
+ """Cosine annealing learning rate schedule."""
82
+
83
+ def __init__(
84
+ self,
85
+ initial_lr: float,
86
+ num_steps: int,
87
+ min_lr: float = 0.0,
88
+ warmup_steps: int = 0,
89
+ ):
90
+ """
91
+ Initialize cosine LR scheduler.
92
+
93
+ Args:
94
+ initial_lr: Maximum learning rate
95
+ num_steps: Total number of steps
96
+ min_lr: Minimum learning rate (default: 0.0)
97
+ warmup_steps: Number of linear warmup steps (default: 0)
98
+ """
99
+ super().__init__(initial_lr, num_steps)
100
+ self.min_lr = min_lr
101
+ self.warmup_steps = warmup_steps
102
+
103
+ def get_lr(self) -> float:
104
+ if self.current_step < self.warmup_steps:
105
+ # Linear warmup
106
+ return self.initial_lr * (self.current_step / self.warmup_steps)
107
+
108
+ # Cosine annealing
109
+ progress = (self.current_step - self.warmup_steps) / (self.num_steps - self.warmup_steps)
110
+ progress = min(1.0, progress)
111
+
112
+ cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
113
+ return self.min_lr + (self.initial_lr - self.min_lr) * cosine_decay
114
+
115
+
116
+ class ExponentialLR(LRScheduler):
117
+ """Exponential learning rate decay."""
118
+
119
+ def __init__(
120
+ self,
121
+ initial_lr: float,
122
+ num_steps: int,
123
+ gamma: float = 0.95,
124
+ ):
125
+ """
126
+ Initialize exponential LR scheduler.
127
+
128
+ Args:
129
+ initial_lr: Starting learning rate
130
+ num_steps: Total number of steps
131
+ gamma: Multiplicative decay factor per step
132
+ """
133
+ super().__init__(initial_lr, num_steps)
134
+ self.gamma = gamma
135
+
136
+ def get_lr(self) -> float:
137
+ return self.initial_lr * (self.gamma ** self.current_step)
138
+
139
+
140
+ class StepLR(LRScheduler):
141
+ """Step-wise learning rate decay."""
142
+
143
+ def __init__(
144
+ self,
145
+ initial_lr: float,
146
+ num_steps: int,
147
+ step_size: int,
148
+ gamma: float = 0.1,
149
+ ):
150
+ """
151
+ Initialize step LR scheduler.
152
+
153
+ Args:
154
+ initial_lr: Starting learning rate
155
+ num_steps: Total number of steps
156
+ step_size: Number of steps between each decay
157
+ gamma: Multiplicative decay factor
158
+ """
159
+ super().__init__(initial_lr, num_steps)
160
+ self.step_size = step_size
161
+ self.gamma = gamma
162
+
163
+ def get_lr(self) -> float:
164
+ num_decays = self.current_step // self.step_size
165
+ return self.initial_lr * (self.gamma ** num_decays)
166
+
167
+
168
+ def create_lr_scheduler(
169
+ scheduler_type: Literal["constant", "linear", "cosine", "exponential", "step"],
170
+ initial_lr: float,
171
+ num_steps: int,
172
+ **kwargs
173
+ ) -> LRScheduler:
174
+ """
175
+ Factory function to create learning rate schedulers.
176
+
177
+ Args:
178
+ scheduler_type: Type of scheduler ("constant", "linear", "cosine", "exponential", "step")
179
+ initial_lr: Initial learning rate
180
+ num_steps: Total number of optimization steps
181
+ **kwargs: Additional scheduler-specific arguments
182
+ For linear: end_lr, start_step
183
+ For cosine: min_lr, warmup_steps
184
+ For exponential: gamma
185
+ For step: step_size, gamma
186
+
187
+ Returns:
188
+ LRScheduler instance
189
+
190
+ Examples:
191
+ # Constant LR
192
+ scheduler = create_lr_scheduler("constant", initial_lr=0.1, num_steps=100)
193
+
194
+ # Linear decay
195
+ scheduler = create_lr_scheduler("linear", initial_lr=0.1, num_steps=100, end_lr=0.01)
196
+
197
+ # Cosine annealing with warmup
198
+ scheduler = create_lr_scheduler("cosine", initial_lr=0.1, num_steps=100,
199
+ min_lr=0.001, warmup_steps=10)
200
+ """
201
+ if scheduler_type == "constant":
202
+ return ConstantLR(initial_lr, num_steps)
203
+
204
+ elif scheduler_type == "linear":
205
+ return LinearLR(
206
+ initial_lr, num_steps,
207
+ end_lr=kwargs.get("end_lr", 0.0),
208
+ start_step=kwargs.get("start_step", 0),
209
+ )
210
+
211
+ elif scheduler_type == "cosine":
212
+ return CosineLR(
213
+ initial_lr, num_steps,
214
+ min_lr=kwargs.get("min_lr", 0.0),
215
+ warmup_steps=kwargs.get("warmup_steps", 0),
216
+ )
217
+
218
+ elif scheduler_type == "exponential":
219
+ return ExponentialLR(
220
+ initial_lr, num_steps,
221
+ gamma=kwargs.get("gamma", 0.95),
222
+ )
223
+
224
+ elif scheduler_type == "step":
225
+ return StepLR(
226
+ initial_lr, num_steps,
227
+ step_size=kwargs.get("step_size", 10),
228
+ gamma=kwargs.get("gamma", 0.1),
229
+ )
230
+
231
+ else:
232
+ raise ValueError(f"Unknown scheduler type: {scheduler_type}. "
233
+ f"Choose from: constant, linear, cosine, exponential, step")
Reward_sana_idealized/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (285 Bytes). View file
 
Reward_sana_idealized/open_clip/__pycache__/coca_model.cpython-311.pyc ADDED
Binary file (18.4 kB). View file
 
Reward_sana_idealized/open_clip/__pycache__/factory.cpython-311.pyc ADDED
Binary file (19.3 kB). View file
 
Reward_sana_idealized/open_clip/__pycache__/model.cpython-311.pyc ADDED
Binary file (25.1 kB). View file
 
Reward_sana_idealized/open_clip/__pycache__/modified_resnet.cpython-311.pyc ADDED
Binary file (13.2 kB). View file
 
Reward_sana_idealized/open_clip/__pycache__/pretrained.cpython-311.pyc ADDED
Binary file (18.5 kB). View file
 
Reward_sana_idealized/open_clip/__pycache__/push_to_hf_hub.cpython-311.pyc ADDED
Binary file (9.29 kB). View file
 
Reward_sana_idealized/open_clip/__pycache__/timm_model.cpython-311.pyc ADDED
Binary file (6.73 kB). View file
 
Reward_sana_idealized/open_clip/__pycache__/tokenizer.cpython-311.pyc ADDED
Binary file (16 kB). View file
 
Reward_sana_idealized/open_clip/__pycache__/transformer.cpython-311.pyc ADDED
Binary file (42.6 kB). View file
 
Reward_sana_idealized/open_clip/model_configs/convnext_xlarge.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "timm_model_name": "convnext_xlarge",
5
+ "timm_model_pretrained": false,
6
+ "timm_pool": "",
7
+ "timm_proj": "linear",
8
+ "timm_drop": 0.0,
9
+ "timm_drop_path": 0.1,
10
+ "image_size": 256
11
+ },
12
+ "text_cfg": {
13
+ "context_length": 77,
14
+ "vocab_size": 49408,
15
+ "width": 1024,
16
+ "heads": 16,
17
+ "layers": 20
18
+ }
19
+ }
Reward_sana_idealized/pick_score.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from https://github.com/yuvalkirstain/PickScore. Originally MIT License, Copyright (c) 2021.
3
+ """
4
+
5
+
6
+ from transformers import AutoProcessor, AutoModel
7
+ from PIL import Image
8
+ import torch
9
+ import numpy as np
10
+ from transformers import AutoProcessor, AutoModel
11
+ from datasets import load_from_disk, load_dataset
12
+ import torch
13
+ from PIL import Image
14
+ from io import BytesIO
15
+ from tqdm.auto import tqdm
16
+ import sys
17
+
18
+
19
+ def open_image(image):
20
+ if isinstance(image, bytes):
21
+ image = Image.open(BytesIO(image))
22
+ elif isinstance(image, str):
23
+ image = Image.open(image)
24
+ image = image.convert("RGB")
25
+ return image
26
+
27
+
28
+
29
+ class PickScorer(torch.nn.Module):
30
+ def __init__(self, processor_name_or_path, model_pretrained_name_or_path, device='cuda'):
31
+ super().__init__()
32
+ self.processor = AutoProcessor.from_pretrained(processor_name_or_path)
33
+ self.model = AutoModel.from_pretrained(model_pretrained_name_or_path).to(device)
34
+ self.device = device
35
+ self.eval()
36
+
37
+ @torch.no_grad()
38
+ def __call__(self, prompt, images):
39
+ # preprocess
40
+ image_inputs = self.processor(
41
+ images=images,
42
+ padding=True,
43
+ truncation=True,
44
+ max_length=77,
45
+ return_tensors="pt",
46
+ ).to(self.device)
47
+
48
+ text_inputs = self.processor(
49
+ text=prompt,
50
+ padding=True,
51
+ truncation=True,
52
+ max_length=77,
53
+ return_tensors="pt",
54
+ ).to(self.device)
55
+
56
+ with torch.no_grad():
57
+ # embed
58
+ image_embs = self.model.get_image_features(**image_inputs)
59
+ image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
60
+
61
+ text_embs = self.model.get_text_features(**text_inputs)
62
+ text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
63
+
64
+ # score
65
+ scores = self.model.logit_scale.exp() * (text_embs @ image_embs.T)[0]
66
+
67
+ # get probabilities if you have multiple images to choose from
68
+ if len(scores) == 1:
69
+ probs = scores
70
+ else:
71
+ probs = torch.softmax(scores, dim=-1)
72
+
73
+ return probs.cpu().tolist()
74
+
75
+
76
+ def score(self, img_path, prompt):
77
+ if isinstance(img_path, list):
78
+ result = []
79
+ for one_img_path in img_path:
80
+ # Load your image and prompt
81
+ with torch.no_grad():
82
+ # Process the image
83
+ if isinstance(one_img_path, str):
84
+ image = self.preprocess_val(Image.open(one_img_path)).unsqueeze(0).to(device=self.device, non_blocking=True)
85
+ elif isinstance(one_img_path, Image.Image):
86
+ image = self.preprocess_val(one_img_path).unsqueeze(0).to(device=self.device, non_blocking=True)
87
+ else:
88
+ raise TypeError('The type of parameter img_path is illegal.')
89
+ # Process the prompt
90
+ text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
91
+ with torch.cuda.amp.autocast():
92
+ outputs = self.model(image, text)
93
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
94
+ logits_per_image = image_features @ text_features.T
95
+
96
+ pick_score = self.model.logit_scale.exp() * torch.diagonal(logits_per_image).cpu().numpy()
97
+ result.append(pick_score[0])
98
+ return result
99
+ elif isinstance(img_path, str):
100
+ # Load your image and prompt
101
+ with torch.no_grad():
102
+ # Process the image
103
+ image = self.preprocess_val(Image.open(img_path)).unsqueeze(0).to(device=self.device, non_blocking=True)
104
+ # Process the prompt
105
+ text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
106
+ with torch.cuda.amp.autocast():
107
+ outputs = self.model(image, text)
108
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
109
+ logits_per_image = image_features @ text_features.T
110
+
111
+ pick_score = self.model.logit_scale.exp() * torch.diagonal(logits_per_image).cpu().numpy()
112
+ return [pick_score[0]]
113
+ elif isinstance(img_path, Image.Image):
114
+ # Load your image and prompt
115
+ with torch.no_grad():
116
+ # Process the image
117
+ image = self.preprocess_val(img_path).unsqueeze(0).to(device=self.device, non_blocking=True)
118
+ # Process the prompt
119
+ text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
120
+ with torch.cuda.amp.autocast():
121
+ outputs = self.model(image, text)
122
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
123
+ logits_per_image = image_features @ text_features.T
124
+
125
+ pick_score = self.model.logit_scale.exp() * torch.diagonal(logits_per_image).cpu().numpy()
126
+ return [pick_score[0]]
127
+ else:
128
+ raise TypeError('The type of parameter img_path is illegal.')
129
+
130
+
131
+ if __name__ == "__main__":
132
+ pickscorer = PickScorer(processor_name_or_path="laion/CLIP-ViT-H-14-laion2B-s32B-b79K", model_pretrained_name_or_path="yuvalkirstain/PickScore_v1")
133
+
134
+ image0 = open_image('./image0.png')
135
+ image1 = open_image('./image1.png')
136
+ prompt = "photorealistic image of a lone painter standing in a gallery, watching an exhibition of paintings made entirely with AI. In the foreground of the image a robot looks proudly at his art"
137
+
138
+ probs = pickscorer(prompt, [image0])
139
+ probs1 = pickscorer(prompt, [image1])
140
+ print(probs)
141
+ print(probs1)
Reward_sana_idealized/test.ipynb ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "d8044219",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "data": {
11
+ "text/plain": [
12
+ "False"
13
+ ]
14
+ },
15
+ "execution_count": 1,
16
+ "metadata": {},
17
+ "output_type": "execute_result"
18
+ }
19
+ ],
20
+ "source": [
21
+ "import torch\n",
22
+ "torch.cuda.is_available()"
23
+ ]
24
+ }
25
+ ],
26
+ "metadata": {
27
+ "kernelspec": {
28
+ "display_name": "Python 3",
29
+ "language": "python",
30
+ "name": "python3"
31
+ },
32
+ "language_info": {
33
+ "codemirror_mode": {
34
+ "name": "ipython",
35
+ "version": 3
36
+ },
37
+ "file_extension": ".py",
38
+ "mimetype": "text/x-python",
39
+ "name": "python",
40
+ "nbconvert_exporter": "python",
41
+ "pygments_lexer": "ipython3",
42
+ "version": "3.10.18"
43
+ }
44
+ },
45
+ "nbformat": 4,
46
+ "nbformat_minor": 5
47
+ }
Reward_sana_idealized/tune_hyperparams.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hyperparameter tuning script for gradient ascent optimization.
3
+
4
+ This script performs a systematic search over hyperparameter combinations
5
+ to find the optimal configuration for maximum evaluation scores.
6
+ """
7
+
8
+ import subprocess
9
+ import json
10
+ import argparse
11
+ from pathlib import Path
12
+ from datetime import datetime
13
+ import itertools
14
+ import numpy as np
15
+ from typing import Dict, List, Any
16
+ import re
17
+
18
+
19
+ class HyperparameterTuner:
20
+ """Hyperparameter tuner for gradient ascent."""
21
+
22
+ def __init__(
23
+ self,
24
+ output_dir: str = "tuning_results",
25
+ max_samples: int = 30,
26
+ num_steps: int = 20,
27
+ dataset_type: str = "pickapic",
28
+ model_variant: str = "lpo",
29
+ cuda_id: int = 0,
30
+ metrics: List[str] = None
31
+ ):
32
+ self.output_dir = Path(output_dir)
33
+ self.output_dir.mkdir(parents=True, exist_ok=True)
34
+
35
+ self.max_samples = max_samples
36
+ self.num_steps = num_steps
37
+ self.dataset_type = dataset_type
38
+ self.model_variant = model_variant
39
+ self.cuda_id = cuda_id
40
+ self.metrics = metrics or ["clip", "aesthetic", "pickscore", "hpsv2", "imagereward"]
41
+
42
+ # Store results
43
+ self.results = []
44
+ self.baseline_results = None
45
+
46
+ def define_search_space(self) -> List[Dict[str, Any]]:
47
+ """Define the hyperparameter search space - FULL GRID SEARCH.
48
+
49
+ Tests all combinations of parameters including momentum overrides for configs that support it.
50
+ """
51
+
52
+ # Define all parameter values
53
+ cfg_scales = [3.0, 5.0, 7.5] #
54
+
55
+ # All available gradient configs from grad_ascent_configs.py
56
+ grad_configs = [
57
+ # "constant",
58
+ # "linear",
59
+ "cosine_nesterov",
60
+ # "low_to_high_nesterov",
61
+ # "high_to_low_nesterov",
62
+ "low_to_high_momentum",
63
+ "high_to_low_momentum",
64
+ ]
65
+
66
+ num_grad_steps_list = [1, 2] # 5, 7, 10
67
+ grad_step_sizes = [0.001, 0.005, 0.01, 0.05] #
68
+ momentums = [0.5, 0.8, 0.9] #
69
+
70
+ # Generate ALL combinations using itertools.product
71
+ configs = []
72
+ for cfg, grad_cfg, num_steps, step_size, momentum in itertools.product(
73
+ cfg_scales, grad_configs, num_grad_steps_list, grad_step_sizes, momentums
74
+ ):
75
+ configs.append({
76
+ "cfg_scale": cfg,
77
+ "grad_config": grad_cfg,
78
+ "num_grad_steps": num_steps,
79
+ "grad_step_size": step_size,
80
+ "momentum": momentum,
81
+ })
82
+
83
+ print(f"\nGenerated {len(configs)} total configurations")
84
+ print(f" cfg_scales: {len(cfg_scales)}")
85
+ print(f" grad_configs: {len(grad_configs)}")
86
+ print(f" num_grad_steps: {len(num_grad_steps_list)}")
87
+ print(f" grad_step_sizes: {len(grad_step_sizes)}")
88
+ print(f" momentums: {len(momentums)}")
89
+ print(f" Total: {len(cfg_scales)} × {len(grad_configs)} × {len(num_grad_steps_list)} × {len(grad_step_sizes)} × {len(momentums)} = {len(configs)}")
90
+
91
+ return configs
92
+
93
+ def run_baseline(self) -> Dict[str, float]:
94
+ """Run baseline evaluation once."""
95
+ print("\n" + "="*80)
96
+ print("RUNNING BASELINE EVALUATION")
97
+ print("="*80)
98
+
99
+ # Use median cfg_scale for baseline
100
+ cfg_scale = 5.0
101
+
102
+ output_dir = self.output_dir / "baseline"
103
+
104
+ cmd = [
105
+ "python", "eval.py",
106
+ "--model_variant", self.model_variant,
107
+ "--dataset_type", self.dataset_type,
108
+ "--max_samples", str(self.max_samples),
109
+ "--num_steps", str(self.num_steps),
110
+ "--cfg_scale", str(cfg_scale),
111
+ "--output_dir", str(output_dir),
112
+ "--cuda", str(self.cuda_id),
113
+ "--mode", "baseline",
114
+ "--metrics", *self.metrics,
115
+ ]
116
+
117
+ print(f"Command: {' '.join(cmd)}")
118
+
119
+ try:
120
+ result = subprocess.run(cmd, capture_output=True, text=True, check=True)
121
+
122
+ # Parse results from output
123
+ metrics = self._parse_metrics(result.stdout, "baseline")
124
+
125
+ print(f"\nBaseline Results:")
126
+ for metric, value in metrics.items():
127
+ print(f" {metric}: {value:.4f}")
128
+
129
+ self.baseline_results = {
130
+ "cfg_scale": cfg_scale,
131
+ "metrics": metrics,
132
+ }
133
+
134
+ return metrics
135
+
136
+ except subprocess.CalledProcessError as e:
137
+ print(f"Error running baseline: {e}")
138
+ print(f"Stdout: {e.stdout}")
139
+ print(f"Stderr: {e.stderr}")
140
+ return {}
141
+
142
+ def run_experiment(self, config: Dict[str, Any]) -> Dict[str, Any]:
143
+ """Run a single experiment with given hyperparameters."""
144
+
145
+ # Create output directory for this config
146
+ config_name = f"cfg{config['cfg_scale']}_" \
147
+ f"{config['grad_config']}_" \
148
+ f"steps{config['num_grad_steps']}_" \
149
+ f"lr{config['grad_step_size']}_" \
150
+ f"mom{config['momentum']}"
151
+
152
+ output_dir = self.output_dir / config_name
153
+
154
+ # Build command
155
+ cmd = [
156
+ "python", "eval.py",
157
+ "--model_variant", self.model_variant,
158
+ "--dataset_type", self.dataset_type,
159
+ "--grad_config", config["grad_config"],
160
+ "--max_samples", str(self.max_samples),
161
+ "--num_steps", str(self.num_steps),
162
+ "--cfg_scale", str(config["cfg_scale"]),
163
+ "--output_dir", str(output_dir),
164
+ "--cuda", str(self.cuda_id),
165
+ "--mode", "gradient_ascent",
166
+ "--metrics", *self.metrics,
167
+ # Override config parameters
168
+ "--override_num_grad_steps", str(config["num_grad_steps"]),
169
+ "--override_grad_step_size", str(config["grad_step_size"]),
170
+ "--override_momentum", str(config["momentum"]),
171
+ ]
172
+
173
+ print(f"\nRunning experiment: {config_name}")
174
+ print(f"Config: {config}")
175
+
176
+ try:
177
+ result = subprocess.run(cmd, capture_output=True, text=True, check=True)
178
+
179
+ # Parse metrics from output
180
+ metrics = self._parse_metrics(result.stdout, "gradient_ascent")
181
+
182
+ # Compute improvement over baseline
183
+ improvements = {}
184
+ if self.baseline_results:
185
+ baseline_metrics = self.baseline_results["metrics"]
186
+ for metric, value in metrics.items():
187
+ if metric in baseline_metrics:
188
+ baseline_val = baseline_metrics[metric]
189
+ if baseline_val != 0:
190
+ improvement = ((value - baseline_val) / abs(baseline_val)) * 100
191
+ improvements[f"{metric}_improvement"] = improvement
192
+
193
+ result_dict = {
194
+ "config": config,
195
+ "metrics": metrics,
196
+ "improvements": improvements,
197
+ "output_dir": str(output_dir),
198
+ "timestamp": datetime.now().isoformat(),
199
+ }
200
+
201
+ print(f"Results:")
202
+ for metric, value in metrics.items():
203
+ print(f" {metric}: {value:.4f}")
204
+ if improvements:
205
+ print(f"Improvements over baseline:")
206
+ for metric, value in improvements.items():
207
+ print(f" {metric}: {value:+.2f}%")
208
+
209
+ return result_dict
210
+
211
+ except subprocess.CalledProcessError as e:
212
+ print(f"Error running experiment: {e}")
213
+ print(f"Stderr: {e.stderr}")
214
+ return {
215
+ "config": config,
216
+ "error": str(e),
217
+ "timestamp": datetime.now().isoformat(),
218
+ }
219
+
220
+ def _parse_metrics(self, output: str, mode: str) -> Dict[str, float]:
221
+ """Parse metrics from eval.py output."""
222
+ metrics = {}
223
+
224
+ # Look for the summary section
225
+ lines = output.split('\n')
226
+
227
+ # Pattern to match metric lines like " Reward: 0.1234"
228
+ metric_patterns = {
229
+ "reward": r"Reward:\s+([-+]?\d*\.?\d+)",
230
+ "clip": r"CLIP Score:\s+([-+]?\d*\.?\d+)",
231
+ "aesthetic": r"Aesthetic Score:\s+([-+]?\d*\.?\d+)",
232
+ "pickscore": r"PickScore:\s+([-+]?\d*\.?\d+)",
233
+ "hpsv2": r"HPSv2 Score:\s+([-+]?\d*\.?\d+)",
234
+ "hpsv21": r"HPSv2\.1 Score:\s+([-+]?\d*\.?\d+)",
235
+ "imagereward": r"ImageReward:\s+([-+]?\d*\.?\d+)",
236
+ "fid": r"FID:\s+([-+]?\d*\.?\d+)",
237
+ }
238
+
239
+ for line in lines:
240
+ for metric_name, pattern in metric_patterns.items():
241
+ match = re.search(pattern, line)
242
+ if match:
243
+ metrics[metric_name] = float(match.group(1))
244
+
245
+ return metrics
246
+
247
+ def compute_aggregate_score(self, metrics: Dict[str, float]) -> float:
248
+ """
249
+ Compute aggregate score for ranking configurations.
250
+
251
+ Uses weighted combination of metrics (higher is better for most,
252
+ except FID which is lower is better).
253
+ """
254
+ weights = {
255
+ "reward": 1.0,
256
+ "clip": 0.8,
257
+ "aesthetic": 0.8,
258
+ "pickscore": 1.0,
259
+ "hpsv2": 1.0,
260
+ "hpsv21": 1.0,
261
+ "imagereward": 1.0,
262
+ "fid": -0.5, # Negative weight (lower FID is better)
263
+ }
264
+
265
+ score = 0.0
266
+ total_weight = 0.0
267
+
268
+ for metric, value in metrics.items():
269
+ if metric in weights:
270
+ score += weights[metric] * value
271
+ total_weight += abs(weights[metric])
272
+
273
+ # Normalize by total weight
274
+ if total_weight > 0:
275
+ score /= total_weight
276
+
277
+ return score
278
+
279
+ def run_search(
280
+ self,
281
+ search_type: str = "grid",
282
+ start_idx: int = 0,
283
+ end_idx: int = None
284
+ ) -> List[Dict[str, Any]]:
285
+ """
286
+ Run hyperparameter search.
287
+
288
+ Args:
289
+ search_type: Type of search ("grid" or "random")
290
+ start_idx: Starting index for experiments (for GPU distribution)
291
+ end_idx: Ending index for experiments (for GPU distribution)
292
+ """
293
+ all_configs = self.define_search_space()
294
+
295
+ print("\n" + "="*80)
296
+ print("HYPERPARAMETER SEARCH CONFIGURATION")
297
+ print("="*80)
298
+ print(f"Dataset: {self.dataset_type}")
299
+ print(f"Model: {self.model_variant}")
300
+ print(f"Samples: {self.max_samples}")
301
+ print(f"Inference steps: {self.num_steps}")
302
+ print(f"Metrics: {', '.join(self.metrics)}")
303
+
304
+ # Select subset of configs if indices provided
305
+ if search_type == "grid":
306
+ configs = all_configs
307
+ elif search_type == "random":
308
+ # Random sample from all configs
309
+ n_samples = min(50, len(all_configs))
310
+ indices = np.random.choice(len(all_configs), n_samples, replace=False)
311
+ configs = [all_configs[i] for i in indices]
312
+ else:
313
+ raise ValueError(f"Unknown search type: {search_type}")
314
+
315
+ # Apply index slicing for GPU distribution
316
+ if end_idx is None:
317
+ end_idx = len(configs)
318
+ configs = configs[start_idx:end_idx]
319
+
320
+ print(f"\nTotal configurations: {len(all_configs)}")
321
+ print(f"Assigned to this worker: {len(configs)} (indices {start_idx} to {end_idx})")
322
+
323
+ # Run baseline first
324
+ if self.baseline_results is None:
325
+ self.run_baseline()
326
+
327
+ # Run experiments
328
+ print("\n" + "="*80)
329
+ print("RUNNING EXPERIMENTS")
330
+ print("="*80)
331
+
332
+ for i, config in enumerate(configs, 1):
333
+ print(f"\n{'='*80}")
334
+ print(f"Experiment {i}/{len(configs)}")
335
+ print(f"{'='*80}")
336
+
337
+ result = self.run_experiment(config)
338
+ self.results.append(result)
339
+
340
+ # Save intermediate results
341
+ self._save_results()
342
+
343
+ return self.results
344
+
345
+ def _generate_grid_configs(self, search_space: Dict[str, List[Any]]) -> List[Dict[str, Any]]:
346
+ """Generate all combinations for grid search."""
347
+ keys = list(search_space.keys())
348
+ values = list(search_space.values())
349
+
350
+ configs = []
351
+ for combination in itertools.product(*values):
352
+ config = dict(zip(keys, combination))
353
+ configs.append(config)
354
+
355
+ return configs
356
+
357
+ def _generate_random_configs(
358
+ self,
359
+ search_space: Dict[str, List[Any]],
360
+ n_samples: int = 20
361
+ ) -> List[Dict[str, Any]]:
362
+ """Generate random configurations for random search."""
363
+ configs = []
364
+
365
+ for _ in range(n_samples):
366
+ config = {}
367
+ for param, values in search_space.items():
368
+ config[param] = np.random.choice(values)
369
+ configs.append(config)
370
+
371
+ return configs
372
+
373
+ def _save_results(self):
374
+ """Save results to JSON file."""
375
+ results_file = self.output_dir / "tuning_results.json"
376
+
377
+ data = {
378
+ "baseline": self.baseline_results,
379
+ "experiments": self.results,
380
+ "timestamp": datetime.now().isoformat(),
381
+ "config": {
382
+ "max_samples": self.max_samples,
383
+ "num_steps": self.num_steps,
384
+ "dataset_type": self.dataset_type,
385
+ "model_variant": self.model_variant,
386
+ }
387
+ }
388
+
389
+ with open(results_file, 'w') as f:
390
+ json.dump(data, f, indent=2)
391
+
392
+ print(f"\nResults saved to: {results_file}")
393
+
394
+ def analyze_results(self) -> Dict[str, Any]:
395
+ """Analyze results and find best configuration."""
396
+ if not self.results:
397
+ print("No results to analyze!")
398
+ return {}
399
+
400
+ print("\n" + "="*80)
401
+ print("ANALYSIS: FINDING BEST CONFIGURATION")
402
+ print("="*80)
403
+
404
+ # Filter out failed experiments
405
+ successful_results = [r for r in self.results if "metrics" in r]
406
+
407
+ if not successful_results:
408
+ print("No successful experiments!")
409
+ return {}
410
+
411
+ # Compute aggregate scores
412
+ for result in successful_results:
413
+ metrics = result["metrics"]
414
+ result["aggregate_score"] = self.compute_aggregate_score(metrics)
415
+
416
+ # Sort by aggregate score
417
+ successful_results.sort(key=lambda x: x["aggregate_score"], reverse=True)
418
+
419
+ # Print top 5 configurations
420
+ print("\nTop 5 Configurations:")
421
+ print("="*80)
422
+
423
+ for i, result in enumerate(successful_results[:5], 1):
424
+ print(f"\n#{i} - Aggregate Score: {result['aggregate_score']:.4f}")
425
+ print(f"Config: {result['config']}")
426
+ print(f"Metrics:")
427
+ for metric, value in result['metrics'].items():
428
+ print(f" {metric}: {value:.4f}")
429
+ if result.get('improvements'):
430
+ print(f"Improvements over baseline:")
431
+ for metric, value in result['improvements'].items():
432
+ print(f" {metric}: {value:+.2f}%")
433
+
434
+ # Save best config
435
+ best_result = successful_results[0]
436
+ best_config_file = self.output_dir / "best_config.json"
437
+
438
+ with open(best_config_file, 'w') as f:
439
+ json.dump({
440
+ "config": best_result["config"],
441
+ "metrics": best_result["metrics"],
442
+ "aggregate_score": best_result["aggregate_score"],
443
+ "improvements": best_result.get("improvements", {}),
444
+ }, f, indent=2)
445
+
446
+ print(f"\n✓ Best configuration saved to: {best_config_file}")
447
+
448
+ return best_result
449
+
450
+
451
+ def main():
452
+ parser = argparse.ArgumentParser(description="Hyperparameter tuning for gradient ascent")
453
+ parser.add_argument("--output_dir", type=str, default="tuning_results",
454
+ help="Directory to save tuning results")
455
+ parser.add_argument("--max_samples", type=int, default=30,
456
+ help="Number of samples to use for tuning")
457
+ parser.add_argument("--num_steps", type=int, default=20,
458
+ help="Number of inference steps (fixed)")
459
+ parser.add_argument("--dataset_type", type=str, default="pickapic",
460
+ choices=["coco", "pickapic"],
461
+ help="Dataset to use")
462
+ parser.add_argument("--model_variant", type=str, default="lpo",
463
+ choices=["origin", "spo", "diffusion_dpo", "lpo"],
464
+ help="Model variant to use")
465
+ parser.add_argument("--cuda", type=int, default=0,
466
+ help="CUDA device ID")
467
+ parser.add_argument("--search_type", type=str, default="grid",
468
+ choices=["grid", "random"],
469
+ help="Type of hyperparameter search")
470
+ parser.add_argument("--metrics", type=str, nargs="+",
471
+ default=["clip", "aesthetic", "pickscore", "hpsv2", "imagereward"],
472
+ help="Metrics to evaluate")
473
+ parser.add_argument("--start_idx", type=int, default=0,
474
+ help="Starting index for experiments (for GPU distribution)")
475
+ parser.add_argument("--end_idx", type=int, default=None,
476
+ help="Ending index for experiments (for GPU distribution)")
477
+
478
+ args = parser.parse_args()
479
+
480
+ # Create tuner
481
+ tuner = HyperparameterTuner(
482
+ output_dir=args.output_dir,
483
+ max_samples=args.max_samples,
484
+ num_steps=args.num_steps,
485
+ dataset_type=args.dataset_type,
486
+ model_variant=args.model_variant,
487
+ cuda_id=args.cuda,
488
+ metrics=args.metrics,
489
+ )
490
+
491
+ # Run search
492
+ results = tuner.run_search(
493
+ search_type=args.search_type,
494
+ start_idx=args.start_idx,
495
+ end_idx=args.end_idx
496
+ )
497
+
498
+ # Analyze results
499
+ best_result = tuner.analyze_results()
500
+
501
+ print("\n" + "="*80)
502
+ print("TUNING COMPLETE!")
503
+ print("="*80)
504
+ print(f"Total experiments: {len(results)}")
505
+ print(f"Results directory: {args.output_dir}")
506
+
507
+ if best_result:
508
+ print(f"\nBest configuration:")
509
+ print(json.dumps(best_result["config"], indent=2))
510
+ print(f"\nAggregate score: {best_result['aggregate_score']:.4f}")
511
+
512
+
513
+ if __name__ == "__main__":
514
+ main()
Reward_sana_idealized/tune_parallel.sh ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Parallel hyperparameter tuning across 8 GPUs
4
+ # This script distributes experiments evenly across all available GPUs
5
+
6
+ clear
7
+
8
+ # Activate conda environment
9
+ source ~/miniconda3/etc/profile.d/conda.sh
10
+ conda activate /home/ec2-user/aev
11
+
12
+ # Configuration
13
+ DATASET_TYPE="pickapic" # "coco" or "pickapic"
14
+ MODEL_VARIANT="lpo" # "origin", "spo", "diffusion_dpo", or "lpo"
15
+ MAX_SAMPLES=500 # Number of samples for tuning
16
+ NUM_STEPS=50 # Fixed inference steps
17
+ SEARCH_TYPE="grid" # "grid" or "random"
18
+ OUTPUT_DIR="RESULTS_TURNING/run_2"
19
+ NUM_GPUS=8 # Number of GPUs to use
20
+
21
+ echo "=============================================="
22
+ echo " PARALLEL HYPERPARAMETER TUNING"
23
+ echo "=============================================="
24
+ echo ""
25
+ echo "Configuration:"
26
+ echo " Dataset: $DATASET_TYPE"
27
+ echo " Model: $MODEL_VARIANT"
28
+ echo " Samples: $MAX_SAMPLES"
29
+ echo " Inference Steps: $NUM_STEPS"
30
+ echo " Search Type: $SEARCH_TYPE"
31
+ echo " GPUs: $NUM_GPUS"
32
+ echo " Output: $OUTPUT_DIR"
33
+ echo ""
34
+
35
+ # First, calculate total number of experiments
36
+ echo "Calculating total experiments..."
37
+ TOTAL_CONFIGS=$(python -c "
38
+ from tune_hyperparams import HyperparameterTuner
39
+ import sys
40
+ tuner = HyperparameterTuner()
41
+ configs = tuner.define_search_space()
42
+ sys.stderr.write(f'Generated {len(configs)} configurations\n')
43
+ print(len(configs))
44
+ " 2>&1 | tail -1)
45
+
46
+ echo "Total configurations: $TOTAL_CONFIGS"
47
+ echo ""
48
+
49
+ # Calculate experiments per GPU
50
+ CONFIGS_PER_GPU=$((TOTAL_CONFIGS / NUM_GPUS))
51
+ REMAINDER=$((TOTAL_CONFIGS % NUM_GPUS))
52
+
53
+ echo "Distributing work:"
54
+ echo " Base configs per GPU: $CONFIGS_PER_GPU"
55
+ echo " Extra configs for first GPUs: $REMAINDER"
56
+ echo ""
57
+
58
+ # Create output directory
59
+ mkdir -p "$OUTPUT_DIR"
60
+
61
+ # Array to store background process IDs
62
+ PIDS=()
63
+
64
+ # Launch parallel processes on each GPU
65
+ for GPU_ID in $(seq 0 $((NUM_GPUS - 1))); do
66
+ # Calculate start and end indices for this GPU
67
+ START_IDX=$((GPU_ID * CONFIGS_PER_GPU))
68
+
69
+ # Give extra configs to first GPUs
70
+ if [ $GPU_ID -lt $REMAINDER ]; then
71
+ START_IDX=$((START_IDX + GPU_ID))
72
+ END_IDX=$((START_IDX + CONFIGS_PER_GPU + 1))
73
+ else
74
+ START_IDX=$((START_IDX + REMAINDER))
75
+ END_IDX=$((START_IDX + CONFIGS_PER_GPU))
76
+ fi
77
+
78
+ # Create GPU-specific output directory
79
+ GPU_OUTPUT_DIR="${OUTPUT_DIR}/gpu_${GPU_ID}"
80
+ mkdir -p "$GPU_OUTPUT_DIR"
81
+
82
+ echo "GPU $GPU_ID: configs $START_IDX to $END_IDX"
83
+
84
+ # Launch tuning process in background
85
+ nohup python tune_hyperparams.py \
86
+ --output_dir "$GPU_OUTPUT_DIR" \
87
+ --max_samples $MAX_SAMPLES \
88
+ --num_steps $NUM_STEPS \
89
+ --dataset_type "$DATASET_TYPE" \
90
+ --model_variant "$MODEL_VARIANT" \
91
+ --cuda $GPU_ID \
92
+ --search_type "$SEARCH_TYPE" \
93
+ --start_idx $START_IDX \
94
+ --end_idx $END_IDX \
95
+ --metrics clip aesthetic pickscore hpsv2 imagereward \
96
+ > "${GPU_OUTPUT_DIR}/tuning.log" 2>&1 &
97
+
98
+ # Store PID
99
+ PIDS+=($!)
100
+
101
+ echo " Launched with PID: ${PIDS[$GPU_ID]}"
102
+
103
+ # Small delay to avoid race conditions
104
+ sleep 2
105
+ done
106
+
107
+ echo ""
108
+ echo "=============================================="
109
+ echo " ALL PROCESSES LAUNCHED"
110
+ echo "=============================================="
111
+ echo ""
112
+ echo "Background processes running:"
113
+ for GPU_ID in $(seq 0 $((NUM_GPUS - 1))); do
114
+ echo " GPU $GPU_ID: PID ${PIDS[$GPU_ID]} -> ${OUTPUT_DIR}/gpu_${GPU_ID}/tuning.log"
115
+ done
116
+ echo ""
117
+ echo "To monitor progress:"
118
+ echo " tail -f ${OUTPUT_DIR}/gpu_0/tuning.log"
119
+ echo " tail -f ${OUTPUT_DIR}/gpu_1/tuning.log"
120
+ echo " ... etc"
121
+ echo ""
122
+ echo "To check all GPU processes:"
123
+ echo " ps aux | grep tune_hyperparams.py"
124
+ echo ""
125
+ echo "To monitor GPU usage:"
126
+ echo " watch -n 1 nvidia-smi"
127
+ echo ""
128
+ echo "To kill all processes:"
129
+ echo " kill ${PIDS[@]}"
130
+ echo ""
131
+ echo "Waiting for all processes to complete..."
132
+ echo "(Press Ctrl+C to stop waiting, processes will continue in background)"
133
+ echo ""
134
+
135
+ # Wait for all background processes
136
+ for PID in "${PIDS[@]}"; do
137
+ wait $PID
138
+ done
139
+
140
+ echo ""
141
+ echo "=============================================="
142
+ echo " ALL TUNING PROCESSES COMPLETE"
143
+ echo "=============================================="
144
+ echo ""
145
+
146
+ # Merge results from all GPUs
147
+ echo "Merging results from all GPUs..."
148
+
149
+ # Activate conda environment for Python script
150
+ source ~/miniconda3/etc/profile.d/conda.sh
151
+ conda activate /home/ec2-user/aev
152
+
153
+ python - <<'EOF'
154
+ import json
155
+ from pathlib import Path
156
+ import sys
157
+
158
+ output_dir = Path("RESULTS_TURNING")
159
+ all_results = []
160
+ baseline_result = None
161
+
162
+ # Collect results from each GPU
163
+ for gpu_id in range(8):
164
+ gpu_dir = output_dir / f"gpu_{gpu_id}"
165
+ results_file = gpu_dir / "tuning_results.json"
166
+
167
+ if results_file.exists():
168
+ with open(results_file, 'r') as f:
169
+ data = json.load(f)
170
+
171
+ # Get baseline (should be same from all)
172
+ if baseline_result is None and "baseline" in data:
173
+ baseline_result = data["baseline"]
174
+
175
+ # Collect experiments
176
+ if "experiments" in data:
177
+ all_results.extend(data["experiments"])
178
+
179
+ print(f"GPU {gpu_id}: {len(data.get('experiments', []))} results")
180
+
181
+ # Merge all results
182
+ merged_data = {
183
+ "baseline": baseline_result,
184
+ "experiments": all_results,
185
+ "num_gpus": 8,
186
+ "total_experiments": len(all_results)
187
+ }
188
+
189
+ # Save merged results
190
+ merged_file = output_dir / "merged_results.json"
191
+ with open(merged_file, 'w') as f:
192
+ json.dump(merged_data, f, indent=2)
193
+
194
+ print(f"\nMerged {len(all_results)} total results")
195
+ print(f"Saved to: {merged_file}")
196
+
197
+ # Find best configuration
198
+ successful = [r for r in all_results if "metrics" in r]
199
+ if successful:
200
+ # Compute aggregate scores
201
+ def compute_score(metrics):
202
+ weights = {
203
+ "reward": 1.0, "clip": 0.8, "aesthetic": 0.8,
204
+ "pickscore": 1.0, "hpsv2": 1.0, "imagereward": 1.0,
205
+ "fid": -0.5
206
+ }
207
+ score = sum(weights.get(k, 0) * v for k, v in metrics.items())
208
+ return score / sum(abs(w) for w in weights.values())
209
+
210
+ for r in successful:
211
+ r["aggregate_score"] = compute_score(r["metrics"])
212
+
213
+ successful.sort(key=lambda x: x["aggregate_score"], reverse=True)
214
+
215
+ best = successful[0]
216
+ best_file = output_dir / "best_config.json"
217
+ with open(best_file, 'w') as f:
218
+ json.dump({
219
+ "config": best["config"],
220
+ "metrics": best["metrics"],
221
+ "aggregate_score": best["aggregate_score"],
222
+ "improvements": best.get("improvements", {})
223
+ }, f, indent=2)
224
+
225
+ print(f"\n{'='*60}")
226
+ print("BEST CONFIGURATION:")
227
+ print(f"{'='*60}")
228
+ print(json.dumps(best["config"], indent=2))
229
+ print(f"\nAggregate Score: {best['aggregate_score']:.4f}")
230
+ print(f"Saved to: {best_file}")
231
+ else:
232
+ print("\nNo successful experiments found!")
233
+ sys.exit(1)
234
+ EOF
235
+
236
+ if [ $? -eq 0 ]; then
237
+ echo ""
238
+ echo "=============================================="
239
+ echo " TUNING COMPLETE!"
240
+ echo "=============================================="
241
+ echo ""
242
+ echo "Results:"
243
+ echo " Merged results: ${OUTPUT_DIR}/merged_results.json"
244
+ echo " Best config: ${OUTPUT_DIR}/best_config.json"
245
+ echo ""
246
+ echo "View best configuration:"
247
+ echo " cat ${OUTPUT_DIR}/best_config.json"
248
+ echo ""
249
+ else
250
+ echo ""
251
+ echo "ERROR: Failed to merge results"
252
+ exit 1
253
+ fi
Reward_sdxl_idealized/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (242 Bytes). View file
 
Reward_sdxl_idealized/models/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (280 Bytes). View file
 
Reward_sdxl_idealized/models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (240 Bytes). View file
 
Reward_sdxl_idealized/models/__pycache__/reward_model.cpython-39.pyc ADDED
Binary file (9.1 kB). View file
 
Reward_sdxl_idealized/models/__pycache__/reward_model_sdxl.cpython-310.pyc ADDED
Binary file (9.96 kB). View file