AI-Learning-Playground / generator.py
adi-123's picture
Upload 5 files
888531f verified
"""
Visualization Generator - LLM selects and configures visual components
No code generation - just structured JSON config
"""
import re
import json
import hashlib
import copy
from together import Together
from components import render_component, COMPONENTS
SYSTEM_PROMPT = """You are an ML teacher creating BEGINNER-FRIENDLY visualizations. Make concepts crystal clear.
CRITICAL: In JSON, backslashes must be double-escaped. Write \\\\sum not \\sum for LaTeX.
AVAILABLE COMPONENTS:
1. scatter_cluster - clustering visualization with centroids (params: n_clusters, n_points)
2. cluster_distribution - cluster size bar chart, SYNCED with scatter_cluster (params: n_clusters, n_points)
3. gradient_descent_3d - 3D loss surface with optimization path
4. gradient_descent_2d - 2D contour view (bird's eye) of gradient descent
5. loss_curve - training loss over steps
6. flow_diagram - basic neural network architecture (layers as ints: [3,4,2])
7. matrix_heatmap - attention weights, confusion matrices (REQUIRES: labels, values, x_title, y_title)
8. distribution_plot - softmax probabilities (REQUIRES: categories, values)
9. decision_boundary - classification regions
10. line_progression - training curves over epochs
11. custom_plotly - ADVANCED: For ANY concept not covered above. Provide full Plotly JSON spec.
=== CUSTOM_PLOTLY COMPONENT (for complex/unique visualizations) ===
Use custom_plotly when the concept requires specialized diagrams not covered by components 1-9.
AVAILABLE TEMPLATES (use "template" key in config):
- "transformer_attention" - Q/K/V attention mechanism with arrows and flow
- "lstm_gates" - Forget/Input/Output gates visualization
- "vae_architecture" - Encoder → Latent → Decoder flow
- "gan_architecture" - Generator vs Discriminator adversarial setup
- "cnn_architecture" - Convolution → Pooling → FC layers
Example using template:
{"type": "custom_plotly", "config": {"title": "Attention Q/K/V", "template": "transformer_attention"}}
Example with custom data (for concepts without templates):
{"type": "custom_plotly", "config": {
"title": "Custom Visualization",
"data": [
{"type": "scatter", "x": [1,2,3], "y": [1,2,3], "mode": "markers+text", "text": ["A","B","C"], "marker": {"size": 30, "color": "#667eea"}},
{"type": "scatter", "x": [1,3], "y": [1,3], "mode": "lines", "line": {"color": "gray", "dash": "dot"}}
],
"layout": {"xaxis": {"visible": false}, "yaxis": {"visible": false}, "annotations": [{"x": 2, "y": 2, "text": "Label", "showarrow": false}]}
}}
Valid trace types: scatter, scatter3d, bar, heatmap, contour, surface, pie, sankey, treemap, sunburst
=== HARD REQUIREMENTS (MUST FOLLOW - NO EXCEPTIONS) ===
REQUIREMENT 1 - COMPONENTS: You MUST return EXACTLY 3 components. NOT 1. NOT 2. ALWAYS 3.
- Each component MUST have a DIFFERENT "type" value
- Recommended combinations:
* Gradient Descent: gradient_descent_3d + gradient_descent_2d + loss_curve
* Clustering (K-Means, DBSCAN): scatter_cluster + cluster_distribution + loss_curve (ALL use same n_clusters, n_points params!)
* Basic Neural Networks: flow_diagram + distribution_plot + loss_curve
* Attention Mechanism: custom_plotly (template: transformer_attention) + matrix_heatmap + distribution_plot
* Transformers: custom_plotly (template: transformer_attention) + matrix_heatmap + distribution_plot
* LSTM/RNN: custom_plotly (template: lstm_gates) + line_progression + distribution_plot
* VAE/Autoencoder: custom_plotly (template: vae_architecture) + scatter_cluster + loss_curve
* GAN: custom_plotly (template: gan_architecture) + distribution_plot + line_progression
* CNN: custom_plotly (template: cnn_architecture) + matrix_heatmap + distribution_plot
* Decision Boundary/SVM: decision_boundary + scatter_cluster + line_progression
* Any other concept: Use custom_plotly with appropriate data/layout for the architecture
REQUIREMENT 2 - WHY_IT_MATTERS: This field is REQUIRED and MUST contain 2-4 sentences.
- Explain real-world applications (e.g., "Used in Netflix recommendations, fraud detection")
- Explain consequences of getting it wrong (e.g., "Wrong clusters = wrong customer segments = wasted marketing budget")
- This field CANNOT be empty or null
REQUIREMENT 3 - EVOLUTION: This field is REQUIRED. Explain what problem this concept solves.
- "predecessor": The previous concept/approach this improves upon (e.g., "RNN" for LSTM, "Seq2Seq" for Attention)
- "predecessor_problem": 1-2 sentences on the limitation/problem of the predecessor
- "how_it_solves": 1-2 sentences on how this concept fixes that problem
- "key_innovation": One sentence on the core innovation (e.g., "Gates that control information flow")
- If this is a foundational concept with no predecessor, set predecessor to "None" and explain what general problem it solves
REQUIREMENT 4 - MATHEMATICAL DETAILS: You MUST provide comprehensive math coverage.
- MINIMUM 5 formulas in the "formulas" array (aim for 6-8)
- Required formula types: main equation, gradient/derivative, update rule, loss function, convergence condition
- MINIMUM 6 variables explained in the "variables" array
- Each formula description MUST be 1-2 sentences in plain English
- Each variable meaning MUST be a full sentence, not just 1-2 words
=== OUTPUT FORMAT (JSON only) ===
{
"title": "Clear Beginner-Friendly Title",
"oneliner": "One sentence a beginner can understand (max 15 words)",
"intuition": "2-3 sentences using everyday analogies. No jargon. Like explaining to a smart 12-year-old.",
"why_it_matters": "REQUIRED: 2-4 sentences on real-world applications and what goes wrong if you mess this up.",
"evolution": {
"predecessor": "Name of the previous approach (or 'None' if foundational)",
"predecessor_problem": "What limitation/problem did the predecessor have?",
"how_it_solves": "How does this concept fix that problem?",
"key_innovation": "One sentence: the core breakthrough idea"
},
"math": {
"formulas": [
{"name": "Main Formula", "equation": "$$LaTeX$$", "description": "Plain English explanation (1-2 sentences)"},
{"name": "Gradient", "equation": "$$LaTeX$$", "description": "How we compute the direction to move"},
{"name": "Update Rule", "equation": "$$LaTeX$$", "description": "How parameters change each step"},
{"name": "Loss Function", "equation": "$$LaTeX$$", "description": "What we are trying to minimize"},
{"name": "Convergence", "equation": "$$LaTeX$$", "description": "When do we stop"}
],
"variables": [
{"symbol": "x", "meaning": "Full sentence explaining this variable in beginner terms"},
{"symbol": "y", "meaning": "Full sentence explaining this variable in beginner terms"}
]
},
"components": [
{"type": "component_type_1", "config": {"title": "Descriptive Title for View 1", ...}},
{"type": "component_type_2", "config": {"title": "Descriptive Title for View 2", ...}},
{"type": "component_type_3", "config": {"title": "Descriptive Title for View 3", ...}}
],
"params": [
{"name": "param_name", "min": 2, "max": 8, "default": 3, "step": 1, "label": "Human Label", "component_index": 0, "config_key": "config_key_name"}
]
}
=== COMPONENT CONFIG KEYS ===
ALGORITHM SIMULATIONS (data generated from parameters - stays synchronized):
- scatter_cluster: n_clusters, n_points (clustering visualization)
- cluster_distribution: n_clusters, n_points (MUST match scatter_cluster params for sync!)
- gradient_descent_3d, gradient_descent_2d, loss_curve: learning_rate, n_steps, start_x
- decision_boundary: n_points, model_type
- line_progression: epochs
- flow_diagram: layers (list of ints)
IMPORTANT FOR CLUSTERING: scatter_cluster and cluster_distribution MUST use the SAME n_clusters and n_points
values so they stay synchronized when user adjusts sliders.
PURE RENDERERS - YOU MUST PROVIDE ALL DATA + INTERACTIVE PARAMS:
distribution_plot (REQUIRES DATA):
- categories: list of strings (REQUIRED) - e.g., ["Token 1", "Token 2"] or ["Dog", "Cat"]
- values: list of floats (REQUIRED) - e.g., [0.6, 0.3, 0.1] (should sum to ~1.0)
- x_title: string (optional) - axis label
- temperature: float (INTERACTIVE) - controls distribution sharpness (0.1=peaked, 2.0=uniform)
matrix_heatmap (REQUIRES DATA):
- labels: list of strings (REQUIRED) - e.g., ["Token 1", "Token 2", "Token 3"]
- values: 2D array of floats (REQUIRED) - e.g., [[0.5, 0.3, 0.2], [0.2, 0.6, 0.2], [0.3, 0.1, 0.6]]
- x_title: string (REQUIRED) - e.g., "Keys" for attention, "Feature" for CNN
- y_title: string (REQUIRED) - e.g., "Queries" for attention, "Filter" for CNN
- focus_row: int (INTERACTIVE) - 1-indexed row to highlight (e.g., "which token is attending?")
- threshold: float (INTERACTIVE) - only show values above threshold (0.0-1.0)
custom_plotly: template (string), data (array of traces), layout (object)
=== INTERACTIVE PARAMS FOR PURE RENDERERS ===
IMPORTANT: For concepts using distribution_plot or matrix_heatmap, you MUST include interactive params.
These let users explore the visualization even though data comes from LLM.
For Attention/Transformer concepts, include:
{"name": "temperature", "min": 0.1, "max": 2.0, "default": 1.0, "step": 0.1, "label": "Attention Temperature", "component_index": 2, "config_key": "temperature"}
{"name": "focus_row", "min": 1, "max": <num_tokens>, "default": 1, "step": 1, "label": "Focus on Token", "component_index": 1, "config_key": "focus_row"}
For Classification/Softmax concepts, include:
{"name": "temperature", "min": 0.1, "max": 3.0, "default": 1.0, "step": 0.1, "label": "Softmax Temperature", "component_index": <index>, "config_key": "temperature"}
=== CONCEPT-SPECIFIC DATA + PARAMS EXAMPLES ===
For "Attention Mechanism":
- distribution_plot: categories=["The", "cat", "sat", "down"], values=[0.1, 0.5, 0.3, 0.1]
- matrix_heatmap: labels=["The", "cat", "sat", "down"], values=[[0.7,0.1,0.1,0.1],[0.1,0.6,0.2,0.1],[0.2,0.2,0.4,0.2],[0.1,0.1,0.2,0.6]], x_title="Keys", y_title="Queries"
- params: [
{"name": "temperature", "min": 0.1, "max": 2.0, "default": 1.0, "step": 0.1, "label": "Attention Sharpness", "component_index": 2, "config_key": "temperature"},
{"name": "focus_row", "min": 1, "max": 4, "default": 2, "step": 1, "label": "Focus on Token", "component_index": 1, "config_key": "focus_row"}
]
For "Image Classification" / "CNN":
- distribution_plot: categories=["Dog", "Cat", "Bird", "Car"], values=[0.7, 0.15, 0.1, 0.05]
- params: [{"name": "temperature", "min": 0.1, "max": 3.0, "default": 1.0, "step": 0.1, "label": "Prediction Confidence", "component_index": <dist_plot_index>, "config_key": "temperature"}]
For "Softmax Function":
- distribution_plot: categories=["Class A", "Class B", "Class C"], values=[0.6, 0.3, 0.1]
- params: [{"name": "temperature", "min": 0.1, "max": 3.0, "default": 1.0, "step": 0.1, "label": "Temperature (τ)", "component_index": <index>, "config_key": "temperature"}]
For "Confusion Matrix":
- matrix_heatmap: labels=["Positive", "Negative"], values=[[45,5],[10,40]], x_title="Predicted", y_title="Actual"
- params: [{"name": "threshold", "min": 0, "max": 50, "default": 0, "step": 5, "label": "Min Value Shown", "component_index": <index>, "config_key": "threshold"}]
CRITICAL:
- If you don't provide the required data, the visualization will show an ERROR, not defaults.
- If you don't provide params, users CANNOT interact with the UI. ALWAYS include params.
- Never use animal names (Cat, Dog) unless the concept is specifically about image classification.
=== VALIDATION CHECKLIST (verify before returning) ===
[ ] components array has EXACTLY 3 items with DIFFERENT types
[ ] params array has AT LEAST 2 interactive parameters (users MUST be able to interact!)
[ ] why_it_matters has 2-4 real sentences (not empty!)
[ ] evolution object has all 4 fields filled (predecessor, predecessor_problem, how_it_solves, key_innovation)
[ ] formulas array has 5+ formulas with descriptions
[ ] variables array has 6+ variables with full sentence meanings
[ ] All config_keys match the expected keys above
[ ] distribution_plot has categories + values + a temperature param in params array
[ ] matrix_heatmap has labels + values + x_title + y_title + focus_row or threshold param in params array
Return ONLY valid JSON. No markdown, no explanation, just the JSON object."""
class VisualizationGenerator:
"""Generates visualizations by having LLM configure visual components."""
_cache = {}
def __init__(self, api_key: str):
self.client = Together(api_key=api_key)
self.model = "Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8"
def _cache_key(self, topic: str) -> str:
normalized = re.sub(r'[-_\s]+', ' ', topic.lower().strip())
return hashlib.md5(normalized.encode()).hexdigest()[:16]
def _call_llm(self, prompt: str) -> str:
response = self.client.chat.completions.create(
model=self.model,
max_tokens=4000, # Increased for comprehensive responses
temperature=0.3,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt}
]
)
return response.choices[0].message.content
def _parse_response(self, response: str) -> dict:
"""Parse JSON from LLM response, handling LaTeX escapes."""
response = response.strip()
# Remove markdown code blocks if present
if response.startswith('```'):
response = re.sub(r'^```json?\s*', '', response)
response = re.sub(r'\s*```$', '', response)
# Find JSON object
match = re.search(r'\{[\s\S]*\}', response)
if match:
response = match.group()
# Fix common LaTeX escape issues in JSON
# LLMs often write \sum instead of \\sum in JSON strings
def fix_latex_escapes(text):
# Fix unescaped backslashes before LaTeX commands
latex_commands = ['sum', 'frac', 'sqrt', 'alpha', 'beta', 'theta', 'mu', 'sigma',
'nabla', 'partial', 'infty', 'int', 'prod', 'log', 'exp',
'sin', 'cos', 'tan', 'lim', 'max', 'min', 'arg', 'text',
'mathbf', 'mathcal', 'hat', 'bar', 'vec', 'dot', 'cdot',
'times', 'div', 'pm', 'leq', 'geq', 'neq', 'approx',
'in', 'notin', 'subset', 'supset', 'cup', 'cap',
'forall', 'exists', 'rightarrow', 'leftarrow', 'Rightarrow',
'left', 'right', 'big', 'Big', 'bigg', 'Bigg']
for cmd in latex_commands:
# Replace \cmd with \\cmd (but not \\cmd which is already escaped)
text = re.sub(r'(?<!\\)\\(' + cmd + r')', r'\\\\\\1', text)
return text
response = fix_latex_escapes(response)
try:
parsed = json.loads(response)
return self._validate_and_fix_response(parsed)
except json.JSONDecodeError as e:
# Last resort: try with raw_decode to get partial result
try:
decoder = json.JSONDecoder()
obj, _ = decoder.raw_decode(response)
return self._validate_and_fix_response(obj)
except:
pass
# Return minimal valid structure instead of crashing
print(f"JSON parse error: {e}")
print(f"Response was: {response[:500]}...")
return {
'title': 'Visualization',
'oneliner': 'Unable to parse LLM response',
'intuition': '',
'why_it_matters': 'This concept is fundamental to machine learning and is used across many real-world applications.',
'math': {},
'components': [{'type': 'scatter_cluster', 'config': {'title': 'Default Visualization'}}],
'params': []
}
def _validate_and_fix_response(self, parsed: dict) -> dict:
"""Validate LLM response and fix common issues."""
# Fix 1: Ensure we have exactly 3 components
components = parsed.get('components', [])
if len(components) < 3:
print(f"WARNING: LLM returned only {len(components)} component(s), adding fallback views")
# Add complementary components based on what we have
existing_types = {c.get('type') for c in components}
fallback_additions = [
('flow_diagram', {'title': 'Network Architecture', 'layers': [4, 8, 6, 3]}),
('loss_curve', {'title': 'Training Progress'}),
('distribution_plot', {'title': 'Output Distribution'}),
('line_progression', {'title': 'Learning Curve'}),
('scatter_cluster', {'title': 'Data Distribution', 'n_clusters': 3}),
('matrix_heatmap', {'title': 'Weight Visualization', 'size': 5}),
]
for comp_type, config in fallback_additions:
if comp_type not in existing_types and len(components) < 3:
components.append({'type': comp_type, 'config': config})
existing_types.add(comp_type)
parsed['components'] = components
# Fix 2: Ensure why_it_matters is not empty
why = parsed.get('why_it_matters', '')
if not why or len(why.strip()) < 20:
print("WARNING: why_it_matters is empty or too short, adding default")
title = parsed.get('title', 'This concept')
parsed['why_it_matters'] = (
f"{title} is widely used in production ML systems including recommendation engines, "
f"fraud detection, and autonomous vehicles. Getting it wrong can lead to poor model performance, "
f"wasted compute resources, and incorrect predictions that impact real users."
)
# Fix 3: Ensure math section has sufficient content
math = parsed.get('math', {})
if not math:
math = {'formulas': [], 'variables': []}
parsed['math'] = math
formulas = math.get('formulas', [])
variables = math.get('variables', [])
if len(formulas) < 3:
print(f"WARNING: Only {len(formulas)} formulas, math section may be incomplete")
# We don't auto-generate formulas as they're concept-specific, but we log the issue
if len(variables) < 3:
print(f"WARNING: Only {len(variables)} variables explained")
return parsed
def render_components(self, config: dict, param_overrides: dict = None) -> list:
"""Render all components with optional param overrides."""
# Deep copy to avoid modifying cached config
components = copy.deepcopy(config.get('components', []))
params = config.get('params', [])
figures = []
# Apply param overrides to component configs
if param_overrides:
print(f"Param overrides: {param_overrides}")
print(f"Params config: {params}")
for param in params:
param_name = param['name']
if param_name in param_overrides:
comp_idx = param.get('component_index', 0)
config_key = param.get('config_key', param_name)
value = param_overrides[param_name]
if comp_idx < len(components):
components[comp_idx]['config'][config_key] = value
# Also set common aliases to ensure component finds it
if 'cluster' in config_key.lower() or config_key == 'k':
components[comp_idx]['config']['n_clusters'] = int(value)
components[comp_idx]['config']['k'] = int(value)
if 'point' in config_key.lower():
components[comp_idx]['config']['n_points'] = int(value)
print(f"Set [{comp_idx}].{config_key} = {value}")
# CRITICAL: Propagate clustering params to ALL components that need them
# This ensures scatter_cluster and cluster_distribution stay synchronized
synced_types = ['scatter_cluster', 'cluster_distribution']
# Extract n_clusters and n_points from param_overrides (check multiple aliases)
n_clusters = None
n_points = None
for key, value in param_overrides.items():
key_lower = key.lower()
if 'cluster' in key_lower or key == 'k':
n_clusters = int(value)
if 'point' in key_lower:
n_points = int(value)
# Apply to ALL synced components
for comp in components:
if comp['type'] in synced_types:
if n_clusters is not None:
comp['config']['n_clusters'] = n_clusters
comp['config']['k'] = n_clusters
if n_points is not None:
comp['config']['n_points'] = n_points
print(f"Synced {comp['type']}: n_clusters={n_clusters}, n_points={n_points}")
# Generate new seed based on param values so visualization actually changes
seed = hash(str(param_overrides)) % 10000
for comp in components:
comp['config']['seed'] = seed
# Render each component
for comp in components:
comp_type = comp['type']
comp_config = comp.get('config', {})
try:
fig = render_component(comp_type, comp_config)
figures.append(fig)
except Exception as e:
print(f"Error rendering {comp_type}: {e}")
return figures
def generate(self, topic: str) -> dict:
"""Generate visualization config for topic."""
cache_key = self._cache_key(topic)
if cache_key in self._cache:
cached = self._cache[cache_key]
figures = self.render_components(cached)
return {**cached, 'figures': figures}
prompt = f"Create an interactive visualization for: {topic}"
response = self._call_llm(prompt)
parsed = self._parse_response(response)
if not parsed.get('components'):
raise ValueError("LLM did not return valid components")
# Cache config (without figures)
self._cache[cache_key] = {
'title': parsed.get('title', topic),
'oneliner': parsed.get('oneliner', ''),
'intuition': parsed.get('intuition', ''),
'why_it_matters': parsed.get('why_it_matters', ''),
'evolution': parsed.get('evolution', {}),
'math': parsed.get('math', {}),
'components': parsed.get('components', []),
'params': parsed.get('params', []),
}
# Render with defaults
figures = self.render_components(self._cache[cache_key])
return {
**self._cache[cache_key],
'figures': figures
}
def update_params(self, topic: str, param_values: dict) -> list:
"""Update visualization with new param values."""
cache_key = self._cache_key(topic)
if cache_key not in self._cache:
return []
config = self._cache[cache_key]
return self.render_components(config, param_values)