Upload 16 files
Browse files- .gitattributes +2 -34
- LICENSE +9 -0
- README.md +58 -3
- app.py +52 -0
- config.yaml +13 -0
- examples/demo_commands.txt +13 -0
- examples/example_image.png +0 -0
- examples/generate_synthetic_phantom.py +15 -0
- infer_anomaly.py +129 -0
- io_utils.py +62 -0
- models/unet.py +54 -0
- notebooks/MEDIVIEW-3D_Demo.ipynb +1 -0
- reconstruct_3d.py +45 -0
- requirements.txt +13 -0
- train_unet.py +48 -0
- utils.py +13 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,3 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.
|
| 24 |
-
*.
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.nii filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.nii.gz filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LICENSE
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License 2.0
|
| 2 |
+
|
| 3 |
+
Copyright 2025 hmnshudhmn24
|
| 4 |
+
|
| 5 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
you may not use this file except in compliance with the License.
|
| 7 |
+
You may obtain a copy of the License at
|
| 8 |
+
|
| 9 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
README.md
CHANGED
|
@@ -1,3 +1,58 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
license: apache-2.0
|
| 5 |
+
tags:
|
| 6 |
+
- image-to-3d
|
| 7 |
+
- medical-imaging
|
| 8 |
+
- reconstruction
|
| 9 |
+
- segmentation
|
| 10 |
+
- explainability
|
| 11 |
+
pipeline_tag: image-to-3d
|
| 12 |
+
library_name: python
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# MEDIVIEW-3D (Advanced)
|
| 16 |
+
|
| 17 |
+
**MEDIVIEW-3D** converts 2D medical image slices into a 3D reconstruction, localizes anomalous regions, and generates **textual explanations** describing the detected regions (size, approximate location, and suggested next steps).
|
| 18 |
+
|
| 19 |
+
**Important:** This is a research/demo tool and **not** a medical device. Do not use for clinical decisions.
|
| 20 |
+
|
| 21 |
+
## Quickstart (demo)
|
| 22 |
+
|
| 23 |
+
1. Install:
|
| 24 |
+
```bash
|
| 25 |
+
pip install -r requirements.txt
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
2. Generate synthetic phantom slices:
|
| 29 |
+
```bash
|
| 30 |
+
python examples/generate_synthetic_phantom.py
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
3. Run inference with thresholding and get a mesh + explanation:
|
| 34 |
+
```bash
|
| 35 |
+
python infer_anomaly.py --source examples/synthetic_phantom --method threshold --out demo_mesh_threshold.ply --explain_out explanation.txt
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
4. (Optional) Train small UNet and run model-based inference:
|
| 39 |
+
```bash
|
| 40 |
+
python train_unet.py --data examples/synthetic_phantom --epochs 3 --out models/unet_demo.pt
|
| 41 |
+
python infer_anomaly.py --source examples/synthetic_phantom --method model --model_path models/unet_demo.pt --out demo_mesh_model.ply --explain_out explanation_model.txt
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
5. Run Streamlit demo:
|
| 45 |
+
```bash
|
| 46 |
+
streamlit run app.py
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
## What you get
|
| 50 |
+
- 3D mesh `.ply` with anomaly regions colored red
|
| 51 |
+
- `explanation.txt` with human-friendly descriptions of detected regions
|
| 52 |
+
- Example synthetic phantom (no patient data)
|
| 53 |
+
- Small UNet implementation for demo training
|
| 54 |
+
|
| 55 |
+
## Safety & Limitations
|
| 56 |
+
- Demo-only; not clinically validated.
|
| 57 |
+
- Do not upload identifiable patient data to public repos.
|
| 58 |
+
- For real medical use, integrate robust preprocessing and obtain regulatory approvals.
|
app.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import numpy as np
|
| 4 |
+
import tempfile, os, zipfile
|
| 5 |
+
from io_utils import load_slices_from_folder
|
| 6 |
+
from infer_anomaly import simple_threshold_mask, mesh_and_mask_from_volume
|
| 7 |
+
from io_utils import save_mesh_ply
|
| 8 |
+
|
| 9 |
+
st.set_page_config(page_title="MEDIVIEW-3D", layout="wide")
|
| 10 |
+
st.title("MEDIVIEW-3D — 2D → 3D Reconstruction + Explanation")
|
| 11 |
+
|
| 12 |
+
uploaded = st.file_uploader("Upload a zip of PNG slices (or skip to use example)", type=["zip"])
|
| 13 |
+
use_example = st.button("Use example synthetic slices")
|
| 14 |
+
thresh = st.slider("Threshold (for simple anomaly)", 0.0, 1.0, 0.6)
|
| 15 |
+
out_mesh = st.text_input("Output mesh path", value="mesh_demo.ply")
|
| 16 |
+
|
| 17 |
+
def unpack_and_read(zipf):
|
| 18 |
+
tmpdir = tempfile.mkdtemp()
|
| 19 |
+
with zipfile.ZipFile(zipf) as z:
|
| 20 |
+
z.extractall(tmpdir)
|
| 21 |
+
vol = load_slices_from_folder(tmpdir, glob_pattern="*.png")
|
| 22 |
+
return vol, tmpdir
|
| 23 |
+
|
| 24 |
+
vol = None
|
| 25 |
+
if uploaded:
|
| 26 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp:
|
| 27 |
+
tmp.write(uploaded.getvalue())
|
| 28 |
+
tmp.flush()
|
| 29 |
+
vol, tmpdir = unpack_and_read(tmp.name)
|
| 30 |
+
elif use_example:
|
| 31 |
+
if os.path.isdir("examples/synthetic_phantom"):
|
| 32 |
+
vol = load_slices_from_folder("examples/synthetic_phantom", glob_pattern="*.png")
|
| 33 |
+
else:
|
| 34 |
+
st.error("Example slices not found in examples/synthetic_phantom")
|
| 35 |
+
|
| 36 |
+
if vol is not None:
|
| 37 |
+
st.write("Volume shape (z,y,x):", vol.shape)
|
| 38 |
+
mid = vol.shape[0]//2
|
| 39 |
+
st.image(Image.fromarray(vol[mid].astype('uint8')), caption=f"Middle slice (index {mid})")
|
| 40 |
+
if st.button("Run reconstruction & anomaly detection"):
|
| 41 |
+
mask = simple_threshold_mask(vol, threshold=thresh)
|
| 42 |
+
verts, faces, normals, colors = mesh_and_mask_from_volume(vol, iso=0.5, spacing=(1.0,1.0,1.0), mask=mask)
|
| 43 |
+
save_mesh_ply(verts, faces, out_mesh, normals=normals, colors=colors)
|
| 44 |
+
st.success(f"Saved colored mesh to {out_mesh}")
|
| 45 |
+
# Generate simple textual explanation
|
| 46 |
+
from infer_anomaly import analyze_regions, generate_text_explanation
|
| 47 |
+
regions = analyze_regions(mask, min_voxels=20)
|
| 48 |
+
expl = generate_text_explanation(regions, vol.shape)
|
| 49 |
+
st.subheader('Detected regions explanation')
|
| 50 |
+
st.text(expl)
|
| 51 |
+
else:
|
| 52 |
+
st.info("Upload a slice zip or click 'Use example synthetic slices' to proceed.")
|
config.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
io:
|
| 2 |
+
slice_glob: "examples/synthetic_phantom/*.png"
|
| 3 |
+
reconstruct:
|
| 4 |
+
iso_value: 0.5
|
| 5 |
+
spacing: [1.0, 1.0, 1.0]
|
| 6 |
+
anomaly:
|
| 7 |
+
method: "threshold"
|
| 8 |
+
threshold: 0.6
|
| 9 |
+
model:
|
| 10 |
+
path: "models/unet_best.pt"
|
| 11 |
+
text_explainer:
|
| 12 |
+
enabled: true
|
| 13 |
+
min_region_voxels: 50
|
examples/demo_commands.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Demo commands
|
| 2 |
+
|
| 3 |
+
# 1) Generate synthetic phantom slices
|
| 4 |
+
python examples/generate_synthetic_phantom.py
|
| 5 |
+
|
| 6 |
+
# 2) Reconstruct mesh + explanation (thresholding)
|
| 7 |
+
python infer_anomaly.py --source examples/synthetic_phantom --method threshold --out demo_mesh_threshold.ply --explain_out explanation.txt
|
| 8 |
+
|
| 9 |
+
# 3) Train small UNet (optional)
|
| 10 |
+
python train_unet.py --data examples/synthetic_phantom --epochs 3 --out models/unet_demo.pt
|
| 11 |
+
|
| 12 |
+
# 4) Run model-based inference (after training)
|
| 13 |
+
python infer_anomaly.py --source examples/synthetic_phantom --method model --model_path models/unet_demo.pt --out demo_mesh_model.ply --explain_out explanation_model.txt
|
examples/example_image.png
ADDED
|
examples/generate_synthetic_phantom.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, numpy as np
|
| 2 |
+
from PIL import Image
|
| 3 |
+
os.makedirs("examples/synthetic_phantom", exist_ok=True)
|
| 4 |
+
z = 64
|
| 5 |
+
y = 128
|
| 6 |
+
x = 128
|
| 7 |
+
for i in range(z):
|
| 8 |
+
Y, X = np.meshgrid(np.linspace(-1,1,x), np.linspace(-1,1,y))
|
| 9 |
+
cx = 0.2*np.sin(2*np.pi*i/z)
|
| 10 |
+
cy = 0.2*np.cos(2*np.pi*i/z)
|
| 11 |
+
r = np.sqrt((X-cx)**2 + (Y-cy)**2)
|
| 12 |
+
img = (np.exp(-(r**2)/(2*(0.2**2))) * 255).astype(np.uint8)
|
| 13 |
+
img = img + np.random.randint(0,20,size=img.shape).astype(np.uint8)
|
| 14 |
+
Image.fromarray(img).save(f"examples/synthetic_phantom/slice_{i:03d}.png")
|
| 15 |
+
print("Synthetic phantom saved to examples/synthetic_phantom/")
|
infer_anomaly.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse, yaml, numpy as np
|
| 2 |
+
from utils import normalize_volume
|
| 3 |
+
from io_utils import load_slices_from_folder, load_nifti, load_dicom_folder, save_mesh_ply
|
| 4 |
+
from skimage import measure
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
def simple_threshold_mask(volume, threshold=0.6):
|
| 8 |
+
vol_n = normalize_volume(volume)
|
| 9 |
+
return (vol_n > threshold).astype(np.uint8)
|
| 10 |
+
|
| 11 |
+
def mask_to_vertex_colors(verts, faces, mask_vol, spacing=(1.0,1.0,1.0)):
|
| 12 |
+
coords = verts / np.array(spacing)
|
| 13 |
+
coords_idx = np.round(coords).astype(int)
|
| 14 |
+
colors = np.zeros((len(verts), 3), dtype=float)
|
| 15 |
+
zmax, ymax, xmax = mask_vol.shape
|
| 16 |
+
for i, (z,y,x) in enumerate(coords_idx):
|
| 17 |
+
if 0 <= z < zmax and 0 <= y < ymax and 0 <= x < xmax:
|
| 18 |
+
if mask_vol[z,y,x]:
|
| 19 |
+
colors[i] = np.array([1.0, 0.0, 0.0])
|
| 20 |
+
else:
|
| 21 |
+
colors[i] = np.array([0.7, 0.7, 0.7])
|
| 22 |
+
else:
|
| 23 |
+
colors[i] = np.array([0.7, 0.7, 0.7])
|
| 24 |
+
return colors
|
| 25 |
+
|
| 26 |
+
def mesh_and_mask_from_volume(vol, iso=0.5, spacing=(1.0,1.0,1.0), mask=None):
|
| 27 |
+
vol_n = normalize_volume(vol)
|
| 28 |
+
verts, faces, normals, values = measure.marching_cubes(vol_n, level=iso, spacing=spacing)
|
| 29 |
+
colors = None
|
| 30 |
+
if mask is not None:
|
| 31 |
+
colors = mask_to_vertex_colors(verts, faces, mask, spacing=spacing)
|
| 32 |
+
return verts, faces, normals, colors
|
| 33 |
+
|
| 34 |
+
def analyze_regions(mask, min_voxels=50):
|
| 35 |
+
# find connected components and simple stats
|
| 36 |
+
from scipy import ndimage as ndi
|
| 37 |
+
labeled, n = ndi.label(mask)
|
| 38 |
+
regions = []
|
| 39 |
+
for lab in range(1, n+1):
|
| 40 |
+
coords = np.argwhere(labeled==lab)
|
| 41 |
+
voxels = coords.shape[0]
|
| 42 |
+
z_mean, y_mean, x_mean = coords.mean(axis=0).tolist()
|
| 43 |
+
regions.append({'label': lab, 'voxels': int(voxels), 'center': [float(z_mean), float(y_mean), float(x_mean)]})
|
| 44 |
+
# filter small
|
| 45 |
+
regions = [r for r in regions if r['voxels'] >= min_voxels]
|
| 46 |
+
return regions
|
| 47 |
+
|
| 48 |
+
def generate_text_explanation(regions, vol_shape, spacing=(1.0,1.0,1.0)):
|
| 49 |
+
if not regions:
|
| 50 |
+
return "No anomalous regions detected above threshold."
|
| 51 |
+
|
| 52 |
+
texts = []
|
| 53 |
+
|
| 54 |
+
zdim, ydim, xdim = vol_shape
|
| 55 |
+
for r in regions:
|
| 56 |
+
zc, yc, xc = r['center']
|
| 57 |
+
# approximate location as top/middle/bottom and left/center/right
|
| 58 |
+
zpos = 'top' if zc < zdim*0.33 else ('bottom' if zc > zdim*0.66 else 'middle')
|
| 59 |
+
ypos = 'left' if xc < xdim*0.33 else ('right' if xc > xdim*0.66 else 'center')
|
| 60 |
+
texts.append(f"Region {r['label']}: approx {r['voxels']} voxels, located near the {zpos} (z~{zc:.1f}), {ypos} (x~{xc:.1f}). Suggest clinical review and consider high-resolution imaging or segmentation.")
|
| 61 |
+
return "\n".join(texts)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def main():
|
| 65 |
+
parser = argparse.ArgumentParser()
|
| 66 |
+
parser.add_argument("--source", required=True)
|
| 67 |
+
parser.add_argument("--source_type", choices=['folder','dicom','nifti'], default='folder')
|
| 68 |
+
parser.add_argument("--glob", default="*.png")
|
| 69 |
+
parser.add_argument("--config", default="config.yaml")
|
| 70 |
+
parser.add_argument("--method", choices=['threshold','model'], default='threshold')
|
| 71 |
+
parser.add_argument("--threshold", type=float, default=None)
|
| 72 |
+
parser.add_argument("--model_path", default=None)
|
| 73 |
+
parser.add_argument("--out", default="mesh_colored.ply")
|
| 74 |
+
parser.add_argument("--explain_out", default="explanation.txt")
|
| 75 |
+
args = parser.parse_args()
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
cfg = yaml.safe_load(open(args.config))
|
| 79 |
+
except Exception:
|
| 80 |
+
cfg = {}
|
| 81 |
+
cfg_thresh = cfg.get('anomaly', {}).get('threshold', 0.6)
|
| 82 |
+
threshold = args.threshold or cfg_thresh
|
| 83 |
+
spacing = tuple(cfg.get('reconstruct', {}).get('spacing', [1.0,1.0,1.0]))
|
| 84 |
+
iso = cfg.get('reconstruct',{}).get('iso_value', 0.5)
|
| 85 |
+
min_vox = cfg.get('text_explainer',{}).get('min_region_voxels', 50)
|
| 86 |
+
|
| 87 |
+
# load volume
|
| 88 |
+
if args.source_type == 'folder':
|
| 89 |
+
vol = load_slices_from_folder(args.source, glob_pattern=args.glob)
|
| 90 |
+
elif args.source_type == 'dicom':
|
| 91 |
+
vol = load_dicom_folder(args.source)
|
| 92 |
+
else:
|
| 93 |
+
vol = load_nifti(args.source)
|
| 94 |
+
|
| 95 |
+
if args.method == 'threshold':
|
| 96 |
+
mask = simple_threshold_mask(vol, threshold=threshold)
|
| 97 |
+
else:
|
| 98 |
+
# model-based per-slice segmentation (if model provided)
|
| 99 |
+
try:
|
| 100 |
+
import torch
|
| 101 |
+
from models.unet import UNet
|
| 102 |
+
model = UNet(in_channels=1, out_channels=1)
|
| 103 |
+
model.load_state_dict(torch.load(args.model_path, map_location='cpu'))
|
| 104 |
+
model.eval()
|
| 105 |
+
vol_n = normalize_volume(vol)
|
| 106 |
+
mask = np.zeros_like(vol_n, dtype=np.uint8)
|
| 107 |
+
for i in range(vol_n.shape[0]):
|
| 108 |
+
s = vol_n[i]
|
| 109 |
+
x = (s - s.min())/(s.max()-s.min()+1e-8)
|
| 110 |
+
import torch
|
| 111 |
+
inp = torch.tensor(x[np.newaxis, np.newaxis, ...], dtype=torch.float32)
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
out = model(inp).numpy()[0,0]
|
| 114 |
+
mask[i] = (out > 0.5).astype(np.uint8)
|
| 115 |
+
except Exception as e:
|
| 116 |
+
print("Model-based method failed:", e)
|
| 117 |
+
mask = simple_threshold_mask(vol, threshold=threshold)
|
| 118 |
+
|
| 119 |
+
regions = analyze_regions(mask, min_voxels=min_vox)
|
| 120 |
+
explanation = generate_text_explanation(regions, vol.shape, spacing=spacing)
|
| 121 |
+
verts, faces, normals, colors = mesh_and_mask_from_volume(vol, iso=iso, spacing=spacing, mask=mask)
|
| 122 |
+
save_mesh_ply(verts, faces, args.out, normals=normals, colors=colors)
|
| 123 |
+
with open(args.explain_out, 'w') as f:
|
| 124 |
+
f.write(explanation)
|
| 125 |
+
print("Saved colored mesh to", args.out)
|
| 126 |
+
print("Saved textual explanation to", args.explain_out)
|
| 127 |
+
|
| 128 |
+
if __name__ == '__main__':
|
| 129 |
+
main()
|
io_utils.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, glob
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import pydicom
|
| 5 |
+
import nibabel as nib
|
| 6 |
+
|
| 7 |
+
def load_slices_from_folder(folder, glob_pattern="*.png", sort_numerically=True):
|
| 8 |
+
paths = sorted(glob.glob(os.path.join(folder, glob_pattern)))
|
| 9 |
+
if sort_numerically:
|
| 10 |
+
try:
|
| 11 |
+
paths = sorted(paths, key=lambda p: int(''.join(filter(str.isdigit, os.path.basename(p))) or 0))
|
| 12 |
+
except Exception:
|
| 13 |
+
pass
|
| 14 |
+
slices = [np.array(Image.open(p).convert("L")) for p in paths]
|
| 15 |
+
vol = np.stack(slices, axis=0) # z, y, x
|
| 16 |
+
return vol.astype(np.float32)
|
| 17 |
+
|
| 18 |
+
def load_dicom_folder(folder):
|
| 19 |
+
files = [p for p in glob.glob(os.path.join(folder, "**"), recursive=False) if os.path.isfile(p)]
|
| 20 |
+
ds = []
|
| 21 |
+
for f in files:
|
| 22 |
+
try:
|
| 23 |
+
d = pydicom.dcmread(f)
|
| 24 |
+
ds.append((getattr(d, "InstanceNumber", 0), d))
|
| 25 |
+
except Exception:
|
| 26 |
+
continue
|
| 27 |
+
ds = sorted(ds, key=lambda x: x[0])
|
| 28 |
+
imgs = [d.pixel_array for _, d in ds]
|
| 29 |
+
vol = np.stack(imgs, axis=0).astype(np.float32)
|
| 30 |
+
return vol
|
| 31 |
+
|
| 32 |
+
def load_nifti(path):
|
| 33 |
+
nii = nib.load(path)
|
| 34 |
+
arr = nii.get_fdata().astype(np.float32)
|
| 35 |
+
if arr.ndim == 3:
|
| 36 |
+
return np.transpose(arr, (2,1,0))
|
| 37 |
+
return arr
|
| 38 |
+
|
| 39 |
+
def save_mesh_ply(verts, faces, out_path, normals=None, colors=None):
|
| 40 |
+
import numpy as np
|
| 41 |
+
with open(out_path, 'w') as f:
|
| 42 |
+
f.write("ply\nformat ascii 1.0\n")
|
| 43 |
+
f.write(f"element vertex {len(verts)}\n")
|
| 44 |
+
f.write("property float x\nproperty float y\nproperty float z\n")
|
| 45 |
+
if colors is not None:
|
| 46 |
+
f.write("property uchar red\nproperty uchar green\nproperty uchar blue\n")
|
| 47 |
+
if normals is not None:
|
| 48 |
+
f.write("property float nx\nproperty float ny\nproperty nz\n")
|
| 49 |
+
f.write(f"element face {len(faces)}\n")
|
| 50 |
+
f.write("property list uchar int vertex_indices\n")
|
| 51 |
+
f.write("end_header\n")
|
| 52 |
+
for i, v in enumerate(verts):
|
| 53 |
+
line = f"{v[0]} {v[1]} {v[2]}"
|
| 54 |
+
if colors is not None:
|
| 55 |
+
c = (np.clip(colors[i], 0, 1) * 255).astype(int)
|
| 56 |
+
line += f" {c[0]} {c[1]} {c[2]}"
|
| 57 |
+
if normals is not None:
|
| 58 |
+
n = normals[i]
|
| 59 |
+
line += f" {n[0]} {n[1]} {n[2]}"
|
| 60 |
+
f.write(line + "\n")
|
| 61 |
+
for face in faces:
|
| 62 |
+
f.write(f"3 {face[0]} {face[1]} {face[2]}\n")
|
models/unet.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
class DoubleConv(nn.Module):
|
| 7 |
+
def __init__(self, in_ch, out_ch):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.net = nn.Sequential(
|
| 10 |
+
nn.Conv2d(in_ch, out_ch, 3, padding=1),
|
| 11 |
+
nn.BatchNorm2d(out_ch),
|
| 12 |
+
nn.ReLU(inplace=True),
|
| 13 |
+
nn.Conv2d(out_ch, out_ch, 3, padding=1),
|
| 14 |
+
nn.BatchNorm2d(out_ch),
|
| 15 |
+
nn.ReLU(inplace=True),
|
| 16 |
+
)
|
| 17 |
+
def forward(self, x): return self.net(x)
|
| 18 |
+
|
| 19 |
+
class Down(nn.Module):
|
| 20 |
+
def __init__(self, in_ch, out_ch):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.net = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_ch, out_ch))
|
| 23 |
+
def forward(self, x): return self.net(x)
|
| 24 |
+
|
| 25 |
+
class Up(nn.Module):
|
| 26 |
+
def __init__(self, in_ch, out_ch):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.up = nn.ConvTranspose2d(in_ch, in_ch//2, 2, stride=2)
|
| 29 |
+
self.conv = DoubleConv(in_ch, out_ch)
|
| 30 |
+
def forward(self, x1, x2):
|
| 31 |
+
x1 = self.up(x1)
|
| 32 |
+
diffY = x2.size()[2] - x1.size()[2]
|
| 33 |
+
diffX = x2.size()[3] - x1.size()[3]
|
| 34 |
+
x1 = F.pad(x1, [diffX//2, diffX - diffX//2, diffY//2, diffY - diffY//2])
|
| 35 |
+
x = torch.cat([x2, x1], dim=1)
|
| 36 |
+
return self.conv(x)
|
| 37 |
+
|
| 38 |
+
class UNet(nn.Module):
|
| 39 |
+
def __init__(self, in_channels=1, out_channels=1, base_c=32):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.inc = DoubleConv(in_channels, base_c)
|
| 42 |
+
self.down1 = Down(base_c, base_c*2)
|
| 43 |
+
self.down2 = Down(base_c*2, base_c*4)
|
| 44 |
+
self.up1 = Up(base_c*4, base_c*2)
|
| 45 |
+
self.up2 = Up(base_c*2, base_c)
|
| 46 |
+
self.outc = nn.Conv2d(base_c, out_channels, 1)
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
x1 = self.inc(x)
|
| 49 |
+
x2 = self.down1(x1)
|
| 50 |
+
x3 = self.down2(x2)
|
| 51 |
+
x = self.up1(x3, x2)
|
| 52 |
+
x = self.up2(x, x1)
|
| 53 |
+
x = self.outc(x)
|
| 54 |
+
return torch.sigmoid(x)
|
notebooks/MEDIVIEW-3D_Demo.ipynb
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"cells": [{"cell_type": "markdown", "metadata": {}, "source": ["# MEDIVIEW-3D Demo Notebook\\n", "Run the demo scripts locally."]}, {"cell_type": "code", "metadata": {}, "source": ["!python examples/generate_synthetic_phantom.py\\n", "!python infer_anomaly.py --source examples/synthetic_phantom --method threshold --out demo_mesh_threshold.ply --explain_out explanation.txt\\n"]}], "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, "language_info": {"name": "python"}}, "nbformat": 4, "nbformat_minor": 5}
|
reconstruct_3d.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse, yaml
|
| 2 |
+
import numpy as np
|
| 3 |
+
from skimage import measure
|
| 4 |
+
from utils import normalize_volume
|
| 5 |
+
from io_utils import load_slices_from_folder, load_dicom_folder, load_nifti, save_mesh_ply
|
| 6 |
+
|
| 7 |
+
def build_volume(source, source_type='folder', glob_pattern="*.png"):
|
| 8 |
+
if source_type == 'folder':
|
| 9 |
+
vol = load_slices_from_folder(source, glob_pattern=glob_pattern)
|
| 10 |
+
elif source_type == 'dicom':
|
| 11 |
+
vol = load_dicom_folder(source)
|
| 12 |
+
elif source_type == 'nifti':
|
| 13 |
+
vol = load_nifti(source)
|
| 14 |
+
else:
|
| 15 |
+
raise ValueError("Unsupported source_type")
|
| 16 |
+
return vol
|
| 17 |
+
|
| 18 |
+
def mesh_from_volume(vol, iso=0.5, spacing=(1.0,1.0,1.0)):
|
| 19 |
+
vol_n = normalize_volume(vol)
|
| 20 |
+
verts, faces, normals, values = measure.marching_cubes(vol_n, level=iso, spacing=spacing)
|
| 21 |
+
return verts, faces, normals
|
| 22 |
+
|
| 23 |
+
def main():
|
| 24 |
+
parser = argparse.ArgumentParser()
|
| 25 |
+
parser.add_argument("--source", required=True)
|
| 26 |
+
parser.add_argument("--source_type", choices=['folder','dicom','nifti'], default='folder')
|
| 27 |
+
parser.add_argument("--glob", default="*.png")
|
| 28 |
+
parser.add_argument("--config", default="config.yaml")
|
| 29 |
+
parser.add_argument("--out", default="mesh.ply")
|
| 30 |
+
parser.add_argument("--iso", type=float, default=None)
|
| 31 |
+
args = parser.parse_args()
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
cfg = yaml.safe_load(open(args.config))
|
| 35 |
+
except Exception:
|
| 36 |
+
cfg = {}
|
| 37 |
+
iso = args.iso or (cfg.get('reconstruct',{}).get('iso_value', 0.5))
|
| 38 |
+
spacing = cfg.get('reconstruct',{}).get('spacing', [1.0,1.0,1.0])
|
| 39 |
+
vol = build_volume(args.source, source_type=args.source_type, glob_pattern=args.glob)
|
| 40 |
+
verts, faces, normals = mesh_from_volume(vol, iso=iso, spacing=tuple(spacing))
|
| 41 |
+
save_mesh_ply(verts, faces, args.out, normals=normals)
|
| 42 |
+
print(f"Saved mesh to {args.out} (verts={len(verts)}, faces={len(faces)})")
|
| 43 |
+
|
| 44 |
+
if __name__ == '__main__':
|
| 45 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy>=1.21
|
| 2 |
+
scipy
|
| 3 |
+
scikit-image
|
| 4 |
+
matplotlib
|
| 5 |
+
pillow
|
| 6 |
+
imageio
|
| 7 |
+
pydicom
|
| 8 |
+
nibabel
|
| 9 |
+
PyYAML
|
| 10 |
+
trimesh
|
| 11 |
+
torch>=1.10
|
| 12 |
+
tqdm
|
| 13 |
+
streamlit
|
train_unet.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, glob
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.data import Dataset, DataLoader
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from models.unet import UNet
|
| 8 |
+
import torch.optim as optim
|
| 9 |
+
|
| 10 |
+
class SliceDataset(Dataset):
|
| 11 |
+
def __init__(self, folder):
|
| 12 |
+
self.paths = sorted(glob.glob(os.path.join(folder, "*.png")))
|
| 13 |
+
def __len__(self): return len(self.paths)
|
| 14 |
+
def __getitem__(self, idx):
|
| 15 |
+
p = self.paths[idx]
|
| 16 |
+
img = np.array(Image.open(p).convert("L"), dtype=np.float32)/255.0
|
| 17 |
+
mask = (img > img.mean() + 0.25).astype(np.float32)
|
| 18 |
+
img = img[np.newaxis,...]
|
| 19 |
+
mask = mask[np.newaxis,...]
|
| 20 |
+
return torch.tensor(img), torch.tensor(mask)
|
| 21 |
+
|
| 22 |
+
def train(folder, epochs=3, out="models/unet_best.pt"):
|
| 23 |
+
ds = SliceDataset(folder)
|
| 24 |
+
dl = DataLoader(ds, batch_size=4, shuffle=True)
|
| 25 |
+
model = UNet(in_channels=1, out_channels=1)
|
| 26 |
+
opt = optim.Adam(model.parameters(), lr=1e-3)
|
| 27 |
+
loss_fn = nn.BCELoss()
|
| 28 |
+
for epoch in range(epochs):
|
| 29 |
+
total=0
|
| 30 |
+
model.train()
|
| 31 |
+
for x,y in dl:
|
| 32 |
+
outp = model(x)
|
| 33 |
+
loss = loss_fn(outp, y)
|
| 34 |
+
opt.zero_grad(); loss.backward(); opt.step()
|
| 35 |
+
total += loss.item()
|
| 36 |
+
print(f"Epoch {epoch+1}, loss {total/len(dl):.4f}")
|
| 37 |
+
os.makedirs(os.path.dirname(out), exist_ok=True)
|
| 38 |
+
torch.save(model.state_dict(), out)
|
| 39 |
+
print("Saved model to", out)
|
| 40 |
+
|
| 41 |
+
if __name__ == '__main__':
|
| 42 |
+
import argparse
|
| 43 |
+
p = argparse.ArgumentParser()
|
| 44 |
+
p.add_argument('--data', default='examples/synthetic_phantom')
|
| 45 |
+
p.add_argument('--epochs', type=int, default=3)
|
| 46 |
+
p.add_argument('--out', default='models/unet_best.pt')
|
| 47 |
+
args = p.parse_args()
|
| 48 |
+
train(args.data, epochs=args.epochs, out=args.out)
|
utils.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
def normalize_volume(vol):
|
| 4 |
+
vmin = np.nanmin(vol)
|
| 5 |
+
vmax = np.nanmax(vol)
|
| 6 |
+
if vmax - vmin < 1e-8:
|
| 7 |
+
return np.zeros_like(vol)
|
| 8 |
+
return (vol - vmin) / (vmax - vmin)
|
| 9 |
+
|
| 10 |
+
def ensure_3d(volume):
|
| 11 |
+
if volume.ndim == 2:
|
| 12 |
+
return volume[np.newaxis, ...]
|
| 13 |
+
return volume
|