| | |
| | """ |
| | BrainGemma3D + LIME Interpretability |
| | ================================================== |
| | Usage: |
| | python braingemma3d_interpretability.py \\ |
| | --model_dir ./final_model \\ |
| | --mri_path /path/to/scan.nii.gz \\ |
| | --report "The brain shows a mass in the left frontal lobe..." \\ |
| | --output_dir ./lime_output |
| | |
| | If --report is not provided, the script will generate it first. |
| | """ |
| |
|
| | import os |
| | import sys |
| | import json |
| | import argparse |
| | import random |
| | import importlib.util |
| | from pathlib import Path |
| |
|
| | import numpy as np |
| | import torch |
| | import matplotlib |
| | matplotlib.use("Agg") |
| | import matplotlib.pyplot as plt |
| |
|
| | |
| | from lime import lime_image |
| | from skimage.segmentation import slic |
| | from scipy.ndimage import binary_closing, binary_opening, binary_fill_holes, binary_erosion |
| | from skimage.morphology import ball, remove_small_objects |
| | from skimage.measure import label as cc_label |
| |
|
| |
|
| | def set_seed(seed: int = 42): |
| | """Set random seed for reproducibility.""" |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| |
|
| |
|
| | def import_architecture_from_model_dir(model_dir): |
| | """Dynamically import braingemma3d_architecture.py from model folder.""" |
| | arch_path = os.path.join(model_dir, "braingemma3d_architecture.py") |
| | spec = importlib.util.spec_from_file_location("braingemma3d_architecture", arch_path) |
| | module = importlib.util.module_from_spec(spec) |
| | spec.loader.exec_module(module) |
| | return module |
| |
|
| |
|
| | def load_full_model(model_dir, device): |
| | """Load BrainGemma3D model with projector weights.""" |
| | arch_module = import_architecture_from_model_dir(model_dir) |
| | BrainGemma3D = arch_module.BrainGemma3D |
| | load_nifti_volume = arch_module.load_nifti_volume |
| | CANONICAL_PROMPT = arch_module.CANONICAL_PROMPT |
| |
|
| | with open(os.path.join(model_dir, "model_config.json")) as f: |
| | cfg = json.load(f) |
| |
|
| | model = BrainGemma3D( |
| | vision_model_dir=os.path.join(model_dir, cfg["vision_model_dir"]), |
| | language_model_dir=os.path.join(model_dir, cfg["language_model_dir"]), |
| | depth=cfg["depth"], |
| | num_vision_tokens=cfg["num_vision_tokens"], |
| | freeze_vision=True, |
| | freeze_language=True, |
| | device_map={"": 0} if device == "cuda" else None, |
| | ) |
| |
|
| | |
| | proj_path = os.path.join(model_dir, "projector_vis_scale.pt") |
| | ckpt = torch.load(proj_path, map_location=device) |
| | model.vision_projector.load_state_dict(ckpt["vision_projector"]) |
| |
|
| | if ckpt["vis_scale"] is not None: |
| | if isinstance(ckpt["vis_scale"], torch.Tensor): |
| | model.vis_scale.data = ckpt["vis_scale"].to(device) |
| | else: |
| | model.vis_scale.data.fill_(ckpt["vis_scale"]) |
| |
|
| | model.eval() |
| | return model, load_nifti_volume, CANONICAL_PROMPT |
| |
|
| |
|
| | |
| | |
| | |
| | @torch.no_grad() |
| | def lime_score_report_nll(volumes, model, prompt: str, report_ref: str, batch_size: int = 1): |
| | """ |
| | Score perturbed volumes with NLL of reference report. |
| | Lower NLL = model more confident in reference report = better support. |
| | LIME maximizes this score, so we return -NLL. |
| | |
| | Implementation follows original interpretability.py logic. |
| | """ |
| | device = model.lm_device |
| | |
| | |
| | prompt_ids = model.tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(device) |
| | report_ids = model.tokenizer(report_ref, return_tensors="pt", add_special_tokens=False).input_ids.to(device) |
| | text_ids_1 = torch.cat([prompt_ids, report_ids], dim=1) |
| | |
| | |
| | vols = torch.from_numpy(np.asarray(volumes)).to(device) |
| | if vols.ndim == 4: |
| | vols = vols.unsqueeze(1) |
| | |
| | N = vols.shape[0] |
| | scores = [] |
| | |
| | for i in range(0, N, batch_size): |
| | v = vols[i:i+batch_size].to(dtype=torch.bfloat16) |
| | |
| | |
| | vision_tokens = model.encode_volume(v) |
| | |
| | |
| | text_ids = text_ids_1.repeat(v.size(0), 1) |
| | text_embeds = model.language_model.get_input_embeddings()(text_ids) |
| | |
| | |
| | inputs_embeds = torch.cat([vision_tokens, text_embeds], dim=1) |
| | |
| | |
| | V = vision_tokens.size(1) |
| | prompt_mask = torch.full((v.size(0), prompt_ids.size(1)), -100, device=device, dtype=torch.long) |
| | vision_mask = torch.full((v.size(0), V), -100, device=device, dtype=torch.long) |
| | labels = torch.cat([vision_mask, prompt_mask, report_ids.repeat(v.size(0), 1)], dim=1) |
| | |
| | |
| | out = model.language_model(inputs_embeds=inputs_embeds, labels=labels) |
| | loss = out.loss |
| | scores.append((-loss).detach().float().cpu()) |
| | |
| | return torch.stack(scores).numpy().reshape(-1, 1) |
| |
|
| |
|
| | |
| | |
| | |
| | def quick_brain_mask( |
| | vol_zyx: np.ndarray, |
| | p_thresh: float = 25, |
| | min_cc_vox: int = 2000 |
| | ) -> np.ndarray: |
| | """Create brain mask from 3D volume.""" |
| | v = vol_zyx.astype(np.float32) |
| | thr = np.percentile(v, p_thresh) |
| | m = v > thr |
| | m = binary_opening(m, structure=ball(1)) |
| | m = binary_closing(m, structure=ball(2)) |
| | m = binary_fill_holes(m) |
| | m = remove_small_objects(m, min_size=min_cc_vox) |
| |
|
| | |
| | cc = cc_label(m) |
| | if cc.max() > 1: |
| | sizes = [(i, (cc == i).sum()) for i in range(1, cc.max() + 1)] |
| | largest = max(sizes, key=lambda x: x[1])[0] |
| | m = (cc == largest) |
| | |
| | return m.astype(bool) |
| |
|
| |
|
| | def big_supervoxels_brain_only( |
| | vol_zyx: np.ndarray, |
| | n_segments: int = 20, |
| | compactness: float = 0.05, |
| | sigma: float = 1.0, |
| | p_thresh: float = 25, |
| | min_cc_vox: int = 2000, |
| | ): |
| | """ |
| | Segment ONLY brain tissue using SLIC with brain mask. |
| | |
| | Returns segments with 0-based contiguous labels: |
| | - 0 = background (not brain) |
| | - 1, 2, ..., N = brain supervoxels |
| | |
| | This labeling is CRITICAL for LIME 0.2.0.1 which uses feature |
| | indices directly as segment labels: mask[segments == feature_idx]. |
| | With 0-based contiguous labels, feature i maps exactly to segment i. |
| | Background (0) adds one harmless noise feature to LIME's regression. |
| | """ |
| | brain = quick_brain_mask(vol_zyx, p_thresh=p_thresh, min_cc_vox=min_cc_vox) |
| |
|
| | |
| | |
| | |
| | seg = slic( |
| | vol_zyx, |
| | n_segments=n_segments, |
| | compactness=compactness, |
| | sigma=sigma, |
| | channel_axis=None, |
| | start_label=1, |
| | mask=brain, |
| | ) |
| | |
| | |
| | seg[seg < 0] = 0 |
| |
|
| | |
| | unique = np.unique(seg) |
| | expected = np.arange(len(unique)) |
| | if not np.array_equal(unique, expected): |
| | new_seg = np.zeros_like(seg) |
| | for new_id, old_id in enumerate(unique): |
| | new_seg[seg == old_id] = new_id |
| | seg = new_seg |
| | print(f"โน๏ธ Relabeled segments to contiguous 0..{len(unique)-1}", flush=True) |
| |
|
| | n_brain_segs = len(np.unique(seg)) - 1 |
| | print(f"๐งฉ Brain-only SLIC: {n_brain_segs} brain supervoxels " |
| | f"(requested {n_segments}), brain covers {100*brain.sum()/brain.size:.1f}% of volume", |
| | flush=True) |
| |
|
| | return seg, brain |
| |
|
| |
|
| | def make_segmentation_fn(cached_segments: np.ndarray): |
| | """Return a segmentation function that always returns pre-computed segments.""" |
| | def segmentation_fn(vol): |
| | return cached_segments |
| | return segmentation_fn |
| |
|
| |
|
| | |
| | |
| | |
| | def save_slice_png(volume_zyx: np.ndarray, out_path: str, axis: int = 0, idx: int | None = None, rot_k: int = 0): |
| | if idx is None: |
| | idx = volume_zyx.shape[axis] // 2 |
| |
|
| | if axis == 0: |
| | img = volume_zyx[idx, :, :] |
| | title = f"Axial (Z) slice {idx}" |
| | elif axis == 1: |
| | img = volume_zyx[:, idx, :] |
| | title = f"Coronal (Y) slice {idx}" |
| | else: |
| | img = volume_zyx[:, :, idx] |
| | title = f"Sagittal (X) slice {idx}" |
| |
|
| | img = np.rot90(img, k=rot_k) |
| |
|
| | plt.figure(figsize=(6, 6)) |
| | plt.imshow(img, cmap="gray", origin="lower") |
| | plt.title(title) |
| | plt.axis("off") |
| | os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) |
| | plt.tight_layout() |
| | plt.savefig(out_path, dpi=160) |
| | plt.close() |
| |
|
| |
|
| | def save_overlay_png( |
| | volume_zyx: np.ndarray, |
| | heat_zyx: np.ndarray, |
| | out_path: str, |
| | axis: int = 0, |
| | idx: int | None = None, |
| | alpha: float = 0.45, |
| | clip_q: float = 0.99, |
| | rot_k: int = 0, |
| | ): |
| | assert volume_zyx.shape == heat_zyx.shape |
| |
|
| | if idx is None: |
| | idx = volume_zyx.shape[axis] // 2 |
| |
|
| | if axis == 0: |
| | img = volume_zyx[idx, :, :] |
| | h = heat_zyx[idx, :, :] |
| | title = f"Axial (Z) overlay slice {idx}" |
| | elif axis == 1: |
| | img = volume_zyx[:, idx, :] |
| | h = heat_zyx[:, idx, :] |
| | title = f"Coronal (Y) overlay slice {idx}" |
| | else: |
| | img = volume_zyx[:, :, idx] |
| | h = heat_zyx[:, :, idx] |
| | title = f"Sagittal (X) overlay slice {idx}" |
| |
|
| | img = np.rot90(img, k=rot_k) |
| | h = np.rot90(h, k=rot_k) |
| |
|
| | m = float(max(np.quantile(np.abs(h), clip_q), 1e-8)) |
| | h_vis = np.clip(h, -m, m) |
| |
|
| | plt.figure(figsize=(6, 6)) |
| | plt.imshow(img, cmap="gray", origin="lower") |
| | im = plt.imshow(h_vis, cmap="bwr", alpha=alpha, origin="lower", vmin=-m, vmax=m) |
| | plt.title(title) |
| | plt.axis("off") |
| | plt.colorbar(im, fraction=0.046, pad=0.04) |
| | os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) |
| | plt.tight_layout() |
| | plt.savefig(out_path, dpi=160) |
| | plt.close() |
| |
|
| |
|
| | def save_overlay_grid_png( |
| | volume_zyx: np.ndarray, |
| | heat_zyx: np.ndarray, |
| | out_path: str, |
| | axis: int = 0, |
| | idxs=None, |
| | n_cols: int = 6, |
| | n_slices: int = 36, |
| | alpha: float = 0.45, |
| | clip_q: float = 0.99, |
| | rot_k: int = 0, |
| | figsize_per_cell: float = 2.2, |
| | add_colorbar: bool = False, |
| | suptitle: str | None = None, |
| | ): |
| | assert volume_zyx.shape == heat_zyx.shape |
| | assert axis in (0, 1, 2) |
| |
|
| | dim = volume_zyx.shape[axis] |
| | if idxs is None: |
| | lo = int(0.10 * (dim - 1)) |
| | hi = int(0.90 * (dim - 1)) |
| | if hi <= lo: |
| | lo, hi = 0, dim - 1 |
| | idxs = np.linspace(lo, hi, n_slices, dtype=int).tolist() |
| | else: |
| | idxs = list(map(int, idxs)) |
| |
|
| | n = len(idxs) |
| | n_rows = int(np.ceil(n / n_cols)) |
| |
|
| | m = float(max(np.quantile(np.abs(heat_zyx), clip_q), 1e-8)) |
| |
|
| | fig_w = n_cols * figsize_per_cell |
| | fig_h = n_rows * figsize_per_cell |
| | fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_w, fig_h)) |
| | axes = np.array(axes).reshape(-1) |
| |
|
| | def get_slice(arr, ax, i): |
| | if ax == 0: |
| | s = arr[i, :, :] |
| | elif ax == 1: |
| | s = arr[:, i, :] |
| | else: |
| | s = arr[:, :, i] |
| | return np.rot90(s, k=rot_k) |
| |
|
| | im_for_cbar = None |
| | for j, idx in enumerate(idxs): |
| | axp = axes[j] |
| | img = get_slice(volume_zyx, axis, idx) |
| | h = get_slice(heat_zyx, axis, idx) |
| | h_vis = np.clip(h, -m, m) |
| |
|
| | axp.imshow(img, cmap="gray", origin="lower") |
| | im_for_cbar = axp.imshow(h_vis, cmap="bwr", alpha=alpha, origin="lower", vmin=-m, vmax=m) |
| | axp.set_title(f"{idx}", fontsize=9) |
| | axp.axis("off") |
| |
|
| | for k in range(n, n_rows * n_cols): |
| | axes[k].axis("off") |
| |
|
| | if suptitle is None: |
| | name = "Axial (Z)" if axis == 0 else ("Coronal (Y)" if axis == 1 else "Sagittal (X)") |
| | suptitle = f"{name} | rot {rot_k*90}ยฐ | clip_q={clip_q} | alpha={alpha}" |
| | fig.suptitle(suptitle, y=0.98, fontsize=12) |
| |
|
| | if add_colorbar and im_for_cbar is not None: |
| | cbar = fig.colorbar(im_for_cbar, ax=axes[:n], fraction=0.02, pad=0.01) |
| | cbar.set_label("LIME weight (clipped)", rotation=90) |
| |
|
| | os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) |
| | plt.tight_layout() |
| | plt.savefig(out_path, dpi=160) |
| | plt.close(fig) |
| |
|
| |
|
| | def create_overlay_from_segments(segments_2d: np.ndarray, weights: dict, alpha=0.5) -> np.ndarray: |
| | """ |
| | Create RGBA overlay from segments and LIME weights. |
| | Red = positive (supportive), Blue = negative (contradictory) |
| | Visualizes ALL supervoxels based on their weights. |
| | Returns (H, W, 4) RGBA array |
| | """ |
| | H, W = segments_2d.shape |
| | overlay = np.zeros((H, W, 4), dtype=np.float32) |
| | |
| | |
| | all_weights = [float(v) for k, v in weights.items() if int(k) != 0] |
| | if not all_weights: |
| | return overlay |
| | |
| | max_abs_weight = max(abs(w) for w in all_weights) |
| | if max_abs_weight < 1e-8: |
| | return overlay |
| | |
| | |
| | for seg_id_str, weight in weights.items(): |
| | seg_id = int(seg_id_str) |
| | if seg_id == 0: |
| | continue |
| | |
| | mask = (segments_2d == seg_id) |
| | if not mask.any(): |
| | continue |
| | |
| | |
| | norm_weight = weight / max_abs_weight |
| | |
| | |
| | edge_mask = mask & (~binary_erosion(mask)) |
| | |
| | if weight > 0: |
| | |
| | overlay[mask, 0] = 1.0 |
| | overlay[mask, 1] = 0.0 |
| | overlay[mask, 2] = 0.0 |
| | overlay[mask, 3] = alpha * abs(norm_weight) |
| | |
| | |
| | overlay[edge_mask, 3] = min(1.0, alpha * abs(norm_weight) * 2.0) |
| | else: |
| | |
| | overlay[mask, 0] = 0.0 |
| | overlay[mask, 1] = 0.4 |
| | overlay[mask, 2] = 1.0 |
| | overlay[mask, 3] = alpha * abs(norm_weight) |
| | |
| | |
| | overlay[edge_mask, 3] = min(1.0, alpha * abs(norm_weight) * 2.0) |
| | |
| | return overlay |
| |
|
| |
|
| | def get_top_positive_supervoxel_id(weights: dict, ignore_ids=(0,)) -> int: |
| | """Return segment ID with highest positive LIME weight (most RED / supportive). |
| | Ignores background segment 0 by default.""" |
| | items = [(int(k), float(v)) for k, v in weights.items() if int(k) not in ignore_ids] |
| | if not items: |
| | raise ValueError("weights vuoto o contiene solo segmenti ignorati.") |
| |
|
| | pos = [(k, v) for k, v in items if v > 0] |
| | if pos: |
| | return max(pos, key=lambda kv: kv[1])[0] |
| | return max(items, key=lambda kv: kv[1])[0] |
| |
|
| |
|
| | def get_top_negative_supervoxel_id(weights: dict, ignore_ids=(0,)) -> int: |
| | """ |
| | Ritorna l'id del segmento con weight piรน negativo (piรน 'blu'). |
| | Se non esistono pesi negativi, ritorna comunque il min (anche se positivo). |
| | """ |
| | items = [(int(k), float(v)) for k, v in weights.items() if int(k) not in ignore_ids] |
| | if not items: |
| | raise ValueError("weights vuoto o contiene solo segmenti ignorati.") |
| |
|
| | neg = [(k, v) for k, v in items if v < 0] |
| | if neg: |
| | return min(neg, key=lambda kv: kv[1])[0] |
| | return min(items, key=lambda kv: kv[1])[0] |
| |
|
| |
|
| | def _rgba_overlay_from_mask(mask2d: np.ndarray, rgba=(1.0, 0.0, 0.0), alpha=0.45) -> np.ndarray: |
| | """ |
| | mask2d: float/bool (H,W) con 1 dove disegnare |
| | rgba: (R,G,B) in [0,1] |
| | """ |
| | m = mask2d.astype(np.float32) |
| | overlay = np.zeros((m.shape[0], m.shape[1], 4), dtype=np.float32) |
| | overlay[..., 0] = float(rgba[0]) |
| | overlay[..., 1] = float(rgba[1]) |
| | overlay[..., 2] = float(rgba[2]) |
| | overlay[..., 3] = float(alpha) * m |
| | return overlay |
| |
|
| |
|
| | def _rgba_edge_from_mask(mask2d: np.ndarray, rgba=(1.0, 0.0, 0.0), edge_alpha=1.0) -> np.ndarray: |
| | m = mask2d.astype(bool) |
| | edge = m & (~binary_erosion(m)) |
| | overlay = np.zeros((m.shape[0], m.shape[1], 4), dtype=np.float32) |
| | overlay[..., 0] = float(rgba[0]) |
| | overlay[..., 1] = float(rgba[1]) |
| | overlay[..., 2] = float(rgba[2]) |
| | overlay[..., 3] = float(edge_alpha) * edge.astype(np.float32) |
| | return overlay |
| |
|
| |
|
| | def save_overlay_single_supervoxel_png( |
| | volume_zyx: np.ndarray, |
| | segments_zyx: np.ndarray, |
| | weights: dict, |
| | out_path: str, |
| | axis: int = 0, |
| | idx: int | None = None, |
| | rot_k: int = 0, |
| | alpha: float = 0.45, |
| | origin: str = "lower", |
| | edge_alpha: float = 1.0, |
| | ): |
| | """ |
| | Salva overlay con: |
| | - supervoxel piรน 'rosso' (peso massimo positivo) in rosso acceso |
| | - supervoxel piรน 'blu' (peso piรน negativo) in blu acceso |
| | Ritorna (best_red_id, best_blue_id). |
| | """ |
| | best_red_id = get_top_positive_supervoxel_id(weights, ignore_ids=(0,)) |
| | best_blue_id = get_top_negative_supervoxel_id(weights, ignore_ids=(0,)) |
| |
|
| | mask_red_3d = (segments_zyx == best_red_id).astype(np.float32) |
| | mask_blue_3d = (segments_zyx == best_blue_id).astype(np.float32) |
| |
|
| | if idx is None: |
| | idx = volume_zyx.shape[axis] // 2 |
| |
|
| | |
| | if axis == 0: |
| | img = volume_zyx[idx, :, :] |
| | m_red = mask_red_3d[idx, :, :] |
| | m_blue = mask_blue_3d[idx, :, :] |
| | title = f"Axial(Z) slice {idx} | red={best_red_id} | blue={best_blue_id}" |
| | elif axis == 1: |
| | img = volume_zyx[:, idx, :] |
| | m_red = mask_red_3d[:, idx, :] |
| | m_blue = mask_blue_3d[:, idx, :] |
| | title = f"Coronal(Y) slice {idx} | red={best_red_id} | blue={best_blue_id}" |
| | else: |
| | img = volume_zyx[:, :, idx] |
| | m_red = mask_red_3d[:, :, idx] |
| | m_blue = mask_blue_3d[:, :, idx] |
| | title = f"Sagittal(X) slice {idx} | red={best_red_id} | blue={best_blue_id}" |
| |
|
| | img = np.rot90(img, k=rot_k) |
| | m_red = np.rot90(m_red, k=rot_k) |
| | m_blue = np.rot90(m_blue, k=rot_k) |
| |
|
| | plt.figure(figsize=(6, 6)) |
| | plt.imshow(img, cmap="gray", origin=origin) |
| |
|
| | |
| | plt.imshow(_rgba_overlay_from_mask(m_blue, rgba=(0.0, 0.4, 1.0), alpha=alpha), origin=origin) |
| | plt.imshow(_rgba_edge_from_mask(m_blue, rgba=(0.0, 0.4, 1.0), edge_alpha=edge_alpha), origin=origin) |
| |
|
| | plt.imshow(_rgba_overlay_from_mask(m_red, rgba=(1.0, 0.0, 0.0), alpha=alpha), origin=origin) |
| | plt.imshow(_rgba_edge_from_mask(m_red, rgba=(1.0, 0.0, 0.0), edge_alpha=edge_alpha), origin=origin) |
| |
|
| | plt.title(title) |
| | plt.axis("off") |
| |
|
| | os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) |
| | plt.tight_layout() |
| | plt.savefig(out_path, dpi=160) |
| | plt.close() |
| |
|
| | return best_red_id, best_blue_id |
| |
|
| |
|
| | def save_overlay_grid_single_supervoxel_png( |
| | volume_zyx: np.ndarray, |
| | segments_zyx: np.ndarray, |
| | weights: dict, |
| | out_path: str, |
| | axis: int = 0, |
| | n_cols: int = 8, |
| | rot_k: int = 0, |
| | alpha: float = 0.45, |
| | origin: str = "lower", |
| | suptitle: str | None = None, |
| | edge_alpha: float = 1.0, |
| | ): |
| | """ |
| | Griglia overlay con TUTTE le slice, organizzate come save_flair_grid_all: |
| | - supervoxel piรน 'rosso' in rosso acceso |
| | - supervoxel piรน 'blu' in blu acceso |
| | Ritorna (best_red_id, best_blue_id). |
| | """ |
| | best_red_id = get_top_positive_supervoxel_id(weights, ignore_ids=(0,)) |
| | best_blue_id = get_top_negative_supervoxel_id(weights, ignore_ids=(0,)) |
| |
|
| | mask_red_3d = (segments_zyx == best_red_id).astype(np.float32) |
| | mask_blue_3d = (segments_zyx == best_blue_id).astype(np.float32) |
| |
|
| | dim = volume_zyx.shape[axis] |
| | n_rows = int(np.ceil(dim / n_cols)) |
| |
|
| | fig, axes = plt.subplots( |
| | n_rows, |
| | n_cols, |
| | figsize=(n_cols * 2, n_rows * 2), |
| | facecolor="black" |
| | ) |
| | axes = np.array(axes).reshape(-1) |
| |
|
| | def get_slice(arr, ax, i): |
| | if ax == 0: |
| | s = arr[i, :, :] |
| | elif ax == 1: |
| | s = arr[:, i, :] |
| | else: |
| | s = arr[:, :, i] |
| | return np.rot90(s, k=rot_k) |
| |
|
| | for i in range(dim): |
| | img = get_slice(volume_zyx, axis, i) |
| | m_red = get_slice(mask_red_3d, axis, i) |
| | m_blue = get_slice(mask_blue_3d, axis, i) |
| |
|
| | axes[i].imshow(img, cmap="gray", origin=origin) |
| |
|
| | |
| | axes[i].imshow(_rgba_overlay_from_mask(m_blue, rgba=(0.0, 0.4, 1.0), alpha=alpha), origin=origin) |
| | axes[i].imshow(_rgba_edge_from_mask(m_blue, rgba=(0.0, 0.4, 1.0), edge_alpha=edge_alpha), origin=origin) |
| |
|
| | axes[i].imshow(_rgba_overlay_from_mask(m_red, rgba=(1.0, 0.0, 0.0), alpha=alpha), origin=origin) |
| | axes[i].imshow(_rgba_edge_from_mask(m_red, rgba=(1.0, 0.0, 0.0), edge_alpha=edge_alpha), origin=origin) |
| |
|
| | axes[i].set_title( |
| | f"z={i}", |
| | color="cyan", |
| | fontsize=9, |
| | fontweight='bold' |
| | ) |
| | axes[i].axis("off") |
| |
|
| | |
| | for i in range(dim, len(axes)): |
| | axes[i].axis("off") |
| |
|
| | if suptitle is None: |
| | name = "Axial(Z)" if axis == 0 else ("Coronal(Y)" if axis == 1 else "Sagittal(X)") |
| | suptitle = f"{name} | red={best_red_id} | blue={best_blue_id} | rot {rot_k*90}ยฐ" |
| | fig.suptitle(suptitle) |
| |
|
| | os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) |
| | plt.tight_layout() |
| | plt.savefig(out_path, dpi=150, bbox_inches="tight") |
| | plt.close(fig) |
| |
|
| | return best_red_id, best_blue_id |
| |
|
| |
|
| | def save_volume_slices_overlay( |
| | vol: torch.Tensor, |
| | heat: np.ndarray, |
| | save_path: str, |
| | title: str = "Volume overlay", |
| | ncols: int = 8, |
| | is_healthy: bool = False, |
| | alpha: float = 0.45, |
| | clip_q: float = 0.99, |
| | rot_k: int = 0, |
| | brain_mask: np.ndarray | None = None, |
| | ): |
| | |
| | if vol.ndim == 5: |
| | vol = vol[0, 0] |
| | elif vol.ndim == 4: |
| | vol = vol[0] |
| |
|
| | vol_np = vol.detach().cpu().numpy().astype(np.float32) |
| | heat_np = heat.astype(np.float32) |
| |
|
| | if vol_np.shape != heat_np.shape: |
| | raise ValueError(f"Shape mismatch: vol {vol_np.shape} vs heat {heat_np.shape}") |
| |
|
| | if brain_mask is not None: |
| | if brain_mask.shape != vol_np.shape: |
| | raise ValueError(f"Brain mask shape mismatch: {brain_mask.shape} vs {vol_np.shape}") |
| | brain_np = brain_mask.astype(bool) |
| | else: |
| | brain_np = None |
| |
|
| | D, H, W = vol_np.shape |
| | nrows = int(np.ceil(D / ncols)) |
| |
|
| | |
| | m = float(max(np.quantile(np.abs(heat_np), clip_q), 1e-8)) |
| |
|
| | fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 2, nrows * 2), facecolor="black") |
| | axes = axes.flatten() |
| |
|
| | for i in range(D): |
| | img = vol_np[i] |
| | h = heat_np[i] |
| |
|
| | if brain_np is not None: |
| | b = brain_np[i] |
| | else: |
| | b = None |
| |
|
| | |
| | img = np.rot90(img, k=rot_k) |
| | h = np.rot90(h, k=rot_k) |
| | if b is not None: |
| | b = np.rot90(b, k=rot_k) |
| |
|
| | |
| | h_vis = np.clip(h, -m, m) |
| |
|
| | ax = axes[i] |
| | ax.set_facecolor("black") |
| |
|
| | if b is not None: |
| | |
| | img_ma = np.ma.array(img, mask=~b) |
| | ax.imshow(img_ma, cmap="gray", origin="lower") |
| |
|
| | |
| | h_ma = np.ma.array(h_vis, mask=~b) |
| | ax.imshow(h_ma, cmap="bwr", alpha=alpha, vmin=-m, vmax=m, origin="lower") |
| | else: |
| | ax.imshow(img, cmap="gray", origin="lower") |
| | ax.imshow(h_vis, cmap="bwr", alpha=alpha, vmin=-m, vmax=m, origin="lower") |
| |
|
| | ax.set_title(f"z={i}", color="cyan", fontsize=9, fontweight="bold") |
| | ax.axis("off") |
| |
|
| | for i in range(D, len(axes)): |
| | axes[i].set_facecolor("black") |
| | axes[i].axis("off") |
| |
|
| | fig.suptitle(f"{title} {'(Healthy)' if is_healthy else '(Pathological)'}", color="white") |
| |
|
| | Path(save_path).parent.mkdir(parents=True, exist_ok=True) |
| | plt.tight_layout() |
| | plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor()) |
| | plt.close(fig) |
| |
|
| |
|
| | def save_flair_grid_all(nifti_path: str, save_path: str, load_nifti_volume_fn, ncols: int = 8): |
| | """ |
| | Save grid of all slices from a NIfTI file. |
| | Note: load_nifti_volume_fn must be provided (get it from import_architecture_from_model_dir). |
| | """ |
| | vol = load_nifti_volume_fn(nifti_path) |
| | vol = vol.squeeze(0).squeeze(0).detach().cpu().numpy() |
| | D = vol.shape[0] |
| | nrows = int(np.ceil(D / ncols)) |
| |
|
| | fig, axes = plt.subplots( |
| | nrows, |
| | ncols, |
| | figsize=(ncols * 2, nrows * 2), |
| | facecolor="black" |
| | ) |
| | axes = axes.flatten() |
| |
|
| | for i in range(D): |
| | axes[i].imshow(vol[i], cmap="gray", origin="lower") |
| | axes[i].set_title( |
| | f"z={i}", |
| | color="cyan", |
| | fontsize=9, |
| | fontweight='bold' |
| | ) |
| | axes[i].axis("off") |
| |
|
| | |
| | for i in range(D, len(axes)): |
| | axes[i].axis("off") |
| |
|
| | |
| | Path(save_path).parent.mkdir(parents=True, exist_ok=True) |
| |
|
| | plt.tight_layout() |
| | plt.savefig(save_path, dpi=150, bbox_inches="tight") |
| | plt.close(fig) |
| |
|
| |
|
| | |
| | |
| | |
| | def run_interpretability( |
| | model, |
| | load_nifti_volume, |
| | CANONICAL_PROMPT, |
| | mri_path: str, |
| | report: str, |
| | output_dir: str, |
| | lime_samples: int = 100, |
| | n_segments: int = 20, |
| | hide_color: float = 0.0, |
| | alpha: float = 0.45, |
| | clip_q: float = 0.99, |
| | seed: int = 42, |
| | ): |
| | """Run LIME interpretability on a single MRI scan.""" |
| | set_seed(seed) |
| | device = next(model.parameters()).device |
| | out_dir = Path(output_dir) |
| | out_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | print(f"\n{'='*60}") |
| | print("๐ BrainGemma3D LIME Interpretability") |
| | print(f"{'='*60}") |
| | print(f"๐ MRI: {mri_path}") |
| | print(f"๐ Report: {report[:100]}...") |
| | print(f"๐พ Output: {output_dir}") |
| | print(f"{'='*60}\n") |
| |
|
| | |
| | print("๐ฅ Loading MRI volume...") |
| | volume = load_nifti_volume(mri_path, target_size=(64, 128, 128)).to(device) |
| | if volume.ndim == 4: |
| | volume = volume.unsqueeze(0) |
| | vol_np = volume.squeeze().cpu().numpy() |
| | print(f" Shape: {vol_np.shape}") |
| |
|
| | |
| | print(f"\n๐งฉ Creating {n_segments} brain supervoxels...") |
| | segments, brain_mask = big_supervoxels_brain_only(vol_np, n_segments=n_segments) |
| | |
| | |
| | print(f"\n๐ฌ Running LIME with {lime_samples} samples...") |
| | segmentation_fn = make_segmentation_fn(segments) |
| | |
| | explainer = lime_image.LimeImageExplainer() |
| | |
| | def predict_fn(vols_4d): |
| | """ |
| | vols_4d: (n_samples, D, H, W) - LIME perturbed volumes |
| | Returns: (n_samples,) array of scores |
| | """ |
| | |
| | vols_5d = vols_4d[:, np.newaxis, :, :, :] |
| | scores = lime_score_report_nll( |
| | vols_5d, |
| | model, |
| | prompt=CANONICAL_PROMPT, |
| | report_ref=report, |
| | batch_size=1, |
| | ) |
| | return scores |
| | |
| | |
| | explanation = explainer.explain_instance( |
| | vol_np, |
| | predict_fn, |
| | top_labels=1, |
| | hide_color=hide_color, |
| | num_samples=lime_samples, |
| | segmentation_fn=segmentation_fn, |
| | ) |
| | |
| | |
| | label = explanation.top_labels[0] |
| | weights = dict(explanation.local_exp[label]) |
| | |
| | print(f"\nโ
LIME completed!") |
| | print(f" Supervoxel weights (sample): {list(weights.items())[:5]}") |
| | |
| | |
| | print("\n๐ Building weight volume...") |
| | wvol = np.zeros_like(vol_np, dtype=np.float32) |
| | for seg_id, w in weights.items(): |
| | seg_id = int(seg_id) |
| | if seg_id == 0: |
| | continue |
| | wvol[segments == seg_id] = float(w) |
| | |
| | |
| | wvol[~brain_mask] = 0.0 |
| | |
| | |
| | print("\n๐พ Saving visualizations...") |
| | |
| | |
| | save_volume_slices_overlay( |
| | volume, |
| | wvol, |
| | str(out_dir / "overlay_slices.png"), |
| | title="LIME Interpretability", |
| | ncols=8, |
| | is_healthy=False, |
| | alpha=alpha, |
| | clip_q=clip_q, |
| | rot_k=0, |
| | brain_mask=brain_mask, |
| | ) |
| | |
| | |
| | save_overlay_grid_single_supervoxel_png( |
| | vol_np, segments, weights, |
| | out_path=str(out_dir / "lime_top_supervoxels_grid.png"), |
| | axis=0, n_cols=8, alpha=0.55, |
| | suptitle="Top Supportive (Red) and Contradicting (Blue) Supervoxels" |
| | ) |
| | |
| | |
| | print("\n๐พ Creating 2x3 grid figure (original + LIME overlay)...") |
| | D = vol_np.shape[0] |
| | |
| | lo = int(0.30 * D) |
| | hi = int(0.70 * D) |
| | selected_slices = np.linspace(lo, hi, 3, dtype=int).tolist() |
| | |
| | n_slices = len(selected_slices) |
| | fig, axes = plt.subplots(2, n_slices, figsize=(n_slices * 4, 2 * 4)) |
| | |
| | for col, slice_idx in enumerate(selected_slices): |
| | |
| | img_slice = vol_np[slice_idx, :, :] |
| | seg_slice = segments[slice_idx, :, :] |
| | |
| | |
| | axes[0, col].imshow(img_slice, cmap='gray', origin='lower', interpolation='bilinear') |
| | axes[0, col].set_title(f'Slice {slice_idx}', fontsize=12, fontweight='bold') |
| | axes[0, col].axis('off') |
| | |
| | |
| | axes[1, col].imshow(img_slice, cmap='gray', origin='lower', interpolation='bilinear') |
| | overlay = create_overlay_from_segments(seg_slice, weights, alpha=0.5) |
| | axes[1, col].imshow(overlay, origin='lower', interpolation='nearest') |
| | axes[1, col].axis('off') |
| | |
| | |
| | axes[0, 0].text(-0.15, 0.5, 'Original', transform=axes[0, 0].transAxes, |
| | fontsize=14, fontweight='bold', va='center', rotation=90) |
| | axes[1, 0].text(-0.15, 0.5, 'LIME Overlay', transform=axes[1, 0].transAxes, |
| | fontsize=14, fontweight='bold', va='center', rotation=90) |
| | |
| | plt.tight_layout() |
| | plt.savefig(str(out_dir / "lime_2x3_grid.png"), dpi=300, bbox_inches='tight', facecolor='white') |
| | plt.close() |
| | print(f"โ
Saved 2x3 grid (slices {selected_slices})") |
| | |
| | |
| | with open(out_dir / "lime_report.txt", "w") as f: |
| | f.write(f"Reference Report:\n{report}\n\n") |
| | f.write(f"LIME Supervoxel Weights (top 20):\n") |
| | sorted_weights = sorted(weights.items(), key=lambda x: abs(x[1]), reverse=True) |
| | for seg_id, weight in sorted_weights[:20]: |
| | if int(seg_id) != 0: |
| | f.write(f" Segment {seg_id}: {weight:.4f}\n") |
| | |
| | |
| | weights_dict = {int(k): float(v) for k, v in weights.items() if int(k) != 0} |
| | with open(out_dir / "lime_weights.json", "w") as f: |
| | json.dump(weights_dict, f, indent=2) |
| | print(f"๐พ Saved lime_weights.json ({len(weights_dict)} brain supervoxels)", flush=True) |
| | |
| | |
| | np.save(str(out_dir / "lime_wvol.npy"), wvol) |
| | np.save(str(out_dir / "lime_segments.npy"), segments) |
| | print(f"โ
Saved wvol/segments arrays", flush=True) |
| | print(f" wvol stats: shape={wvol.shape} min={wvol.min():.4g} max={wvol.max():.4g}", flush=True) |
| | |
| | print(f"\n{'='*60}") |
| | print("โ
Interpretability analysis completed!") |
| | print(f" Results saved to: {output_dir}") |
| | print(f"{'='*60}\n") |
| | |
| | return weights, wvol |
| |
|
| |
|
| | |
| | |
| | |
| | def main(): |
| | parser = argparse.ArgumentParser(description="BrainGemma3D LIME Interpretability") |
| | |
| | |
| | parser.add_argument("--model_dir", required=True, help="Path to BrainGemma3D model folder") |
| | parser.add_argument("--mri_path", required=True, help="Path to .nii/.nii.gz MRI scan") |
| | |
| | |
| | parser.add_argument("--report", default=None, help="Reference report text. If not provided, will generate it first.") |
| | parser.add_argument("--output_dir", default="./lime_output", help="Output directory for results") |
| | |
| | |
| | parser.add_argument("--max_new_tokens", type=int, default=256) |
| | parser.add_argument("--temperature", type=float, default=0.1) |
| | parser.add_argument("--top_p", type=float, default=0.9) |
| | |
| | |
| | parser.add_argument("--lime_samples", type=int, default=100, help="Number of LIME samples") |
| | parser.add_argument("--n_segments", type=int, default=20, help="Number of supervoxels") |
| | parser.add_argument("--hide_color", type=float, default=0.0, help="Hide color for LIME perturbations") |
| | |
| | |
| | parser.add_argument("--alpha", type=float, default=0.45, help="Overlay transparency") |
| | parser.add_argument("--clip_q", type=float, default=0.99, help="Heatmap clipping quantile") |
| | |
| | |
| | parser.add_argument("--seed", type=int, default=42, help="Random seed") |
| | |
| | args = parser.parse_args() |
| | |
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | print(f"๐ Loading BrainGemma3D model from {args.model_dir}...") |
| | model, load_nifti_volume, CANONICAL_PROMPT = load_full_model(args.model_dir, device) |
| | print("โ
Model loaded successfully!") |
| | |
| | |
| | if args.report is None: |
| | print("\n๐ No report provided, generating one...") |
| | volume = load_nifti_volume(args.mri_path, target_size=(64, 128, 128)).to(device) |
| | if volume.ndim == 4: |
| | volume = volume.unsqueeze(0) |
| | |
| | with torch.no_grad(): |
| | report = model.generate_report( |
| | volume, |
| | prompt=CANONICAL_PROMPT, |
| | max_new_tokens=args.max_new_tokens, |
| | temperature=args.temperature, |
| | top_p=args.top_p, |
| | ) |
| | print(f"โ
Generated report: {report}") |
| | else: |
| | report = args.report |
| | |
| | |
| | run_interpretability( |
| | model=model, |
| | load_nifti_volume=load_nifti_volume, |
| | CANONICAL_PROMPT=CANONICAL_PROMPT, |
| | mri_path=args.mri_path, |
| | report=report, |
| | output_dir=args.output_dir, |
| | lime_samples=args.lime_samples, |
| | n_segments=args.n_segments, |
| | hide_color=args.hide_color, |
| | alpha=args.alpha, |
| | clip_q=args.clip_q, |
| | seed=args.seed, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|