# ============================================================
# Chest X-ray Doctor Interface with Transformer Attention Maps
# ============================================================
import os, cv2
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms, models
from transformers import ViTModel
import gradio as gr
import google.generativeai as genai
from datetime import datetime
import re
import json
# ------------------------------------------------------------
# Device
# ------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# ------------------------------------------------------------
# Google Gemini Configuration
# ------------------------------------------------------------
GEMINI_API_KEY = "AIzaSyBY6rQz-TCRenrrdXv2uKbE4GTbgHQbLuk"
if GEMINI_API_KEY:
genai.configure(api_key=GEMINI_API_KEY)
gemini_model = genai.GenerativeModel('gemini-3-pro-preview') # Pro model for better visual analysis
print("Google Gemini API configured successfully (gemini-1.5-pro with vision)")
else:
gemini_model = None
print("Warning: GEMINI_API_KEY not found. Gemini analysis will be disabled.")
print("Set it using: export GEMINI_API_KEY='your-api-key' (Linux/Mac)")
print("Or: set GEMINI_API_KEY=your-api-key (Windows)")
# ------------------------------------------------------------
# Labels (Primary Diseases to Consider)
# ------------------------------------------------------------
label_columns = [
'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Effusion',
'Emphysema', 'Fibrosis', 'Hernia', 'Infiltration'
]
# Note: Co-morbidities and conditions outside this list can also occur
# ------------------------------------------------------------
# Bio-ViL Model (for predictions)
# ------------------------------------------------------------
class BioViLChestXRayModel(nn.Module):
"""Bio-ViL model for chest X-ray classification"""
def __init__(self, num_classes=14, dropout_rate=0.3, model_dir="model"):
super().__init__()
print(f"Loading Bio-ViL vision model from: {model_dir}")
self.vision_model = ViTModel.from_pretrained(
model_dir,
torch_dtype=torch.float32,
ignore_mismatched_sizes=True
)
# Feature dimension
if hasattr(self.vision_model.config, 'hidden_size'):
self.feature_dim = self.vision_model.config.hidden_size
else:
self.feature_dim = 768
print(f"Model feature dimension: {self.feature_dim}")
# Classification head - multi-head attention
self.classifier = nn.Sequential(
nn.Dropout(dropout_rate),
nn.Linear(self.feature_dim, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(256, num_classes)
)
# Disease-specific attention
self.disease_attention = nn.MultiheadAttention(
embed_dim=self.feature_dim,
num_heads=8,
dropout=dropout_rate,
batch_first=True
)
def forward(self, pixel_values):
if pixel_values.dim() == 3:
pixel_values = pixel_values.unsqueeze(0)
outputs = self.vision_model(pixel_values, output_hidden_states=True)
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
features = outputs.pooler_output
else:
features = outputs.last_hidden_state[:, 0, :]
# Apply attention
features_expanded = features.unsqueeze(1)
attended_features, _ = self.disease_attention(
features_expanded, features_expanded, features_expanded
)
features = attended_features.squeeze(1)
logits = self.classifier(features)
return logits
# ------------------------------------------------------------
# Hybrid CNN + ViT Model (for attention visualization)
# ------------------------------------------------------------
class CNNViTChestXRayModel(nn.Module):
def __init__(self, num_classes=14, dropout_rate=0.3):
super().__init__()
# Initialize ResNet-50 for CNN features
resnet = models.resnet50(pretrained=True)
self.cnn = nn.Sequential(*list(resnet.children())[:-2]) # Take layers up to the last pooling
# Initialize ViT model from HuggingFace
# We enable output_attentions=True to retrieve attention maps
self.vit = ViTModel.from_pretrained(
"google/vit-base-patch16-224-in21k",
torch_dtype=torch.float32,
output_attentions=True, # <-- enable attention maps
ignore_mismatched_sizes=True
)
self.feature_dim = self.vit.config.hidden_size # 768 for ViT-base
# 1x1 Conv to project ResNet's 2048 channels to ViT's feature dimension (768)
self.proj = nn.Conv2d(2048, self.feature_dim, kernel_size=1)
# Multi-head attention layer for disease-specific feature aggregation
self.disease_attention = nn.MultiheadAttention(
embed_dim=self.feature_dim, num_heads=8,
dropout=dropout_rate, batch_first=True
)
# Classification head
self.classifier = nn.Sequential(
nn.Dropout(dropout_rate),
nn.Linear(self.feature_dim, 512), nn.BatchNorm1d(512), nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(256, num_classes)
)
self.model_name = "ResNet50 + ViT Hybrid"
def forward(self, pixel_values):
# Handle single image input by adding batch dimension
if pixel_values.dim() == 3:
pixel_values = pixel_values.unsqueeze(0)
# 1. CNN Feature Extraction (ResNet-50)
cnn_feats = self.cnn(pixel_values) # Output: B, 2048, H/32, W/32
# 2. Projection (1x1 Conv)
feats = self.proj(cnn_feats) # Output: B, 768, H/32, W/32 (e.g., B, 768, 7, 7 for 224x224 input)
B, C, H, W = feats.shape
# 3. Flatten and Transpose to create ViT-like tokens (excluding CLS token)
tokens = feats.flatten(2).transpose(1, 2) # Output: B, 49, 768 (49 tokens for 7x7 map)
# 4. Add CLS Token
cls_token = self.vit.embeddings.cls_token.expand(B, -1, -1)
tokens = torch.cat((cls_token, tokens), dim=1) # Output: B, 50, 768
# 5. Add Position Embeddings
pos_emb = self.vit.embeddings.position_embeddings[:, :tokens.size(1), :].to(tokens.device)
tokens = tokens + pos_emb
# 6. ViT Encoder Pass
vit_out = self.vit.encoder(tokens, output_attentions=True)
hidden_states = vit_out.last_hidden_state
attentions = vit_out.attentions # list of attention maps from all layers
# 7. Extract CLS token feature
features = hidden_states[:, 0, :] # Output: B, 768
# 8. Disease-specific Attention (Self-Attention on the CLS token)
features_expanded = features.unsqueeze(1) # B, 1, 768 (Query/Key/Value)
attended, _ = self.disease_attention(features_expanded, features_expanded, features_expanded)
features = attended.squeeze(1) # Output: B, 768 (Final aggregated feature)
# 9. Classification
logits = self.classifier(features)
# Return projected CNN features (feats) for CNN activation dashboard
return logits, attentions, feats
# ------------------------------------------------------------
# Image Preprocessing
# ------------------------------------------------------------
def preprocess_medical_image(img):
"""Apply CLAHE and preprocessing to medical image"""
# Initialize CLAHE for contrast enhancement
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
# Convert to grayscale if needed
if len(img.shape) == 3:
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
# Apply CLAHE
img = clahe.apply(img)
# Normalize to 0-255 range
img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX)
# Apply slight Gaussian blur for noise reduction
img = cv2.GaussianBlur(img, (3, 3), 0.5)
# Convert to 3-channel RGB for model
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
return img
# ------------------------------------------------------------
# Transforms
# ------------------------------------------------------------
val_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# ------------------------------------------------------------
# Visualization Helper: ViT Attention (from original code)
# ------------------------------------------------------------
def get_vit_blended_image(image_tensor, attention_maps):
"""
Computes and blends the image with the heatmap derived from the ViT CLS token's attention.
"""
# Denormalize image for display
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
img = image_tensor.permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
img = img * std + mean
img = (img - img.min()) / (img.max() - img.min()) # Normalize to 0-1 for display
img = (img * 255).astype(np.uint8) # Scale to 0-255
# Extract attention map (CLS token's attention to patches) from the last layer, mean across heads
cls_to_patches_attn = attention_maps[-1].mean(dim=1)[0, 0, 1:].detach().cpu().numpy()
# Reshape the attention scores (49 patches for 224x224 input)
expected_size = 49
if cls_to_patches_attn.size != expected_size:
# Handle cases where input size might result in different patch count
side = int(np.sqrt(cls_to_patches_attn.size))
attn = cls_to_patches_attn.reshape(side, side)
else:
attn = cls_to_patches_attn.reshape(7, 7)
# Resize the low-resolution attention map to match the original image size (224x224)
attn = cv2.resize(attn, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_LINEAR)
# Scale attention map to 0-255
attn = (attn - attn.min()) / (attn.max() - attn.min())
attn_scaled = np.uint8(255 * attn)
# Create heatmap using VIRIDIS colormap for contrast
heatmap = cv2.applyColorMap(attn_scaled, cv2.COLORMAP_VIRIDIS)
# Blend image and heatmap
img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
blended = cv2.addWeighted(img_bgr, 0.4, heatmap, 0.6, 0)
# Convert back to RGB for display
blended_rgb = cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)
return blended_rgb
# ------------------------------------------------------------
# Load Models
# ------------------------------------------------------------
BIOVIL_MODEL_PATH = "best_biovil_hf_model.pth" # Bio-ViL model for predictions
# Try to find the ViT model directory - common locations
BIOVIL_MODEL_DIR = None
for possible_dir in ["model", "vit-model", "microsoft/BiomedVLP-CXR-BERT-specialized"]:
if os.path.exists(possible_dir) or "/" in possible_dir:
BIOVIL_MODEL_DIR = possible_dir
break
# Fallback to Hugging Face model if no local model found
if BIOVIL_MODEL_DIR is None:
BIOVIL_MODEL_DIR = "google/vit-base-patch16-224-in21k"
ATTENTION_MODEL_PATH = "best_cnn_vit.pth" # CNN+ViT model for attention visualization
biovil_model = None
attention_model = None
def load_biovil_model():
"""Load the Bio-ViL model for predictions"""
global biovil_model
if biovil_model is not None:
return biovil_model
print(f"Loading Bio-ViL model with ViT from: {BIOVIL_MODEL_DIR}")
biovil_model = BioViLChestXRayModel(
num_classes=len(label_columns),
model_dir=BIOVIL_MODEL_DIR
).to(device)
if os.path.exists(BIOVIL_MODEL_PATH):
try:
checkpoint = torch.load(
BIOVIL_MODEL_PATH,
weights_only=False,
map_location=device
)
# Handle different checkpoint formats
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
biovil_model.load_state_dict(state_dict, strict=False)
print("Bio-ViL model weights loaded successfully.")
except Exception as e:
print(f"Warning: Could not load Bio-ViL model weights. Error: {e}")
else:
print(f"Warning: Bio-ViL checkpoint not found at {BIOVIL_MODEL_PATH}")
biovil_model.eval()
return biovil_model
def load_attention_model():
"""Load the CNN+ViT model for attention visualization"""
global attention_model
if attention_model is not None:
return attention_model
print("Loading attention visualization model...")
attention_model = CNNViTChestXRayModel(num_classes=len(label_columns)).to(device)
if os.path.exists(ATTENTION_MODEL_PATH):
try:
checkpoint = torch.load(
ATTENTION_MODEL_PATH,
weights_only=False,
map_location=device
)
state_dict = checkpoint.get("state_dict", checkpoint)
attention_model.load_state_dict(state_dict, strict=False)
print("Attention model weights loaded successfully.")
except Exception as e:
print(f"Warning: Could not load attention model weights. Error: {e}")
else:
print(f"Warning: Attention model checkpoint not found at {ATTENTION_MODEL_PATH}")
attention_model.eval()
return attention_model
# ------------------------------------------------------------
# Prediction Functions
# ------------------------------------------------------------
# NOTE: Bio-ViL model predictions are NO LONGER USED
# The deep learning model only generates attention maps
# All diagnosis is performed by Gemini AI from visual analysis
def get_attention_visualization(image):
"""
Generate attention visualization from CNN+ViT model
Args:
image: PIL Image or numpy array
Returns:
attention_img: Attention map overlay image
"""
if image is None:
return None
# Load attention model
model = load_attention_model()
# Convert to numpy array if PIL Image
if isinstance(image, Image.Image):
image = np.array(image)
# Preprocess the medical image
processed_img = preprocess_medical_image(image)
# Convert to PIL for transforms
pil_img = Image.fromarray(processed_img)
# Apply transforms
img_tensor = val_transforms(pil_img)
img_tensor = img_tensor.unsqueeze(0).to(device)
# Get attention maps
with torch.no_grad():
_, attentions, _ = model(img_tensor)
# Generate attention map
attention_img = get_vit_blended_image(
img_tensor[0].cpu(),
[attn for attn in attentions]
)
return attention_img
def generate_gemini_analysis(raw_image, attention_image):
"""
Generate medical diagnosis using Google Gemini AI with direct visual analysis
NO DEEP LEARNING MODEL PREDICTIONS - Pure visual interpretation by AI
Args:
raw_image: Original X-ray image (PIL Image)
attention_image: Attention map visualization (numpy array) - shows regions of interest
Returns:
analysis_html: HTML formatted comprehensive diagnosis from Gemini
"""
if gemini_model is None:
return """
β οΈ Gemini Analysis Unavailable
Google Gemini API key not configured. Set GEMINI_API_KEY environment variable to enable AI analysis.
"""
try:
# Convert attention map to PIL Image if it's a numpy array
if isinstance(attention_image, np.ndarray):
attention_pil = Image.fromarray(attention_image)
else:
attention_pil = attention_image
# Disease list to consider
disease_list = "\n".join([f"- {disease}" for disease in label_columns])
# Create detailed prompt for Gemini with structured JSON output
prompt = f"""You are a board-certified radiologist with 20+ years of experience in thoracic imaging. You are conducting a PRIMARY DIAGNOSTIC EVALUATION of a chest X-ray examination.
**EXAMINATION MATERIALS PROVIDED:**
1. **Original Chest X-ray Image** (first image) - The primary diagnostic image for your interpretation
2. **AI Attention Heatmap** (second image) - Visualization showing regions highlighted by a computer vision model (for reference only)
**PRIMARY DISEASES TO CONSIDER:**
{disease_list}
**Important Note:** Patients may have co-morbidities or conditions outside this list. Consider all findings you observe.
**YOUR DIAGNOSTIC TASK:**
YOU are the primary diagnostician. There are NO AI predictions provided. Perform a complete diagnostic evaluation based solely on what YOU observe in the chest X-ray image.
**CRITICAL: YOU MUST RESPOND WITH VALID JSON ONLY**
Provide your complete diagnostic evaluation as a JSON object with the following structure:
```json
{{
"patient_summary": "A clear, compassionate explanation for the patient in simple, non-technical language (2-3 paragraphs). IMPORTANT: Clearly state the suspected/anticipated condition or disease (e.g., 'This appearance is concerning for possible lung cancer' or 'These findings suggest pneumonia'). Do NOT be vague - patients need to understand what condition is being considered. Explain what was found, what it likely means, and what happens next. Use simple terms but be direct about the suspected diagnosis while acknowledging uncertainty where appropriate.",
"clinical_thought_process": "Comprehensive technical analysis for physicians. Include systematic anatomical review, differential diagnosis reasoning, specific findings with anatomical locations, measurements, confidence levels, and clinical significance. Use precise medical terminology.",
"findings": [
{{
"finding": "Name of finding (e.g., 'Right lower lobe consolidation')",
"location": "Specific anatomical location",
"severity": "mild/moderate/severe",
"confidence": "high/moderate/low",
"description": "Detailed technical description"
}}
],
"progression_assessment": "Analysis of disease stage, acuity (acute/chronic), progression risk, expected course, and prognosis. Include timeline expectations and monitoring recommendations.",
"further_tests": [
{{
"test": "Name of recommended test/study",
"urgency": "immediate/urgent/routine/optional",
"rationale": "Why this test is needed",
"expected_findings": "What this test would help clarify"
}}
],
"differential_diagnosis": [
{{
"diagnosis": "Name of condition",
"probability": "high/moderate/low",
"supporting_evidence": "Visual findings that support this diagnosis",
"against_evidence": "Findings that argue against this diagnosis"
}}
],
"attention_map_analysis": "Brief assessment of whether the AI attention map highlighted relevant abnormal regions you identified. Did it miss anything important?",
"clinical_recommendations": [
"Specific actionable recommendation 1",
"Specific actionable recommendation 2"
],
"limitations": "Any limitations in the current imaging that affect diagnostic certainty (e.g., single view, image quality issues, need for comparison studies)"
}}
```
---
**1. SYSTEMATIC IMAGE QUALITY & TECHNICAL ASSESSMENT**
Evaluate the radiograph's diagnostic quality:
- **Projection & Positioning**: PA/AP view, rotation, inspiratory effort
- **Penetration & Exposure**: Under/over-penetrated, optimal contrast
- **Coverage**: Adequate visualization of lung apices, costophrenic angles, and lateral chest wall
- **Artifacts or Limitations**: Foreign bodies, patient motion, technical factors affecting interpretation
- **Overall Image Quality**: Diagnostic quality rating (excellent/good/adequate/limited)
**2. DETAILED ANATOMICAL REVIEW**
Systematically examine each anatomical region:
**A. CARDIAC & MEDIASTINAL STRUCTURES:**
- **Heart Size**: Cardiothoracic ratio (estimate percentage), cardiomegaly present?
- **Cardiac Contours**: Right atrium, left ventricle, aortic knob visibility
- **Mediastinal Width**: Normal/widened, mass effect, contour abnormalities
- **Trachea**: Midline/deviated, caliber changes
- **Hilar Structures**: Size, density, prominence (lymphadenopathy?)
**B. PULMONARY PARENCHYMA:**
- **Right Lung Field**: Upper/middle/lower zone opacities, consolidations, masses, nodules
- **Left Lung Field**: Upper/middle/lower zone opacities, consolidations, masses, nodules
- **Lung Volumes**: Hyperinflation/atelectasis patterns
- **Interstitial Patterns**: Reticular, nodular, reticulonodular patterns
- **Air Bronchograms**: Present/absent, location
**C. PLEURAL SPACES & DIAPHRAGM:**
- **Pleural Effusions**: Blunted costophrenic angles (right/left), meniscus sign, quantity estimate
- **Pneumothorax**: Visceral pleural line visible, extent if present
- **Pleural Thickening**: Focal/diffuse, calcifications
- **Diaphragm**: Height, contour irregularities, elevation/flattening
**D. BONY THORAX & SOFT TISSUES:**
- **Ribs**: Fractures, lesions, bone density
- **Clavicles & Scapulae**: Asymmetry, fractures
- **Spine**: Alignment, degenerative changes visible
- **Soft Tissues**: Subcutaneous emphysema, breast shadows, chest wall masses
**3. AI ATTENTION MAP CRITICAL ANALYSIS**
Detailed interpretation of the model's focus areas:
- **Primary Focus Regions**: Identify specific anatomical areas highlighted (e.g., "right lower lobe", "left hilum", "cardiac apex")
- **Correlation with Visual Findings**: Do highlighted areas correspond to visible pathology you identified?
- **Model Performance Assessment**:
- True Positives: Model correctly identified abnormal areas
- False Positives: Model flagged normal areas as abnormal
- False Negatives: Model missed visible abnormalities
- Overall Model Reliability: Based on attention-finding correlation
- **Clinical Validity**: Does the AI's attention pattern match what an experienced radiologist would focus on?
**4. COMPREHENSIVE CLINICAL ASSESSMENT**
Synthesize findings into clinical diagnoses:
**A. PRIMARY FINDINGS** (Findings with highest clinical significance):
- List 2-3 most significant abnormalities
- Severity assessment for each (mild/moderate/severe)
- Clinical implications and pathophysiology
- Supporting visual evidence from the X-ray
**B. SECONDARY FINDINGS** (Additional observations):
- Less critical but notable findings
- Incidental findings that warrant mention
- Chronic vs acute appearance
**C. DIAGNOSTIC CONFIDENCE ASSESSMENT**:
- **High Confidence Findings**: Clearly evident abnormalities with definitive visual features
- **Moderate Confidence Findings**: Findings present but require clinical correlation or additional imaging
- **Low Confidence/Uncertain Findings**: Subtle changes that may represent pathology or normal variation
- **Attention Map Utility**: Did the attention heatmap highlight relevant abnormal regions you identified?
**5. DIFFERENTIAL DIAGNOSIS**
Provide ranked differential diagnoses:
- **Most Likely Diagnosis**: Primary diagnostic impression with justification
- **Alternative Diagnoses**: 2-3 other possibilities to consider
- **Discriminating Features**: What findings support one diagnosis over another
- **Clinical Context Needed**: What patient history would help narrow diagnosis
**6. DETAILED RECOMMENDATIONS**
**A. URGENCY STRATIFICATION**:
- **Critical (Immediate action within hours)**: Life-threatening findings
- **Urgent (Action within 24-48 hours)**: Significant findings requiring prompt attention
- **Semi-Urgent (Action within 1 week)**: Important but stable findings
- **Routine (Follow-up as scheduled)**: Chronic or minor findings
**B. ADDITIONAL IMAGING**:
- Specific studies recommended (CT chest, lateral view, etc.)
- Rationale for each recommendation
- Urgency of additional imaging
- What specific questions the additional imaging should answer
**C. CLINICAL CORRELATION**:
- Specific clinical signs/symptoms to correlate
- Laboratory tests that would be helpful (CBC, BNP, D-dimer, etc.)
- Physical examination findings to document
**D. FOLLOW-UP PLAN**:
- Recommended follow-up imaging timeline
- Monitoring parameters
- When to escalate care
**7. CRITICAL REASONING & LIMITATIONS**
**A. DIAGNOSTIC REASONING**:
- How visual findings integrate with model predictions
- Why you prioritized certain findings over others
- Pattern recognition applied (e.g., "bilateral perihilar opacities suggest pulmonary edema")
- Evidence-based support for conclusions
**B. UNCERTAINTY & LIMITATIONS**:
- Areas of diagnostic uncertainty
- Limitations of single-view radiograph
- Findings requiring additional views/studies for confirmation
- Technical factors limiting interpretation
**C. QUALITY ASSURANCE**:
- Self-check of systematic review (anything missed?)
- Alternative interpretations considered
- Peer consultation recommendations if complex
**8. PATIENT COMMUNICATION SUMMARY**
Brief, clear summary appropriate for patient discussion:
- Main findings in non-technical language
- Severity and implications
- Next steps explained simply
---
**IMPORTANT GUIDELINES:**
- **YOU ARE THE PRIMARY DIAGNOSTICIAN** - There are no AI predictions to validate
- Be SPECIFIC with anatomical locations (e.g., "right lower lobe lateral segment" not just "right lung")
- Quantify when possible (e.g., "approximately 300ml pleural effusion" not just "effusion present")
- Reference specific visual features you observe in the X-ray
- Provide measurements or estimates (cardiothoracic ratio, effusion size, etc.)
- Assign probability/confidence to each finding
- Acknowledge uncertainty where appropriate - radiology involves clinical judgment
- This is AI-generated analysis - MUST NOT replace board-certified radiologist review
- Focus on what you ACTUALLY SEE in the images
**RESPONSE FORMAT:**
- YOU MUST respond with ONLY the JSON object - no additional text before or after
- Ensure all JSON fields are properly escaped (quotes, newlines, etc.)
- If a section doesn't apply (e.g., no further tests needed), use empty arrays [] or explain why in the text
- Be thorough, detailed, and clinically comprehensive in each section
Respond ONLY with the JSON object now:"""
# Generate analysis with images
response = gemini_model.generate_content([
prompt,
raw_image,
attention_pil
])
analysis_text = response.text
# Try to parse JSON from response
try:
# Extract JSON if wrapped in code blocks
if "```json" in analysis_text:
json_start = analysis_text.find("```json") + 7
json_end = analysis_text.find("```", json_start)
json_str = analysis_text[json_start:json_end].strip()
elif "```" in analysis_text:
json_start = analysis_text.find("{")
json_end = analysis_text.rfind("}") + 1
json_str = analysis_text[json_start:json_end].strip()
else:
# Try to find JSON object directly
json_start = analysis_text.find("{")
json_end = analysis_text.rfind("}") + 1
json_str = analysis_text[json_start:json_end].strip()
analysis_data = json.loads(json_str)
# Return structured data for tabbed interface
return analysis_data
except (json.JSONDecodeError, ValueError) as e:
# Fallback: return error with raw text
return {
"patient_summary": "Error parsing AI response. Please try again.",
"clinical_thought_process": f"JSON parsing error: {str(e)}\n\nRaw response:\n{analysis_text}",
"findings": [],
"progression_assessment": "Unable to assess due to parsing error.",
"further_tests": [],
"differential_diagnosis": [],
"attention_map_analysis": "N/A",
"clinical_recommendations": [],
"limitations": "Response parsing failed. Please regenerate analysis."
}
except Exception as e:
error_html = f"""
β Gemini Analysis Error
Failed to generate analysis: {str(e)}
"""
return error_html
def predict_xray_quick(image):
"""
Quick analysis: Generate attention map only (NO predictions from deep learning model)
Args:
image: PIL Image or numpy array
Returns:
info_html: Architecture information
attention_img: Attention map overlay image
status_html: Status message
raw_image: Original image for Gemini
attention_img_for_gemini: Attention map for Gemini
"""
if image is None:
return "No image provided
", None, None, None
try:
# Store original image for Gemini
if isinstance(image, Image.Image):
raw_image = image.copy()
else:
raw_image = Image.fromarray(image)
# Get attention visualization ONLY (no predictions)
attention_img = get_attention_visualization(image)
# Loading message shown in summary tab while analysis runs
loading_html = """
π¬ Analyzing X-ray...
RadFusion is performing comprehensive visual diagnosis
"""
return loading_html, attention_img, raw_image, attention_img
except Exception as e:
error_msg = f""
return error_msg, None, None, None
def generate_full_analysis(raw_image, attention_image):
"""
Generate Gemini direct diagnosis from images and format for tabbed display
Args:
raw_image: Original X-ray image (PIL Image)
attention_image: Attention map visualization (numpy array)
Returns:
Tuple of HTML formatted sections for each tab:
(summary_html, thought_process_html, findings_html, progression_html,
further_tests_html, differential_html, recommendations_html)
"""
if raw_image is None:
error_msg = "No image available for analysis
"
return error_msg, error_msg, error_msg, error_msg, error_msg, error_msg, error_msg
try:
analysis_data = generate_gemini_analysis(raw_image, attention_image)
# Format Patient Summary (simple language)
summary_html = f"""
π Patient Summary
{analysis_data.get('patient_summary', 'No summary available').replace(chr(10), '
')}
βΉοΈ For Patients
This section is written in simple language to help you understand the findings.
"""
# Format Clinical Thought Process (technical language)
thought_process_html = f"""
π¬ Clinical Thought Process
{analysis_data.get('clinical_thought_process', 'No clinical analysis available')}
π©Ί For Healthcare Professionals
Comprehensive technical analysis with medical terminology.
"""
# Format Findings
findings_html = ""
findings_html += "
π Detailed Findings
"
findings_list = analysis_data.get('findings', [])
if findings_list:
for idx, finding in enumerate(findings_list, 1):
severity = finding.get('severity', 'unknown')
severity_color = {'mild': '#27ae60', 'moderate': '#f39c12', 'severe': '#e74c3c'}.get(severity.lower(), '#95a5a6')
confidence = finding.get('confidence', 'unknown')
findings_html += f"""
{idx}. {finding.get('finding', 'Unknown finding')}
{severity.upper()}
Location: {finding.get('location', 'Not specified')}
{finding.get('description', 'No description provided')}
Confidence: {confidence}
"""
else:
findings_html += "
No specific findings documented.
"
findings_html += "
"
# Format Progression Assessment
progression_html = f"""
π Progression Assessment
{analysis_data.get('progression_assessment', 'No progression assessment available')}
"""
# Format Further Tests
further_tests_html = ""
further_tests_html += "
π§ͺ Further Tests Recommended
"
tests_list = analysis_data.get('further_tests', [])
if tests_list:
for idx, test in enumerate(tests_list, 1):
urgency = test.get('urgency', 'routine')
urgency_color = {'immediate': '#e74c3c', 'urgent': '#f39c12', 'routine': '#3498db', 'optional': '#95a5a6'}.get(urgency.lower(), '#95a5a6')
further_tests_html += f"""
{idx}. {test.get('test', 'Unknown test')}
{urgency.upper()}
Rationale: {test.get('rationale', 'Not specified')}
Expected Findings: {test.get('expected_findings', 'Not specified')}
"""
else:
further_tests_html += "
No additional tests recommended at this time.
"
further_tests_html += "
"
# Format Differential Diagnosis
differential_html = ""
differential_html += "
π― Differential Diagnosis
"
diff_list = analysis_data.get('differential_diagnosis', [])
if diff_list:
for idx, diagnosis in enumerate(diff_list, 1):
probability = diagnosis.get('probability', 'unknown')
prob_color = {'high': '#e74c3c', 'moderate': '#f39c12', 'low': '#27ae60'}.get(probability.lower(), '#95a5a6')
differential_html += f"""
{idx}. {diagnosis.get('diagnosis', 'Unknown diagnosis')}
{probability.upper()} PROBABILITY
β Supporting Evidence: {diagnosis.get('supporting_evidence', 'Not specified')}
β Against: {diagnosis.get('against_evidence', 'Not specified')}
"""
else:
differential_html += "
No differential diagnoses documented.
"
differential_html += "
"
# Format Recommendations & Additional Info
recommendations_html = ""
recommendations_html += "
π‘ Clinical Recommendations
"
recommendations_list = analysis_data.get('clinical_recommendations', [])
if recommendations_list:
recommendations_html += "
"
for rec in recommendations_list:
recommendations_html += f"- {rec}
"
recommendations_html += "
"
else:
recommendations_html += "
No specific recommendations at this time.
"
# Add Attention Map Analysis
recommendations_html += f"""
πΊοΈ AI Attention Map Analysis
{analysis_data.get('attention_map_analysis', 'Not available')}
"""
# Add Limitations
recommendations_html += f"""
β οΈ Limitations
{analysis_data.get('limitations', 'Not specified')}
"""
# Add timestamp
recommendations_html += f"""
π
Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
"""
recommendations_html += "
"
return summary_html, thought_process_html, findings_html, progression_html, further_tests_html, differential_html, recommendations_html
except Exception as e:
error_msg = f""
return error_msg, error_msg, error_msg, error_msg, error_msg, error_msg, error_msg
# ------------------------------------------------------------
# Gradio Interface
# ------------------------------------------------------------
def create_interface():
"""Create and launch the Gradio interface with progressive display"""
with gr.Blocks(title="Chest X-ray Analysis for Doctors") as demo:
with gr.Row(equal_height=True):
with gr.Column(scale=1):
gr.Markdown("## π€ Upload X-ray Image")
image_input = gr.Image(
label="Upload Chest X-ray",
type="pil",
height=400
)
analyze_btn = gr.Button(
"π¬ Analyze X-ray",
variant="primary",
size="lg"
)
# with gr.Accordion("βΉοΈ System Information", open=False):
# gr.Markdown(
# """
# ### How It Works:
# 1. **Upload** a chest X-ray image (PNG, JPG, JPEG)
# 2. **Click** "Analyze X-ray" button
# 3. **View** comprehensive AI diagnosis and attention visualization
# ### Technology:
# - **Google Gemini AI**: Performs complete visual diagnostic evaluation
# - **Attention Map**: Highlights regions of interest in the X-ray
# ### Primary Diseases Evaluated:
# Atelectasis, Cardiomegaly, Consolidation, Edema, Effusion,
# Emphysema, Fibrosis, Hernia, Infiltration
# *Note: Co-morbidities and other conditions may also be identified*
# """
# )
with gr.Column(scale=1):
gr.Markdown("## π€ RadFusion Diagnosis")
with gr.Tabs():
with gr.Tab("π Summary"):
summary_output = gr.HTML(
value="""
π
Patient Summary
Simple language explanation for patients
"""
)
with gr.Tab("π¬ Clinical Analysis"):
thought_process_output = gr.HTML(
value="""
π¬
Clinical Thought Process
Comprehensive technical analysis for physicians
"""
)
with gr.Tab("π Findings"):
findings_output = gr.HTML(
value="""
π
Detailed Findings
Structured list of all findings
"""
)
with gr.Tab("π Progression"):
progression_output = gr.HTML(
value="""
π
Progression Assessment
Disease stage and prognosis
"""
)
with gr.Tab("π§ͺ Further Tests"):
tests_output = gr.HTML(
value="""
π§ͺ
Recommended Tests
Additional studies needed
"""
)
with gr.Tab("π― Differential Dx"):
differential_output = gr.HTML(
value="""
π―
Differential Diagnosis
Possible diagnoses ranked
"""
)
with gr.Tab("π‘ Recommendations"):
recommendations_output = gr.HTML(
value="""
π‘
Clinical Recommendations
Action items and guidance
"""
)
attention_output = gr.Image(
label="Visual Focus Areas",
height=400
)
# Hidden states to store data for Gemini (no predictions needed)
raw_image_state = gr.State(value=None)
attention_image_state = gr.State(value=None)
# Connect the button - first generate attention map
quick_analysis = analyze_btn.click(
fn=predict_xray_quick,
inputs=image_input,
outputs=[summary_output, attention_output, raw_image_state, attention_image_state]
)
# Then generate Gemini direct diagnosis for all tabs
quick_analysis.then(
fn=generate_full_analysis,
inputs=[raw_image_state, attention_image_state],
outputs=[summary_output, thought_process_output, findings_output,
progression_output, tests_output, differential_output, recommendations_output]
)
return demo
# ------------------------------------------------------------
# Main
# ------------------------------------------------------------
if __name__ == "__main__":
print("=" * 80)
print("Starting Chest X-ray Analysis System")
print("=" * 80)
# Preload the attention model (Bio-ViL not needed for predictions)
print("\nLoading attention visualization model...")
print("Note: Bio-ViL predictions disabled - using LLM-primary architecture")
try:
load_attention_model()
print("Attention model loaded successfully!")
except Exception as e:
print(f"Warning: Could not load attention model: {e}")
print("\n" + "=" * 80)
print("Starting Gradio interface...")
print("=" * 80)
# Create and launch interface
demo = create_interface()
demo.launch()