Spaces:
Build error
Build error
ns
commited on
Commit
·
0e231b3
1
Parent(s):
4827b95
containers
Browse files- containers/etl/Dockerfile +37 -0
- containers/etl/__init__.py +0 -0
- containers/etl/common.py +119 -0
- containers/etl/requirements.txt +2 -0
- containers/etl/run.py +38 -0
- containers/jupyter/Dockerfile +28 -0
- containers/jupyter/requirements.txt +1 -0
- containers/physionet/Dockerfile +11 -0
- containers/physionet/entrypoint.sh +11 -0
- containers/physionet/run.sh +6 -0
- containers/prerad +1 -0
- containers/streamlit/example.txt +16 -0
- containers/train/Dockerfile +31 -0
- containers/train/requirements.txt +3 -0
- containers/train/run.py +83 -0
containers/etl/Dockerfile
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.9-buster
|
| 2 |
+
|
| 3 |
+
RUN \
|
| 4 |
+
apt-get update && \
|
| 5 |
+
apt-get -y upgrade && \
|
| 6 |
+
apt-get clean && \
|
| 7 |
+
rm -rf /var/lib/apt/lists/*
|
| 8 |
+
|
| 9 |
+
RUN useradd --create-home app
|
| 10 |
+
WORKDIR /home/app
|
| 11 |
+
|
| 12 |
+
COPY requirements.txt /home/app/
|
| 13 |
+
COPY __init__.py /home/app/
|
| 14 |
+
COPY common.py /home/app/
|
| 15 |
+
COPY run.py /home/app/
|
| 16 |
+
|
| 17 |
+
RUN \
|
| 18 |
+
chown app:app /home/app/requirements.txt && \
|
| 19 |
+
chmod 0755 /home/app/requirements.txt && \
|
| 20 |
+
chown app:app /home/app/__init__.py && \
|
| 21 |
+
chmod 0755 /home/app/__init__.py && \
|
| 22 |
+
chown app:app /home/app/run.py && \
|
| 23 |
+
chmod 0755 /home/app/run.py && \
|
| 24 |
+
chown app:app /home/app/common.py && \
|
| 25 |
+
chmod 0755 /home/app/common.py
|
| 26 |
+
|
| 27 |
+
USER app
|
| 28 |
+
|
| 29 |
+
ENV VIRTUAL_ENV=/home/app/venv
|
| 30 |
+
RUN python3 -m venv $VIRTUAL_ENV
|
| 31 |
+
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
| 32 |
+
|
| 33 |
+
RUN \
|
| 34 |
+
pip install --upgrade pip && \
|
| 35 |
+
pip install -r requirements.txt
|
| 36 |
+
|
| 37 |
+
CMD ["python", "run.py", "worker", "-l", "info"]
|
containers/etl/__init__.py
ADDED
|
File without changes
|
containers/etl/common.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import jsonlines
|
| 6 |
+
|
| 7 |
+
def flatten_json(data: dict) -> dict:
|
| 8 |
+
""" recursive flatten json elements from https://www.geeksforgeeks.org/flattening-json-objects-in-python/"""
|
| 9 |
+
out = {}
|
| 10 |
+
|
| 11 |
+
def flatten(x, name=""):
|
| 12 |
+
# If the Nested key-value
|
| 13 |
+
# pair is of dict type
|
| 14 |
+
if type(x) is dict:
|
| 15 |
+
for a in x:
|
| 16 |
+
flatten(x[a], name + a + "_")
|
| 17 |
+
|
| 18 |
+
# If the Nested key-value
|
| 19 |
+
# pair is of list type
|
| 20 |
+
elif type(x) is list:
|
| 21 |
+
i = 0
|
| 22 |
+
|
| 23 |
+
for a in x:
|
| 24 |
+
flatten(a, name + str(i) + "_")
|
| 25 |
+
i += 1
|
| 26 |
+
else:
|
| 27 |
+
out[name[:-1]] = x
|
| 28 |
+
|
| 29 |
+
flatten(data)
|
| 30 |
+
return out
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def construct_report(string: str) -> tuple:
|
| 34 |
+
|
| 35 |
+
# normalize sections
|
| 36 |
+
keywords = [x.replace(":","").lower() for x in re.findall("[A-Z0-9][A-Z0-9. ]*:",string)]
|
| 37 |
+
|
| 38 |
+
# normalize sections
|
| 39 |
+
paragraphs = re.findall("(\w+)*: *(.*?)(?=\s*(?:\w+:|$))", string.lower())
|
| 40 |
+
sections = []
|
| 41 |
+
for header, paragraph in paragraphs:
|
| 42 |
+
if header in [x.replace(" ","_").replace("/","_") for x in keywords]:
|
| 43 |
+
sections.append(":".join([header, ". ".join([x.strip() for x in paragraph.split(". ") if x])]))
|
| 44 |
+
else:
|
| 45 |
+
sections.append(" - ".join([header, ". ".join([x.strip() for x in paragraph.split(". ") if x])]))
|
| 46 |
+
sections = list(map(lambda a: a + "." if a[-1] != "." else a, sections))
|
| 47 |
+
paragraphs = re.findall("(\w+) *: *(.*?)(?=\s*(?:\w+:|$))", " ".join(sections))
|
| 48 |
+
|
| 49 |
+
report = {}
|
| 50 |
+
for header, paragraph in paragraphs:
|
| 51 |
+
sentence = paragraph.replace(" ", ". ").replace("..", ".").replace(" - ."," - ")
|
| 52 |
+
sentence = re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s", sentence)
|
| 53 |
+
sentence = [x.strip() for x in sentence if len(x) > 2]
|
| 54 |
+
report[header.replace("_", " ")] = [x.replace("_", " ") for x in sentence]
|
| 55 |
+
report = flatten_json(report)
|
| 56 |
+
topic = [x.split("_")[0] for x in report.keys()]
|
| 57 |
+
body = [x for x in report.values()]
|
| 58 |
+
report = pd.DataFrame(list(zip(topic, body)))
|
| 59 |
+
try:
|
| 60 |
+
report.columns = ["paragraph", "sentence"]
|
| 61 |
+
report["ranking"] = report.index
|
| 62 |
+
report["screen"] = report["sentence"].apply(lambda x: 1 if 'interval change' in x or 'compar' in x or 'prior' in x or 'improved from' in x else 0)
|
| 63 |
+
reason = re.sub(" +", " ", " ".join([": ".join([key, value]) for (key,value) in collapse_report(report).items() if key in ['indication','history']]))
|
| 64 |
+
text = re.sub(" +", " ", " ".join([": ".join([key, value]) for (key,value) in collapse_report(report[report.screen==0]).items() if key in ['findings','impression']]))
|
| 65 |
+
if 'findings' in text and 'impression' in text:
|
| 66 |
+
return reason, text
|
| 67 |
+
else:
|
| 68 |
+
return None, None
|
| 69 |
+
except ValueError:
|
| 70 |
+
return None, None
|
| 71 |
+
|
| 72 |
+
# take a report dataframe and return a dictionary of the paragraphs
|
| 73 |
+
def collapse_report(report: pd.DataFrame) -> dict:
|
| 74 |
+
"""take raw text and return paragraphs in sections as key:value pairs"""
|
| 75 |
+
out = pd.merge(
|
| 76 |
+
report['paragraph'].drop_duplicates(),
|
| 77 |
+
report.groupby(['paragraph'])['sentence'].transform(lambda x: ' '.join(x)).drop_duplicates(),
|
| 78 |
+
left_index=True,
|
| 79 |
+
right_index=True
|
| 80 |
+
)
|
| 81 |
+
structure = dict()
|
| 82 |
+
for index, row in out.iterrows():
|
| 83 |
+
structure[row['paragraph']] = row['sentence']
|
| 84 |
+
return structure
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def extract_transform(row: dict) -> None:
|
| 88 |
+
|
| 89 |
+
report_root = "./physionet.org/files/mimic-cxr/2.0.0/files"
|
| 90 |
+
image_root = "./physionet.org/files/mimic-cxr-jpg/2.0.0/files"
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
scans = os.listdir(os.path.join(image_root,row["part"],row["patient"]))
|
| 94 |
+
scans = [x for x in scans if 'txt' not in x]
|
| 95 |
+
for scan in scans:
|
| 96 |
+
report = os.path.join(report_root,row["part"],row["patient"],scan+".txt")
|
| 97 |
+
if os.path.exists(report):
|
| 98 |
+
with open(report,"r") as f:
|
| 99 |
+
original = f.read()
|
| 100 |
+
transformed = re.sub(" +"," ",original.replace("FINAL REPORT","").strip().replace("\n \n",".").replace("\n"," ")).replace(" . "," ").replace("..",".").replace("CHEST RADIOGRAPHS."," ").strip()
|
| 101 |
+
if len(transformed) > 0:
|
| 102 |
+
reason, text = construct_report(transformed)
|
| 103 |
+
images = [os.path.join(image_root,row["part"],row["patient"],scan,x) for x in os.listdir(os.path.join(image_root,row["part"],row["patient"],scan))]
|
| 104 |
+
images = [x for x in images if os.path.exists(x)]
|
| 105 |
+
random.shuffle(images) # shuffle so we can reasonably sample 1 image per study
|
| 106 |
+
with jsonlines.open("dataset.jsonl","a") as writer:
|
| 107 |
+
for image in images:
|
| 108 |
+
writer.write({
|
| 109 |
+
"fold": row["patient"][0:3],
|
| 110 |
+
"image": image,
|
| 111 |
+
"study": image.split("/")[-2],
|
| 112 |
+
"original": transformed,
|
| 113 |
+
"report": report,
|
| 114 |
+
"patient": row["patient"],
|
| 115 |
+
"reason": reason,
|
| 116 |
+
"text": " ".join([reason,text]) if reason is not None and text is not None else None
|
| 117 |
+
})
|
| 118 |
+
except FileNotFoundError:
|
| 119 |
+
pass
|
containers/etl/requirements.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
jsonlines>=3.1.0,<3.2
|
| 2 |
+
pandas>=1.5.3,<1.6
|
containers/etl/run.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import jsonlines
|
| 3 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 4 |
+
from common import extract_transform
|
| 5 |
+
|
| 6 |
+
def run():
|
| 7 |
+
|
| 8 |
+
# remove previous executions
|
| 9 |
+
if os.path.exists("/opt/physionet/dataset.jsonl"):
|
| 10 |
+
os.remove("/opt/physionet/dataset.jsonl")
|
| 11 |
+
|
| 12 |
+
if os.path.exists("/opt/physionet/control.jsonl"):
|
| 13 |
+
os.remove("/opt/physionet/control.jsonl")
|
| 14 |
+
|
| 15 |
+
# create a control dictionary
|
| 16 |
+
root = "/opt/physionet/physionet.org/files/mimic-cxr/2.0.0/files"
|
| 17 |
+
with jsonlines.open("/opt/physionet/control.jsonl","w") as writer:
|
| 18 |
+
parts = os.listdir(root)
|
| 19 |
+
for part in parts:
|
| 20 |
+
patients = os.listdir(os.path.join(root,part))
|
| 21 |
+
for patient in patients:
|
| 22 |
+
scan = [x for x in os.listdir(os.path.join(root,part,patient)) if x.endswith('.txt')]
|
| 23 |
+
writer.write({"part": part, "patient": patient,"scan": scan})
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# parse each record
|
| 27 |
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
| 28 |
+
with jsonlines.open("/opt/physionet/control.jsonl","r") as reader:
|
| 29 |
+
executor.map(extract_transform, reader)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# only run it if there are files downloaded
|
| 33 |
+
if __name__ == "__main__":
|
| 34 |
+
try:
|
| 35 |
+
if len(os.listdir('/opt/physionet/physionet.org/files/mimic-cxr/2.0.0/files')) > 0:
|
| 36 |
+
run()
|
| 37 |
+
except OSError:
|
| 38 |
+
print("not downloaded yet")
|
containers/jupyter/Dockerfile
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.9-buster
|
| 2 |
+
|
| 3 |
+
RUN \
|
| 4 |
+
apt-get update && \
|
| 5 |
+
apt-get -y upgrade && \
|
| 6 |
+
apt-get clean && \
|
| 7 |
+
rm -rf /var/lib/apt/lists/*
|
| 8 |
+
|
| 9 |
+
RUN useradd --create-home app
|
| 10 |
+
WORKDIR /home/app
|
| 11 |
+
|
| 12 |
+
COPY requirements.txt /home/app/
|
| 13 |
+
|
| 14 |
+
RUN \
|
| 15 |
+
chown app:app /home/app/requirements.txt && \
|
| 16 |
+
chmod 0755 /home/app/requirements.txt
|
| 17 |
+
|
| 18 |
+
USER app
|
| 19 |
+
|
| 20 |
+
ENV VIRTUAL_ENV=/home/app/venv
|
| 21 |
+
RUN python3 -m venv $VIRTUAL_ENV
|
| 22 |
+
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
| 23 |
+
|
| 24 |
+
RUN \
|
| 25 |
+
pip install --upgrade pip && \
|
| 26 |
+
pip install -r requirements.txt
|
| 27 |
+
|
| 28 |
+
CMD ["jupyter", "notebook", "--notebook-dir=/opt/notebooks", "--ip='*'", "--port=8888", "--no-browser"]
|
containers/jupyter/requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
jupyter>=1.0.0,<1.1
|
containers/physionet/Dockerfile
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM debian:buster
|
| 2 |
+
|
| 3 |
+
RUN apt-get update -y && \
|
| 4 |
+
apt-get -y install parallel wget && \
|
| 5 |
+
apt-get -y autoclean && \
|
| 6 |
+
apt-get -y autoremove && \
|
| 7 |
+
rm -rf /var/lib/apt/lists/*
|
| 8 |
+
|
| 9 |
+
COPY entrypoint.sh /opt/entrypoint.sh
|
| 10 |
+
|
| 11 |
+
ENTRYPOINT ["/bin/bash", "/opt/entrypoint.sh"]
|
containers/physionet/entrypoint.sh
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e
|
| 3 |
+
|
| 4 |
+
if [ $# -eq 0 ]
|
| 5 |
+
then
|
| 6 |
+
echo no download requested
|
| 7 |
+
else
|
| 8 |
+
cd /opt/physionet
|
| 9 |
+
wget -A .txt -r -nc -c -np --user $PHYSIONET_USER --password $PHYSIONET_PASSWORD https://physionet.org/files/mimic-cxr/2.0.0/files/
|
| 10 |
+
seq 10 19 | parallel -j4 wget -A .jpg -r -nc -c -np --user $PHYSIONET_USER --password $PHYSIONET_PASSWORD https://physionet.org/files/mimic-cxr-jpg/2.0.0/files/p{}/
|
| 11 |
+
fi
|
containers/physionet/run.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
if ${1:-false}; then
|
| 3 |
+
# get the JPG
|
| 4 |
+
# spread out over 4 cores
|
| 5 |
+
echo True
|
| 6 |
+
fi
|
containers/prerad
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Subproject commit 9c1e2f3995808260287d58217a0522592f78aed0
|
containers/streamlit/example.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FINAL REPORT
|
| 2 |
+
EXAMINATION: CHEST (PORTABLE AP)
|
| 3 |
+
|
| 4 |
+
INDICATION: ___ year old man with episodic ___ weakness // r/o infection
|
| 5 |
+
|
| 6 |
+
TECHNIQUE: CHEST (PORTABLE AP)
|
| 7 |
+
|
| 8 |
+
COMPARISON: None.
|
| 9 |
+
|
| 10 |
+
IMPRESSION:
|
| 11 |
+
|
| 12 |
+
Heart size and mediastinum are mildly enlarged. The patient is after median
|
| 13 |
+
sternotomy and CABG. Lung volumes are preserved. Mild interstitial changes
|
| 14 |
+
are noted bilaterally, potentially representing chronic changes but mild
|
| 15 |
+
interstitial edema is a possibility. No definitive focal consolidations to
|
| 16 |
+
suggest infectious process demonstrated. No pleural effusion or pneumothorax.
|
containers/train/Dockerfile
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.9-buster
|
| 2 |
+
|
| 3 |
+
RUN \
|
| 4 |
+
apt-get update && \
|
| 5 |
+
apt-get -y upgrade && \
|
| 6 |
+
apt-get clean && \
|
| 7 |
+
rm -rf /var/lib/apt/lists/*
|
| 8 |
+
|
| 9 |
+
RUN useradd --create-home app
|
| 10 |
+
WORKDIR /home/app
|
| 11 |
+
|
| 12 |
+
COPY requirements.txt /home/app/
|
| 13 |
+
COPY run.py /home/app/
|
| 14 |
+
|
| 15 |
+
RUN \
|
| 16 |
+
chown app:app /home/app/requirements.txt && \
|
| 17 |
+
chmod 0755 /home/app/requirements.txt && \
|
| 18 |
+
chown app:app /home/app/run.py && \
|
| 19 |
+
chmod 0755 /home/app/run.py
|
| 20 |
+
|
| 21 |
+
USER app
|
| 22 |
+
|
| 23 |
+
ENV VIRTUAL_ENV=/home/app/venv
|
| 24 |
+
RUN python3 -m venv $VIRTUAL_ENV
|
| 25 |
+
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
| 26 |
+
|
| 27 |
+
RUN \
|
| 28 |
+
pip install --upgrade pip && \
|
| 29 |
+
pip install -r requirements.txt
|
| 30 |
+
|
| 31 |
+
CMD ["python", "run.py", "worker", "-l", "info"]
|
containers/train/requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers[torch]>=4.26.0,<4.27
|
| 2 |
+
pillow>=9.4.0,<9.5
|
| 3 |
+
datasets>=2.9.0,<2.10
|
containers/train/run.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from datasets import Dataset, Image
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import Trainer, TrainingArguments
|
| 5 |
+
from transformers import DataCollatorForLanguageModeling
|
| 6 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
| 7 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 8 |
+
|
| 9 |
+
# model initialize form pretrained
|
| 10 |
+
repo = "Salesforce/blip-image-captioning-large"
|
| 11 |
+
processor = BlipProcessor.from_pretrained(repo)
|
| 12 |
+
tokenizer = processor.tokenizer
|
| 13 |
+
model = BlipForConditionalGeneration.from_pretrained(repo)
|
| 14 |
+
|
| 15 |
+
# load the data configuration and split into test/valid
|
| 16 |
+
dt = pd.read_json("dataset.jsonl",lines=True).dropna()
|
| 17 |
+
dt["train"] = dt["fold"].apply(lambda x: 0 if x in ['p19'] else 1) # 10% of data
|
| 18 |
+
dt["patient"]= dt["patient"].apply(lambda x: x[0:5])
|
| 19 |
+
train=dt[dt.train==1]
|
| 20 |
+
valid=dt[dt.train==0]
|
| 21 |
+
|
| 22 |
+
# create datasets
|
| 23 |
+
train_dataset = Dataset.from_dict({
|
| 24 |
+
"image": train["image"].to_list(),
|
| 25 |
+
"fold": train["fold"].to_list(),
|
| 26 |
+
"text": train["text"].to_list(),
|
| 27 |
+
"reason": train["reason"].to_list(),
|
| 28 |
+
"id": [x.split("/")[-1].replace(".jpg","") for x in train["image"].to_list()]
|
| 29 |
+
}).cast_column("image", Image())
|
| 30 |
+
|
| 31 |
+
valid_dataset = Dataset.from_dict({
|
| 32 |
+
"image": valid["image"].to_list(),
|
| 33 |
+
"fold": valid["fold"].to_list(),
|
| 34 |
+
"text": valid["text"].to_list(),
|
| 35 |
+
"reason": valid["reason"].to_list(),
|
| 36 |
+
"id": [x.split("/")[-1].replace(".jpg","") for x in valid["image"].to_list()]
|
| 37 |
+
}).cast_column("image", Image())
|
| 38 |
+
|
| 39 |
+
def transform(example_batch):
|
| 40 |
+
return processor(
|
| 41 |
+
images=[image for image in example_batch["image"]],
|
| 42 |
+
text=[text for text in example_batch["text"]],
|
| 43 |
+
return_tensors="np",
|
| 44 |
+
padding='max_length',
|
| 45 |
+
max_length=512
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# apply
|
| 49 |
+
train_prepared = train_dataset.shuffle(seed=42).with_transform(transform)
|
| 50 |
+
valid_prepared = valid_dataset.shuffle(seed=42).with_transform(transform)
|
| 51 |
+
|
| 52 |
+
# " ".join(processor.batch_decode(train_prepared[0]["input_ids"])).replace(" ##","")
|
| 53 |
+
training_args = TrainingArguments(
|
| 54 |
+
num_train_epochs=5,
|
| 55 |
+
evaluation_strategy="epoch",
|
| 56 |
+
save_steps=1000,
|
| 57 |
+
logging_steps=100,
|
| 58 |
+
per_device_eval_batch_size=2,
|
| 59 |
+
per_device_train_batch_size=2,
|
| 60 |
+
gradient_accumulation_steps=8,
|
| 61 |
+
lr_scheduler_type='cosine_with_restarts',
|
| 62 |
+
warmup_ratio=0.1,
|
| 63 |
+
learning_rate=5e-5,
|
| 64 |
+
save_total_limit=1,
|
| 65 |
+
output_dir="/opt/models/generate-cxr-checkpoints"
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
data_collator = DataCollatorForLanguageModeling(
|
| 69 |
+
tokenizer=tokenizer,
|
| 70 |
+
mlm = False
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
trainer = Trainer(
|
| 74 |
+
model=model,
|
| 75 |
+
tokenizer=processor,
|
| 76 |
+
args=training_args,
|
| 77 |
+
train_dataset=train_prepared,
|
| 78 |
+
eval_dataset=valid_prepared,
|
| 79 |
+
data_collator=data_collator,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
trainer.train()
|
| 83 |
+
trainer.save_model("/opt/models/generate-cxr")
|