hmnshudhmn24 commited on
Commit
5b023ff
·
verified ·
1 Parent(s): f14713b

Upload 16 files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,3 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.pt filter=lfs diff=lfs merge=lfs -text
2
+ *.nii filter=lfs diff=lfs merge=lfs -text
3
+ *.nii.gz filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
LICENSE ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Apache License 2.0
2
+
3
+ Copyright 2025 hmnshudhmn24
4
+
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
README.md CHANGED
@@ -1,3 +1,58 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: apache-2.0
5
+ tags:
6
+ - image-to-3d
7
+ - medical-imaging
8
+ - reconstruction
9
+ - segmentation
10
+ - explainability
11
+ pipeline_tag: image-to-3d
12
+ library_name: python
13
+ ---
14
+
15
+ # MEDIVIEW-3D (Advanced)
16
+
17
+ **MEDIVIEW-3D** converts 2D medical image slices into a 3D reconstruction, localizes anomalous regions, and generates **textual explanations** describing the detected regions (size, approximate location, and suggested next steps).
18
+
19
+ **Important:** This is a research/demo tool and **not** a medical device. Do not use for clinical decisions.
20
+
21
+ ## Quickstart (demo)
22
+
23
+ 1. Install:
24
+ ```bash
25
+ pip install -r requirements.txt
26
+ ```
27
+
28
+ 2. Generate synthetic phantom slices:
29
+ ```bash
30
+ python examples/generate_synthetic_phantom.py
31
+ ```
32
+
33
+ 3. Run inference with thresholding and get a mesh + explanation:
34
+ ```bash
35
+ python infer_anomaly.py --source examples/synthetic_phantom --method threshold --out demo_mesh_threshold.ply --explain_out explanation.txt
36
+ ```
37
+
38
+ 4. (Optional) Train small UNet and run model-based inference:
39
+ ```bash
40
+ python train_unet.py --data examples/synthetic_phantom --epochs 3 --out models/unet_demo.pt
41
+ python infer_anomaly.py --source examples/synthetic_phantom --method model --model_path models/unet_demo.pt --out demo_mesh_model.ply --explain_out explanation_model.txt
42
+ ```
43
+
44
+ 5. Run Streamlit demo:
45
+ ```bash
46
+ streamlit run app.py
47
+ ```
48
+
49
+ ## What you get
50
+ - 3D mesh `.ply` with anomaly regions colored red
51
+ - `explanation.txt` with human-friendly descriptions of detected regions
52
+ - Example synthetic phantom (no patient data)
53
+ - Small UNet implementation for demo training
54
+
55
+ ## Safety & Limitations
56
+ - Demo-only; not clinically validated.
57
+ - Do not upload identifiable patient data to public repos.
58
+ - For real medical use, integrate robust preprocessing and obtain regulatory approvals.
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import numpy as np
4
+ import tempfile, os, zipfile
5
+ from io_utils import load_slices_from_folder
6
+ from infer_anomaly import simple_threshold_mask, mesh_and_mask_from_volume
7
+ from io_utils import save_mesh_ply
8
+
9
+ st.set_page_config(page_title="MEDIVIEW-3D", layout="wide")
10
+ st.title("MEDIVIEW-3D — 2D → 3D Reconstruction + Explanation")
11
+
12
+ uploaded = st.file_uploader("Upload a zip of PNG slices (or skip to use example)", type=["zip"])
13
+ use_example = st.button("Use example synthetic slices")
14
+ thresh = st.slider("Threshold (for simple anomaly)", 0.0, 1.0, 0.6)
15
+ out_mesh = st.text_input("Output mesh path", value="mesh_demo.ply")
16
+
17
+ def unpack_and_read(zipf):
18
+ tmpdir = tempfile.mkdtemp()
19
+ with zipfile.ZipFile(zipf) as z:
20
+ z.extractall(tmpdir)
21
+ vol = load_slices_from_folder(tmpdir, glob_pattern="*.png")
22
+ return vol, tmpdir
23
+
24
+ vol = None
25
+ if uploaded:
26
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp:
27
+ tmp.write(uploaded.getvalue())
28
+ tmp.flush()
29
+ vol, tmpdir = unpack_and_read(tmp.name)
30
+ elif use_example:
31
+ if os.path.isdir("examples/synthetic_phantom"):
32
+ vol = load_slices_from_folder("examples/synthetic_phantom", glob_pattern="*.png")
33
+ else:
34
+ st.error("Example slices not found in examples/synthetic_phantom")
35
+
36
+ if vol is not None:
37
+ st.write("Volume shape (z,y,x):", vol.shape)
38
+ mid = vol.shape[0]//2
39
+ st.image(Image.fromarray(vol[mid].astype('uint8')), caption=f"Middle slice (index {mid})")
40
+ if st.button("Run reconstruction & anomaly detection"):
41
+ mask = simple_threshold_mask(vol, threshold=thresh)
42
+ verts, faces, normals, colors = mesh_and_mask_from_volume(vol, iso=0.5, spacing=(1.0,1.0,1.0), mask=mask)
43
+ save_mesh_ply(verts, faces, out_mesh, normals=normals, colors=colors)
44
+ st.success(f"Saved colored mesh to {out_mesh}")
45
+ # Generate simple textual explanation
46
+ from infer_anomaly import analyze_regions, generate_text_explanation
47
+ regions = analyze_regions(mask, min_voxels=20)
48
+ expl = generate_text_explanation(regions, vol.shape)
49
+ st.subheader('Detected regions explanation')
50
+ st.text(expl)
51
+ else:
52
+ st.info("Upload a slice zip or click 'Use example synthetic slices' to proceed.")
config.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ io:
2
+ slice_glob: "examples/synthetic_phantom/*.png"
3
+ reconstruct:
4
+ iso_value: 0.5
5
+ spacing: [1.0, 1.0, 1.0]
6
+ anomaly:
7
+ method: "threshold"
8
+ threshold: 0.6
9
+ model:
10
+ path: "models/unet_best.pt"
11
+ text_explainer:
12
+ enabled: true
13
+ min_region_voxels: 50
examples/demo_commands.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Demo commands
2
+
3
+ # 1) Generate synthetic phantom slices
4
+ python examples/generate_synthetic_phantom.py
5
+
6
+ # 2) Reconstruct mesh + explanation (thresholding)
7
+ python infer_anomaly.py --source examples/synthetic_phantom --method threshold --out demo_mesh_threshold.ply --explain_out explanation.txt
8
+
9
+ # 3) Train small UNet (optional)
10
+ python train_unet.py --data examples/synthetic_phantom --epochs 3 --out models/unet_demo.pt
11
+
12
+ # 4) Run model-based inference (after training)
13
+ python infer_anomaly.py --source examples/synthetic_phantom --method model --model_path models/unet_demo.pt --out demo_mesh_model.ply --explain_out explanation_model.txt
examples/example_image.png ADDED
examples/generate_synthetic_phantom.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, numpy as np
2
+ from PIL import Image
3
+ os.makedirs("examples/synthetic_phantom", exist_ok=True)
4
+ z = 64
5
+ y = 128
6
+ x = 128
7
+ for i in range(z):
8
+ Y, X = np.meshgrid(np.linspace(-1,1,x), np.linspace(-1,1,y))
9
+ cx = 0.2*np.sin(2*np.pi*i/z)
10
+ cy = 0.2*np.cos(2*np.pi*i/z)
11
+ r = np.sqrt((X-cx)**2 + (Y-cy)**2)
12
+ img = (np.exp(-(r**2)/(2*(0.2**2))) * 255).astype(np.uint8)
13
+ img = img + np.random.randint(0,20,size=img.shape).astype(np.uint8)
14
+ Image.fromarray(img).save(f"examples/synthetic_phantom/slice_{i:03d}.png")
15
+ print("Synthetic phantom saved to examples/synthetic_phantom/")
infer_anomaly.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, yaml, numpy as np
2
+ from utils import normalize_volume
3
+ from io_utils import load_slices_from_folder, load_nifti, load_dicom_folder, save_mesh_ply
4
+ from skimage import measure
5
+ import os
6
+
7
+ def simple_threshold_mask(volume, threshold=0.6):
8
+ vol_n = normalize_volume(volume)
9
+ return (vol_n > threshold).astype(np.uint8)
10
+
11
+ def mask_to_vertex_colors(verts, faces, mask_vol, spacing=(1.0,1.0,1.0)):
12
+ coords = verts / np.array(spacing)
13
+ coords_idx = np.round(coords).astype(int)
14
+ colors = np.zeros((len(verts), 3), dtype=float)
15
+ zmax, ymax, xmax = mask_vol.shape
16
+ for i, (z,y,x) in enumerate(coords_idx):
17
+ if 0 <= z < zmax and 0 <= y < ymax and 0 <= x < xmax:
18
+ if mask_vol[z,y,x]:
19
+ colors[i] = np.array([1.0, 0.0, 0.0])
20
+ else:
21
+ colors[i] = np.array([0.7, 0.7, 0.7])
22
+ else:
23
+ colors[i] = np.array([0.7, 0.7, 0.7])
24
+ return colors
25
+
26
+ def mesh_and_mask_from_volume(vol, iso=0.5, spacing=(1.0,1.0,1.0), mask=None):
27
+ vol_n = normalize_volume(vol)
28
+ verts, faces, normals, values = measure.marching_cubes(vol_n, level=iso, spacing=spacing)
29
+ colors = None
30
+ if mask is not None:
31
+ colors = mask_to_vertex_colors(verts, faces, mask, spacing=spacing)
32
+ return verts, faces, normals, colors
33
+
34
+ def analyze_regions(mask, min_voxels=50):
35
+ # find connected components and simple stats
36
+ from scipy import ndimage as ndi
37
+ labeled, n = ndi.label(mask)
38
+ regions = []
39
+ for lab in range(1, n+1):
40
+ coords = np.argwhere(labeled==lab)
41
+ voxels = coords.shape[0]
42
+ z_mean, y_mean, x_mean = coords.mean(axis=0).tolist()
43
+ regions.append({'label': lab, 'voxels': int(voxels), 'center': [float(z_mean), float(y_mean), float(x_mean)]})
44
+ # filter small
45
+ regions = [r for r in regions if r['voxels'] >= min_voxels]
46
+ return regions
47
+
48
+ def generate_text_explanation(regions, vol_shape, spacing=(1.0,1.0,1.0)):
49
+ if not regions:
50
+ return "No anomalous regions detected above threshold."
51
+
52
+ texts = []
53
+
54
+ zdim, ydim, xdim = vol_shape
55
+ for r in regions:
56
+ zc, yc, xc = r['center']
57
+ # approximate location as top/middle/bottom and left/center/right
58
+ zpos = 'top' if zc < zdim*0.33 else ('bottom' if zc > zdim*0.66 else 'middle')
59
+ ypos = 'left' if xc < xdim*0.33 else ('right' if xc > xdim*0.66 else 'center')
60
+ texts.append(f"Region {r['label']}: approx {r['voxels']} voxels, located near the {zpos} (z~{zc:.1f}), {ypos} (x~{xc:.1f}). Suggest clinical review and consider high-resolution imaging or segmentation.")
61
+ return "\n".join(texts)
62
+
63
+
64
+ def main():
65
+ parser = argparse.ArgumentParser()
66
+ parser.add_argument("--source", required=True)
67
+ parser.add_argument("--source_type", choices=['folder','dicom','nifti'], default='folder')
68
+ parser.add_argument("--glob", default="*.png")
69
+ parser.add_argument("--config", default="config.yaml")
70
+ parser.add_argument("--method", choices=['threshold','model'], default='threshold')
71
+ parser.add_argument("--threshold", type=float, default=None)
72
+ parser.add_argument("--model_path", default=None)
73
+ parser.add_argument("--out", default="mesh_colored.ply")
74
+ parser.add_argument("--explain_out", default="explanation.txt")
75
+ args = parser.parse_args()
76
+
77
+ try:
78
+ cfg = yaml.safe_load(open(args.config))
79
+ except Exception:
80
+ cfg = {}
81
+ cfg_thresh = cfg.get('anomaly', {}).get('threshold', 0.6)
82
+ threshold = args.threshold or cfg_thresh
83
+ spacing = tuple(cfg.get('reconstruct', {}).get('spacing', [1.0,1.0,1.0]))
84
+ iso = cfg.get('reconstruct',{}).get('iso_value', 0.5)
85
+ min_vox = cfg.get('text_explainer',{}).get('min_region_voxels', 50)
86
+
87
+ # load volume
88
+ if args.source_type == 'folder':
89
+ vol = load_slices_from_folder(args.source, glob_pattern=args.glob)
90
+ elif args.source_type == 'dicom':
91
+ vol = load_dicom_folder(args.source)
92
+ else:
93
+ vol = load_nifti(args.source)
94
+
95
+ if args.method == 'threshold':
96
+ mask = simple_threshold_mask(vol, threshold=threshold)
97
+ else:
98
+ # model-based per-slice segmentation (if model provided)
99
+ try:
100
+ import torch
101
+ from models.unet import UNet
102
+ model = UNet(in_channels=1, out_channels=1)
103
+ model.load_state_dict(torch.load(args.model_path, map_location='cpu'))
104
+ model.eval()
105
+ vol_n = normalize_volume(vol)
106
+ mask = np.zeros_like(vol_n, dtype=np.uint8)
107
+ for i in range(vol_n.shape[0]):
108
+ s = vol_n[i]
109
+ x = (s - s.min())/(s.max()-s.min()+1e-8)
110
+ import torch
111
+ inp = torch.tensor(x[np.newaxis, np.newaxis, ...], dtype=torch.float32)
112
+ with torch.no_grad():
113
+ out = model(inp).numpy()[0,0]
114
+ mask[i] = (out > 0.5).astype(np.uint8)
115
+ except Exception as e:
116
+ print("Model-based method failed:", e)
117
+ mask = simple_threshold_mask(vol, threshold=threshold)
118
+
119
+ regions = analyze_regions(mask, min_voxels=min_vox)
120
+ explanation = generate_text_explanation(regions, vol.shape, spacing=spacing)
121
+ verts, faces, normals, colors = mesh_and_mask_from_volume(vol, iso=iso, spacing=spacing, mask=mask)
122
+ save_mesh_ply(verts, faces, args.out, normals=normals, colors=colors)
123
+ with open(args.explain_out, 'w') as f:
124
+ f.write(explanation)
125
+ print("Saved colored mesh to", args.out)
126
+ print("Saved textual explanation to", args.explain_out)
127
+
128
+ if __name__ == '__main__':
129
+ main()
io_utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, glob
2
+ import numpy as np
3
+ from PIL import Image
4
+ import pydicom
5
+ import nibabel as nib
6
+
7
+ def load_slices_from_folder(folder, glob_pattern="*.png", sort_numerically=True):
8
+ paths = sorted(glob.glob(os.path.join(folder, glob_pattern)))
9
+ if sort_numerically:
10
+ try:
11
+ paths = sorted(paths, key=lambda p: int(''.join(filter(str.isdigit, os.path.basename(p))) or 0))
12
+ except Exception:
13
+ pass
14
+ slices = [np.array(Image.open(p).convert("L")) for p in paths]
15
+ vol = np.stack(slices, axis=0) # z, y, x
16
+ return vol.astype(np.float32)
17
+
18
+ def load_dicom_folder(folder):
19
+ files = [p for p in glob.glob(os.path.join(folder, "**"), recursive=False) if os.path.isfile(p)]
20
+ ds = []
21
+ for f in files:
22
+ try:
23
+ d = pydicom.dcmread(f)
24
+ ds.append((getattr(d, "InstanceNumber", 0), d))
25
+ except Exception:
26
+ continue
27
+ ds = sorted(ds, key=lambda x: x[0])
28
+ imgs = [d.pixel_array for _, d in ds]
29
+ vol = np.stack(imgs, axis=0).astype(np.float32)
30
+ return vol
31
+
32
+ def load_nifti(path):
33
+ nii = nib.load(path)
34
+ arr = nii.get_fdata().astype(np.float32)
35
+ if arr.ndim == 3:
36
+ return np.transpose(arr, (2,1,0))
37
+ return arr
38
+
39
+ def save_mesh_ply(verts, faces, out_path, normals=None, colors=None):
40
+ import numpy as np
41
+ with open(out_path, 'w') as f:
42
+ f.write("ply\nformat ascii 1.0\n")
43
+ f.write(f"element vertex {len(verts)}\n")
44
+ f.write("property float x\nproperty float y\nproperty float z\n")
45
+ if colors is not None:
46
+ f.write("property uchar red\nproperty uchar green\nproperty uchar blue\n")
47
+ if normals is not None:
48
+ f.write("property float nx\nproperty float ny\nproperty nz\n")
49
+ f.write(f"element face {len(faces)}\n")
50
+ f.write("property list uchar int vertex_indices\n")
51
+ f.write("end_header\n")
52
+ for i, v in enumerate(verts):
53
+ line = f"{v[0]} {v[1]} {v[2]}"
54
+ if colors is not None:
55
+ c = (np.clip(colors[i], 0, 1) * 255).astype(int)
56
+ line += f" {c[0]} {c[1]} {c[2]}"
57
+ if normals is not None:
58
+ n = normals[i]
59
+ line += f" {n[0]} {n[1]} {n[2]}"
60
+ f.write(line + "\n")
61
+ for face in faces:
62
+ f.write(f"3 {face[0]} {face[1]} {face[2]}\n")
models/unet.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class DoubleConv(nn.Module):
7
+ def __init__(self, in_ch, out_ch):
8
+ super().__init__()
9
+ self.net = nn.Sequential(
10
+ nn.Conv2d(in_ch, out_ch, 3, padding=1),
11
+ nn.BatchNorm2d(out_ch),
12
+ nn.ReLU(inplace=True),
13
+ nn.Conv2d(out_ch, out_ch, 3, padding=1),
14
+ nn.BatchNorm2d(out_ch),
15
+ nn.ReLU(inplace=True),
16
+ )
17
+ def forward(self, x): return self.net(x)
18
+
19
+ class Down(nn.Module):
20
+ def __init__(self, in_ch, out_ch):
21
+ super().__init__()
22
+ self.net = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_ch, out_ch))
23
+ def forward(self, x): return self.net(x)
24
+
25
+ class Up(nn.Module):
26
+ def __init__(self, in_ch, out_ch):
27
+ super().__init__()
28
+ self.up = nn.ConvTranspose2d(in_ch, in_ch//2, 2, stride=2)
29
+ self.conv = DoubleConv(in_ch, out_ch)
30
+ def forward(self, x1, x2):
31
+ x1 = self.up(x1)
32
+ diffY = x2.size()[2] - x1.size()[2]
33
+ diffX = x2.size()[3] - x1.size()[3]
34
+ x1 = F.pad(x1, [diffX//2, diffX - diffX//2, diffY//2, diffY - diffY//2])
35
+ x = torch.cat([x2, x1], dim=1)
36
+ return self.conv(x)
37
+
38
+ class UNet(nn.Module):
39
+ def __init__(self, in_channels=1, out_channels=1, base_c=32):
40
+ super().__init__()
41
+ self.inc = DoubleConv(in_channels, base_c)
42
+ self.down1 = Down(base_c, base_c*2)
43
+ self.down2 = Down(base_c*2, base_c*4)
44
+ self.up1 = Up(base_c*4, base_c*2)
45
+ self.up2 = Up(base_c*2, base_c)
46
+ self.outc = nn.Conv2d(base_c, out_channels, 1)
47
+ def forward(self, x):
48
+ x1 = self.inc(x)
49
+ x2 = self.down1(x1)
50
+ x3 = self.down2(x2)
51
+ x = self.up1(x3, x2)
52
+ x = self.up2(x, x1)
53
+ x = self.outc(x)
54
+ return torch.sigmoid(x)
notebooks/MEDIVIEW-3D_Demo.ipynb ADDED
@@ -0,0 +1 @@
 
 
1
+ {"cells": [{"cell_type": "markdown", "metadata": {}, "source": ["# MEDIVIEW-3D Demo Notebook\\n", "Run the demo scripts locally."]}, {"cell_type": "code", "metadata": {}, "source": ["!python examples/generate_synthetic_phantom.py\\n", "!python infer_anomaly.py --source examples/synthetic_phantom --method threshold --out demo_mesh_threshold.ply --explain_out explanation.txt\\n"]}], "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, "language_info": {"name": "python"}}, "nbformat": 4, "nbformat_minor": 5}
reconstruct_3d.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, yaml
2
+ import numpy as np
3
+ from skimage import measure
4
+ from utils import normalize_volume
5
+ from io_utils import load_slices_from_folder, load_dicom_folder, load_nifti, save_mesh_ply
6
+
7
+ def build_volume(source, source_type='folder', glob_pattern="*.png"):
8
+ if source_type == 'folder':
9
+ vol = load_slices_from_folder(source, glob_pattern=glob_pattern)
10
+ elif source_type == 'dicom':
11
+ vol = load_dicom_folder(source)
12
+ elif source_type == 'nifti':
13
+ vol = load_nifti(source)
14
+ else:
15
+ raise ValueError("Unsupported source_type")
16
+ return vol
17
+
18
+ def mesh_from_volume(vol, iso=0.5, spacing=(1.0,1.0,1.0)):
19
+ vol_n = normalize_volume(vol)
20
+ verts, faces, normals, values = measure.marching_cubes(vol_n, level=iso, spacing=spacing)
21
+ return verts, faces, normals
22
+
23
+ def main():
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument("--source", required=True)
26
+ parser.add_argument("--source_type", choices=['folder','dicom','nifti'], default='folder')
27
+ parser.add_argument("--glob", default="*.png")
28
+ parser.add_argument("--config", default="config.yaml")
29
+ parser.add_argument("--out", default="mesh.ply")
30
+ parser.add_argument("--iso", type=float, default=None)
31
+ args = parser.parse_args()
32
+
33
+ try:
34
+ cfg = yaml.safe_load(open(args.config))
35
+ except Exception:
36
+ cfg = {}
37
+ iso = args.iso or (cfg.get('reconstruct',{}).get('iso_value', 0.5))
38
+ spacing = cfg.get('reconstruct',{}).get('spacing', [1.0,1.0,1.0])
39
+ vol = build_volume(args.source, source_type=args.source_type, glob_pattern=args.glob)
40
+ verts, faces, normals = mesh_from_volume(vol, iso=iso, spacing=tuple(spacing))
41
+ save_mesh_ply(verts, faces, args.out, normals=normals)
42
+ print(f"Saved mesh to {args.out} (verts={len(verts)}, faces={len(faces)})")
43
+
44
+ if __name__ == '__main__':
45
+ main()
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy>=1.21
2
+ scipy
3
+ scikit-image
4
+ matplotlib
5
+ pillow
6
+ imageio
7
+ pydicom
8
+ nibabel
9
+ PyYAML
10
+ trimesh
11
+ torch>=1.10
12
+ tqdm
13
+ streamlit
train_unet.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, glob
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torch
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import torch.nn as nn
7
+ from models.unet import UNet
8
+ import torch.optim as optim
9
+
10
+ class SliceDataset(Dataset):
11
+ def __init__(self, folder):
12
+ self.paths = sorted(glob.glob(os.path.join(folder, "*.png")))
13
+ def __len__(self): return len(self.paths)
14
+ def __getitem__(self, idx):
15
+ p = self.paths[idx]
16
+ img = np.array(Image.open(p).convert("L"), dtype=np.float32)/255.0
17
+ mask = (img > img.mean() + 0.25).astype(np.float32)
18
+ img = img[np.newaxis,...]
19
+ mask = mask[np.newaxis,...]
20
+ return torch.tensor(img), torch.tensor(mask)
21
+
22
+ def train(folder, epochs=3, out="models/unet_best.pt"):
23
+ ds = SliceDataset(folder)
24
+ dl = DataLoader(ds, batch_size=4, shuffle=True)
25
+ model = UNet(in_channels=1, out_channels=1)
26
+ opt = optim.Adam(model.parameters(), lr=1e-3)
27
+ loss_fn = nn.BCELoss()
28
+ for epoch in range(epochs):
29
+ total=0
30
+ model.train()
31
+ for x,y in dl:
32
+ outp = model(x)
33
+ loss = loss_fn(outp, y)
34
+ opt.zero_grad(); loss.backward(); opt.step()
35
+ total += loss.item()
36
+ print(f"Epoch {epoch+1}, loss {total/len(dl):.4f}")
37
+ os.makedirs(os.path.dirname(out), exist_ok=True)
38
+ torch.save(model.state_dict(), out)
39
+ print("Saved model to", out)
40
+
41
+ if __name__ == '__main__':
42
+ import argparse
43
+ p = argparse.ArgumentParser()
44
+ p.add_argument('--data', default='examples/synthetic_phantom')
45
+ p.add_argument('--epochs', type=int, default=3)
46
+ p.add_argument('--out', default='models/unet_best.pt')
47
+ args = p.parse_args()
48
+ train(args.data, epochs=args.epochs, out=args.out)
utils.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def normalize_volume(vol):
4
+ vmin = np.nanmin(vol)
5
+ vmax = np.nanmax(vol)
6
+ if vmax - vmin < 1e-8:
7
+ return np.zeros_like(vol)
8
+ return (vol - vmin) / (vmax - vmin)
9
+
10
+ def ensure_3d(volume):
11
+ if volume.ndim == 2:
12
+ return volume[np.newaxis, ...]
13
+ return volume