mediview-3d-vision / infer_anomaly.py
hmnshudhmn24's picture
Upload 16 files
5b023ff verified
import argparse, yaml, numpy as np
from utils import normalize_volume
from io_utils import load_slices_from_folder, load_nifti, load_dicom_folder, save_mesh_ply
from skimage import measure
import os
def simple_threshold_mask(volume, threshold=0.6):
vol_n = normalize_volume(volume)
return (vol_n > threshold).astype(np.uint8)
def mask_to_vertex_colors(verts, faces, mask_vol, spacing=(1.0,1.0,1.0)):
coords = verts / np.array(spacing)
coords_idx = np.round(coords).astype(int)
colors = np.zeros((len(verts), 3), dtype=float)
zmax, ymax, xmax = mask_vol.shape
for i, (z,y,x) in enumerate(coords_idx):
if 0 <= z < zmax and 0 <= y < ymax and 0 <= x < xmax:
if mask_vol[z,y,x]:
colors[i] = np.array([1.0, 0.0, 0.0])
else:
colors[i] = np.array([0.7, 0.7, 0.7])
else:
colors[i] = np.array([0.7, 0.7, 0.7])
return colors
def mesh_and_mask_from_volume(vol, iso=0.5, spacing=(1.0,1.0,1.0), mask=None):
vol_n = normalize_volume(vol)
verts, faces, normals, values = measure.marching_cubes(vol_n, level=iso, spacing=spacing)
colors = None
if mask is not None:
colors = mask_to_vertex_colors(verts, faces, mask, spacing=spacing)
return verts, faces, normals, colors
def analyze_regions(mask, min_voxels=50):
# find connected components and simple stats
from scipy import ndimage as ndi
labeled, n = ndi.label(mask)
regions = []
for lab in range(1, n+1):
coords = np.argwhere(labeled==lab)
voxels = coords.shape[0]
z_mean, y_mean, x_mean = coords.mean(axis=0).tolist()
regions.append({'label': lab, 'voxels': int(voxels), 'center': [float(z_mean), float(y_mean), float(x_mean)]})
# filter small
regions = [r for r in regions if r['voxels'] >= min_voxels]
return regions
def generate_text_explanation(regions, vol_shape, spacing=(1.0,1.0,1.0)):
if not regions:
return "No anomalous regions detected above threshold."
texts = []
zdim, ydim, xdim = vol_shape
for r in regions:
zc, yc, xc = r['center']
# approximate location as top/middle/bottom and left/center/right
zpos = 'top' if zc < zdim*0.33 else ('bottom' if zc > zdim*0.66 else 'middle')
ypos = 'left' if xc < xdim*0.33 else ('right' if xc > xdim*0.66 else 'center')
texts.append(f"Region {r['label']}: approx {r['voxels']} voxels, located near the {zpos} (z~{zc:.1f}), {ypos} (x~{xc:.1f}). Suggest clinical review and consider high-resolution imaging or segmentation.")
return "\n".join(texts)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--source", required=True)
parser.add_argument("--source_type", choices=['folder','dicom','nifti'], default='folder')
parser.add_argument("--glob", default="*.png")
parser.add_argument("--config", default="config.yaml")
parser.add_argument("--method", choices=['threshold','model'], default='threshold')
parser.add_argument("--threshold", type=float, default=None)
parser.add_argument("--model_path", default=None)
parser.add_argument("--out", default="mesh_colored.ply")
parser.add_argument("--explain_out", default="explanation.txt")
args = parser.parse_args()
try:
cfg = yaml.safe_load(open(args.config))
except Exception:
cfg = {}
cfg_thresh = cfg.get('anomaly', {}).get('threshold', 0.6)
threshold = args.threshold or cfg_thresh
spacing = tuple(cfg.get('reconstruct', {}).get('spacing', [1.0,1.0,1.0]))
iso = cfg.get('reconstruct',{}).get('iso_value', 0.5)
min_vox = cfg.get('text_explainer',{}).get('min_region_voxels', 50)
# load volume
if args.source_type == 'folder':
vol = load_slices_from_folder(args.source, glob_pattern=args.glob)
elif args.source_type == 'dicom':
vol = load_dicom_folder(args.source)
else:
vol = load_nifti(args.source)
if args.method == 'threshold':
mask = simple_threshold_mask(vol, threshold=threshold)
else:
# model-based per-slice segmentation (if model provided)
try:
import torch
from models.unet import UNet
model = UNet(in_channels=1, out_channels=1)
model.load_state_dict(torch.load(args.model_path, map_location='cpu'))
model.eval()
vol_n = normalize_volume(vol)
mask = np.zeros_like(vol_n, dtype=np.uint8)
for i in range(vol_n.shape[0]):
s = vol_n[i]
x = (s - s.min())/(s.max()-s.min()+1e-8)
import torch
inp = torch.tensor(x[np.newaxis, np.newaxis, ...], dtype=torch.float32)
with torch.no_grad():
out = model(inp).numpy()[0,0]
mask[i] = (out > 0.5).astype(np.uint8)
except Exception as e:
print("Model-based method failed:", e)
mask = simple_threshold_mask(vol, threshold=threshold)
regions = analyze_regions(mask, min_voxels=min_vox)
explanation = generate_text_explanation(regions, vol.shape, spacing=spacing)
verts, faces, normals, colors = mesh_and_mask_from_volume(vol, iso=iso, spacing=spacing, mask=mask)
save_mesh_ply(verts, faces, args.out, normals=normals, colors=colors)
with open(args.explain_out, 'w') as f:
f.write(explanation)
print("Saved colored mesh to", args.out)
print("Saved textual explanation to", args.explain_out)
if __name__ == '__main__':
main()