File size: 1,732 Bytes
5b023ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os, glob
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from models.unet import UNet
import torch.optim as optim

class SliceDataset(Dataset):
    def __init__(self, folder):
        self.paths = sorted(glob.glob(os.path.join(folder, "*.png")))
    def __len__(self): return len(self.paths)
    def __getitem__(self, idx):
        p = self.paths[idx]
        img = np.array(Image.open(p).convert("L"), dtype=np.float32)/255.0
        mask = (img > img.mean() + 0.25).astype(np.float32)
        img = img[np.newaxis,...]
        mask = mask[np.newaxis,...]
        return torch.tensor(img), torch.tensor(mask)

def train(folder, epochs=3, out="models/unet_best.pt"):
    ds = SliceDataset(folder)
    dl = DataLoader(ds, batch_size=4, shuffle=True)
    model = UNet(in_channels=1, out_channels=1)
    opt = optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.BCELoss()
    for epoch in range(epochs):
        total=0
        model.train()
        for x,y in dl:
            outp = model(x)
            loss = loss_fn(outp, y)
            opt.zero_grad(); loss.backward(); opt.step()
            total += loss.item()
        print(f"Epoch {epoch+1}, loss {total/len(dl):.4f}")
    os.makedirs(os.path.dirname(out), exist_ok=True)
    torch.save(model.state_dict(), out)
    print("Saved model to", out)

if __name__ == '__main__':
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument('--data', default='examples/synthetic_phantom')
    p.add_argument('--epochs', type=int, default=3)
    p.add_argument('--out', default='models/unet_best.pt')
    args = p.parse_args()
    train(args.data, epochs=args.epochs, out=args.out)