linoyts HF Staff commited on
Commit
af28872
·
verified ·
1 Parent(s): 91e8cef

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +268 -0
app.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import shutil
4
+ import tempfile
5
+ import zipfile
6
+ from pathlib import Path
7
+ from datasets import load_dataset
8
+ from huggingface_hub import HfApi, login
9
+ from PIL import Image
10
+
11
+
12
+ def process_hf_dataset(dataset_name: str, image_col: str, caption_col: str, lora_name: str, progress=gr.Progress()):
13
+ """Process a HuggingFace dataset and create image/txt pairs."""
14
+ if not dataset_name.strip():
15
+ return None, "Please enter a dataset name", lora_name
16
+
17
+ try:
18
+ progress(0, desc="Loading dataset...")
19
+ ds = load_dataset(dataset_name, split="train")
20
+
21
+ # Create temp directory for output
22
+ output_dir = tempfile.mkdtemp()
23
+
24
+ # Detect columns if not specified
25
+ if not image_col:
26
+ for col in ds.column_names:
27
+ if ds.features[col].dtype == "image" or "image" in col.lower():
28
+ image_col = col
29
+ break
30
+
31
+ if not caption_col:
32
+ for col in ds.column_names:
33
+ if "text" in col.lower() or "caption" in col.lower() or "prompt" in col.lower():
34
+ caption_col = col
35
+ break
36
+
37
+ if not image_col or not caption_col:
38
+ return None, f"Could not detect columns. Available: {ds.column_names}", lora_name
39
+
40
+ progress(0.1, desc=f"Processing {len(ds)} images...")
41
+
42
+ for i, item in enumerate(ds):
43
+ progress((i + 1) / len(ds), desc=f"Processing image {i+1}/{len(ds)}")
44
+
45
+ # Save image
46
+ img = item[image_col]
47
+ if not isinstance(img, Image.Image):
48
+ img = Image.open(img)
49
+
50
+ img_filename = f"{i:05d}.png"
51
+ txt_filename = f"{i:05d}.txt"
52
+
53
+ img.save(os.path.join(output_dir, img_filename))
54
+
55
+ # Save caption
56
+ caption = item[caption_col]
57
+ with open(os.path.join(output_dir, txt_filename), "w", encoding="utf-8") as f:
58
+ f.write(str(caption))
59
+
60
+ return output_dir, f"Processed {len(ds)} images from {dataset_name}", lora_name
61
+
62
+ except Exception as e:
63
+ return None, f"Error: {str(e)}", lora_name
64
+
65
+
66
+ def process_uploaded_images(images: list, caption: str, lora_name: str, progress=gr.Progress()):
67
+ """Process uploaded images with a shared caption."""
68
+ if not images:
69
+ return None, "Please upload some images", lora_name
70
+
71
+ output_dir = tempfile.mkdtemp()
72
+
73
+ for i, img_data in enumerate(progress.tqdm(images, desc="Processing images")):
74
+ # Gallery returns tuples of (filepath, caption) or just filepath
75
+ if isinstance(img_data, tuple):
76
+ img_path = img_data[0]
77
+ else:
78
+ img_path = img_data
79
+
80
+ img = Image.open(img_path)
81
+
82
+ # Use original filename without extension
83
+ orig_name = Path(img_path).stem
84
+ img_filename = f"{orig_name}.png"
85
+ txt_filename = f"{orig_name}.txt"
86
+
87
+ img.save(os.path.join(output_dir, img_filename))
88
+
89
+ with open(os.path.join(output_dir, txt_filename), "w", encoding="utf-8") as f:
90
+ f.write(caption if caption else "")
91
+
92
+ return output_dir, f"Processed {len(images)} images", lora_name
93
+
94
+
95
+ def create_zip(output_dir: str, lora_name: str = None):
96
+ """Create a zip file from the output directory."""
97
+ if not output_dir or not os.path.exists(output_dir):
98
+ return None
99
+
100
+ # Use lora_name for zip filename if provided
101
+ if lora_name and lora_name.strip():
102
+ zip_filename = f"{lora_name.strip().replace(' ', '_')}.zip"
103
+ zip_path = os.path.join(tempfile.gettempdir(), zip_filename)
104
+ else:
105
+ zip_path = tempfile.mktemp(suffix=".zip")
106
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
107
+ for file in os.listdir(output_dir):
108
+ zf.write(os.path.join(output_dir, file), file)
109
+
110
+ return zip_path
111
+
112
+
113
+ def push_to_hub(output_dir: str, repo_name: str, token: str, private: bool, progress=gr.Progress()):
114
+ """Push the processed dataset to HuggingFace Hub."""
115
+ if not output_dir or not os.path.exists(output_dir):
116
+ return "No data to push. Process a dataset first."
117
+
118
+ if not repo_name or not repo_name.strip():
119
+ return "Please enter a repository name (or provide a LoRA name when processing)"
120
+
121
+ if not token or not token.strip():
122
+ return "Please enter your HuggingFace token"
123
+
124
+ try:
125
+ progress(0, desc="Logging in...")
126
+ api = HfApi(token=token)
127
+
128
+ progress(0.2, desc="Creating repository...")
129
+ api.create_repo(repo_name, repo_type="dataset", private=private, exist_ok=True)
130
+
131
+ progress(0.4, desc="Uploading files...")
132
+ api.upload_folder(
133
+ folder_path=output_dir,
134
+ repo_id=repo_name,
135
+ repo_type="dataset",
136
+ )
137
+
138
+ return f"Successfully pushed to https://huggingface.co/datasets/{repo_name}"
139
+
140
+ except Exception as e:
141
+ return f"Error: {str(e)}"
142
+
143
+
144
+ # Global state for output directory
145
+ current_output_dir = {"path": None, "lora_name": None}
146
+
147
+
148
+ def process_dataset_wrapper(dataset_name, image_col, caption_col, lora_name, progress=gr.Progress()):
149
+ output_dir, msg, lora = process_hf_dataset(dataset_name, image_col, caption_col, lora_name, progress)
150
+ current_output_dir["path"] = output_dir
151
+ current_output_dir["lora_name"] = lora
152
+ zip_path = create_zip(output_dir, lora) if output_dir else None
153
+ return msg, zip_path
154
+
155
+
156
+ def process_images_wrapper(images, caption, lora_name, progress=gr.Progress()):
157
+ output_dir, msg, lora = process_uploaded_images(images, caption, lora_name, progress)
158
+ current_output_dir["path"] = output_dir
159
+ current_output_dir["lora_name"] = lora
160
+ zip_path = create_zip(output_dir, lora) if output_dir else None
161
+ return msg, zip_path
162
+
163
+
164
+ def push_wrapper(repo_name, token, private, progress=gr.Progress()):
165
+ # Use lora_name as default repo name if repo_name is empty
166
+ final_repo_name = repo_name.strip() if repo_name.strip() else None
167
+ if not final_repo_name and current_output_dir["lora_name"]:
168
+ final_repo_name = current_output_dir["lora_name"].strip().replace(" ", "_")
169
+ return push_to_hub(current_output_dir["path"], final_repo_name, token, private, progress)
170
+
171
+
172
+ # Build the Gradio interface
173
+ with gr.Blocks(title="AI Toolkit Dataset Converter") as demo:
174
+ gr.Markdown("# AI Toolkit Dataset Converter")
175
+ gr.Markdown("""Convert your datasets to the format expected by [ostris AI Toolkit](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#dataset-preparation). You can either:
176
+ 1. provide a dataset name from the hub OR
177
+ 2. upload your images directly
178
+ """)
179
+
180
+ with gr.Tabs():
181
+ # Tab 1: HuggingFace Dataset
182
+ with gr.Tab("From HuggingFace Dataset"):
183
+ dataset_name = gr.Textbox(
184
+ label="Dataset Name",
185
+ placeholder="e.g., Norod78/Yarn-art-style"
186
+ )
187
+ lora_name_ds = gr.Textbox(
188
+ label="LoRA Name (optional)",
189
+ placeholder="e.g., my-lora-style",
190
+ info="Used for ZIP filename and Hub dataset name"
191
+ )
192
+ with gr.Row():
193
+ image_col = gr.Textbox(
194
+ label="Image Column (leave empty to auto-detect)",
195
+ placeholder="image"
196
+ )
197
+ caption_col = gr.Textbox(
198
+ label="Caption Column (leave empty to auto-detect)",
199
+ placeholder="text"
200
+ )
201
+ process_ds_btn = gr.Button("Process Dataset", variant="primary")
202
+ ds_status = gr.Textbox(label="Status", interactive=False)
203
+ ds_download = gr.File(label="Download ZIP")
204
+
205
+ process_ds_btn.click(
206
+ process_dataset_wrapper,
207
+ inputs=[dataset_name, image_col, caption_col, lora_name_ds],
208
+ outputs=[ds_status, ds_download]
209
+ )
210
+
211
+ # Tab 2: Upload Images
212
+ with gr.Tab("From Uploaded Images"):
213
+ images_input = gr.Gallery(
214
+ label="Upload Images",
215
+ file_types=["image"],
216
+ interactive=True,
217
+ columns=4,
218
+ height="auto"
219
+ )
220
+ lora_name_img = gr.Textbox(
221
+ label="LoRA Name (optional)",
222
+ placeholder="e.g., my-lora-style",
223
+ info="Used for ZIP filename and Hub dataset name"
224
+ )
225
+ shared_caption = gr.Textbox(
226
+ label="Caption for all images",
227
+ placeholder="Enter a caption to use for all uploaded images",
228
+ lines=3
229
+ )
230
+ process_img_btn = gr.Button("Process Images", variant="primary")
231
+ img_status = gr.Textbox(label="Status", interactive=False)
232
+ img_download = gr.File(label="Download ZIP")
233
+
234
+ process_img_btn.click(
235
+ process_images_wrapper,
236
+ inputs=[images_input, shared_caption, lora_name_img],
237
+ outputs=[img_status, img_download]
238
+ )
239
+
240
+ # Push to Hub section
241
+ gr.Markdown("---")
242
+ gr.Markdown("### Push to HuggingFace Hub")
243
+
244
+ with gr.Row():
245
+ repo_name = gr.Textbox(
246
+ label="Repository Name",
247
+ placeholder="username/dataset-name (uses LoRA name if empty)",
248
+ info="Leave empty to use LoRA name as dataset name"
249
+ )
250
+ hf_token = gr.Textbox(
251
+ label="HuggingFace Token",
252
+ type="password",
253
+ placeholder="hf_..."
254
+ )
255
+
256
+ private_repo = gr.Checkbox(label="Private Repository", value=False)
257
+ push_btn = gr.Button("Push to Hub", variant="secondary")
258
+ push_status = gr.Textbox(label="Push Status", interactive=False)
259
+
260
+ push_btn.click(
261
+ push_wrapper,
262
+ inputs=[repo_name, hf_token, private_repo],
263
+ outputs=[push_status]
264
+ )
265
+
266
+
267
+ if __name__ == "__main__":
268
+ demo.launch()