Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Reward_sana_idealized/README.md +41 -0
- Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_1/evaluation_results.txt +4 -0
- Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_1/log.log +203 -0
- Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_2/evaluation_results.txt +4 -0
- Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_2/log.log +258 -0
- Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_2/lr_curve.png +0 -0
- Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_3/evaluation_results.txt +4 -0
- Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_3/log.log +218 -0
- Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_3/lr_curve.png +0 -0
- Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/evaluation_results.txt +4 -0
- Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/log.log +218 -0
- Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/lr_curve.png +0 -0
- Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/rewards_curve.png +0 -0
- Reward_sana_idealized/__pycache__/eval.cpython-311.pyc +0 -0
- Reward_sana_idealized/__pycache__/gradient_ascent_utils.cpython-311.pyc +0 -0
- Reward_sana_idealized/blip/__init__.py +1 -0
- Reward_sana_idealized/blip/__pycache__/__init__.cpython-311.pyc +0 -0
- Reward_sana_idealized/blip/__pycache__/blip.cpython-311.pyc +0 -0
- Reward_sana_idealized/blip/__pycache__/blip_pretrain.cpython-311.pyc +0 -0
- Reward_sana_idealized/blip/__pycache__/med.cpython-311.pyc +0 -0
- Reward_sana_idealized/blip/blip.py +70 -0
- Reward_sana_idealized/blip/blip_pretrain.py +43 -0
- Reward_sana_idealized/config_analysis_tuning.ipynb +218 -0
- Reward_sana_idealized/eval.py +1447 -0
- Reward_sana_idealized/examples.sh +162 -0
- Reward_sana_idealized/grad_ascent_configs.py +67 -0
- Reward_sana_idealized/gradient_ascent_utils.py +391 -0
- Reward_sana_idealized/hpsv2_score.py +110 -0
- Reward_sana_idealized/imagereward_score.py +221 -0
- Reward_sana_idealized/lr_scheduler.py +233 -0
- Reward_sana_idealized/models/__pycache__/__init__.cpython-311.pyc +0 -0
- Reward_sana_idealized/open_clip/__pycache__/coca_model.cpython-311.pyc +0 -0
- Reward_sana_idealized/open_clip/__pycache__/factory.cpython-311.pyc +0 -0
- Reward_sana_idealized/open_clip/__pycache__/model.cpython-311.pyc +0 -0
- Reward_sana_idealized/open_clip/__pycache__/modified_resnet.cpython-311.pyc +0 -0
- Reward_sana_idealized/open_clip/__pycache__/pretrained.cpython-311.pyc +0 -0
- Reward_sana_idealized/open_clip/__pycache__/push_to_hf_hub.cpython-311.pyc +0 -0
- Reward_sana_idealized/open_clip/__pycache__/timm_model.cpython-311.pyc +0 -0
- Reward_sana_idealized/open_clip/__pycache__/tokenizer.cpython-311.pyc +0 -0
- Reward_sana_idealized/open_clip/__pycache__/transformer.cpython-311.pyc +0 -0
- Reward_sana_idealized/open_clip/model_configs/convnext_xlarge.json +19 -0
- Reward_sana_idealized/pick_score.py +141 -0
- Reward_sana_idealized/test.ipynb +47 -0
- Reward_sana_idealized/tune_hyperparams.py +514 -0
- Reward_sana_idealized/tune_parallel.sh +253 -0
- Reward_sdxl_idealized/models/__pycache__/__init__.cpython-310.pyc +0 -0
- Reward_sdxl_idealized/models/__pycache__/__init__.cpython-313.pyc +0 -0
- Reward_sdxl_idealized/models/__pycache__/__init__.cpython-39.pyc +0 -0
- Reward_sdxl_idealized/models/__pycache__/reward_model.cpython-39.pyc +0 -0
- Reward_sdxl_idealized/models/__pycache__/reward_model_sdxl.cpython-310.pyc +0 -0
Reward_sana_idealized/README.md
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Reward SANA Idealized
|
| 2 |
+
|
| 3 |
+
This folder is a SANA-only reward-guided inference package.
|
| 4 |
+
|
| 5 |
+
## What is inside
|
| 6 |
+
|
| 7 |
+
- `models/reward_model.py`
|
| 8 |
+
- Local SANA reward wrapper (no trainer import from other directories).
|
| 9 |
+
- Loads base SANA diffusers modules and local reward checkpoint weights.
|
| 10 |
+
- `pipelines/sana_reward_pipeline.py`
|
| 11 |
+
- SANA pipeline with per-step reward tracking.
|
| 12 |
+
- `pipelines/sana_gradient_ascent_pipeline.py`
|
| 13 |
+
- SANA pipeline with gradient-ascent latent updates.
|
| 14 |
+
- `eval.py`
|
| 15 |
+
- End-to-end evaluation script.
|
| 16 |
+
- `examples.sh`
|
| 17 |
+
- Cluster entrypoint for prefetch and evaluation.
|
| 18 |
+
|
| 19 |
+
## Default checkpoint
|
| 20 |
+
|
| 21 |
+
`examples.sh` defaults to:
|
| 22 |
+
|
| 23 |
+
`/g/data/rr81/LPO/lrm/lrm_sana/logs/v8/reward_model/step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951/checkpoint-gstep76000`
|
| 24 |
+
|
| 25 |
+
Override with:
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
LRM_MODEL_PATH=/path/to/checkpoint-dir-or-model.safetensors
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## Run (10-sample smoke test)
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
cd /g/data/rr81/LPO/Reward_sana_idealized
|
| 35 |
+
OFFLINE_MODE=1 MAX_SAMPLES=10 MODE=gradient_ascent MODEL_PROFILE=sana_600m_512 ./examples.sh
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## Notes
|
| 39 |
+
|
| 40 |
+
- Uses existing Python env: `/g/data/rr81/aev/bin/python`.
|
| 41 |
+
- GPU nodes should run with offline HF cache.
|
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_1/evaluation_results.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
mode: baseline
|
| 2 |
+
metrics: ['clip', 'aesthetic', 'pickscore', 'hpsv2', 'hpsv21', 'imagereward']
|
| 3 |
+
config: {'num_samples': 500, 'num_steps': 20, 'cfg_scale': 4.5, 'grad_range': [0, 700], 'grad_steps': 5, 'grad_step_size': 0.1}
|
| 4 |
+
baseline: {'avg_reward': np.float64(0.6854833755493164), 'clip_score': np.float64(26.60960610508919), 'aesthetic_score': np.float64(5.930574191093445), 'pickscore': np.float64(21.89574451446533), 'hpsv2_score': np.float16(0.2805), 'hpsv21_score': np.float16(0.292), 'imagereward_score': np.float64(1.001599932681769)}
|
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_1/log.log
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
======================================================================
|
| 2 |
+
FID EVALUATION: BASELINE vs GRADIENT ASCENT
|
| 3 |
+
======================================================================
|
| 4 |
+
|
| 5 |
+
Logging to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_1/log.log
|
| 6 |
+
|
| 7 |
+
Device: cuda:0
|
| 8 |
+
Dataset: PICKAPIC
|
| 9 |
+
Data directory: ./data
|
| 10 |
+
Base model: Efficient-Large-Model/Sana_600M_512px_diffusers
|
| 11 |
+
Model variant: sana_600m_512
|
| 12 |
+
LRM model: /g/data/rr81/LPO/lrm/lrm_sana/logs/v8/reward_model/step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951/checkpoint-gstep76000
|
| 13 |
+
HF cache dir: /scratch/rr81/ma5430/.cache/huggingface/hub
|
| 14 |
+
HF offline mode: True
|
| 15 |
+
Inference steps: 20
|
| 16 |
+
CFG scale: 4.5
|
| 17 |
+
Batch size: 1
|
| 18 |
+
Max samples: All
|
| 19 |
+
Output directory: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_1
|
| 20 |
+
Save images: False
|
| 21 |
+
Evaluation mode: baseline
|
| 22 |
+
Metrics to evaluate: CLIP, AESTHETIC, PICKSCORE, HPSV2, HPSV21, IMAGEREWARD
|
| 23 |
+
Gradient ascent config: one_step_rectification_config
|
| 24 |
+
|
| 25 |
+
======================================================================
|
| 26 |
+
1. LOADING VALIDATION DATA
|
| 27 |
+
======================================================================
|
| 28 |
+
Loading Pick-a-Pic validation prompts...
|
| 29 |
+
Loading cached Pick-a-Pic split 'validation_unique' from 1 parquet shards
|
| 30 |
+
cache=/scratch/rr81/ma5430/.cache/huggingface/hub/datasets--pickapic-anonymous--pickapic_v1
|
| 31 |
+
Loaded 500 Pick-a-Pic validation samples
|
| 32 |
+
|
| 33 |
+
======================================================================
|
| 34 |
+
2. LOADING REWARD MODEL
|
| 35 |
+
======================================================================
|
| 36 |
+
Loading SANA base reward backbone from Efficient-Large-Model/Sana_600M_512px_diffusers...
|
| 37 |
+
Loading SANA reward checkpoint from /g/data/rr81/LPO/lrm/lrm_sana/logs/v8/reward_model/step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951/checkpoint-gstep76000/model.safetensors...
|
| 38 |
+
✓ Loaded checkpoint keys: 1214
|
| 39 |
+
✓ Missing keys: 0 | Unexpected keys: 0
|
| 40 |
+
✓ SANA LRM Reward Model initialized successfully!
|
| 41 |
+
✓ Reward model loaded
|
| 42 |
+
|
| 43 |
+
======================================================================
|
| 44 |
+
3. LOADING PIPELINE
|
| 45 |
+
======================================================================
|
| 46 |
+
✓ Loaded SANA base model: Efficient-Large-Model/Sana_600M_512px_diffusers
|
| 47 |
+
✓ Reward model attached to SANA pipeline
|
| 48 |
+
✓ Pipeline loaded
|
| 49 |
+
GPU memory before scorer load: 126.00 GB free / 140.06 GB total
|
| 50 |
+
Scorer device: cuda:0
|
| 51 |
+
|
| 52 |
+
======================================================================
|
| 53 |
+
3.5. LOADING CLIP AND AESTHETIC SCORERS
|
| 54 |
+
======================================================================
|
| 55 |
+
✓ CLIP scorer loaded
|
| 56 |
+
✓ Aesthetic scorer loaded
|
| 57 |
+
✓ PickScore scorer loaded
|
| 58 |
+
✓ HPSv2 scorer loaded
|
| 59 |
+
✓ HPSv2.1 scorer loaded
|
| 60 |
+
load checkpoint from /scratch/rr81/ma5430/.cache/huggingface/hub/models--THUDM--ImageReward/snapshots/5736be03b2652728fb87788c9797b0570450ab72/ImageReward.pt
|
| 61 |
+
checkpoint loaded
|
| 62 |
+
✓ ImageReward scorer loaded
|
| 63 |
+
|
| 64 |
+
======================================================================
|
| 65 |
+
4. CONFIGURING GRADIENT ASCENT
|
| 66 |
+
======================================================================
|
| 67 |
+
Loading gradient ascent config: one_step_rectification_config
|
| 68 |
+
Config loaded: {'grad_timestep_range': (200, 800), 'num_grad_steps': 1, 'grad_step_size': 1.0, 'grad_scale': 1.0, 'lr_scheduler_type': 'constant', 'use_momentum': False, 'use_nesterov': False, 'use_iso_projection': False}
|
| 69 |
+
Gradient timestep range: (200, 800)
|
| 70 |
+
Gradient steps: 1
|
| 71 |
+
Gradient step size (initial LR): 1.0
|
| 72 |
+
LR Scheduler: constant
|
| 73 |
+
✓ Gradient ascent enabled for timesteps (200, 800)
|
| 74 |
+
|
| 75 |
+
======================================================================
|
| 76 |
+
5. EVALUATING BASELINE
|
| 77 |
+
======================================================================
|
| 78 |
+
|
| 79 |
+
Generating images with baseline mode...
|
| 80 |
+
|
| 81 |
+
[baseline] Batch 10/500 | Samples: 10/500 | Reward (t=136.0): 0.9946 | Reward (Avg): 0.7891 | CLIP: 25.9197 | Aesthetic: 6.1225 | PickScore: 21.9931 | HPSv2: 0.2854 | HPSv2.1: 0.3103 | ImageReward: 1.2588
|
| 82 |
+
|
| 83 |
+
[baseline] Batch 20/500 | Samples: 20/500 | Reward (t=136.0): 0.1036 | Reward (Avg): 0.7397 | CLIP: 26.1207 | Aesthetic: 6.0740 | PickScore: 22.1701 | HPSv2: 0.2849 | HPSv2.1: 0.3103 | ImageReward: 1.1364
|
| 84 |
+
|
| 85 |
+
[baseline] Batch 30/500 | Samples: 30/500 | Reward (t=136.0): 0.9990 | Reward (Avg): 0.7536 | CLIP: 26.2919 | Aesthetic: 6.0006 | PickScore: 22.3621 | HPSv2: 0.2859 | HPSv2.1: 0.3066 | ImageReward: 1.1024
|
| 86 |
+
|
| 87 |
+
[baseline] Batch 40/500 | Samples: 40/500 | Reward (t=136.0): 0.4502 | Reward (Avg): 0.7308 | CLIP: 26.6222 | Aesthetic: 6.0754 | PickScore: 22.3074 | HPSv2: 0.2844 | HPSv2.1: 0.3040 | ImageReward: 0.9845
|
| 88 |
+
|
| 89 |
+
[baseline] Batch 50/500 | Samples: 50/500 | Reward (t=136.0): 0.2013 | Reward (Avg): 0.7104 | CLIP: 26.4461 | Aesthetic: 6.0134 | PickScore: 22.1421 | HPSv2: 0.2832 | HPSv2.1: 0.3013 | ImageReward: 1.0774
|
| 90 |
+
|
| 91 |
+
[baseline] Batch 60/500 | Samples: 60/500 | Reward (t=136.0): 0.8906 | Reward (Avg): 0.7145 | CLIP: 26.3397 | Aesthetic: 6.0189 | PickScore: 22.1258 | HPSv2: 0.2837 | HPSv2.1: 0.3013 | ImageReward: 1.0667
|
| 92 |
+
|
| 93 |
+
[baseline] Batch 70/500 | Samples: 70/500 | Reward (t=136.0): 0.9961 | Reward (Avg): 0.6977 | CLIP: 26.5825 | Aesthetic: 5.9899 | PickScore: 22.1627 | HPSv2: 0.2839 | HPSv2.1: 0.3013 | ImageReward: 1.0167
|
| 94 |
+
|
| 95 |
+
[baseline] Batch 80/500 | Samples: 80/500 | Reward (t=136.0): 0.9014 | Reward (Avg): 0.6814 | CLIP: 26.4912 | Aesthetic: 5.9676 | PickScore: 22.0616 | HPSv2: 0.2832 | HPSv2.1: 0.2974 | ImageReward: 0.9867
|
| 96 |
+
|
| 97 |
+
[baseline] Batch 90/500 | Samples: 90/500 | Reward (t=136.0): 0.9990 | Reward (Avg): 0.6827 | CLIP: 26.6833 | Aesthetic: 5.9604 | PickScore: 22.0380 | HPSv2: 0.2830 | HPSv2.1: 0.2976 | ImageReward: 1.0014
|
| 98 |
+
|
| 99 |
+
[baseline] Batch 100/500 | Samples: 100/500 | Reward (t=136.0): 0.0362 | Reward (Avg): 0.6914 | CLIP: 26.9976 | Aesthetic: 5.9972 | PickScore: 21.9799 | HPSv2: 0.2822 | HPSv2.1: 0.2957 | ImageReward: 1.0027
|
| 100 |
+
|
| 101 |
+
[baseline] Batch 110/500 | Samples: 110/500 | Reward (t=136.0): 0.5342 | Reward (Avg): 0.6817 | CLIP: 27.1451 | Aesthetic: 5.9820 | PickScore: 21.9669 | HPSv2: 0.2825 | HPSv2.1: 0.2959 | ImageReward: 1.0009
|
| 102 |
+
|
| 103 |
+
[baseline] Batch 120/500 | Samples: 120/500 | Reward (t=136.0): 0.0418 | Reward (Avg): 0.6660 | CLIP: 27.0372 | Aesthetic: 5.9733 | PickScore: 21.9549 | HPSv2: 0.2825 | HPSv2.1: 0.2964 | ImageReward: 1.0380
|
| 104 |
+
|
| 105 |
+
[baseline] Batch 130/500 | Samples: 130/500 | Reward (t=136.0): 0.9771 | Reward (Avg): 0.6797 | CLIP: 27.0679 | Aesthetic: 5.9902 | PickScore: 22.0210 | HPSv2: 0.2830 | HPSv2.1: 0.2974 | ImageReward: 1.0360
|
| 106 |
+
|
| 107 |
+
[baseline] Batch 140/500 | Samples: 140/500 | Reward (t=136.0): 0.9722 | Reward (Avg): 0.6796 | CLIP: 27.2356 | Aesthetic: 5.9616 | PickScore: 22.0095 | HPSv2: 0.2830 | HPSv2.1: 0.2961 | ImageReward: 1.0353
|
| 108 |
+
|
| 109 |
+
[baseline] Batch 150/500 | Samples: 150/500 | Reward (t=136.0): 0.9639 | Reward (Avg): 0.6779 | CLIP: 27.0927 | Aesthetic: 5.9419 | PickScore: 21.9896 | HPSv2: 0.2825 | HPSv2.1: 0.2952 | ImageReward: 1.0313
|
| 110 |
+
|
| 111 |
+
[baseline] Batch 160/500 | Samples: 160/500 | Reward (t=136.0): 0.8735 | Reward (Avg): 0.6787 | CLIP: 27.1935 | Aesthetic: 5.9386 | PickScore: 22.0361 | HPSv2: 0.2827 | HPSv2.1: 0.2954 | ImageReward: 1.0422
|
| 112 |
+
|
| 113 |
+
[baseline] Batch 170/500 | Samples: 170/500 | Reward (t=136.0): 0.8418 | Reward (Avg): 0.6797 | CLIP: 26.9886 | Aesthetic: 5.9230 | PickScore: 21.9763 | HPSv2: 0.2820 | HPSv2.1: 0.2939 | ImageReward: 1.0346
|
| 114 |
+
|
| 115 |
+
[baseline] Batch 180/500 | Samples: 180/500 | Reward (t=136.0): 0.3572 | Reward (Avg): 0.6742 | CLIP: 27.0903 | Aesthetic: 5.9289 | PickScore: 21.9842 | HPSv2: 0.2825 | HPSv2.1: 0.2947 | ImageReward: 1.0590
|
| 116 |
+
|
| 117 |
+
[baseline] Batch 190/500 | Samples: 190/500 | Reward (t=136.0): 0.3916 | Reward (Avg): 0.6795 | CLIP: 27.0595 | Aesthetic: 5.9303 | PickScore: 21.9667 | HPSv2: 0.2817 | HPSv2.1: 0.2937 | ImageReward: 1.0438
|
| 118 |
+
|
| 119 |
+
[baseline] Batch 200/500 | Samples: 200/500 | Reward (t=136.0): 0.3701 | Reward (Avg): 0.6735 | CLIP: 27.0636 | Aesthetic: 5.9264 | PickScore: 21.9664 | HPSv2: 0.2820 | HPSv2.1: 0.2944 | ImageReward: 1.0526
|
| 120 |
+
|
| 121 |
+
[baseline] Batch 210/500 | Samples: 210/500 | Reward (t=136.0): 0.7476 | Reward (Avg): 0.6796 | CLIP: 27.0620 | Aesthetic: 5.9352 | PickScore: 21.9646 | HPSv2: 0.2820 | HPSv2.1: 0.2949 | ImageReward: 1.0607
|
| 122 |
+
|
| 123 |
+
[baseline] Batch 220/500 | Samples: 220/500 | Reward (t=136.0): 0.9932 | Reward (Avg): 0.6813 | CLIP: 27.1314 | Aesthetic: 5.9390 | PickScore: 21.9501 | HPSv2: 0.2820 | HPSv2.1: 0.2947 | ImageReward: 1.0481
|
| 124 |
+
|
| 125 |
+
[baseline] Batch 230/500 | Samples: 230/500 | Reward (t=136.0): 0.4731 | Reward (Avg): 0.6819 | CLIP: 27.1906 | Aesthetic: 5.9441 | PickScore: 21.9595 | HPSv2: 0.2820 | HPSv2.1: 0.2944 | ImageReward: 1.0323
|
| 126 |
+
|
| 127 |
+
[baseline] Batch 240/500 | Samples: 240/500 | Reward (t=136.0): 0.2905 | Reward (Avg): 0.6844 | CLIP: 27.0801 | Aesthetic: 5.9538 | PickScore: 21.9540 | HPSv2: 0.2815 | HPSv2.1: 0.2937 | ImageReward: 1.0183
|
| 128 |
+
|
| 129 |
+
[baseline] Batch 250/500 | Samples: 250/500 | Reward (t=136.0): 0.9868 | Reward (Avg): 0.6836 | CLIP: 27.0973 | Aesthetic: 5.9652 | PickScore: 21.9579 | HPSv2: 0.2817 | HPSv2.1: 0.2939 | ImageReward: 1.0174
|
| 130 |
+
|
| 131 |
+
[baseline] Batch 260/500 | Samples: 260/500 | Reward (t=136.0): 0.6987 | Reward (Avg): 0.6825 | CLIP: 27.0369 | Aesthetic: 5.9730 | PickScore: 21.9534 | HPSv2: 0.2817 | HPSv2.1: 0.2939 | ImageReward: 1.0270
|
| 132 |
+
|
| 133 |
+
[baseline] Batch 270/500 | Samples: 270/500 | Reward (t=136.0): 1.0000 | Reward (Avg): 0.6849 | CLIP: 27.0198 | Aesthetic: 5.9743 | PickScore: 21.9475 | HPSv2: 0.2820 | HPSv2.1: 0.2939 | ImageReward: 1.0323
|
| 134 |
+
|
| 135 |
+
[baseline] Batch 280/500 | Samples: 280/500 | Reward (t=136.0): 0.9316 | Reward (Avg): 0.6886 | CLIP: 27.0667 | Aesthetic: 5.9771 | PickScore: 21.9809 | HPSv2: 0.2822 | HPSv2.1: 0.2949 | ImageReward: 1.0492
|
| 136 |
+
|
| 137 |
+
[baseline] Batch 290/500 | Samples: 290/500 | Reward (t=136.0): 0.8652 | Reward (Avg): 0.6863 | CLIP: 26.9701 | Aesthetic: 5.9660 | PickScore: 21.9452 | HPSv2: 0.2820 | HPSv2.1: 0.2939 | ImageReward: 1.0336
|
| 138 |
+
|
| 139 |
+
[baseline] Batch 300/500 | Samples: 300/500 | Reward (t=136.0): 0.9995 | Reward (Avg): 0.6888 | CLIP: 26.9522 | Aesthetic: 5.9680 | PickScore: 21.9559 | HPSv2: 0.2822 | HPSv2.1: 0.2947 | ImageReward: 1.0375
|
| 140 |
+
|
| 141 |
+
[baseline] Batch 310/500 | Samples: 310/500 | Reward (t=136.0): 0.9971 | Reward (Avg): 0.6894 | CLIP: 27.0262 | Aesthetic: 5.9725 | PickScore: 21.9700 | HPSv2: 0.2825 | HPSv2.1: 0.2949 | ImageReward: 1.0583
|
| 142 |
+
|
| 143 |
+
[baseline] Batch 320/500 | Samples: 320/500 | Reward (t=136.0): 0.8667 | Reward (Avg): 0.6890 | CLIP: 27.0493 | Aesthetic: 5.9674 | PickScore: 21.9830 | HPSv2: 0.2825 | HPSv2.1: 0.2947 | ImageReward: 1.0534
|
| 144 |
+
|
| 145 |
+
[baseline] Batch 330/500 | Samples: 330/500 | Reward (t=136.0): 0.9683 | Reward (Avg): 0.6922 | CLIP: 27.0304 | Aesthetic: 5.9733 | PickScore: 21.9784 | HPSv2: 0.2822 | HPSv2.1: 0.2947 | ImageReward: 1.0603
|
| 146 |
+
|
| 147 |
+
[baseline] Batch 340/500 | Samples: 340/500 | Reward (t=136.0): 0.8975 | Reward (Avg): 0.6945 | CLIP: 27.0049 | Aesthetic: 5.9703 | PickScore: 21.9790 | HPSv2: 0.2822 | HPSv2.1: 0.2949 | ImageReward: 1.0708
|
| 148 |
+
|
| 149 |
+
[baseline] Batch 350/500 | Samples: 350/500 | Reward (t=136.0): 0.0694 | Reward (Avg): 0.6900 | CLIP: 27.0073 | Aesthetic: 5.9688 | PickScore: 21.9897 | HPSv2: 0.2822 | HPSv2.1: 0.2949 | ImageReward: 1.0626
|
| 150 |
+
|
| 151 |
+
[baseline] Batch 360/500 | Samples: 360/500 | Reward (t=136.0): 0.9307 | Reward (Avg): 0.6921 | CLIP: 27.0431 | Aesthetic: 5.9667 | PickScore: 21.9860 | HPSv2: 0.2825 | HPSv2.1: 0.2952 | ImageReward: 1.0531
|
| 152 |
+
|
| 153 |
+
[baseline] Batch 370/500 | Samples: 370/500 | Reward (t=136.0): 0.9175 | Reward (Avg): 0.6917 | CLIP: 26.9788 | Aesthetic: 5.9615 | PickScore: 21.9723 | HPSv2: 0.2822 | HPSv2.1: 0.2949 | ImageReward: 1.0486
|
| 154 |
+
|
| 155 |
+
[baseline] Batch 380/500 | Samples: 380/500 | Reward (t=136.0): 0.3616 | Reward (Avg): 0.6916 | CLIP: 27.0571 | Aesthetic: 5.9690 | PickScore: 21.9754 | HPSv2: 0.2822 | HPSv2.1: 0.2952 | ImageReward: 1.0540
|
| 156 |
+
|
| 157 |
+
[baseline] Batch 390/500 | Samples: 390/500 | Reward (t=136.0): 0.9912 | Reward (Avg): 0.6914 | CLIP: 26.9386 | Aesthetic: 5.9658 | PickScore: 21.9601 | HPSv2: 0.2820 | HPSv2.1: 0.2944 | ImageReward: 1.0402
|
| 158 |
+
|
| 159 |
+
[baseline] Batch 400/500 | Samples: 400/500 | Reward (t=136.0): 0.0252 | Reward (Avg): 0.6910 | CLIP: 26.8978 | Aesthetic: 5.9574 | PickScore: 21.9578 | HPSv2: 0.2820 | HPSv2.1: 0.2942 | ImageReward: 1.0424
|
| 160 |
+
|
| 161 |
+
[baseline] Batch 410/500 | Samples: 410/500 | Reward (t=136.0): 0.2007 | Reward (Avg): 0.6909 | CLIP: 26.8640 | Aesthetic: 5.9528 | PickScore: 21.9488 | HPSv2: 0.2815 | HPSv2.1: 0.2937 | ImageReward: 1.0236
|
| 162 |
+
|
| 163 |
+
[baseline] Batch 420/500 | Samples: 420/500 | Reward (t=136.0): 0.9917 | Reward (Avg): 0.6889 | CLIP: 26.8175 | Aesthetic: 5.9515 | PickScore: 21.9370 | HPSv2: 0.2815 | HPSv2.1: 0.2935 | ImageReward: 1.0200
|
| 164 |
+
|
| 165 |
+
[baseline] Batch 430/500 | Samples: 430/500 | Reward (t=136.0): 0.2085 | Reward (Avg): 0.6878 | CLIP: 26.8390 | Aesthetic: 5.9464 | PickScore: 21.9401 | HPSv2: 0.2815 | HPSv2.1: 0.2935 | ImageReward: 1.0237
|
| 166 |
+
|
| 167 |
+
[baseline] Batch 440/500 | Samples: 440/500 | Reward (t=136.0): 0.7144 | Reward (Avg): 0.6867 | CLIP: 26.7663 | Aesthetic: 5.9488 | PickScore: 21.9390 | HPSv2: 0.2815 | HPSv2.1: 0.2935 | ImageReward: 1.0194
|
| 168 |
+
|
| 169 |
+
[baseline] Batch 450/500 | Samples: 450/500 | Reward (t=136.0): 0.8086 | Reward (Avg): 0.6877 | CLIP: 26.7831 | Aesthetic: 5.9422 | PickScore: 21.9401 | HPSv2: 0.2812 | HPSv2.1: 0.2932 | ImageReward: 1.0113
|
| 170 |
+
|
| 171 |
+
[baseline] Batch 460/500 | Samples: 460/500 | Reward (t=136.0): 0.6558 | Reward (Avg): 0.6879 | CLIP: 26.7559 | Aesthetic: 5.9414 | PickScore: 21.9357 | HPSv2: 0.2810 | HPSv2.1: 0.2927 | ImageReward: 1.0081
|
| 172 |
+
|
| 173 |
+
[baseline] Batch 470/500 | Samples: 470/500 | Reward (t=136.0): 0.6001 | Reward (Avg): 0.6873 | CLIP: 26.6761 | Aesthetic: 5.9320 | PickScore: 21.9162 | HPSv2: 0.2808 | HPSv2.1: 0.2920 | ImageReward: 0.9957
|
| 174 |
+
|
| 175 |
+
[baseline] Batch 480/500 | Samples: 480/500 | Reward (t=136.0): 0.7988 | Reward (Avg): 0.6846 | CLIP: 26.6003 | Aesthetic: 5.9285 | PickScore: 21.9059 | HPSv2: 0.2805 | HPSv2.1: 0.2920 | ImageReward: 0.9969
|
| 176 |
+
|
| 177 |
+
[baseline] Batch 490/500 | Samples: 490/500 | Reward (t=136.0): 0.8433 | Reward (Avg): 0.6850 | CLIP: 26.6023 | Aesthetic: 5.9283 | PickScore: 21.8971 | HPSv2: 0.2805 | HPSv2.1: 0.2915 | ImageReward: 0.9980
|
| 178 |
+
|
| 179 |
+
[baseline] Batch 500/500 | Samples: 500/500 | Reward (t=136.0): 0.9902 | Reward (Avg): 0.6855 | CLIP: 26.6096 | Aesthetic: 5.9306 | PickScore: 21.8957 | HPSv2: 0.2805 | HPSv2.1: 0.2920 | ImageReward: 1.0016
|
| 180 |
+
✓ Baseline Avg Reward: 0.6855
|
| 181 |
+
✓ Baseline Avg CLIP Score: 26.6096
|
| 182 |
+
✓ Baseline Avg Aesthetic Score: 5.9306
|
| 183 |
+
✓ Baseline Avg PickScore: 21.8957
|
| 184 |
+
✓ Baseline Avg HPSv2 Score: 0.2805
|
| 185 |
+
✓ Baseline Avg HPSv2.1 Score: 0.2920
|
| 186 |
+
✓ Baseline Avg ImageReward: 1.0016
|
| 187 |
+
|
| 188 |
+
======================================================================
|
| 189 |
+
FINAL RESULTS
|
| 190 |
+
======================================================================
|
| 191 |
+
|
| 192 |
+
Baseline:
|
| 193 |
+
Avg Reward: 0.6855
|
| 194 |
+
Avg CLIP Score: 26.6096
|
| 195 |
+
Avg Aesthetic: 5.9306
|
| 196 |
+
Avg PickScore: 21.8957
|
| 197 |
+
Avg HPSv2: 0.2805
|
| 198 |
+
Avg HPSv2.1: 0.2920
|
| 199 |
+
Avg ImageReward: 1.0016
|
| 200 |
+
|
| 201 |
+
✓ Results saved to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_1/evaluation_results.txt
|
| 202 |
+
|
| 203 |
+
======================================================================
|
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_2/evaluation_results.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
mode: gradient_ascent
|
| 2 |
+
metrics: ['clip', 'aesthetic', 'pickscore', 'hpsv2', 'hpsv21', 'imagereward']
|
| 3 |
+
config: {'num_samples': 500, 'num_steps': 20, 'cfg_scale': 4.5, 'grad_range': [0, 700], 'grad_steps': 5, 'grad_step_size': 0.1}
|
| 4 |
+
gradient_ascent: {'avg_reward': np.float64(0.910064338684082), 'clip_score': np.float64(26.665978475391864), 'aesthetic_score': np.float64(5.9369088306427), 'pickscore': np.float64(21.87682523727417), 'hpsv2_score': np.float16(0.28), 'hpsv21_score': np.float16(0.2903), 'imagereward_score': np.float64(0.9915356585062655), 'stats': {'num_applications': 10, 'total_reward_improvement': 0.00537109375, 'avg_reward_improvement': 0.000537109375, 'avg_grad_norm': 0.016338474582880735, 'max_grad_norm': 0.022640923038125038, 'detailed_stats': [{'timestep': 785, 'initial_reward': 0.986328125, 'final_reward': 0.9873046875, 'reward_improvement': 0.0009765625, 'grad_norms': [0.022640923038125038], 'reward_history': [0.986328125, 0.986328125], 'lr_history': [1.0], 'latent_change': 0.9999995827674866}, {'timestep': 749, 'initial_reward': 0.98779296875, 'final_reward': 0.98828125, 'reward_improvement': 0.00048828125, 'grad_norms': [0.021537061780691147], 'reward_history': [0.98779296875, 0.98779296875], 'lr_history': [1.0], 'latent_change': 0.9999995827674866}, {'timestep': 710, 'initial_reward': 0.98828125, 'final_reward': 0.98876953125, 'reward_improvement': 0.00048828125, 'grad_norms': [0.018791837617754936], 'reward_history': [0.98828125, 0.98828125], 'lr_history': [1.0], 'latent_change': 0.9999994039535522}, {'timestep': 666, 'initial_reward': 0.98876953125, 'final_reward': 0.990234375, 'reward_improvement': 0.00146484375, 'grad_norms': [0.020797280594706535], 'reward_history': [0.98876953125, 0.98876953125], 'lr_history': [1.0], 'latent_change': 0.9999995231628418}, {'timestep': 617, 'initial_reward': 0.98974609375, 'final_reward': 0.98974609375, 'reward_improvement': 0.0, 'grad_norms': [0.021422632038593292], 'reward_history': [0.98974609375, 0.98974609375], 'lr_history': [1.0], 'latent_change': 0.9999995231628418}, {'timestep': 562, 'initial_reward': 0.98974609375, 'final_reward': 0.990234375, 'reward_improvement': 0.00048828125, 'grad_norms': [0.014982366934418678], 'reward_history': [0.98974609375, 0.98974609375], 'lr_history': [1.0], 'latent_change': 0.9999992251396179}, {'timestep': 499, 'initial_reward': 0.990234375, 'final_reward': 0.99072265625, 'reward_improvement': 0.00048828125, 'grad_norms': [0.012734953314065933], 'reward_history': [0.990234375, 0.990234375], 'lr_history': [1.0], 'latent_change': 0.9999991655349731}, {'timestep': 428, 'initial_reward': 0.99072265625, 'final_reward': 0.9912109375, 'reward_improvement': 0.00048828125, 'grad_norms': [0.011904634535312653], 'reward_history': [0.99072265625, 0.99072265625], 'lr_history': [1.0], 'latent_change': 0.9999989867210388}, {'timestep': 345, 'initial_reward': 0.99169921875, 'final_reward': 0.99169921875, 'reward_improvement': 0.0, 'grad_norms': [0.009561818093061447], 'reward_history': [0.99169921875, 0.99169921875], 'lr_history': [1.0], 'latent_change': 0.9999989867210388}, {'timestep': 249, 'initial_reward': 0.99169921875, 'final_reward': 0.9921875, 'reward_improvement': 0.00048828125, 'grad_norms': [0.009011237882077694], 'reward_history': [0.99169921875, 0.99169921875], 'lr_history': [1.0], 'latent_change': 0.9999988675117493}]}}
|
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_2/log.log
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
======================================================================
|
| 2 |
+
FID EVALUATION: BASELINE vs GRADIENT ASCENT
|
| 3 |
+
======================================================================
|
| 4 |
+
|
| 5 |
+
Logging to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_2/log.log
|
| 6 |
+
|
| 7 |
+
Device: cuda:0
|
| 8 |
+
Dataset: PICKAPIC
|
| 9 |
+
Data directory: ./data
|
| 10 |
+
Base model: Efficient-Large-Model/Sana_600M_512px_diffusers
|
| 11 |
+
Model variant: sana_600m_512
|
| 12 |
+
LRM model: /g/data/rr81/LPO/lrm/lrm_sana/logs/v8/reward_model/step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951/checkpoint-gstep76000
|
| 13 |
+
HF cache dir: /scratch/rr81/ma5430/.cache/huggingface/hub
|
| 14 |
+
HF offline mode: True
|
| 15 |
+
Inference steps: 20
|
| 16 |
+
CFG scale: 4.5
|
| 17 |
+
Batch size: 1
|
| 18 |
+
Max samples: All
|
| 19 |
+
Output directory: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_2
|
| 20 |
+
Save images: False
|
| 21 |
+
Evaluation mode: gradient_ascent
|
| 22 |
+
Metrics to evaluate: CLIP, AESTHETIC, PICKSCORE, HPSV2, HPSV21, IMAGEREWARD
|
| 23 |
+
Gradient ascent config: one_step_rectification_config
|
| 24 |
+
|
| 25 |
+
======================================================================
|
| 26 |
+
1. LOADING VALIDATION DATA
|
| 27 |
+
======================================================================
|
| 28 |
+
Loading Pick-a-Pic validation prompts...
|
| 29 |
+
Loading cached Pick-a-Pic split 'validation_unique' from 1 parquet shards
|
| 30 |
+
cache=/scratch/rr81/ma5430/.cache/huggingface/hub/datasets--pickapic-anonymous--pickapic_v1
|
| 31 |
+
Loaded 500 Pick-a-Pic validation samples
|
| 32 |
+
|
| 33 |
+
======================================================================
|
| 34 |
+
2. LOADING REWARD MODEL
|
| 35 |
+
======================================================================
|
| 36 |
+
Loading SANA base reward backbone from Efficient-Large-Model/Sana_600M_512px_diffusers...
|
| 37 |
+
Loading SANA reward checkpoint from /g/data/rr81/LPO/lrm/lrm_sana/logs/v8/reward_model/step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951/checkpoint-gstep76000/model.safetensors...
|
| 38 |
+
✓ Loaded checkpoint keys: 1214
|
| 39 |
+
✓ Missing keys: 0 | Unexpected keys: 0
|
| 40 |
+
✓ SANA LRM Reward Model initialized successfully!
|
| 41 |
+
✓ Reward model loaded
|
| 42 |
+
|
| 43 |
+
======================================================================
|
| 44 |
+
3. LOADING PIPELINE
|
| 45 |
+
======================================================================
|
| 46 |
+
✓ Loaded SANA base model: Efficient-Large-Model/Sana_600M_512px_diffusers
|
| 47 |
+
✓ Reward model attached to SANA pipeline
|
| 48 |
+
✓ Pipeline loaded
|
| 49 |
+
GPU memory before scorer load: 93.09 GB free / 140.06 GB total
|
| 50 |
+
Scorer device: cuda:0
|
| 51 |
+
|
| 52 |
+
======================================================================
|
| 53 |
+
3.5. LOADING CLIP AND AESTHETIC SCORERS
|
| 54 |
+
======================================================================
|
| 55 |
+
✓ CLIP scorer loaded
|
| 56 |
+
✓ Aesthetic scorer loaded
|
| 57 |
+
✓ PickScore scorer loaded
|
| 58 |
+
✓ HPSv2 scorer loaded
|
| 59 |
+
✓ HPSv2.1 scorer loaded
|
| 60 |
+
load checkpoint from /scratch/rr81/ma5430/.cache/huggingface/hub/models--THUDM--ImageReward/snapshots/5736be03b2652728fb87788c9797b0570450ab72/ImageReward.pt
|
| 61 |
+
checkpoint loaded
|
| 62 |
+
✓ ImageReward scorer loaded
|
| 63 |
+
|
| 64 |
+
======================================================================
|
| 65 |
+
4. CONFIGURING GRADIENT ASCENT
|
| 66 |
+
======================================================================
|
| 67 |
+
Loading gradient ascent config: one_step_rectification_config
|
| 68 |
+
Config loaded: {'grad_timestep_range': (200, 800), 'num_grad_steps': 1, 'grad_step_size': 1.0, 'grad_scale': 1.0, 'lr_scheduler_type': 'constant', 'use_momentum': False, 'use_nesterov': False, 'use_iso_projection': False}
|
| 69 |
+
Gradient timestep range: (200, 800)
|
| 70 |
+
Gradient steps: 1
|
| 71 |
+
Gradient step size (initial LR): 1.0
|
| 72 |
+
LR Scheduler: constant
|
| 73 |
+
✓ Gradient ascent enabled for timesteps (200, 800)
|
| 74 |
+
|
| 75 |
+
======================================================================
|
| 76 |
+
6. EVALUATING GRADIENT ASCENT
|
| 77 |
+
======================================================================
|
| 78 |
+
|
| 79 |
+
Generating images with gradient_ascent mode...
|
| 80 |
+
|
| 81 |
+
[gradient_ascent] Batch 10/500 | Samples: 10/500 | Reward (t=136.0): 0.9961 | Reward (Avg): 0.9693 | CLIP: 26.4394 | Aesthetic: 6.0884 | PickScore: 21.9325 | HPSv2: 0.2861 | HPSv2.1: 0.3071 | ImageReward: 1.2097
|
| 82 |
+
|
| 83 |
+
[gradient_ascent] Batch 20/500 | Samples: 20/500 | Reward (t=136.0): 0.4675 | Reward (Avg): 0.9257 | CLIP: 26.3258 | Aesthetic: 6.0671 | PickScore: 22.1625 | HPSv2: 0.2861 | HPSv2.1: 0.3086 | ImageReward: 1.1082
|
| 84 |
+
|
| 85 |
+
[gradient_ascent] Batch 30/500 | Samples: 30/500 | Reward (t=136.0): 0.9980 | Reward (Avg): 0.9233 | CLIP: 26.4155 | Aesthetic: 5.9805 | PickScore: 22.3315 | HPSv2: 0.2861 | HPSv2.1: 0.3057 | ImageReward: 1.0953
|
| 86 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 87 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 88 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 89 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 90 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 91 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 92 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 93 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 94 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 95 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 96 |
+
|
| 97 |
+
[gradient_ascent] Batch 40/500 | Samples: 40/500 | Reward (t=136.0): 0.9722 | Reward (Avg): 0.9284 | CLIP: 26.8101 | Aesthetic: 6.0611 | PickScore: 22.3091 | HPSv2: 0.2849 | HPSv2.1: 0.3037 | ImageReward: 0.9902
|
| 98 |
+
|
| 99 |
+
[gradient_ascent] Batch 50/500 | Samples: 50/500 | Reward (t=136.0): 0.6968 | Reward (Avg): 0.9302 | CLIP: 26.6150 | Aesthetic: 5.9952 | PickScore: 22.1516 | HPSv2: 0.2832 | HPSv2.1: 0.3005 | ImageReward: 1.0679
|
| 100 |
+
|
| 101 |
+
[gradient_ascent] Batch 60/500 | Samples: 60/500 | Reward (t=136.0): 0.9868 | Reward (Avg): 0.9303 | CLIP: 26.4753 | Aesthetic: 6.0003 | PickScore: 22.1196 | HPSv2: 0.2837 | HPSv2.1: 0.3000 | ImageReward: 1.0506
|
| 102 |
+
|
| 103 |
+
[gradient_ascent] Batch 70/500 | Samples: 70/500 | Reward (t=136.0): 0.9966 | Reward (Avg): 0.9293 | CLIP: 26.7216 | Aesthetic: 5.9899 | PickScore: 22.1551 | HPSv2: 0.2842 | HPSv2.1: 0.3005 | ImageReward: 0.9863
|
| 104 |
+
|
| 105 |
+
[gradient_ascent] Batch 80/500 | Samples: 80/500 | Reward (t=136.0): 0.9819 | Reward (Avg): 0.9147 | CLIP: 26.6154 | Aesthetic: 5.9622 | PickScore: 22.0316 | HPSv2: 0.2830 | HPSv2.1: 0.2959 | ImageReward: 0.9452
|
| 106 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 107 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 108 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 109 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 110 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 111 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 112 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 113 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 114 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 115 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 116 |
+
|
| 117 |
+
[gradient_ascent] Batch 90/500 | Samples: 90/500 | Reward (t=136.0): 0.9990 | Reward (Avg): 0.9137 | CLIP: 26.7968 | Aesthetic: 5.9589 | PickScore: 22.0089 | HPSv2: 0.2825 | HPSv2.1: 0.2959 | ImageReward: 0.9702
|
| 118 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 119 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 120 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 121 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 122 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 123 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 124 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 125 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 126 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 127 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 128 |
+
|
| 129 |
+
[gradient_ascent] Batch 100/500 | Samples: 100/500 | Reward (t=136.0): 0.7866 | Reward (Avg): 0.9190 | CLIP: 27.0602 | Aesthetic: 6.0023 | PickScore: 21.9608 | HPSv2: 0.2817 | HPSv2.1: 0.2942 | ImageReward: 0.9681
|
| 130 |
+
|
| 131 |
+
[gradient_ascent] Batch 110/500 | Samples: 110/500 | Reward (t=136.0): 0.9692 | Reward (Avg): 0.9170 | CLIP: 27.2418 | Aesthetic: 5.9877 | PickScore: 21.9353 | HPSv2: 0.2820 | HPSv2.1: 0.2939 | ImageReward: 0.9614
|
| 132 |
+
|
| 133 |
+
[gradient_ascent] Batch 120/500 | Samples: 120/500 | Reward (t=136.0): 0.8379 | Reward (Avg): 0.9146 | CLIP: 27.1626 | Aesthetic: 5.9762 | PickScore: 21.9197 | HPSv2: 0.2820 | HPSv2.1: 0.2942 | ImageReward: 0.9982
|
| 134 |
+
|
| 135 |
+
[gradient_ascent] Batch 130/500 | Samples: 130/500 | Reward (t=136.0): 0.9858 | Reward (Avg): 0.9196 | CLIP: 27.2047 | Aesthetic: 5.9938 | PickScore: 21.9912 | HPSv2: 0.2825 | HPSv2.1: 0.2952 | ImageReward: 0.9948
|
| 136 |
+
|
| 137 |
+
[gradient_ascent] Batch 140/500 | Samples: 140/500 | Reward (t=136.0): 0.9912 | Reward (Avg): 0.9157 | CLIP: 27.3311 | Aesthetic: 5.9660 | PickScore: 21.9761 | HPSv2: 0.2822 | HPSv2.1: 0.2944 | ImageReward: 0.9935
|
| 138 |
+
|
| 139 |
+
[gradient_ascent] Batch 150/500 | Samples: 150/500 | Reward (t=136.0): 0.9932 | Reward (Avg): 0.9114 | CLIP: 27.1722 | Aesthetic: 5.9463 | PickScore: 21.9579 | HPSv2: 0.2817 | HPSv2.1: 0.2930 | ImageReward: 0.9898
|
| 140 |
+
|
| 141 |
+
[gradient_ascent] Batch 160/500 | Samples: 160/500 | Reward (t=136.0): 0.9814 | Reward (Avg): 0.9088 | CLIP: 27.2961 | Aesthetic: 5.9422 | PickScore: 21.9990 | HPSv2: 0.2820 | HPSv2.1: 0.2932 | ImageReward: 1.0014
|
| 142 |
+
|
| 143 |
+
[gradient_ascent] Batch 170/500 | Samples: 170/500 | Reward (t=136.0): 0.9531 | Reward (Avg): 0.9071 | CLIP: 27.0807 | Aesthetic: 5.9318 | PickScore: 21.9435 | HPSv2: 0.2815 | HPSv2.1: 0.2917 | ImageReward: 0.9945
|
| 144 |
+
|
| 145 |
+
[gradient_ascent] Batch 180/500 | Samples: 180/500 | Reward (t=136.0): 0.8799 | Reward (Avg): 0.9030 | CLIP: 27.1614 | Aesthetic: 5.9334 | PickScore: 21.9554 | HPSv2: 0.2817 | HPSv2.1: 0.2927 | ImageReward: 1.0219
|
| 146 |
+
|
| 147 |
+
[gradient_ascent] Batch 190/500 | Samples: 190/500 | Reward (t=136.0): 0.8799 | Reward (Avg): 0.9062 | CLIP: 27.1369 | Aesthetic: 5.9368 | PickScore: 21.9370 | HPSv2: 0.2812 | HPSv2.1: 0.2920 | ImageReward: 1.0060
|
| 148 |
+
|
| 149 |
+
[gradient_ascent] Batch 200/500 | Samples: 200/500 | Reward (t=136.0): 0.9653 | Reward (Avg): 0.9037 | CLIP: 27.1264 | Aesthetic: 5.9329 | PickScore: 21.9405 | HPSv2: 0.2815 | HPSv2.1: 0.2925 | ImageReward: 1.0175
|
| 150 |
+
|
| 151 |
+
[gradient_ascent] Batch 210/500 | Samples: 210/500 | Reward (t=136.0): 0.9897 | Reward (Avg): 0.9063 | CLIP: 27.1528 | Aesthetic: 5.9437 | PickScore: 21.9455 | HPSv2: 0.2815 | HPSv2.1: 0.2932 | ImageReward: 1.0271
|
| 152 |
+
|
| 153 |
+
[gradient_ascent] Batch 220/500 | Samples: 220/500 | Reward (t=136.0): 0.9961 | Reward (Avg): 0.9069 | CLIP: 27.2116 | Aesthetic: 5.9480 | PickScore: 21.9343 | HPSv2: 0.2815 | HPSv2.1: 0.2932 | ImageReward: 1.0189
|
| 154 |
+
|
| 155 |
+
[gradient_ascent] Batch 230/500 | Samples: 230/500 | Reward (t=136.0): 0.9409 | Reward (Avg): 0.9084 | CLIP: 27.2825 | Aesthetic: 5.9535 | PickScore: 21.9417 | HPSv2: 0.2815 | HPSv2.1: 0.2927 | ImageReward: 1.0007
|
| 156 |
+
|
| 157 |
+
[gradient_ascent] Batch 240/500 | Samples: 240/500 | Reward (t=136.0): 0.8716 | Reward (Avg): 0.9101 | CLIP: 27.1896 | Aesthetic: 5.9654 | PickScore: 21.9351 | HPSv2: 0.2810 | HPSv2.1: 0.2920 | ImageReward: 0.9896
|
| 158 |
+
|
| 159 |
+
[gradient_ascent] Batch 250/500 | Samples: 250/500 | Reward (t=136.0): 0.9966 | Reward (Avg): 0.9120 | CLIP: 27.2003 | Aesthetic: 5.9738 | PickScore: 21.9403 | HPSv2: 0.2812 | HPSv2.1: 0.2922 | ImageReward: 0.9901
|
| 160 |
+
|
| 161 |
+
[gradient_ascent] Batch 260/500 | Samples: 260/500 | Reward (t=136.0): 0.8730 | Reward (Avg): 0.9100 | CLIP: 27.1515 | Aesthetic: 5.9809 | PickScore: 21.9393 | HPSv2: 0.2812 | HPSv2.1: 0.2922 | ImageReward: 1.0035
|
| 162 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 163 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 164 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 165 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 166 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 167 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 168 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 169 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 170 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 171 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 172 |
+
|
| 173 |
+
[gradient_ascent] Batch 270/500 | Samples: 270/500 | Reward (t=136.0): 1.0000 | Reward (Avg): 0.9094 | CLIP: 27.1001 | Aesthetic: 5.9807 | PickScore: 21.9312 | HPSv2: 0.2812 | HPSv2.1: 0.2922 | ImageReward: 1.0090
|
| 174 |
+
|
| 175 |
+
[gradient_ascent] Batch 280/500 | Samples: 280/500 | Reward (t=136.0): 0.9683 | Reward (Avg): 0.9115 | CLIP: 27.1456 | Aesthetic: 5.9814 | PickScore: 21.9610 | HPSv2: 0.2817 | HPSv2.1: 0.2932 | ImageReward: 1.0251
|
| 176 |
+
|
| 177 |
+
[gradient_ascent] Batch 290/500 | Samples: 290/500 | Reward (t=136.0): 0.9956 | Reward (Avg): 0.9111 | CLIP: 27.0581 | Aesthetic: 5.9709 | PickScore: 21.9298 | HPSv2: 0.2812 | HPSv2.1: 0.2925 | ImageReward: 1.0123
|
| 178 |
+
?? WARNING: Gradient exists but max value is 0.0
|
| 179 |
+
|
| 180 |
+
[gradient_ascent] Batch 300/500 | Samples: 300/500 | Reward (t=136.0): 0.9995 | Reward (Avg): 0.9110 | CLIP: 27.0435 | Aesthetic: 5.9736 | PickScore: 21.9395 | HPSv2: 0.2817 | HPSv2.1: 0.2930 | ImageReward: 1.0170
|
| 181 |
+
|
| 182 |
+
[gradient_ascent] Batch 310/500 | Samples: 310/500 | Reward (t=136.0): 0.9980 | Reward (Avg): 0.9117 | CLIP: 27.0912 | Aesthetic: 5.9763 | PickScore: 21.9508 | HPSv2: 0.2817 | HPSv2.1: 0.2932 | ImageReward: 1.0371
|
| 183 |
+
|
| 184 |
+
[gradient_ascent] Batch 320/500 | Samples: 320/500 | Reward (t=136.0): 0.9775 | Reward (Avg): 0.9116 | CLIP: 27.1126 | Aesthetic: 5.9707 | PickScore: 21.9669 | HPSv2: 0.2817 | HPSv2.1: 0.2930 | ImageReward: 1.0342
|
| 185 |
+
|
| 186 |
+
[gradient_ascent] Batch 330/500 | Samples: 330/500 | Reward (t=136.0): 0.9941 | Reward (Avg): 0.9121 | CLIP: 27.1159 | Aesthetic: 5.9762 | PickScore: 21.9632 | HPSv2: 0.2817 | HPSv2.1: 0.2932 | ImageReward: 1.0420
|
| 187 |
+
|
| 188 |
+
[gradient_ascent] Batch 340/500 | Samples: 340/500 | Reward (t=136.0): 0.9785 | Reward (Avg): 0.9129 | CLIP: 27.1023 | Aesthetic: 5.9739 | PickScore: 21.9643 | HPSv2: 0.2817 | HPSv2.1: 0.2935 | ImageReward: 1.0503
|
| 189 |
+
|
| 190 |
+
[gradient_ascent] Batch 350/500 | Samples: 350/500 | Reward (t=136.0): 0.5396 | Reward (Avg): 0.9110 | CLIP: 27.1031 | Aesthetic: 5.9720 | PickScore: 21.9747 | HPSv2: 0.2817 | HPSv2.1: 0.2935 | ImageReward: 1.0437
|
| 191 |
+
|
| 192 |
+
[gradient_ascent] Batch 360/500 | Samples: 360/500 | Reward (t=136.0): 0.9805 | Reward (Avg): 0.9119 | CLIP: 27.1421 | Aesthetic: 5.9696 | PickScore: 21.9725 | HPSv2: 0.2817 | HPSv2.1: 0.2937 | ImageReward: 1.0372
|
| 193 |
+
|
| 194 |
+
[gradient_ascent] Batch 370/500 | Samples: 370/500 | Reward (t=136.0): 0.9692 | Reward (Avg): 0.9117 | CLIP: 27.0698 | Aesthetic: 5.9642 | PickScore: 21.9583 | HPSv2: 0.2817 | HPSv2.1: 0.2935 | ImageReward: 1.0346
|
| 195 |
+
|
| 196 |
+
[gradient_ascent] Batch 380/500 | Samples: 380/500 | Reward (t=136.0): 0.7930 | Reward (Avg): 0.9110 | CLIP: 27.1452 | Aesthetic: 5.9731 | PickScore: 21.9609 | HPSv2: 0.2817 | HPSv2.1: 0.2937 | ImageReward: 1.0401
|
| 197 |
+
|
| 198 |
+
[gradient_ascent] Batch 390/500 | Samples: 390/500 | Reward (t=136.0): 0.9932 | Reward (Avg): 0.9114 | CLIP: 27.0228 | Aesthetic: 5.9694 | PickScore: 21.9430 | HPSv2: 0.2812 | HPSv2.1: 0.2930 | ImageReward: 1.0260
|
| 199 |
+
|
| 200 |
+
[gradient_ascent] Batch 400/500 | Samples: 400/500 | Reward (t=136.0): 0.4897 | Reward (Avg): 0.9117 | CLIP: 26.9789 | Aesthetic: 5.9603 | PickScore: 21.9398 | HPSv2: 0.2812 | HPSv2.1: 0.2927 | ImageReward: 1.0275
|
| 201 |
+
|
| 202 |
+
[gradient_ascent] Batch 410/500 | Samples: 410/500 | Reward (t=136.0): 0.8018 | Reward (Avg): 0.9123 | CLIP: 26.9305 | Aesthetic: 5.9569 | PickScore: 21.9279 | HPSv2: 0.2810 | HPSv2.1: 0.2922 | ImageReward: 1.0088
|
| 203 |
+
|
| 204 |
+
[gradient_ascent] Batch 420/500 | Samples: 420/500 | Reward (t=136.0): 0.9966 | Reward (Avg): 0.9113 | CLIP: 26.8775 | Aesthetic: 5.9550 | PickScore: 21.9134 | HPSv2: 0.2808 | HPSv2.1: 0.2920 | ImageReward: 1.0062
|
| 205 |
+
|
| 206 |
+
[gradient_ascent] Batch 430/500 | Samples: 430/500 | Reward (t=136.0): 0.7876 | Reward (Avg): 0.9102 | CLIP: 26.8827 | Aesthetic: 5.9510 | PickScore: 21.9192 | HPSv2: 0.2808 | HPSv2.1: 0.2920 | ImageReward: 1.0116
|
| 207 |
+
|
| 208 |
+
[gradient_ascent] Batch 440/500 | Samples: 440/500 | Reward (t=136.0): 0.9561 | Reward (Avg): 0.9096 | CLIP: 26.8088 | Aesthetic: 5.9552 | PickScore: 21.9176 | HPSv2: 0.2808 | HPSv2.1: 0.2920 | ImageReward: 1.0110
|
| 209 |
+
|
| 210 |
+
[gradient_ascent] Batch 450/500 | Samples: 450/500 | Reward (t=136.0): 0.9443 | Reward (Avg): 0.9086 | CLIP: 26.8014 | Aesthetic: 5.9476 | PickScore: 21.9168 | HPSv2: 0.2805 | HPSv2.1: 0.2915 | ImageReward: 1.0002
|
| 211 |
+
|
| 212 |
+
[gradient_ascent] Batch 460/500 | Samples: 460/500 | Reward (t=136.0): 0.9888 | Reward (Avg): 0.9095 | CLIP: 26.7852 | Aesthetic: 5.9451 | PickScore: 21.9099 | HPSv2: 0.2805 | HPSv2.1: 0.2913 | ImageReward: 0.9956
|
| 213 |
+
|
| 214 |
+
[gradient_ascent] Batch 470/500 | Samples: 470/500 | Reward (t=136.0): 0.9800 | Reward (Avg): 0.9107 | CLIP: 26.7099 | Aesthetic: 5.9361 | PickScore: 21.8909 | HPSv2: 0.2800 | HPSv2.1: 0.2905 | ImageReward: 0.9830
|
| 215 |
+
|
| 216 |
+
[gradient_ascent] Batch 480/500 | Samples: 480/500 | Reward (t=136.0): 0.9692 | Reward (Avg): 0.9096 | CLIP: 26.6490 | Aesthetic: 5.9346 | PickScore: 21.8834 | HPSv2: 0.2800 | HPSv2.1: 0.2903 | ImageReward: 0.9868
|
| 217 |
+
|
| 218 |
+
[gradient_ascent] Batch 490/500 | Samples: 490/500 | Reward (t=136.0): 0.9629 | Reward (Avg): 0.9101 | CLIP: 26.6586 | Aesthetic: 5.9349 | PickScore: 21.8766 | HPSv2: 0.2800 | HPSv2.1: 0.2900 | ImageReward: 0.9869
|
| 219 |
+
|
| 220 |
+
[gradient_ascent] Batch 500/500 | Samples: 500/500 | Reward (t=136.0): 0.9927 | Reward (Avg): 0.9101 | CLIP: 26.6660 | Aesthetic: 5.9369 | PickScore: 21.8768 | HPSv2: 0.2800 | HPSv2.1: 0.2903 | ImageReward: 0.9915
|
| 221 |
+
✓ Gradient Ascent Avg Reward: 0.9101
|
| 222 |
+
✓ Gradient Ascent Avg CLIP Score: 26.6660
|
| 223 |
+
✓ Gradient Ascent Avg Aesthetic Score: 5.9369
|
| 224 |
+
✓ Gradient Ascent Avg PickScore: 21.8768
|
| 225 |
+
✓ Gradient Ascent Avg HPSv2 Score: 0.2800
|
| 226 |
+
✓ Gradient Ascent Avg HPSv2.1 Score: 0.2903
|
| 227 |
+
✓ Gradient Ascent Avg ImageReward: 0.9915
|
| 228 |
+
|
| 229 |
+
Gradient Ascent Statistics:
|
| 230 |
+
Applications: 10
|
| 231 |
+
Total reward improvement: +0.0054
|
| 232 |
+
Avg reward improvement: +0.0005
|
| 233 |
+
|
| 234 |
+
✓ Saved LR curve plot to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_2/lr_curve.png
|
| 235 |
+
Total gradient steps: 10
|
| 236 |
+
LR range: 1.000000 → 1.000000
|
| 237 |
+
|
| 238 |
+
✓ Saved Rewards curve plot to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_2/rewards_curve.png
|
| 239 |
+
Total gradient steps: 20
|
| 240 |
+
Reward range: 0.9971 → 0.9990
|
| 241 |
+
Total improvement: +0.0020
|
| 242 |
+
|
| 243 |
+
======================================================================
|
| 244 |
+
FINAL RESULTS
|
| 245 |
+
======================================================================
|
| 246 |
+
|
| 247 |
+
Gradient Ascent:
|
| 248 |
+
Avg Reward: 0.9101
|
| 249 |
+
Avg CLIP Score: 26.6660
|
| 250 |
+
Avg Aesthetic: 5.9369
|
| 251 |
+
Avg PickScore: 21.8768
|
| 252 |
+
Avg HPSv2: 0.2800
|
| 253 |
+
Avg HPSv2.1: 0.2903
|
| 254 |
+
Avg ImageReward: 0.9915
|
| 255 |
+
|
| 256 |
+
✓ Results saved to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_2/evaluation_results.txt
|
| 257 |
+
|
| 258 |
+
======================================================================
|
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_2/lr_curve.png
ADDED
|
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_3/evaluation_results.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
mode: gradient_ascent
|
| 2 |
+
metrics: ['clip', 'aesthetic', 'pickscore', 'hpsv2', 'hpsv21', 'imagereward']
|
| 3 |
+
config: {'num_samples': 500, 'num_steps': 20, 'cfg_scale': 4.5, 'grad_range': [0, 700], 'grad_steps': 5, 'grad_step_size': 0.1}
|
| 4 |
+
gradient_ascent: {'avg_reward': np.float64(0.92294921875), 'clip_score': np.float64(26.616577001571656), 'aesthetic_score': np.float64(5.955972215652466), 'pickscore': np.float64(21.881759201049803), 'hpsv2_score': np.float16(0.2798), 'hpsv21_score': np.float16(0.2905), 'imagereward_score': np.float64(0.9889077876545489), 'stats': {'num_applications': 11, 'total_reward_improvement': 0.078125, 'avg_reward_improvement': 0.007102272727272727, 'avg_grad_norm': 0.21211126311258835, 'max_grad_norm': 0.277950644493103, 'detailed_stats': [{'timestep': 785, 'initial_reward': 0.9140625, 'final_reward': 0.93359375, 'reward_improvement': 0.01953125, 'grad_norms': [0.26926133036613464], 'reward_history': [0.9140625, 0.9140625], 'lr_history': [1.0], 'latent_change': 1.0}, {'timestep': 749, 'initial_reward': 0.9375, 'final_reward': 0.94921875, 'reward_improvement': 0.01171875, 'grad_norms': [0.277950644493103], 'reward_history': [0.9375, 0.9375], 'lr_history': [1.0], 'latent_change': 0.9999999403953552}, {'timestep': 710, 'initial_reward': 0.94921875, 'final_reward': 0.9609375, 'reward_improvement': 0.01171875, 'grad_norms': [0.27435681223869324], 'reward_history': [0.94921875, 0.94921875], 'lr_history': [1.0], 'latent_change': 1.0000001192092896}, {'timestep': 666, 'initial_reward': 0.9609375, 'final_reward': 0.96875, 'reward_improvement': 0.0078125, 'grad_norms': [0.24942916631698608], 'reward_history': [0.9609375, 0.9609375], 'lr_history': [1.0], 'latent_change': 0.9999998807907104}, {'timestep': 617, 'initial_reward': 0.96875, 'final_reward': 0.97265625, 'reward_improvement': 0.00390625, 'grad_norms': [0.21902038156986237], 'reward_history': [0.96875, 0.96875], 'lr_history': [1.0], 'latent_change': 0.9999999403953552}, {'timestep': 562, 'initial_reward': 0.96875, 'final_reward': 0.9765625, 'reward_improvement': 0.0078125, 'grad_norms': [0.19418643414974213], 'reward_history': [0.96875, 0.96875], 'lr_history': [1.0], 'latent_change': 1.0}, {'timestep': 499, 'initial_reward': 0.97265625, 'final_reward': 0.9765625, 'reward_improvement': 0.00390625, 'grad_norms': [0.18358778953552246], 'reward_history': [0.97265625, 0.97265625], 'lr_history': [1.0], 'latent_change': 0.9999998807907104}, {'timestep': 428, 'initial_reward': 0.9765625, 'final_reward': 0.98046875, 'reward_improvement': 0.00390625, 'grad_norms': [0.1757386326789856], 'reward_history': [0.9765625, 0.9765625], 'lr_history': [1.0], 'latent_change': 0.9999999403953552}, {'timestep': 345, 'initial_reward': 0.9765625, 'final_reward': 0.98046875, 'reward_improvement': 0.00390625, 'grad_norms': [0.16876810789108276], 'reward_history': [0.9765625, 0.9765625], 'lr_history': [1.0], 'latent_change': 0.9999998807907104}, {'timestep': 249, 'initial_reward': 0.98046875, 'final_reward': 0.98046875, 'reward_improvement': 0.0, 'grad_norms': [0.16214041411876678], 'reward_history': [0.98046875, 0.98046875], 'lr_history': [1.0], 'latent_change': 0.9999998807907104}, {'timestep': 136, 'initial_reward': 0.9765625, 'final_reward': 0.98046875, 'reward_improvement': 0.00390625, 'grad_norms': [0.1587841808795929], 'reward_history': [0.9765625, 0.9765625], 'lr_history': [1.0], 'latent_change': 0.9999998807907104}]}}
|
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_3/log.log
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
======================================================================
|
| 2 |
+
FID EVALUATION: BASELINE vs GRADIENT ASCENT
|
| 3 |
+
======================================================================
|
| 4 |
+
|
| 5 |
+
Logging to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_3/log.log
|
| 6 |
+
|
| 7 |
+
Device: cuda:0
|
| 8 |
+
Dataset: PICKAPIC
|
| 9 |
+
Data directory: ./data
|
| 10 |
+
Base model: Efficient-Large-Model/Sana_600M_512px_diffusers
|
| 11 |
+
Model variant: sana_600m_512
|
| 12 |
+
LRM model: /g/data/rr81/LPO/lrm/lrm_sana/logs/v8/reward_model/step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951/checkpoint-gstep33000
|
| 13 |
+
HF cache dir: /scratch/rr81/ma5430/.cache/huggingface/hub
|
| 14 |
+
HF offline mode: True
|
| 15 |
+
Inference steps: 20
|
| 16 |
+
CFG scale: 4.5
|
| 17 |
+
Batch size: 1
|
| 18 |
+
Max samples: All
|
| 19 |
+
Generation dtype: bf16
|
| 20 |
+
Output directory: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_3
|
| 21 |
+
Save images: False
|
| 22 |
+
Evaluation mode: gradient_ascent
|
| 23 |
+
Metrics to evaluate: CLIP, AESTHETIC, PICKSCORE, HPSV2, HPSV21, IMAGEREWARD
|
| 24 |
+
Gradient ascent config: one_step_rectification_config
|
| 25 |
+
|
| 26 |
+
======================================================================
|
| 27 |
+
1. LOADING VALIDATION DATA
|
| 28 |
+
======================================================================
|
| 29 |
+
Loading Pick-a-Pic validation prompts...
|
| 30 |
+
Loading cached Pick-a-Pic split 'validation_unique' from 1 parquet shards
|
| 31 |
+
cache=/scratch/rr81/ma5430/.cache/huggingface/hub/datasets--pickapic-anonymous--pickapic_v1
|
| 32 |
+
Loaded 500 Pick-a-Pic validation samples
|
| 33 |
+
|
| 34 |
+
======================================================================
|
| 35 |
+
2. LOADING REWARD MODEL
|
| 36 |
+
======================================================================
|
| 37 |
+
Loading SANA base reward backbone from Efficient-Large-Model/Sana_600M_512px_diffusers...
|
| 38 |
+
Loading SANA reward checkpoint from /g/data/rr81/LPO/lrm/lrm_sana/logs/v8/reward_model/step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951/checkpoint-gstep33000/model.safetensors...
|
| 39 |
+
✓ Loaded checkpoint keys: 1214
|
| 40 |
+
✓ Missing keys: 0 | Unexpected keys: 0
|
| 41 |
+
✓ SANA LRM Reward Model initialized successfully!
|
| 42 |
+
✓ Reward model loaded
|
| 43 |
+
|
| 44 |
+
======================================================================
|
| 45 |
+
3. LOADING PIPELINE
|
| 46 |
+
======================================================================
|
| 47 |
+
✓ Loaded SANA base model: Efficient-Large-Model/Sana_600M_512px_diffusers
|
| 48 |
+
✓ Reward model attached to SANA pipeline
|
| 49 |
+
✓ Pipeline loaded
|
| 50 |
+
GPU memory before scorer load: 125.86 GB free / 140.06 GB total
|
| 51 |
+
Scorer device: cuda:0
|
| 52 |
+
|
| 53 |
+
======================================================================
|
| 54 |
+
3.5. LOADING CLIP AND AESTHETIC SCORERS
|
| 55 |
+
======================================================================
|
| 56 |
+
✓ CLIP scorer loaded
|
| 57 |
+
✓ Aesthetic scorer loaded
|
| 58 |
+
✓ PickScore scorer loaded
|
| 59 |
+
✓ HPSv2 scorer loaded
|
| 60 |
+
✓ HPSv2.1 scorer loaded
|
| 61 |
+
load checkpoint from /scratch/rr81/ma5430/.cache/huggingface/hub/models--THUDM--ImageReward/snapshots/5736be03b2652728fb87788c9797b0570450ab72/ImageReward.pt
|
| 62 |
+
checkpoint loaded
|
| 63 |
+
✓ ImageReward scorer loaded
|
| 64 |
+
|
| 65 |
+
======================================================================
|
| 66 |
+
4. CONFIGURING GRADIENT ASCENT
|
| 67 |
+
======================================================================
|
| 68 |
+
Loading gradient ascent config: one_step_rectification_config
|
| 69 |
+
Config loaded: {'grad_timestep_range': (100, 800), 'num_grad_steps': 1, 'grad_step_size': 1.0, 'grad_scale': 1.0, 'lr_scheduler_type': 'constant', 'use_momentum': False, 'use_nesterov': False, 'use_iso_projection': False}
|
| 70 |
+
Gradient timestep range: (100, 800)
|
| 71 |
+
Gradient steps: 1
|
| 72 |
+
Gradient step size (initial LR): 1.0
|
| 73 |
+
LR Scheduler: constant
|
| 74 |
+
✓ Gradient ascent enabled for timesteps (100, 800)
|
| 75 |
+
|
| 76 |
+
======================================================================
|
| 77 |
+
6. EVALUATING GRADIENT ASCENT
|
| 78 |
+
======================================================================
|
| 79 |
+
|
| 80 |
+
Generating images with gradient_ascent mode...
|
| 81 |
+
|
| 82 |
+
[gradient_ascent] Batch 10/500 | Samples: 10/500 | Reward (t=136.0): 1.0000 | Reward (Avg): 0.9625 | CLIP: 26.4663 | Aesthetic: 6.1413 | PickScore: 22.1120 | HPSv2: 0.2876 | HPSv2.1: 0.3157 | ImageReward: 1.2690
|
| 83 |
+
|
| 84 |
+
[gradient_ascent] Batch 20/500 | Samples: 20/500 | Reward (t=136.0): 0.7266 | Reward (Avg): 0.9354 | CLIP: 26.3375 | Aesthetic: 6.1225 | PickScore: 22.1887 | HPSv2: 0.2869 | HPSv2.1: 0.3118 | ImageReward: 1.0984
|
| 85 |
+
|
| 86 |
+
[gradient_ascent] Batch 30/500 | Samples: 30/500 | Reward (t=136.0): 1.0000 | Reward (Avg): 0.9434 | CLIP: 26.4766 | Aesthetic: 6.0533 | PickScore: 22.3414 | HPSv2: 0.2864 | HPSv2.1: 0.3079 | ImageReward: 1.0883
|
| 87 |
+
|
| 88 |
+
[gradient_ascent] Batch 40/500 | Samples: 40/500 | Reward (t=136.0): 0.9609 | Reward (Avg): 0.9415 | CLIP: 26.6829 | Aesthetic: 6.0924 | PickScore: 22.2873 | HPSv2: 0.2844 | HPSv2.1: 0.3047 | ImageReward: 0.9498
|
| 89 |
+
|
| 90 |
+
[gradient_ascent] Batch 50/500 | Samples: 50/500 | Reward (t=136.0): 0.8711 | Reward (Avg): 0.9413 | CLIP: 26.3142 | Aesthetic: 6.0252 | PickScore: 22.1119 | HPSv2: 0.2830 | HPSv2.1: 0.3010 | ImageReward: 1.0441
|
| 91 |
+
|
| 92 |
+
[gradient_ascent] Batch 60/500 | Samples: 60/500 | Reward (t=136.0): 0.9961 | Reward (Avg): 0.9441 | CLIP: 26.1909 | Aesthetic: 6.0327 | PickScore: 22.1030 | HPSv2: 0.2837 | HPSv2.1: 0.3013 | ImageReward: 1.0347
|
| 93 |
+
|
| 94 |
+
[gradient_ascent] Batch 70/500 | Samples: 70/500 | Reward (t=136.0): 1.0000 | Reward (Avg): 0.9421 | CLIP: 26.4568 | Aesthetic: 6.0299 | PickScore: 22.1226 | HPSv2: 0.2839 | HPSv2.1: 0.3013 | ImageReward: 1.0243
|
| 95 |
+
|
| 96 |
+
[gradient_ascent] Batch 80/500 | Samples: 80/500 | Reward (t=136.0): 0.9922 | Reward (Avg): 0.9356 | CLIP: 26.2856 | Aesthetic: 6.0066 | PickScore: 22.0058 | HPSv2: 0.2830 | HPSv2.1: 0.2964 | ImageReward: 0.9735
|
| 97 |
+
|
| 98 |
+
[gradient_ascent] Batch 90/500 | Samples: 90/500 | Reward (t=136.0): 0.9961 | Reward (Avg): 0.9348 | CLIP: 26.5472 | Aesthetic: 5.9998 | PickScore: 21.9813 | HPSv2: 0.2825 | HPSv2.1: 0.2964 | ImageReward: 0.9939
|
| 99 |
+
|
| 100 |
+
[gradient_ascent] Batch 100/500 | Samples: 100/500 | Reward (t=136.0): 0.9844 | Reward (Avg): 0.9393 | CLIP: 26.8104 | Aesthetic: 6.0321 | PickScore: 21.9362 | HPSv2: 0.2817 | HPSv2.1: 0.2947 | ImageReward: 0.9892
|
| 101 |
+
|
| 102 |
+
[gradient_ascent] Batch 110/500 | Samples: 110/500 | Reward (t=136.0): 0.9844 | Reward (Avg): 0.9385 | CLIP: 26.9949 | Aesthetic: 6.0235 | PickScore: 21.9147 | HPSv2: 0.2817 | HPSv2.1: 0.2942 | ImageReward: 0.9906
|
| 103 |
+
|
| 104 |
+
[gradient_ascent] Batch 120/500 | Samples: 120/500 | Reward (t=136.0): 0.7188 | Reward (Avg): 0.9285 | CLIP: 26.8847 | Aesthetic: 6.0040 | PickScore: 21.8963 | HPSv2: 0.2815 | HPSv2.1: 0.2942 | ImageReward: 1.0242
|
| 105 |
+
|
| 106 |
+
[gradient_ascent] Batch 130/500 | Samples: 130/500 | Reward (t=136.0): 0.9883 | Reward (Avg): 0.9322 | CLIP: 26.9108 | Aesthetic: 6.0127 | PickScore: 21.9706 | HPSv2: 0.2822 | HPSv2.1: 0.2949 | ImageReward: 1.0272
|
| 107 |
+
|
| 108 |
+
[gradient_ascent] Batch 140/500 | Samples: 140/500 | Reward (t=136.0): 0.9727 | Reward (Avg): 0.9334 | CLIP: 27.1204 | Aesthetic: 5.9798 | PickScore: 21.9565 | HPSv2: 0.2820 | HPSv2.1: 0.2939 | ImageReward: 1.0218
|
| 109 |
+
|
| 110 |
+
[gradient_ascent] Batch 150/500 | Samples: 150/500 | Reward (t=136.0): 0.9961 | Reward (Avg): 0.9266 | CLIP: 27.0076 | Aesthetic: 5.9507 | PickScore: 21.9421 | HPSv2: 0.2815 | HPSv2.1: 0.2927 | ImageReward: 1.0071
|
| 111 |
+
|
| 112 |
+
[gradient_ascent] Batch 160/500 | Samples: 160/500 | Reward (t=136.0): 0.9375 | Reward (Avg): 0.9261 | CLIP: 27.1794 | Aesthetic: 5.9462 | PickScore: 21.9868 | HPSv2: 0.2815 | HPSv2.1: 0.2927 | ImageReward: 1.0162
|
| 113 |
+
|
| 114 |
+
[gradient_ascent] Batch 170/500 | Samples: 170/500 | Reward (t=136.0): 0.8789 | Reward (Avg): 0.9254 | CLIP: 27.0033 | Aesthetic: 5.9378 | PickScore: 21.9386 | HPSv2: 0.2810 | HPSv2.1: 0.2915 | ImageReward: 1.0123
|
| 115 |
+
|
| 116 |
+
[gradient_ascent] Batch 180/500 | Samples: 180/500 | Reward (t=136.0): 0.9570 | Reward (Avg): 0.9212 | CLIP: 27.1150 | Aesthetic: 5.9410 | PickScore: 21.9510 | HPSv2: 0.2815 | HPSv2.1: 0.2922 | ImageReward: 1.0373
|
| 117 |
+
|
| 118 |
+
[gradient_ascent] Batch 190/500 | Samples: 190/500 | Reward (t=136.0): 0.7188 | Reward (Avg): 0.9218 | CLIP: 27.0857 | Aesthetic: 5.9454 | PickScore: 21.9378 | HPSv2: 0.2810 | HPSv2.1: 0.2915 | ImageReward: 1.0217
|
| 119 |
+
|
| 120 |
+
[gradient_ascent] Batch 200/500 | Samples: 200/500 | Reward (t=136.0): 0.9609 | Reward (Avg): 0.9220 | CLIP: 27.0541 | Aesthetic: 5.9446 | PickScore: 21.9444 | HPSv2: 0.2812 | HPSv2.1: 0.2922 | ImageReward: 1.0335
|
| 121 |
+
|
| 122 |
+
[gradient_ascent] Batch 210/500 | Samples: 210/500 | Reward (t=136.0): 0.9922 | Reward (Avg): 0.9241 | CLIP: 27.0499 | Aesthetic: 5.9568 | PickScore: 21.9480 | HPSv2: 0.2812 | HPSv2.1: 0.2930 | ImageReward: 1.0426
|
| 123 |
+
|
| 124 |
+
[gradient_ascent] Batch 220/500 | Samples: 220/500 | Reward (t=136.0): 0.9922 | Reward (Avg): 0.9244 | CLIP: 27.1198 | Aesthetic: 5.9586 | PickScore: 21.9354 | HPSv2: 0.2812 | HPSv2.1: 0.2932 | ImageReward: 1.0333
|
| 125 |
+
|
| 126 |
+
[gradient_ascent] Batch 230/500 | Samples: 230/500 | Reward (t=136.0): 0.8281 | Reward (Avg): 0.9250 | CLIP: 27.1823 | Aesthetic: 5.9633 | PickScore: 21.9462 | HPSv2: 0.2812 | HPSv2.1: 0.2927 | ImageReward: 1.0163
|
| 127 |
+
|
| 128 |
+
[gradient_ascent] Batch 240/500 | Samples: 240/500 | Reward (t=136.0): 0.9453 | Reward (Avg): 0.9262 | CLIP: 27.0908 | Aesthetic: 5.9750 | PickScore: 21.9335 | HPSv2: 0.2808 | HPSv2.1: 0.2920 | ImageReward: 1.0059
|
| 129 |
+
|
| 130 |
+
[gradient_ascent] Batch 250/500 | Samples: 250/500 | Reward (t=136.0): 0.9961 | Reward (Avg): 0.9277 | CLIP: 27.0939 | Aesthetic: 5.9814 | PickScore: 21.9393 | HPSv2: 0.2810 | HPSv2.1: 0.2922 | ImageReward: 1.0046
|
| 131 |
+
|
| 132 |
+
[gradient_ascent] Batch 260/500 | Samples: 260/500 | Reward (t=136.0): 0.9648 | Reward (Avg): 0.9273 | CLIP: 27.0651 | Aesthetic: 5.9886 | PickScore: 21.9368 | HPSv2: 0.2810 | HPSv2.1: 0.2922 | ImageReward: 1.0170
|
| 133 |
+
|
| 134 |
+
[gradient_ascent] Batch 270/500 | Samples: 270/500 | Reward (t=136.0): 1.0000 | Reward (Avg): 0.9268 | CLIP: 26.9957 | Aesthetic: 5.9890 | PickScore: 21.9255 | HPSv2: 0.2810 | HPSv2.1: 0.2922 | ImageReward: 1.0180
|
| 135 |
+
|
| 136 |
+
[gradient_ascent] Batch 280/500 | Samples: 280/500 | Reward (t=136.0): 0.9688 | Reward (Avg): 0.9280 | CLIP: 27.0482 | Aesthetic: 5.9911 | PickScore: 21.9605 | HPSv2: 0.2812 | HPSv2.1: 0.2930 | ImageReward: 1.0319
|
| 137 |
+
|
| 138 |
+
[gradient_ascent] Batch 290/500 | Samples: 290/500 | Reward (t=136.0): 0.9688 | Reward (Avg): 0.9285 | CLIP: 26.9579 | Aesthetic: 5.9804 | PickScore: 21.9239 | HPSv2: 0.2810 | HPSv2.1: 0.2922 | ImageReward: 1.0166
|
| 139 |
+
|
| 140 |
+
[gradient_ascent] Batch 300/500 | Samples: 300/500 | Reward (t=136.0): 1.0000 | Reward (Avg): 0.9295 | CLIP: 26.9421 | Aesthetic: 5.9830 | PickScore: 21.9336 | HPSv2: 0.2812 | HPSv2.1: 0.2927 | ImageReward: 1.0217
|
| 141 |
+
|
| 142 |
+
[gradient_ascent] Batch 310/500 | Samples: 310/500 | Reward (t=136.0): 0.9961 | Reward (Avg): 0.9297 | CLIP: 27.0328 | Aesthetic: 5.9872 | PickScore: 21.9477 | HPSv2: 0.2815 | HPSv2.1: 0.2932 | ImageReward: 1.0419
|
| 143 |
+
|
| 144 |
+
[gradient_ascent] Batch 320/500 | Samples: 320/500 | Reward (t=136.0): 0.9492 | Reward (Avg): 0.9282 | CLIP: 27.0486 | Aesthetic: 5.9830 | PickScore: 21.9641 | HPSv2: 0.2812 | HPSv2.1: 0.2930 | ImageReward: 1.0378
|
| 145 |
+
|
| 146 |
+
[gradient_ascent] Batch 330/500 | Samples: 330/500 | Reward (t=136.0): 0.9961 | Reward (Avg): 0.9291 | CLIP: 27.0375 | Aesthetic: 5.9877 | PickScore: 21.9561 | HPSv2: 0.2812 | HPSv2.1: 0.2930 | ImageReward: 1.0415
|
| 147 |
+
|
| 148 |
+
[gradient_ascent] Batch 340/500 | Samples: 340/500 | Reward (t=136.0): 0.9688 | Reward (Avg): 0.9295 | CLIP: 27.0379 | Aesthetic: 5.9848 | PickScore: 21.9602 | HPSv2: 0.2815 | HPSv2.1: 0.2935 | ImageReward: 1.0553
|
| 149 |
+
|
| 150 |
+
[gradient_ascent] Batch 350/500 | Samples: 350/500 | Reward (t=136.0): 0.7539 | Reward (Avg): 0.9292 | CLIP: 27.0467 | Aesthetic: 5.9841 | PickScore: 21.9707 | HPSv2: 0.2812 | HPSv2.1: 0.2935 | ImageReward: 1.0468
|
| 151 |
+
|
| 152 |
+
[gradient_ascent] Batch 360/500 | Samples: 360/500 | Reward (t=136.0): 0.9609 | Reward (Avg): 0.9277 | CLIP: 27.0976 | Aesthetic: 5.9827 | PickScore: 21.9736 | HPSv2: 0.2815 | HPSv2.1: 0.2939 | ImageReward: 1.0415
|
| 153 |
+
|
| 154 |
+
[gradient_ascent] Batch 370/500 | Samples: 370/500 | Reward (t=136.0): 0.9453 | Reward (Avg): 0.9267 | CLIP: 27.0320 | Aesthetic: 5.9800 | PickScore: 21.9612 | HPSv2: 0.2815 | HPSv2.1: 0.2935 | ImageReward: 1.0372
|
| 155 |
+
|
| 156 |
+
[gradient_ascent] Batch 380/500 | Samples: 380/500 | Reward (t=136.0): 0.7891 | Reward (Avg): 0.9272 | CLIP: 27.1124 | Aesthetic: 5.9873 | PickScore: 21.9642 | HPSv2: 0.2815 | HPSv2.1: 0.2939 | ImageReward: 1.0436
|
| 157 |
+
|
| 158 |
+
[gradient_ascent] Batch 390/500 | Samples: 390/500 | Reward (t=136.0): 0.9922 | Reward (Avg): 0.9274 | CLIP: 26.9883 | Aesthetic: 5.9833 | PickScore: 21.9461 | HPSv2: 0.2810 | HPSv2.1: 0.2930 | ImageReward: 1.0313
|
| 159 |
+
|
| 160 |
+
[gradient_ascent] Batch 400/500 | Samples: 400/500 | Reward (t=136.0): 0.8711 | Reward (Avg): 0.9275 | CLIP: 26.9427 | Aesthetic: 5.9776 | PickScore: 21.9448 | HPSv2: 0.2810 | HPSv2.1: 0.2930 | ImageReward: 1.0291
|
| 161 |
+
|
| 162 |
+
[gradient_ascent] Batch 410/500 | Samples: 410/500 | Reward (t=136.0): 0.3535 | Reward (Avg): 0.9265 | CLIP: 26.8892 | Aesthetic: 5.9751 | PickScore: 21.9345 | HPSv2: 0.2808 | HPSv2.1: 0.2925 | ImageReward: 1.0099
|
| 163 |
+
|
| 164 |
+
[gradient_ascent] Batch 420/500 | Samples: 420/500 | Reward (t=136.0): 0.9844 | Reward (Avg): 0.9249 | CLIP: 26.8305 | Aesthetic: 5.9748 | PickScore: 21.9225 | HPSv2: 0.2805 | HPSv2.1: 0.2922 | ImageReward: 1.0087
|
| 165 |
+
|
| 166 |
+
[gradient_ascent] Batch 430/500 | Samples: 430/500 | Reward (t=136.0): 0.9609 | Reward (Avg): 0.9245 | CLIP: 26.8481 | Aesthetic: 5.9711 | PickScore: 21.9269 | HPSv2: 0.2805 | HPSv2.1: 0.2922 | ImageReward: 1.0131
|
| 167 |
+
|
| 168 |
+
[gradient_ascent] Batch 440/500 | Samples: 440/500 | Reward (t=136.0): 0.8984 | Reward (Avg): 0.9250 | CLIP: 26.7605 | Aesthetic: 5.9752 | PickScore: 21.9258 | HPSv2: 0.2805 | HPSv2.1: 0.2922 | ImageReward: 1.0092
|
| 169 |
+
|
| 170 |
+
[gradient_ascent] Batch 450/500 | Samples: 450/500 | Reward (t=136.0): 0.9531 | Reward (Avg): 0.9243 | CLIP: 26.7746 | Aesthetic: 5.9678 | PickScore: 21.9262 | HPSv2: 0.2803 | HPSv2.1: 0.2917 | ImageReward: 1.0003
|
| 171 |
+
|
| 172 |
+
[gradient_ascent] Batch 460/500 | Samples: 460/500 | Reward (t=136.0): 0.9805 | Reward (Avg): 0.9246 | CLIP: 26.7613 | Aesthetic: 5.9665 | PickScore: 21.9229 | HPSv2: 0.2803 | HPSv2.1: 0.2915 | ImageReward: 0.9966
|
| 173 |
+
|
| 174 |
+
[gradient_ascent] Batch 470/500 | Samples: 470/500 | Reward (t=136.0): 0.9766 | Reward (Avg): 0.9251 | CLIP: 26.6828 | Aesthetic: 5.9578 | PickScore: 21.9037 | HPSv2: 0.2800 | HPSv2.1: 0.2908 | ImageReward: 0.9823
|
| 175 |
+
|
| 176 |
+
[gradient_ascent] Batch 480/500 | Samples: 480/500 | Reward (t=136.0): 0.9688 | Reward (Avg): 0.9234 | CLIP: 26.6169 | Aesthetic: 5.9549 | PickScore: 21.8937 | HPSv2: 0.2798 | HPSv2.1: 0.2905 | ImageReward: 0.9846
|
| 177 |
+
|
| 178 |
+
[gradient_ascent] Batch 490/500 | Samples: 490/500 | Reward (t=136.0): 0.9688 | Reward (Avg): 0.9227 | CLIP: 26.6094 | Aesthetic: 5.9548 | PickScore: 21.8840 | HPSv2: 0.2798 | HPSv2.1: 0.2903 | ImageReward: 0.9847
|
| 179 |
+
|
| 180 |
+
[gradient_ascent] Batch 500/500 | Samples: 500/500 | Reward (t=136.0): 0.9805 | Reward (Avg): 0.9229 | CLIP: 26.6166 | Aesthetic: 5.9560 | PickScore: 21.8818 | HPSv2: 0.2798 | HPSv2.1: 0.2905 | ImageReward: 0.9889
|
| 181 |
+
✓ Gradient Ascent Avg Reward: 0.9229
|
| 182 |
+
✓ Gradient Ascent Avg CLIP Score: 26.6166
|
| 183 |
+
✓ Gradient Ascent Avg Aesthetic Score: 5.9560
|
| 184 |
+
✓ Gradient Ascent Avg PickScore: 21.8818
|
| 185 |
+
✓ Gradient Ascent Avg HPSv2 Score: 0.2798
|
| 186 |
+
✓ Gradient Ascent Avg HPSv2.1 Score: 0.2905
|
| 187 |
+
✓ Gradient Ascent Avg ImageReward: 0.9889
|
| 188 |
+
|
| 189 |
+
Gradient Ascent Statistics:
|
| 190 |
+
Applications: 11
|
| 191 |
+
Total reward improvement: +0.0781
|
| 192 |
+
Avg reward improvement: +0.0071
|
| 193 |
+
|
| 194 |
+
✓ Saved LR curve plot to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_3/lr_curve.png
|
| 195 |
+
Total gradient steps: 11
|
| 196 |
+
LR range: 1.000000 → 1.000000
|
| 197 |
+
|
| 198 |
+
✓ Saved Rewards curve plot to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_3/rewards_curve.png
|
| 199 |
+
Total gradient steps: 22
|
| 200 |
+
Reward range: 0.9922 → 1.0000
|
| 201 |
+
Total improvement: +0.0078
|
| 202 |
+
|
| 203 |
+
======================================================================
|
| 204 |
+
FINAL RESULTS
|
| 205 |
+
======================================================================
|
| 206 |
+
|
| 207 |
+
Gradient Ascent:
|
| 208 |
+
Avg Reward: 0.9229
|
| 209 |
+
Avg CLIP Score: 26.6166
|
| 210 |
+
Avg Aesthetic: 5.9560
|
| 211 |
+
Avg PickScore: 21.8818
|
| 212 |
+
Avg HPSv2: 0.2798
|
| 213 |
+
Avg HPSv2.1: 0.2905
|
| 214 |
+
Avg ImageReward: 0.9889
|
| 215 |
+
|
| 216 |
+
✓ Results saved to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_3/evaluation_results.txt
|
| 217 |
+
|
| 218 |
+
======================================================================
|
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_3/lr_curve.png
ADDED
|
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/evaluation_results.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
mode: gradient_ascent
|
| 2 |
+
metrics: ['clip', 'aesthetic', 'pickscore', 'hpsv2', 'hpsv21', 'imagereward']
|
| 3 |
+
config: {'num_samples': 500, 'num_steps': 20, 'cfg_scale': 4.5, 'grad_range': [0, 700], 'grad_steps': 5, 'grad_step_size': 0.1}
|
| 4 |
+
gradient_ascent: {'avg_reward': np.float64(0.86123828125), 'clip_score': np.float64(26.688936717987062), 'aesthetic_score': np.float64(5.964459408760071), 'pickscore': np.float64(21.888245372772218), 'hpsv2_score': np.float16(0.2798), 'hpsv21_score': np.float16(0.2896), 'imagereward_score': np.float64(0.9574673203025014), 'stats': {'num_applications': 11, 'total_reward_improvement': 0.3046875, 'avg_reward_improvement': 0.027698863636363636, 'avg_grad_norm': 0.20624772933396426, 'max_grad_norm': 0.24269965291023254, 'detailed_stats': [{'timestep': 785, 'initial_reward': 0.65234375, 'final_reward': 0.703125, 'reward_improvement': 0.05078125, 'grad_norms': [0.237853541970253], 'reward_history': [0.65234375, 0.65234375], 'lr_history': [1.0], 'latent_change': 0.9999999403953552}, {'timestep': 749, 'initial_reward': 0.71484375, 'final_reward': 0.76171875, 'reward_improvement': 0.046875, 'grad_norms': [0.24269965291023254], 'reward_history': [0.71484375, 0.71484375], 'lr_history': [1.0], 'latent_change': 0.9999999403953552}, {'timestep': 710, 'initial_reward': 0.76171875, 'final_reward': 0.80078125, 'reward_improvement': 0.0390625, 'grad_norms': [0.23808617889881134], 'reward_history': [0.76171875, 0.76171875], 'lr_history': [1.0], 'latent_change': 1.0}, {'timestep': 666, 'initial_reward': 0.80078125, 'final_reward': 0.83203125, 'reward_improvement': 0.03125, 'grad_norms': [0.2161739617586136], 'reward_history': [0.80078125, 0.80078125], 'lr_history': [1.0], 'latent_change': 0.9999998807907104}, {'timestep': 617, 'initial_reward': 0.828125, 'final_reward': 0.85546875, 'reward_improvement': 0.02734375, 'grad_norms': [0.20163142681121826], 'reward_history': [0.828125, 0.828125], 'lr_history': [1.0], 'latent_change': 1.0}, {'timestep': 562, 'initial_reward': 0.8515625, 'final_reward': 0.87109375, 'reward_improvement': 0.01953125, 'grad_norms': [0.1897481232881546], 'reward_history': [0.8515625, 0.8515625], 'lr_history': [1.0], 'latent_change': 0.9999999403953552}, {'timestep': 499, 'initial_reward': 0.8671875, 'final_reward': 0.88671875, 'reward_improvement': 0.01953125, 'grad_norms': [0.181509330868721], 'reward_history': [0.8671875, 0.8671875], 'lr_history': [1.0], 'latent_change': 0.9999998807907104}, {'timestep': 428, 'initial_reward': 0.875, 'final_reward': 0.89453125, 'reward_improvement': 0.01953125, 'grad_norms': [0.18168200552463531], 'reward_history': [0.875, 0.875], 'lr_history': [1.0], 'latent_change': 0.9999998807907104}, {'timestep': 345, 'initial_reward': 0.88671875, 'final_reward': 0.90234375, 'reward_improvement': 0.015625, 'grad_norms': [0.185561865568161], 'reward_history': [0.88671875, 0.88671875], 'lr_history': [1.0], 'latent_change': 0.9999999403953552}, {'timestep': 249, 'initial_reward': 0.890625, 'final_reward': 0.90625, 'reward_improvement': 0.015625, 'grad_norms': [0.19334888458251953], 'reward_history': [0.890625, 0.890625], 'lr_history': [1.0], 'latent_change': 0.9999998807907104}, {'timestep': 136, 'initial_reward': 0.890625, 'final_reward': 0.91015625, 'reward_improvement': 0.01953125, 'grad_norms': [0.20043005049228668], 'reward_history': [0.890625, 0.890625], 'lr_history': [1.0], 'latent_change': 0.9999999403953552}]}}
|
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/log.log
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
======================================================================
|
| 2 |
+
FID EVALUATION: BASELINE vs GRADIENT ASCENT
|
| 3 |
+
======================================================================
|
| 4 |
+
|
| 5 |
+
Logging to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/log.log
|
| 6 |
+
|
| 7 |
+
Device: cuda:0
|
| 8 |
+
Dataset: PICKAPIC
|
| 9 |
+
Data directory: ./data
|
| 10 |
+
Base model: Efficient-Large-Model/Sana_600M_512px_diffusers
|
| 11 |
+
Model variant: sana_600m_512
|
| 12 |
+
LRM model: /g/data/rr81/LPO/lrm/lrm_sana/logs/v7/reward_model/step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951/checkpoint-gstep32000
|
| 13 |
+
HF cache dir: /scratch/rr81/ma5430/.cache/huggingface/hub
|
| 14 |
+
HF offline mode: True
|
| 15 |
+
Inference steps: 20
|
| 16 |
+
CFG scale: 4.5
|
| 17 |
+
Batch size: 1
|
| 18 |
+
Max samples: All
|
| 19 |
+
Generation dtype: bf16
|
| 20 |
+
Output directory: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4
|
| 21 |
+
Save images: False
|
| 22 |
+
Evaluation mode: gradient_ascent
|
| 23 |
+
Metrics to evaluate: CLIP, AESTHETIC, PICKSCORE, HPSV2, HPSV21, IMAGEREWARD
|
| 24 |
+
Gradient ascent config: one_step_rectification_config
|
| 25 |
+
|
| 26 |
+
======================================================================
|
| 27 |
+
1. LOADING VALIDATION DATA
|
| 28 |
+
======================================================================
|
| 29 |
+
Loading Pick-a-Pic validation prompts...
|
| 30 |
+
Loading cached Pick-a-Pic split 'validation_unique' from 1 parquet shards
|
| 31 |
+
cache=/scratch/rr81/ma5430/.cache/huggingface/hub/datasets--pickapic-anonymous--pickapic_v1
|
| 32 |
+
Loaded 500 Pick-a-Pic validation samples
|
| 33 |
+
|
| 34 |
+
======================================================================
|
| 35 |
+
2. LOADING REWARD MODEL
|
| 36 |
+
======================================================================
|
| 37 |
+
Loading SANA base reward backbone from Efficient-Large-Model/Sana_600M_512px_diffusers...
|
| 38 |
+
Loading SANA reward checkpoint from /g/data/rr81/LPO/lrm/lrm_sana/logs/v7/reward_model/step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951/checkpoint-gstep32000/model.safetensors...
|
| 39 |
+
✓ Loaded checkpoint keys: 1214
|
| 40 |
+
✓ Missing keys: 0 | Unexpected keys: 0
|
| 41 |
+
✓ SANA LRM Reward Model initialized successfully!
|
| 42 |
+
✓ Reward model loaded
|
| 43 |
+
|
| 44 |
+
======================================================================
|
| 45 |
+
3. LOADING PIPELINE
|
| 46 |
+
======================================================================
|
| 47 |
+
✓ Loaded SANA base model: Efficient-Large-Model/Sana_600M_512px_diffusers
|
| 48 |
+
✓ Reward model attached to SANA pipeline
|
| 49 |
+
✓ Pipeline loaded
|
| 50 |
+
GPU memory before scorer load: 85.91 GB free / 140.06 GB total
|
| 51 |
+
Scorer device: cuda:0
|
| 52 |
+
|
| 53 |
+
======================================================================
|
| 54 |
+
3.5. LOADING CLIP AND AESTHETIC SCORERS
|
| 55 |
+
======================================================================
|
| 56 |
+
✓ CLIP scorer loaded
|
| 57 |
+
✓ Aesthetic scorer loaded
|
| 58 |
+
✓ PickScore scorer loaded
|
| 59 |
+
✓ HPSv2 scorer loaded
|
| 60 |
+
✓ HPSv2.1 scorer loaded
|
| 61 |
+
load checkpoint from /scratch/rr81/ma5430/.cache/huggingface/hub/models--THUDM--ImageReward/snapshots/5736be03b2652728fb87788c9797b0570450ab72/ImageReward.pt
|
| 62 |
+
checkpoint loaded
|
| 63 |
+
✓ ImageReward scorer loaded
|
| 64 |
+
|
| 65 |
+
======================================================================
|
| 66 |
+
4. CONFIGURING GRADIENT ASCENT
|
| 67 |
+
======================================================================
|
| 68 |
+
Loading gradient ascent config: one_step_rectification_config
|
| 69 |
+
Config loaded: {'grad_timestep_range': (100, 800), 'num_grad_steps': 1, 'grad_step_size': 1.0, 'grad_scale': 1.0, 'lr_scheduler_type': 'constant', 'use_momentum': False, 'use_nesterov': False, 'use_iso_projection': False}
|
| 70 |
+
Gradient timestep range: (100, 800)
|
| 71 |
+
Gradient steps: 1
|
| 72 |
+
Gradient step size (initial LR): 1.0
|
| 73 |
+
LR Scheduler: constant
|
| 74 |
+
✓ Gradient ascent enabled for timesteps (100, 800)
|
| 75 |
+
|
| 76 |
+
======================================================================
|
| 77 |
+
6. EVALUATING GRADIENT ASCENT
|
| 78 |
+
======================================================================
|
| 79 |
+
|
| 80 |
+
Generating images with gradient_ascent mode...
|
| 81 |
+
|
| 82 |
+
[gradient_ascent] Batch 10/500 | Samples: 10/500 | Reward (t=136.0): 0.9883 | Reward (Avg): 0.9230 | CLIP: 27.2524 | Aesthetic: 6.1769 | PickScore: 22.0294 | HPSv2: 0.2856 | HPSv2.1: 0.3088 | ImageReward: 1.2917
|
| 83 |
+
|
| 84 |
+
[gradient_ascent] Batch 20/500 | Samples: 20/500 | Reward (t=136.0): 0.8672 | Reward (Avg): 0.9189 | CLIP: 26.8319 | Aesthetic: 6.1308 | PickScore: 22.1532 | HPSv2: 0.2854 | HPSv2.1: 0.3091 | ImageReward: 1.1375
|
| 85 |
+
|
| 86 |
+
[gradient_ascent] Batch 30/500 | Samples: 30/500 | Reward (t=136.0): 0.9922 | Reward (Avg): 0.9047 | CLIP: 26.7728 | Aesthetic: 6.0477 | PickScore: 22.3591 | HPSv2: 0.2854 | HPSv2.1: 0.3052 | ImageReward: 1.1096
|
| 87 |
+
|
| 88 |
+
[gradient_ascent] Batch 40/500 | Samples: 40/500 | Reward (t=136.0): 0.8672 | Reward (Avg): 0.9026 | CLIP: 26.8790 | Aesthetic: 6.1062 | PickScore: 22.3062 | HPSv2: 0.2839 | HPSv2.1: 0.3013 | ImageReward: 0.9804
|
| 89 |
+
|
| 90 |
+
[gradient_ascent] Batch 50/500 | Samples: 50/500 | Reward (t=136.0): 0.6523 | Reward (Avg): 0.8916 | CLIP: 26.6755 | Aesthetic: 6.0488 | PickScore: 22.1421 | HPSv2: 0.2827 | HPSv2.1: 0.2991 | ImageReward: 1.0652
|
| 91 |
+
|
| 92 |
+
[gradient_ascent] Batch 60/500 | Samples: 60/500 | Reward (t=136.0): 0.9648 | Reward (Avg): 0.8891 | CLIP: 26.5230 | Aesthetic: 6.0566 | PickScore: 22.1336 | HPSv2: 0.2832 | HPSv2.1: 0.2991 | ImageReward: 1.0473
|
| 93 |
+
|
| 94 |
+
[gradient_ascent] Batch 70/500 | Samples: 70/500 | Reward (t=136.0): 0.9805 | Reward (Avg): 0.8851 | CLIP: 26.6211 | Aesthetic: 6.0387 | PickScore: 22.1567 | HPSv2: 0.2834 | HPSv2.1: 0.2996 | ImageReward: 1.0103
|
| 95 |
+
|
| 96 |
+
[gradient_ascent] Batch 80/500 | Samples: 80/500 | Reward (t=136.0): 0.9141 | Reward (Avg): 0.8869 | CLIP: 26.5639 | Aesthetic: 6.0182 | PickScore: 22.0309 | HPSv2: 0.2825 | HPSv2.1: 0.2947 | ImageReward: 0.9646
|
| 97 |
+
|
| 98 |
+
[gradient_ascent] Batch 90/500 | Samples: 90/500 | Reward (t=136.0): 0.6836 | Reward (Avg): 0.8837 | CLIP: 26.7986 | Aesthetic: 6.0132 | PickScore: 22.0176 | HPSv2: 0.2825 | HPSv2.1: 0.2952 | ImageReward: 0.9975
|
| 99 |
+
|
| 100 |
+
[gradient_ascent] Batch 100/500 | Samples: 100/500 | Reward (t=136.0): 0.5234 | Reward (Avg): 0.8824 | CLIP: 27.0476 | Aesthetic: 6.0459 | PickScore: 21.9703 | HPSv2: 0.2815 | HPSv2.1: 0.2932 | ImageReward: 0.9856
|
| 101 |
+
|
| 102 |
+
[gradient_ascent] Batch 110/500 | Samples: 110/500 | Reward (t=136.0): 0.8086 | Reward (Avg): 0.8786 | CLIP: 27.2317 | Aesthetic: 6.0274 | PickScore: 21.9682 | HPSv2: 0.2820 | HPSv2.1: 0.2935 | ImageReward: 0.9922
|
| 103 |
+
|
| 104 |
+
[gradient_ascent] Batch 120/500 | Samples: 120/500 | Reward (t=136.0): 0.5312 | Reward (Avg): 0.8745 | CLIP: 27.0959 | Aesthetic: 6.0128 | PickScore: 21.9585 | HPSv2: 0.2817 | HPSv2.1: 0.2935 | ImageReward: 1.0282
|
| 105 |
+
|
| 106 |
+
[gradient_ascent] Batch 130/500 | Samples: 130/500 | Reward (t=136.0): 0.8984 | Reward (Avg): 0.8748 | CLIP: 27.1302 | Aesthetic: 6.0250 | PickScore: 22.0331 | HPSv2: 0.2825 | HPSv2.1: 0.2944 | ImageReward: 1.0238
|
| 107 |
+
|
| 108 |
+
[gradient_ascent] Batch 140/500 | Samples: 140/500 | Reward (t=136.0): 0.8438 | Reward (Avg): 0.8711 | CLIP: 27.3238 | Aesthetic: 5.9973 | PickScore: 22.0136 | HPSv2: 0.2822 | HPSv2.1: 0.2935 | ImageReward: 1.0193
|
| 109 |
+
|
| 110 |
+
[gradient_ascent] Batch 150/500 | Samples: 150/500 | Reward (t=136.0): 0.9844 | Reward (Avg): 0.8711 | CLIP: 27.2450 | Aesthetic: 5.9757 | PickScore: 21.9954 | HPSv2: 0.2817 | HPSv2.1: 0.2922 | ImageReward: 1.0077
|
| 111 |
+
|
| 112 |
+
[gradient_ascent] Batch 160/500 | Samples: 160/500 | Reward (t=136.0): 0.8828 | Reward (Avg): 0.8731 | CLIP: 27.3797 | Aesthetic: 5.9691 | PickScore: 22.0329 | HPSv2: 0.2817 | HPSv2.1: 0.2922 | ImageReward: 1.0169
|
| 113 |
+
|
| 114 |
+
[gradient_ascent] Batch 170/500 | Samples: 170/500 | Reward (t=136.0): 0.8477 | Reward (Avg): 0.8723 | CLIP: 27.1880 | Aesthetic: 5.9651 | PickScore: 21.9791 | HPSv2: 0.2812 | HPSv2.1: 0.2910 | ImageReward: 1.0030
|
| 115 |
+
|
| 116 |
+
[gradient_ascent] Batch 180/500 | Samples: 180/500 | Reward (t=136.0): 0.8906 | Reward (Avg): 0.8697 | CLIP: 27.2816 | Aesthetic: 5.9708 | PickScore: 21.9840 | HPSv2: 0.2815 | HPSv2.1: 0.2917 | ImageReward: 1.0276
|
| 117 |
+
|
| 118 |
+
[gradient_ascent] Batch 190/500 | Samples: 190/500 | Reward (t=136.0): 0.8672 | Reward (Avg): 0.8692 | CLIP: 27.2303 | Aesthetic: 5.9694 | PickScore: 21.9639 | HPSv2: 0.2810 | HPSv2.1: 0.2910 | ImageReward: 1.0080
|
| 119 |
+
|
| 120 |
+
[gradient_ascent] Batch 200/500 | Samples: 200/500 | Reward (t=136.0): 0.8281 | Reward (Avg): 0.8684 | CLIP: 27.2174 | Aesthetic: 5.9641 | PickScore: 21.9716 | HPSv2: 0.2815 | HPSv2.1: 0.2917 | ImageReward: 1.0208
|
| 121 |
+
|
| 122 |
+
[gradient_ascent] Batch 210/500 | Samples: 210/500 | Reward (t=136.0): 0.8789 | Reward (Avg): 0.8703 | CLIP: 27.1992 | Aesthetic: 5.9754 | PickScore: 21.9724 | HPSv2: 0.2812 | HPSv2.1: 0.2925 | ImageReward: 1.0306
|
| 123 |
+
|
| 124 |
+
[gradient_ascent] Batch 220/500 | Samples: 220/500 | Reward (t=136.0): 0.9375 | Reward (Avg): 0.8702 | CLIP: 27.2434 | Aesthetic: 5.9779 | PickScore: 21.9538 | HPSv2: 0.2815 | HPSv2.1: 0.2922 | ImageReward: 1.0132
|
| 125 |
+
|
| 126 |
+
[gradient_ascent] Batch 230/500 | Samples: 230/500 | Reward (t=136.0): 0.6250 | Reward (Avg): 0.8715 | CLIP: 27.2932 | Aesthetic: 5.9839 | PickScore: 21.9659 | HPSv2: 0.2812 | HPSv2.1: 0.2920 | ImageReward: 0.9930
|
| 127 |
+
|
| 128 |
+
[gradient_ascent] Batch 240/500 | Samples: 240/500 | Reward (t=136.0): 0.8594 | Reward (Avg): 0.8721 | CLIP: 27.1944 | Aesthetic: 5.9938 | PickScore: 21.9538 | HPSv2: 0.2808 | HPSv2.1: 0.2913 | ImageReward: 0.9821
|
| 129 |
+
|
| 130 |
+
[gradient_ascent] Batch 250/500 | Samples: 250/500 | Reward (t=136.0): 0.8047 | Reward (Avg): 0.8723 | CLIP: 27.1940 | Aesthetic: 5.9992 | PickScore: 21.9601 | HPSv2: 0.2810 | HPSv2.1: 0.2913 | ImageReward: 0.9797
|
| 131 |
+
|
| 132 |
+
[gradient_ascent] Batch 260/500 | Samples: 260/500 | Reward (t=136.0): 0.9297 | Reward (Avg): 0.8711 | CLIP: 27.1640 | Aesthetic: 6.0070 | PickScore: 21.9547 | HPSv2: 0.2810 | HPSv2.1: 0.2915 | ImageReward: 0.9918
|
| 133 |
+
|
| 134 |
+
[gradient_ascent] Batch 270/500 | Samples: 270/500 | Reward (t=136.0): 0.9961 | Reward (Avg): 0.8714 | CLIP: 27.1096 | Aesthetic: 6.0077 | PickScore: 21.9464 | HPSv2: 0.2810 | HPSv2.1: 0.2913 | ImageReward: 0.9900
|
| 135 |
+
|
| 136 |
+
[gradient_ascent] Batch 280/500 | Samples: 280/500 | Reward (t=136.0): 0.9258 | Reward (Avg): 0.8715 | CLIP: 27.1485 | Aesthetic: 6.0107 | PickScore: 21.9772 | HPSv2: 0.2812 | HPSv2.1: 0.2920 | ImageReward: 1.0051
|
| 137 |
+
|
| 138 |
+
[gradient_ascent] Batch 290/500 | Samples: 290/500 | Reward (t=136.0): 0.9062 | Reward (Avg): 0.8722 | CLIP: 27.0509 | Aesthetic: 5.9991 | PickScore: 21.9401 | HPSv2: 0.2810 | HPSv2.1: 0.2913 | ImageReward: 0.9888
|
| 139 |
+
|
| 140 |
+
[gradient_ascent] Batch 300/500 | Samples: 300/500 | Reward (t=136.0): 0.9844 | Reward (Avg): 0.8718 | CLIP: 27.0241 | Aesthetic: 6.0017 | PickScore: 21.9507 | HPSv2: 0.2812 | HPSv2.1: 0.2920 | ImageReward: 0.9941
|
| 141 |
+
|
| 142 |
+
[gradient_ascent] Batch 310/500 | Samples: 310/500 | Reward (t=136.0): 0.9375 | Reward (Avg): 0.8709 | CLIP: 27.0985 | Aesthetic: 6.0055 | PickScore: 21.9663 | HPSv2: 0.2815 | HPSv2.1: 0.2925 | ImageReward: 1.0157
|
| 143 |
+
|
| 144 |
+
[gradient_ascent] Batch 320/500 | Samples: 320/500 | Reward (t=136.0): 0.9492 | Reward (Avg): 0.8703 | CLIP: 27.1293 | Aesthetic: 6.0013 | PickScore: 21.9831 | HPSv2: 0.2815 | HPSv2.1: 0.2922 | ImageReward: 1.0116
|
| 145 |
+
|
| 146 |
+
[gradient_ascent] Batch 330/500 | Samples: 330/500 | Reward (t=136.0): 0.9609 | Reward (Avg): 0.8704 | CLIP: 27.1128 | Aesthetic: 6.0068 | PickScore: 21.9762 | HPSv2: 0.2815 | HPSv2.1: 0.2925 | ImageReward: 1.0155
|
| 147 |
+
|
| 148 |
+
[gradient_ascent] Batch 340/500 | Samples: 340/500 | Reward (t=136.0): 0.9648 | Reward (Avg): 0.8709 | CLIP: 27.1100 | Aesthetic: 6.0010 | PickScore: 21.9779 | HPSv2: 0.2812 | HPSv2.1: 0.2927 | ImageReward: 1.0264
|
| 149 |
+
|
| 150 |
+
[gradient_ascent] Batch 350/500 | Samples: 350/500 | Reward (t=136.0): 0.7969 | Reward (Avg): 0.8694 | CLIP: 27.1092 | Aesthetic: 5.9996 | PickScore: 21.9887 | HPSv2: 0.2812 | HPSv2.1: 0.2927 | ImageReward: 1.0177
|
| 151 |
+
|
| 152 |
+
[gradient_ascent] Batch 360/500 | Samples: 360/500 | Reward (t=136.0): 0.7461 | Reward (Avg): 0.8684 | CLIP: 27.1495 | Aesthetic: 5.9991 | PickScore: 21.9878 | HPSv2: 0.2815 | HPSv2.1: 0.2930 | ImageReward: 1.0115
|
| 153 |
+
|
| 154 |
+
[gradient_ascent] Batch 370/500 | Samples: 370/500 | Reward (t=136.0): 0.9336 | Reward (Avg): 0.8669 | CLIP: 27.0887 | Aesthetic: 5.9964 | PickScore: 21.9767 | HPSv2: 0.2815 | HPSv2.1: 0.2925 | ImageReward: 1.0073
|
| 155 |
+
|
| 156 |
+
[gradient_ascent] Batch 380/500 | Samples: 380/500 | Reward (t=136.0): 0.6836 | Reward (Avg): 0.8673 | CLIP: 27.1722 | Aesthetic: 6.0026 | PickScore: 21.9784 | HPSv2: 0.2812 | HPSv2.1: 0.2927 | ImageReward: 1.0130
|
| 157 |
+
|
| 158 |
+
[gradient_ascent] Batch 390/500 | Samples: 390/500 | Reward (t=136.0): 0.9688 | Reward (Avg): 0.8669 | CLIP: 27.0503 | Aesthetic: 5.9985 | PickScore: 21.9617 | HPSv2: 0.2810 | HPSv2.1: 0.2920 | ImageReward: 1.0016
|
| 159 |
+
|
| 160 |
+
[gradient_ascent] Batch 400/500 | Samples: 400/500 | Reward (t=136.0): 0.6484 | Reward (Avg): 0.8664 | CLIP: 27.0083 | Aesthetic: 5.9924 | PickScore: 21.9558 | HPSv2: 0.2810 | HPSv2.1: 0.2920 | ImageReward: 0.9993
|
| 161 |
+
|
| 162 |
+
[gradient_ascent] Batch 410/500 | Samples: 410/500 | Reward (t=136.0): 0.4082 | Reward (Avg): 0.8651 | CLIP: 26.9581 | Aesthetic: 5.9885 | PickScore: 21.9445 | HPSv2: 0.2808 | HPSv2.1: 0.2913 | ImageReward: 0.9801
|
| 163 |
+
|
| 164 |
+
[gradient_ascent] Batch 420/500 | Samples: 420/500 | Reward (t=136.0): 0.6602 | Reward (Avg): 0.8630 | CLIP: 26.9090 | Aesthetic: 5.9864 | PickScore: 21.9316 | HPSv2: 0.2805 | HPSv2.1: 0.2910 | ImageReward: 0.9757
|
| 165 |
+
|
| 166 |
+
[gradient_ascent] Batch 430/500 | Samples: 430/500 | Reward (t=136.0): 0.8164 | Reward (Avg): 0.8628 | CLIP: 26.9218 | Aesthetic: 5.9813 | PickScore: 21.9366 | HPSv2: 0.2805 | HPSv2.1: 0.2910 | ImageReward: 0.9807
|
| 167 |
+
|
| 168 |
+
[gradient_ascent] Batch 440/500 | Samples: 440/500 | Reward (t=136.0): 0.8516 | Reward (Avg): 0.8632 | CLIP: 26.8428 | Aesthetic: 5.9854 | PickScore: 21.9359 | HPSv2: 0.2805 | HPSv2.1: 0.2910 | ImageReward: 0.9783
|
| 169 |
+
|
| 170 |
+
[gradient_ascent] Batch 450/500 | Samples: 450/500 | Reward (t=136.0): 0.7812 | Reward (Avg): 0.8623 | CLIP: 26.8449 | Aesthetic: 5.9793 | PickScore: 21.9350 | HPSv2: 0.2803 | HPSv2.1: 0.2908 | ImageReward: 0.9707
|
| 171 |
+
|
| 172 |
+
[gradient_ascent] Batch 460/500 | Samples: 460/500 | Reward (t=136.0): 0.8086 | Reward (Avg): 0.8627 | CLIP: 26.8162 | Aesthetic: 5.9783 | PickScore: 21.9268 | HPSv2: 0.2803 | HPSv2.1: 0.2903 | ImageReward: 0.9679
|
| 173 |
+
|
| 174 |
+
[gradient_ascent] Batch 470/500 | Samples: 470/500 | Reward (t=136.0): 0.6836 | Reward (Avg): 0.8619 | CLIP: 26.7554 | Aesthetic: 5.9671 | PickScore: 21.9094 | HPSv2: 0.2798 | HPSv2.1: 0.2898 | ImageReward: 0.9538
|
| 175 |
+
|
| 176 |
+
[gradient_ascent] Batch 480/500 | Samples: 480/500 | Reward (t=136.0): 0.9141 | Reward (Avg): 0.8607 | CLIP: 26.6819 | Aesthetic: 5.9631 | PickScore: 21.8992 | HPSv2: 0.2798 | HPSv2.1: 0.2896 | ImageReward: 0.9561
|
| 177 |
+
|
| 178 |
+
[gradient_ascent] Batch 490/500 | Samples: 490/500 | Reward (t=136.0): 0.8906 | Reward (Avg): 0.8607 | CLIP: 26.6874 | Aesthetic: 5.9629 | PickScore: 21.8907 | HPSv2: 0.2795 | HPSv2.1: 0.2893 | ImageReward: 0.9554
|
| 179 |
+
|
| 180 |
+
[gradient_ascent] Batch 500/500 | Samples: 500/500 | Reward (t=136.0): 0.9102 | Reward (Avg): 0.8612 | CLIP: 26.6889 | Aesthetic: 5.9645 | PickScore: 21.8882 | HPSv2: 0.2798 | HPSv2.1: 0.2896 | ImageReward: 0.9575
|
| 181 |
+
✓ Gradient Ascent Avg Reward: 0.8612
|
| 182 |
+
✓ Gradient Ascent Avg CLIP Score: 26.6889
|
| 183 |
+
✓ Gradient Ascent Avg Aesthetic Score: 5.9645
|
| 184 |
+
✓ Gradient Ascent Avg PickScore: 21.8882
|
| 185 |
+
✓ Gradient Ascent Avg HPSv2 Score: 0.2798
|
| 186 |
+
✓ Gradient Ascent Avg HPSv2.1 Score: 0.2896
|
| 187 |
+
✓ Gradient Ascent Avg ImageReward: 0.9575
|
| 188 |
+
|
| 189 |
+
Gradient Ascent Statistics:
|
| 190 |
+
Applications: 11
|
| 191 |
+
Total reward improvement: +0.3047
|
| 192 |
+
Avg reward improvement: +0.0277
|
| 193 |
+
|
| 194 |
+
✓ Saved LR curve plot to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/lr_curve.png
|
| 195 |
+
Total gradient steps: 11
|
| 196 |
+
LR range: 1.000000 → 1.000000
|
| 197 |
+
|
| 198 |
+
✓ Saved Rewards curve plot to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/rewards_curve.png
|
| 199 |
+
Total gradient steps: 22
|
| 200 |
+
Reward range: 0.9727 → 0.9961
|
| 201 |
+
Total improvement: +0.0234
|
| 202 |
+
|
| 203 |
+
======================================================================
|
| 204 |
+
FINAL RESULTS
|
| 205 |
+
======================================================================
|
| 206 |
+
|
| 207 |
+
Gradient Ascent:
|
| 208 |
+
Avg Reward: 0.8612
|
| 209 |
+
Avg CLIP Score: 26.6889
|
| 210 |
+
Avg Aesthetic: 5.9645
|
| 211 |
+
Avg PickScore: 21.8882
|
| 212 |
+
Avg HPSv2: 0.2798
|
| 213 |
+
Avg HPSv2.1: 0.2896
|
| 214 |
+
Avg ImageReward: 0.9575
|
| 215 |
+
|
| 216 |
+
✓ Results saved to: RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/evaluation_results.txt
|
| 217 |
+
|
| 218 |
+
======================================================================
|
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/lr_curve.png
ADDED
|
Reward_sana_idealized/RESULTS/pickapic/one_step_rectification_config_sana_600m_512/run_4/rewards_curve.png
ADDED
|
Reward_sana_idealized/__pycache__/eval.cpython-311.pyc
ADDED
|
Binary file (75.8 kB). View file
|
|
|
Reward_sana_idealized/__pycache__/gradient_ascent_utils.cpython-311.pyc
ADDED
|
Binary file (17.1 kB). View file
|
|
|
Reward_sana_idealized/blip/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .blip_pretrain import *
|
Reward_sana_idealized/blip/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (202 Bytes). View file
|
|
|
Reward_sana_idealized/blip/__pycache__/blip.cpython-311.pyc
ADDED
|
Binary file (4.03 kB). View file
|
|
|
Reward_sana_idealized/blip/__pycache__/blip_pretrain.cpython-311.pyc
ADDED
|
Binary file (2.35 kB). View file
|
|
|
Reward_sana_idealized/blip/__pycache__/med.cpython-311.pyc
ADDED
|
Binary file (46.9 kB). View file
|
|
|
Reward_sana_idealized/blip/blip.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
| 3 |
+
'''
|
| 4 |
+
|
| 5 |
+
import warnings
|
| 6 |
+
warnings.filterwarnings("ignore")
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import os
|
| 10 |
+
from urllib.parse import urlparse
|
| 11 |
+
from timm.models.hub import download_cached_file
|
| 12 |
+
from transformers import BertTokenizer
|
| 13 |
+
from .vit import VisionTransformer, interpolate_pos_embed
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def init_tokenizer():
|
| 17 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 18 |
+
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
| 19 |
+
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
| 20 |
+
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
| 21 |
+
return tokenizer
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
|
| 25 |
+
|
| 26 |
+
assert vit in ['base', 'large'], "vit parameter must be base or large"
|
| 27 |
+
if vit=='base':
|
| 28 |
+
vision_width = 768
|
| 29 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
|
| 30 |
+
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
| 31 |
+
drop_path_rate=0 or drop_path_rate
|
| 32 |
+
)
|
| 33 |
+
elif vit=='large':
|
| 34 |
+
vision_width = 1024
|
| 35 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
|
| 36 |
+
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
| 37 |
+
drop_path_rate=0.1 or drop_path_rate
|
| 38 |
+
)
|
| 39 |
+
return visual_encoder, vision_width
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def is_url(url_or_filename):
|
| 43 |
+
parsed = urlparse(url_or_filename)
|
| 44 |
+
return parsed.scheme in ("http", "https")
|
| 45 |
+
|
| 46 |
+
def load_checkpoint(model,url_or_filename):
|
| 47 |
+
if is_url(url_or_filename):
|
| 48 |
+
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
| 49 |
+
checkpoint = torch.load(cached_file, map_location='cpu')
|
| 50 |
+
elif os.path.isfile(url_or_filename):
|
| 51 |
+
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
| 52 |
+
else:
|
| 53 |
+
raise RuntimeError('checkpoint url or path is invalid')
|
| 54 |
+
|
| 55 |
+
state_dict = checkpoint['model']
|
| 56 |
+
|
| 57 |
+
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
| 58 |
+
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
|
| 59 |
+
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
|
| 60 |
+
model.visual_encoder_m)
|
| 61 |
+
for key in model.state_dict().keys():
|
| 62 |
+
if key in state_dict.keys():
|
| 63 |
+
if state_dict[key].shape!=model.state_dict()[key].shape:
|
| 64 |
+
print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape)
|
| 65 |
+
del state_dict[key]
|
| 66 |
+
|
| 67 |
+
msg = model.load_state_dict(state_dict,strict=False)
|
| 68 |
+
print('load checkpoint from %s'%url_or_filename)
|
| 69 |
+
return model,msg
|
| 70 |
+
|
Reward_sana_idealized/blip/blip_pretrain.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
| 3 |
+
'''
|
| 4 |
+
|
| 5 |
+
import transformers
|
| 6 |
+
transformers.logging.set_verbosity_error()
|
| 7 |
+
|
| 8 |
+
from torch import nn
|
| 9 |
+
import os
|
| 10 |
+
from .med import BertConfig, BertModel
|
| 11 |
+
from .blip import create_vit, init_tokenizer
|
| 12 |
+
|
| 13 |
+
class BLIP_Pretrain(nn.Module):
|
| 14 |
+
def __init__(self,
|
| 15 |
+
med_config = "med_config.json",
|
| 16 |
+
image_size = 224,
|
| 17 |
+
vit = 'base',
|
| 18 |
+
vit_grad_ckpt = False,
|
| 19 |
+
vit_ckpt_layer = 0,
|
| 20 |
+
embed_dim = 256,
|
| 21 |
+
queue_size = 57600,
|
| 22 |
+
momentum = 0.995,
|
| 23 |
+
):
|
| 24 |
+
"""
|
| 25 |
+
Args:
|
| 26 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 27 |
+
image_size (int): input image size
|
| 28 |
+
vit (str): model size of vision transformer
|
| 29 |
+
"""
|
| 30 |
+
super().__init__()
|
| 31 |
+
|
| 32 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
|
| 33 |
+
|
| 34 |
+
self.tokenizer = init_tokenizer()
|
| 35 |
+
encoder_config = BertConfig.from_json_file(med_config)
|
| 36 |
+
encoder_config.encoder_width = vision_width
|
| 37 |
+
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
|
| 38 |
+
|
| 39 |
+
text_width = self.text_encoder.config.hidden_size
|
| 40 |
+
|
| 41 |
+
self.vision_proj = nn.Linear(vision_width, embed_dim)
|
| 42 |
+
self.text_proj = nn.Linear(text_width, embed_dim)
|
| 43 |
+
|
Reward_sana_idealized/config_analysis_tuning.ipynb
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "a24d02a2",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import json\n",
|
| 11 |
+
"import pandas as pd\n",
|
| 12 |
+
"import numpy as np\n",
|
| 13 |
+
"from pathlib import Path\n",
|
| 14 |
+
"from datetime import datetime\n",
|
| 15 |
+
"import warnings\n",
|
| 16 |
+
"warnings.filterwarnings('ignore')\n",
|
| 17 |
+
"\n",
|
| 18 |
+
"# ============================================================================\n",
|
| 19 |
+
"# SECTION 1: Load and Parse Results from GPU Tuning Runs\n",
|
| 20 |
+
"# ==========================-==================================================\n",
|
| 21 |
+
"print(\"=\" * 80)\n",
|
| 22 |
+
"print(\"LOADING TUNING RESULTS FROM GPU RUNS\")\n",
|
| 23 |
+
"print(\"=\" * 80)\n",
|
| 24 |
+
"\n",
|
| 25 |
+
"results_dir = Path(\"RESULTS_TURNING/run_2\")\n",
|
| 26 |
+
"all_experiments = []\n",
|
| 27 |
+
"baseline_metrics = None\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"# Collect results from all GPU runs\n",
|
| 30 |
+
"for gpu_id in range(8):\n",
|
| 31 |
+
" gpu_dir = results_dir / f\"gpu_{gpu_id}\"\n",
|
| 32 |
+
" results_file = gpu_dir / \"tuning_results.json\"\n",
|
| 33 |
+
" \n",
|
| 34 |
+
" if results_file.exists():\n",
|
| 35 |
+
" with open(results_file, 'r') as f:\n",
|
| 36 |
+
" data = json.load(f)\n",
|
| 37 |
+
" \n",
|
| 38 |
+
" # Extract baseline (same across all GPUs)\n",
|
| 39 |
+
" if baseline_metrics is None and \"baseline\" in data:\n",
|
| 40 |
+
" baseline_metrics = data[\"baseline\"][\"metrics\"]\n",
|
| 41 |
+
" print(f\"\\n📊 Baseline Metrics (cfg_scale=5.0):\")\n",
|
| 42 |
+
" for metric, value in baseline_metrics.items():\n",
|
| 43 |
+
" print(f\" {metric:15s}: {value:.6f}\")\n",
|
| 44 |
+
" \n",
|
| 45 |
+
" # Collect all experiments\n",
|
| 46 |
+
" if \"experiments\" in data:\n",
|
| 47 |
+
" all_experiments.extend(data[\"experiments\"])\n",
|
| 48 |
+
" print(f\"✓ GPU {gpu_id}: {len(data['experiments'])} results loaded\")\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"print(f\"\\n✓ Total experiments loaded: {len(all_experiments)}\")\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"# ============================================================================\n",
|
| 53 |
+
"# SECTION 2: Filter Top Configs with Improvements Across All Metrics\n",
|
| 54 |
+
"# ============================================================================\n",
|
| 55 |
+
"print(\"\\n\" + \"=\" * 80)\n",
|
| 56 |
+
"print(\"FILTERING CONFIGURATIONS WITH IMPROVEMENTS IN ALL METRICS\")\n",
|
| 57 |
+
"print(\"=\" * 80)\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"# Define improvement metrics to track (using ImageReward instead of Reward)\n",
|
| 60 |
+
"improvement_metrics = [\n",
|
| 61 |
+
" \"aesthetic_improvement\", \n",
|
| 62 |
+
" \"imagereward_improvement\", \n",
|
| 63 |
+
" \"clip_improvement\", \n",
|
| 64 |
+
" \"pickscore_improvement\", \n",
|
| 65 |
+
" \"hpsv2_improvement\"\n",
|
| 66 |
+
" ]\n",
|
| 67 |
+
"\n",
|
| 68 |
+
"# Filter experiments with improvements in ALL metrics\n",
|
| 69 |
+
"top_configs = []\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"for exp in all_experiments:\n",
|
| 72 |
+
" if \"improvements\" not in exp or \"config\" not in exp or \"metrics\" not in exp:\n",
|
| 73 |
+
" continue\n",
|
| 74 |
+
" \n",
|
| 75 |
+
" improvements = exp[\"improvements\"]\n",
|
| 76 |
+
" config = exp[\"config\"]\n",
|
| 77 |
+
" metrics = exp[\"metrics\"]\n",
|
| 78 |
+
" \n",
|
| 79 |
+
" # Check if ALL improvements are positive (>0)\n",
|
| 80 |
+
" all_positive = all(improvements.get(metric, -1) > 0 for metric in improvement_metrics)\n",
|
| 81 |
+
" \n",
|
| 82 |
+
" if all_positive:\n",
|
| 83 |
+
" # Calculate aggregate improvement score\n",
|
| 84 |
+
" avg_improvement = np.mean([improvements.get(metric, 0) for metric in improvement_metrics])\n",
|
| 85 |
+
" \n",
|
| 86 |
+
" top_configs.append({\n",
|
| 87 |
+
" \"config\": config,\n",
|
| 88 |
+
" \"metrics\": metrics,\n",
|
| 89 |
+
" \"improvements\": improvements,\n",
|
| 90 |
+
" \"avg_improvement\": avg_improvement\n",
|
| 91 |
+
" })\n",
|
| 92 |
+
"\n",
|
| 93 |
+
"print(f\"✓ Found {len(top_configs)} configurations with improvements in ALL metrics\")\n",
|
| 94 |
+
"\n",
|
| 95 |
+
"# Sort by average improvement\n",
|
| 96 |
+
"top_configs.sort(key=lambda x: x[\"avg_improvement\"], reverse=True)\n",
|
| 97 |
+
"\n",
|
| 98 |
+
"# Get top 10\n",
|
| 99 |
+
"top_10 = top_configs[:10]\n",
|
| 100 |
+
"print(f\"✓ Extracted top 10 best performing configurations\")\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"# ============================================================================\n",
|
| 103 |
+
"# SECTION 3: Create Comprehensive Results Table\n",
|
| 104 |
+
"# ============================================================================\n",
|
| 105 |
+
"print(\"\\n\" + \"=\" * 80)\n",
|
| 106 |
+
"print(\"CREATING COMPREHENSIVE RESULTS TABLE\")\n",
|
| 107 |
+
"print(\"=\" * 80)\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"# Build detailed table data\n",
|
| 110 |
+
"table_data = []\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"for rank, result in enumerate(top_10, 1):\n",
|
| 113 |
+
" cfg = result[\"config\"]\n",
|
| 114 |
+
" metrics = result[\"metrics\"]\n",
|
| 115 |
+
" improvements = result[\"improvements\"]\n",
|
| 116 |
+
" \n",
|
| 117 |
+
" row = {\n",
|
| 118 |
+
" \"Rank\": rank,\n",
|
| 119 |
+
" \"CFG Scale\": cfg.get(\"cfg_scale\", \"N/A\"),\n",
|
| 120 |
+
" \"Grad Config\": cfg.get(\"grad_config\", \"N/A\"),\n",
|
| 121 |
+
" \"Steps\": cfg.get(\"num_grad_steps\", \"N/A\"),\n",
|
| 122 |
+
" \"LR\": cfg.get(\"grad_step_size\", \"N/A\"),\n",
|
| 123 |
+
" \"Momentum\": cfg.get(\"momentum\", \"N/A\"),\n",
|
| 124 |
+
" \"ImageReward\": f\"{metrics.get('imagereward', 0):.6f}\",\n",
|
| 125 |
+
" \"ImageReward ↑\": f\"{improvements.get('imagereward_improvement', 0):+.2f}%\",\n",
|
| 126 |
+
" \"CLIP\": f\"{metrics.get('clip', 0):.4f}\",\n",
|
| 127 |
+
" \"CLIP ↑\": f\"{improvements.get('clip_improvement', 0):+.2f}%\",\n",
|
| 128 |
+
" \"Aesthetic\": f\"{metrics.get('aesthetic', 0):.4f}\",\n",
|
| 129 |
+
" \"Aesthetic ↑\": f\"{improvements.get('aesthetic_improvement', 0):+.2f}%\",\n",
|
| 130 |
+
" \"PickScore\": f\"{metrics.get('pickscore', 0):.4f}\",\n",
|
| 131 |
+
" \"PickScore ↑\": f\"{improvements.get('pickscore_improvement', 0):+.2f}%\",\n",
|
| 132 |
+
" \"HPSv2\": f\"{metrics.get('hpsv2', 0):.4f}\",\n",
|
| 133 |
+
" \"HPSv2 ↑\": f\"{improvements.get('hpsv2_improvement', 0):+.2f}%\",\n",
|
| 134 |
+
" \"Avg Improvement\": f\"{result['avg_improvement']:+.2f}%\",\n",
|
| 135 |
+
" }\n",
|
| 136 |
+
" \n",
|
| 137 |
+
" table_data.append(row)\n",
|
| 138 |
+
"\n",
|
| 139 |
+
"df_top_10 = pd.DataFrame(table_data)\n",
|
| 140 |
+
"\n",
|
| 141 |
+
"print(\"\\n📋 TOP 10 CONFIGURATIONS WITH IMPROVEMENTS IN ALL METRICS:\")\n",
|
| 142 |
+
"print(\"=\" * 180)\n",
|
| 143 |
+
"print(df_top_10.to_string(index=False))\n",
|
| 144 |
+
"print(\"=\" * 180)\n",
|
| 145 |
+
"\n",
|
| 146 |
+
"# ============================================================================\n",
|
| 147 |
+
"# SECTION 4: Visualize and Summary Statistics\n",
|
| 148 |
+
"# ============================================================================\n",
|
| 149 |
+
"print(\"\\n\" + \"=\" * 80)\n",
|
| 150 |
+
"print(\"SUMMARY STATISTICS\")\n",
|
| 151 |
+
"print(\"=\" * 80)\n",
|
| 152 |
+
"\n",
|
| 153 |
+
"# Extract numeric improvement values for analysis\n",
|
| 154 |
+
"improvement_summary = []\n",
|
| 155 |
+
"for result in top_10:\n",
|
| 156 |
+
" improvements = result[\"improvements\"]\n",
|
| 157 |
+
" for metric in [\"imagereward_improvement\", \"clip_improvement\", \"aesthetic_improvement\", \n",
|
| 158 |
+
" \"pickscore_improvement\", \"hpsv2_improvement\"]:\n",
|
| 159 |
+
" metric_name = metric.replace(\"_improvement\", \"\").upper()\n",
|
| 160 |
+
" improvement_summary.append({\n",
|
| 161 |
+
" \"Metric\": metric_name,\n",
|
| 162 |
+
" \"Improvement %\": improvements.get(metric, 0)\n",
|
| 163 |
+
" })\n",
|
| 164 |
+
"\n",
|
| 165 |
+
"df_summary = pd.DataFrame(improvement_summary)\n",
|
| 166 |
+
"\n",
|
| 167 |
+
"print(\"\\n📊 Average Improvements by Metric (Top 10):\")\n",
|
| 168 |
+
"metric_stats = df_summary.groupby(\"Metric\")[\"Improvement %\"].agg([\"mean\", \"std\", \"min\", \"max\"])\n",
|
| 169 |
+
"print(metric_stats.round(2))\n",
|
| 170 |
+
"\n",
|
| 171 |
+
"print(\"\\n📈 Best Configuration Details:\")\n",
|
| 172 |
+
"best = top_10[0]\n",
|
| 173 |
+
"best_cfg = best[\"config\"]\n",
|
| 174 |
+
"best_metrics = best[\"metrics\"]\n",
|
| 175 |
+
"best_improvements = best[\"improvements\"]\n",
|
| 176 |
+
"\n",
|
| 177 |
+
"print(f\"\\n✓ RANK #1 - Best Performing Configuration:\")\n",
|
| 178 |
+
"print(f\" Configuration:\")\n",
|
| 179 |
+
"print(f\" • CFG Scale: {best_cfg.get('cfg_scale')}\")\n",
|
| 180 |
+
"print(f\" • Gradient Config: {best_cfg.get('grad_config')}\")\n",
|
| 181 |
+
"print(f\" • Gradient Steps: {best_cfg.get('num_grad_steps')}\")\n",
|
| 182 |
+
"print(f\" • Step Size: {best_cfg.get('grad_step_size')}\")\n",
|
| 183 |
+
"print(f\" • Momentum: {best_cfg.get('momentum')}\")\n",
|
| 184 |
+
"print(f\"\\n Metrics:\")\n",
|
| 185 |
+
"for metric in [\"imagereward\", \"clip\", \"aesthetic\", \"pickscore\", \"hpsv2\"]:\n",
|
| 186 |
+
" baseline_val = baseline_metrics.get(metric, 0)\n",
|
| 187 |
+
" current_val = best_metrics.get(metric, 0)\n",
|
| 188 |
+
" improvement = best_improvements.get(f\"{metric}_improvement\", 0)\n",
|
| 189 |
+
" print(f\" • {metric:12s}: {current_val:8.6f} (baseline: {baseline_val:8.6f}) ↑ {improvement:+6.2f}%\")\n",
|
| 190 |
+
"\n",
|
| 191 |
+
"print(\"\\n\" + \"=\" * 80)\n",
|
| 192 |
+
"print(\"✓ ANALYSIS COMPLETE - TOP 10 CONFIGURATIONS IDENTIFIED\")\n",
|
| 193 |
+
"print(\"=\" * 80)"
|
| 194 |
+
]
|
| 195 |
+
}
|
| 196 |
+
],
|
| 197 |
+
"metadata": {
|
| 198 |
+
"kernelspec": {
|
| 199 |
+
"display_name": "Python 3",
|
| 200 |
+
"language": "python",
|
| 201 |
+
"name": "python3"
|
| 202 |
+
},
|
| 203 |
+
"language_info": {
|
| 204 |
+
"codemirror_mode": {
|
| 205 |
+
"name": "ipython",
|
| 206 |
+
"version": 3
|
| 207 |
+
},
|
| 208 |
+
"file_extension": ".py",
|
| 209 |
+
"mimetype": "text/x-python",
|
| 210 |
+
"name": "python",
|
| 211 |
+
"nbconvert_exporter": "python",
|
| 212 |
+
"pygments_lexer": "ipython3",
|
| 213 |
+
"version": "3.10.18"
|
| 214 |
+
}
|
| 215 |
+
},
|
| 216 |
+
"nbformat": 4,
|
| 217 |
+
"nbformat_minor": 5
|
| 218 |
+
}
|
Reward_sana_idealized/eval.py
ADDED
|
@@ -0,0 +1,1447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluation script for comparing baseline and gradient ascent pipelines using multiple metrics.
|
| 3 |
+
|
| 4 |
+
This script evaluates both pipelines on COCO or Pick-a-Pic validation sets and computes
|
| 5 |
+
various preference and quality metrics.
|
| 6 |
+
"""
|
| 7 |
+
import warnings
|
| 8 |
+
warnings.filterwarnings("ignore")
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import json
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import logging
|
| 15 |
+
from glob import glob
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from PIL import Image
|
| 18 |
+
from diffusers import SanaPipeline
|
| 19 |
+
from models import LRMRewardModel
|
| 20 |
+
from pipelines.sana_gradient_ascent_pipeline import SanaGradientAscentPipeline
|
| 21 |
+
from torchmetrics.image.fid import FrechetInceptionDistance
|
| 22 |
+
from torchmetrics.multimodal import CLIPScore
|
| 23 |
+
from transformers import CLIPModel, CLIPProcessor
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
import numpy as np
|
| 26 |
+
import argparse
|
| 27 |
+
from datasets import load_dataset
|
| 28 |
+
from grad_ascent_configs import get_config, list_configs
|
| 29 |
+
import matplotlib.pyplot as plt
|
| 30 |
+
import matplotlib
|
| 31 |
+
matplotlib.use('Agg') # Use non-interactive backend
|
| 32 |
+
|
| 33 |
+
from huggingface_hub import hf_hub_download
|
| 34 |
+
|
| 35 |
+
import random
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
SANA_PROFILE_TO_MODEL_ID = {
|
| 39 |
+
"sana_600m_512": "Efficient-Large-Model/Sana_600M_512px_diffusers",
|
| 40 |
+
"sana_1600m_512": "Efficient-Large-Model/Sana_1600M_512px_diffusers",
|
| 41 |
+
"sana_sprint_0_6b_1024": "Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers",
|
| 42 |
+
"sana_sprint_1_6b_1024": "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def configure_hf_runtime(hf_cache_dir=None, force_offline=False):
|
| 47 |
+
"""Set Hugging Face cache/offline environment for cluster-safe execution."""
|
| 48 |
+
cache_dir = hf_cache_dir or os.getenv("HF_HUB_CACHE") or os.getenv("HUGGINGFACE_HUB_CACHE")
|
| 49 |
+
if cache_dir:
|
| 50 |
+
os.environ["HF_HUB_CACHE"] = cache_dir
|
| 51 |
+
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
|
| 52 |
+
os.environ["HF_HOME"] = os.path.dirname(cache_dir)
|
| 53 |
+
|
| 54 |
+
env_offline = os.getenv("HF_HUB_OFFLINE", "0").strip().lower() in {"1", "true", "yes", "on"}
|
| 55 |
+
offline_enabled = bool(force_offline or env_offline)
|
| 56 |
+
if offline_enabled:
|
| 57 |
+
os.environ["HF_DATASETS_OFFLINE"] = "1"
|
| 58 |
+
os.environ["HF_METRICS_OFFLINE"] = "1"
|
| 59 |
+
os.environ["HF_MODULES_OFFLINE"] = "1"
|
| 60 |
+
os.environ["TRANSFORMERS_OFFLINE"] = "1"
|
| 61 |
+
os.environ["DIFFUSERS_OFFLINE"] = "1"
|
| 62 |
+
os.environ["HF_HUB_OFFLINE"] = "1"
|
| 63 |
+
|
| 64 |
+
return cache_dir, offline_enabled
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def resolve_default_lrm_model():
|
| 68 |
+
"""Prefer the local SANA reward checkpoint when available."""
|
| 69 |
+
project_root = Path(__file__).resolve().parents[1]
|
| 70 |
+
default_ckpt_dir = (
|
| 71 |
+
project_root
|
| 72 |
+
/ "lrm"
|
| 73 |
+
/ "lrm_sana"
|
| 74 |
+
/ "logs"
|
| 75 |
+
/ "v8"
|
| 76 |
+
/ "reward_model"
|
| 77 |
+
/ "step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951"
|
| 78 |
+
/ "checkpoint-gstep76000"
|
| 79 |
+
)
|
| 80 |
+
if default_ckpt_dir.exists():
|
| 81 |
+
return str(default_ckpt_dir)
|
| 82 |
+
return ""
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def load_pickapic_prompts(max_samples=None, cache_dir=None, offline=False):
|
| 86 |
+
"""Load Pick-a-Pic prompts with robust offline fallback to cached parquet shards."""
|
| 87 |
+
split = "validation_unique"
|
| 88 |
+
|
| 89 |
+
if not offline:
|
| 90 |
+
try:
|
| 91 |
+
ds = load_dataset("pickapic-anonymous/pickapic_v1", split=split, streaming=True)
|
| 92 |
+
prompts = []
|
| 93 |
+
for i, sample in enumerate(ds):
|
| 94 |
+
prompts.append(sample["caption"])
|
| 95 |
+
if max_samples and i + 1 >= max_samples:
|
| 96 |
+
break
|
| 97 |
+
return prompts
|
| 98 |
+
except Exception as e:
|
| 99 |
+
print(f"Warning: online streaming load failed ({e}). Trying cached offline parquet shards.")
|
| 100 |
+
|
| 101 |
+
cache_candidates = []
|
| 102 |
+
for p in [
|
| 103 |
+
cache_dir,
|
| 104 |
+
os.getenv("HF_HUB_CACHE"),
|
| 105 |
+
os.getenv("HUGGINGFACE_HUB_CACHE"),
|
| 106 |
+
(os.path.join(os.getenv("HF_HOME"), "hub") if os.getenv("HF_HOME") else None),
|
| 107 |
+
os.path.expanduser("~/.cache/huggingface/hub"),
|
| 108 |
+
"/scratch/rr81/ma5430/.cache/huggingface/hub",
|
| 109 |
+
]:
|
| 110 |
+
if p and p not in cache_candidates:
|
| 111 |
+
cache_candidates.append(p)
|
| 112 |
+
|
| 113 |
+
for cache_root in cache_candidates:
|
| 114 |
+
repo_cache = os.path.join(cache_root, "datasets--pickapic-anonymous--pickapic_v1")
|
| 115 |
+
if not os.path.isdir(repo_cache):
|
| 116 |
+
continue
|
| 117 |
+
|
| 118 |
+
snapshot_dir = None
|
| 119 |
+
ref_main = os.path.join(repo_cache, "refs", "main")
|
| 120 |
+
if os.path.isfile(ref_main):
|
| 121 |
+
revision = open(ref_main, "r", encoding="utf-8").read().strip()
|
| 122 |
+
candidate = os.path.join(repo_cache, "snapshots", revision)
|
| 123 |
+
if os.path.isdir(candidate):
|
| 124 |
+
snapshot_dir = candidate
|
| 125 |
+
|
| 126 |
+
if snapshot_dir is None:
|
| 127 |
+
snapshots = sorted(glob(os.path.join(repo_cache, "snapshots", "*")))
|
| 128 |
+
if snapshots:
|
| 129 |
+
snapshot_dir = snapshots[-1]
|
| 130 |
+
|
| 131 |
+
if snapshot_dir is None:
|
| 132 |
+
continue
|
| 133 |
+
|
| 134 |
+
data_dir = os.path.join(snapshot_dir, "data")
|
| 135 |
+
if not os.path.isdir(data_dir):
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
selected_split = split
|
| 139 |
+
parquet_files = sorted(glob(os.path.join(data_dir, f"{selected_split}-*.parquet")))
|
| 140 |
+
if not parquet_files:
|
| 141 |
+
for alt_split in ("test_unique", "test"):
|
| 142 |
+
alt_files = sorted(glob(os.path.join(data_dir, f"{alt_split}-*.parquet")))
|
| 143 |
+
if alt_files:
|
| 144 |
+
selected_split = alt_split
|
| 145 |
+
parquet_files = alt_files
|
| 146 |
+
print(f"Offline cache missing split '{split}', falling back to '{selected_split}'.")
|
| 147 |
+
break
|
| 148 |
+
|
| 149 |
+
if not parquet_files:
|
| 150 |
+
continue
|
| 151 |
+
|
| 152 |
+
print(
|
| 153 |
+
f"Loading cached Pick-a-Pic split '{selected_split}' from {len(parquet_files)} parquet shards\n"
|
| 154 |
+
f"cache={repo_cache}"
|
| 155 |
+
)
|
| 156 |
+
ds = load_dataset("parquet", data_files=parquet_files, split="train")
|
| 157 |
+
prompts = ds["caption"]
|
| 158 |
+
if max_samples:
|
| 159 |
+
prompts = prompts[:max_samples]
|
| 160 |
+
return list(prompts)
|
| 161 |
+
|
| 162 |
+
raise RuntimeError(
|
| 163 |
+
"Could not load pickapic prompts in offline mode. "
|
| 164 |
+
"Set --hf_cache_dir to a cache that contains datasets--pickapic-anonymous--pickapic_v1."
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def resolve_scorer_device(requested_device, generation_device, min_free_gb_for_gpu=14.0):
|
| 169 |
+
"""Choose where metric scorers should run to avoid GPU OOM/cudnn init failures."""
|
| 170 |
+
if requested_device == "cpu":
|
| 171 |
+
return "cpu"
|
| 172 |
+
|
| 173 |
+
if not torch.cuda.is_available() or not str(generation_device).startswith("cuda"):
|
| 174 |
+
return "cpu"
|
| 175 |
+
|
| 176 |
+
if requested_device == "cuda":
|
| 177 |
+
return generation_device
|
| 178 |
+
|
| 179 |
+
# Auto mode: only keep scorers on GPU if enough headroom remains after loading generation models.
|
| 180 |
+
try:
|
| 181 |
+
free_bytes, total_bytes = torch.cuda.mem_get_info(torch.device(generation_device))
|
| 182 |
+
free_gb = free_bytes / (1024 ** 3)
|
| 183 |
+
total_gb = total_bytes / (1024 ** 3)
|
| 184 |
+
print(f"GPU memory before scorer load: {free_gb:.2f} GB free / {total_gb:.2f} GB total")
|
| 185 |
+
if free_gb >= min_free_gb_for_gpu:
|
| 186 |
+
return generation_device
|
| 187 |
+
print(
|
| 188 |
+
f"⚠ Low free VRAM ({free_gb:.2f} GB). Running scorers on CPU to keep diffusion stable. "
|
| 189 |
+
f"Use --scorer_device cuda to force GPU scorers."
|
| 190 |
+
)
|
| 191 |
+
return "cpu"
|
| 192 |
+
except Exception as e:
|
| 193 |
+
print(f"Warning: could not inspect CUDA free memory ({e}). Falling back to CPU scorers.")
|
| 194 |
+
return "cpu"
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def configure_cudnn_safely(device):
|
| 198 |
+
"""Disable cuDNN when the current GPU or runtime cannot initialize it safely."""
|
| 199 |
+
if not torch.cuda.is_available() or not str(device).startswith("cuda"):
|
| 200 |
+
return
|
| 201 |
+
|
| 202 |
+
try:
|
| 203 |
+
major, minor = torch.cuda.get_device_capability(torch.device(device))
|
| 204 |
+
if (major, minor) < (7, 5):
|
| 205 |
+
print(
|
| 206 |
+
f"⚠ Detected compute capability sm_{major}{minor} (< 75). "
|
| 207 |
+
"Disabling cuDNN to prevent runtime initialization failures."
|
| 208 |
+
)
|
| 209 |
+
torch.backends.cudnn.enabled = False
|
| 210 |
+
return
|
| 211 |
+
|
| 212 |
+
# Force a cuDNN init probe early so failures are handled once at startup.
|
| 213 |
+
_ = torch.backends.cudnn.version()
|
| 214 |
+
except Exception as e:
|
| 215 |
+
print(f"⚠ cuDNN init probe failed ({e}). Disabling cuDNN for this run.")
|
| 216 |
+
torch.backends.cudnn.enabled = False
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def resolve_generation_dtype(requested_dtype, device):
|
| 220 |
+
"""Pick a safe generation dtype for the current device."""
|
| 221 |
+
req = str(requested_dtype).strip().lower()
|
| 222 |
+
|
| 223 |
+
if req == "fp32":
|
| 224 |
+
return torch.float32
|
| 225 |
+
|
| 226 |
+
if not str(device).startswith("cuda"):
|
| 227 |
+
if req in {"fp16", "bf16", "auto"}:
|
| 228 |
+
print("⚠ Non-CUDA device detected. Falling back to fp32.")
|
| 229 |
+
return torch.float32
|
| 230 |
+
|
| 231 |
+
if req == "fp16":
|
| 232 |
+
return torch.float16
|
| 233 |
+
|
| 234 |
+
if req == "bf16":
|
| 235 |
+
if torch.cuda.is_bf16_supported():
|
| 236 |
+
return torch.bfloat16
|
| 237 |
+
print("⚠ bf16 requested but not supported on this GPU. Falling back to fp16.")
|
| 238 |
+
return torch.float16
|
| 239 |
+
|
| 240 |
+
# auto: prefer bf16 on supported GPUs to avoid fp16 underflow in tiny gradients.
|
| 241 |
+
if torch.cuda.is_bf16_supported():
|
| 242 |
+
return torch.bfloat16
|
| 243 |
+
return torch.float16
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def dtype_to_name(dtype: torch.dtype) -> str:
|
| 247 |
+
if dtype == torch.float16:
|
| 248 |
+
return "fp16"
|
| 249 |
+
if dtype == torch.bfloat16:
|
| 250 |
+
return "bf16"
|
| 251 |
+
return "fp32"
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def seed_everything(seed: int):
|
| 255 |
+
"""Locks down all random number generators for absolute reproducibility."""
|
| 256 |
+
# 1. Python & Numpy
|
| 257 |
+
random.seed(seed)
|
| 258 |
+
np.random.seed(seed)
|
| 259 |
+
|
| 260 |
+
# 2. PyTorch Base
|
| 261 |
+
torch.manual_seed(seed)
|
| 262 |
+
if torch.cuda.is_available():
|
| 263 |
+
torch.cuda.manual_seed(seed)
|
| 264 |
+
torch.cuda.manual_seed_all(seed) # For multi-GPU
|
| 265 |
+
|
| 266 |
+
# 3. cuDNN Determinism (Crucial for consistent gradients)
|
| 267 |
+
torch.backends.cudnn.deterministic = True
|
| 268 |
+
torch.backends.cudnn.benchmark = False
|
| 269 |
+
|
| 270 |
+
# 4. Optional: Force deterministic algorithms for PyTorch 2.0+
|
| 271 |
+
# Uncomment if variance persists, but it may slow down generation slightly
|
| 272 |
+
# torch.use_deterministic_algorithms(True)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class MLP(nn.Module):
|
| 276 |
+
"""MLP for aesthetic scoring."""
|
| 277 |
+
def __init__(self):
|
| 278 |
+
super().__init__()
|
| 279 |
+
self.layers = nn.Sequential(
|
| 280 |
+
nn.Linear(768, 1024),
|
| 281 |
+
nn.Dropout(0.2),
|
| 282 |
+
nn.Linear(1024, 128),
|
| 283 |
+
nn.Dropout(0.2),
|
| 284 |
+
nn.Linear(128, 64),
|
| 285 |
+
nn.Dropout(0.1),
|
| 286 |
+
nn.Linear(64, 16),
|
| 287 |
+
nn.Linear(16, 1),
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
@torch.no_grad()
|
| 291 |
+
def forward(self, embed):
|
| 292 |
+
return self.layers(embed)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class AestheticScorer(torch.nn.Module):
|
| 296 |
+
"""Aesthetic scorer using CLIP and MLP."""
|
| 297 |
+
def __init__(self, dtype, device, clip_name_or_path="openai/clip-vit-large-patch14",
|
| 298 |
+
aesthetic_path="./sac+logos+ava1-l14-linearMSE.pth"):
|
| 299 |
+
super().__init__()
|
| 300 |
+
self.clip = CLIPModel.from_pretrained(clip_name_or_path)
|
| 301 |
+
self.processor = CLIPProcessor.from_pretrained(clip_name_or_path)
|
| 302 |
+
self.mlp = MLP()
|
| 303 |
+
|
| 304 |
+
# Load aesthetic weights
|
| 305 |
+
if os.path.exists(aesthetic_path):
|
| 306 |
+
state_dict = torch.load(aesthetic_path, map_location='cpu')
|
| 307 |
+
self.mlp.load_state_dict(state_dict)
|
| 308 |
+
else:
|
| 309 |
+
print(f"Warning: Aesthetic weights not found at {aesthetic_path}")
|
| 310 |
+
|
| 311 |
+
self.dtype = dtype
|
| 312 |
+
self.to(device)
|
| 313 |
+
self.eval()
|
| 314 |
+
|
| 315 |
+
@torch.no_grad()
|
| 316 |
+
def __call__(self, images):
|
| 317 |
+
device = next(self.parameters()).device
|
| 318 |
+
inputs = self.processor(images=images, return_tensors="pt")
|
| 319 |
+
inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()}
|
| 320 |
+
embed = self.clip.get_image_features(**inputs)
|
| 321 |
+
# normalize embedding
|
| 322 |
+
embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
|
| 323 |
+
return self.mlp(embed).squeeze(1)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class TeeLogger:
|
| 327 |
+
"""Logger that writes to both console and file."""
|
| 328 |
+
def __init__(self, log_file):
|
| 329 |
+
self.terminal = sys.stdout
|
| 330 |
+
self.log = open(log_file, 'w')
|
| 331 |
+
|
| 332 |
+
def write(self, message):
|
| 333 |
+
self.terminal.write(message)
|
| 334 |
+
self.log.write(message)
|
| 335 |
+
self.log.flush()
|
| 336 |
+
|
| 337 |
+
def flush(self):
|
| 338 |
+
self.terminal.flush()
|
| 339 |
+
self.log.flush()
|
| 340 |
+
|
| 341 |
+
def close(self):
|
| 342 |
+
self.log.close()
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def setup_logging(output_dir):
|
| 346 |
+
"""Setup logging to both console and file."""
|
| 347 |
+
output_path = Path(output_dir)
|
| 348 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 349 |
+
log_file = output_path / "log.log"
|
| 350 |
+
|
| 351 |
+
# Redirect stdout to both console and file
|
| 352 |
+
tee = TeeLogger(log_file)
|
| 353 |
+
sys.stdout = tee
|
| 354 |
+
|
| 355 |
+
return tee, log_file
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def load_validation_data(data_dir, max_samples=None, dataset_type="coco", hf_cache_dir=None, offline=False):
|
| 359 |
+
"""Load validation prompts and image paths.
|
| 360 |
+
|
| 361 |
+
Args:
|
| 362 |
+
data_dir: Path to data directory
|
| 363 |
+
max_samples: Maximum number of samples to load
|
| 364 |
+
dataset_type: Type of dataset ("coco" or "pickapic")
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
prompts: List of text prompts
|
| 368 |
+
image_paths: List of image paths (None for pickapic streaming dataset)
|
| 369 |
+
"""
|
| 370 |
+
if dataset_type == "coco":
|
| 371 |
+
data_dir = Path(data_dir)
|
| 372 |
+
val_json = data_dir / "coco" / "caption_val.json"
|
| 373 |
+
|
| 374 |
+
if not val_json.exists():
|
| 375 |
+
raise FileNotFoundError(f"Validation JSON not found: {val_json}")
|
| 376 |
+
|
| 377 |
+
with open(val_json, 'r') as f:
|
| 378 |
+
data = json.load(f)
|
| 379 |
+
|
| 380 |
+
# Validate that image folder exists
|
| 381 |
+
val_img_dir = data_dir / "coco" / "images" / "val"
|
| 382 |
+
if not val_img_dir.exists():
|
| 383 |
+
raise FileNotFoundError(f"Validation image directory not found: {val_img_dir}")
|
| 384 |
+
|
| 385 |
+
# Parse data
|
| 386 |
+
prompts = []
|
| 387 |
+
image_paths = []
|
| 388 |
+
for img_path, caption in data.items():
|
| 389 |
+
full_path = data_dir / "coco" / img_path
|
| 390 |
+
if full_path.exists():
|
| 391 |
+
prompts.append(caption)
|
| 392 |
+
image_paths.append(str(full_path))
|
| 393 |
+
else:
|
| 394 |
+
print(f"Warning: Image not found: {full_path}")
|
| 395 |
+
|
| 396 |
+
if max_samples:
|
| 397 |
+
prompts = prompts[:max_samples]
|
| 398 |
+
image_paths = image_paths[:max_samples]
|
| 399 |
+
|
| 400 |
+
print(f"Loaded {len(prompts)} COCO validation samples")
|
| 401 |
+
return prompts, image_paths
|
| 402 |
+
|
| 403 |
+
elif dataset_type == "pickapic":
|
| 404 |
+
print("Loading Pick-a-Pic validation prompts...")
|
| 405 |
+
prompts = load_pickapic_prompts(max_samples=max_samples, cache_dir=hf_cache_dir, offline=offline)
|
| 406 |
+
|
| 407 |
+
print(f"Loaded {len(prompts)} Pick-a-Pic validation samples")
|
| 408 |
+
return prompts, None # No reference images for Pick-a-Pic
|
| 409 |
+
|
| 410 |
+
else:
|
| 411 |
+
raise ValueError(f"Unknown dataset type: {dataset_type}. Choose 'coco' or 'pickapic'.")
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def generate_and_evaluate(
|
| 415 |
+
pipeline,
|
| 416 |
+
prompts,
|
| 417 |
+
image_paths,
|
| 418 |
+
device,
|
| 419 |
+
dtype,
|
| 420 |
+
num_inference_steps=20,
|
| 421 |
+
guidance_scale=7.5,
|
| 422 |
+
seed=42,
|
| 423 |
+
batch_size=1,
|
| 424 |
+
apply_gradient_ascent=False,
|
| 425 |
+
mode_name="baseline",
|
| 426 |
+
log_interval=10,
|
| 427 |
+
output_dir=None,
|
| 428 |
+
save_images=False,
|
| 429 |
+
clip_scorer=None,
|
| 430 |
+
aesthetic_scorer=None,
|
| 431 |
+
pick_scorer=None,
|
| 432 |
+
hpsv2_scorer=None,
|
| 433 |
+
hpsv21_scorer=None,
|
| 434 |
+
imagereward_scorer=None,
|
| 435 |
+
compute_fid=True,
|
| 436 |
+
capture_trajectory=False
|
| 437 |
+
):
|
| 438 |
+
"""Generate images and update FID metric."""
|
| 439 |
+
pipeline.to(device)
|
| 440 |
+
|
| 441 |
+
print(f"\nGenerating images with {mode_name} mode...")
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
all_rewards = []
|
| 445 |
+
all_clip_scores = []
|
| 446 |
+
all_aesthetic_scores = []
|
| 447 |
+
all_pick_scores = []
|
| 448 |
+
all_hpsv2_scores = []
|
| 449 |
+
all_hpsv21_scores = []
|
| 450 |
+
all_imagereward_scores = []
|
| 451 |
+
lr_history_first_image = None # Store LR history for first image
|
| 452 |
+
trajectory_first_image = []
|
| 453 |
+
num_batches = (len(prompts) + batch_size - 1) // batch_size
|
| 454 |
+
|
| 455 |
+
# Create output directory if saving images
|
| 456 |
+
if save_images and output_dir:
|
| 457 |
+
mode_output_dir = Path(output_dir) / mode_name
|
| 458 |
+
mode_output_dir.mkdir(parents=True, exist_ok=True)
|
| 459 |
+
|
| 460 |
+
# Disable internal progress bars
|
| 461 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 462 |
+
|
| 463 |
+
for idx, i in enumerate(tqdm(range(0, len(prompts), batch_size), desc=f"Generating {mode_name}")):
|
| 464 |
+
batch_prompts = prompts[i:i+batch_size]
|
| 465 |
+
batch_real_paths = image_paths[i:i+batch_size] if image_paths is not None else None
|
| 466 |
+
batch_num = idx + 1
|
| 467 |
+
|
| 468 |
+
# Initialize FID metric if needed
|
| 469 |
+
fid_metric = None
|
| 470 |
+
real_images_tensor = None
|
| 471 |
+
|
| 472 |
+
if compute_fid and batch_real_paths is not None:
|
| 473 |
+
fid_metric = FrechetInceptionDistance().to(device)
|
| 474 |
+
|
| 475 |
+
# Load and update FID with real images for this batch
|
| 476 |
+
real_images = []
|
| 477 |
+
for path in batch_real_paths:
|
| 478 |
+
img = Image.open(path).convert("RGB")
|
| 479 |
+
img = img.resize((512, 512)) # Inception v3 input size
|
| 480 |
+
img_array = np.array(img)
|
| 481 |
+
real_images.append(img_array)
|
| 482 |
+
|
| 483 |
+
# Convert to tensor [B, H, W, C] -> [B, C, H, W]
|
| 484 |
+
real_images_tensor = torch.from_numpy(np.stack(real_images)).permute(0, 3, 1, 2).float()
|
| 485 |
+
real_images_tensor = real_images_tensor.to(device)
|
| 486 |
+
|
| 487 |
+
# Generate images
|
| 488 |
+
generator = torch.Generator(device=device).manual_seed(seed + i)
|
| 489 |
+
|
| 490 |
+
# Only capture trajectory for the very first batch to save RAM
|
| 491 |
+
def trajectory_callback(step, timestep, latents):
|
| 492 |
+
if idx == 0 and capture_trajectory:
|
| 493 |
+
# Detach and move to CPU immediately to prevent VRAM OOM
|
| 494 |
+
trajectory_first_image.append(latents.detach().cpu().clone())
|
| 495 |
+
|
| 496 |
+
with torch.no_grad():
|
| 497 |
+
result = pipeline(
|
| 498 |
+
prompt=batch_prompts,
|
| 499 |
+
num_inference_steps=num_inference_steps,
|
| 500 |
+
guidance_scale=guidance_scale,
|
| 501 |
+
generator=generator,
|
| 502 |
+
track_rewards=True,
|
| 503 |
+
print_rewards=False,
|
| 504 |
+
apply_gradient_ascent=apply_gradient_ascent,
|
| 505 |
+
verbose_grad=False,
|
| 506 |
+
callback=trajectory_callback if capture_trajectory else None,
|
| 507 |
+
callback_steps=1
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
# Process generated images
|
| 511 |
+
images = result.images
|
| 512 |
+
|
| 513 |
+
# Update FID metric if computing it
|
| 514 |
+
if compute_fid and fid_metric is not None:
|
| 515 |
+
image_tensors = []
|
| 516 |
+
|
| 517 |
+
for img in images:
|
| 518 |
+
img_resized = img.resize((512, 512)) # Inception v3 input size
|
| 519 |
+
img_array = np.array(img_resized)
|
| 520 |
+
image_tensors.append(img_array)
|
| 521 |
+
|
| 522 |
+
# Convert to tensor and update FID
|
| 523 |
+
images_tensor = torch.from_numpy(np.stack(image_tensors)).permute(0, 3, 1, 2).float()
|
| 524 |
+
images_tensor = images_tensor.to(device)
|
| 525 |
+
|
| 526 |
+
if batch_size == 1:
|
| 527 |
+
real_images_tensor = torch.cat([real_images_tensor, real_images_tensor], dim=0).to(dtype=torch.uint8)
|
| 528 |
+
images_tensor = torch.cat([images_tensor, images_tensor], dim=0).to(dtype=torch.uint8)
|
| 529 |
+
fid_metric.update(real_images_tensor, real=True)
|
| 530 |
+
fid_metric.update(images_tensor, real=False)
|
| 531 |
+
|
| 532 |
+
# Track rewards - get the final timestep reward (t=0)
|
| 533 |
+
current_batch_final_reward = None
|
| 534 |
+
current_batch_final_timestep = None
|
| 535 |
+
if hasattr(pipeline, 'reward_history') and pipeline.reward_history:
|
| 536 |
+
# For each image, get the reward from the last denoising step (t=0 or closest to 0)
|
| 537 |
+
num_steps_per_image = num_inference_steps
|
| 538 |
+
|
| 539 |
+
# Get the last entry which corresponds to the final timestep of the last image in batch
|
| 540 |
+
final_entry = pipeline.reward_history[-1]
|
| 541 |
+
current_batch_final_reward = final_entry['reward_score']
|
| 542 |
+
current_batch_final_timestep = final_entry['timestep']
|
| 543 |
+
all_rewards.append(current_batch_final_reward)
|
| 544 |
+
|
| 545 |
+
# Capture LR history from first image if gradient ascent is enabled
|
| 546 |
+
if apply_gradient_ascent and idx == 0 and lr_history_first_image is None:
|
| 547 |
+
if hasattr(pipeline, 'grad_guidance') and pipeline.grad_guidance:
|
| 548 |
+
grad_stats = pipeline.grad_guidance.get_statistics()
|
| 549 |
+
if grad_stats and 'detailed_stats' in grad_stats:
|
| 550 |
+
# Extract LR history from the gradient ascent statistics
|
| 551 |
+
lr_history_first_image = {
|
| 552 |
+
'prompt': batch_prompts[0],
|
| 553 |
+
'timesteps': [],
|
| 554 |
+
'learning_rates': [], # All LR values from all gradient steps
|
| 555 |
+
'rewards': []
|
| 556 |
+
}
|
| 557 |
+
for stat in grad_stats['detailed_stats']:
|
| 558 |
+
lr_history_first_image['timesteps'].append(stat['timestep'])
|
| 559 |
+
if 'lr_history' in stat:
|
| 560 |
+
# Extend with all LR values from this timestep's gradient steps
|
| 561 |
+
lr_history_first_image['learning_rates'].extend(stat['lr_history'])
|
| 562 |
+
# Collect all rewards from reward_history for each gradient step
|
| 563 |
+
if 'reward_history' in stat:
|
| 564 |
+
lr_history_first_image['rewards'].extend(stat['reward_history'])
|
| 565 |
+
|
| 566 |
+
# Compute CLIP score
|
| 567 |
+
if clip_scorer is not None:
|
| 568 |
+
clip_device = next(clip_scorer.parameters()).device
|
| 569 |
+
# Convert PIL images to tensor format for CLIP score [C, H, W] in range [0, 1]
|
| 570 |
+
for img, prompt in zip(images, batch_prompts):
|
| 571 |
+
img_array = np.array(img).astype(np.float32)
|
| 572 |
+
img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).to(clip_device)
|
| 573 |
+
clip_score = clip_scorer(img_tensor, [prompt]).item()
|
| 574 |
+
all_clip_scores.append(clip_score)
|
| 575 |
+
|
| 576 |
+
# Compute aesthetic score
|
| 577 |
+
if aesthetic_scorer is not None:
|
| 578 |
+
aesthetic_scores = aesthetic_scorer(images)
|
| 579 |
+
if isinstance(aesthetic_scores, torch.Tensor):
|
| 580 |
+
aesthetic_scores = aesthetic_scores.cpu().numpy()
|
| 581 |
+
if aesthetic_scores.ndim == 0:
|
| 582 |
+
aesthetic_scores = [aesthetic_scores.item()]
|
| 583 |
+
all_aesthetic_scores.extend(aesthetic_scores.tolist() if hasattr(aesthetic_scores, 'tolist') else [aesthetic_scores])
|
| 584 |
+
|
| 585 |
+
# Compute PickScore
|
| 586 |
+
if pick_scorer is not None:
|
| 587 |
+
for img, prompt in zip(images, batch_prompts):
|
| 588 |
+
pick_score = pick_scorer(prompt, [img])[0]
|
| 589 |
+
all_pick_scores.append(pick_score)
|
| 590 |
+
|
| 591 |
+
# Compute HPSv2 score
|
| 592 |
+
if hpsv2_scorer is not None:
|
| 593 |
+
for img, prompt in zip(images, batch_prompts):
|
| 594 |
+
hpsv2_score = hpsv2_scorer.score(img, prompt)[0]
|
| 595 |
+
all_hpsv2_scores.append(hpsv2_score)
|
| 596 |
+
|
| 597 |
+
# Compute HPSv2.1 score
|
| 598 |
+
if hpsv21_scorer is not None:
|
| 599 |
+
for img, prompt in zip(images, batch_prompts):
|
| 600 |
+
hpsv21_score = hpsv21_scorer.score(img, prompt)[0]
|
| 601 |
+
all_hpsv21_scores.append(hpsv21_score)
|
| 602 |
+
|
| 603 |
+
# Compute ImageReward score
|
| 604 |
+
if imagereward_scorer is not None:
|
| 605 |
+
for img, prompt in zip(images, batch_prompts):
|
| 606 |
+
imagereward_score = imagereward_scorer.score(prompt, img)
|
| 607 |
+
all_imagereward_scores.append(imagereward_score)
|
| 608 |
+
|
| 609 |
+
# Save generated images if requested
|
| 610 |
+
if save_images and output_dir:
|
| 611 |
+
for img_idx, img in enumerate(images):
|
| 612 |
+
global_idx = i + img_idx
|
| 613 |
+
img_path = mode_output_dir / f"sample_{global_idx:05d}.png"
|
| 614 |
+
img.save(img_path)
|
| 615 |
+
|
| 616 |
+
# Log intermediate FID and metrics every log_interval batches
|
| 617 |
+
if batch_num % log_interval == 0 or batch_num == num_batches:
|
| 618 |
+
num_samples_processed = min(i + batch_size, len(prompts))
|
| 619 |
+
log_msg = f"\n[{mode_name}] Batch {batch_num}/{num_batches} | Samples: {num_samples_processed}/{len(prompts)}"
|
| 620 |
+
|
| 621 |
+
# Add FID if computing
|
| 622 |
+
if compute_fid and fid_metric is not None:
|
| 623 |
+
try:
|
| 624 |
+
current_fid = fid_metric.compute().item()
|
| 625 |
+
log_msg += f" | FID: {current_fid:.4f}"
|
| 626 |
+
except Exception as e:
|
| 627 |
+
log_msg += f" | FID: Computing..."
|
| 628 |
+
|
| 629 |
+
# Add reward - show both final timestep reward and average
|
| 630 |
+
if all_rewards:
|
| 631 |
+
avg_reward = np.mean(all_rewards)
|
| 632 |
+
if current_batch_final_reward is not None:
|
| 633 |
+
log_msg += f" | Reward (t={current_batch_final_timestep}): {current_batch_final_reward:.4f}"
|
| 634 |
+
log_msg += f" | Reward (Avg): {avg_reward:.4f}"
|
| 635 |
+
else:
|
| 636 |
+
log_msg += f" | Reward (Avg): {avg_reward:.4f}"
|
| 637 |
+
|
| 638 |
+
# Add CLIP if computing
|
| 639 |
+
if clip_scorer is not None and all_clip_scores:
|
| 640 |
+
log_msg += f" | CLIP: {np.mean(all_clip_scores):.4f}"
|
| 641 |
+
|
| 642 |
+
# Add aesthetic if computing
|
| 643 |
+
if aesthetic_scorer is not None and all_aesthetic_scores:
|
| 644 |
+
log_msg += f" | Aesthetic: {np.mean(all_aesthetic_scores):.4f}"
|
| 645 |
+
|
| 646 |
+
# Add PickScore
|
| 647 |
+
if pick_scorer is not None and all_pick_scores:
|
| 648 |
+
log_msg += f" | PickScore: {np.mean(all_pick_scores):.4f}"
|
| 649 |
+
|
| 650 |
+
# Add HPSv2
|
| 651 |
+
if hpsv2_scorer is not None and all_hpsv2_scores:
|
| 652 |
+
log_msg += f" | HPSv2: {np.mean(all_hpsv2_scores):.4f}"
|
| 653 |
+
|
| 654 |
+
# Add HPSv2.1
|
| 655 |
+
if hpsv21_scorer is not None and all_hpsv21_scores:
|
| 656 |
+
log_msg += f" | HPSv2.1: {np.mean(all_hpsv21_scores):.4f}"
|
| 657 |
+
|
| 658 |
+
# Add ImageReward
|
| 659 |
+
if imagereward_scorer is not None and all_imagereward_scores:
|
| 660 |
+
log_msg += f" | ImageReward: {np.mean(all_imagereward_scores):.4f}"
|
| 661 |
+
|
| 662 |
+
print(log_msg)
|
| 663 |
+
|
| 664 |
+
# Re-enable progress bars
|
| 665 |
+
pipeline.set_progress_bar_config(disable=False)
|
| 666 |
+
|
| 667 |
+
avg_reward = np.mean(all_rewards) if all_rewards else 0.0
|
| 668 |
+
avg_clip_score = np.mean(all_clip_scores) if all_clip_scores else 0.0
|
| 669 |
+
avg_aesthetic_score = np.mean(all_aesthetic_scores) if all_aesthetic_scores else 0.0
|
| 670 |
+
avg_pick_score = np.mean(all_pick_scores) if all_pick_scores else 0.0
|
| 671 |
+
avg_hpsv2_score = np.mean(all_hpsv2_scores) if all_hpsv2_scores else 0.0
|
| 672 |
+
avg_hpsv21_score = np.mean(all_hpsv21_scores) if all_hpsv21_scores else 0.0
|
| 673 |
+
avg_imagereward_score = np.mean(all_imagereward_scores) if all_imagereward_scores else 0.0
|
| 674 |
+
|
| 675 |
+
return avg_reward, fid_metric, avg_clip_score, avg_aesthetic_score, avg_pick_score, avg_hpsv2_score, avg_hpsv21_score, avg_imagereward_score, lr_history_first_image, trajectory_first_image
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
def auto_increment_path(base_path):
|
| 679 |
+
"""
|
| 680 |
+
Create an auto-incrementing run folder inside base_path.
|
| 681 |
+
Returns: base_path/run_1, base_path/run_2, etc.
|
| 682 |
+
"""
|
| 683 |
+
base_path = Path(base_path)
|
| 684 |
+
base_path.mkdir(parents=True, exist_ok=True) # Ensure base directory exists
|
| 685 |
+
|
| 686 |
+
i = 1
|
| 687 |
+
while True:
|
| 688 |
+
new_path = base_path / f"run_{i}"
|
| 689 |
+
if not new_path.exists():
|
| 690 |
+
return new_path
|
| 691 |
+
i += 1
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
def main():
|
| 695 |
+
parser = argparse.ArgumentParser(description="Evaluate baseline and gradient ascent pipelines")
|
| 696 |
+
parser.add_argument("--data_dir", type=str, default="./data", help="Path to data directory")
|
| 697 |
+
parser.add_argument("--dataset_type", type=str, default="coco", choices=["coco", "pickapic"],
|
| 698 |
+
help="Dataset to use for evaluation: coco or pickapic (default: coco)")
|
| 699 |
+
parser.add_argument("--base_model", type=str, default=None, help="Override SANA base model repo id")
|
| 700 |
+
parser.add_argument(
|
| 701 |
+
"--model_variant",
|
| 702 |
+
type=str,
|
| 703 |
+
default="sana_600m_512",
|
| 704 |
+
choices=list(SANA_PROFILE_TO_MODEL_ID.keys()),
|
| 705 |
+
help="SANA model profile to use",
|
| 706 |
+
)
|
| 707 |
+
parser.add_argument("--lrm_model", type=str, default=None, help="SANA reward checkpoint path (directory or model.safetensors).")
|
| 708 |
+
parser.add_argument("--hf_cache_dir", type=str, default="/scratch/rr81/ma5430/.cache/huggingface/hub", help="Shared HF cache directory")
|
| 709 |
+
parser.add_argument("--offline", action="store_true", help="Force fully offline mode (recommended on GPU nodes)")
|
| 710 |
+
parser.add_argument("--num_steps", type=int, default=50, help="Number of inference steps")
|
| 711 |
+
parser.add_argument("--cfg_scale", type=float, default=4.5, help="Classifier-free guidance scale")
|
| 712 |
+
parser.add_argument("--dtype", type=str, default="bf16", choices=["auto", "bf16", "fp16", "fp32"],
|
| 713 |
+
help="Generation/reward dtype. bf16 is recommended for tiny gradient stability.")
|
| 714 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
| 715 |
+
parser.add_argument("--max_samples", type=int, default=None, help="Max samples to evaluate (None for all)")
|
| 716 |
+
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generation (use 1 for reward model compatibility)")
|
| 717 |
+
parser.add_argument("--fid_batch_size", type=int, default=32, help="Batch size for FID computation")
|
| 718 |
+
parser.add_argument("--log_interval", type=int, default=10, help="Log FID and metrics every N batches")
|
| 719 |
+
parser.add_argument("--output_dir", type=str, default="eval_outputs", help="Directory to save generated images and results")
|
| 720 |
+
parser.add_argument("--save_images", action="store_true", help="Save all generated images to output directory")
|
| 721 |
+
parser.add_argument("--mode", type=str, default="both", choices=["baseline", "gradient_ascent", "both"],
|
| 722 |
+
help="Which evaluation to run: baseline, gradient_ascent, or both (default: both)")
|
| 723 |
+
|
| 724 |
+
# Metrics selection
|
| 725 |
+
parser.add_argument("--metrics", type=str, nargs="+", default=["clip", "aesthetic"],
|
| 726 |
+
choices=["fid", "clip", "aesthetic", "pickscore", "hpsv2", "hpsv21", "imagereward"],
|
| 727 |
+
help="Which metrics to evaluate (default: clip aesthetic)")
|
| 728 |
+
parser.add_argument("--scorer_device", type=str, default="auto", choices=["auto", "cpu", "cuda"],
|
| 729 |
+
help="Device for metric scorers. auto keeps scorers on GPU only when enough VRAM is free.")
|
| 730 |
+
|
| 731 |
+
# Gradient ascent config
|
| 732 |
+
parser.add_argument("--grad_config", type=str, default=None,
|
| 733 |
+
help=f"Gradient ascent config preset (available: {', '.join(list_configs())}). "
|
| 734 |
+
"If provided, overrides individual grad_* arguments.")
|
| 735 |
+
parser.add_argument("--grad_range_start", type=int, default=0, help="Gradient timestep range start")
|
| 736 |
+
parser.add_argument("--grad_range_end", type=int, default=700, help="Gradient timestep range end")
|
| 737 |
+
parser.add_argument("--grad_steps", type=int, default=5, help="Number of gradient steps per timestep (use 5 for better reward improvement)")
|
| 738 |
+
parser.add_argument("--grad_step_size", type=float, default=0.1, help="Gradient step size (initial LR)")
|
| 739 |
+
|
| 740 |
+
# Config overrides (these override values from grad_config if specified)
|
| 741 |
+
parser.add_argument("--override_momentum", type=float, default=None, help="Override momentum value from grad_config")
|
| 742 |
+
parser.add_argument("--override_num_grad_steps", type=int, default=None, help="Override num_grad_steps from grad_config")
|
| 743 |
+
parser.add_argument("--override_grad_step_size", type=float, default=None, help="Override grad_step_size from grad_config")
|
| 744 |
+
|
| 745 |
+
# Cuda
|
| 746 |
+
parser.add_argument("--cuda", type=int, default=0, help="Use CUDA device id")
|
| 747 |
+
|
| 748 |
+
args = parser.parse_args()
|
| 749 |
+
|
| 750 |
+
hf_cache_dir, offline_enabled = configure_hf_runtime(args.hf_cache_dir, force_offline=args.offline)
|
| 751 |
+
if args.base_model is None:
|
| 752 |
+
args.base_model = SANA_PROFILE_TO_MODEL_ID[args.model_variant]
|
| 753 |
+
if args.lrm_model is None:
|
| 754 |
+
args.lrm_model = resolve_default_lrm_model()
|
| 755 |
+
if not args.lrm_model:
|
| 756 |
+
raise ValueError(
|
| 757 |
+
"No SANA reward checkpoint found. Provide --lrm_model pointing to checkpoint dir or model.safetensors"
|
| 758 |
+
)
|
| 759 |
+
|
| 760 |
+
seed_everything(args.seed)
|
| 761 |
+
|
| 762 |
+
# Configuration
|
| 763 |
+
device = f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu"
|
| 764 |
+
dtype = resolve_generation_dtype(args.dtype, device)
|
| 765 |
+
os.environ["ACCELERATE_MIXED_PRECISION"] = dtype_to_name(dtype)
|
| 766 |
+
configure_cudnn_safely(device)
|
| 767 |
+
|
| 768 |
+
# Create auto-incremented output directory
|
| 769 |
+
args.output_dir = auto_increment_path(args.output_dir)
|
| 770 |
+
|
| 771 |
+
# Setup logging to file
|
| 772 |
+
tee_logger, log_file = setup_logging(args.output_dir)
|
| 773 |
+
|
| 774 |
+
print("="*70)
|
| 775 |
+
print("FID EVALUATION: BASELINE vs GRADIENT ASCENT")
|
| 776 |
+
print("="*70)
|
| 777 |
+
print(f"\nLogging to: {log_file}")
|
| 778 |
+
print(f"\nDevice: {device}")
|
| 779 |
+
print(f"Dataset: {args.dataset_type.upper()}")
|
| 780 |
+
print(f"Data directory: {args.data_dir}")
|
| 781 |
+
print(f"Base model: {args.base_model}")
|
| 782 |
+
print(f"Model variant: {args.model_variant}")
|
| 783 |
+
print(f"LRM model: {args.lrm_model}")
|
| 784 |
+
print(f"HF cache dir: {hf_cache_dir or 'default'}")
|
| 785 |
+
print(f"HF offline mode: {offline_enabled}")
|
| 786 |
+
print(f"Inference steps: {args.num_steps}")
|
| 787 |
+
print(f"CFG scale: {args.cfg_scale}")
|
| 788 |
+
print(f"Batch size: {args.batch_size}")
|
| 789 |
+
print(f"Max samples: {args.max_samples or 'All'}")
|
| 790 |
+
print(f"Generation dtype: {dtype_to_name(dtype)}")
|
| 791 |
+
print(f"Output directory: {args.output_dir}")
|
| 792 |
+
print(f"Save images: {args.save_images}")
|
| 793 |
+
print(f"Evaluation mode: {args.mode}")
|
| 794 |
+
print(f"Metrics to evaluate: {', '.join(args.metrics).upper()}")
|
| 795 |
+
if args.grad_config:
|
| 796 |
+
print(f"Gradient ascent config: {args.grad_config}")
|
| 797 |
+
|
| 798 |
+
# Load validation data
|
| 799 |
+
print("\n" + "="*70)
|
| 800 |
+
print("1. LOADING VALIDATION DATA")
|
| 801 |
+
print("="*70)
|
| 802 |
+
prompts, image_paths = load_validation_data(
|
| 803 |
+
args.data_dir,
|
| 804 |
+
args.max_samples,
|
| 805 |
+
args.dataset_type,
|
| 806 |
+
hf_cache_dir=hf_cache_dir,
|
| 807 |
+
offline=offline_enabled,
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
# Automatically disable FID if no reference images available (e.g., Pick-a-Pic dataset)
|
| 811 |
+
can_compute_fid = image_paths is not None
|
| 812 |
+
if not can_compute_fid and "fid" in args.metrics:
|
| 813 |
+
print("\n⚠ Warning: FID metric requested but no reference images available. FID will be skipped.")
|
| 814 |
+
args.metrics = [m for m in args.metrics if m != "fid"]
|
| 815 |
+
|
| 816 |
+
# Load reward model
|
| 817 |
+
print("\n" + "="*70)
|
| 818 |
+
print("2. LOADING REWARD MODEL")
|
| 819 |
+
print("="*70)
|
| 820 |
+
reward_model = LRMRewardModel(
|
| 821 |
+
pretrained_model_name_or_path=args.base_model,
|
| 822 |
+
lrm_model_path=args.lrm_model,
|
| 823 |
+
model_profile=args.model_variant,
|
| 824 |
+
guidance_scale=args.cfg_scale,
|
| 825 |
+
device=device
|
| 826 |
+
)
|
| 827 |
+
if dtype == torch.float16:
|
| 828 |
+
reward_model = reward_model.half()
|
| 829 |
+
elif dtype == torch.bfloat16:
|
| 830 |
+
reward_model = reward_model.to(dtype=torch.bfloat16)
|
| 831 |
+
else:
|
| 832 |
+
reward_model = reward_model.to(dtype=torch.float32)
|
| 833 |
+
reward_model.eval()
|
| 834 |
+
print("✓ Reward model loaded")
|
| 835 |
+
|
| 836 |
+
# Load pipeline
|
| 837 |
+
print("\n" + "="*70)
|
| 838 |
+
print("3. LOADING PIPELINE")
|
| 839 |
+
print("="*70)
|
| 840 |
+
|
| 841 |
+
pretrained_kwargs = {"local_files_only": offline_enabled}
|
| 842 |
+
if hf_cache_dir:
|
| 843 |
+
pretrained_kwargs["cache_dir"] = hf_cache_dir
|
| 844 |
+
|
| 845 |
+
base_pipeline = SanaPipeline.from_pretrained(
|
| 846 |
+
args.base_model,
|
| 847 |
+
torch_dtype=dtype,
|
| 848 |
+
**pretrained_kwargs,
|
| 849 |
+
)
|
| 850 |
+
print(f"✓ Loaded SANA base model: {args.base_model}")
|
| 851 |
+
|
| 852 |
+
pipeline = SanaGradientAscentPipeline(**base_pipeline.components)
|
| 853 |
+
pipeline = pipeline.to(device)
|
| 854 |
+
pipeline.set_reward_model(reward_model)
|
| 855 |
+
print("✓ Pipeline loaded")
|
| 856 |
+
|
| 857 |
+
scorer_device = resolve_scorer_device(args.scorer_device, device)
|
| 858 |
+
scorer_dtype = dtype if str(scorer_device).startswith("cuda") else torch.float32
|
| 859 |
+
print(f"Scorer device: {scorer_device}")
|
| 860 |
+
|
| 861 |
+
if torch.cuda.is_available():
|
| 862 |
+
torch.cuda.empty_cache()
|
| 863 |
+
|
| 864 |
+
# Load CLIP scorer
|
| 865 |
+
print("\n" + "="*70)
|
| 866 |
+
print("3.5. LOADING CLIP AND AESTHETIC SCORERS")
|
| 867 |
+
print("="*70)
|
| 868 |
+
|
| 869 |
+
# Only load scorers for requested metrics
|
| 870 |
+
clip_scorer = None
|
| 871 |
+
aesthetic_scorer = None
|
| 872 |
+
pick_scorer = None
|
| 873 |
+
hpsv2_scorer = None
|
| 874 |
+
hpsv21_scorer = None
|
| 875 |
+
imagereward_scorer = None
|
| 876 |
+
|
| 877 |
+
if "clip" in args.metrics:
|
| 878 |
+
try:
|
| 879 |
+
clip_scorer = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14").to(scorer_device)
|
| 880 |
+
print("✓ CLIP scorer loaded")
|
| 881 |
+
except Exception as e:
|
| 882 |
+
print(f"Warning: Could not load CLIP scorer: {e}")
|
| 883 |
+
clip_scorer = None
|
| 884 |
+
else:
|
| 885 |
+
print("⊘ CLIP scorer skipped (not in selected metrics)")
|
| 886 |
+
|
| 887 |
+
if "aesthetic" in args.metrics:
|
| 888 |
+
try:
|
| 889 |
+
aesthetic_scorer = AestheticScorer(dtype=scorer_dtype, device=scorer_device)
|
| 890 |
+
print("✓ Aesthetic scorer loaded")
|
| 891 |
+
except Exception as e:
|
| 892 |
+
print(f"Warning: Could not load Aesthetic scorer: {e}")
|
| 893 |
+
aesthetic_scorer = None
|
| 894 |
+
else:
|
| 895 |
+
print("⊘ Aesthetic scorer skipped (not in selected metrics)")
|
| 896 |
+
|
| 897 |
+
if "pickscore" in args.metrics:
|
| 898 |
+
try:
|
| 899 |
+
from pick_score import PickScorer
|
| 900 |
+
pick_scorer = PickScorer(
|
| 901 |
+
processor_name_or_path="laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
|
| 902 |
+
model_pretrained_name_or_path="yuvalkirstain/PickScore_v1",
|
| 903 |
+
device=scorer_device
|
| 904 |
+
)
|
| 905 |
+
print("✓ PickScore scorer loaded")
|
| 906 |
+
except Exception as e:
|
| 907 |
+
print(f"Warning: Could not load PickScore scorer: {e}")
|
| 908 |
+
pick_scorer = None
|
| 909 |
+
else:
|
| 910 |
+
print("⊘ PickScore scorer skipped (not in selected metrics)")
|
| 911 |
+
|
| 912 |
+
if "hpsv2" in args.metrics:
|
| 913 |
+
try:
|
| 914 |
+
from hpsv2_score import HPSv2Scorer
|
| 915 |
+
hf_dl_kwargs = {"local_files_only": offline_enabled}
|
| 916 |
+
if hf_cache_dir:
|
| 917 |
+
hf_dl_kwargs["cache_dir"] = hf_cache_dir
|
| 918 |
+
hpsv2_scorer = HPSv2Scorer(
|
| 919 |
+
clip_pretrained_name_or_path=hf_hub_download(
|
| 920 |
+
repo_id="laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
|
| 921 |
+
filename="open_clip_pytorch_model.bin",
|
| 922 |
+
**hf_dl_kwargs,
|
| 923 |
+
),
|
| 924 |
+
model_pretrained_name_or_path=hf_hub_download(
|
| 925 |
+
repo_id="xswu/HPSv2",
|
| 926 |
+
filename="HPS_v2_compressed.pt",
|
| 927 |
+
**hf_dl_kwargs,
|
| 928 |
+
),
|
| 929 |
+
device=scorer_device
|
| 930 |
+
)
|
| 931 |
+
print("✓ HPSv2 scorer loaded")
|
| 932 |
+
except Exception as e:
|
| 933 |
+
print(f"Warning: Could not load HPSv2 scorer: {e}")
|
| 934 |
+
hpsv2_scorer = None
|
| 935 |
+
else:
|
| 936 |
+
print("⊘ HPSv2 scorer skipped (not in selected metrics)")
|
| 937 |
+
|
| 938 |
+
if "hpsv21" in args.metrics:
|
| 939 |
+
try:
|
| 940 |
+
from hpsv2_score import HPSv2Scorer
|
| 941 |
+
hf_dl_kwargs = {"local_files_only": offline_enabled}
|
| 942 |
+
if hf_cache_dir:
|
| 943 |
+
hf_dl_kwargs["cache_dir"] = hf_cache_dir
|
| 944 |
+
hpsv21_scorer = HPSv2Scorer(
|
| 945 |
+
clip_pretrained_name_or_path=hf_hub_download(
|
| 946 |
+
repo_id="laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
|
| 947 |
+
filename="open_clip_pytorch_model.bin",
|
| 948 |
+
**hf_dl_kwargs,
|
| 949 |
+
),
|
| 950 |
+
model_pretrained_name_or_path=hf_hub_download(
|
| 951 |
+
repo_id="xswu/HPSv2",
|
| 952 |
+
filename="HPS_v2.1_compressed.pt",
|
| 953 |
+
**hf_dl_kwargs,
|
| 954 |
+
),
|
| 955 |
+
device=scorer_device
|
| 956 |
+
)
|
| 957 |
+
print("✓ HPSv2.1 scorer loaded")
|
| 958 |
+
except Exception as e:
|
| 959 |
+
print(f"Warning: Could not load HPSv2.1 scorer: {e}")
|
| 960 |
+
hpsv21_scorer = None
|
| 961 |
+
else:
|
| 962 |
+
print("⊘ HPSv2.1 scorer skipped (not in selected metrics)")
|
| 963 |
+
|
| 964 |
+
if "imagereward" in args.metrics:
|
| 965 |
+
try:
|
| 966 |
+
from imagereward_score import load_imagereward
|
| 967 |
+
hf_dl_kwargs = {"local_files_only": offline_enabled}
|
| 968 |
+
if hf_cache_dir:
|
| 969 |
+
hf_dl_kwargs["cache_dir"] = hf_cache_dir
|
| 970 |
+
imagereward_scorer = load_imagereward(
|
| 971 |
+
model_path=hf_hub_download(repo_id="THUDM/ImageReward", filename="ImageReward.pt", **hf_dl_kwargs),
|
| 972 |
+
med_config=hf_hub_download(repo_id="THUDM/ImageReward", filename="med_config.json", **hf_dl_kwargs),
|
| 973 |
+
device=scorer_device
|
| 974 |
+
)
|
| 975 |
+
print("✓ ImageReward scorer loaded")
|
| 976 |
+
except Exception as e:
|
| 977 |
+
print(f"Warning: Could not load ImageReward scorer: {e}")
|
| 978 |
+
imagereward_scorer = None
|
| 979 |
+
else:
|
| 980 |
+
print("⊘ ImageReward scorer skipped (not in selected metrics)")
|
| 981 |
+
|
| 982 |
+
# Configure gradient ascent
|
| 983 |
+
print("\n" + "="*70)
|
| 984 |
+
print("4. CONFIGURING GRADIENT ASCENT")
|
| 985 |
+
print("="*70)
|
| 986 |
+
|
| 987 |
+
# Use config preset if provided, otherwise use individual args
|
| 988 |
+
if args.grad_config:
|
| 989 |
+
print(f"Loading gradient ascent config: {args.grad_config}")
|
| 990 |
+
grad_config = get_config(args.grad_config)
|
| 991 |
+
print(f"Config loaded: {grad_config}")
|
| 992 |
+
|
| 993 |
+
# Apply overrides if specified
|
| 994 |
+
if args.override_momentum is not None:
|
| 995 |
+
grad_config['momentum'] = args.override_momentum
|
| 996 |
+
print(f" Overriding momentum: {args.override_momentum}")
|
| 997 |
+
if args.override_num_grad_steps is not None:
|
| 998 |
+
grad_config['num_grad_steps'] = args.override_num_grad_steps
|
| 999 |
+
print(f" Overriding num_grad_steps: {args.override_num_grad_steps}")
|
| 1000 |
+
if args.override_grad_step_size is not None:
|
| 1001 |
+
grad_config['grad_step_size'] = args.override_grad_step_size
|
| 1002 |
+
print(f" Overriding grad_step_size: {args.override_grad_step_size}")
|
| 1003 |
+
else:
|
| 1004 |
+
grad_config = {
|
| 1005 |
+
"grad_timestep_range": (args.grad_range_start, args.grad_range_end),
|
| 1006 |
+
"num_grad_steps": args.grad_steps,
|
| 1007 |
+
"grad_step_size": args.grad_step_size,
|
| 1008 |
+
}
|
| 1009 |
+
print(f"Using manual gradient ascent configuration")
|
| 1010 |
+
|
| 1011 |
+
print(f"Gradient timestep range: {grad_config.get('grad_timestep_range', (args.grad_range_start, args.grad_range_end))}")
|
| 1012 |
+
print(f"Gradient steps: {grad_config.get('num_grad_steps', args.grad_steps)}")
|
| 1013 |
+
print(f"Gradient step size (initial LR): {grad_config.get('grad_step_size', args.grad_step_size)}")
|
| 1014 |
+
if grad_config.get('lr_scheduler_type'):
|
| 1015 |
+
print(f"LR Scheduler: {grad_config['lr_scheduler_type']}")
|
| 1016 |
+
if grad_config.get('use_momentum'):
|
| 1017 |
+
print(f"Momentum: {grad_config.get('momentum', 0.9)} (Nesterov: {grad_config.get('use_nesterov', False)})")
|
| 1018 |
+
|
| 1019 |
+
pipeline.enable_gradient_ascent(**grad_config)
|
| 1020 |
+
|
| 1021 |
+
# Initialize result variables
|
| 1022 |
+
fid_score_baseline = None
|
| 1023 |
+
avg_reward_baseline = None
|
| 1024 |
+
clip_score_baseline = None
|
| 1025 |
+
aesthetic_score_baseline = None
|
| 1026 |
+
pick_score_baseline = None
|
| 1027 |
+
hpsv2_score_baseline = None
|
| 1028 |
+
hpsv21_score_baseline = None
|
| 1029 |
+
imagereward_score_baseline = None
|
| 1030 |
+
fid_score_grad = None
|
| 1031 |
+
avg_reward_grad = None
|
| 1032 |
+
clip_score_grad = None
|
| 1033 |
+
aesthetic_score_grad = None
|
| 1034 |
+
pick_score_grad = None
|
| 1035 |
+
hpsv2_score_grad = None
|
| 1036 |
+
hpsv21_score_grad = None
|
| 1037 |
+
imagereward_score_grad = None
|
| 1038 |
+
grad_stats = None
|
| 1039 |
+
|
| 1040 |
+
# ========== BASELINE EVALUATION ==========
|
| 1041 |
+
if args.mode in ["baseline", "both"]:
|
| 1042 |
+
print("\n" + "="*70)
|
| 1043 |
+
print("5. EVALUATING BASELINE")
|
| 1044 |
+
print("="*70)
|
| 1045 |
+
|
| 1046 |
+
# Generate and evaluate baseline
|
| 1047 |
+
avg_reward_baseline, fid_baseline, clip_score_baseline, aesthetic_score_baseline, pick_score_baseline, hpsv2_score_baseline, hpsv21_score_baseline, imagereward_score_baseline, _, baseline_trajectory = generate_and_evaluate(
|
| 1048 |
+
pipeline=pipeline,
|
| 1049 |
+
prompts=prompts,
|
| 1050 |
+
image_paths=image_paths,
|
| 1051 |
+
device=device,
|
| 1052 |
+
dtype=dtype,
|
| 1053 |
+
num_inference_steps=args.num_steps,
|
| 1054 |
+
guidance_scale=args.cfg_scale,
|
| 1055 |
+
seed=args.seed,
|
| 1056 |
+
batch_size=args.batch_size,
|
| 1057 |
+
apply_gradient_ascent=False,
|
| 1058 |
+
mode_name="baseline",
|
| 1059 |
+
log_interval=args.log_interval,
|
| 1060 |
+
output_dir=args.output_dir,
|
| 1061 |
+
save_images=args.save_images,
|
| 1062 |
+
clip_scorer=clip_scorer,
|
| 1063 |
+
aesthetic_scorer=aesthetic_scorer,
|
| 1064 |
+
pick_scorer=pick_scorer,
|
| 1065 |
+
hpsv2_scorer=hpsv2_scorer,
|
| 1066 |
+
hpsv21_scorer=hpsv21_scorer,
|
| 1067 |
+
imagereward_scorer=imagereward_scorer,
|
| 1068 |
+
compute_fid=("fid" in args.metrics and can_compute_fid),
|
| 1069 |
+
capture_trajectory=True
|
| 1070 |
+
)
|
| 1071 |
+
|
| 1072 |
+
# Compute FID for baseline if requested
|
| 1073 |
+
if "fid" in args.metrics and fid_baseline is not None:
|
| 1074 |
+
fid_score_baseline = fid_baseline.compute().item()
|
| 1075 |
+
print(f"\n✓ Baseline FID: {fid_score_baseline:.4f}")
|
| 1076 |
+
print(f"✓ Baseline Avg Reward: {avg_reward_baseline:.4f}")
|
| 1077 |
+
if "clip" in args.metrics:
|
| 1078 |
+
print(f"✓ Baseline Avg CLIP Score: {clip_score_baseline:.4f}")
|
| 1079 |
+
if "aesthetic" in args.metrics:
|
| 1080 |
+
print(f"✓ Baseline Avg Aesthetic Score: {aesthetic_score_baseline:.4f}")
|
| 1081 |
+
if "pickscore" in args.metrics and pick_score_baseline is not None:
|
| 1082 |
+
print(f"✓ Baseline Avg PickScore: {pick_score_baseline:.4f}")
|
| 1083 |
+
if "hpsv2" in args.metrics and hpsv2_score_baseline is not None:
|
| 1084 |
+
print(f"✓ Baseline Avg HPSv2 Score: {hpsv2_score_baseline:.4f}")
|
| 1085 |
+
if "hpsv21" in args.metrics and hpsv21_score_baseline is not None:
|
| 1086 |
+
print(f"✓ Baseline Avg HPSv2.1 Score: {hpsv21_score_baseline:.4f}")
|
| 1087 |
+
if "imagereward" in args.metrics and imagereward_score_baseline is not None:
|
| 1088 |
+
print(f"✓ Baseline Avg ImageReward: {imagereward_score_baseline:.4f}")
|
| 1089 |
+
|
| 1090 |
+
# ========== GRADIENT ASCENT EVALUATION ==========
|
| 1091 |
+
if args.mode in ["gradient_ascent", "both"]:
|
| 1092 |
+
print("\n" + "="*70)
|
| 1093 |
+
print("6. EVALUATING GRADIENT ASCENT")
|
| 1094 |
+
print("="*70)
|
| 1095 |
+
|
| 1096 |
+
# Generate and evaluate with gradient ascent
|
| 1097 |
+
avg_reward_grad, fid_grad, clip_score_grad, aesthetic_score_grad, pick_score_grad, hpsv2_score_grad, hpsv21_score_grad, imagereward_score_grad, lr_history, guided_trajectory = generate_and_evaluate(
|
| 1098 |
+
pipeline=pipeline,
|
| 1099 |
+
prompts=prompts,
|
| 1100 |
+
image_paths=image_paths,
|
| 1101 |
+
device=device,
|
| 1102 |
+
dtype=dtype,
|
| 1103 |
+
num_inference_steps=args.num_steps,
|
| 1104 |
+
guidance_scale=args.cfg_scale,
|
| 1105 |
+
seed=args.seed,
|
| 1106 |
+
batch_size=args.batch_size,
|
| 1107 |
+
apply_gradient_ascent=True,
|
| 1108 |
+
mode_name="gradient_ascent",
|
| 1109 |
+
log_interval=args.log_interval,
|
| 1110 |
+
output_dir=args.output_dir,
|
| 1111 |
+
save_images=args.save_images,
|
| 1112 |
+
clip_scorer=clip_scorer,
|
| 1113 |
+
aesthetic_scorer=aesthetic_scorer,
|
| 1114 |
+
pick_scorer=pick_scorer,
|
| 1115 |
+
hpsv2_scorer=hpsv2_scorer,
|
| 1116 |
+
hpsv21_scorer=hpsv21_scorer,
|
| 1117 |
+
imagereward_scorer=imagereward_scorer,
|
| 1118 |
+
compute_fid=("fid" in args.metrics and can_compute_fid),
|
| 1119 |
+
capture_trajectory=True
|
| 1120 |
+
)
|
| 1121 |
+
|
| 1122 |
+
# Compute FID for gradient ascent if requested
|
| 1123 |
+
if "fid" in args.metrics and fid_grad is not None:
|
| 1124 |
+
fid_score_grad = fid_grad.compute().item()
|
| 1125 |
+
print(f"\n✓ Gradient Ascent FID: {fid_score_grad:.4f}")
|
| 1126 |
+
print(f"✓ Gradient Ascent Avg Reward: {avg_reward_grad:.4f}")
|
| 1127 |
+
if "clip" in args.metrics:
|
| 1128 |
+
print(f"✓ Gradient Ascent Avg CLIP Score: {clip_score_grad:.4f}")
|
| 1129 |
+
if "aesthetic" in args.metrics:
|
| 1130 |
+
print(f"✓ Gradient Ascent Avg Aesthetic Score: {aesthetic_score_grad:.4f}")
|
| 1131 |
+
if "pickscore" in args.metrics and pick_score_grad is not None:
|
| 1132 |
+
print(f"✓ Gradient Ascent Avg PickScore: {pick_score_grad:.4f}")
|
| 1133 |
+
if "hpsv2" in args.metrics and hpsv2_score_grad is not None:
|
| 1134 |
+
print(f"✓ Gradient Ascent Avg HPSv2 Score: {hpsv2_score_grad:.4f}")
|
| 1135 |
+
if "hpsv21" in args.metrics and hpsv21_score_grad is not None:
|
| 1136 |
+
print(f"✓ Gradient Ascent Avg HPSv2.1 Score: {hpsv21_score_grad:.4f}")
|
| 1137 |
+
if "imagereward" in args.metrics and imagereward_score_grad is not None:
|
| 1138 |
+
print(f"✓ Gradient Ascent Avg ImageReward: {imagereward_score_grad:.4f}")
|
| 1139 |
+
|
| 1140 |
+
# Get gradient stats
|
| 1141 |
+
grad_stats = pipeline.grad_guidance.get_statistics()
|
| 1142 |
+
if grad_stats:
|
| 1143 |
+
print(f"\nGradient Ascent Statistics:")
|
| 1144 |
+
print(f" Applications: {grad_stats['num_applications']}")
|
| 1145 |
+
print(f" Total reward improvement: {grad_stats['total_reward_improvement']:+.4f}")
|
| 1146 |
+
print(f" Avg reward improvement: {grad_stats['avg_reward_improvement']:+.4f}")
|
| 1147 |
+
|
| 1148 |
+
# Plot LR curve if we captured it
|
| 1149 |
+
if lr_history is not None and lr_history['learning_rates']:
|
| 1150 |
+
plot_path = Path(args.output_dir) / "lr_curve.png"
|
| 1151 |
+
|
| 1152 |
+
# LR values are now continuous across all gradient steps
|
| 1153 |
+
lrs = lr_history['learning_rates']
|
| 1154 |
+
steps = list(range(len(lrs))) # Step indices (0 to total_steps-1)
|
| 1155 |
+
|
| 1156 |
+
plt.figure(figsize=(12, 6))
|
| 1157 |
+
plt.plot(steps, lrs, linewidth=2, color='blue', alpha=0.8)
|
| 1158 |
+
|
| 1159 |
+
# Mark the first step with a star
|
| 1160 |
+
plt.plot(steps[0], lrs[0], marker='*', markersize=20, color='gold',
|
| 1161 |
+
markeredgecolor='darkgoldenrod', markeredgewidth=2, zorder=5)
|
| 1162 |
+
|
| 1163 |
+
# Mark timestep boundaries
|
| 1164 |
+
num_timesteps = len(lr_history['timesteps'])
|
| 1165 |
+
num_grad_steps_per_timestep = len(lrs) // num_timesteps if num_timesteps > 0 else 0
|
| 1166 |
+
if num_grad_steps_per_timestep > 0:
|
| 1167 |
+
for i in range(num_timesteps + 1):
|
| 1168 |
+
step_idx = i * num_grad_steps_per_timestep
|
| 1169 |
+
if step_idx <= len(lrs):
|
| 1170 |
+
plt.axvline(x=step_idx, color='red', linestyle='--', alpha=0.3, linewidth=1)
|
| 1171 |
+
if i < num_timesteps:
|
| 1172 |
+
plt.text(step_idx, plt.ylim()[1] * 0.95, f't={lr_history["timesteps"][i]}',
|
| 1173 |
+
fontsize=8, color='red', alpha=0.7, ha='left')
|
| 1174 |
+
|
| 1175 |
+
plt.xlabel('Global Gradient Step', fontsize=12)
|
| 1176 |
+
plt.ylabel('Learning Rate', fontsize=12)
|
| 1177 |
+
plt.title(f'Learning Rate Evolution Across All Gradient Steps\\nPrompt: "{lr_history["prompt"][:60]}..."',
|
| 1178 |
+
fontsize=12, fontweight='bold')
|
| 1179 |
+
plt.grid(True, alpha=0.3)
|
| 1180 |
+
|
| 1181 |
+
# Add info text
|
| 1182 |
+
num_timesteps = len(lr_history['timesteps'])
|
| 1183 |
+
num_grad_steps_per_timestep = len(lrs) // num_timesteps if num_timesteps > 0 else 0
|
| 1184 |
+
plt.text(0.02, 0.98,
|
| 1185 |
+
f'Total timesteps: {num_timesteps}\\nGrad steps/timestep: {num_grad_steps_per_timestep}\\nTotal grad steps: {len(lrs)}',
|
| 1186 |
+
transform=plt.gca().transAxes, fontsize=10, verticalalignment='top',
|
| 1187 |
+
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
| 1188 |
+
|
| 1189 |
+
plt.tight_layout()
|
| 1190 |
+
plt.savefig(plot_path, dpi=150, bbox_inches='tight')
|
| 1191 |
+
plt.close()
|
| 1192 |
+
print(f"\n✓ Saved LR curve plot to: {plot_path}")
|
| 1193 |
+
print(f" Total gradient steps: {len(lrs)}")
|
| 1194 |
+
print(f" LR range: {min(lrs):.6f} → {max(lrs):.6f}")
|
| 1195 |
+
|
| 1196 |
+
# Plot Rewards curve if we captured it
|
| 1197 |
+
if lr_history is not None and lr_history['rewards']:
|
| 1198 |
+
plot_path = Path(args.output_dir) / "rewards_curve.png"
|
| 1199 |
+
|
| 1200 |
+
# Reward values are now continuous across all gradient steps
|
| 1201 |
+
rewards = lr_history['rewards']
|
| 1202 |
+
steps = list(range(len(rewards))) # Step indices (0 to total_steps-1)
|
| 1203 |
+
|
| 1204 |
+
plt.figure(figsize=(12, 6))
|
| 1205 |
+
plt.plot(steps, rewards, linewidth=2, color='green', alpha=0.8)
|
| 1206 |
+
|
| 1207 |
+
# Mark the first step with a star
|
| 1208 |
+
plt.plot(steps[0], rewards[0], marker='*', markersize=20, color='gold',
|
| 1209 |
+
markeredgecolor='darkgoldenrod', markeredgewidth=2, zorder=5)
|
| 1210 |
+
|
| 1211 |
+
# Mark timestep boundaries
|
| 1212 |
+
num_timesteps = len(lr_history['timesteps'])
|
| 1213 |
+
# rewards has one extra value at the start (initial) compared to gradient steps
|
| 1214 |
+
num_grad_steps_per_timestep = (len(rewards) - num_timesteps) // num_timesteps if num_timesteps > 0 else 0
|
| 1215 |
+
if num_grad_steps_per_timestep > 0:
|
| 1216 |
+
for i in range(num_timesteps + 1):
|
| 1217 |
+
step_idx = i * (num_grad_steps_per_timestep + 1) # +1 because reward_history includes initial
|
| 1218 |
+
if step_idx <= len(rewards):
|
| 1219 |
+
plt.axvline(x=step_idx, color='red', linestyle='--', alpha=0.3, linewidth=1)
|
| 1220 |
+
if i < num_timesteps:
|
| 1221 |
+
plt.text(step_idx, plt.ylim()[1] * 0.95, f't={lr_history["timesteps"][i]}',
|
| 1222 |
+
fontsize=8, color='red', alpha=0.7, ha='left')
|
| 1223 |
+
|
| 1224 |
+
plt.xlabel('Global Gradient Step', fontsize=12)
|
| 1225 |
+
plt.ylabel('Reward Score', fontsize=12)
|
| 1226 |
+
plt.title(f'Reward Evolution Across All Gradient Steps\nPrompt: "{lr_history["prompt"][:60]}..."',
|
| 1227 |
+
fontsize=12, fontweight='bold')
|
| 1228 |
+
plt.grid(True, alpha=0.3)
|
| 1229 |
+
|
| 1230 |
+
# Add info text
|
| 1231 |
+
num_timesteps = len(lr_history['timesteps'])
|
| 1232 |
+
reward_improvement = rewards[-1] - rewards[0] if len(rewards) > 1 else 0
|
| 1233 |
+
plt.text(0.02, 0.98,
|
| 1234 |
+
f'Total timesteps: {num_timesteps}\nTotal grad steps: {len(rewards)}\n'
|
| 1235 |
+
f'Initial reward: {rewards[0]:.4f}\nFinal reward: {rewards[-1]:.4f}\n'
|
| 1236 |
+
f'Improvement: {reward_improvement:+.4f}',
|
| 1237 |
+
transform=plt.gca().transAxes, fontsize=10, verticalalignment='top',
|
| 1238 |
+
bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.5))
|
| 1239 |
+
|
| 1240 |
+
plt.tight_layout()
|
| 1241 |
+
plt.savefig(plot_path, dpi=150, bbox_inches='tight')
|
| 1242 |
+
plt.close()
|
| 1243 |
+
print(f"\n✓ Saved Rewards curve plot to: {plot_path}")
|
| 1244 |
+
print(f" Total gradient steps: {len(rewards)}")
|
| 1245 |
+
print(f" Reward range: {min(rewards):.4f} → {max(rewards):.4f}")
|
| 1246 |
+
print(f" Total improvement: {reward_improvement:+.4f}")
|
| 1247 |
+
|
| 1248 |
+
# ---> NEW: PLOT TRAJECTORY DIVERGENCE (MANIFOLD DRIFT) <---
|
| 1249 |
+
if args.mode == "both" and 'baseline_trajectory' in locals() and 'guided_trajectory' in locals():
|
| 1250 |
+
if len(baseline_trajectory) == len(guided_trajectory) and len(baseline_trajectory) > 0:
|
| 1251 |
+
print("\n" + "="*70)
|
| 1252 |
+
print("7. CALCULATING TRAJECTORY DIVERGENCE (THEOREM 1 & 2)")
|
| 1253 |
+
print("="*70)
|
| 1254 |
+
|
| 1255 |
+
drift_path = Path(args.output_dir) / "trajectory_drift.png"
|
| 1256 |
+
|
| 1257 |
+
l2_distances = []
|
| 1258 |
+
# Calculate L2 norm ||z_t_guided - z_t_base||_2 for each step
|
| 1259 |
+
for b_lat, g_lat in zip(baseline_trajectory, guided_trajectory):
|
| 1260 |
+
dist = torch.norm(g_lat.float() - b_lat.float(), p=2).item()
|
| 1261 |
+
l2_distances.append(dist)
|
| 1262 |
+
|
| 1263 |
+
steps = list(range(len(l2_distances)))
|
| 1264 |
+
|
| 1265 |
+
plt.figure(figsize=(10, 6))
|
| 1266 |
+
plt.plot(steps, l2_distances, linewidth=2.5, color='purple', marker='o', markersize=4)
|
| 1267 |
+
|
| 1268 |
+
plt.xlabel('Denoising Step', fontsize=12)
|
| 1269 |
+
plt.ylabel('L2 Distance: ||z_guided - z_base||_2', fontsize=12)
|
| 1270 |
+
plt.title('Latent Trajectory Divergence (Manifold Drift)', fontsize=14, fontweight='bold')
|
| 1271 |
+
plt.grid(True, alpha=0.3)
|
| 1272 |
+
|
| 1273 |
+
# Add interpretation text based on your theory
|
| 1274 |
+
max_drift = max(l2_distances)
|
| 1275 |
+
plt.text(0.02, 0.98,
|
| 1276 |
+
f'Max Drift: {max_drift:.4f}\n'
|
| 1277 |
+
f'Final Drift: {l2_distances[-1]:.4f}\n'
|
| 1278 |
+
f'(Matches bounded drift from Thm 1\n'
|
| 1279 |
+
f'or ODE stiffness collapse from Thm 2)',
|
| 1280 |
+
transform=plt.gca().transAxes, fontsize=10, verticalalignment='top',
|
| 1281 |
+
bbox=dict(boxstyle='round', facecolor='thistle', alpha=0.5))
|
| 1282 |
+
|
| 1283 |
+
plt.tight_layout()
|
| 1284 |
+
plt.savefig(drift_path, dpi=150, bbox_inches='tight')
|
| 1285 |
+
plt.close()
|
| 1286 |
+
print(f"? Saved Manifold Drift curve to: {drift_path}")
|
| 1287 |
+
print(f" Max L2 Distance from baseline: {max_drift:.4f}")
|
| 1288 |
+
|
| 1289 |
+
# ========== FINAL RESULTS ==========
|
| 1290 |
+
print("\n" + "="*70)
|
| 1291 |
+
print("FINAL RESULTS")
|
| 1292 |
+
print("="*70)
|
| 1293 |
+
|
| 1294 |
+
if avg_reward_baseline is not None:
|
| 1295 |
+
print(f"\nBaseline:")
|
| 1296 |
+
if fid_score_baseline is not None:
|
| 1297 |
+
print(f" FID Score: {fid_score_baseline:.4f}")
|
| 1298 |
+
print(f" Avg Reward: {avg_reward_baseline:.4f}")
|
| 1299 |
+
if "clip" in args.metrics and clip_score_baseline is not None:
|
| 1300 |
+
print(f" Avg CLIP Score: {clip_score_baseline:.4f}")
|
| 1301 |
+
if "aesthetic" in args.metrics and aesthetic_score_baseline is not None:
|
| 1302 |
+
print(f" Avg Aesthetic: {aesthetic_score_baseline:.4f}")
|
| 1303 |
+
if "pickscore" in args.metrics and pick_score_baseline is not None:
|
| 1304 |
+
print(f" Avg PickScore: {pick_score_baseline:.4f}")
|
| 1305 |
+
if "hpsv2" in args.metrics and hpsv2_score_baseline is not None:
|
| 1306 |
+
print(f" Avg HPSv2: {hpsv2_score_baseline:.4f}")
|
| 1307 |
+
if "hpsv21" in args.metrics and hpsv21_score_baseline is not None:
|
| 1308 |
+
print(f" Avg HPSv2.1: {hpsv21_score_baseline:.4f}")
|
| 1309 |
+
if "imagereward" in args.metrics and imagereward_score_baseline is not None:
|
| 1310 |
+
print(f" Avg ImageReward: {imagereward_score_baseline:.4f}")
|
| 1311 |
+
|
| 1312 |
+
if avg_reward_grad is not None:
|
| 1313 |
+
print(f"\nGradient Ascent:")
|
| 1314 |
+
if fid_score_grad is not None:
|
| 1315 |
+
print(f" FID Score: {fid_score_grad:.4f}")
|
| 1316 |
+
print(f" Avg Reward: {avg_reward_grad:.4f}")
|
| 1317 |
+
if "clip" in args.metrics and clip_score_grad is not None:
|
| 1318 |
+
print(f" Avg CLIP Score: {clip_score_grad:.4f}")
|
| 1319 |
+
if "aesthetic" in args.metrics and aesthetic_score_grad is not None:
|
| 1320 |
+
print(f" Avg Aesthetic: {aesthetic_score_grad:.4f}")
|
| 1321 |
+
if "pickscore" in args.metrics and pick_score_grad is not None:
|
| 1322 |
+
print(f" Avg PickScore: {pick_score_grad:.4f}")
|
| 1323 |
+
if "hpsv2" in args.metrics and hpsv2_score_grad is not None:
|
| 1324 |
+
print(f" Avg HPSv2: {hpsv2_score_grad:.4f}")
|
| 1325 |
+
if "hpsv21" in args.metrics and hpsv21_score_grad is not None:
|
| 1326 |
+
print(f" Avg HPSv2.1: {hpsv21_score_grad:.4f}")
|
| 1327 |
+
if "imagereward" in args.metrics and imagereward_score_grad is not None:
|
| 1328 |
+
print(f" Avg ImageReward: {imagereward_score_grad:.4f}")
|
| 1329 |
+
|
| 1330 |
+
if avg_reward_baseline is not None and avg_reward_grad is not None:
|
| 1331 |
+
print(f"\nComparison:")
|
| 1332 |
+
if fid_score_baseline is not None and fid_score_grad is not None:
|
| 1333 |
+
fid_diff = fid_score_grad - fid_score_baseline
|
| 1334 |
+
print(f" FID Change: {fid_diff:+.4f} ({'worse' if fid_diff > 0 else 'better'}, lower is better)")
|
| 1335 |
+
reward_diff = avg_reward_grad - avg_reward_baseline
|
| 1336 |
+
print(f" Reward Change: {reward_diff:+.4f} ({'better' if reward_diff > 0 else 'worse'}, higher is better)")
|
| 1337 |
+
if "clip" in args.metrics and clip_score_baseline is not None and clip_score_grad is not None:
|
| 1338 |
+
clip_diff = clip_score_grad - clip_score_baseline
|
| 1339 |
+
print(f" CLIP Change: {clip_diff:+.4f} ({'better' if clip_diff > 0 else 'worse'}, higher is better)")
|
| 1340 |
+
if "aesthetic" in args.metrics and aesthetic_score_baseline is not None and aesthetic_score_grad is not None:
|
| 1341 |
+
aesthetic_diff = aesthetic_score_grad - aesthetic_score_baseline
|
| 1342 |
+
print(f" Aesthetic Change: {aesthetic_diff:+.4f} ({'better' if aesthetic_diff > 0 else 'worse'}, higher is better)")
|
| 1343 |
+
if "pickscore" in args.metrics and pick_score_baseline is not None and pick_score_grad is not None:
|
| 1344 |
+
pick_diff = pick_score_grad - pick_score_baseline
|
| 1345 |
+
print(f" PickScore Change: {pick_diff:+.4f} ({'better' if pick_diff > 0 else 'worse'}, higher is better)")
|
| 1346 |
+
if "hpsv2" in args.metrics and hpsv2_score_baseline is not None and hpsv2_score_grad is not None:
|
| 1347 |
+
hpsv2_diff = hpsv2_score_grad - hpsv2_score_baseline
|
| 1348 |
+
print(f" HPSv2 Change: {hpsv2_diff:+.4f} ({'better' if hpsv2_diff > 0 else 'worse'}, higher is better)")
|
| 1349 |
+
if "hpsv21" in args.metrics and hpsv21_score_baseline is not None and hpsv21_score_grad is not None:
|
| 1350 |
+
hpsv21_diff = hpsv21_score_grad - hpsv21_score_baseline
|
| 1351 |
+
print(f" HPSv2.1 Change: {hpsv21_diff:+.4f} ({'better' if hpsv21_diff > 0 else 'worse'}, higher is better)")
|
| 1352 |
+
if "imagereward" in args.metrics and imagereward_score_baseline is not None and imagereward_score_grad is not None:
|
| 1353 |
+
imagereward_diff = imagereward_score_grad - imagereward_score_baseline
|
| 1354 |
+
print(f" ImageReward Chg: {imagereward_diff:+.4f} ({'better' if imagereward_diff > 0 else 'worse'}, higher is better)")
|
| 1355 |
+
|
| 1356 |
+
# Save results to file
|
| 1357 |
+
results = {
|
| 1358 |
+
"mode": args.mode,
|
| 1359 |
+
"metrics": args.metrics,
|
| 1360 |
+
"config": {
|
| 1361 |
+
"num_samples": len(prompts),
|
| 1362 |
+
"num_steps": args.num_steps,
|
| 1363 |
+
"cfg_scale": args.cfg_scale,
|
| 1364 |
+
"grad_range": [args.grad_range_start, args.grad_range_end],
|
| 1365 |
+
"grad_steps": args.grad_steps,
|
| 1366 |
+
"grad_step_size": args.grad_step_size
|
| 1367 |
+
}
|
| 1368 |
+
}
|
| 1369 |
+
|
| 1370 |
+
if avg_reward_baseline is not None:
|
| 1371 |
+
results["baseline"] = {"avg_reward": avg_reward_baseline}
|
| 1372 |
+
if fid_score_baseline is not None:
|
| 1373 |
+
results["baseline"]["fid"] = fid_score_baseline
|
| 1374 |
+
if "clip" in args.metrics and clip_score_baseline is not None:
|
| 1375 |
+
results["baseline"]["clip_score"] = clip_score_baseline
|
| 1376 |
+
if "aesthetic" in args.metrics and aesthetic_score_baseline is not None:
|
| 1377 |
+
results["baseline"]["aesthetic_score"] = aesthetic_score_baseline
|
| 1378 |
+
if "pickscore" in args.metrics and pick_score_baseline is not None:
|
| 1379 |
+
results["baseline"]["pickscore"] = pick_score_baseline
|
| 1380 |
+
if "hpsv2" in args.metrics and hpsv2_score_baseline is not None:
|
| 1381 |
+
results["baseline"]["hpsv2_score"] = hpsv2_score_baseline
|
| 1382 |
+
if "hpsv21" in args.metrics and hpsv21_score_baseline is not None:
|
| 1383 |
+
results["baseline"]["hpsv21_score"] = hpsv21_score_baseline
|
| 1384 |
+
if "imagereward" in args.metrics and imagereward_score_baseline is not None:
|
| 1385 |
+
results["baseline"]["imagereward_score"] = imagereward_score_baseline
|
| 1386 |
+
|
| 1387 |
+
if avg_reward_grad is not None:
|
| 1388 |
+
results["gradient_ascent"] = {"avg_reward": avg_reward_grad}
|
| 1389 |
+
if fid_score_grad is not None:
|
| 1390 |
+
results["gradient_ascent"]["fid"] = fid_score_grad
|
| 1391 |
+
if "clip" in args.metrics and clip_score_grad is not None:
|
| 1392 |
+
results["gradient_ascent"]["clip_score"] = clip_score_grad
|
| 1393 |
+
if "aesthetic" in args.metrics and aesthetic_score_grad is not None:
|
| 1394 |
+
results["gradient_ascent"]["aesthetic_score"] = aesthetic_score_grad
|
| 1395 |
+
if "pickscore" in args.metrics and pick_score_grad is not None:
|
| 1396 |
+
results["gradient_ascent"]["pickscore"] = pick_score_grad
|
| 1397 |
+
if "hpsv2" in args.metrics and hpsv2_score_grad is not None:
|
| 1398 |
+
results["gradient_ascent"]["hpsv2_score"] = hpsv2_score_grad
|
| 1399 |
+
if "hpsv21" in args.metrics and hpsv21_score_grad is not None:
|
| 1400 |
+
results["gradient_ascent"]["hpsv21_score"] = hpsv21_score_grad
|
| 1401 |
+
if "imagereward" in args.metrics and imagereward_score_grad is not None:
|
| 1402 |
+
results["gradient_ascent"]["imagereward_score"] = imagereward_score_grad
|
| 1403 |
+
if grad_stats:
|
| 1404 |
+
results["gradient_ascent"]["stats"] = grad_stats
|
| 1405 |
+
|
| 1406 |
+
if avg_reward_baseline is not None and avg_reward_grad is not None:
|
| 1407 |
+
results["comparison"] = {
|
| 1408 |
+
"reward_difference": avg_reward_grad - avg_reward_baseline
|
| 1409 |
+
}
|
| 1410 |
+
if fid_score_baseline is not None and fid_score_grad is not None:
|
| 1411 |
+
results["comparison"]["fid_difference"] = fid_score_grad - fid_score_baseline
|
| 1412 |
+
if "clip" in args.metrics and clip_score_baseline is not None and clip_score_grad is not None:
|
| 1413 |
+
results["comparison"]["clip_difference"] = clip_score_grad - clip_score_baseline
|
| 1414 |
+
if "aesthetic" in args.metrics and aesthetic_score_baseline is not None and aesthetic_score_grad is not None:
|
| 1415 |
+
results["comparison"]["aesthetic_difference"] = aesthetic_score_grad - aesthetic_score_baseline
|
| 1416 |
+
if "pickscore" in args.metrics and pick_score_baseline is not None and pick_score_grad is not None:
|
| 1417 |
+
results["comparison"]["pickscore_difference"] = pick_score_grad - pick_score_baseline
|
| 1418 |
+
if "hpsv2" in args.metrics and hpsv2_score_baseline is not None and hpsv2_score_grad is not None:
|
| 1419 |
+
results["comparison"]["hpsv2_difference"] = hpsv2_score_grad - hpsv2_score_baseline
|
| 1420 |
+
if "hpsv21" in args.metrics and hpsv21_score_baseline is not None and hpsv21_score_grad is not None:
|
| 1421 |
+
results["comparison"]["hpsv21_difference"] = hpsv21_score_grad - hpsv21_score_baseline
|
| 1422 |
+
if "imagereward" in args.metrics and imagereward_score_baseline is not None and imagereward_score_grad is not None:
|
| 1423 |
+
results["comparison"]["imagereward_difference"] = imagereward_score_grad - imagereward_score_baseline
|
| 1424 |
+
|
| 1425 |
+
# Save results to output directory
|
| 1426 |
+
output_path = Path(args.output_dir)
|
| 1427 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 1428 |
+
results_path = output_path / "evaluation_results.txt"
|
| 1429 |
+
|
| 1430 |
+
with open(results_path, "w") as f:
|
| 1431 |
+
for k, v in results.items():
|
| 1432 |
+
f.write(f"{k}: {v}\n")
|
| 1433 |
+
|
| 1434 |
+
|
| 1435 |
+
print(f"\n✓ Results saved to: {results_path}")
|
| 1436 |
+
if args.save_images:
|
| 1437 |
+
print(f"✓ Generated images saved to: {output_path}/baseline/ and {output_path}/gradient_ascent/")
|
| 1438 |
+
print("\n" + "="*70)
|
| 1439 |
+
|
| 1440 |
+
# Close logger
|
| 1441 |
+
tee_logger.close()
|
| 1442 |
+
sys.stdout = tee_logger.terminal
|
| 1443 |
+
|
| 1444 |
+
|
| 1445 |
+
if __name__ == "__main__":
|
| 1446 |
+
main()
|
| 1447 |
+
|
Reward_sana_idealized/examples.sh
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
# bash examples.sh
|
| 4 |
+
if [[ -n "${TERM:-}" ]]; then
|
| 5 |
+
clear
|
| 6 |
+
fi
|
| 7 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 8 |
+
cd "$SCRIPT_DIR"
|
| 9 |
+
|
| 10 |
+
# Shared HF cache used on this cluster.
|
| 11 |
+
HF_HUB_CACHE_DIR="${HF_HUB_CACHE_DIR:-/scratch/rr81/ma5430/.cache/huggingface/hub}"
|
| 12 |
+
export HF_HUB_CACHE="$HF_HUB_CACHE_DIR"
|
| 13 |
+
export HUGGINGFACE_HUB_CACHE="$HF_HUB_CACHE_DIR"
|
| 14 |
+
export HF_HOME="$(dirname "$HF_HUB_CACHE_DIR")"
|
| 15 |
+
|
| 16 |
+
# GPU nodes have no internet, while login nodes do.
|
| 17 |
+
# Auto default: offline on GPU nodes, online on login nodes.
|
| 18 |
+
DEFAULT_OFFLINE_MODE="1"
|
| 19 |
+
if ! (command -v nvidia-smi >/dev/null 2>&1 && nvidia-smi -L >/dev/null 2>&1); then
|
| 20 |
+
DEFAULT_OFFLINE_MODE="0"
|
| 21 |
+
fi
|
| 22 |
+
OFFLINE_MODE="${OFFLINE_MODE:-$DEFAULT_OFFLINE_MODE}"
|
| 23 |
+
|
| 24 |
+
if [[ "$OFFLINE_MODE" == "1" ]]; then
|
| 25 |
+
export HF_DATASETS_OFFLINE="1"
|
| 26 |
+
export HF_METRICS_OFFLINE="1"
|
| 27 |
+
export HF_MODULES_OFFLINE="1"
|
| 28 |
+
export TRANSFORMERS_OFFLINE="1"
|
| 29 |
+
export DIFFUSERS_OFFLINE="1"
|
| 30 |
+
export HF_HUB_OFFLINE="1"
|
| 31 |
+
else
|
| 32 |
+
export HF_DATASETS_OFFLINE="0"
|
| 33 |
+
export HF_METRICS_OFFLINE="0"
|
| 34 |
+
export HF_MODULES_OFFLINE="0"
|
| 35 |
+
export TRANSFORMERS_OFFLINE="0"
|
| 36 |
+
export DIFFUSERS_OFFLINE="0"
|
| 37 |
+
export HF_HUB_OFFLINE="0"
|
| 38 |
+
fi
|
| 39 |
+
|
| 40 |
+
# Existing environment requested by user.
|
| 41 |
+
PYTHON_BIN="${PYTHON_BIN:-/g/data/rr81/aev/bin/python}"
|
| 42 |
+
if [[ ! -x "$PYTHON_BIN" ]]; then
|
| 43 |
+
echo "[examples.sh] Missing Python executable: $PYTHON_BIN" >&2
|
| 44 |
+
exit 1
|
| 45 |
+
fi
|
| 46 |
+
|
| 47 |
+
DATASET_NAME="${DATASET_NAME:-pickapic}" # coco | pickapic
|
| 48 |
+
GRAD_CONFIG="${GRAD_CONFIG:-one_step_rectification_config}"
|
| 49 |
+
MODEL_PROFILE="${MODEL_PROFILE:-sana_600m_512}" # sana_600m_512 | sana_1600m_512 | sana_sprint_0_6b_1024 | sana_sprint_1_6b_1024
|
| 50 |
+
MODE="${MODE:-gradient_ascent}" # gradient_ascent | baseline | both
|
| 51 |
+
# Empty MAX_SAMPLES means evaluate all available samples.
|
| 52 |
+
MAX_SAMPLES="${MAX_SAMPLES:-}"
|
| 53 |
+
NUM_STEPS="${NUM_STEPS:-20}"
|
| 54 |
+
CFG_SCALE="${CFG_SCALE:-4.5}"
|
| 55 |
+
DTYPE="${DTYPE:-bf16}" # auto | bf16 | fp16 | fp32
|
| 56 |
+
METRICS="${METRICS:-clip aesthetic pickscore hpsv2 hpsv21 imagereward}"
|
| 57 |
+
PREFETCH_ONLY="${PREFETCH_ONLY:-0}"
|
| 58 |
+
|
| 59 |
+
# Override this path whenever you want to swap reward weights.
|
| 60 |
+
# LRM_MODEL_PATH="${LRM_MODEL_PATH:-/g/data/rr81/LPO/lrm/lrm_sana/logs/v8/reward_model/step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951/checkpoint-gstep33000}"
|
| 61 |
+
LRM_MODEL_PATH="${LRM_MODEL_PATH:-/g/data/rr81/LPO/lrm/lrm_sana/logs/v7/reward_model/step_sana_sana_600m_512_variable-t_lr1e-5_step-8000_filter2_time951/checkpoint-gstep32000}"
|
| 62 |
+
|
| 63 |
+
if [[ -z "${GPU_ID:-}" ]]; then
|
| 64 |
+
if command -v nvidia-smi >/dev/null 2>&1 && nvidia-smi -L >/dev/null 2>&1; then
|
| 65 |
+
GPU_ID="$(nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits | sort -k2 -n | head -n1 | cut -d',' -f1 | tr -d ' ')"
|
| 66 |
+
GPU_ID="${GPU_ID:-0}"
|
| 67 |
+
else
|
| 68 |
+
GPU_ID="0"
|
| 69 |
+
echo "[examples.sh] No visible NVIDIA GPU on this node. Defaulting GPU_ID=0."
|
| 70 |
+
echo "[examples.sh] eval.py will run on CPU if CUDA is unavailable."
|
| 71 |
+
fi
|
| 72 |
+
fi
|
| 73 |
+
|
| 74 |
+
echo "Using GPU ID: $GPU_ID"
|
| 75 |
+
echo "Using LRM weights: $LRM_MODEL_PATH"
|
| 76 |
+
echo "HF offline mode: $OFFLINE_MODE"
|
| 77 |
+
echo "Generation dtype: $DTYPE"
|
| 78 |
+
|
| 79 |
+
if [[ "$PREFETCH_ONLY" == "1" ]]; then
|
| 80 |
+
echo "[examples.sh] PREFETCH_ONLY=1 -> downloading required model files to shared cache and exiting."
|
| 81 |
+
export MODEL_PROFILE
|
| 82 |
+
export METRICS
|
| 83 |
+
"$PYTHON_BIN" - <<'PY'
|
| 84 |
+
import os
|
| 85 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 86 |
+
|
| 87 |
+
cache_dir = os.environ["HF_HUB_CACHE"]
|
| 88 |
+
model_profile = os.environ.get("MODEL_PROFILE", "sana_600m_512")
|
| 89 |
+
metrics = set(os.environ.get("METRICS", "clip aesthetic").split())
|
| 90 |
+
|
| 91 |
+
profile_to_repo = {
|
| 92 |
+
"sana_600m_512": "Efficient-Large-Model/Sana_600M_512px_diffusers",
|
| 93 |
+
"sana_1600m_512": "Efficient-Large-Model/Sana_1600M_512px_diffusers",
|
| 94 |
+
"sana_sprint_0_6b_1024": "Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers",
|
| 95 |
+
"sana_sprint_1_6b_1024": "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
def snap(repo_id):
|
| 99 |
+
print(f"[prefetch] snapshot_download: {repo_id}")
|
| 100 |
+
snapshot_download(repo_id=repo_id, cache_dir=cache_dir, local_files_only=False)
|
| 101 |
+
|
| 102 |
+
def one(repo_id, filename):
|
| 103 |
+
print(f"[prefetch] hf_hub_download: {repo_id}/{filename}")
|
| 104 |
+
hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir, local_files_only=False)
|
| 105 |
+
|
| 106 |
+
if model_profile not in profile_to_repo:
|
| 107 |
+
raise ValueError(f"Unknown MODEL_PROFILE={model_profile}")
|
| 108 |
+
|
| 109 |
+
# Base SANA model used for generation + reward backbone
|
| 110 |
+
snap(profile_to_repo[model_profile])
|
| 111 |
+
|
| 112 |
+
# Required for CLIP-based metrics and LRM text projection init fallback
|
| 113 |
+
if "clip" in metrics or "aesthetic" in metrics:
|
| 114 |
+
snap("openai/clip-vit-large-patch14")
|
| 115 |
+
|
| 116 |
+
if "pickscore" in metrics:
|
| 117 |
+
snap("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
| 118 |
+
snap("yuvalkirstain/PickScore_v1")
|
| 119 |
+
|
| 120 |
+
if "hpsv2" in metrics or "hpsv21" in metrics:
|
| 121 |
+
one("laion/CLIP-ViT-H-14-laion2B-s32B-b79K", "open_clip_pytorch_model.bin")
|
| 122 |
+
if "hpsv2" in metrics:
|
| 123 |
+
one("xswu/HPSv2", "HPS_v2_compressed.pt")
|
| 124 |
+
if "hpsv21" in metrics:
|
| 125 |
+
one("xswu/HPSv2", "HPS_v2.1_compressed.pt")
|
| 126 |
+
|
| 127 |
+
if "imagereward" in metrics:
|
| 128 |
+
one("THUDM/ImageReward", "ImageReward.pt")
|
| 129 |
+
one("THUDM/ImageReward", "med_config.json")
|
| 130 |
+
|
| 131 |
+
print("[prefetch] done")
|
| 132 |
+
PY
|
| 133 |
+
exit 0
|
| 134 |
+
fi
|
| 135 |
+
|
| 136 |
+
read -r -a METRICS_ARR <<< "$METRICS"
|
| 137 |
+
|
| 138 |
+
CMD=(
|
| 139 |
+
"$PYTHON_BIN" eval.py
|
| 140 |
+
--model_variant "$MODEL_PROFILE"
|
| 141 |
+
--dataset_type "$DATASET_NAME"
|
| 142 |
+
--lrm_model "$LRM_MODEL_PATH"
|
| 143 |
+
--grad_config "$GRAD_CONFIG"
|
| 144 |
+
--metrics "${METRICS_ARR[@]}"
|
| 145 |
+
--num_steps "$NUM_STEPS"
|
| 146 |
+
--cfg_scale "$CFG_SCALE"
|
| 147 |
+
--dtype "$DTYPE"
|
| 148 |
+
--hf_cache_dir "$HF_HUB_CACHE_DIR"
|
| 149 |
+
--output_dir "RESULTS/$DATASET_NAME/${GRAD_CONFIG}_${MODEL_PROFILE}"
|
| 150 |
+
--cuda "$GPU_ID"
|
| 151 |
+
--mode "$MODE"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
if [[ -n "$MAX_SAMPLES" ]]; then
|
| 155 |
+
CMD+=(--max_samples "$MAX_SAMPLES")
|
| 156 |
+
fi
|
| 157 |
+
|
| 158 |
+
if [[ "$OFFLINE_MODE" == "1" ]]; then
|
| 159 |
+
CMD+=(--offline)
|
| 160 |
+
fi
|
| 161 |
+
|
| 162 |
+
"${CMD[@]}"
|
Reward_sana_idealized/grad_ascent_configs.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration presets for gradient ascent optimization.
|
| 3 |
+
|
| 4 |
+
Provides pre-configured settings for various optimization strategies
|
| 5 |
+
including learning rate scheduling and momentum configurations.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Dict, Any
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
ONE_STEP_RECTIFICATION_CONFIG = {
|
| 12 |
+
"grad_timestep_range": (100, 800), # Match SDXL one-step rectification window
|
| 13 |
+
"num_grad_steps": 1,
|
| 14 |
+
"grad_step_size": 1.0,
|
| 15 |
+
"grad_scale": 1.0,
|
| 16 |
+
"lr_scheduler_type": "constant",
|
| 17 |
+
"use_momentum": False,
|
| 18 |
+
"use_nesterov": False,
|
| 19 |
+
"use_iso_projection": False
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
# ============================================================================
|
| 23 |
+
# Config Dictionary (for easy access)
|
| 24 |
+
# ============================================================================
|
| 25 |
+
|
| 26 |
+
CONFIGS = {
|
| 27 |
+
"one_step_rectification_config": ONE_STEP_RECTIFICATION_CONFIG,
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_config(config_name: str) -> Dict[str, Any]:
|
| 32 |
+
"""
|
| 33 |
+
Get a gradient ascent configuration by name.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
config_name: Name of the configuration
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
Configuration dictionary
|
| 40 |
+
|
| 41 |
+
Raises:
|
| 42 |
+
ValueError: If config_name is not found
|
| 43 |
+
|
| 44 |
+
Example:
|
| 45 |
+
config = get_config("cosine_nesterov")
|
| 46 |
+
pipeline.enable_gradient_ascent(**config)
|
| 47 |
+
"""
|
| 48 |
+
if config_name not in CONFIGS:
|
| 49 |
+
available = ", ".join(sorted(CONFIGS.keys()))
|
| 50 |
+
raise ValueError(f"Unknown config: {config_name}. Available: {available}")
|
| 51 |
+
|
| 52 |
+
return CONFIGS[config_name].copy()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def list_configs() -> list:
|
| 56 |
+
"""List all available configuration names."""
|
| 57 |
+
return sorted(CONFIGS.keys())
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def print_config(config_name: str):
|
| 61 |
+
"""Print a configuration in a readable format."""
|
| 62 |
+
config = get_config(config_name)
|
| 63 |
+
print(f"\nConfiguration: {config_name}")
|
| 64 |
+
print("=" * 60)
|
| 65 |
+
for key, value in config.items():
|
| 66 |
+
print(f" {key}: {value}")
|
| 67 |
+
print("=" * 60)
|
Reward_sana_idealized/gradient_ascent_utils.py
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradient Ascent utilities for reward-guided diffusion generation.
|
| 3 |
+
|
| 4 |
+
This module implements gradient ascent on the LRM reward score to guide
|
| 5 |
+
the diffusion process toward higher preference scores.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from typing import Optional, Tuple, List, Literal
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from lr_scheduler import create_lr_scheduler, LRScheduler
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class RewardGuidedDiffusion:
|
| 16 |
+
"""
|
| 17 |
+
Implements reward-guided generation using gradient ascent.
|
| 18 |
+
|
| 19 |
+
During denoising, at specified timesteps, we:
|
| 20 |
+
1. Compute the reward score for current latents
|
| 21 |
+
2. Calculate gradients of reward w.r.t. latents
|
| 22 |
+
3. Update latents in the direction that increases reward
|
| 23 |
+
|
| 24 |
+
This guides generation toward higher preference scores.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
reward_model,
|
| 30 |
+
grad_scale: float = 1.0,
|
| 31 |
+
grad_timestep_range: Optional[Tuple[int, int]] = None,
|
| 32 |
+
num_grad_steps: int = 5,
|
| 33 |
+
grad_step_size: float = 0.1,
|
| 34 |
+
gradient_checkpoint: bool = False,
|
| 35 |
+
# LR Scheduling
|
| 36 |
+
lr_scheduler_type: Literal["constant", "linear", "cosine", "exponential", "step"] = "constant",
|
| 37 |
+
lr_scheduler_kwargs: Optional[dict] = None,
|
| 38 |
+
# Momentum
|
| 39 |
+
use_momentum: bool = False,
|
| 40 |
+
momentum: float = 0.9,
|
| 41 |
+
use_nesterov: bool = False,
|
| 42 |
+
use_iso_projection: bool = False
|
| 43 |
+
):
|
| 44 |
+
"""
|
| 45 |
+
Initialize reward-guided diffusion.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
reward_model: LRM reward model for computing preference scores
|
| 49 |
+
grad_scale: Scale factor for gradient updates (default: 1.0)
|
| 50 |
+
grad_timestep_range: Tuple of (min_t, max_t) for gradient ascent.
|
| 51 |
+
If None, applies to all timesteps.
|
| 52 |
+
num_grad_steps: Number of gradient ascent steps per timestep
|
| 53 |
+
grad_step_size: Step size for each gradient update (initial LR)
|
| 54 |
+
gradient_checkpoint: Whether to use gradient checkpointing
|
| 55 |
+
lr_scheduler_type: Type of LR scheduler ("constant", "linear", "cosine", "exponential", "step")
|
| 56 |
+
lr_scheduler_kwargs: Additional kwargs for LR scheduler (e.g., end_lr, min_lr, warmup_steps)
|
| 57 |
+
use_momentum: Whether to use momentum in gradient updates
|
| 58 |
+
momentum: Momentum coefficient (typically 0.9)
|
| 59 |
+
use_nesterov: Whether to use Nesterov momentum
|
| 60 |
+
use_iso_projection: Whether to use Iso Projection
|
| 61 |
+
"""
|
| 62 |
+
self.reward_model = reward_model
|
| 63 |
+
self.grad_scale = grad_scale
|
| 64 |
+
self.grad_timestep_range = grad_timestep_range
|
| 65 |
+
self.num_grad_steps = num_grad_steps
|
| 66 |
+
self.grad_step_size = grad_step_size
|
| 67 |
+
self.gradient_checkpoint = gradient_checkpoint
|
| 68 |
+
|
| 69 |
+
# LR Scheduler
|
| 70 |
+
self.lr_scheduler_type = lr_scheduler_type
|
| 71 |
+
self.lr_scheduler_kwargs = lr_scheduler_kwargs or {}
|
| 72 |
+
self.lr_scheduler: Optional[LRScheduler] = None
|
| 73 |
+
self.global_lr_scheduler: Optional[LRScheduler] = None # Scheduler across denoising timesteps
|
| 74 |
+
|
| 75 |
+
# Momentum
|
| 76 |
+
self.use_momentum = use_momentum
|
| 77 |
+
self.momentum = momentum
|
| 78 |
+
self.use_nesterov = use_nesterov
|
| 79 |
+
self.velocity = None # Will be initialized per optimization
|
| 80 |
+
|
| 81 |
+
self.use_iso_projection = use_iso_projection
|
| 82 |
+
|
| 83 |
+
# Statistics
|
| 84 |
+
self.grad_stats = []
|
| 85 |
+
self.timestep_counter = 0 # Track which timestep we're on
|
| 86 |
+
|
| 87 |
+
def should_apply_gradient(self, timestep: int) -> bool:
|
| 88 |
+
"""Check if gradient ascent should be applied at this timestep."""
|
| 89 |
+
|
| 90 |
+
if self.grad_timestep_range is None:
|
| 91 |
+
return False
|
| 92 |
+
|
| 93 |
+
min_t, max_t = self.grad_timestep_range
|
| 94 |
+
return min_t <= timestep <= max_t
|
| 95 |
+
|
| 96 |
+
@torch.enable_grad()
|
| 97 |
+
def compute_reward_gradient(
|
| 98 |
+
self,
|
| 99 |
+
latents: torch.Tensor,
|
| 100 |
+
prompt,
|
| 101 |
+
timestep: int,
|
| 102 |
+
) -> Tuple[torch.Tensor, float]:
|
| 103 |
+
"""
|
| 104 |
+
Compute gradient of reward score w.r.t. latents in FP32 to prevent underflow.
|
| 105 |
+
"""
|
| 106 |
+
# 1. Cast to FP32 and ensure we are detached from previous iterations
|
| 107 |
+
latents_fp32 = latents.detach().to(torch.float32).clone()
|
| 108 |
+
latents_fp32.requires_grad_(True)
|
| 109 |
+
|
| 110 |
+
# 2. Compute reward score
|
| 111 |
+
# Note: Even if the model internally uses fp16/bf16, autograd will
|
| 112 |
+
# safely accumulate the gradient in fp32 for our leaf node.
|
| 113 |
+
reward_score = self.reward_model.get_reward_score(
|
| 114 |
+
latents_fp32,
|
| 115 |
+
prompt,
|
| 116 |
+
timestep,
|
| 117 |
+
enable_grad=True,
|
| 118 |
+
return_logits=True,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
reward_score_mean = reward_score.mean()
|
| 122 |
+
if not torch.isfinite(reward_score_mean):
|
| 123 |
+
return torch.zeros_like(latents), 0.0
|
| 124 |
+
|
| 125 |
+
# 3. Extract gradient
|
| 126 |
+
# CRITICAL: retain_graph=True prevents the graph from dying across multiple
|
| 127 |
+
# gradient steps if your reward model relies on cached text embeddings.
|
| 128 |
+
grad = torch.autograd.grad(
|
| 129 |
+
outputs=reward_score_mean,
|
| 130 |
+
inputs=latents_fp32,
|
| 131 |
+
create_graph=False,
|
| 132 |
+
retain_graph=True, # Keeps the graph alive for the next step!
|
| 133 |
+
allow_unused=True,
|
| 134 |
+
)[0]
|
| 135 |
+
|
| 136 |
+
# 4. Handle None gradients and cast back to the pipeline's original dtype
|
| 137 |
+
if grad is None:
|
| 138 |
+
grad = torch.zeros_like(latents)
|
| 139 |
+
else:
|
| 140 |
+
grad = torch.nan_to_num(grad, nan=0.0, posinf=0.0, neginf=0.0)
|
| 141 |
+
grad = grad.to(latents.dtype)
|
| 142 |
+
|
| 143 |
+
return grad, reward_score_mean.item()
|
| 144 |
+
|
| 145 |
+
def apply_gradient_ascent(
|
| 146 |
+
self,
|
| 147 |
+
latents: torch.Tensor,
|
| 148 |
+
prompt,
|
| 149 |
+
timestep: int,
|
| 150 |
+
base_noise: Optional[torch.Tensor] = None, # Required for Iso-Marginal projection
|
| 151 |
+
verbose: bool = True,
|
| 152 |
+
total_denoising_steps: Optional[int] = None,
|
| 153 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 154 |
+
|
| 155 |
+
# 1. UPCAST TO FP32 AND SETUP OPTIMIZER (Targeting Latents)
|
| 156 |
+
original_latents = latents.detach().clone().to(torch.float32)
|
| 157 |
+
current_latents = torch.nn.Parameter(original_latents.clone())
|
| 158 |
+
|
| 159 |
+
# Initial reward tracking
|
| 160 |
+
with torch.no_grad():
|
| 161 |
+
initial_reward = self.reward_model.get_reward_score(
|
| 162 |
+
latents,
|
| 163 |
+
prompt,
|
| 164 |
+
timestep
|
| 165 |
+
)
|
| 166 |
+
initial_reward_val = initial_reward.item() if initial_reward.numel() == 1 else initial_reward.mean().item()
|
| 167 |
+
|
| 168 |
+
# Initialize tracking lists
|
| 169 |
+
grad_norms = []
|
| 170 |
+
reward_history = [initial_reward_val]
|
| 171 |
+
lr_history = []
|
| 172 |
+
|
| 173 |
+
# 2. FORWARD PASS (model precision follows eval dtype; latents stay fp32 here)
|
| 174 |
+
reward = self.reward_model.get_reward_score(
|
| 175 |
+
current_latents.to(latents.dtype),
|
| 176 |
+
prompt,
|
| 177 |
+
timestep,
|
| 178 |
+
enable_grad=True,
|
| 179 |
+
return_logits=True,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
reward_mean = reward.mean()
|
| 183 |
+
if not torch.isfinite(reward_mean):
|
| 184 |
+
if verbose:
|
| 185 |
+
print("?? WARNING: Non-finite reward encountered; skipping gradient step.")
|
| 186 |
+
rectified_latents = original_latents.clone()
|
| 187 |
+
final_latents = rectified_latents.detach().to(latents.dtype)
|
| 188 |
+
stats = {
|
| 189 |
+
'timestep': timestep,
|
| 190 |
+
'initial_reward': initial_reward_val,
|
| 191 |
+
'final_reward': initial_reward_val,
|
| 192 |
+
'reward_improvement': 0.0,
|
| 193 |
+
'grad_norms': [0.0],
|
| 194 |
+
'reward_history': reward_history,
|
| 195 |
+
'lr_history': [0.0],
|
| 196 |
+
'latent_change': 0.0,
|
| 197 |
+
}
|
| 198 |
+
self.grad_stats.append(stats)
|
| 199 |
+
return final_latents, stats
|
| 200 |
+
|
| 201 |
+
loss = -reward_mean
|
| 202 |
+
loss.backward()
|
| 203 |
+
|
| 204 |
+
# Extract latent gradient
|
| 205 |
+
raw_grad = current_latents.grad
|
| 206 |
+
if raw_grad is not None:
|
| 207 |
+
raw_grad = torch.nan_to_num(raw_grad, nan=0.0, posinf=0.0, neginf=0.0)
|
| 208 |
+
reward_history.append(torch.sigmoid(reward_mean).item())
|
| 209 |
+
|
| 210 |
+
# 3. ISO-MARGINAL PROJECTION WITH ASYMMETRIC INCLUSION
|
| 211 |
+
if raw_grad is not None and base_noise is not None and self.use_iso_projection:
|
| 212 |
+
gamma = 1e-8
|
| 213 |
+
B = raw_grad.shape[0]
|
| 214 |
+
|
| 215 |
+
grad_flat = raw_grad.view(B, -1)
|
| 216 |
+
noise_flat = base_noise.view(B, -1).to(torch.float32)
|
| 217 |
+
|
| 218 |
+
# Compute projection scalar for raw_grad (which is -?R)
|
| 219 |
+
dot_product = (grad_flat * noise_flat).sum(dim=1, keepdim=True)
|
| 220 |
+
noise_norm_sq = (noise_flat * noise_flat).sum(dim=1, keepdim=True)
|
| 221 |
+
|
| 222 |
+
proj_scalar = dot_product / (noise_norm_sq + gamma)
|
| 223 |
+
proj_scalar = proj_scalar.view(B, 1, 1, 1)
|
| 224 |
+
|
| 225 |
+
# 1. Decompose
|
| 226 |
+
grad_parallel = proj_scalar * base_noise.to(torch.float32)
|
| 227 |
+
grad_perp = raw_grad - grad_parallel
|
| 228 |
+
|
| 229 |
+
# 2. Asymmetric Inclusion
|
| 230 |
+
# proj_scalar > 0 means the applied step (+?R) points toward -epsilon (Denoising. GOOD.)
|
| 231 |
+
# proj_scalar < 0 means the applied step (+?R) points toward +epsilon (Noising. BAD.)
|
| 232 |
+
safe_proj_scalar = torch.clamp(proj_scalar, min=0.0)
|
| 233 |
+
|
| 234 |
+
beta = 1.0 # Retention factor for the safe parallel gradient
|
| 235 |
+
safe_grad_parallel = beta * (safe_proj_scalar * base_noise.to(torch.float32))
|
| 236 |
+
|
| 237 |
+
# 3. Recombine
|
| 238 |
+
grad_perp = grad_perp + safe_grad_parallel
|
| 239 |
+
else:
|
| 240 |
+
grad_perp = raw_grad
|
| 241 |
+
if base_noise is None and self.use_iso_projection:
|
| 242 |
+
print("?? WARNING: base_noise missing. Skipping Iso-Marginal projection.")
|
| 243 |
+
|
| 244 |
+
# 4. KINETIC RECTIFICATION (Applied to the projected latent gradient)
|
| 245 |
+
if grad_perp is not None:
|
| 246 |
+
grad_norm = grad_perp.float().norm().item()
|
| 247 |
+
max_abs_grad = grad_perp.float().abs().max().item()
|
| 248 |
+
|
| 249 |
+
recovered_with_fallback = False
|
| 250 |
+
if grad_norm <= 0 or max_abs_grad <= 0:
|
| 251 |
+
fallback_grad, _ = self.compute_reward_gradient(
|
| 252 |
+
original_latents,
|
| 253 |
+
prompt,
|
| 254 |
+
timestep,
|
| 255 |
+
)
|
| 256 |
+
fallback_grad = torch.nan_to_num(fallback_grad, nan=0.0, posinf=0.0, neginf=0.0)
|
| 257 |
+
fallback_grad = fallback_grad.to(dtype=original_latents.dtype)
|
| 258 |
+
fallback_norm = fallback_grad.float().norm().item()
|
| 259 |
+
fallback_max_abs = fallback_grad.float().abs().max().item()
|
| 260 |
+
|
| 261 |
+
if fallback_norm > 0 and fallback_max_abs > 0:
|
| 262 |
+
grad_perp = fallback_grad
|
| 263 |
+
grad_norm = fallback_norm
|
| 264 |
+
max_abs_grad = fallback_max_abs
|
| 265 |
+
recovered_with_fallback = True
|
| 266 |
+
|
| 267 |
+
if grad_norm > 0 and max_abs_grad > 0:
|
| 268 |
+
kinetic_direction = grad_perp / (grad_norm + 1e-8)
|
| 269 |
+
|
| 270 |
+
# Because the max element is 1.0, alpha is the EXACT float32 change applied.
|
| 271 |
+
alpha = self.grad_step_size
|
| 272 |
+
|
| 273 |
+
with torch.no_grad():
|
| 274 |
+
rectified_latents = original_latents - (alpha * kinetic_direction)
|
| 275 |
+
if recovered_with_fallback:
|
| 276 |
+
print(
|
| 277 |
+
"✓ Recovered collapsed gradient using fp32 fallback "
|
| 278 |
+
f"(norm={grad_norm:.3e}, max_abs={max_abs_grad:.3e})"
|
| 279 |
+
)
|
| 280 |
+
else:
|
| 281 |
+
print(
|
| 282 |
+
"?? WARNING: Gradient tensor exists but magnitude collapsed to zero "
|
| 283 |
+
f"(norm={grad_norm:.3e}, max_abs={max_abs_grad:.3e}, dtype={grad_perp.dtype})"
|
| 284 |
+
)
|
| 285 |
+
rectified_latents = original_latents.clone()
|
| 286 |
+
alpha = 0.0
|
| 287 |
+
max_grad = grad_norm
|
| 288 |
+
else:
|
| 289 |
+
print("?? FATAL: PyTorch completely dropped the latent gradient!")
|
| 290 |
+
rectified_latents = original_latents.clone()
|
| 291 |
+
max_grad = 0.0
|
| 292 |
+
alpha = 0.0
|
| 293 |
+
|
| 294 |
+
if verbose:
|
| 295 |
+
print(f" Grad step | LR: {alpha:.6f} | Reward: {reward.mean().item():.4f} | Max Grad: {max_grad:.4f}")
|
| 296 |
+
|
| 297 |
+
# 5. DOWNCAST AND RETURN
|
| 298 |
+
final_latents = rectified_latents.detach().to(latents.dtype)
|
| 299 |
+
|
| 300 |
+
with torch.no_grad():
|
| 301 |
+
final_reward = self.reward_model.get_reward_score(
|
| 302 |
+
final_latents, prompt, timestep
|
| 303 |
+
)
|
| 304 |
+
final_reward_val = final_reward.item() if final_reward.numel() == 1 else final_reward.mean().item()
|
| 305 |
+
|
| 306 |
+
stats = {
|
| 307 |
+
'timestep': timestep,
|
| 308 |
+
'initial_reward': initial_reward_val,
|
| 309 |
+
'final_reward': final_reward_val,
|
| 310 |
+
'reward_improvement': final_reward_val - initial_reward_val,
|
| 311 |
+
'grad_norms': [max_grad],
|
| 312 |
+
'reward_history': reward_history,
|
| 313 |
+
'lr_history': [alpha], # Kept for plotting logic
|
| 314 |
+
'latent_change': (final_latents - original_latents.to(latents.dtype)).norm().item(),
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
self.grad_stats.append(stats)
|
| 318 |
+
|
| 319 |
+
return final_latents, stats
|
| 320 |
+
|
| 321 |
+
def get_statistics(self) -> dict:
|
| 322 |
+
"""Get aggregated statistics across all gradient ascent applications."""
|
| 323 |
+
if not self.grad_stats:
|
| 324 |
+
return {}
|
| 325 |
+
|
| 326 |
+
total_improvement = sum(s['reward_improvement'] for s in self.grad_stats)
|
| 327 |
+
avg_improvement = total_improvement / len(self.grad_stats)
|
| 328 |
+
|
| 329 |
+
all_grad_norms = [n for s in self.grad_stats for n in s['grad_norms']]
|
| 330 |
+
|
| 331 |
+
return {
|
| 332 |
+
'num_applications': len(self.grad_stats),
|
| 333 |
+
'total_reward_improvement': total_improvement,
|
| 334 |
+
'avg_reward_improvement': avg_improvement,
|
| 335 |
+
'avg_grad_norm': sum(all_grad_norms) / len(all_grad_norms) if all_grad_norms else 0,
|
| 336 |
+
'max_grad_norm': max(all_grad_norms) if all_grad_norms else 0,
|
| 337 |
+
'detailed_stats': self.grad_stats,
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
def reset_statistics(self):
|
| 341 |
+
"""Reset statistics and global scheduler."""
|
| 342 |
+
self.grad_stats = []
|
| 343 |
+
self.global_lr_scheduler = None
|
| 344 |
+
self.timestep_counter = 0
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def create_reward_guided_generator(
|
| 348 |
+
reward_model,
|
| 349 |
+
grad_timestep_range: Tuple[int, int] = (500, 700),
|
| 350 |
+
grad_scale: float = 1.0,
|
| 351 |
+
num_grad_steps: int = 5,
|
| 352 |
+
grad_step_size: float = 0.1,
|
| 353 |
+
lr_scheduler_type: str = "constant",
|
| 354 |
+
lr_scheduler_kwargs: Optional[dict] = None,
|
| 355 |
+
use_momentum: bool = False,
|
| 356 |
+
momentum: float = 0.9,
|
| 357 |
+
use_nesterov: bool = False,
|
| 358 |
+
use_iso_projection: bool = False
|
| 359 |
+
) -> RewardGuidedDiffusion:
|
| 360 |
+
"""
|
| 361 |
+
Convenience function to create a reward-guided diffusion generator.
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
reward_model: LRM reward model
|
| 365 |
+
grad_timestep_range: Tuple of (min_t, max_t) for applying gradients
|
| 366 |
+
grad_scale: Scale factor for gradient magnitude
|
| 367 |
+
num_grad_steps: Number of gradient ascent iterations per timestep
|
| 368 |
+
grad_step_size: Step size for each gradient update (initial LR)
|
| 369 |
+
lr_scheduler_type: Type of LR scheduler
|
| 370 |
+
lr_scheduler_kwargs: Additional kwargs for LR scheduler
|
| 371 |
+
use_momentum: Whether to use momentum
|
| 372 |
+
momentum: Momentum coefficient
|
| 373 |
+
use_nesterov: Whether to use Nesterov momentum
|
| 374 |
+
use_iso_projection: Whether to use Iso Projection
|
| 375 |
+
|
| 376 |
+
Returns:
|
| 377 |
+
RewardGuidedDiffusion instance
|
| 378 |
+
"""
|
| 379 |
+
return RewardGuidedDiffusion(
|
| 380 |
+
reward_model=reward_model,
|
| 381 |
+
grad_scale=grad_scale,
|
| 382 |
+
grad_timestep_range=grad_timestep_range,
|
| 383 |
+
num_grad_steps=num_grad_steps,
|
| 384 |
+
grad_step_size=grad_step_size,
|
| 385 |
+
lr_scheduler_type=lr_scheduler_type,
|
| 386 |
+
lr_scheduler_kwargs=lr_scheduler_kwargs,
|
| 387 |
+
use_momentum=use_momentum,
|
| 388 |
+
momentum=momentum,
|
| 389 |
+
use_nesterov=use_nesterov,
|
| 390 |
+
use_iso_projection= False
|
| 391 |
+
)
|
Reward_sana_idealized/hpsv2_score.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adapted from https://github.com/tgxs002/HPSv2. Originally Apache License, Version 2.0, January 2004.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from open_clip import create_model_and_transforms, get_tokenizer
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class HPSv2Scorer():
|
| 11 |
+
def __init__(self, clip_pretrained_name_or_path, model_pretrained_name_or_path, device='cuda'):
|
| 12 |
+
self.model, _, self.preprocess_val = create_model_and_transforms(
|
| 13 |
+
'ViT-H-14',
|
| 14 |
+
# 'laion2B-s32B-b79K',
|
| 15 |
+
clip_pretrained_name_or_path,
|
| 16 |
+
precision='amp',
|
| 17 |
+
device=device,
|
| 18 |
+
jit=False,
|
| 19 |
+
force_quick_gelu=False,
|
| 20 |
+
force_custom_text=False,
|
| 21 |
+
force_patch_dropout=False,
|
| 22 |
+
force_image_size=None,
|
| 23 |
+
pretrained_image=False,
|
| 24 |
+
image_mean=None,
|
| 25 |
+
image_std=None,
|
| 26 |
+
light_augmentation=True,
|
| 27 |
+
aug_cfg={},
|
| 28 |
+
output_dict=True,
|
| 29 |
+
with_score_predictor=False,
|
| 30 |
+
with_region_predictor=False
|
| 31 |
+
)
|
| 32 |
+
self.device = device
|
| 33 |
+
checkpoint = torch.load(model_pretrained_name_or_path, map_location=device)
|
| 34 |
+
self.model.load_state_dict(checkpoint['state_dict'])
|
| 35 |
+
self.tokenizer = get_tokenizer('ViT-H-14')
|
| 36 |
+
self.model = self.model.to(device)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def score(self, img_path, prompt):
|
| 40 |
+
|
| 41 |
+
if isinstance(img_path, list):
|
| 42 |
+
result = []
|
| 43 |
+
for one_img_path in img_path:
|
| 44 |
+
# Load your image and prompt
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
# Process the image
|
| 47 |
+
if isinstance(one_img_path, str):
|
| 48 |
+
image = self.preprocess_val(Image.open(one_img_path)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
| 49 |
+
elif isinstance(one_img_path, Image.Image):
|
| 50 |
+
image = self.preprocess_val(one_img_path).unsqueeze(0).to(device=self.device, non_blocking=True)
|
| 51 |
+
else:
|
| 52 |
+
raise TypeError('The type of parameter img_path is illegal.')
|
| 53 |
+
# Process the prompt
|
| 54 |
+
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
|
| 55 |
+
# Calculate the HPS
|
| 56 |
+
with torch.cuda.amp.autocast():
|
| 57 |
+
outputs = self.model(image, text)
|
| 58 |
+
image_features, text_features = outputs["image_features"], outputs["text_features"]
|
| 59 |
+
logits_per_image = image_features @ text_features.T
|
| 60 |
+
|
| 61 |
+
hps_score = torch.diagonal(logits_per_image).cpu().numpy()
|
| 62 |
+
result.append(hps_score[0])
|
| 63 |
+
return result
|
| 64 |
+
elif isinstance(img_path, str):
|
| 65 |
+
# Load your image and prompt
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
# Process the image
|
| 68 |
+
image = self.preprocess_val(Image.open(img_path)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
| 69 |
+
# Process the prompt
|
| 70 |
+
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
|
| 71 |
+
# Calculate the HPS
|
| 72 |
+
with torch.cuda.amp.autocast():
|
| 73 |
+
outputs = self.model(image, text)
|
| 74 |
+
image_features, text_features = outputs["image_features"], outputs["text_features"]
|
| 75 |
+
logits_per_image = image_features @ text_features.T
|
| 76 |
+
|
| 77 |
+
hps_score = torch.diagonal(logits_per_image).cpu().numpy()
|
| 78 |
+
return [hps_score[0]]
|
| 79 |
+
elif isinstance(img_path, Image.Image):
|
| 80 |
+
# Load your image and prompt
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
# Process the image
|
| 83 |
+
image = self.preprocess_val(img_path).unsqueeze(0).to(device=self.device, non_blocking=True)
|
| 84 |
+
# Process the prompt
|
| 85 |
+
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
|
| 86 |
+
# Calculate the HPS
|
| 87 |
+
with torch.cuda.amp.autocast():
|
| 88 |
+
outputs = self.model(image, text)
|
| 89 |
+
image_features, text_features = outputs["image_features"], outputs["text_features"]
|
| 90 |
+
logits_per_image = image_features @ text_features.T
|
| 91 |
+
|
| 92 |
+
hps_score = torch.diagonal(logits_per_image).cpu().numpy()
|
| 93 |
+
return [hps_score[0]]
|
| 94 |
+
else:
|
| 95 |
+
raise TypeError('The type of parameter img_path is illegal.')
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if __name__ == "__main__":
|
| 99 |
+
from huggingface_hub import hf_hub_download
|
| 100 |
+
|
| 101 |
+
clip_model_path = hf_hub_download(repo_id="laion/CLIP-ViT-H-14-laion2B-s32B-b79K", filename="open_clip_pytorch_model.bin")
|
| 102 |
+
hps_model_path = hf_hub_download(repo_id="xswu/HPSv2", filename="HPS_v2_compressed.pt")
|
| 103 |
+
|
| 104 |
+
hpsv2_scorer = HPSv2Scorer(clip_pretrained_name_or_path=clip_model_path,
|
| 105 |
+
model_pretrained_name_or_path=hps_model_path)
|
| 106 |
+
score = hpsv2_scorer.score(img_path=['./image0.png', './image1.png'],
|
| 107 |
+
prompt='photorealistic image of a lone painter standing in a gallery, watching an exhibition of paintings made entirely with AI. In the foreground of the image a robot looks proudly at his art')
|
| 108 |
+
|
| 109 |
+
print(score)
|
| 110 |
+
|
Reward_sana_idealized/imagereward_score.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adapted from https://github.com/THUDM/ImageReward. Originally Apache License, Version 2.0, January 2004.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from io import BytesIO
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from blip.blip_pretrain import BLIP_Pretrain
|
| 11 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
| 12 |
+
from typing import Any, Union, List
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from torchvision.transforms import InterpolationMode
|
| 16 |
+
BICUBIC = InterpolationMode.BICUBIC
|
| 17 |
+
except ImportError:
|
| 18 |
+
BICUBIC = Image.BICUBIC
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def open_image(image):
|
| 22 |
+
if isinstance(image, bytes):
|
| 23 |
+
image = Image.open(BytesIO(image))
|
| 24 |
+
elif isinstance(image, str):
|
| 25 |
+
image = Image.open(image)
|
| 26 |
+
image = image.convert("RGB")
|
| 27 |
+
return image
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _convert_image_to_rgb(image):
|
| 31 |
+
return image.convert("RGB")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _transform(n_px):
|
| 35 |
+
return Compose([
|
| 36 |
+
Resize(n_px, interpolation=BICUBIC),
|
| 37 |
+
CenterCrop(n_px),
|
| 38 |
+
_convert_image_to_rgb,
|
| 39 |
+
ToTensor(),
|
| 40 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
| 41 |
+
])
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class MLP(nn.Module):
|
| 45 |
+
def __init__(self, input_size):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.input_size = input_size
|
| 48 |
+
|
| 49 |
+
self.layers = nn.Sequential(
|
| 50 |
+
nn.Linear(self.input_size, 1024),
|
| 51 |
+
#nn.ReLU(),
|
| 52 |
+
nn.Dropout(0.2),
|
| 53 |
+
nn.Linear(1024, 128),
|
| 54 |
+
#nn.ReLU(),
|
| 55 |
+
nn.Dropout(0.2),
|
| 56 |
+
nn.Linear(128, 64),
|
| 57 |
+
#nn.ReLU(),
|
| 58 |
+
nn.Dropout(0.1),
|
| 59 |
+
nn.Linear(64, 16),
|
| 60 |
+
#nn.ReLU(),
|
| 61 |
+
nn.Linear(16, 1)
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# initial MLP param
|
| 65 |
+
for name, param in self.layers.named_parameters():
|
| 66 |
+
if 'weight' in name:
|
| 67 |
+
nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1))
|
| 68 |
+
if 'bias' in name:
|
| 69 |
+
nn.init.constant_(param, val=0)
|
| 70 |
+
|
| 71 |
+
def forward(self, input):
|
| 72 |
+
return self.layers(input)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class ImageReward(nn.Module):
|
| 76 |
+
def __init__(self, med_config, device='cpu'):
|
| 77 |
+
super().__init__()
|
| 78 |
+
self.device = device
|
| 79 |
+
|
| 80 |
+
self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config)
|
| 81 |
+
self.preprocess = _transform(224)
|
| 82 |
+
self.mlp = MLP(768)
|
| 83 |
+
|
| 84 |
+
self.mean = 0.16717362830052426
|
| 85 |
+
self.std = 1.0333394966054072
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def score_gard(self, prompt_ids, prompt_attention_mask, image):
|
| 89 |
+
|
| 90 |
+
image_embeds = self.blip.visual_encoder(image)
|
| 91 |
+
# text encode cross attention with image
|
| 92 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(self.device)
|
| 93 |
+
text_output = self.blip.text_encoder(prompt_ids,
|
| 94 |
+
attention_mask = prompt_attention_mask,
|
| 95 |
+
encoder_hidden_states = image_embeds,
|
| 96 |
+
encoder_attention_mask = image_atts,
|
| 97 |
+
return_dict = True,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
txt_features = text_output.last_hidden_state[:,0,:] # (feature_dim)
|
| 101 |
+
rewards = self.mlp(txt_features)
|
| 102 |
+
rewards = (rewards - self.mean) / self.std
|
| 103 |
+
|
| 104 |
+
return rewards
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def score(self, prompt, image):
|
| 108 |
+
|
| 109 |
+
if (type(image).__name__=='list'):
|
| 110 |
+
_, rewards = self.inference_rank(prompt, image)
|
| 111 |
+
return rewards
|
| 112 |
+
|
| 113 |
+
# text encode
|
| 114 |
+
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
|
| 115 |
+
|
| 116 |
+
# image encode
|
| 117 |
+
if isinstance(image, Image.Image):
|
| 118 |
+
pil_image = image
|
| 119 |
+
elif isinstance(image, str):
|
| 120 |
+
if os.path.isfile(image):
|
| 121 |
+
pil_image = Image.open(image)
|
| 122 |
+
else:
|
| 123 |
+
raise TypeError(r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.')
|
| 124 |
+
|
| 125 |
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
| 126 |
+
image_embeds = self.blip.visual_encoder(image)
|
| 127 |
+
|
| 128 |
+
# text encode cross attention with image
|
| 129 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(self.device)
|
| 130 |
+
text_output = self.blip.text_encoder(text_input.input_ids,
|
| 131 |
+
attention_mask = text_input.attention_mask,
|
| 132 |
+
encoder_hidden_states = image_embeds,
|
| 133 |
+
encoder_attention_mask = image_atts,
|
| 134 |
+
return_dict = True,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
txt_features = text_output.last_hidden_state[:,0,:].float() # (feature_dim)
|
| 138 |
+
rewards = self.mlp(txt_features)
|
| 139 |
+
rewards = (rewards - self.mean) / self.std
|
| 140 |
+
|
| 141 |
+
return rewards.detach().cpu().numpy().item()
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def inference_rank(self, prompt, generations_list):
|
| 145 |
+
|
| 146 |
+
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
|
| 147 |
+
|
| 148 |
+
txt_set = []
|
| 149 |
+
for generation in generations_list:
|
| 150 |
+
# image encode
|
| 151 |
+
if isinstance(generation, Image.Image):
|
| 152 |
+
pil_image = generation
|
| 153 |
+
elif isinstance(generation, str):
|
| 154 |
+
if os.path.isfile(generation):
|
| 155 |
+
pil_image = Image.open(generation)
|
| 156 |
+
else:
|
| 157 |
+
raise TypeError(r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.')
|
| 158 |
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
| 159 |
+
image_embeds = self.blip.visual_encoder(image)
|
| 160 |
+
|
| 161 |
+
# text encode cross attention with image
|
| 162 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(self.device)
|
| 163 |
+
text_output = self.blip.text_encoder(text_input.input_ids,
|
| 164 |
+
attention_mask = text_input.attention_mask,
|
| 165 |
+
encoder_hidden_states = image_embeds,
|
| 166 |
+
encoder_attention_mask = image_atts,
|
| 167 |
+
return_dict = True,
|
| 168 |
+
)
|
| 169 |
+
txt_set.append(text_output.last_hidden_state[:,0,:])
|
| 170 |
+
|
| 171 |
+
txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim]
|
| 172 |
+
rewards = self.mlp(txt_features) # [image_num, 1]
|
| 173 |
+
rewards = (rewards - self.mean) / self.std
|
| 174 |
+
rewards = torch.squeeze(rewards)
|
| 175 |
+
_, rank = torch.sort(rewards, dim=0, descending=True)
|
| 176 |
+
_, indices = torch.sort(rank, dim=0)
|
| 177 |
+
indices = indices + 1
|
| 178 |
+
|
| 179 |
+
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def load_imagereward(model_path: str, med_config: str = None, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu"):
|
| 183 |
+
"""Load a ImageReward model
|
| 184 |
+
|
| 185 |
+
Parameters
|
| 186 |
+
----------
|
| 187 |
+
name : str
|
| 188 |
+
A model name listed by `ImageReward.available_models()`, or the path to a model checkpoint containing the state_dict
|
| 189 |
+
|
| 190 |
+
device : Union[str, torch.device]
|
| 191 |
+
The device to put the loaded model
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
Returns
|
| 195 |
+
-------
|
| 196 |
+
model : torch.nn.Module
|
| 197 |
+
The ImageReward model
|
| 198 |
+
"""
|
| 199 |
+
print('load checkpoint from %s'%model_path)
|
| 200 |
+
state_dict = torch.load(model_path, map_location='cpu')
|
| 201 |
+
|
| 202 |
+
model = ImageReward(device=device, med_config=med_config).to(device)
|
| 203 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
| 204 |
+
print("checkpoint loaded")
|
| 205 |
+
model.eval()
|
| 206 |
+
|
| 207 |
+
return model
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
if __name__ == '__main__':
|
| 211 |
+
from huggingface_hub import hf_hub_download
|
| 212 |
+
|
| 213 |
+
model_path = hf_hub_download(repo_id="THUDM/ImageReward", filename="ImageReward.pt")
|
| 214 |
+
config_path = hf_hub_download(repo_id="THUDM/ImageReward", filename="med_config.json")
|
| 215 |
+
|
| 216 |
+
image0 = open_image('./image0.png')
|
| 217 |
+
image1 = open_image('./image1.png')
|
| 218 |
+
prompt = "photorealistic image of a lone painter standing in a gallery, watching an exhibition of paintings made entirely with AI. In the foreground of the image a robot looks proudly at his art"
|
| 219 |
+
model = load_imagereward(model_path=model_path, med_config=config_path, device='cuda')
|
| 220 |
+
|
| 221 |
+
print(model.score(prompt, [image0, image1]))
|
Reward_sana_idealized/lr_scheduler.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Learning rate schedulers for gradient ascent optimization.
|
| 3 |
+
|
| 4 |
+
Provides various LR scheduling strategies for reward-guided gradient ascent,
|
| 5 |
+
including cosine annealing, linear decay, and custom schedules.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
from typing import Optional, Literal
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class LRScheduler:
|
| 13 |
+
"""Base class for learning rate schedulers."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, initial_lr: float, num_steps: int):
|
| 16 |
+
"""
|
| 17 |
+
Initialize LR scheduler.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
initial_lr: Initial learning rate
|
| 21 |
+
num_steps: Total number of optimization steps
|
| 22 |
+
"""
|
| 23 |
+
self.initial_lr = initial_lr
|
| 24 |
+
self.num_steps = num_steps
|
| 25 |
+
self.current_step = 0
|
| 26 |
+
|
| 27 |
+
def get_lr(self) -> float:
|
| 28 |
+
"""Get current learning rate."""
|
| 29 |
+
raise NotImplementedError
|
| 30 |
+
|
| 31 |
+
def step(self):
|
| 32 |
+
"""Update scheduler state after a step."""
|
| 33 |
+
self.current_step += 1
|
| 34 |
+
|
| 35 |
+
def reset(self):
|
| 36 |
+
"""Reset scheduler state."""
|
| 37 |
+
self.current_step = 0
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ConstantLR(LRScheduler):
|
| 41 |
+
"""Constant learning rate (no scheduling)."""
|
| 42 |
+
|
| 43 |
+
def get_lr(self) -> float:
|
| 44 |
+
return self.initial_lr
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class LinearLR(LRScheduler):
|
| 48 |
+
"""Linear learning rate decay."""
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
initial_lr: float,
|
| 53 |
+
num_steps: int,
|
| 54 |
+
end_lr: float = 0.0,
|
| 55 |
+
start_step: int = 0,
|
| 56 |
+
):
|
| 57 |
+
"""
|
| 58 |
+
Initialize linear LR scheduler.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
initial_lr: Starting learning rate
|
| 62 |
+
num_steps: Total number of steps
|
| 63 |
+
end_lr: Ending learning rate (default: 0.0)
|
| 64 |
+
start_step: Step to begin decay (default: 0)
|
| 65 |
+
"""
|
| 66 |
+
super().__init__(initial_lr, num_steps)
|
| 67 |
+
self.end_lr = end_lr
|
| 68 |
+
self.start_step = start_step
|
| 69 |
+
|
| 70 |
+
def get_lr(self) -> float:
|
| 71 |
+
if self.current_step < self.start_step:
|
| 72 |
+
return self.initial_lr
|
| 73 |
+
|
| 74 |
+
progress = (self.current_step - self.start_step) / (self.num_steps - self.start_step)
|
| 75 |
+
progress = min(1.0, progress)
|
| 76 |
+
|
| 77 |
+
return self.initial_lr + (self.end_lr - self.initial_lr) * progress
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class CosineLR(LRScheduler):
|
| 81 |
+
"""Cosine annealing learning rate schedule."""
|
| 82 |
+
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
initial_lr: float,
|
| 86 |
+
num_steps: int,
|
| 87 |
+
min_lr: float = 0.0,
|
| 88 |
+
warmup_steps: int = 0,
|
| 89 |
+
):
|
| 90 |
+
"""
|
| 91 |
+
Initialize cosine LR scheduler.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
initial_lr: Maximum learning rate
|
| 95 |
+
num_steps: Total number of steps
|
| 96 |
+
min_lr: Minimum learning rate (default: 0.0)
|
| 97 |
+
warmup_steps: Number of linear warmup steps (default: 0)
|
| 98 |
+
"""
|
| 99 |
+
super().__init__(initial_lr, num_steps)
|
| 100 |
+
self.min_lr = min_lr
|
| 101 |
+
self.warmup_steps = warmup_steps
|
| 102 |
+
|
| 103 |
+
def get_lr(self) -> float:
|
| 104 |
+
if self.current_step < self.warmup_steps:
|
| 105 |
+
# Linear warmup
|
| 106 |
+
return self.initial_lr * (self.current_step / self.warmup_steps)
|
| 107 |
+
|
| 108 |
+
# Cosine annealing
|
| 109 |
+
progress = (self.current_step - self.warmup_steps) / (self.num_steps - self.warmup_steps)
|
| 110 |
+
progress = min(1.0, progress)
|
| 111 |
+
|
| 112 |
+
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
|
| 113 |
+
return self.min_lr + (self.initial_lr - self.min_lr) * cosine_decay
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class ExponentialLR(LRScheduler):
|
| 117 |
+
"""Exponential learning rate decay."""
|
| 118 |
+
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
initial_lr: float,
|
| 122 |
+
num_steps: int,
|
| 123 |
+
gamma: float = 0.95,
|
| 124 |
+
):
|
| 125 |
+
"""
|
| 126 |
+
Initialize exponential LR scheduler.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
initial_lr: Starting learning rate
|
| 130 |
+
num_steps: Total number of steps
|
| 131 |
+
gamma: Multiplicative decay factor per step
|
| 132 |
+
"""
|
| 133 |
+
super().__init__(initial_lr, num_steps)
|
| 134 |
+
self.gamma = gamma
|
| 135 |
+
|
| 136 |
+
def get_lr(self) -> float:
|
| 137 |
+
return self.initial_lr * (self.gamma ** self.current_step)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class StepLR(LRScheduler):
|
| 141 |
+
"""Step-wise learning rate decay."""
|
| 142 |
+
|
| 143 |
+
def __init__(
|
| 144 |
+
self,
|
| 145 |
+
initial_lr: float,
|
| 146 |
+
num_steps: int,
|
| 147 |
+
step_size: int,
|
| 148 |
+
gamma: float = 0.1,
|
| 149 |
+
):
|
| 150 |
+
"""
|
| 151 |
+
Initialize step LR scheduler.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
initial_lr: Starting learning rate
|
| 155 |
+
num_steps: Total number of steps
|
| 156 |
+
step_size: Number of steps between each decay
|
| 157 |
+
gamma: Multiplicative decay factor
|
| 158 |
+
"""
|
| 159 |
+
super().__init__(initial_lr, num_steps)
|
| 160 |
+
self.step_size = step_size
|
| 161 |
+
self.gamma = gamma
|
| 162 |
+
|
| 163 |
+
def get_lr(self) -> float:
|
| 164 |
+
num_decays = self.current_step // self.step_size
|
| 165 |
+
return self.initial_lr * (self.gamma ** num_decays)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def create_lr_scheduler(
|
| 169 |
+
scheduler_type: Literal["constant", "linear", "cosine", "exponential", "step"],
|
| 170 |
+
initial_lr: float,
|
| 171 |
+
num_steps: int,
|
| 172 |
+
**kwargs
|
| 173 |
+
) -> LRScheduler:
|
| 174 |
+
"""
|
| 175 |
+
Factory function to create learning rate schedulers.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
scheduler_type: Type of scheduler ("constant", "linear", "cosine", "exponential", "step")
|
| 179 |
+
initial_lr: Initial learning rate
|
| 180 |
+
num_steps: Total number of optimization steps
|
| 181 |
+
**kwargs: Additional scheduler-specific arguments
|
| 182 |
+
For linear: end_lr, start_step
|
| 183 |
+
For cosine: min_lr, warmup_steps
|
| 184 |
+
For exponential: gamma
|
| 185 |
+
For step: step_size, gamma
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
LRScheduler instance
|
| 189 |
+
|
| 190 |
+
Examples:
|
| 191 |
+
# Constant LR
|
| 192 |
+
scheduler = create_lr_scheduler("constant", initial_lr=0.1, num_steps=100)
|
| 193 |
+
|
| 194 |
+
# Linear decay
|
| 195 |
+
scheduler = create_lr_scheduler("linear", initial_lr=0.1, num_steps=100, end_lr=0.01)
|
| 196 |
+
|
| 197 |
+
# Cosine annealing with warmup
|
| 198 |
+
scheduler = create_lr_scheduler("cosine", initial_lr=0.1, num_steps=100,
|
| 199 |
+
min_lr=0.001, warmup_steps=10)
|
| 200 |
+
"""
|
| 201 |
+
if scheduler_type == "constant":
|
| 202 |
+
return ConstantLR(initial_lr, num_steps)
|
| 203 |
+
|
| 204 |
+
elif scheduler_type == "linear":
|
| 205 |
+
return LinearLR(
|
| 206 |
+
initial_lr, num_steps,
|
| 207 |
+
end_lr=kwargs.get("end_lr", 0.0),
|
| 208 |
+
start_step=kwargs.get("start_step", 0),
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
elif scheduler_type == "cosine":
|
| 212 |
+
return CosineLR(
|
| 213 |
+
initial_lr, num_steps,
|
| 214 |
+
min_lr=kwargs.get("min_lr", 0.0),
|
| 215 |
+
warmup_steps=kwargs.get("warmup_steps", 0),
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
elif scheduler_type == "exponential":
|
| 219 |
+
return ExponentialLR(
|
| 220 |
+
initial_lr, num_steps,
|
| 221 |
+
gamma=kwargs.get("gamma", 0.95),
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
elif scheduler_type == "step":
|
| 225 |
+
return StepLR(
|
| 226 |
+
initial_lr, num_steps,
|
| 227 |
+
step_size=kwargs.get("step_size", 10),
|
| 228 |
+
gamma=kwargs.get("gamma", 0.1),
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
else:
|
| 232 |
+
raise ValueError(f"Unknown scheduler type: {scheduler_type}. "
|
| 233 |
+
f"Choose from: constant, linear, cosine, exponential, step")
|
Reward_sana_idealized/models/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (285 Bytes). View file
|
|
|
Reward_sana_idealized/open_clip/__pycache__/coca_model.cpython-311.pyc
ADDED
|
Binary file (18.4 kB). View file
|
|
|
Reward_sana_idealized/open_clip/__pycache__/factory.cpython-311.pyc
ADDED
|
Binary file (19.3 kB). View file
|
|
|
Reward_sana_idealized/open_clip/__pycache__/model.cpython-311.pyc
ADDED
|
Binary file (25.1 kB). View file
|
|
|
Reward_sana_idealized/open_clip/__pycache__/modified_resnet.cpython-311.pyc
ADDED
|
Binary file (13.2 kB). View file
|
|
|
Reward_sana_idealized/open_clip/__pycache__/pretrained.cpython-311.pyc
ADDED
|
Binary file (18.5 kB). View file
|
|
|
Reward_sana_idealized/open_clip/__pycache__/push_to_hf_hub.cpython-311.pyc
ADDED
|
Binary file (9.29 kB). View file
|
|
|
Reward_sana_idealized/open_clip/__pycache__/timm_model.cpython-311.pyc
ADDED
|
Binary file (6.73 kB). View file
|
|
|
Reward_sana_idealized/open_clip/__pycache__/tokenizer.cpython-311.pyc
ADDED
|
Binary file (16 kB). View file
|
|
|
Reward_sana_idealized/open_clip/__pycache__/transformer.cpython-311.pyc
ADDED
|
Binary file (42.6 kB). View file
|
|
|
Reward_sana_idealized/open_clip/model_configs/convnext_xlarge.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"embed_dim": 1024,
|
| 3 |
+
"vision_cfg": {
|
| 4 |
+
"timm_model_name": "convnext_xlarge",
|
| 5 |
+
"timm_model_pretrained": false,
|
| 6 |
+
"timm_pool": "",
|
| 7 |
+
"timm_proj": "linear",
|
| 8 |
+
"timm_drop": 0.0,
|
| 9 |
+
"timm_drop_path": 0.1,
|
| 10 |
+
"image_size": 256
|
| 11 |
+
},
|
| 12 |
+
"text_cfg": {
|
| 13 |
+
"context_length": 77,
|
| 14 |
+
"vocab_size": 49408,
|
| 15 |
+
"width": 1024,
|
| 16 |
+
"heads": 16,
|
| 17 |
+
"layers": 20
|
| 18 |
+
}
|
| 19 |
+
}
|
Reward_sana_idealized/pick_score.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adapted from https://github.com/yuvalkirstain/PickScore. Originally MIT License, Copyright (c) 2021.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from transformers import AutoProcessor, AutoModel
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from transformers import AutoProcessor, AutoModel
|
| 11 |
+
from datasets import load_from_disk, load_dataset
|
| 12 |
+
import torch
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from io import BytesIO
|
| 15 |
+
from tqdm.auto import tqdm
|
| 16 |
+
import sys
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def open_image(image):
|
| 20 |
+
if isinstance(image, bytes):
|
| 21 |
+
image = Image.open(BytesIO(image))
|
| 22 |
+
elif isinstance(image, str):
|
| 23 |
+
image = Image.open(image)
|
| 24 |
+
image = image.convert("RGB")
|
| 25 |
+
return image
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class PickScorer(torch.nn.Module):
|
| 30 |
+
def __init__(self, processor_name_or_path, model_pretrained_name_or_path, device='cuda'):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.processor = AutoProcessor.from_pretrained(processor_name_or_path)
|
| 33 |
+
self.model = AutoModel.from_pretrained(model_pretrained_name_or_path).to(device)
|
| 34 |
+
self.device = device
|
| 35 |
+
self.eval()
|
| 36 |
+
|
| 37 |
+
@torch.no_grad()
|
| 38 |
+
def __call__(self, prompt, images):
|
| 39 |
+
# preprocess
|
| 40 |
+
image_inputs = self.processor(
|
| 41 |
+
images=images,
|
| 42 |
+
padding=True,
|
| 43 |
+
truncation=True,
|
| 44 |
+
max_length=77,
|
| 45 |
+
return_tensors="pt",
|
| 46 |
+
).to(self.device)
|
| 47 |
+
|
| 48 |
+
text_inputs = self.processor(
|
| 49 |
+
text=prompt,
|
| 50 |
+
padding=True,
|
| 51 |
+
truncation=True,
|
| 52 |
+
max_length=77,
|
| 53 |
+
return_tensors="pt",
|
| 54 |
+
).to(self.device)
|
| 55 |
+
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
# embed
|
| 58 |
+
image_embs = self.model.get_image_features(**image_inputs)
|
| 59 |
+
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
|
| 60 |
+
|
| 61 |
+
text_embs = self.model.get_text_features(**text_inputs)
|
| 62 |
+
text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
|
| 63 |
+
|
| 64 |
+
# score
|
| 65 |
+
scores = self.model.logit_scale.exp() * (text_embs @ image_embs.T)[0]
|
| 66 |
+
|
| 67 |
+
# get probabilities if you have multiple images to choose from
|
| 68 |
+
if len(scores) == 1:
|
| 69 |
+
probs = scores
|
| 70 |
+
else:
|
| 71 |
+
probs = torch.softmax(scores, dim=-1)
|
| 72 |
+
|
| 73 |
+
return probs.cpu().tolist()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def score(self, img_path, prompt):
|
| 77 |
+
if isinstance(img_path, list):
|
| 78 |
+
result = []
|
| 79 |
+
for one_img_path in img_path:
|
| 80 |
+
# Load your image and prompt
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
# Process the image
|
| 83 |
+
if isinstance(one_img_path, str):
|
| 84 |
+
image = self.preprocess_val(Image.open(one_img_path)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
| 85 |
+
elif isinstance(one_img_path, Image.Image):
|
| 86 |
+
image = self.preprocess_val(one_img_path).unsqueeze(0).to(device=self.device, non_blocking=True)
|
| 87 |
+
else:
|
| 88 |
+
raise TypeError('The type of parameter img_path is illegal.')
|
| 89 |
+
# Process the prompt
|
| 90 |
+
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
|
| 91 |
+
with torch.cuda.amp.autocast():
|
| 92 |
+
outputs = self.model(image, text)
|
| 93 |
+
image_features, text_features = outputs["image_features"], outputs["text_features"]
|
| 94 |
+
logits_per_image = image_features @ text_features.T
|
| 95 |
+
|
| 96 |
+
pick_score = self.model.logit_scale.exp() * torch.diagonal(logits_per_image).cpu().numpy()
|
| 97 |
+
result.append(pick_score[0])
|
| 98 |
+
return result
|
| 99 |
+
elif isinstance(img_path, str):
|
| 100 |
+
# Load your image and prompt
|
| 101 |
+
with torch.no_grad():
|
| 102 |
+
# Process the image
|
| 103 |
+
image = self.preprocess_val(Image.open(img_path)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
| 104 |
+
# Process the prompt
|
| 105 |
+
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
|
| 106 |
+
with torch.cuda.amp.autocast():
|
| 107 |
+
outputs = self.model(image, text)
|
| 108 |
+
image_features, text_features = outputs["image_features"], outputs["text_features"]
|
| 109 |
+
logits_per_image = image_features @ text_features.T
|
| 110 |
+
|
| 111 |
+
pick_score = self.model.logit_scale.exp() * torch.diagonal(logits_per_image).cpu().numpy()
|
| 112 |
+
return [pick_score[0]]
|
| 113 |
+
elif isinstance(img_path, Image.Image):
|
| 114 |
+
# Load your image and prompt
|
| 115 |
+
with torch.no_grad():
|
| 116 |
+
# Process the image
|
| 117 |
+
image = self.preprocess_val(img_path).unsqueeze(0).to(device=self.device, non_blocking=True)
|
| 118 |
+
# Process the prompt
|
| 119 |
+
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
|
| 120 |
+
with torch.cuda.amp.autocast():
|
| 121 |
+
outputs = self.model(image, text)
|
| 122 |
+
image_features, text_features = outputs["image_features"], outputs["text_features"]
|
| 123 |
+
logits_per_image = image_features @ text_features.T
|
| 124 |
+
|
| 125 |
+
pick_score = self.model.logit_scale.exp() * torch.diagonal(logits_per_image).cpu().numpy()
|
| 126 |
+
return [pick_score[0]]
|
| 127 |
+
else:
|
| 128 |
+
raise TypeError('The type of parameter img_path is illegal.')
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if __name__ == "__main__":
|
| 132 |
+
pickscorer = PickScorer(processor_name_or_path="laion/CLIP-ViT-H-14-laion2B-s32B-b79K", model_pretrained_name_or_path="yuvalkirstain/PickScore_v1")
|
| 133 |
+
|
| 134 |
+
image0 = open_image('./image0.png')
|
| 135 |
+
image1 = open_image('./image1.png')
|
| 136 |
+
prompt = "photorealistic image of a lone painter standing in a gallery, watching an exhibition of paintings made entirely with AI. In the foreground of the image a robot looks proudly at his art"
|
| 137 |
+
|
| 138 |
+
probs = pickscorer(prompt, [image0])
|
| 139 |
+
probs1 = pickscorer(prompt, [image1])
|
| 140 |
+
print(probs)
|
| 141 |
+
print(probs1)
|
Reward_sana_idealized/test.ipynb
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "d8044219",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"data": {
|
| 11 |
+
"text/plain": [
|
| 12 |
+
"False"
|
| 13 |
+
]
|
| 14 |
+
},
|
| 15 |
+
"execution_count": 1,
|
| 16 |
+
"metadata": {},
|
| 17 |
+
"output_type": "execute_result"
|
| 18 |
+
}
|
| 19 |
+
],
|
| 20 |
+
"source": [
|
| 21 |
+
"import torch\n",
|
| 22 |
+
"torch.cuda.is_available()"
|
| 23 |
+
]
|
| 24 |
+
}
|
| 25 |
+
],
|
| 26 |
+
"metadata": {
|
| 27 |
+
"kernelspec": {
|
| 28 |
+
"display_name": "Python 3",
|
| 29 |
+
"language": "python",
|
| 30 |
+
"name": "python3"
|
| 31 |
+
},
|
| 32 |
+
"language_info": {
|
| 33 |
+
"codemirror_mode": {
|
| 34 |
+
"name": "ipython",
|
| 35 |
+
"version": 3
|
| 36 |
+
},
|
| 37 |
+
"file_extension": ".py",
|
| 38 |
+
"mimetype": "text/x-python",
|
| 39 |
+
"name": "python",
|
| 40 |
+
"nbconvert_exporter": "python",
|
| 41 |
+
"pygments_lexer": "ipython3",
|
| 42 |
+
"version": "3.10.18"
|
| 43 |
+
}
|
| 44 |
+
},
|
| 45 |
+
"nbformat": 4,
|
| 46 |
+
"nbformat_minor": 5
|
| 47 |
+
}
|
Reward_sana_idealized/tune_hyperparams.py
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hyperparameter tuning script for gradient ascent optimization.
|
| 3 |
+
|
| 4 |
+
This script performs a systematic search over hyperparameter combinations
|
| 5 |
+
to find the optimal configuration for maximum evaluation scores.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import subprocess
|
| 9 |
+
import json
|
| 10 |
+
import argparse
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
import itertools
|
| 14 |
+
import numpy as np
|
| 15 |
+
from typing import Dict, List, Any
|
| 16 |
+
import re
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class HyperparameterTuner:
|
| 20 |
+
"""Hyperparameter tuner for gradient ascent."""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
output_dir: str = "tuning_results",
|
| 25 |
+
max_samples: int = 30,
|
| 26 |
+
num_steps: int = 20,
|
| 27 |
+
dataset_type: str = "pickapic",
|
| 28 |
+
model_variant: str = "lpo",
|
| 29 |
+
cuda_id: int = 0,
|
| 30 |
+
metrics: List[str] = None
|
| 31 |
+
):
|
| 32 |
+
self.output_dir = Path(output_dir)
|
| 33 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 34 |
+
|
| 35 |
+
self.max_samples = max_samples
|
| 36 |
+
self.num_steps = num_steps
|
| 37 |
+
self.dataset_type = dataset_type
|
| 38 |
+
self.model_variant = model_variant
|
| 39 |
+
self.cuda_id = cuda_id
|
| 40 |
+
self.metrics = metrics or ["clip", "aesthetic", "pickscore", "hpsv2", "imagereward"]
|
| 41 |
+
|
| 42 |
+
# Store results
|
| 43 |
+
self.results = []
|
| 44 |
+
self.baseline_results = None
|
| 45 |
+
|
| 46 |
+
def define_search_space(self) -> List[Dict[str, Any]]:
|
| 47 |
+
"""Define the hyperparameter search space - FULL GRID SEARCH.
|
| 48 |
+
|
| 49 |
+
Tests all combinations of parameters including momentum overrides for configs that support it.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
# Define all parameter values
|
| 53 |
+
cfg_scales = [3.0, 5.0, 7.5] #
|
| 54 |
+
|
| 55 |
+
# All available gradient configs from grad_ascent_configs.py
|
| 56 |
+
grad_configs = [
|
| 57 |
+
# "constant",
|
| 58 |
+
# "linear",
|
| 59 |
+
"cosine_nesterov",
|
| 60 |
+
# "low_to_high_nesterov",
|
| 61 |
+
# "high_to_low_nesterov",
|
| 62 |
+
"low_to_high_momentum",
|
| 63 |
+
"high_to_low_momentum",
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
num_grad_steps_list = [1, 2] # 5, 7, 10
|
| 67 |
+
grad_step_sizes = [0.001, 0.005, 0.01, 0.05] #
|
| 68 |
+
momentums = [0.5, 0.8, 0.9] #
|
| 69 |
+
|
| 70 |
+
# Generate ALL combinations using itertools.product
|
| 71 |
+
configs = []
|
| 72 |
+
for cfg, grad_cfg, num_steps, step_size, momentum in itertools.product(
|
| 73 |
+
cfg_scales, grad_configs, num_grad_steps_list, grad_step_sizes, momentums
|
| 74 |
+
):
|
| 75 |
+
configs.append({
|
| 76 |
+
"cfg_scale": cfg,
|
| 77 |
+
"grad_config": grad_cfg,
|
| 78 |
+
"num_grad_steps": num_steps,
|
| 79 |
+
"grad_step_size": step_size,
|
| 80 |
+
"momentum": momentum,
|
| 81 |
+
})
|
| 82 |
+
|
| 83 |
+
print(f"\nGenerated {len(configs)} total configurations")
|
| 84 |
+
print(f" cfg_scales: {len(cfg_scales)}")
|
| 85 |
+
print(f" grad_configs: {len(grad_configs)}")
|
| 86 |
+
print(f" num_grad_steps: {len(num_grad_steps_list)}")
|
| 87 |
+
print(f" grad_step_sizes: {len(grad_step_sizes)}")
|
| 88 |
+
print(f" momentums: {len(momentums)}")
|
| 89 |
+
print(f" Total: {len(cfg_scales)} × {len(grad_configs)} × {len(num_grad_steps_list)} × {len(grad_step_sizes)} × {len(momentums)} = {len(configs)}")
|
| 90 |
+
|
| 91 |
+
return configs
|
| 92 |
+
|
| 93 |
+
def run_baseline(self) -> Dict[str, float]:
|
| 94 |
+
"""Run baseline evaluation once."""
|
| 95 |
+
print("\n" + "="*80)
|
| 96 |
+
print("RUNNING BASELINE EVALUATION")
|
| 97 |
+
print("="*80)
|
| 98 |
+
|
| 99 |
+
# Use median cfg_scale for baseline
|
| 100 |
+
cfg_scale = 5.0
|
| 101 |
+
|
| 102 |
+
output_dir = self.output_dir / "baseline"
|
| 103 |
+
|
| 104 |
+
cmd = [
|
| 105 |
+
"python", "eval.py",
|
| 106 |
+
"--model_variant", self.model_variant,
|
| 107 |
+
"--dataset_type", self.dataset_type,
|
| 108 |
+
"--max_samples", str(self.max_samples),
|
| 109 |
+
"--num_steps", str(self.num_steps),
|
| 110 |
+
"--cfg_scale", str(cfg_scale),
|
| 111 |
+
"--output_dir", str(output_dir),
|
| 112 |
+
"--cuda", str(self.cuda_id),
|
| 113 |
+
"--mode", "baseline",
|
| 114 |
+
"--metrics", *self.metrics,
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
print(f"Command: {' '.join(cmd)}")
|
| 118 |
+
|
| 119 |
+
try:
|
| 120 |
+
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
| 121 |
+
|
| 122 |
+
# Parse results from output
|
| 123 |
+
metrics = self._parse_metrics(result.stdout, "baseline")
|
| 124 |
+
|
| 125 |
+
print(f"\nBaseline Results:")
|
| 126 |
+
for metric, value in metrics.items():
|
| 127 |
+
print(f" {metric}: {value:.4f}")
|
| 128 |
+
|
| 129 |
+
self.baseline_results = {
|
| 130 |
+
"cfg_scale": cfg_scale,
|
| 131 |
+
"metrics": metrics,
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
return metrics
|
| 135 |
+
|
| 136 |
+
except subprocess.CalledProcessError as e:
|
| 137 |
+
print(f"Error running baseline: {e}")
|
| 138 |
+
print(f"Stdout: {e.stdout}")
|
| 139 |
+
print(f"Stderr: {e.stderr}")
|
| 140 |
+
return {}
|
| 141 |
+
|
| 142 |
+
def run_experiment(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
| 143 |
+
"""Run a single experiment with given hyperparameters."""
|
| 144 |
+
|
| 145 |
+
# Create output directory for this config
|
| 146 |
+
config_name = f"cfg{config['cfg_scale']}_" \
|
| 147 |
+
f"{config['grad_config']}_" \
|
| 148 |
+
f"steps{config['num_grad_steps']}_" \
|
| 149 |
+
f"lr{config['grad_step_size']}_" \
|
| 150 |
+
f"mom{config['momentum']}"
|
| 151 |
+
|
| 152 |
+
output_dir = self.output_dir / config_name
|
| 153 |
+
|
| 154 |
+
# Build command
|
| 155 |
+
cmd = [
|
| 156 |
+
"python", "eval.py",
|
| 157 |
+
"--model_variant", self.model_variant,
|
| 158 |
+
"--dataset_type", self.dataset_type,
|
| 159 |
+
"--grad_config", config["grad_config"],
|
| 160 |
+
"--max_samples", str(self.max_samples),
|
| 161 |
+
"--num_steps", str(self.num_steps),
|
| 162 |
+
"--cfg_scale", str(config["cfg_scale"]),
|
| 163 |
+
"--output_dir", str(output_dir),
|
| 164 |
+
"--cuda", str(self.cuda_id),
|
| 165 |
+
"--mode", "gradient_ascent",
|
| 166 |
+
"--metrics", *self.metrics,
|
| 167 |
+
# Override config parameters
|
| 168 |
+
"--override_num_grad_steps", str(config["num_grad_steps"]),
|
| 169 |
+
"--override_grad_step_size", str(config["grad_step_size"]),
|
| 170 |
+
"--override_momentum", str(config["momentum"]),
|
| 171 |
+
]
|
| 172 |
+
|
| 173 |
+
print(f"\nRunning experiment: {config_name}")
|
| 174 |
+
print(f"Config: {config}")
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
| 178 |
+
|
| 179 |
+
# Parse metrics from output
|
| 180 |
+
metrics = self._parse_metrics(result.stdout, "gradient_ascent")
|
| 181 |
+
|
| 182 |
+
# Compute improvement over baseline
|
| 183 |
+
improvements = {}
|
| 184 |
+
if self.baseline_results:
|
| 185 |
+
baseline_metrics = self.baseline_results["metrics"]
|
| 186 |
+
for metric, value in metrics.items():
|
| 187 |
+
if metric in baseline_metrics:
|
| 188 |
+
baseline_val = baseline_metrics[metric]
|
| 189 |
+
if baseline_val != 0:
|
| 190 |
+
improvement = ((value - baseline_val) / abs(baseline_val)) * 100
|
| 191 |
+
improvements[f"{metric}_improvement"] = improvement
|
| 192 |
+
|
| 193 |
+
result_dict = {
|
| 194 |
+
"config": config,
|
| 195 |
+
"metrics": metrics,
|
| 196 |
+
"improvements": improvements,
|
| 197 |
+
"output_dir": str(output_dir),
|
| 198 |
+
"timestamp": datetime.now().isoformat(),
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
print(f"Results:")
|
| 202 |
+
for metric, value in metrics.items():
|
| 203 |
+
print(f" {metric}: {value:.4f}")
|
| 204 |
+
if improvements:
|
| 205 |
+
print(f"Improvements over baseline:")
|
| 206 |
+
for metric, value in improvements.items():
|
| 207 |
+
print(f" {metric}: {value:+.2f}%")
|
| 208 |
+
|
| 209 |
+
return result_dict
|
| 210 |
+
|
| 211 |
+
except subprocess.CalledProcessError as e:
|
| 212 |
+
print(f"Error running experiment: {e}")
|
| 213 |
+
print(f"Stderr: {e.stderr}")
|
| 214 |
+
return {
|
| 215 |
+
"config": config,
|
| 216 |
+
"error": str(e),
|
| 217 |
+
"timestamp": datetime.now().isoformat(),
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
def _parse_metrics(self, output: str, mode: str) -> Dict[str, float]:
|
| 221 |
+
"""Parse metrics from eval.py output."""
|
| 222 |
+
metrics = {}
|
| 223 |
+
|
| 224 |
+
# Look for the summary section
|
| 225 |
+
lines = output.split('\n')
|
| 226 |
+
|
| 227 |
+
# Pattern to match metric lines like " Reward: 0.1234"
|
| 228 |
+
metric_patterns = {
|
| 229 |
+
"reward": r"Reward:\s+([-+]?\d*\.?\d+)",
|
| 230 |
+
"clip": r"CLIP Score:\s+([-+]?\d*\.?\d+)",
|
| 231 |
+
"aesthetic": r"Aesthetic Score:\s+([-+]?\d*\.?\d+)",
|
| 232 |
+
"pickscore": r"PickScore:\s+([-+]?\d*\.?\d+)",
|
| 233 |
+
"hpsv2": r"HPSv2 Score:\s+([-+]?\d*\.?\d+)",
|
| 234 |
+
"hpsv21": r"HPSv2\.1 Score:\s+([-+]?\d*\.?\d+)",
|
| 235 |
+
"imagereward": r"ImageReward:\s+([-+]?\d*\.?\d+)",
|
| 236 |
+
"fid": r"FID:\s+([-+]?\d*\.?\d+)",
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
for line in lines:
|
| 240 |
+
for metric_name, pattern in metric_patterns.items():
|
| 241 |
+
match = re.search(pattern, line)
|
| 242 |
+
if match:
|
| 243 |
+
metrics[metric_name] = float(match.group(1))
|
| 244 |
+
|
| 245 |
+
return metrics
|
| 246 |
+
|
| 247 |
+
def compute_aggregate_score(self, metrics: Dict[str, float]) -> float:
|
| 248 |
+
"""
|
| 249 |
+
Compute aggregate score for ranking configurations.
|
| 250 |
+
|
| 251 |
+
Uses weighted combination of metrics (higher is better for most,
|
| 252 |
+
except FID which is lower is better).
|
| 253 |
+
"""
|
| 254 |
+
weights = {
|
| 255 |
+
"reward": 1.0,
|
| 256 |
+
"clip": 0.8,
|
| 257 |
+
"aesthetic": 0.8,
|
| 258 |
+
"pickscore": 1.0,
|
| 259 |
+
"hpsv2": 1.0,
|
| 260 |
+
"hpsv21": 1.0,
|
| 261 |
+
"imagereward": 1.0,
|
| 262 |
+
"fid": -0.5, # Negative weight (lower FID is better)
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
score = 0.0
|
| 266 |
+
total_weight = 0.0
|
| 267 |
+
|
| 268 |
+
for metric, value in metrics.items():
|
| 269 |
+
if metric in weights:
|
| 270 |
+
score += weights[metric] * value
|
| 271 |
+
total_weight += abs(weights[metric])
|
| 272 |
+
|
| 273 |
+
# Normalize by total weight
|
| 274 |
+
if total_weight > 0:
|
| 275 |
+
score /= total_weight
|
| 276 |
+
|
| 277 |
+
return score
|
| 278 |
+
|
| 279 |
+
def run_search(
|
| 280 |
+
self,
|
| 281 |
+
search_type: str = "grid",
|
| 282 |
+
start_idx: int = 0,
|
| 283 |
+
end_idx: int = None
|
| 284 |
+
) -> List[Dict[str, Any]]:
|
| 285 |
+
"""
|
| 286 |
+
Run hyperparameter search.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
search_type: Type of search ("grid" or "random")
|
| 290 |
+
start_idx: Starting index for experiments (for GPU distribution)
|
| 291 |
+
end_idx: Ending index for experiments (for GPU distribution)
|
| 292 |
+
"""
|
| 293 |
+
all_configs = self.define_search_space()
|
| 294 |
+
|
| 295 |
+
print("\n" + "="*80)
|
| 296 |
+
print("HYPERPARAMETER SEARCH CONFIGURATION")
|
| 297 |
+
print("="*80)
|
| 298 |
+
print(f"Dataset: {self.dataset_type}")
|
| 299 |
+
print(f"Model: {self.model_variant}")
|
| 300 |
+
print(f"Samples: {self.max_samples}")
|
| 301 |
+
print(f"Inference steps: {self.num_steps}")
|
| 302 |
+
print(f"Metrics: {', '.join(self.metrics)}")
|
| 303 |
+
|
| 304 |
+
# Select subset of configs if indices provided
|
| 305 |
+
if search_type == "grid":
|
| 306 |
+
configs = all_configs
|
| 307 |
+
elif search_type == "random":
|
| 308 |
+
# Random sample from all configs
|
| 309 |
+
n_samples = min(50, len(all_configs))
|
| 310 |
+
indices = np.random.choice(len(all_configs), n_samples, replace=False)
|
| 311 |
+
configs = [all_configs[i] for i in indices]
|
| 312 |
+
else:
|
| 313 |
+
raise ValueError(f"Unknown search type: {search_type}")
|
| 314 |
+
|
| 315 |
+
# Apply index slicing for GPU distribution
|
| 316 |
+
if end_idx is None:
|
| 317 |
+
end_idx = len(configs)
|
| 318 |
+
configs = configs[start_idx:end_idx]
|
| 319 |
+
|
| 320 |
+
print(f"\nTotal configurations: {len(all_configs)}")
|
| 321 |
+
print(f"Assigned to this worker: {len(configs)} (indices {start_idx} to {end_idx})")
|
| 322 |
+
|
| 323 |
+
# Run baseline first
|
| 324 |
+
if self.baseline_results is None:
|
| 325 |
+
self.run_baseline()
|
| 326 |
+
|
| 327 |
+
# Run experiments
|
| 328 |
+
print("\n" + "="*80)
|
| 329 |
+
print("RUNNING EXPERIMENTS")
|
| 330 |
+
print("="*80)
|
| 331 |
+
|
| 332 |
+
for i, config in enumerate(configs, 1):
|
| 333 |
+
print(f"\n{'='*80}")
|
| 334 |
+
print(f"Experiment {i}/{len(configs)}")
|
| 335 |
+
print(f"{'='*80}")
|
| 336 |
+
|
| 337 |
+
result = self.run_experiment(config)
|
| 338 |
+
self.results.append(result)
|
| 339 |
+
|
| 340 |
+
# Save intermediate results
|
| 341 |
+
self._save_results()
|
| 342 |
+
|
| 343 |
+
return self.results
|
| 344 |
+
|
| 345 |
+
def _generate_grid_configs(self, search_space: Dict[str, List[Any]]) -> List[Dict[str, Any]]:
|
| 346 |
+
"""Generate all combinations for grid search."""
|
| 347 |
+
keys = list(search_space.keys())
|
| 348 |
+
values = list(search_space.values())
|
| 349 |
+
|
| 350 |
+
configs = []
|
| 351 |
+
for combination in itertools.product(*values):
|
| 352 |
+
config = dict(zip(keys, combination))
|
| 353 |
+
configs.append(config)
|
| 354 |
+
|
| 355 |
+
return configs
|
| 356 |
+
|
| 357 |
+
def _generate_random_configs(
|
| 358 |
+
self,
|
| 359 |
+
search_space: Dict[str, List[Any]],
|
| 360 |
+
n_samples: int = 20
|
| 361 |
+
) -> List[Dict[str, Any]]:
|
| 362 |
+
"""Generate random configurations for random search."""
|
| 363 |
+
configs = []
|
| 364 |
+
|
| 365 |
+
for _ in range(n_samples):
|
| 366 |
+
config = {}
|
| 367 |
+
for param, values in search_space.items():
|
| 368 |
+
config[param] = np.random.choice(values)
|
| 369 |
+
configs.append(config)
|
| 370 |
+
|
| 371 |
+
return configs
|
| 372 |
+
|
| 373 |
+
def _save_results(self):
|
| 374 |
+
"""Save results to JSON file."""
|
| 375 |
+
results_file = self.output_dir / "tuning_results.json"
|
| 376 |
+
|
| 377 |
+
data = {
|
| 378 |
+
"baseline": self.baseline_results,
|
| 379 |
+
"experiments": self.results,
|
| 380 |
+
"timestamp": datetime.now().isoformat(),
|
| 381 |
+
"config": {
|
| 382 |
+
"max_samples": self.max_samples,
|
| 383 |
+
"num_steps": self.num_steps,
|
| 384 |
+
"dataset_type": self.dataset_type,
|
| 385 |
+
"model_variant": self.model_variant,
|
| 386 |
+
}
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
with open(results_file, 'w') as f:
|
| 390 |
+
json.dump(data, f, indent=2)
|
| 391 |
+
|
| 392 |
+
print(f"\nResults saved to: {results_file}")
|
| 393 |
+
|
| 394 |
+
def analyze_results(self) -> Dict[str, Any]:
|
| 395 |
+
"""Analyze results and find best configuration."""
|
| 396 |
+
if not self.results:
|
| 397 |
+
print("No results to analyze!")
|
| 398 |
+
return {}
|
| 399 |
+
|
| 400 |
+
print("\n" + "="*80)
|
| 401 |
+
print("ANALYSIS: FINDING BEST CONFIGURATION")
|
| 402 |
+
print("="*80)
|
| 403 |
+
|
| 404 |
+
# Filter out failed experiments
|
| 405 |
+
successful_results = [r for r in self.results if "metrics" in r]
|
| 406 |
+
|
| 407 |
+
if not successful_results:
|
| 408 |
+
print("No successful experiments!")
|
| 409 |
+
return {}
|
| 410 |
+
|
| 411 |
+
# Compute aggregate scores
|
| 412 |
+
for result in successful_results:
|
| 413 |
+
metrics = result["metrics"]
|
| 414 |
+
result["aggregate_score"] = self.compute_aggregate_score(metrics)
|
| 415 |
+
|
| 416 |
+
# Sort by aggregate score
|
| 417 |
+
successful_results.sort(key=lambda x: x["aggregate_score"], reverse=True)
|
| 418 |
+
|
| 419 |
+
# Print top 5 configurations
|
| 420 |
+
print("\nTop 5 Configurations:")
|
| 421 |
+
print("="*80)
|
| 422 |
+
|
| 423 |
+
for i, result in enumerate(successful_results[:5], 1):
|
| 424 |
+
print(f"\n#{i} - Aggregate Score: {result['aggregate_score']:.4f}")
|
| 425 |
+
print(f"Config: {result['config']}")
|
| 426 |
+
print(f"Metrics:")
|
| 427 |
+
for metric, value in result['metrics'].items():
|
| 428 |
+
print(f" {metric}: {value:.4f}")
|
| 429 |
+
if result.get('improvements'):
|
| 430 |
+
print(f"Improvements over baseline:")
|
| 431 |
+
for metric, value in result['improvements'].items():
|
| 432 |
+
print(f" {metric}: {value:+.2f}%")
|
| 433 |
+
|
| 434 |
+
# Save best config
|
| 435 |
+
best_result = successful_results[0]
|
| 436 |
+
best_config_file = self.output_dir / "best_config.json"
|
| 437 |
+
|
| 438 |
+
with open(best_config_file, 'w') as f:
|
| 439 |
+
json.dump({
|
| 440 |
+
"config": best_result["config"],
|
| 441 |
+
"metrics": best_result["metrics"],
|
| 442 |
+
"aggregate_score": best_result["aggregate_score"],
|
| 443 |
+
"improvements": best_result.get("improvements", {}),
|
| 444 |
+
}, f, indent=2)
|
| 445 |
+
|
| 446 |
+
print(f"\n✓ Best configuration saved to: {best_config_file}")
|
| 447 |
+
|
| 448 |
+
return best_result
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def main():
|
| 452 |
+
parser = argparse.ArgumentParser(description="Hyperparameter tuning for gradient ascent")
|
| 453 |
+
parser.add_argument("--output_dir", type=str, default="tuning_results",
|
| 454 |
+
help="Directory to save tuning results")
|
| 455 |
+
parser.add_argument("--max_samples", type=int, default=30,
|
| 456 |
+
help="Number of samples to use for tuning")
|
| 457 |
+
parser.add_argument("--num_steps", type=int, default=20,
|
| 458 |
+
help="Number of inference steps (fixed)")
|
| 459 |
+
parser.add_argument("--dataset_type", type=str, default="pickapic",
|
| 460 |
+
choices=["coco", "pickapic"],
|
| 461 |
+
help="Dataset to use")
|
| 462 |
+
parser.add_argument("--model_variant", type=str, default="lpo",
|
| 463 |
+
choices=["origin", "spo", "diffusion_dpo", "lpo"],
|
| 464 |
+
help="Model variant to use")
|
| 465 |
+
parser.add_argument("--cuda", type=int, default=0,
|
| 466 |
+
help="CUDA device ID")
|
| 467 |
+
parser.add_argument("--search_type", type=str, default="grid",
|
| 468 |
+
choices=["grid", "random"],
|
| 469 |
+
help="Type of hyperparameter search")
|
| 470 |
+
parser.add_argument("--metrics", type=str, nargs="+",
|
| 471 |
+
default=["clip", "aesthetic", "pickscore", "hpsv2", "imagereward"],
|
| 472 |
+
help="Metrics to evaluate")
|
| 473 |
+
parser.add_argument("--start_idx", type=int, default=0,
|
| 474 |
+
help="Starting index for experiments (for GPU distribution)")
|
| 475 |
+
parser.add_argument("--end_idx", type=int, default=None,
|
| 476 |
+
help="Ending index for experiments (for GPU distribution)")
|
| 477 |
+
|
| 478 |
+
args = parser.parse_args()
|
| 479 |
+
|
| 480 |
+
# Create tuner
|
| 481 |
+
tuner = HyperparameterTuner(
|
| 482 |
+
output_dir=args.output_dir,
|
| 483 |
+
max_samples=args.max_samples,
|
| 484 |
+
num_steps=args.num_steps,
|
| 485 |
+
dataset_type=args.dataset_type,
|
| 486 |
+
model_variant=args.model_variant,
|
| 487 |
+
cuda_id=args.cuda,
|
| 488 |
+
metrics=args.metrics,
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
# Run search
|
| 492 |
+
results = tuner.run_search(
|
| 493 |
+
search_type=args.search_type,
|
| 494 |
+
start_idx=args.start_idx,
|
| 495 |
+
end_idx=args.end_idx
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
# Analyze results
|
| 499 |
+
best_result = tuner.analyze_results()
|
| 500 |
+
|
| 501 |
+
print("\n" + "="*80)
|
| 502 |
+
print("TUNING COMPLETE!")
|
| 503 |
+
print("="*80)
|
| 504 |
+
print(f"Total experiments: {len(results)}")
|
| 505 |
+
print(f"Results directory: {args.output_dir}")
|
| 506 |
+
|
| 507 |
+
if best_result:
|
| 508 |
+
print(f"\nBest configuration:")
|
| 509 |
+
print(json.dumps(best_result["config"], indent=2))
|
| 510 |
+
print(f"\nAggregate score: {best_result['aggregate_score']:.4f}")
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
if __name__ == "__main__":
|
| 514 |
+
main()
|
Reward_sana_idealized/tune_parallel.sh
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Parallel hyperparameter tuning across 8 GPUs
|
| 4 |
+
# This script distributes experiments evenly across all available GPUs
|
| 5 |
+
|
| 6 |
+
clear
|
| 7 |
+
|
| 8 |
+
# Activate conda environment
|
| 9 |
+
source ~/miniconda3/etc/profile.d/conda.sh
|
| 10 |
+
conda activate /home/ec2-user/aev
|
| 11 |
+
|
| 12 |
+
# Configuration
|
| 13 |
+
DATASET_TYPE="pickapic" # "coco" or "pickapic"
|
| 14 |
+
MODEL_VARIANT="lpo" # "origin", "spo", "diffusion_dpo", or "lpo"
|
| 15 |
+
MAX_SAMPLES=500 # Number of samples for tuning
|
| 16 |
+
NUM_STEPS=50 # Fixed inference steps
|
| 17 |
+
SEARCH_TYPE="grid" # "grid" or "random"
|
| 18 |
+
OUTPUT_DIR="RESULTS_TURNING/run_2"
|
| 19 |
+
NUM_GPUS=8 # Number of GPUs to use
|
| 20 |
+
|
| 21 |
+
echo "=============================================="
|
| 22 |
+
echo " PARALLEL HYPERPARAMETER TUNING"
|
| 23 |
+
echo "=============================================="
|
| 24 |
+
echo ""
|
| 25 |
+
echo "Configuration:"
|
| 26 |
+
echo " Dataset: $DATASET_TYPE"
|
| 27 |
+
echo " Model: $MODEL_VARIANT"
|
| 28 |
+
echo " Samples: $MAX_SAMPLES"
|
| 29 |
+
echo " Inference Steps: $NUM_STEPS"
|
| 30 |
+
echo " Search Type: $SEARCH_TYPE"
|
| 31 |
+
echo " GPUs: $NUM_GPUS"
|
| 32 |
+
echo " Output: $OUTPUT_DIR"
|
| 33 |
+
echo ""
|
| 34 |
+
|
| 35 |
+
# First, calculate total number of experiments
|
| 36 |
+
echo "Calculating total experiments..."
|
| 37 |
+
TOTAL_CONFIGS=$(python -c "
|
| 38 |
+
from tune_hyperparams import HyperparameterTuner
|
| 39 |
+
import sys
|
| 40 |
+
tuner = HyperparameterTuner()
|
| 41 |
+
configs = tuner.define_search_space()
|
| 42 |
+
sys.stderr.write(f'Generated {len(configs)} configurations\n')
|
| 43 |
+
print(len(configs))
|
| 44 |
+
" 2>&1 | tail -1)
|
| 45 |
+
|
| 46 |
+
echo "Total configurations: $TOTAL_CONFIGS"
|
| 47 |
+
echo ""
|
| 48 |
+
|
| 49 |
+
# Calculate experiments per GPU
|
| 50 |
+
CONFIGS_PER_GPU=$((TOTAL_CONFIGS / NUM_GPUS))
|
| 51 |
+
REMAINDER=$((TOTAL_CONFIGS % NUM_GPUS))
|
| 52 |
+
|
| 53 |
+
echo "Distributing work:"
|
| 54 |
+
echo " Base configs per GPU: $CONFIGS_PER_GPU"
|
| 55 |
+
echo " Extra configs for first GPUs: $REMAINDER"
|
| 56 |
+
echo ""
|
| 57 |
+
|
| 58 |
+
# Create output directory
|
| 59 |
+
mkdir -p "$OUTPUT_DIR"
|
| 60 |
+
|
| 61 |
+
# Array to store background process IDs
|
| 62 |
+
PIDS=()
|
| 63 |
+
|
| 64 |
+
# Launch parallel processes on each GPU
|
| 65 |
+
for GPU_ID in $(seq 0 $((NUM_GPUS - 1))); do
|
| 66 |
+
# Calculate start and end indices for this GPU
|
| 67 |
+
START_IDX=$((GPU_ID * CONFIGS_PER_GPU))
|
| 68 |
+
|
| 69 |
+
# Give extra configs to first GPUs
|
| 70 |
+
if [ $GPU_ID -lt $REMAINDER ]; then
|
| 71 |
+
START_IDX=$((START_IDX + GPU_ID))
|
| 72 |
+
END_IDX=$((START_IDX + CONFIGS_PER_GPU + 1))
|
| 73 |
+
else
|
| 74 |
+
START_IDX=$((START_IDX + REMAINDER))
|
| 75 |
+
END_IDX=$((START_IDX + CONFIGS_PER_GPU))
|
| 76 |
+
fi
|
| 77 |
+
|
| 78 |
+
# Create GPU-specific output directory
|
| 79 |
+
GPU_OUTPUT_DIR="${OUTPUT_DIR}/gpu_${GPU_ID}"
|
| 80 |
+
mkdir -p "$GPU_OUTPUT_DIR"
|
| 81 |
+
|
| 82 |
+
echo "GPU $GPU_ID: configs $START_IDX to $END_IDX"
|
| 83 |
+
|
| 84 |
+
# Launch tuning process in background
|
| 85 |
+
nohup python tune_hyperparams.py \
|
| 86 |
+
--output_dir "$GPU_OUTPUT_DIR" \
|
| 87 |
+
--max_samples $MAX_SAMPLES \
|
| 88 |
+
--num_steps $NUM_STEPS \
|
| 89 |
+
--dataset_type "$DATASET_TYPE" \
|
| 90 |
+
--model_variant "$MODEL_VARIANT" \
|
| 91 |
+
--cuda $GPU_ID \
|
| 92 |
+
--search_type "$SEARCH_TYPE" \
|
| 93 |
+
--start_idx $START_IDX \
|
| 94 |
+
--end_idx $END_IDX \
|
| 95 |
+
--metrics clip aesthetic pickscore hpsv2 imagereward \
|
| 96 |
+
> "${GPU_OUTPUT_DIR}/tuning.log" 2>&1 &
|
| 97 |
+
|
| 98 |
+
# Store PID
|
| 99 |
+
PIDS+=($!)
|
| 100 |
+
|
| 101 |
+
echo " Launched with PID: ${PIDS[$GPU_ID]}"
|
| 102 |
+
|
| 103 |
+
# Small delay to avoid race conditions
|
| 104 |
+
sleep 2
|
| 105 |
+
done
|
| 106 |
+
|
| 107 |
+
echo ""
|
| 108 |
+
echo "=============================================="
|
| 109 |
+
echo " ALL PROCESSES LAUNCHED"
|
| 110 |
+
echo "=============================================="
|
| 111 |
+
echo ""
|
| 112 |
+
echo "Background processes running:"
|
| 113 |
+
for GPU_ID in $(seq 0 $((NUM_GPUS - 1))); do
|
| 114 |
+
echo " GPU $GPU_ID: PID ${PIDS[$GPU_ID]} -> ${OUTPUT_DIR}/gpu_${GPU_ID}/tuning.log"
|
| 115 |
+
done
|
| 116 |
+
echo ""
|
| 117 |
+
echo "To monitor progress:"
|
| 118 |
+
echo " tail -f ${OUTPUT_DIR}/gpu_0/tuning.log"
|
| 119 |
+
echo " tail -f ${OUTPUT_DIR}/gpu_1/tuning.log"
|
| 120 |
+
echo " ... etc"
|
| 121 |
+
echo ""
|
| 122 |
+
echo "To check all GPU processes:"
|
| 123 |
+
echo " ps aux | grep tune_hyperparams.py"
|
| 124 |
+
echo ""
|
| 125 |
+
echo "To monitor GPU usage:"
|
| 126 |
+
echo " watch -n 1 nvidia-smi"
|
| 127 |
+
echo ""
|
| 128 |
+
echo "To kill all processes:"
|
| 129 |
+
echo " kill ${PIDS[@]}"
|
| 130 |
+
echo ""
|
| 131 |
+
echo "Waiting for all processes to complete..."
|
| 132 |
+
echo "(Press Ctrl+C to stop waiting, processes will continue in background)"
|
| 133 |
+
echo ""
|
| 134 |
+
|
| 135 |
+
# Wait for all background processes
|
| 136 |
+
for PID in "${PIDS[@]}"; do
|
| 137 |
+
wait $PID
|
| 138 |
+
done
|
| 139 |
+
|
| 140 |
+
echo ""
|
| 141 |
+
echo "=============================================="
|
| 142 |
+
echo " ALL TUNING PROCESSES COMPLETE"
|
| 143 |
+
echo "=============================================="
|
| 144 |
+
echo ""
|
| 145 |
+
|
| 146 |
+
# Merge results from all GPUs
|
| 147 |
+
echo "Merging results from all GPUs..."
|
| 148 |
+
|
| 149 |
+
# Activate conda environment for Python script
|
| 150 |
+
source ~/miniconda3/etc/profile.d/conda.sh
|
| 151 |
+
conda activate /home/ec2-user/aev
|
| 152 |
+
|
| 153 |
+
python - <<'EOF'
|
| 154 |
+
import json
|
| 155 |
+
from pathlib import Path
|
| 156 |
+
import sys
|
| 157 |
+
|
| 158 |
+
output_dir = Path("RESULTS_TURNING")
|
| 159 |
+
all_results = []
|
| 160 |
+
baseline_result = None
|
| 161 |
+
|
| 162 |
+
# Collect results from each GPU
|
| 163 |
+
for gpu_id in range(8):
|
| 164 |
+
gpu_dir = output_dir / f"gpu_{gpu_id}"
|
| 165 |
+
results_file = gpu_dir / "tuning_results.json"
|
| 166 |
+
|
| 167 |
+
if results_file.exists():
|
| 168 |
+
with open(results_file, 'r') as f:
|
| 169 |
+
data = json.load(f)
|
| 170 |
+
|
| 171 |
+
# Get baseline (should be same from all)
|
| 172 |
+
if baseline_result is None and "baseline" in data:
|
| 173 |
+
baseline_result = data["baseline"]
|
| 174 |
+
|
| 175 |
+
# Collect experiments
|
| 176 |
+
if "experiments" in data:
|
| 177 |
+
all_results.extend(data["experiments"])
|
| 178 |
+
|
| 179 |
+
print(f"GPU {gpu_id}: {len(data.get('experiments', []))} results")
|
| 180 |
+
|
| 181 |
+
# Merge all results
|
| 182 |
+
merged_data = {
|
| 183 |
+
"baseline": baseline_result,
|
| 184 |
+
"experiments": all_results,
|
| 185 |
+
"num_gpus": 8,
|
| 186 |
+
"total_experiments": len(all_results)
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
# Save merged results
|
| 190 |
+
merged_file = output_dir / "merged_results.json"
|
| 191 |
+
with open(merged_file, 'w') as f:
|
| 192 |
+
json.dump(merged_data, f, indent=2)
|
| 193 |
+
|
| 194 |
+
print(f"\nMerged {len(all_results)} total results")
|
| 195 |
+
print(f"Saved to: {merged_file}")
|
| 196 |
+
|
| 197 |
+
# Find best configuration
|
| 198 |
+
successful = [r for r in all_results if "metrics" in r]
|
| 199 |
+
if successful:
|
| 200 |
+
# Compute aggregate scores
|
| 201 |
+
def compute_score(metrics):
|
| 202 |
+
weights = {
|
| 203 |
+
"reward": 1.0, "clip": 0.8, "aesthetic": 0.8,
|
| 204 |
+
"pickscore": 1.0, "hpsv2": 1.0, "imagereward": 1.0,
|
| 205 |
+
"fid": -0.5
|
| 206 |
+
}
|
| 207 |
+
score = sum(weights.get(k, 0) * v for k, v in metrics.items())
|
| 208 |
+
return score / sum(abs(w) for w in weights.values())
|
| 209 |
+
|
| 210 |
+
for r in successful:
|
| 211 |
+
r["aggregate_score"] = compute_score(r["metrics"])
|
| 212 |
+
|
| 213 |
+
successful.sort(key=lambda x: x["aggregate_score"], reverse=True)
|
| 214 |
+
|
| 215 |
+
best = successful[0]
|
| 216 |
+
best_file = output_dir / "best_config.json"
|
| 217 |
+
with open(best_file, 'w') as f:
|
| 218 |
+
json.dump({
|
| 219 |
+
"config": best["config"],
|
| 220 |
+
"metrics": best["metrics"],
|
| 221 |
+
"aggregate_score": best["aggregate_score"],
|
| 222 |
+
"improvements": best.get("improvements", {})
|
| 223 |
+
}, f, indent=2)
|
| 224 |
+
|
| 225 |
+
print(f"\n{'='*60}")
|
| 226 |
+
print("BEST CONFIGURATION:")
|
| 227 |
+
print(f"{'='*60}")
|
| 228 |
+
print(json.dumps(best["config"], indent=2))
|
| 229 |
+
print(f"\nAggregate Score: {best['aggregate_score']:.4f}")
|
| 230 |
+
print(f"Saved to: {best_file}")
|
| 231 |
+
else:
|
| 232 |
+
print("\nNo successful experiments found!")
|
| 233 |
+
sys.exit(1)
|
| 234 |
+
EOF
|
| 235 |
+
|
| 236 |
+
if [ $? -eq 0 ]; then
|
| 237 |
+
echo ""
|
| 238 |
+
echo "=============================================="
|
| 239 |
+
echo " TUNING COMPLETE!"
|
| 240 |
+
echo "=============================================="
|
| 241 |
+
echo ""
|
| 242 |
+
echo "Results:"
|
| 243 |
+
echo " Merged results: ${OUTPUT_DIR}/merged_results.json"
|
| 244 |
+
echo " Best config: ${OUTPUT_DIR}/best_config.json"
|
| 245 |
+
echo ""
|
| 246 |
+
echo "View best configuration:"
|
| 247 |
+
echo " cat ${OUTPUT_DIR}/best_config.json"
|
| 248 |
+
echo ""
|
| 249 |
+
else
|
| 250 |
+
echo ""
|
| 251 |
+
echo "ERROR: Failed to merge results"
|
| 252 |
+
exit 1
|
| 253 |
+
fi
|
Reward_sdxl_idealized/models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (242 Bytes). View file
|
|
|
Reward_sdxl_idealized/models/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (280 Bytes). View file
|
|
|
Reward_sdxl_idealized/models/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (240 Bytes). View file
|
|
|
Reward_sdxl_idealized/models/__pycache__/reward_model.cpython-39.pyc
ADDED
|
Binary file (9.1 kB). View file
|
|
|
Reward_sdxl_idealized/models/__pycache__/reward_model_sdxl.cpython-310.pyc
ADDED
|
Binary file (9.96 kB). View file
|
|
|