| import json |
| import tempfile |
| import os |
| import glob |
| import shutil |
| import io |
| import time |
| import threading |
| import sys |
|
|
| import gradio as gr |
| import torch |
| from huggingface_hub import hf_hub_download, scan_cache_dir, whoami |
| from safetensors import safe_open |
|
|
| |
| DEFAULT_HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
|
|
| def hf_login(token: str, session_token: str): |
| """Login to Hugging Face with provided token (per-user session).""" |
| if not token: |
| return "β Please provide a token", "Not logged in", session_token |
| |
| try: |
| user_info = whoami(token=token) |
| username = user_info.get('name', 'Unknown') |
| return f"β
Successfully logged in as: {username}", f"β
Logged in as {username}", token |
| except Exception as e: |
| return f"β Login failed: {str(e)}", "β Not logged in", session_token |
|
|
|
|
| def hf_logout(session_token: str): |
| """Logout from Hugging Face (clear session token).""" |
| return "β
Successfully logged out", "Not logged in", None |
|
|
|
|
| def check_hf_status(session_token: str): |
| """Check current HF login status for this session.""" |
| |
| token = session_token or DEFAULT_HF_TOKEN |
| |
| if not token: |
| return "βΉοΈ Not logged in", "Not logged in", session_token |
| |
| try: |
| user_info = whoami(token=token) |
| username = user_info.get('name', 'Unknown') |
| source = "(session)" if session_token else "(default HF_TOKEN)" |
| return f"β
Currently logged in as: {username} {source}", f"β
Logged in as {username}", session_token |
| except Exception: |
| return "βΉοΈ Not logged in", "Not logged in", session_token |
|
|
|
|
| def get_param(model_id: str, param_key: str, log_buffer: io.StringIO, progress: gr.Progress, token: str = None): |
| """ |
| Download and return a specific parameter tensor from a Hugging Face model. |
| """ |
| |
| auth_token = token or DEFAULT_HF_TOKEN |
| |
| |
| original_stderr = sys.stderr |
| sys.stderr = log_buffer |
| |
| try: |
| |
| try: |
| log_buffer.write(f"π₯ Downloading index file for {model_id}...\n") |
| progress(0.1, desc="Downloading index...") |
|
|
| index_path = hf_hub_download( |
| model_id, "model.safetensors.index.json", token=auth_token) |
|
|
| log_buffer.write(f"β Index file found: {index_path}\n") |
|
|
| with open(index_path, "r", encoding="utf-8") as f: |
| index = json.load(f) |
| weight_map = index["weight_map"] |
| if param_key not in weight_map: |
| raise KeyError( |
| f"Parameter '{param_key}' not found in model. Available keys: {list(weight_map.keys())[:10]}..." |
| ) |
| shard_file = weight_map[param_key] |
| log_buffer.write(f"β Parameter found in shard: {shard_file}\n") |
| except Exception as e: |
| if "404" in str(e) or "not found" in str(e).lower(): |
| log_buffer.write("βΉοΈ No index file, trying single model file...\n") |
| shard_file = "model.safetensors" |
| else: |
| raise |
|
|
| log_buffer.write(f"π₯ Downloading shard: {shard_file}...\n") |
| progress(0.3, desc=f"Downloading {shard_file}...") |
|
|
| shard_path = hf_hub_download(model_id, shard_file, token=auth_token) |
|
|
| log_buffer.write(f"\nβ Shard downloaded: {shard_path}\n") |
| progress(0.7, desc="Loading tensor...") |
|
|
| log_buffer.write(f"π Loading tensor '{param_key}'...\n") |
| with safe_open(shard_path, framework="pt") as f: |
| tensor = f.get_tensor(param_key) |
| log_buffer.write(f"β Tensor loaded successfully\n") |
| progress(0.9, desc="Finalizing...") |
|
|
| return tensor |
| finally: |
| |
| sys.stderr = original_stderr |
|
|
|
|
| def get_available_keys(model_id: str, token: str = None): |
| """Get all available parameter keys from a model.""" |
| |
| auth_token = token or DEFAULT_HF_TOKEN |
| |
| try: |
| index_path = hf_hub_download(model_id, "model.safetensors.index.json", token=auth_token) |
| with open(index_path, "r", encoding="utf-8") as f: |
| index = json.load(f) |
| return sorted(index["weight_map"].keys()) |
| except Exception: |
| |
| try: |
| shard_path = hf_hub_download(model_id, "model.safetensors", token=auth_token) |
| with safe_open(shard_path, framework="pt") as f: |
| return sorted(f.keys()) |
| except Exception as e: |
| return [] |
|
|
|
|
| def format_tensor_info(tensor: torch.Tensor) -> str: |
| """Format tensor information for display.""" |
| info = [] |
| info.append(f"**Shape:** {list(tensor.shape)}") |
| info.append(f"**Dtype:** {tensor.dtype}") |
| info.append(f"**Device:** {tensor.device}") |
| info.append(f"**Numel:** {tensor.numel():,}") |
| |
| |
| try: |
| |
| if str(tensor.dtype) in ['torch.float8_e4m3fn', 'torch.float8_e5m2']: |
| stats_tensor = tensor.to(torch.float32) |
| else: |
| stats_tensor = tensor |
| |
| info.append(f"**Min:** {stats_tensor.min().item():.6f}") |
| info.append(f"**Max:** {stats_tensor.max().item():.6f}") |
| info.append(f"**Mean:** {stats_tensor.float().mean().item():.6f}") |
| info.append(f"**Std:** {stats_tensor.float().std().item():.6f}") |
| except Exception as e: |
| info.append(f"**Stats:** Unable to compute (dtype not supported)") |
| |
| return "<br>".join(info) |
|
|
|
|
| def fetch_param(model_id: str, param_key: str, session_token: str, progress=gr.Progress()): |
| """Fetch parameter and return formatted info and tensor preview.""" |
| log_buffer = io.StringIO() |
| last_log_value = "" |
|
|
| if not model_id or not param_key: |
| yield "Please provide both model ID and parameter key.", "", None, "β Missing required inputs" |
| return |
|
|
| try: |
| log_buffer.write(f"π Starting download for {model_id}\n") |
| log_buffer.write(f"π― Target parameter: {param_key}\n\n") |
| progress(0, desc="Initializing...") |
| yield "", "", None, log_buffer.getvalue() |
| time.sleep(0.5) |
|
|
| |
| download_complete = threading.Event() |
| download_error = [None] |
| result_tensor = [None] |
| |
| def download_thread(): |
| try: |
| result_tensor[0] = get_param(model_id, param_key, log_buffer, progress, session_token) |
| except Exception as e: |
| download_error[0] = e |
| finally: |
| download_complete.set() |
| |
| thread = threading.Thread(target=download_thread, daemon=True) |
| thread.start() |
| |
| |
| while not download_complete.is_set(): |
| current_log = log_buffer.getvalue() |
| if current_log != last_log_value: |
| yield "", "", None, current_log |
| last_log_value = current_log |
| time.sleep(1) |
| |
| |
| current_log = log_buffer.getvalue() |
| if current_log != last_log_value: |
| yield "", "", None, current_log |
| last_log_value = current_log |
| |
| |
| if download_error[0]: |
| raise download_error[0] |
| |
| tensor = result_tensor[0] |
| info = format_tensor_info(tensor) |
|
|
| |
| log_buffer.write(f"\nπ Creating preview...\n") |
| yield "", "", None, log_buffer.getvalue() |
|
|
| flat = tensor.flatten() |
| preview_size = min(100, flat.numel()) |
| |
| |
| if str(tensor.dtype) in ['torch.float8_e4m3fn', 'torch.float8_e5m2']: |
| preview = flat[:preview_size].to(torch.float32).tolist() |
| else: |
| preview = flat[:preview_size].tolist() |
|
|
| |
| |
| preview_lines = [] |
| for i in range(0, len(preview), 10): |
| line_values = preview[i:i+10] |
| if tensor.dtype in [torch.float32, torch.float64, torch.float16, torch.bfloat16] or str(tensor.dtype) in ['torch.float8_e4m3fn', 'torch.float8_e5m2']: |
| preview_lines.append(", ".join(f"{v:.6f}" for v in line_values)) |
| elif tensor.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8]: |
| preview_lines.append(", ".join(f"{v}" for v in line_values)) |
| elif tensor.dtype == torch.bool: |
| preview_lines.append(", ".join(f"{v}" for v in line_values)) |
| else: |
| preview_lines.append(", ".join(str(v) for v in line_values)) |
|
|
| preview_str = f"**First {preview_size} values:**\n```\n" + \ |
| "\n".join(preview_lines) + "\n```" |
|
|
| |
| |
|
|
| |
| log_buffer.write(f"πΎ Saving tensor for download...\n") |
| yield info, preview_str, None, log_buffer.getvalue() |
|
|
| temp_dir = tempfile.gettempdir() |
| safe_param_key = param_key.replace("/", "_").replace(".", "_") |
| download_path = os.path.join(temp_dir, f"{safe_param_key}.pt") |
| torch.save(tensor, download_path) |
| log_buffer.write(f"β Saved to: {download_path}\n") |
|
|
| progress(1.0, desc="Complete!") |
| log_buffer.write(f"\nβ
All operations completed successfully!\n") |
| yield info, preview_str, download_path, log_buffer.getvalue() |
| except Exception as e: |
| log_buffer.write(f"\nβ Error: {str(e)}\n") |
| yield f"**Error:** {str(e)}", "", None, log_buffer.getvalue() |
|
|
|
|
| def list_keys(model_id: str, session_token: str): |
| """List all available keys for a model.""" |
| if not model_id: |
| return "Please provide a model ID." |
|
|
| try: |
| keys = get_available_keys(model_id, session_token) |
| if not keys: |
| return "No keys found or failed to load model." |
| return "\n".join(keys) |
| except Exception as e: |
| return f"**Error:** {str(e)}" |
|
|
|
|
| def clear_temp_files(): |
| """Clear all .pt files from temp directory.""" |
| try: |
| temp_dir = tempfile.gettempdir() |
| pt_files = glob.glob(os.path.join(temp_dir, "*.pt")) |
| count = len(pt_files) |
| deleted_files = [] |
| for file in pt_files: |
| try: |
| os.remove(file) |
| deleted_files.append(os.path.basename(file)) |
| except Exception: |
| pass |
|
|
| if deleted_files: |
| files_list = "\n".join(deleted_files) |
| return f"β
Cleared {count} temporary file(s):\n\n{files_list}" |
| else: |
| return "β
No temporary files to clear" |
| except Exception as e: |
| return f"β Error: {str(e)}" |
|
|
|
|
| def clear_hf_cache(): |
| """Clear Hugging Face cache directory.""" |
| try: |
| cache_info = scan_cache_dir() |
| total_size = cache_info.size_on_disk |
| total_repos = len(cache_info.repos) |
|
|
| if total_repos == 0: |
| return "β
Hugging Face cache is already empty" |
|
|
| |
| cache_dir = os.path.expanduser("~/.cache/huggingface/hub") |
| if os.path.exists(cache_dir): |
| shutil.rmtree(cache_dir) |
| os.makedirs(cache_dir) |
| size_mb = total_size / (1024 * 1024) |
| return f"β
Cleared Hugging Face cache: {total_repos} repo(s), {size_mb:.2f} MB freed" |
| else: |
| return "β
Hugging Face cache directory not found" |
| except Exception as e: |
| return f"β Error: {str(e)}" |
|
|
|
|
| def get_cache_info(): |
| """Get size information about caches.""" |
| try: |
| |
| temp_dir = tempfile.gettempdir() |
| pt_files = glob.glob(os.path.join(temp_dir, "*.pt")) |
| temp_size = sum(os.path.getsize(f) |
| for f in pt_files if os.path.exists(f)) |
| temp_size_mb = temp_size / (1024 * 1024) |
|
|
| info = f"π Cache Info:\n\n" |
| info += f"βββ Temp .pt files: {len(pt_files)} file(s), {temp_size_mb:.2f} MB βββ\n" |
|
|
| if pt_files: |
| for file in pt_files: |
| size = os.path.getsize(file) / (1024 * 1024) |
| filename = os.path.basename(file) |
| info += f" β’ {filename} ({size:.2f} MB)\n" |
| else: |
| info += " (empty)\n" |
|
|
| |
| info += f"\nβββ Hugging Face Cache βββ\n" |
| try: |
| cache_info = scan_cache_dir() |
| hf_size_mb = cache_info.size_on_disk / (1024 * 1024) |
| hf_repos = len(cache_info.repos) |
|
|
| info += f"Total: {hf_repos} repo(s), {hf_size_mb:.2f} MB\n\n" |
|
|
| if hf_repos > 0: |
| for repo in cache_info.repos: |
| repo_size = repo.size_on_disk / (1024 * 1024) |
| info += f" π¦ {repo.repo_id}\n" |
| info += f" Size: {repo_size:.2f} MB, Revisions: {len(repo.revisions)}\n" |
| info += f" Last accessed: {repo.last_accessed}\n" |
| else: |
| info += " (empty)\n" |
| except Exception as e: |
| info += f" Error reading HF cache: {str(e)}\n" |
|
|
| info += f"\nβββ Total: {temp_size_mb + (hf_size_mb if 'hf_size_mb' in locals() else 0):.2f} MB βββ" |
| return info |
| except Exception as e: |
| return f"β Error: {str(e)}" |
|
|
|
|
| |
| custom_css = """ |
| * { |
| font-family: Consolas, Monaco, 'Courier New', monospace !important; |
| } |
| .compact-row { |
| gap: 0.5rem !important; |
| } |
| .tensor-preview pre { |
| font-size: 0.75rem !important; |
| line-height: 1.0 !important; |
| } |
| .compact-file { |
| max-height: 80px !important; |
| } |
| .compact-file > div { |
| min-height: 60px !important; |
| } |
| """ |
|
|
| with gr.Blocks(title="Hugging Face Model Weight Inspector") as demo: |
| gr.Markdown("# π Hugging Face Model Weight Inspector") |
| |
| |
| session_token = gr.State(None) |
| |
| |
| with gr.Accordion("π Hugging Face Login (Per-User Session) [β οΈβ οΈβ οΈWIP, Do not useβ οΈβ οΈβ οΈ]", open=False): |
| gr.Markdown(""" |
| **Note:** This Space uses the default `HF_TOKEN` secret for all users if no session token is provided. |
| Login below with your own token for per-user authentication (affects only your session). |
| """) |
| with gr.Row(): |
| with gr.Column(scale=3): |
| hf_token_input = gr.Textbox( |
| label="HF Token", |
| placeholder="hf_...", |
| type="password", |
| ) |
| with gr.Column(scale=2): |
| initial_status = "β
Using default HF_TOKEN" if DEFAULT_HF_TOKEN else "Not logged in" |
| hf_status = gr.Textbox( |
| label="Status", |
| value=initial_status, |
| interactive=False, |
| ) |
| with gr.Row(): |
| login_btn = gr.Button("π Login", variant="primary", scale=1) |
| logout_btn = gr.Button("πͺ Logout", variant="secondary", scale=1) |
| check_status_btn = gr.Button("βΉοΈ Check Status", variant="secondary", scale=1) |
| login_output = gr.Textbox(label="Login Status", interactive=False, lines=2) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| model_id_input = gr.Textbox( |
| label="Model ID", |
| placeholder="e.g., meta-llama/Llama-2-7b-hf", |
| value="Qwen/Qwen3-Coder-Next-FP8", |
| ) |
| param_key_input = gr.Textbox( |
| label="Parameter Key", |
| placeholder="e.g., model.norm.weight", |
| value="model.norm.weight", |
| ) |
| with gr.Row(): |
| list_keys_btn = gr.Button( |
| "π List Keys", variant="secondary", scale=1) |
| fetch_btn = gr.Button("π Fetch", variant="primary", scale=1) |
|
|
| with gr.Column(scale=1): |
| keys_output = gr.Textbox( |
| label="Available Parameter Keys", |
| lines=5, |
| max_lines=8, |
| ) |
|
|
| with gr.Tabs(): |
| with gr.Tab("Results"): |
| with gr.Row(): |
| with gr.Column(scale=3): |
| preview_output = gr.Markdown(label="Tensor Preview", elem_classes="tensor-preview") |
| with gr.Column(scale=1): |
| info_output = gr.Markdown(label="Tensor Info") |
| download_output = gr.File(label="Download Tensor (.pt file)", elem_classes="compact-file") |
| log_output = gr.Textbox( |
| label="π Download Log", lines=1, interactive=False) |
|
|
| with gr.Tab("Cache Management"): |
| with gr.Row(): |
| get_info_btn = gr.Button( |
| "π Get Cache Info", variant="secondary", scale=1) |
| clear_temp_btn = gr.Button( |
| "ποΈ Clear Temp Folder", variant="secondary", scale=1) |
| clear_hf_btn = gr.Button( |
| "ποΈ Clear HF Cache", variant="secondary", scale=1) |
| clear_status = gr.Textbox( |
| label="Status", interactive=False, lines=6) |
|
|
| |
| login_btn.click( |
| fn=hf_login, |
| inputs=[hf_token_input, session_token], |
| outputs=[login_output, hf_status, session_token], |
| ) |
| |
| logout_btn.click( |
| fn=hf_logout, |
| inputs=[session_token], |
| outputs=[login_output, hf_status, session_token], |
| ) |
| |
| check_status_btn.click( |
| fn=check_hf_status, |
| inputs=[session_token], |
| outputs=[login_output, hf_status, session_token], |
| ) |
| |
| list_keys_btn.click( |
| fn=list_keys, |
| inputs=[model_id_input, session_token], |
| outputs=[keys_output], |
| ) |
|
|
| fetch_btn.click( |
| fn=fetch_param, |
| inputs=[model_id_input, param_key_input, session_token], |
| outputs=[info_output, preview_output, download_output, log_output], |
| ) |
|
|
| clear_temp_btn.click( |
| fn=clear_temp_files, |
| inputs=[], |
| outputs=[clear_status], |
| ) |
|
|
| clear_hf_btn.click( |
| fn=clear_hf_cache, |
| inputs=[], |
| outputs=[clear_status], |
| ) |
|
|
| get_info_btn.click( |
| fn=get_cache_info, |
| inputs=[], |
| outputs=[clear_status], |
| ) |
| |
| |
| demo.load( |
| fn=check_hf_status, |
| inputs=[session_token], |
| outputs=[login_output, hf_status, session_token], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", css=custom_css) |
|
|