"""
Visual Component Library - Beginner-Friendly ML Visualizations
Multiple synchronized views with educational annotations
"""
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
COLORS = {
'primary': '#667eea',
'secondary': '#764ba2',
'accent': '#f093fb',
'success': '#43e97b',
'warning': '#fa709a',
'info': '#4facfe',
'gradient': ['#667eea', '#764ba2', '#f093fb', '#4facfe', '#43e97b', '#fa709a'],
'heatmap': 'Viridis',
}
LAYOUT_DEFAULTS = dict(
template='plotly_dark',
paper_bgcolor='rgba(26,26,46,0.9)',
plot_bgcolor='rgba(26,26,46,0.9)',
font=dict(family='Inter, sans-serif', color='#e0e0e0', size=12),
margin=dict(l=60, r=60, t=80, b=60),
hoverlabel=dict(bgcolor='#1a1a2e', font_size=14),
)
def create_placeholder_figure(title: str, message: str) -> go.Figure:
"""
Create a placeholder figure when required data is missing.
Shows a clear error message instead of wrong defaults.
"""
fig = go.Figure()
fig.add_annotation(
x=0.5, y=0.5,
xref='paper', yref='paper',
text=f"{title}
{message}",
showarrow=False,
font=dict(size=16, color='#ff6b6b'),
align='center',
bgcolor='rgba(255,107,107,0.1)',
bordercolor='#ff6b6b',
borderwidth=2,
borderpad=20,
)
fig.update_layout(
title=dict(text=f"β οΈ {title}", font=dict(size=16, color='#ff6b6b')),
xaxis=dict(visible=False),
yaxis=dict(visible=False),
**LAYOUT_DEFAULTS,
)
return fig
def scatter_cluster(config: dict) -> go.Figure:
"""K-Means clustering with educational annotations."""
title = config.get('title', 'K-Means Clustering')
n_clusters = int(config.get('n_clusters') or config.get('k') or config.get('num_clusters') or 3)
n_points = int(config.get('n_points') or config.get('data_points') or 150)
show_centroids = config.get('show_centroids', True)
seed = config.get('seed', 42)
np.random.seed(seed)
# Generate clear, separated clusters for beginners
centers = []
angle_step = 2 * np.pi / n_clusters
radius = 4
for i in range(n_clusters):
angle = i * angle_step
centers.append([radius * np.cos(angle), radius * np.sin(angle)])
centers = np.array(centers)
# Generate points around centers
points_per_cluster = n_points // n_clusters
points = []
labels = []
for i, center in enumerate(centers):
cluster_points = center + np.random.randn(points_per_cluster, 2) * 1.2
points.extend(cluster_points)
labels.extend([i] * points_per_cluster)
points = np.array(points)
labels = np.array(labels)
fig = go.Figure()
# Plot clusters
for i in range(n_clusters):
mask = labels == i
color = COLORS['gradient'][i % len(COLORS['gradient'])]
fig.add_trace(go.Scatter(
x=points[mask, 0], y=points[mask, 1],
mode='markers',
marker=dict(size=10, color=color, opacity=0.7, line=dict(color='white', width=1)),
name=f'Cluster {i+1}',
hovertemplate=f'Cluster {i+1}
Position: (%{{x:.1f}}, %{{y:.1f}})',
))
# Centroids with educational annotation
if show_centroids:
fig.add_trace(go.Scatter(
x=centers[:, 0], y=centers[:, 1],
mode='markers+text',
marker=dict(size=25, color='white', symbol='x', line=dict(color='#333', width=3)),
text=[f'C{i+1}' for i in range(n_clusters)],
textposition='top center',
textfont=dict(size=14, color='white'),
name='Centroids (Cluster Centers)',
hovertemplate='Centroid %{text}
This is the "center" of the cluster
Position: (%{x:.1f}, %{y:.1f})',
))
# Educational annotations
fig.add_annotation(
x=centers[0, 0], y=centers[0, 1] + 2,
text="β Centroid: The algorithm tries to
minimize distance from points to here",
showarrow=True, arrowhead=2, arrowcolor=COLORS['info'],
font=dict(size=11, color=COLORS['info']),
ax=80, ay=-30
)
# Add insight annotation
fig.add_annotation(
x=0.02, y=0.98, xref='paper', yref='paper',
text=f"K = {n_clusters} clusters
{n_points} data points total",
showarrow=False,
font=dict(size=13, color='white'),
align='left',
bgcolor='rgba(102, 126, 234, 0.8)',
borderpad=8
)
fig.update_layout(
title=dict(text=f"{title}
Each color = one cluster, X = centroid", font=dict(size=18)),
xaxis=dict(title='Feature 1', zeroline=False),
yaxis=dict(title='Feature 2', zeroline=False),
legend=dict(orientation='h', y=1.12),
**LAYOUT_DEFAULTS
)
return fig
def cluster_distribution(config: dict) -> go.Figure:
"""
DYNAMIC cluster size distribution - synchronized with scatter_cluster.
Uses same params as scatter_cluster (n_clusters, n_points, seed) to ensure
the distribution matches the actual clustering visualization.
"""
title = config.get('title', 'Cluster Size Distribution')
n_clusters = int(config.get('n_clusters') or config.get('k') or config.get('num_clusters') or 3)
n_points = int(config.get('n_points') or config.get('data_points') or 150)
seed = config.get('seed', 42)
np.random.seed(seed)
# Generate same clustering as scatter_cluster to get accurate counts
points_per_cluster = n_points // n_clusters
remainder = n_points % n_clusters
# Calculate actual cluster sizes (with small random variation for realism)
cluster_sizes = []
cluster_names = []
for i in range(n_clusters):
# Add slight variation to make it realistic (not perfectly equal)
base_size = points_per_cluster + (1 if i < remainder else 0)
variation = np.random.randint(-2, 3) if base_size > 5 else 0
size = max(1, base_size + variation)
cluster_sizes.append(size)
cluster_names.append(f'Cluster {i+1}')
# Normalize to get proportions
total = sum(cluster_sizes)
proportions = [s / total for s in cluster_sizes]
# Sort by size for better visualization
sorted_pairs = sorted(zip(cluster_names, proportions, cluster_sizes), key=lambda x: x[1], reverse=True)
cluster_names = [p[0] for p in sorted_pairs]
proportions = [p[1] for p in sorted_pairs]
cluster_sizes = [p[2] for p in sorted_pairs]
colors = [COLORS['gradient'][i % len(COLORS['gradient'])] for i in range(n_clusters)]
fig = go.Figure()
fig.add_trace(go.Bar(
x=cluster_names, y=proportions,
marker=dict(color=colors, line=dict(color='white', width=1)),
text=[f'{p:.1%}
({s} pts)' for p, s in zip(proportions, cluster_sizes)],
textposition='outside',
textfont=dict(size=12, color='white'),
hovertemplate='%{x}
Size: %{text}',
))
# Highlight largest cluster
fig.add_annotation(
x=cluster_names[0], y=proportions[0] + 0.05,
text="π Largest",
showarrow=False, font=dict(size=12, color='lime'),
)
fig.update_layout(
title=dict(text=f"{title}
Points per cluster (K={n_clusters})"),
xaxis=dict(title='Cluster'),
yaxis=dict(title='Proportion of Points', tickformat='.0%', range=[0, max(proportions) * 1.3]),
**LAYOUT_DEFAULTS,
)
return fig
def gradient_descent_3d(config: dict) -> go.Figure:
"""3D gradient descent with educational annotations."""
title = config.get('title', 'Gradient Descent Optimization')
lr = float(config.get('learning_rate') or config.get('lr') or config.get('alpha') or 0.1)
start_x = float(config.get('start_x') or config.get('start_position') or 2.0)
start_y = float(config.get('start_y') or 2.0)
n_steps = int(config.get('n_steps') or config.get('steps') or config.get('iterations') or 30)
seed = config.get('seed', 42)
np.random.seed(seed)
# Create loss surface
x = np.linspace(-3, 3, 60)
y = np.linspace(-3, 3, 60)
X, Y = np.meshgrid(x, y)
Z = X**2 + Y**2 # Simple quadratic - easy to understand
# Run gradient descent
path_x, path_y, path_z = [start_x], [start_y], [start_x**2 + start_y**2]
px, py = start_x, start_y
for step in range(n_steps):
# Gradient of x^2 + y^2 is (2x, 2y)
gx, gy = 2 * px, 2 * py
px = px - lr * gx
py = py - lr * gy
pz = px**2 + py**2
path_x.append(px)
path_y.append(py)
path_z.append(pz)
# Stop if converged
if pz < 0.001:
break
fig = go.Figure()
# Loss surface
fig.add_trace(go.Surface(
x=X, y=Y, z=Z,
colorscale='Viridis',
opacity=0.85,
showscale=False,
name='Loss Surface',
hovertemplate='Loss at (%{x:.1f}, %{y:.1f}): %{z:.2f}',
))
# Descent path
fig.add_trace(go.Scatter3d(
x=path_x, y=path_y, z=path_z,
mode='lines+markers',
marker=dict(size=6, color=COLORS['warning'], symbol='circle'),
line=dict(color=COLORS['warning'], width=5),
name='Optimization Path',
hovertemplate='Step %{pointNumber}
Position: (%{x:.2f}, %{y:.2f})
Loss: %{z:.3f}',
))
# Start marker
fig.add_trace(go.Scatter3d(
x=[path_x[0]], y=[path_y[0]], z=[path_z[0]],
mode='markers+text',
marker=dict(size=12, color='red', symbol='diamond'),
text=['START'],
textposition='top center',
name='Starting Point',
))
# End marker
fig.add_trace(go.Scatter3d(
x=[path_x[-1]], y=[path_y[-1]], z=[path_z[-1]],
mode='markers+text',
marker=dict(size=12, color='lime', symbol='diamond'),
text=['END'],
textposition='top center',
name='Final Point',
))
# Determine what happened
final_loss = path_z[-1]
if lr > 0.9:
status = "β οΈ Learning rate too HIGH - overshooting!"
status_color = 'red'
elif lr < 0.05:
status = "π Learning rate too LOW - very slow progress"
status_color = 'orange'
elif final_loss < 0.1:
status = "β
Good convergence!"
status_color = 'lime'
else:
status = f"π Loss: {final_loss:.3f} (keep optimizing)"
status_color = 'yellow'
fig.update_layout(
title=dict(
text=f"{title}
Learning Rate (Ξ±) = {lr} | Steps = {len(path_x)-1} | {status}",
font=dict(size=16)
),
scene=dict(
xaxis_title='Parameter ΞΈβ',
yaxis_title='Parameter ΞΈβ',
zaxis_title='Loss J(ΞΈ)',
camera=dict(eye=dict(x=1.8, y=1.8, z=1.2)),
annotations=[
dict(
x=0, y=0, z=0,
text="β Global Minimum
(What we want to find!)",
showarrow=True,
arrowhead=2,
font=dict(size=12, color='lime'),
)
]
),
**LAYOUT_DEFAULTS,
height=550,
)
return fig
def gradient_descent_2d(config: dict) -> go.Figure:
"""2D contour view of gradient descent - easier to understand for beginners."""
title = config.get('title', 'Gradient Descent (Top View)')
lr = float(config.get('learning_rate') or config.get('lr') or 0.1)
start_x = float(config.get('start_x') or 2.0)
start_y = float(config.get('start_y') or 2.0)
n_steps = int(config.get('n_steps') or config.get('iterations') or 30)
seed = config.get('seed', 42)
np.random.seed(seed)
# Create contour
x = np.linspace(-3, 3, 100)
y = np.linspace(-3, 3, 100)
X, Y = np.meshgrid(x, y)
Z = X**2 + Y**2
# Run gradient descent
path_x, path_y = [start_x], [start_y]
px, py = start_x, start_y
for _ in range(n_steps):
gx, gy = 2 * px, 2 * py
px = px - lr * gx
py = py - lr * gy
path_x.append(px)
path_y.append(py)
if px**2 + py**2 < 0.001:
break
fig = go.Figure()
# Contour plot
fig.add_trace(go.Contour(
x=x, y=y, z=Z,
colorscale='Viridis',
contours=dict(showlabels=True, labelfont=dict(size=10, color='white')),
name='Loss Contours',
hovertemplate='Loss: %{z:.2f}',
))
# Path
fig.add_trace(go.Scatter(
x=path_x, y=path_y,
mode='lines+markers',
marker=dict(size=8, color=COLORS['warning']),
line=dict(color=COLORS['warning'], width=3, dash='solid'),
name='Optimization Path',
hovertemplate='Step %{pointNumber}
(%{x:.2f}, %{y:.2f})',
))
# Arrows showing direction
for i in range(0, len(path_x)-1, max(1, len(path_x)//5)):
fig.add_annotation(
x=path_x[i+1], y=path_y[i+1],
ax=path_x[i], ay=path_y[i],
xref='x', yref='y', axref='x', ayref='y',
showarrow=True, arrowhead=2, arrowsize=1.5,
arrowcolor=COLORS['warning']
)
# Start and end
fig.add_trace(go.Scatter(
x=[path_x[0]], y=[path_y[0]],
mode='markers+text',
marker=dict(size=15, color='red', symbol='star'),
text=['START'], textposition='top right',
name='Start',
))
fig.add_trace(go.Scatter(
x=[path_x[-1]], y=[path_y[-1]],
mode='markers+text',
marker=dict(size=15, color='lime', symbol='star'),
text=['END'], textposition='top right',
name='End',
))
# Minimum point
fig.add_trace(go.Scatter(
x=[0], y=[0],
mode='markers+text',
marker=dict(size=20, color='white', symbol='x'),
text=['MINIMUM'], textposition='bottom center',
name='Global Minimum',
))
fig.update_layout(
title=dict(text=f"{title}
Bird's eye view - contour lines show equal loss"),
xaxis=dict(title='ΞΈβ', scaleanchor='y'),
yaxis=dict(title='ΞΈβ'),
**LAYOUT_DEFAULTS,
height=500,
)
return fig
def loss_curve(config: dict) -> go.Figure:
"""Training loss over time - shows convergence."""
title = config.get('title', 'Loss Over Training Steps')
lr = float(config.get('learning_rate') or config.get('lr') or 0.1)
start_x = float(config.get('start_x') or 2.0)
start_y = float(config.get('start_y') or 2.0)
n_steps = int(config.get('n_steps') or config.get('iterations') or 30)
seed = config.get('seed', 42)
np.random.seed(seed)
# Calculate loss at each step
px, py = start_x, start_y
losses = [px**2 + py**2]
for _ in range(n_steps):
gx, gy = 2 * px, 2 * py
px = px - lr * gx
py = py - lr * gy
losses.append(px**2 + py**2)
steps = list(range(len(losses)))
fig = go.Figure()
fig.add_trace(go.Scatter(
x=steps, y=losses,
mode='lines+markers',
marker=dict(size=8, color=COLORS['primary']),
line=dict(color=COLORS['primary'], width=3),
fill='tozeroy',
fillcolor='rgba(102, 126, 234, 0.2)',
name='Training Loss',
hovertemplate='Step %{x}
Loss: %{y:.4f}',
))
# Add annotations
fig.add_annotation(
x=0, y=losses[0],
text=f"Starting Loss: {losses[0]:.2f}",
showarrow=True, arrowhead=2,
font=dict(size=11), ax=50, ay=-30
)
fig.add_annotation(
x=len(losses)-1, y=losses[-1],
text=f"Final Loss: {losses[-1]:.4f}",
showarrow=True, arrowhead=2,
font=dict(size=11, color='lime'), ax=-50, ay=-30
)
# Good/bad indicator
if losses[-1] < 0.01:
fig.add_hline(y=0.01, line_dash="dash", line_color="lime",
annotation_text="β
Converged!", annotation_position="right")
elif losses[-1] > losses[0] * 0.9:
fig.add_annotation(
x=0.5, y=0.5, xref='paper', yref='paper',
text="β οΈ Not converging well - try adjusting learning rate",
font=dict(size=14, color='orange'),
showarrow=False, bgcolor='rgba(0,0,0,0.7)', borderpad=10
)
fig.update_layout(
title=dict(text=f"{title}
Watch the loss decrease as we optimize"),
xaxis=dict(title='Training Step'),
yaxis=dict(title='Loss', type='log' if max(losses) > 100 else 'linear'),
**LAYOUT_DEFAULTS,
height=400,
)
return fig
def flow_diagram(config: dict) -> go.Figure:
"""Neural network architecture with annotations."""
title = config.get('title', 'Neural Network Architecture')
raw_layers = config.get('layers', [3, 4, 4, 2])
# Normalize layers
layers = []
layer_names = ['Input', 'Hidden 1', 'Hidden 2', 'Hidden 3', 'Output']
layer_colors = [COLORS['info'], COLORS['primary'], COLORS['secondary'], COLORS['accent'], COLORS['success']]
for i, layer in enumerate(raw_layers):
if isinstance(layer, int):
name = 'Input' if i == 0 else ('Output' if i == len(raw_layers)-1 else f'Hidden {i}')
layers.append({'name': name, 'nodes': layer, 'color': layer_colors[i % len(layer_colors)]})
elif isinstance(layer, dict):
layers.append({
'name': layer.get('name', f'Layer {i+1}'),
'nodes': layer.get('nodes', 3),
'color': layer.get('color', layer_colors[i % len(layer_colors)])
})
fig = go.Figure()
layer_x = np.linspace(0.1, 0.9, len(layers))
# Draw connections
for i in range(len(layers) - 1):
n1, n2 = layers[i]['nodes'], layers[i+1]['nodes']
y1 = np.linspace(0.2, 0.8, n1)
y2 = np.linspace(0.2, 0.8, n2)
for y1_pos in y1:
for y2_pos in y2:
weight = np.random.uniform(0.2, 1.0)
fig.add_trace(go.Scatter(
x=[layer_x[i], layer_x[i+1]], y=[y1_pos, y2_pos],
mode='lines',
line=dict(color=f'rgba(102, 126, 234, {weight * 0.5})', width=weight * 2),
hoverinfo='skip', showlegend=False,
))
# Draw nodes
for i, (layer, x_pos) in enumerate(zip(layers, layer_x)):
n = layer['nodes']
y_positions = np.linspace(0.2, 0.8, n)
fig.add_trace(go.Scatter(
x=[x_pos] * n, y=y_positions,
mode='markers+text',
marker=dict(size=35, color=layer['color'], line=dict(color='white', width=2)),
text=[str(j+1) for j in range(n)],
textposition='middle center',
textfont=dict(color='white', size=11),
name=layer['name'],
hovertemplate=f"{layer['name']}
Neuron %{{text}}",
))
# Layer label
fig.add_annotation(
x=x_pos, y=-0.05, text=f"{layer['name']}
({n} neurons)",
showarrow=False, font=dict(size=12, color=layer['color']),
)
# Educational annotations
fig.add_annotation(
x=layer_x[0], y=0.95,
text="π₯ Input
Your data goes here",
showarrow=False, font=dict(size=10, color=COLORS['info']),
)
fig.add_annotation(
x=layer_x[-1], y=0.95,
text="π€ Output
Predictions come out here",
showarrow=False, font=dict(size=10, color=COLORS['success']),
)
if len(layers) > 2:
mid = len(layers) // 2
fig.add_annotation(
x=layer_x[mid], y=0.95,
text="π§ Hidden Layers
Learn patterns",
showarrow=False, font=dict(size=10, color=COLORS['primary']),
)
fig.update_layout(
title=dict(text=f"{title}
Data flows left β right through connected neurons"),
xaxis=dict(visible=False, range=[-0.05, 1.05]),
yaxis=dict(visible=False, range=[-0.15, 1.05]),
**LAYOUT_DEFAULTS,
height=500,
showlegend=False,
)
return fig
def matrix_heatmap(config: dict) -> go.Figure:
"""
PURE RENDERER - Matrix heatmap for attention weights, feature maps, confusion matrices.
REQUIRES from LLM:
- labels: list of row/column labels (e.g., ["Token 1", "Token 2"] or ["Filter A", "Filter B"])
- values: 2D array of values (e.g., [[0.1, 0.2], [0.3, 0.4]])
- x_title: title for x-axis (e.g., "Keys" for attention, "Feature" for CNN)
- y_title: title for y-axis (e.g., "Queries" for attention, "Filter" for CNN)
INTERACTIVE PARAMS:
- focus_row: Index of row to highlight (1-indexed for user-friendliness)
- threshold: Only show values above this threshold (0.0-1.0)
Does NOT generate data internally - LLM must provide concept-specific data.
"""
title = config.get('title', 'Matrix Visualization')
subtitle = config.get('subtitle', 'Brighter = higher value')
colorbar_title = config.get('colorbar_title', 'Value')
# REQUIRE data from LLM
labels = config.get('labels')
values = config.get('values')
x_title = config.get('x_title')
y_title = config.get('y_title')
# If required data is missing, show clear error
if not labels or not values or not x_title or not y_title:
return create_placeholder_figure(
"Missing Data for Matrix Heatmap",
"LLM must provide: labels, values (2D array), x_title, y_title.
"
"Example for Attention:
"
"labels=['The', 'cat', 'sat'], x_title='Keys', y_title='Queries'
"
"values=[[0.5, 0.3, 0.2], [0.1, 0.7, 0.2], [0.2, 0.2, 0.6]]"
)
# Convert values to numpy array
try:
data = np.array(values, dtype=float)
if data.ndim != 2:
raise ValueError("values must be 2D")
except (TypeError, ValueError) as e:
return create_placeholder_figure(
"Invalid Data",
f"'values' must be a 2D array of numbers.
Error: {e}"
)
# INTERACTIVE: Focus on specific row (1-indexed for user-friendliness)
focus_row = config.get('focus_row')
if focus_row is not None:
focus_idx = int(focus_row) - 1 # Convert to 0-indexed
if 0 <= focus_idx < data.shape[0]:
# Dim other rows to highlight the focused one
mask = np.ones_like(data) * 0.3
mask[focus_idx, :] = 1.0
data = data * mask
subtitle = f'Focusing on row {focus_row}: "{labels[focus_idx]}"'
# INTERACTIVE: Apply threshold filter
threshold = config.get('threshold')
if threshold is not None:
threshold = float(threshold)
data = np.where(data >= threshold, data, 0)
subtitle = f'Showing values β₯ {threshold:.2f}'
size = len(labels)
hover_template = config.get('hover_template', '%{y} β %{x}
Value: %{z:.3f}')
fig = go.Figure()
fig.add_trace(go.Heatmap(
z=data, x=labels[:data.shape[1]], y=labels[:data.shape[0]],
colorscale='Viridis',
hovertemplate=hover_template,
colorbar=dict(title=dict(text=colorbar_title, side='right'), thickness=15),
))
# Highlight focused row with border
if focus_row is not None and 0 <= int(focus_row) - 1 < data.shape[0]:
focus_idx = int(focus_row) - 1
fig.add_shape(
type='rect',
x0=-0.5, x1=data.shape[1] - 0.5,
y0=focus_idx - 0.5, y1=focus_idx + 0.5,
line=dict(color='#ff6b6b', width=3),
)
# Highlight diagonal if square matrix (only if no focus)
if focus_row is None and data.shape[0] == data.shape[1]:
for i in range(min(data.shape[0], len(labels))):
if i < data.shape[0] and i < data.shape[1]:
fig.add_annotation(
x=i, y=i,
text="β" if data[i, i] > 0.2 else "",
showarrow=False, font=dict(size=8, color='red')
)
fig.update_layout(
title=dict(text=f"{title}
{subtitle}"),
xaxis=dict(title=x_title, tickangle=45),
yaxis=dict(title=y_title, autorange='reversed'),
**LAYOUT_DEFAULTS,
)
return fig
def distribution_plot(config: dict) -> go.Figure:
"""
PURE RENDERER - Probability distribution visualization.
REQUIRES from LLM:
- categories: list of labels (e.g., ["Token 1", "Token 2"] or ["Dog", "Cat"])
- values: list of probabilities (should sum to ~1.0)
INTERACTIVE PARAMS:
- temperature: Adjusts distribution sharpness (0.1=peaked, 2.0=uniform)
Does NOT generate data internally - LLM must provide concept-specific data.
"""
title = config.get('title', 'Probability Distribution')
subtitle = config.get('subtitle', 'Output probabilities')
x_title = config.get('x_title', 'Category')
# REQUIRE data from LLM - no defaults!
categories = config.get('categories')
values = config.get('values')
# If data is missing, show clear error instead of wrong defaults
if not categories or not values:
return create_placeholder_figure(
"Missing Data for Distribution Plot",
"LLM must provide 'categories' and 'values'.
"
"Example: categories=['Token 1', 'Token 2'], values=[0.6, 0.4]"
)
# Ensure values is a list of floats
try:
values = [float(v) for v in values]
except (TypeError, ValueError):
return create_placeholder_figure(
"Invalid Data",
f"'values' must be a list of numbers, got: {type(values)}"
)
# INTERACTIVE: Apply temperature scaling (softmax with temperature)
temperature = float(config.get('temperature', 1.0))
if temperature != 1.0 and temperature > 0:
# Convert to logits (inverse softmax approximation), apply temperature, then softmax
log_values = np.log(np.array(values) + 1e-10)
scaled = log_values / temperature
exp_values = np.exp(scaled - np.max(scaled)) # Numerical stability
values = (exp_values / exp_values.sum()).tolist()
subtitle = f'Temperature = {temperature:.1f} ({"sharper" if temperature < 1 else "smoother"})'
# Sort by probability for better visualization
sorted_pairs = sorted(zip(categories, values), key=lambda x: x[1], reverse=True)
categories = [p[0] for p in sorted_pairs]
values = [p[1] for p in sorted_pairs]
colors = [COLORS['gradient'][i % len(COLORS['gradient'])] for i in range(len(categories))]
fig = go.Figure()
fig.add_trace(go.Bar(
x=categories, y=values,
marker=dict(color=colors, line=dict(color='white', width=1)),
text=[f'{v:.1%}' for v in values],
textposition='outside',
textfont=dict(size=14, color='white'),
hovertemplate='%{x}
Probability: %{y:.2%}',
))
# Highlight winner
if values:
fig.add_annotation(
x=categories[0], y=values[0] + 0.05,
text="π Highest",
showarrow=False, font=dict(size=12, color='lime'),
)
fig.update_layout(
title=dict(text=f"{title}
{subtitle}"),
xaxis=dict(title=x_title),
yaxis=dict(title='Probability', tickformat='.0%', range=[0, max(values) * 1.3] if values else [0, 1]),
**LAYOUT_DEFAULTS,
)
return fig
def decision_boundary(config: dict) -> go.Figure:
"""Classification decision boundary."""
title = config.get('title', 'Decision Boundary')
model_type = config.get('model_type', 'linear')
n_points = int(config.get('n_points', 200))
seed = config.get('seed', 42)
np.random.seed(seed)
# Generate data
if model_type == 'circular':
r1 = np.random.randn(n_points//2) * 0.5 + 1
r2 = np.random.randn(n_points//2) * 0.5 + 3
theta = np.random.rand(n_points) * 2 * np.pi
r = np.concatenate([r1, r2])
X = np.column_stack([r * np.cos(theta), r * np.sin(theta)])
y = (r < 2).astype(int)
else: # linear
X = np.random.randn(n_points, 2) * 2
y = (X[:, 0] + X[:, 1] > 0).astype(int)
fig = go.Figure()
# Decision regions
xx, yy = np.meshgrid(np.linspace(X[:,0].min()-1, X[:,0].max()+1, 100),
np.linspace(X[:,1].min()-1, X[:,1].max()+1, 100))
if model_type == 'circular':
Z = (np.sqrt(xx**2 + yy**2) < 2).astype(float)
else:
Z = (xx + yy > 0).astype(float)
fig.add_trace(go.Contour(
x=np.linspace(X[:,0].min()-1, X[:,0].max()+1, 100),
y=np.linspace(X[:,1].min()-1, X[:,1].max()+1, 100),
z=Z,
colorscale=[[0, 'rgba(250,112,154,0.3)'], [1, 'rgba(79,172,254,0.3)']],
showscale=False, contours=dict(showlines=True, coloring='fill'),
hoverinfo='skip',
))
# Data points
for label, color, name in [(0, COLORS['warning'], 'Class A'), (1, COLORS['info'], 'Class B')]:
mask = y == label
fig.add_trace(go.Scatter(
x=X[mask, 0], y=X[mask, 1],
mode='markers',
marker=dict(size=10, color=color, line=dict(color='white', width=1)),
name=name,
hovertemplate=f'{name}
(%{{x:.1f}}, %{{y:.1f}})',
))
# Boundary annotation
fig.add_annotation(
x=0, y=0,
text="β Decision Boundary
Model classifies differently
on each side",
showarrow=True, arrowhead=2,
font=dict(size=11, color='white'),
ax=100, ay=-50
)
fig.update_layout(
title=dict(text=f"{title}
Shaded regions show model's classification"),
xaxis=dict(title='Feature 1'),
yaxis=dict(title='Feature 2'),
legend=dict(orientation='h', y=1.1),
**LAYOUT_DEFAULTS,
)
return fig
def line_progression(config: dict) -> go.Figure:
"""Training curves with annotations."""
title = config.get('title', 'Training Progress')
epochs = int(config.get('epochs', 50))
seed = config.get('seed', 42)
np.random.seed(seed)
x = np.arange(1, epochs + 1)
# Generate realistic curves
train_loss = 2.0 * np.exp(-0.08 * x) + 0.1 + np.random.randn(epochs) * 0.03
val_loss = 2.2 * np.exp(-0.06 * x) + 0.15 + np.random.randn(epochs) * 0.05
fig = go.Figure()
fig.add_trace(go.Scatter(
x=x, y=train_loss, mode='lines',
name='Training Loss', line=dict(color=COLORS['primary'], width=3),
hovertemplate='Epoch %{x}
Train Loss: %{y:.4f}',
))
fig.add_trace(go.Scatter(
x=x, y=val_loss, mode='lines',
name='Validation Loss', line=dict(color=COLORS['warning'], width=3, dash='dash'),
hovertemplate='Epoch %{x}
Val Loss: %{y:.4f}',
))
# Overfit annotation if applicable
if val_loss[-1] > val_loss[epochs//2]:
overfit_start = epochs // 2
fig.add_vrect(x0=overfit_start, x1=epochs, fillcolor='red', opacity=0.1)
fig.add_annotation(
x=overfit_start + 10, y=val_loss.max(),
text="β οΈ Overfitting zone
Val loss increasing",
showarrow=False, font=dict(color='red', size=11),
)
fig.update_layout(
title=dict(text=f"{title}
Lower is better - watch for val loss going up"),
xaxis=dict(title='Epoch'),
yaxis=dict(title='Loss'),
legend=dict(orientation='h', y=1.1),
hovermode='x unified',
**LAYOUT_DEFAULTS,
)
return fig
def comparison_bars(config: dict) -> go.Figure:
"""Model comparison bars."""
title = config.get('title', 'Model Comparison')
categories = config.get('categories', ['Accuracy', 'Precision', 'Recall', 'F1'])
seed = config.get('seed', 42)
np.random.seed(seed)
groups = [
{'name': 'Model A', 'values': [0.92, 0.89, 0.94, 0.91], 'color': COLORS['primary']},
{'name': 'Model B', 'values': [0.88, 0.91, 0.85, 0.88], 'color': COLORS['warning']},
]
fig = go.Figure()
for group in groups:
fig.add_trace(go.Bar(
x=categories, y=group['values'],
name=group['name'],
marker=dict(color=group['color'], line=dict(color='white', width=1)),
text=[f'{v:.0%}' for v in group['values']],
textposition='outside',
hovertemplate=f"{group['name']}
%{{x}}: %{{y:.1%}}",
))
fig.update_layout(
title=dict(text=f"{title}"),
barmode='group',
yaxis=dict(tickformat='.0%', range=[0, 1.15]),
legend=dict(orientation='h', y=1.1),
**LAYOUT_DEFAULTS,
)
return fig
# Valid Plotly trace types for validation
VALID_TRACE_TYPES = {
'scatter', 'scatter3d', 'scattergl', 'scatterpolar', 'scattergeo',
'bar', 'histogram', 'histogram2d', 'box', 'violin',
'heatmap', 'contour', 'surface', 'mesh3d',
'pie', 'sunburst', 'treemap', 'sankey', 'funnel',
'indicator', 'table', 'carpet', 'cone', 'streamtube',
'isosurface', 'volume', 'image', 'candlestick', 'ohlc'
}
def validate_trace(trace: dict) -> bool:
"""Validate a Plotly trace dict has valid structure."""
if not isinstance(trace, dict):
return False
trace_type = trace.get('type', 'scatter')
return trace_type in VALID_TRACE_TYPES
def custom_plotly(config: dict) -> go.Figure:
"""
Render arbitrary Plotly visualization from JSON spec.
This enables the LLM to generate ANY visualization for ANY ML concept
by providing the full Plotly JSON specification.
Config should contain:
- 'data': list of trace dicts (required)
- 'layout': layout dict (optional, merged with defaults)
- 'title': override title (optional)
- 'template': name of a pre-defined template to use (optional)
Security: This does NOT execute code - it only parses JSON into Plotly objects.
"""
title = config.get('title', 'Visualization')
data = config.get('data', [])
layout = config.get('layout', {})
template_name = config.get('template')
# Load template if specified
if template_name:
try:
from templates import get_template
template = get_template(template_name)
if template:
# Merge template data with provided data
data = template.get('base_data', []) + data
# Merge template layout
template_layout = template.get('layout', {})
layout = {**template_layout, **layout}
# Add template annotations
if 'annotations' in template:
layout['annotations'] = layout.get('annotations', []) + template['annotations']
except ImportError:
print(f"Templates module not found, skipping template: {template_name}")
# Validate data structure
if not data or not isinstance(data, list):
print("custom_plotly: No valid data provided, falling back to default")
return scatter_cluster({'title': title, 'n_clusters': 3})
# Filter and validate traces
valid_traces = []
for i, trace in enumerate(data):
if validate_trace(trace):
valid_traces.append(trace)
else:
print(f"custom_plotly: Skipping invalid trace at index {i}: {type(trace)}")
if not valid_traces:
print("custom_plotly: No valid traces found, falling back to default")
return scatter_cluster({'title': title, 'n_clusters': 3})
# Limit trace count to prevent memory issues
MAX_TRACES = 50
if len(valid_traces) > MAX_TRACES:
print(f"custom_plotly: Limiting traces from {len(valid_traces)} to {MAX_TRACES}")
valid_traces = valid_traces[:MAX_TRACES]
# Create figure from validated data
try:
fig = go.Figure(data=valid_traces)
except Exception as e:
print(f"custom_plotly: Error creating figure: {e}")
return scatter_cluster({'title': title, 'n_clusters': 3})
# Merge layout with defaults
merged_layout = {**LAYOUT_DEFAULTS}
# Apply custom layout (safely)
safe_layout_keys = [
'title', 'xaxis', 'yaxis', 'zaxis', 'showlegend', 'legend',
'annotations', 'shapes', 'images', 'height', 'width',
'scene', 'geo', 'mapbox', 'polar', 'ternary',
'coloraxis', 'hovermode', 'dragmode', 'barmode', 'bargap'
]
for key in safe_layout_keys:
if key in layout:
merged_layout[key] = layout[key]
# Set title with styling
if title:
merged_layout['title'] = dict(
text=f"{title}",
font=dict(size=18)
)
fig.update_layout(**merged_layout)
return fig
# Registry
COMPONENTS = {
'scatter_cluster': scatter_cluster,
'cluster_distribution': cluster_distribution,
'gradient_descent_3d': gradient_descent_3d,
'gradient_descent_2d': gradient_descent_2d,
'loss_curve': loss_curve,
'flow_diagram': flow_diagram,
'matrix_heatmap': matrix_heatmap,
'distribution_plot': distribution_plot,
'decision_boundary': decision_boundary,
'line_progression': line_progression,
'comparison_bars': comparison_bars,
'custom_plotly': custom_plotly,
}
def render_component(component_type: str, config: dict) -> go.Figure:
"""Render a component by type."""
if component_type not in COMPONENTS:
# Fallback to scatter_cluster for unknown types
print(f"Unknown component {component_type}, using scatter_cluster")
return scatter_cluster(config)
return COMPONENTS[component_type](config)