# ============================================================ # Chest X-ray Doctor Interface with Transformer Attention Maps # ============================================================ import os, cv2 import numpy as np import torch import torch.nn as nn from PIL import Image from torchvision import transforms, models from transformers import ViTModel import gradio as gr import google.generativeai as genai from datetime import datetime import re import json # ------------------------------------------------------------ # Device # ------------------------------------------------------------ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # ------------------------------------------------------------ # Google Gemini Configuration # ------------------------------------------------------------ GEMINI_API_KEY = "AIzaSyBY6rQz-TCRenrrdXv2uKbE4GTbgHQbLuk" if GEMINI_API_KEY: genai.configure(api_key=GEMINI_API_KEY) gemini_model = genai.GenerativeModel('gemini-3-pro-preview') # Pro model for better visual analysis print("Google Gemini API configured successfully (gemini-1.5-pro with vision)") else: gemini_model = None print("Warning: GEMINI_API_KEY not found. Gemini analysis will be disabled.") print("Set it using: export GEMINI_API_KEY='your-api-key' (Linux/Mac)") print("Or: set GEMINI_API_KEY=your-api-key (Windows)") # ------------------------------------------------------------ # Labels (Primary Diseases to Consider) # ------------------------------------------------------------ label_columns = [ 'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Effusion', 'Emphysema', 'Fibrosis', 'Hernia', 'Infiltration' ] # Note: Co-morbidities and conditions outside this list can also occur # ------------------------------------------------------------ # Bio-ViL Model (for predictions) # ------------------------------------------------------------ class BioViLChestXRayModel(nn.Module): """Bio-ViL model for chest X-ray classification""" def __init__(self, num_classes=14, dropout_rate=0.3, model_dir="model"): super().__init__() print(f"Loading Bio-ViL vision model from: {model_dir}") self.vision_model = ViTModel.from_pretrained( model_dir, torch_dtype=torch.float32, ignore_mismatched_sizes=True ) # Feature dimension if hasattr(self.vision_model.config, 'hidden_size'): self.feature_dim = self.vision_model.config.hidden_size else: self.feature_dim = 768 print(f"Model feature dimension: {self.feature_dim}") # Classification head - multi-head attention self.classifier = nn.Sequential( nn.Dropout(dropout_rate), nn.Linear(self.feature_dim, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(dropout_rate), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(dropout_rate), nn.Linear(256, num_classes) ) # Disease-specific attention self.disease_attention = nn.MultiheadAttention( embed_dim=self.feature_dim, num_heads=8, dropout=dropout_rate, batch_first=True ) def forward(self, pixel_values): if pixel_values.dim() == 3: pixel_values = pixel_values.unsqueeze(0) outputs = self.vision_model(pixel_values, output_hidden_states=True) if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None: features = outputs.pooler_output else: features = outputs.last_hidden_state[:, 0, :] # Apply attention features_expanded = features.unsqueeze(1) attended_features, _ = self.disease_attention( features_expanded, features_expanded, features_expanded ) features = attended_features.squeeze(1) logits = self.classifier(features) return logits # ------------------------------------------------------------ # Hybrid CNN + ViT Model (for attention visualization) # ------------------------------------------------------------ class CNNViTChestXRayModel(nn.Module): def __init__(self, num_classes=14, dropout_rate=0.3): super().__init__() # Initialize ResNet-50 for CNN features resnet = models.resnet50(pretrained=True) self.cnn = nn.Sequential(*list(resnet.children())[:-2]) # Take layers up to the last pooling # Initialize ViT model from HuggingFace # We enable output_attentions=True to retrieve attention maps self.vit = ViTModel.from_pretrained( "google/vit-base-patch16-224-in21k", torch_dtype=torch.float32, output_attentions=True, # <-- enable attention maps ignore_mismatched_sizes=True ) self.feature_dim = self.vit.config.hidden_size # 768 for ViT-base # 1x1 Conv to project ResNet's 2048 channels to ViT's feature dimension (768) self.proj = nn.Conv2d(2048, self.feature_dim, kernel_size=1) # Multi-head attention layer for disease-specific feature aggregation self.disease_attention = nn.MultiheadAttention( embed_dim=self.feature_dim, num_heads=8, dropout=dropout_rate, batch_first=True ) # Classification head self.classifier = nn.Sequential( nn.Dropout(dropout_rate), nn.Linear(self.feature_dim, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(dropout_rate), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(dropout_rate), nn.Linear(256, num_classes) ) self.model_name = "ResNet50 + ViT Hybrid" def forward(self, pixel_values): # Handle single image input by adding batch dimension if pixel_values.dim() == 3: pixel_values = pixel_values.unsqueeze(0) # 1. CNN Feature Extraction (ResNet-50) cnn_feats = self.cnn(pixel_values) # Output: B, 2048, H/32, W/32 # 2. Projection (1x1 Conv) feats = self.proj(cnn_feats) # Output: B, 768, H/32, W/32 (e.g., B, 768, 7, 7 for 224x224 input) B, C, H, W = feats.shape # 3. Flatten and Transpose to create ViT-like tokens (excluding CLS token) tokens = feats.flatten(2).transpose(1, 2) # Output: B, 49, 768 (49 tokens for 7x7 map) # 4. Add CLS Token cls_token = self.vit.embeddings.cls_token.expand(B, -1, -1) tokens = torch.cat((cls_token, tokens), dim=1) # Output: B, 50, 768 # 5. Add Position Embeddings pos_emb = self.vit.embeddings.position_embeddings[:, :tokens.size(1), :].to(tokens.device) tokens = tokens + pos_emb # 6. ViT Encoder Pass vit_out = self.vit.encoder(tokens, output_attentions=True) hidden_states = vit_out.last_hidden_state attentions = vit_out.attentions # list of attention maps from all layers # 7. Extract CLS token feature features = hidden_states[:, 0, :] # Output: B, 768 # 8. Disease-specific Attention (Self-Attention on the CLS token) features_expanded = features.unsqueeze(1) # B, 1, 768 (Query/Key/Value) attended, _ = self.disease_attention(features_expanded, features_expanded, features_expanded) features = attended.squeeze(1) # Output: B, 768 (Final aggregated feature) # 9. Classification logits = self.classifier(features) # Return projected CNN features (feats) for CNN activation dashboard return logits, attentions, feats # ------------------------------------------------------------ # Image Preprocessing # ------------------------------------------------------------ def preprocess_medical_image(img): """Apply CLAHE and preprocessing to medical image""" # Initialize CLAHE for contrast enhancement clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)) # Convert to grayscale if needed if len(img.shape) == 3: img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) # Apply CLAHE img = clahe.apply(img) # Normalize to 0-255 range img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX) # Apply slight Gaussian blur for noise reduction img = cv2.GaussianBlur(img, (3, 3), 0.5) # Convert to 3-channel RGB for model img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) return img # ------------------------------------------------------------ # Transforms # ------------------------------------------------------------ val_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # ------------------------------------------------------------ # Visualization Helper: ViT Attention (from original code) # ------------------------------------------------------------ def get_vit_blended_image(image_tensor, attention_maps): """ Computes and blends the image with the heatmap derived from the ViT CLS token's attention. """ # Denormalize image for display mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) img = image_tensor.permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C img = img * std + mean img = (img - img.min()) / (img.max() - img.min()) # Normalize to 0-1 for display img = (img * 255).astype(np.uint8) # Scale to 0-255 # Extract attention map (CLS token's attention to patches) from the last layer, mean across heads cls_to_patches_attn = attention_maps[-1].mean(dim=1)[0, 0, 1:].detach().cpu().numpy() # Reshape the attention scores (49 patches for 224x224 input) expected_size = 49 if cls_to_patches_attn.size != expected_size: # Handle cases where input size might result in different patch count side = int(np.sqrt(cls_to_patches_attn.size)) attn = cls_to_patches_attn.reshape(side, side) else: attn = cls_to_patches_attn.reshape(7, 7) # Resize the low-resolution attention map to match the original image size (224x224) attn = cv2.resize(attn, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_LINEAR) # Scale attention map to 0-255 attn = (attn - attn.min()) / (attn.max() - attn.min()) attn_scaled = np.uint8(255 * attn) # Create heatmap using VIRIDIS colormap for contrast heatmap = cv2.applyColorMap(attn_scaled, cv2.COLORMAP_VIRIDIS) # Blend image and heatmap img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) blended = cv2.addWeighted(img_bgr, 0.4, heatmap, 0.6, 0) # Convert back to RGB for display blended_rgb = cv2.cvtColor(blended, cv2.COLOR_BGR2RGB) return blended_rgb # ------------------------------------------------------------ # Load Models # ------------------------------------------------------------ BIOVIL_MODEL_PATH = "best_biovil_hf_model.pth" # Bio-ViL model for predictions # Try to find the ViT model directory - common locations BIOVIL_MODEL_DIR = None for possible_dir in ["model", "vit-model", "microsoft/BiomedVLP-CXR-BERT-specialized"]: if os.path.exists(possible_dir) or "/" in possible_dir: BIOVIL_MODEL_DIR = possible_dir break # Fallback to Hugging Face model if no local model found if BIOVIL_MODEL_DIR is None: BIOVIL_MODEL_DIR = "google/vit-base-patch16-224-in21k" ATTENTION_MODEL_PATH = "best_cnn_vit.pth" # CNN+ViT model for attention visualization biovil_model = None attention_model = None def load_biovil_model(): """Load the Bio-ViL model for predictions""" global biovil_model if biovil_model is not None: return biovil_model print(f"Loading Bio-ViL model with ViT from: {BIOVIL_MODEL_DIR}") biovil_model = BioViLChestXRayModel( num_classes=len(label_columns), model_dir=BIOVIL_MODEL_DIR ).to(device) if os.path.exists(BIOVIL_MODEL_PATH): try: checkpoint = torch.load( BIOVIL_MODEL_PATH, weights_only=False, map_location=device ) # Handle different checkpoint formats if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] else: state_dict = checkpoint biovil_model.load_state_dict(state_dict, strict=False) print("Bio-ViL model weights loaded successfully.") except Exception as e: print(f"Warning: Could not load Bio-ViL model weights. Error: {e}") else: print(f"Warning: Bio-ViL checkpoint not found at {BIOVIL_MODEL_PATH}") biovil_model.eval() return biovil_model def load_attention_model(): """Load the CNN+ViT model for attention visualization""" global attention_model if attention_model is not None: return attention_model print("Loading attention visualization model...") attention_model = CNNViTChestXRayModel(num_classes=len(label_columns)).to(device) if os.path.exists(ATTENTION_MODEL_PATH): try: checkpoint = torch.load( ATTENTION_MODEL_PATH, weights_only=False, map_location=device ) state_dict = checkpoint.get("state_dict", checkpoint) attention_model.load_state_dict(state_dict, strict=False) print("Attention model weights loaded successfully.") except Exception as e: print(f"Warning: Could not load attention model weights. Error: {e}") else: print(f"Warning: Attention model checkpoint not found at {ATTENTION_MODEL_PATH}") attention_model.eval() return attention_model # ------------------------------------------------------------ # Prediction Functions # ------------------------------------------------------------ # NOTE: Bio-ViL model predictions are NO LONGER USED # The deep learning model only generates attention maps # All diagnosis is performed by Gemini AI from visual analysis def get_attention_visualization(image): """ Generate attention visualization from CNN+ViT model Args: image: PIL Image or numpy array Returns: attention_img: Attention map overlay image """ if image is None: return None # Load attention model model = load_attention_model() # Convert to numpy array if PIL Image if isinstance(image, Image.Image): image = np.array(image) # Preprocess the medical image processed_img = preprocess_medical_image(image) # Convert to PIL for transforms pil_img = Image.fromarray(processed_img) # Apply transforms img_tensor = val_transforms(pil_img) img_tensor = img_tensor.unsqueeze(0).to(device) # Get attention maps with torch.no_grad(): _, attentions, _ = model(img_tensor) # Generate attention map attention_img = get_vit_blended_image( img_tensor[0].cpu(), [attn for attn in attentions] ) return attention_img def generate_gemini_analysis(raw_image, attention_image): """ Generate medical diagnosis using Google Gemini AI with direct visual analysis NO DEEP LEARNING MODEL PREDICTIONS - Pure visual interpretation by AI Args: raw_image: Original X-ray image (PIL Image) attention_image: Attention map visualization (numpy array) - shows regions of interest Returns: analysis_html: HTML formatted comprehensive diagnosis from Gemini """ if gemini_model is None: return """

⚠️ Gemini Analysis Unavailable

Google Gemini API key not configured. Set GEMINI_API_KEY environment variable to enable AI analysis.

""" try: # Convert attention map to PIL Image if it's a numpy array if isinstance(attention_image, np.ndarray): attention_pil = Image.fromarray(attention_image) else: attention_pil = attention_image # Disease list to consider disease_list = "\n".join([f"- {disease}" for disease in label_columns]) # Create detailed prompt for Gemini with structured JSON output prompt = f"""You are a board-certified radiologist with 20+ years of experience in thoracic imaging. You are conducting a PRIMARY DIAGNOSTIC EVALUATION of a chest X-ray examination. **EXAMINATION MATERIALS PROVIDED:** 1. **Original Chest X-ray Image** (first image) - The primary diagnostic image for your interpretation 2. **AI Attention Heatmap** (second image) - Visualization showing regions highlighted by a computer vision model (for reference only) **PRIMARY DISEASES TO CONSIDER:** {disease_list} **Important Note:** Patients may have co-morbidities or conditions outside this list. Consider all findings you observe. **YOUR DIAGNOSTIC TASK:** YOU are the primary diagnostician. There are NO AI predictions provided. Perform a complete diagnostic evaluation based solely on what YOU observe in the chest X-ray image. **CRITICAL: YOU MUST RESPOND WITH VALID JSON ONLY** Provide your complete diagnostic evaluation as a JSON object with the following structure: ```json {{ "patient_summary": "A clear, compassionate explanation for the patient in simple, non-technical language (2-3 paragraphs). IMPORTANT: Clearly state the suspected/anticipated condition or disease (e.g., 'This appearance is concerning for possible lung cancer' or 'These findings suggest pneumonia'). Do NOT be vague - patients need to understand what condition is being considered. Explain what was found, what it likely means, and what happens next. Use simple terms but be direct about the suspected diagnosis while acknowledging uncertainty where appropriate.", "clinical_thought_process": "Comprehensive technical analysis for physicians. Include systematic anatomical review, differential diagnosis reasoning, specific findings with anatomical locations, measurements, confidence levels, and clinical significance. Use precise medical terminology.", "findings": [ {{ "finding": "Name of finding (e.g., 'Right lower lobe consolidation')", "location": "Specific anatomical location", "severity": "mild/moderate/severe", "confidence": "high/moderate/low", "description": "Detailed technical description" }} ], "progression_assessment": "Analysis of disease stage, acuity (acute/chronic), progression risk, expected course, and prognosis. Include timeline expectations and monitoring recommendations.", "further_tests": [ {{ "test": "Name of recommended test/study", "urgency": "immediate/urgent/routine/optional", "rationale": "Why this test is needed", "expected_findings": "What this test would help clarify" }} ], "differential_diagnosis": [ {{ "diagnosis": "Name of condition", "probability": "high/moderate/low", "supporting_evidence": "Visual findings that support this diagnosis", "against_evidence": "Findings that argue against this diagnosis" }} ], "attention_map_analysis": "Brief assessment of whether the AI attention map highlighted relevant abnormal regions you identified. Did it miss anything important?", "clinical_recommendations": [ "Specific actionable recommendation 1", "Specific actionable recommendation 2" ], "limitations": "Any limitations in the current imaging that affect diagnostic certainty (e.g., single view, image quality issues, need for comparison studies)" }} ``` --- **1. SYSTEMATIC IMAGE QUALITY & TECHNICAL ASSESSMENT** Evaluate the radiograph's diagnostic quality: - **Projection & Positioning**: PA/AP view, rotation, inspiratory effort - **Penetration & Exposure**: Under/over-penetrated, optimal contrast - **Coverage**: Adequate visualization of lung apices, costophrenic angles, and lateral chest wall - **Artifacts or Limitations**: Foreign bodies, patient motion, technical factors affecting interpretation - **Overall Image Quality**: Diagnostic quality rating (excellent/good/adequate/limited) **2. DETAILED ANATOMICAL REVIEW** Systematically examine each anatomical region: **A. CARDIAC & MEDIASTINAL STRUCTURES:** - **Heart Size**: Cardiothoracic ratio (estimate percentage), cardiomegaly present? - **Cardiac Contours**: Right atrium, left ventricle, aortic knob visibility - **Mediastinal Width**: Normal/widened, mass effect, contour abnormalities - **Trachea**: Midline/deviated, caliber changes - **Hilar Structures**: Size, density, prominence (lymphadenopathy?) **B. PULMONARY PARENCHYMA:** - **Right Lung Field**: Upper/middle/lower zone opacities, consolidations, masses, nodules - **Left Lung Field**: Upper/middle/lower zone opacities, consolidations, masses, nodules - **Lung Volumes**: Hyperinflation/atelectasis patterns - **Interstitial Patterns**: Reticular, nodular, reticulonodular patterns - **Air Bronchograms**: Present/absent, location **C. PLEURAL SPACES & DIAPHRAGM:** - **Pleural Effusions**: Blunted costophrenic angles (right/left), meniscus sign, quantity estimate - **Pneumothorax**: Visceral pleural line visible, extent if present - **Pleural Thickening**: Focal/diffuse, calcifications - **Diaphragm**: Height, contour irregularities, elevation/flattening **D. BONY THORAX & SOFT TISSUES:** - **Ribs**: Fractures, lesions, bone density - **Clavicles & Scapulae**: Asymmetry, fractures - **Spine**: Alignment, degenerative changes visible - **Soft Tissues**: Subcutaneous emphysema, breast shadows, chest wall masses **3. AI ATTENTION MAP CRITICAL ANALYSIS** Detailed interpretation of the model's focus areas: - **Primary Focus Regions**: Identify specific anatomical areas highlighted (e.g., "right lower lobe", "left hilum", "cardiac apex") - **Correlation with Visual Findings**: Do highlighted areas correspond to visible pathology you identified? - **Model Performance Assessment**: - True Positives: Model correctly identified abnormal areas - False Positives: Model flagged normal areas as abnormal - False Negatives: Model missed visible abnormalities - Overall Model Reliability: Based on attention-finding correlation - **Clinical Validity**: Does the AI's attention pattern match what an experienced radiologist would focus on? **4. COMPREHENSIVE CLINICAL ASSESSMENT** Synthesize findings into clinical diagnoses: **A. PRIMARY FINDINGS** (Findings with highest clinical significance): - List 2-3 most significant abnormalities - Severity assessment for each (mild/moderate/severe) - Clinical implications and pathophysiology - Supporting visual evidence from the X-ray **B. SECONDARY FINDINGS** (Additional observations): - Less critical but notable findings - Incidental findings that warrant mention - Chronic vs acute appearance **C. DIAGNOSTIC CONFIDENCE ASSESSMENT**: - **High Confidence Findings**: Clearly evident abnormalities with definitive visual features - **Moderate Confidence Findings**: Findings present but require clinical correlation or additional imaging - **Low Confidence/Uncertain Findings**: Subtle changes that may represent pathology or normal variation - **Attention Map Utility**: Did the attention heatmap highlight relevant abnormal regions you identified? **5. DIFFERENTIAL DIAGNOSIS** Provide ranked differential diagnoses: - **Most Likely Diagnosis**: Primary diagnostic impression with justification - **Alternative Diagnoses**: 2-3 other possibilities to consider - **Discriminating Features**: What findings support one diagnosis over another - **Clinical Context Needed**: What patient history would help narrow diagnosis **6. DETAILED RECOMMENDATIONS** **A. URGENCY STRATIFICATION**: - **Critical (Immediate action within hours)**: Life-threatening findings - **Urgent (Action within 24-48 hours)**: Significant findings requiring prompt attention - **Semi-Urgent (Action within 1 week)**: Important but stable findings - **Routine (Follow-up as scheduled)**: Chronic or minor findings **B. ADDITIONAL IMAGING**: - Specific studies recommended (CT chest, lateral view, etc.) - Rationale for each recommendation - Urgency of additional imaging - What specific questions the additional imaging should answer **C. CLINICAL CORRELATION**: - Specific clinical signs/symptoms to correlate - Laboratory tests that would be helpful (CBC, BNP, D-dimer, etc.) - Physical examination findings to document **D. FOLLOW-UP PLAN**: - Recommended follow-up imaging timeline - Monitoring parameters - When to escalate care **7. CRITICAL REASONING & LIMITATIONS** **A. DIAGNOSTIC REASONING**: - How visual findings integrate with model predictions - Why you prioritized certain findings over others - Pattern recognition applied (e.g., "bilateral perihilar opacities suggest pulmonary edema") - Evidence-based support for conclusions **B. UNCERTAINTY & LIMITATIONS**: - Areas of diagnostic uncertainty - Limitations of single-view radiograph - Findings requiring additional views/studies for confirmation - Technical factors limiting interpretation **C. QUALITY ASSURANCE**: - Self-check of systematic review (anything missed?) - Alternative interpretations considered - Peer consultation recommendations if complex **8. PATIENT COMMUNICATION SUMMARY** Brief, clear summary appropriate for patient discussion: - Main findings in non-technical language - Severity and implications - Next steps explained simply --- **IMPORTANT GUIDELINES:** - **YOU ARE THE PRIMARY DIAGNOSTICIAN** - There are no AI predictions to validate - Be SPECIFIC with anatomical locations (e.g., "right lower lobe lateral segment" not just "right lung") - Quantify when possible (e.g., "approximately 300ml pleural effusion" not just "effusion present") - Reference specific visual features you observe in the X-ray - Provide measurements or estimates (cardiothoracic ratio, effusion size, etc.) - Assign probability/confidence to each finding - Acknowledge uncertainty where appropriate - radiology involves clinical judgment - This is AI-generated analysis - MUST NOT replace board-certified radiologist review - Focus on what you ACTUALLY SEE in the images **RESPONSE FORMAT:** - YOU MUST respond with ONLY the JSON object - no additional text before or after - Ensure all JSON fields are properly escaped (quotes, newlines, etc.) - If a section doesn't apply (e.g., no further tests needed), use empty arrays [] or explain why in the text - Be thorough, detailed, and clinically comprehensive in each section Respond ONLY with the JSON object now:""" # Generate analysis with images response = gemini_model.generate_content([ prompt, raw_image, attention_pil ]) analysis_text = response.text # Try to parse JSON from response try: # Extract JSON if wrapped in code blocks if "```json" in analysis_text: json_start = analysis_text.find("```json") + 7 json_end = analysis_text.find("```", json_start) json_str = analysis_text[json_start:json_end].strip() elif "```" in analysis_text: json_start = analysis_text.find("{") json_end = analysis_text.rfind("}") + 1 json_str = analysis_text[json_start:json_end].strip() else: # Try to find JSON object directly json_start = analysis_text.find("{") json_end = analysis_text.rfind("}") + 1 json_str = analysis_text[json_start:json_end].strip() analysis_data = json.loads(json_str) # Return structured data for tabbed interface return analysis_data except (json.JSONDecodeError, ValueError) as e: # Fallback: return error with raw text return { "patient_summary": "Error parsing AI response. Please try again.", "clinical_thought_process": f"JSON parsing error: {str(e)}\n\nRaw response:\n{analysis_text}", "findings": [], "progression_assessment": "Unable to assess due to parsing error.", "further_tests": [], "differential_diagnosis": [], "attention_map_analysis": "N/A", "clinical_recommendations": [], "limitations": "Response parsing failed. Please regenerate analysis." } except Exception as e: error_html = f"""

❌ Gemini Analysis Error

Failed to generate analysis: {str(e)}

""" return error_html def predict_xray_quick(image): """ Quick analysis: Generate attention map only (NO predictions from deep learning model) Args: image: PIL Image or numpy array Returns: info_html: Architecture information attention_img: Attention map overlay image status_html: Status message raw_image: Original image for Gemini attention_img_for_gemini: Attention map for Gemini """ if image is None: return "
No image provided
", None, None, None try: # Store original image for Gemini if isinstance(image, Image.Image): raw_image = image.copy() else: raw_image = Image.fromarray(image) # Get attention visualization ONLY (no predictions) attention_img = get_attention_visualization(image) # Loading message shown in summary tab while analysis runs loading_html = """
πŸ”¬ Analyzing X-ray...
RadFusion is performing comprehensive visual diagnosis
""" return loading_html, attention_img, raw_image, attention_img except Exception as e: error_msg = f"

Error:

{str(e)}

" return error_msg, None, None, None def generate_full_analysis(raw_image, attention_image): """ Generate Gemini direct diagnosis from images and format for tabbed display Args: raw_image: Original X-ray image (PIL Image) attention_image: Attention map visualization (numpy array) Returns: Tuple of HTML formatted sections for each tab: (summary_html, thought_process_html, findings_html, progression_html, further_tests_html, differential_html, recommendations_html) """ if raw_image is None: error_msg = "
No image available for analysis
" return error_msg, error_msg, error_msg, error_msg, error_msg, error_msg, error_msg try: analysis_data = generate_gemini_analysis(raw_image, attention_image) # Format Patient Summary (simple language) summary_html = f"""

πŸ“‹ Patient Summary

{analysis_data.get('patient_summary', 'No summary available').replace(chr(10), '

')}
ℹ️ For Patients
This section is written in simple language to help you understand the findings.
""" # Format Clinical Thought Process (technical language) thought_process_html = f"""

πŸ”¬ Clinical Thought Process

{analysis_data.get('clinical_thought_process', 'No clinical analysis available')}
🩺 For Healthcare Professionals
Comprehensive technical analysis with medical terminology.
""" # Format Findings findings_html = "
" findings_html += "

πŸ” Detailed Findings

" findings_list = analysis_data.get('findings', []) if findings_list: for idx, finding in enumerate(findings_list, 1): severity = finding.get('severity', 'unknown') severity_color = {'mild': '#27ae60', 'moderate': '#f39c12', 'severe': '#e74c3c'}.get(severity.lower(), '#95a5a6') confidence = finding.get('confidence', 'unknown') findings_html += f"""

{idx}. {finding.get('finding', 'Unknown finding')}

{severity.upper()}
Location: {finding.get('location', 'Not specified')}
{finding.get('description', 'No description provided')}
Confidence: {confidence}
""" else: findings_html += "
No specific findings documented.
" findings_html += "
" # Format Progression Assessment progression_html = f"""

πŸ“ˆ Progression Assessment

{analysis_data.get('progression_assessment', 'No progression assessment available')}
""" # Format Further Tests further_tests_html = "
" further_tests_html += "

πŸ§ͺ Further Tests Recommended

" tests_list = analysis_data.get('further_tests', []) if tests_list: for idx, test in enumerate(tests_list, 1): urgency = test.get('urgency', 'routine') urgency_color = {'immediate': '#e74c3c', 'urgent': '#f39c12', 'routine': '#3498db', 'optional': '#95a5a6'}.get(urgency.lower(), '#95a5a6') further_tests_html += f"""

{idx}. {test.get('test', 'Unknown test')}

{urgency.upper()}
Rationale: {test.get('rationale', 'Not specified')}
Expected Findings: {test.get('expected_findings', 'Not specified')}
""" else: further_tests_html += "
No additional tests recommended at this time.
" further_tests_html += "
" # Format Differential Diagnosis differential_html = "
" differential_html += "

🎯 Differential Diagnosis

" diff_list = analysis_data.get('differential_diagnosis', []) if diff_list: for idx, diagnosis in enumerate(diff_list, 1): probability = diagnosis.get('probability', 'unknown') prob_color = {'high': '#e74c3c', 'moderate': '#f39c12', 'low': '#27ae60'}.get(probability.lower(), '#95a5a6') differential_html += f"""

{idx}. {diagnosis.get('diagnosis', 'Unknown diagnosis')}

{probability.upper()} PROBABILITY
βœ“ Supporting Evidence: {diagnosis.get('supporting_evidence', 'Not specified')}
βœ— Against: {diagnosis.get('against_evidence', 'Not specified')}
""" else: differential_html += "
No differential diagnoses documented.
" differential_html += "
" # Format Recommendations & Additional Info recommendations_html = "
" recommendations_html += "

πŸ’‘ Clinical Recommendations

" recommendations_list = analysis_data.get('clinical_recommendations', []) if recommendations_list: recommendations_html += "" else: recommendations_html += "
No specific recommendations at this time.
" # Add Attention Map Analysis recommendations_html += f"""

πŸ—ΊοΈ AI Attention Map Analysis

{analysis_data.get('attention_map_analysis', 'Not available')}
""" # Add Limitations recommendations_html += f"""

⚠️ Limitations

{analysis_data.get('limitations', 'Not specified')}
""" # Add timestamp recommendations_html += f"""
πŸ“… Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
""" recommendations_html += "
" return summary_html, thought_process_html, findings_html, progression_html, further_tests_html, differential_html, recommendations_html except Exception as e: error_msg = f"

Error:

{str(e)}

" return error_msg, error_msg, error_msg, error_msg, error_msg, error_msg, error_msg # ------------------------------------------------------------ # Gradio Interface # ------------------------------------------------------------ def create_interface(): """Create and launch the Gradio interface with progressive display""" with gr.Blocks(title="Chest X-ray Analysis for Doctors") as demo: with gr.Row(equal_height=True): with gr.Column(scale=1): gr.Markdown("## πŸ“€ Upload X-ray Image") image_input = gr.Image( label="Upload Chest X-ray", type="pil", height=400 ) analyze_btn = gr.Button( "πŸ”¬ Analyze X-ray", variant="primary", size="lg" ) # with gr.Accordion("ℹ️ System Information", open=False): # gr.Markdown( # """ # ### How It Works: # 1. **Upload** a chest X-ray image (PNG, JPG, JPEG) # 2. **Click** "Analyze X-ray" button # 3. **View** comprehensive AI diagnosis and attention visualization # ### Technology: # - **Google Gemini AI**: Performs complete visual diagnostic evaluation # - **Attention Map**: Highlights regions of interest in the X-ray # ### Primary Diseases Evaluated: # Atelectasis, Cardiomegaly, Consolidation, Edema, Effusion, # Emphysema, Fibrosis, Hernia, Infiltration # *Note: Co-morbidities and other conditions may also be identified* # """ # ) with gr.Column(scale=1): gr.Markdown("## πŸ€– RadFusion Diagnosis") with gr.Tabs(): with gr.Tab("πŸ“‹ Summary"): summary_output = gr.HTML( value="""
πŸ“‹
Patient Summary
Simple language explanation for patients
""" ) with gr.Tab("πŸ”¬ Clinical Analysis"): thought_process_output = gr.HTML( value="""
πŸ”¬
Clinical Thought Process
Comprehensive technical analysis for physicians
""" ) with gr.Tab("πŸ” Findings"): findings_output = gr.HTML( value="""
πŸ”
Detailed Findings
Structured list of all findings
""" ) with gr.Tab("πŸ“ˆ Progression"): progression_output = gr.HTML( value="""
πŸ“ˆ
Progression Assessment
Disease stage and prognosis
""" ) with gr.Tab("πŸ§ͺ Further Tests"): tests_output = gr.HTML( value="""
πŸ§ͺ
Recommended Tests
Additional studies needed
""" ) with gr.Tab("🎯 Differential Dx"): differential_output = gr.HTML( value="""
🎯
Differential Diagnosis
Possible diagnoses ranked
""" ) with gr.Tab("πŸ’‘ Recommendations"): recommendations_output = gr.HTML( value="""
πŸ’‘
Clinical Recommendations
Action items and guidance
""" ) attention_output = gr.Image( label="Visual Focus Areas", height=400 ) # Hidden states to store data for Gemini (no predictions needed) raw_image_state = gr.State(value=None) attention_image_state = gr.State(value=None) # Connect the button - first generate attention map quick_analysis = analyze_btn.click( fn=predict_xray_quick, inputs=image_input, outputs=[summary_output, attention_output, raw_image_state, attention_image_state] ) # Then generate Gemini direct diagnosis for all tabs quick_analysis.then( fn=generate_full_analysis, inputs=[raw_image_state, attention_image_state], outputs=[summary_output, thought_process_output, findings_output, progression_output, tests_output, differential_output, recommendations_output] ) return demo # ------------------------------------------------------------ # Main # ------------------------------------------------------------ if __name__ == "__main__": print("=" * 80) print("Starting Chest X-ray Analysis System") print("=" * 80) # Preload the attention model (Bio-ViL not needed for predictions) print("\nLoading attention visualization model...") print("Note: Bio-ViL predictions disabled - using LLM-primary architecture") try: load_attention_model() print("Attention model loaded successfully!") except Exception as e: print(f"Warning: Could not load attention model: {e}") print("\n" + "=" * 80) print("Starting Gradio interface...") print("=" * 80) # Create and launch interface demo = create_interface() demo.launch()