|
|
import argparse, yaml, numpy as np |
|
|
from utils import normalize_volume |
|
|
from io_utils import load_slices_from_folder, load_nifti, load_dicom_folder, save_mesh_ply |
|
|
from skimage import measure |
|
|
import os |
|
|
|
|
|
def simple_threshold_mask(volume, threshold=0.6): |
|
|
vol_n = normalize_volume(volume) |
|
|
return (vol_n > threshold).astype(np.uint8) |
|
|
|
|
|
def mask_to_vertex_colors(verts, faces, mask_vol, spacing=(1.0,1.0,1.0)): |
|
|
coords = verts / np.array(spacing) |
|
|
coords_idx = np.round(coords).astype(int) |
|
|
colors = np.zeros((len(verts), 3), dtype=float) |
|
|
zmax, ymax, xmax = mask_vol.shape |
|
|
for i, (z,y,x) in enumerate(coords_idx): |
|
|
if 0 <= z < zmax and 0 <= y < ymax and 0 <= x < xmax: |
|
|
if mask_vol[z,y,x]: |
|
|
colors[i] = np.array([1.0, 0.0, 0.0]) |
|
|
else: |
|
|
colors[i] = np.array([0.7, 0.7, 0.7]) |
|
|
else: |
|
|
colors[i] = np.array([0.7, 0.7, 0.7]) |
|
|
return colors |
|
|
|
|
|
def mesh_and_mask_from_volume(vol, iso=0.5, spacing=(1.0,1.0,1.0), mask=None): |
|
|
vol_n = normalize_volume(vol) |
|
|
verts, faces, normals, values = measure.marching_cubes(vol_n, level=iso, spacing=spacing) |
|
|
colors = None |
|
|
if mask is not None: |
|
|
colors = mask_to_vertex_colors(verts, faces, mask, spacing=spacing) |
|
|
return verts, faces, normals, colors |
|
|
|
|
|
def analyze_regions(mask, min_voxels=50): |
|
|
|
|
|
from scipy import ndimage as ndi |
|
|
labeled, n = ndi.label(mask) |
|
|
regions = [] |
|
|
for lab in range(1, n+1): |
|
|
coords = np.argwhere(labeled==lab) |
|
|
voxels = coords.shape[0] |
|
|
z_mean, y_mean, x_mean = coords.mean(axis=0).tolist() |
|
|
regions.append({'label': lab, 'voxels': int(voxels), 'center': [float(z_mean), float(y_mean), float(x_mean)]}) |
|
|
|
|
|
regions = [r for r in regions if r['voxels'] >= min_voxels] |
|
|
return regions |
|
|
|
|
|
def generate_text_explanation(regions, vol_shape, spacing=(1.0,1.0,1.0)): |
|
|
if not regions: |
|
|
return "No anomalous regions detected above threshold." |
|
|
|
|
|
texts = [] |
|
|
|
|
|
zdim, ydim, xdim = vol_shape |
|
|
for r in regions: |
|
|
zc, yc, xc = r['center'] |
|
|
|
|
|
zpos = 'top' if zc < zdim*0.33 else ('bottom' if zc > zdim*0.66 else 'middle') |
|
|
ypos = 'left' if xc < xdim*0.33 else ('right' if xc > xdim*0.66 else 'center') |
|
|
texts.append(f"Region {r['label']}: approx {r['voxels']} voxels, located near the {zpos} (z~{zc:.1f}), {ypos} (x~{xc:.1f}). Suggest clinical review and consider high-resolution imaging or segmentation.") |
|
|
return "\n".join(texts) |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--source", required=True) |
|
|
parser.add_argument("--source_type", choices=['folder','dicom','nifti'], default='folder') |
|
|
parser.add_argument("--glob", default="*.png") |
|
|
parser.add_argument("--config", default="config.yaml") |
|
|
parser.add_argument("--method", choices=['threshold','model'], default='threshold') |
|
|
parser.add_argument("--threshold", type=float, default=None) |
|
|
parser.add_argument("--model_path", default=None) |
|
|
parser.add_argument("--out", default="mesh_colored.ply") |
|
|
parser.add_argument("--explain_out", default="explanation.txt") |
|
|
args = parser.parse_args() |
|
|
|
|
|
try: |
|
|
cfg = yaml.safe_load(open(args.config)) |
|
|
except Exception: |
|
|
cfg = {} |
|
|
cfg_thresh = cfg.get('anomaly', {}).get('threshold', 0.6) |
|
|
threshold = args.threshold or cfg_thresh |
|
|
spacing = tuple(cfg.get('reconstruct', {}).get('spacing', [1.0,1.0,1.0])) |
|
|
iso = cfg.get('reconstruct',{}).get('iso_value', 0.5) |
|
|
min_vox = cfg.get('text_explainer',{}).get('min_region_voxels', 50) |
|
|
|
|
|
|
|
|
if args.source_type == 'folder': |
|
|
vol = load_slices_from_folder(args.source, glob_pattern=args.glob) |
|
|
elif args.source_type == 'dicom': |
|
|
vol = load_dicom_folder(args.source) |
|
|
else: |
|
|
vol = load_nifti(args.source) |
|
|
|
|
|
if args.method == 'threshold': |
|
|
mask = simple_threshold_mask(vol, threshold=threshold) |
|
|
else: |
|
|
|
|
|
try: |
|
|
import torch |
|
|
from models.unet import UNet |
|
|
model = UNet(in_channels=1, out_channels=1) |
|
|
model.load_state_dict(torch.load(args.model_path, map_location='cpu')) |
|
|
model.eval() |
|
|
vol_n = normalize_volume(vol) |
|
|
mask = np.zeros_like(vol_n, dtype=np.uint8) |
|
|
for i in range(vol_n.shape[0]): |
|
|
s = vol_n[i] |
|
|
x = (s - s.min())/(s.max()-s.min()+1e-8) |
|
|
import torch |
|
|
inp = torch.tensor(x[np.newaxis, np.newaxis, ...], dtype=torch.float32) |
|
|
with torch.no_grad(): |
|
|
out = model(inp).numpy()[0,0] |
|
|
mask[i] = (out > 0.5).astype(np.uint8) |
|
|
except Exception as e: |
|
|
print("Model-based method failed:", e) |
|
|
mask = simple_threshold_mask(vol, threshold=threshold) |
|
|
|
|
|
regions = analyze_regions(mask, min_voxels=min_vox) |
|
|
explanation = generate_text_explanation(regions, vol.shape, spacing=spacing) |
|
|
verts, faces, normals, colors = mesh_and_mask_from_volume(vol, iso=iso, spacing=spacing, mask=mask) |
|
|
save_mesh_ply(verts, faces, args.out, normals=normals, colors=colors) |
|
|
with open(args.explain_out, 'w') as f: |
|
|
f.write(explanation) |
|
|
print("Saved colored mesh to", args.out) |
|
|
print("Saved textual explanation to", args.explain_out) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|