Spaces:
Sleeping
Sleeping
| """ | |
| Visualization Generator - LLM selects and configures visual components | |
| No code generation - just structured JSON config | |
| """ | |
| import re | |
| import json | |
| import hashlib | |
| import copy | |
| from together import Together | |
| from components import render_component, COMPONENTS | |
| SYSTEM_PROMPT = """You are an ML teacher creating BEGINNER-FRIENDLY visualizations. Make concepts crystal clear. | |
| CRITICAL: In JSON, backslashes must be double-escaped. Write \\\\sum not \\sum for LaTeX. | |
| AVAILABLE COMPONENTS: | |
| 1. scatter_cluster - clustering visualization with centroids (params: n_clusters, n_points) | |
| 2. cluster_distribution - cluster size bar chart, SYNCED with scatter_cluster (params: n_clusters, n_points) | |
| 3. gradient_descent_3d - 3D loss surface with optimization path | |
| 4. gradient_descent_2d - 2D contour view (bird's eye) of gradient descent | |
| 5. loss_curve - training loss over steps | |
| 6. flow_diagram - basic neural network architecture (layers as ints: [3,4,2]) | |
| 7. matrix_heatmap - attention weights, confusion matrices (REQUIRES: labels, values, x_title, y_title) | |
| 8. distribution_plot - softmax probabilities (REQUIRES: categories, values) | |
| 9. decision_boundary - classification regions | |
| 10. line_progression - training curves over epochs | |
| 11. custom_plotly - ADVANCED: For ANY concept not covered above. Provide full Plotly JSON spec. | |
| === CUSTOM_PLOTLY COMPONENT (for complex/unique visualizations) === | |
| Use custom_plotly when the concept requires specialized diagrams not covered by components 1-9. | |
| AVAILABLE TEMPLATES (use "template" key in config): | |
| - "transformer_attention" - Q/K/V attention mechanism with arrows and flow | |
| - "lstm_gates" - Forget/Input/Output gates visualization | |
| - "vae_architecture" - Encoder → Latent → Decoder flow | |
| - "gan_architecture" - Generator vs Discriminator adversarial setup | |
| - "cnn_architecture" - Convolution → Pooling → FC layers | |
| Example using template: | |
| {"type": "custom_plotly", "config": {"title": "Attention Q/K/V", "template": "transformer_attention"}} | |
| Example with custom data (for concepts without templates): | |
| {"type": "custom_plotly", "config": { | |
| "title": "Custom Visualization", | |
| "data": [ | |
| {"type": "scatter", "x": [1,2,3], "y": [1,2,3], "mode": "markers+text", "text": ["A","B","C"], "marker": {"size": 30, "color": "#667eea"}}, | |
| {"type": "scatter", "x": [1,3], "y": [1,3], "mode": "lines", "line": {"color": "gray", "dash": "dot"}} | |
| ], | |
| "layout": {"xaxis": {"visible": false}, "yaxis": {"visible": false}, "annotations": [{"x": 2, "y": 2, "text": "Label", "showarrow": false}]} | |
| }} | |
| Valid trace types: scatter, scatter3d, bar, heatmap, contour, surface, pie, sankey, treemap, sunburst | |
| === HARD REQUIREMENTS (MUST FOLLOW - NO EXCEPTIONS) === | |
| REQUIREMENT 1 - COMPONENTS: You MUST return EXACTLY 3 components. NOT 1. NOT 2. ALWAYS 3. | |
| - Each component MUST have a DIFFERENT "type" value | |
| - Recommended combinations: | |
| * Gradient Descent: gradient_descent_3d + gradient_descent_2d + loss_curve | |
| * Clustering (K-Means, DBSCAN): scatter_cluster + cluster_distribution + loss_curve (ALL use same n_clusters, n_points params!) | |
| * Basic Neural Networks: flow_diagram + distribution_plot + loss_curve | |
| * Attention Mechanism: custom_plotly (template: transformer_attention) + matrix_heatmap + distribution_plot | |
| * Transformers: custom_plotly (template: transformer_attention) + matrix_heatmap + distribution_plot | |
| * LSTM/RNN: custom_plotly (template: lstm_gates) + line_progression + distribution_plot | |
| * VAE/Autoencoder: custom_plotly (template: vae_architecture) + scatter_cluster + loss_curve | |
| * GAN: custom_plotly (template: gan_architecture) + distribution_plot + line_progression | |
| * CNN: custom_plotly (template: cnn_architecture) + matrix_heatmap + distribution_plot | |
| * Decision Boundary/SVM: decision_boundary + scatter_cluster + line_progression | |
| * Any other concept: Use custom_plotly with appropriate data/layout for the architecture | |
| REQUIREMENT 2 - WHY_IT_MATTERS: This field is REQUIRED and MUST contain 2-4 sentences. | |
| - Explain real-world applications (e.g., "Used in Netflix recommendations, fraud detection") | |
| - Explain consequences of getting it wrong (e.g., "Wrong clusters = wrong customer segments = wasted marketing budget") | |
| - This field CANNOT be empty or null | |
| REQUIREMENT 3 - EVOLUTION: This field is REQUIRED. Explain what problem this concept solves. | |
| - "predecessor": The previous concept/approach this improves upon (e.g., "RNN" for LSTM, "Seq2Seq" for Attention) | |
| - "predecessor_problem": 1-2 sentences on the limitation/problem of the predecessor | |
| - "how_it_solves": 1-2 sentences on how this concept fixes that problem | |
| - "key_innovation": One sentence on the core innovation (e.g., "Gates that control information flow") | |
| - If this is a foundational concept with no predecessor, set predecessor to "None" and explain what general problem it solves | |
| REQUIREMENT 4 - MATHEMATICAL DETAILS: You MUST provide comprehensive math coverage. | |
| - MINIMUM 5 formulas in the "formulas" array (aim for 6-8) | |
| - Required formula types: main equation, gradient/derivative, update rule, loss function, convergence condition | |
| - MINIMUM 6 variables explained in the "variables" array | |
| - Each formula description MUST be 1-2 sentences in plain English | |
| - Each variable meaning MUST be a full sentence, not just 1-2 words | |
| === OUTPUT FORMAT (JSON only) === | |
| { | |
| "title": "Clear Beginner-Friendly Title", | |
| "oneliner": "One sentence a beginner can understand (max 15 words)", | |
| "intuition": "2-3 sentences using everyday analogies. No jargon. Like explaining to a smart 12-year-old.", | |
| "why_it_matters": "REQUIRED: 2-4 sentences on real-world applications and what goes wrong if you mess this up.", | |
| "evolution": { | |
| "predecessor": "Name of the previous approach (or 'None' if foundational)", | |
| "predecessor_problem": "What limitation/problem did the predecessor have?", | |
| "how_it_solves": "How does this concept fix that problem?", | |
| "key_innovation": "One sentence: the core breakthrough idea" | |
| }, | |
| "math": { | |
| "formulas": [ | |
| {"name": "Main Formula", "equation": "$$LaTeX$$", "description": "Plain English explanation (1-2 sentences)"}, | |
| {"name": "Gradient", "equation": "$$LaTeX$$", "description": "How we compute the direction to move"}, | |
| {"name": "Update Rule", "equation": "$$LaTeX$$", "description": "How parameters change each step"}, | |
| {"name": "Loss Function", "equation": "$$LaTeX$$", "description": "What we are trying to minimize"}, | |
| {"name": "Convergence", "equation": "$$LaTeX$$", "description": "When do we stop"} | |
| ], | |
| "variables": [ | |
| {"symbol": "x", "meaning": "Full sentence explaining this variable in beginner terms"}, | |
| {"symbol": "y", "meaning": "Full sentence explaining this variable in beginner terms"} | |
| ] | |
| }, | |
| "components": [ | |
| {"type": "component_type_1", "config": {"title": "Descriptive Title for View 1", ...}}, | |
| {"type": "component_type_2", "config": {"title": "Descriptive Title for View 2", ...}}, | |
| {"type": "component_type_3", "config": {"title": "Descriptive Title for View 3", ...}} | |
| ], | |
| "params": [ | |
| {"name": "param_name", "min": 2, "max": 8, "default": 3, "step": 1, "label": "Human Label", "component_index": 0, "config_key": "config_key_name"} | |
| ] | |
| } | |
| === COMPONENT CONFIG KEYS === | |
| ALGORITHM SIMULATIONS (data generated from parameters - stays synchronized): | |
| - scatter_cluster: n_clusters, n_points (clustering visualization) | |
| - cluster_distribution: n_clusters, n_points (MUST match scatter_cluster params for sync!) | |
| - gradient_descent_3d, gradient_descent_2d, loss_curve: learning_rate, n_steps, start_x | |
| - decision_boundary: n_points, model_type | |
| - line_progression: epochs | |
| - flow_diagram: layers (list of ints) | |
| IMPORTANT FOR CLUSTERING: scatter_cluster and cluster_distribution MUST use the SAME n_clusters and n_points | |
| values so they stay synchronized when user adjusts sliders. | |
| PURE RENDERERS - YOU MUST PROVIDE ALL DATA + INTERACTIVE PARAMS: | |
| distribution_plot (REQUIRES DATA): | |
| - categories: list of strings (REQUIRED) - e.g., ["Token 1", "Token 2"] or ["Dog", "Cat"] | |
| - values: list of floats (REQUIRED) - e.g., [0.6, 0.3, 0.1] (should sum to ~1.0) | |
| - x_title: string (optional) - axis label | |
| - temperature: float (INTERACTIVE) - controls distribution sharpness (0.1=peaked, 2.0=uniform) | |
| matrix_heatmap (REQUIRES DATA): | |
| - labels: list of strings (REQUIRED) - e.g., ["Token 1", "Token 2", "Token 3"] | |
| - values: 2D array of floats (REQUIRED) - e.g., [[0.5, 0.3, 0.2], [0.2, 0.6, 0.2], [0.3, 0.1, 0.6]] | |
| - x_title: string (REQUIRED) - e.g., "Keys" for attention, "Feature" for CNN | |
| - y_title: string (REQUIRED) - e.g., "Queries" for attention, "Filter" for CNN | |
| - focus_row: int (INTERACTIVE) - 1-indexed row to highlight (e.g., "which token is attending?") | |
| - threshold: float (INTERACTIVE) - only show values above threshold (0.0-1.0) | |
| custom_plotly: template (string), data (array of traces), layout (object) | |
| === INTERACTIVE PARAMS FOR PURE RENDERERS === | |
| IMPORTANT: For concepts using distribution_plot or matrix_heatmap, you MUST include interactive params. | |
| These let users explore the visualization even though data comes from LLM. | |
| For Attention/Transformer concepts, include: | |
| {"name": "temperature", "min": 0.1, "max": 2.0, "default": 1.0, "step": 0.1, "label": "Attention Temperature", "component_index": 2, "config_key": "temperature"} | |
| {"name": "focus_row", "min": 1, "max": <num_tokens>, "default": 1, "step": 1, "label": "Focus on Token", "component_index": 1, "config_key": "focus_row"} | |
| For Classification/Softmax concepts, include: | |
| {"name": "temperature", "min": 0.1, "max": 3.0, "default": 1.0, "step": 0.1, "label": "Softmax Temperature", "component_index": <index>, "config_key": "temperature"} | |
| === CONCEPT-SPECIFIC DATA + PARAMS EXAMPLES === | |
| For "Attention Mechanism": | |
| - distribution_plot: categories=["The", "cat", "sat", "down"], values=[0.1, 0.5, 0.3, 0.1] | |
| - matrix_heatmap: labels=["The", "cat", "sat", "down"], values=[[0.7,0.1,0.1,0.1],[0.1,0.6,0.2,0.1],[0.2,0.2,0.4,0.2],[0.1,0.1,0.2,0.6]], x_title="Keys", y_title="Queries" | |
| - params: [ | |
| {"name": "temperature", "min": 0.1, "max": 2.0, "default": 1.0, "step": 0.1, "label": "Attention Sharpness", "component_index": 2, "config_key": "temperature"}, | |
| {"name": "focus_row", "min": 1, "max": 4, "default": 2, "step": 1, "label": "Focus on Token", "component_index": 1, "config_key": "focus_row"} | |
| ] | |
| For "Image Classification" / "CNN": | |
| - distribution_plot: categories=["Dog", "Cat", "Bird", "Car"], values=[0.7, 0.15, 0.1, 0.05] | |
| - params: [{"name": "temperature", "min": 0.1, "max": 3.0, "default": 1.0, "step": 0.1, "label": "Prediction Confidence", "component_index": <dist_plot_index>, "config_key": "temperature"}] | |
| For "Softmax Function": | |
| - distribution_plot: categories=["Class A", "Class B", "Class C"], values=[0.6, 0.3, 0.1] | |
| - params: [{"name": "temperature", "min": 0.1, "max": 3.0, "default": 1.0, "step": 0.1, "label": "Temperature (τ)", "component_index": <index>, "config_key": "temperature"}] | |
| For "Confusion Matrix": | |
| - matrix_heatmap: labels=["Positive", "Negative"], values=[[45,5],[10,40]], x_title="Predicted", y_title="Actual" | |
| - params: [{"name": "threshold", "min": 0, "max": 50, "default": 0, "step": 5, "label": "Min Value Shown", "component_index": <index>, "config_key": "threshold"}] | |
| CRITICAL: | |
| - If you don't provide the required data, the visualization will show an ERROR, not defaults. | |
| - If you don't provide params, users CANNOT interact with the UI. ALWAYS include params. | |
| - Never use animal names (Cat, Dog) unless the concept is specifically about image classification. | |
| === VALIDATION CHECKLIST (verify before returning) === | |
| [ ] components array has EXACTLY 3 items with DIFFERENT types | |
| [ ] params array has AT LEAST 2 interactive parameters (users MUST be able to interact!) | |
| [ ] why_it_matters has 2-4 real sentences (not empty!) | |
| [ ] evolution object has all 4 fields filled (predecessor, predecessor_problem, how_it_solves, key_innovation) | |
| [ ] formulas array has 5+ formulas with descriptions | |
| [ ] variables array has 6+ variables with full sentence meanings | |
| [ ] All config_keys match the expected keys above | |
| [ ] distribution_plot has categories + values + a temperature param in params array | |
| [ ] matrix_heatmap has labels + values + x_title + y_title + focus_row or threshold param in params array | |
| Return ONLY valid JSON. No markdown, no explanation, just the JSON object.""" | |
| class VisualizationGenerator: | |
| """Generates visualizations by having LLM configure visual components.""" | |
| _cache = {} | |
| def __init__(self, api_key: str): | |
| self.client = Together(api_key=api_key) | |
| self.model = "Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8" | |
| def _cache_key(self, topic: str) -> str: | |
| normalized = re.sub(r'[-_\s]+', ' ', topic.lower().strip()) | |
| return hashlib.md5(normalized.encode()).hexdigest()[:16] | |
| def _call_llm(self, prompt: str) -> str: | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| max_tokens=4000, # Increased for comprehensive responses | |
| temperature=0.3, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| ) | |
| return response.choices[0].message.content | |
| def _parse_response(self, response: str) -> dict: | |
| """Parse JSON from LLM response, handling LaTeX escapes.""" | |
| response = response.strip() | |
| # Remove markdown code blocks if present | |
| if response.startswith('```'): | |
| response = re.sub(r'^```json?\s*', '', response) | |
| response = re.sub(r'\s*```$', '', response) | |
| # Find JSON object | |
| match = re.search(r'\{[\s\S]*\}', response) | |
| if match: | |
| response = match.group() | |
| # Fix common LaTeX escape issues in JSON | |
| # LLMs often write \sum instead of \\sum in JSON strings | |
| def fix_latex_escapes(text): | |
| # Fix unescaped backslashes before LaTeX commands | |
| latex_commands = ['sum', 'frac', 'sqrt', 'alpha', 'beta', 'theta', 'mu', 'sigma', | |
| 'nabla', 'partial', 'infty', 'int', 'prod', 'log', 'exp', | |
| 'sin', 'cos', 'tan', 'lim', 'max', 'min', 'arg', 'text', | |
| 'mathbf', 'mathcal', 'hat', 'bar', 'vec', 'dot', 'cdot', | |
| 'times', 'div', 'pm', 'leq', 'geq', 'neq', 'approx', | |
| 'in', 'notin', 'subset', 'supset', 'cup', 'cap', | |
| 'forall', 'exists', 'rightarrow', 'leftarrow', 'Rightarrow', | |
| 'left', 'right', 'big', 'Big', 'bigg', 'Bigg'] | |
| for cmd in latex_commands: | |
| # Replace \cmd with \\cmd (but not \\cmd which is already escaped) | |
| text = re.sub(r'(?<!\\)\\(' + cmd + r')', r'\\\\\\1', text) | |
| return text | |
| response = fix_latex_escapes(response) | |
| try: | |
| parsed = json.loads(response) | |
| return self._validate_and_fix_response(parsed) | |
| except json.JSONDecodeError as e: | |
| # Last resort: try with raw_decode to get partial result | |
| try: | |
| decoder = json.JSONDecoder() | |
| obj, _ = decoder.raw_decode(response) | |
| return self._validate_and_fix_response(obj) | |
| except: | |
| pass | |
| # Return minimal valid structure instead of crashing | |
| print(f"JSON parse error: {e}") | |
| print(f"Response was: {response[:500]}...") | |
| return { | |
| 'title': 'Visualization', | |
| 'oneliner': 'Unable to parse LLM response', | |
| 'intuition': '', | |
| 'why_it_matters': 'This concept is fundamental to machine learning and is used across many real-world applications.', | |
| 'math': {}, | |
| 'components': [{'type': 'scatter_cluster', 'config': {'title': 'Default Visualization'}}], | |
| 'params': [] | |
| } | |
| def _validate_and_fix_response(self, parsed: dict) -> dict: | |
| """Validate LLM response and fix common issues.""" | |
| # Fix 1: Ensure we have exactly 3 components | |
| components = parsed.get('components', []) | |
| if len(components) < 3: | |
| print(f"WARNING: LLM returned only {len(components)} component(s), adding fallback views") | |
| # Add complementary components based on what we have | |
| existing_types = {c.get('type') for c in components} | |
| fallback_additions = [ | |
| ('flow_diagram', {'title': 'Network Architecture', 'layers': [4, 8, 6, 3]}), | |
| ('loss_curve', {'title': 'Training Progress'}), | |
| ('distribution_plot', {'title': 'Output Distribution'}), | |
| ('line_progression', {'title': 'Learning Curve'}), | |
| ('scatter_cluster', {'title': 'Data Distribution', 'n_clusters': 3}), | |
| ('matrix_heatmap', {'title': 'Weight Visualization', 'size': 5}), | |
| ] | |
| for comp_type, config in fallback_additions: | |
| if comp_type not in existing_types and len(components) < 3: | |
| components.append({'type': comp_type, 'config': config}) | |
| existing_types.add(comp_type) | |
| parsed['components'] = components | |
| # Fix 2: Ensure why_it_matters is not empty | |
| why = parsed.get('why_it_matters', '') | |
| if not why or len(why.strip()) < 20: | |
| print("WARNING: why_it_matters is empty or too short, adding default") | |
| title = parsed.get('title', 'This concept') | |
| parsed['why_it_matters'] = ( | |
| f"{title} is widely used in production ML systems including recommendation engines, " | |
| f"fraud detection, and autonomous vehicles. Getting it wrong can lead to poor model performance, " | |
| f"wasted compute resources, and incorrect predictions that impact real users." | |
| ) | |
| # Fix 3: Ensure math section has sufficient content | |
| math = parsed.get('math', {}) | |
| if not math: | |
| math = {'formulas': [], 'variables': []} | |
| parsed['math'] = math | |
| formulas = math.get('formulas', []) | |
| variables = math.get('variables', []) | |
| if len(formulas) < 3: | |
| print(f"WARNING: Only {len(formulas)} formulas, math section may be incomplete") | |
| # We don't auto-generate formulas as they're concept-specific, but we log the issue | |
| if len(variables) < 3: | |
| print(f"WARNING: Only {len(variables)} variables explained") | |
| return parsed | |
| def render_components(self, config: dict, param_overrides: dict = None) -> list: | |
| """Render all components with optional param overrides.""" | |
| # Deep copy to avoid modifying cached config | |
| components = copy.deepcopy(config.get('components', [])) | |
| params = config.get('params', []) | |
| figures = [] | |
| # Apply param overrides to component configs | |
| if param_overrides: | |
| print(f"Param overrides: {param_overrides}") | |
| print(f"Params config: {params}") | |
| for param in params: | |
| param_name = param['name'] | |
| if param_name in param_overrides: | |
| comp_idx = param.get('component_index', 0) | |
| config_key = param.get('config_key', param_name) | |
| value = param_overrides[param_name] | |
| if comp_idx < len(components): | |
| components[comp_idx]['config'][config_key] = value | |
| # Also set common aliases to ensure component finds it | |
| if 'cluster' in config_key.lower() or config_key == 'k': | |
| components[comp_idx]['config']['n_clusters'] = int(value) | |
| components[comp_idx]['config']['k'] = int(value) | |
| if 'point' in config_key.lower(): | |
| components[comp_idx]['config']['n_points'] = int(value) | |
| print(f"Set [{comp_idx}].{config_key} = {value}") | |
| # CRITICAL: Propagate clustering params to ALL components that need them | |
| # This ensures scatter_cluster and cluster_distribution stay synchronized | |
| synced_types = ['scatter_cluster', 'cluster_distribution'] | |
| # Extract n_clusters and n_points from param_overrides (check multiple aliases) | |
| n_clusters = None | |
| n_points = None | |
| for key, value in param_overrides.items(): | |
| key_lower = key.lower() | |
| if 'cluster' in key_lower or key == 'k': | |
| n_clusters = int(value) | |
| if 'point' in key_lower: | |
| n_points = int(value) | |
| # Apply to ALL synced components | |
| for comp in components: | |
| if comp['type'] in synced_types: | |
| if n_clusters is not None: | |
| comp['config']['n_clusters'] = n_clusters | |
| comp['config']['k'] = n_clusters | |
| if n_points is not None: | |
| comp['config']['n_points'] = n_points | |
| print(f"Synced {comp['type']}: n_clusters={n_clusters}, n_points={n_points}") | |
| # Generate new seed based on param values so visualization actually changes | |
| seed = hash(str(param_overrides)) % 10000 | |
| for comp in components: | |
| comp['config']['seed'] = seed | |
| # Render each component | |
| for comp in components: | |
| comp_type = comp['type'] | |
| comp_config = comp.get('config', {}) | |
| try: | |
| fig = render_component(comp_type, comp_config) | |
| figures.append(fig) | |
| except Exception as e: | |
| print(f"Error rendering {comp_type}: {e}") | |
| return figures | |
| def generate(self, topic: str) -> dict: | |
| """Generate visualization config for topic.""" | |
| cache_key = self._cache_key(topic) | |
| if cache_key in self._cache: | |
| cached = self._cache[cache_key] | |
| figures = self.render_components(cached) | |
| return {**cached, 'figures': figures} | |
| prompt = f"Create an interactive visualization for: {topic}" | |
| response = self._call_llm(prompt) | |
| parsed = self._parse_response(response) | |
| if not parsed.get('components'): | |
| raise ValueError("LLM did not return valid components") | |
| # Cache config (without figures) | |
| self._cache[cache_key] = { | |
| 'title': parsed.get('title', topic), | |
| 'oneliner': parsed.get('oneliner', ''), | |
| 'intuition': parsed.get('intuition', ''), | |
| 'why_it_matters': parsed.get('why_it_matters', ''), | |
| 'evolution': parsed.get('evolution', {}), | |
| 'math': parsed.get('math', {}), | |
| 'components': parsed.get('components', []), | |
| 'params': parsed.get('params', []), | |
| } | |
| # Render with defaults | |
| figures = self.render_components(self._cache[cache_key]) | |
| return { | |
| **self._cache[cache_key], | |
| 'figures': figures | |
| } | |
| def update_params(self, topic: str, param_values: dict) -> list: | |
| """Update visualization with new param values.""" | |
| cache_key = self._cache_key(topic) | |
| if cache_key not in self._cache: | |
| return [] | |
| config = self._cache[cache_key] | |
| return self.render_components(config, param_values) | |