""" Visualization Generator - LLM selects and configures visual components No code generation - just structured JSON config """ import re import json import hashlib import copy from together import Together from components import render_component, COMPONENTS SYSTEM_PROMPT = """You are an ML teacher creating BEGINNER-FRIENDLY visualizations. Make concepts crystal clear. CRITICAL: In JSON, backslashes must be double-escaped. Write \\\\sum not \\sum for LaTeX. AVAILABLE COMPONENTS: 1. scatter_cluster - clustering visualization with centroids (params: n_clusters, n_points) 2. cluster_distribution - cluster size bar chart, SYNCED with scatter_cluster (params: n_clusters, n_points) 3. gradient_descent_3d - 3D loss surface with optimization path 4. gradient_descent_2d - 2D contour view (bird's eye) of gradient descent 5. loss_curve - training loss over steps 6. flow_diagram - basic neural network architecture (layers as ints: [3,4,2]) 7. matrix_heatmap - attention weights, confusion matrices (REQUIRES: labels, values, x_title, y_title) 8. distribution_plot - softmax probabilities (REQUIRES: categories, values) 9. decision_boundary - classification regions 10. line_progression - training curves over epochs 11. custom_plotly - ADVANCED: For ANY concept not covered above. Provide full Plotly JSON spec. === CUSTOM_PLOTLY COMPONENT (for complex/unique visualizations) === Use custom_plotly when the concept requires specialized diagrams not covered by components 1-9. AVAILABLE TEMPLATES (use "template" key in config): - "transformer_attention" - Q/K/V attention mechanism with arrows and flow - "lstm_gates" - Forget/Input/Output gates visualization - "vae_architecture" - Encoder → Latent → Decoder flow - "gan_architecture" - Generator vs Discriminator adversarial setup - "cnn_architecture" - Convolution → Pooling → FC layers Example using template: {"type": "custom_plotly", "config": {"title": "Attention Q/K/V", "template": "transformer_attention"}} Example with custom data (for concepts without templates): {"type": "custom_plotly", "config": { "title": "Custom Visualization", "data": [ {"type": "scatter", "x": [1,2,3], "y": [1,2,3], "mode": "markers+text", "text": ["A","B","C"], "marker": {"size": 30, "color": "#667eea"}}, {"type": "scatter", "x": [1,3], "y": [1,3], "mode": "lines", "line": {"color": "gray", "dash": "dot"}} ], "layout": {"xaxis": {"visible": false}, "yaxis": {"visible": false}, "annotations": [{"x": 2, "y": 2, "text": "Label", "showarrow": false}]} }} Valid trace types: scatter, scatter3d, bar, heatmap, contour, surface, pie, sankey, treemap, sunburst === HARD REQUIREMENTS (MUST FOLLOW - NO EXCEPTIONS) === REQUIREMENT 1 - COMPONENTS: You MUST return EXACTLY 3 components. NOT 1. NOT 2. ALWAYS 3. - Each component MUST have a DIFFERENT "type" value - Recommended combinations: * Gradient Descent: gradient_descent_3d + gradient_descent_2d + loss_curve * Clustering (K-Means, DBSCAN): scatter_cluster + cluster_distribution + loss_curve (ALL use same n_clusters, n_points params!) * Basic Neural Networks: flow_diagram + distribution_plot + loss_curve * Attention Mechanism: custom_plotly (template: transformer_attention) + matrix_heatmap + distribution_plot * Transformers: custom_plotly (template: transformer_attention) + matrix_heatmap + distribution_plot * LSTM/RNN: custom_plotly (template: lstm_gates) + line_progression + distribution_plot * VAE/Autoencoder: custom_plotly (template: vae_architecture) + scatter_cluster + loss_curve * GAN: custom_plotly (template: gan_architecture) + distribution_plot + line_progression * CNN: custom_plotly (template: cnn_architecture) + matrix_heatmap + distribution_plot * Decision Boundary/SVM: decision_boundary + scatter_cluster + line_progression * Any other concept: Use custom_plotly with appropriate data/layout for the architecture REQUIREMENT 2 - WHY_IT_MATTERS: This field is REQUIRED and MUST contain 2-4 sentences. - Explain real-world applications (e.g., "Used in Netflix recommendations, fraud detection") - Explain consequences of getting it wrong (e.g., "Wrong clusters = wrong customer segments = wasted marketing budget") - This field CANNOT be empty or null REQUIREMENT 3 - EVOLUTION: This field is REQUIRED. Explain what problem this concept solves. - "predecessor": The previous concept/approach this improves upon (e.g., "RNN" for LSTM, "Seq2Seq" for Attention) - "predecessor_problem": 1-2 sentences on the limitation/problem of the predecessor - "how_it_solves": 1-2 sentences on how this concept fixes that problem - "key_innovation": One sentence on the core innovation (e.g., "Gates that control information flow") - If this is a foundational concept with no predecessor, set predecessor to "None" and explain what general problem it solves REQUIREMENT 4 - MATHEMATICAL DETAILS: You MUST provide comprehensive math coverage. - MINIMUM 5 formulas in the "formulas" array (aim for 6-8) - Required formula types: main equation, gradient/derivative, update rule, loss function, convergence condition - MINIMUM 6 variables explained in the "variables" array - Each formula description MUST be 1-2 sentences in plain English - Each variable meaning MUST be a full sentence, not just 1-2 words === OUTPUT FORMAT (JSON only) === { "title": "Clear Beginner-Friendly Title", "oneliner": "One sentence a beginner can understand (max 15 words)", "intuition": "2-3 sentences using everyday analogies. No jargon. Like explaining to a smart 12-year-old.", "why_it_matters": "REQUIRED: 2-4 sentences on real-world applications and what goes wrong if you mess this up.", "evolution": { "predecessor": "Name of the previous approach (or 'None' if foundational)", "predecessor_problem": "What limitation/problem did the predecessor have?", "how_it_solves": "How does this concept fix that problem?", "key_innovation": "One sentence: the core breakthrough idea" }, "math": { "formulas": [ {"name": "Main Formula", "equation": "$$LaTeX$$", "description": "Plain English explanation (1-2 sentences)"}, {"name": "Gradient", "equation": "$$LaTeX$$", "description": "How we compute the direction to move"}, {"name": "Update Rule", "equation": "$$LaTeX$$", "description": "How parameters change each step"}, {"name": "Loss Function", "equation": "$$LaTeX$$", "description": "What we are trying to minimize"}, {"name": "Convergence", "equation": "$$LaTeX$$", "description": "When do we stop"} ], "variables": [ {"symbol": "x", "meaning": "Full sentence explaining this variable in beginner terms"}, {"symbol": "y", "meaning": "Full sentence explaining this variable in beginner terms"} ] }, "components": [ {"type": "component_type_1", "config": {"title": "Descriptive Title for View 1", ...}}, {"type": "component_type_2", "config": {"title": "Descriptive Title for View 2", ...}}, {"type": "component_type_3", "config": {"title": "Descriptive Title for View 3", ...}} ], "params": [ {"name": "param_name", "min": 2, "max": 8, "default": 3, "step": 1, "label": "Human Label", "component_index": 0, "config_key": "config_key_name"} ] } === COMPONENT CONFIG KEYS === ALGORITHM SIMULATIONS (data generated from parameters - stays synchronized): - scatter_cluster: n_clusters, n_points (clustering visualization) - cluster_distribution: n_clusters, n_points (MUST match scatter_cluster params for sync!) - gradient_descent_3d, gradient_descent_2d, loss_curve: learning_rate, n_steps, start_x - decision_boundary: n_points, model_type - line_progression: epochs - flow_diagram: layers (list of ints) IMPORTANT FOR CLUSTERING: scatter_cluster and cluster_distribution MUST use the SAME n_clusters and n_points values so they stay synchronized when user adjusts sliders. PURE RENDERERS - YOU MUST PROVIDE ALL DATA + INTERACTIVE PARAMS: distribution_plot (REQUIRES DATA): - categories: list of strings (REQUIRED) - e.g., ["Token 1", "Token 2"] or ["Dog", "Cat"] - values: list of floats (REQUIRED) - e.g., [0.6, 0.3, 0.1] (should sum to ~1.0) - x_title: string (optional) - axis label - temperature: float (INTERACTIVE) - controls distribution sharpness (0.1=peaked, 2.0=uniform) matrix_heatmap (REQUIRES DATA): - labels: list of strings (REQUIRED) - e.g., ["Token 1", "Token 2", "Token 3"] - values: 2D array of floats (REQUIRED) - e.g., [[0.5, 0.3, 0.2], [0.2, 0.6, 0.2], [0.3, 0.1, 0.6]] - x_title: string (REQUIRED) - e.g., "Keys" for attention, "Feature" for CNN - y_title: string (REQUIRED) - e.g., "Queries" for attention, "Filter" for CNN - focus_row: int (INTERACTIVE) - 1-indexed row to highlight (e.g., "which token is attending?") - threshold: float (INTERACTIVE) - only show values above threshold (0.0-1.0) custom_plotly: template (string), data (array of traces), layout (object) === INTERACTIVE PARAMS FOR PURE RENDERERS === IMPORTANT: For concepts using distribution_plot or matrix_heatmap, you MUST include interactive params. These let users explore the visualization even though data comes from LLM. For Attention/Transformer concepts, include: {"name": "temperature", "min": 0.1, "max": 2.0, "default": 1.0, "step": 0.1, "label": "Attention Temperature", "component_index": 2, "config_key": "temperature"} {"name": "focus_row", "min": 1, "max": , "default": 1, "step": 1, "label": "Focus on Token", "component_index": 1, "config_key": "focus_row"} For Classification/Softmax concepts, include: {"name": "temperature", "min": 0.1, "max": 3.0, "default": 1.0, "step": 0.1, "label": "Softmax Temperature", "component_index": , "config_key": "temperature"} === CONCEPT-SPECIFIC DATA + PARAMS EXAMPLES === For "Attention Mechanism": - distribution_plot: categories=["The", "cat", "sat", "down"], values=[0.1, 0.5, 0.3, 0.1] - matrix_heatmap: labels=["The", "cat", "sat", "down"], values=[[0.7,0.1,0.1,0.1],[0.1,0.6,0.2,0.1],[0.2,0.2,0.4,0.2],[0.1,0.1,0.2,0.6]], x_title="Keys", y_title="Queries" - params: [ {"name": "temperature", "min": 0.1, "max": 2.0, "default": 1.0, "step": 0.1, "label": "Attention Sharpness", "component_index": 2, "config_key": "temperature"}, {"name": "focus_row", "min": 1, "max": 4, "default": 2, "step": 1, "label": "Focus on Token", "component_index": 1, "config_key": "focus_row"} ] For "Image Classification" / "CNN": - distribution_plot: categories=["Dog", "Cat", "Bird", "Car"], values=[0.7, 0.15, 0.1, 0.05] - params: [{"name": "temperature", "min": 0.1, "max": 3.0, "default": 1.0, "step": 0.1, "label": "Prediction Confidence", "component_index": , "config_key": "temperature"}] For "Softmax Function": - distribution_plot: categories=["Class A", "Class B", "Class C"], values=[0.6, 0.3, 0.1] - params: [{"name": "temperature", "min": 0.1, "max": 3.0, "default": 1.0, "step": 0.1, "label": "Temperature (τ)", "component_index": , "config_key": "temperature"}] For "Confusion Matrix": - matrix_heatmap: labels=["Positive", "Negative"], values=[[45,5],[10,40]], x_title="Predicted", y_title="Actual" - params: [{"name": "threshold", "min": 0, "max": 50, "default": 0, "step": 5, "label": "Min Value Shown", "component_index": , "config_key": "threshold"}] CRITICAL: - If you don't provide the required data, the visualization will show an ERROR, not defaults. - If you don't provide params, users CANNOT interact with the UI. ALWAYS include params. - Never use animal names (Cat, Dog) unless the concept is specifically about image classification. === VALIDATION CHECKLIST (verify before returning) === [ ] components array has EXACTLY 3 items with DIFFERENT types [ ] params array has AT LEAST 2 interactive parameters (users MUST be able to interact!) [ ] why_it_matters has 2-4 real sentences (not empty!) [ ] evolution object has all 4 fields filled (predecessor, predecessor_problem, how_it_solves, key_innovation) [ ] formulas array has 5+ formulas with descriptions [ ] variables array has 6+ variables with full sentence meanings [ ] All config_keys match the expected keys above [ ] distribution_plot has categories + values + a temperature param in params array [ ] matrix_heatmap has labels + values + x_title + y_title + focus_row or threshold param in params array Return ONLY valid JSON. No markdown, no explanation, just the JSON object.""" class VisualizationGenerator: """Generates visualizations by having LLM configure visual components.""" _cache = {} def __init__(self, api_key: str): self.client = Together(api_key=api_key) self.model = "Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8" def _cache_key(self, topic: str) -> str: normalized = re.sub(r'[-_\s]+', ' ', topic.lower().strip()) return hashlib.md5(normalized.encode()).hexdigest()[:16] def _call_llm(self, prompt: str) -> str: response = self.client.chat.completions.create( model=self.model, max_tokens=4000, # Increased for comprehensive responses temperature=0.3, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": prompt} ] ) return response.choices[0].message.content def _parse_response(self, response: str) -> dict: """Parse JSON from LLM response, handling LaTeX escapes.""" response = response.strip() # Remove markdown code blocks if present if response.startswith('```'): response = re.sub(r'^```json?\s*', '', response) response = re.sub(r'\s*```$', '', response) # Find JSON object match = re.search(r'\{[\s\S]*\}', response) if match: response = match.group() # Fix common LaTeX escape issues in JSON # LLMs often write \sum instead of \\sum in JSON strings def fix_latex_escapes(text): # Fix unescaped backslashes before LaTeX commands latex_commands = ['sum', 'frac', 'sqrt', 'alpha', 'beta', 'theta', 'mu', 'sigma', 'nabla', 'partial', 'infty', 'int', 'prod', 'log', 'exp', 'sin', 'cos', 'tan', 'lim', 'max', 'min', 'arg', 'text', 'mathbf', 'mathcal', 'hat', 'bar', 'vec', 'dot', 'cdot', 'times', 'div', 'pm', 'leq', 'geq', 'neq', 'approx', 'in', 'notin', 'subset', 'supset', 'cup', 'cap', 'forall', 'exists', 'rightarrow', 'leftarrow', 'Rightarrow', 'left', 'right', 'big', 'Big', 'bigg', 'Bigg'] for cmd in latex_commands: # Replace \cmd with \\cmd (but not \\cmd which is already escaped) text = re.sub(r'(? dict: """Validate LLM response and fix common issues.""" # Fix 1: Ensure we have exactly 3 components components = parsed.get('components', []) if len(components) < 3: print(f"WARNING: LLM returned only {len(components)} component(s), adding fallback views") # Add complementary components based on what we have existing_types = {c.get('type') for c in components} fallback_additions = [ ('flow_diagram', {'title': 'Network Architecture', 'layers': [4, 8, 6, 3]}), ('loss_curve', {'title': 'Training Progress'}), ('distribution_plot', {'title': 'Output Distribution'}), ('line_progression', {'title': 'Learning Curve'}), ('scatter_cluster', {'title': 'Data Distribution', 'n_clusters': 3}), ('matrix_heatmap', {'title': 'Weight Visualization', 'size': 5}), ] for comp_type, config in fallback_additions: if comp_type not in existing_types and len(components) < 3: components.append({'type': comp_type, 'config': config}) existing_types.add(comp_type) parsed['components'] = components # Fix 2: Ensure why_it_matters is not empty why = parsed.get('why_it_matters', '') if not why or len(why.strip()) < 20: print("WARNING: why_it_matters is empty or too short, adding default") title = parsed.get('title', 'This concept') parsed['why_it_matters'] = ( f"{title} is widely used in production ML systems including recommendation engines, " f"fraud detection, and autonomous vehicles. Getting it wrong can lead to poor model performance, " f"wasted compute resources, and incorrect predictions that impact real users." ) # Fix 3: Ensure math section has sufficient content math = parsed.get('math', {}) if not math: math = {'formulas': [], 'variables': []} parsed['math'] = math formulas = math.get('formulas', []) variables = math.get('variables', []) if len(formulas) < 3: print(f"WARNING: Only {len(formulas)} formulas, math section may be incomplete") # We don't auto-generate formulas as they're concept-specific, but we log the issue if len(variables) < 3: print(f"WARNING: Only {len(variables)} variables explained") return parsed def render_components(self, config: dict, param_overrides: dict = None) -> list: """Render all components with optional param overrides.""" # Deep copy to avoid modifying cached config components = copy.deepcopy(config.get('components', [])) params = config.get('params', []) figures = [] # Apply param overrides to component configs if param_overrides: print(f"Param overrides: {param_overrides}") print(f"Params config: {params}") for param in params: param_name = param['name'] if param_name in param_overrides: comp_idx = param.get('component_index', 0) config_key = param.get('config_key', param_name) value = param_overrides[param_name] if comp_idx < len(components): components[comp_idx]['config'][config_key] = value # Also set common aliases to ensure component finds it if 'cluster' in config_key.lower() or config_key == 'k': components[comp_idx]['config']['n_clusters'] = int(value) components[comp_idx]['config']['k'] = int(value) if 'point' in config_key.lower(): components[comp_idx]['config']['n_points'] = int(value) print(f"Set [{comp_idx}].{config_key} = {value}") # CRITICAL: Propagate clustering params to ALL components that need them # This ensures scatter_cluster and cluster_distribution stay synchronized synced_types = ['scatter_cluster', 'cluster_distribution'] # Extract n_clusters and n_points from param_overrides (check multiple aliases) n_clusters = None n_points = None for key, value in param_overrides.items(): key_lower = key.lower() if 'cluster' in key_lower or key == 'k': n_clusters = int(value) if 'point' in key_lower: n_points = int(value) # Apply to ALL synced components for comp in components: if comp['type'] in synced_types: if n_clusters is not None: comp['config']['n_clusters'] = n_clusters comp['config']['k'] = n_clusters if n_points is not None: comp['config']['n_points'] = n_points print(f"Synced {comp['type']}: n_clusters={n_clusters}, n_points={n_points}") # Generate new seed based on param values so visualization actually changes seed = hash(str(param_overrides)) % 10000 for comp in components: comp['config']['seed'] = seed # Render each component for comp in components: comp_type = comp['type'] comp_config = comp.get('config', {}) try: fig = render_component(comp_type, comp_config) figures.append(fig) except Exception as e: print(f"Error rendering {comp_type}: {e}") return figures def generate(self, topic: str) -> dict: """Generate visualization config for topic.""" cache_key = self._cache_key(topic) if cache_key in self._cache: cached = self._cache[cache_key] figures = self.render_components(cached) return {**cached, 'figures': figures} prompt = f"Create an interactive visualization for: {topic}" response = self._call_llm(prompt) parsed = self._parse_response(response) if not parsed.get('components'): raise ValueError("LLM did not return valid components") # Cache config (without figures) self._cache[cache_key] = { 'title': parsed.get('title', topic), 'oneliner': parsed.get('oneliner', ''), 'intuition': parsed.get('intuition', ''), 'why_it_matters': parsed.get('why_it_matters', ''), 'evolution': parsed.get('evolution', {}), 'math': parsed.get('math', {}), 'components': parsed.get('components', []), 'params': parsed.get('params', []), } # Render with defaults figures = self.render_components(self._cache[cache_key]) return { **self._cache[cache_key], 'figures': figures } def update_params(self, topic: str, param_values: dict) -> list: """Update visualization with new param values.""" cache_key = self._cache_key(topic) if cache_key not in self._cache: return [] config = self._cache[cache_key] return self.render_components(config, param_values)