import gradio as gr import os import shutil import tempfile import zipfile from pathlib import Path from datasets import load_dataset from huggingface_hub import HfApi, login from PIL import Image def process_hf_dataset(dataset_name: str, image_col: str, caption_col: str, lora_name: str, progress=gr.Progress()): """Process a HuggingFace dataset and create image/txt pairs.""" if not dataset_name.strip(): return None, "Please enter a dataset name", lora_name try: progress(0, desc="Loading dataset...") ds = load_dataset(dataset_name, split="train") # Create temp directory for output output_dir = tempfile.mkdtemp() # Detect columns if not specified if not image_col: for col in ds.column_names: if ds.features[col].dtype == "image" or "image" in col.lower(): image_col = col break if not caption_col: for col in ds.column_names: if "text" in col.lower() or "caption" in col.lower() or "prompt" in col.lower(): caption_col = col break if not image_col or not caption_col: return None, f"Could not detect columns. Available: {ds.column_names}", lora_name progress(0.1, desc=f"Processing {len(ds)} images...") for i, item in enumerate(ds): progress((i + 1) / len(ds), desc=f"Processing image {i+1}/{len(ds)}") # Save image img = item[image_col] if not isinstance(img, Image.Image): img = Image.open(img) img_filename = f"{i:05d}.png" txt_filename = f"{i:05d}.txt" img.save(os.path.join(output_dir, img_filename)) # Save caption caption = item[caption_col] with open(os.path.join(output_dir, txt_filename), "w", encoding="utf-8") as f: f.write(str(caption)) return output_dir, f"Processed {len(ds)} images from {dataset_name}", lora_name except Exception as e: return None, f"Error: {str(e)}", lora_name def process_uploaded_images(images: list, caption: str, lora_name: str, progress=gr.Progress()): """Process uploaded images with a shared caption.""" if not images: return None, "Please upload some images", lora_name output_dir = tempfile.mkdtemp() for i, img_data in enumerate(progress.tqdm(images, desc="Processing images")): # Gallery returns tuples of (filepath, caption) or just filepath if isinstance(img_data, tuple): img_path = img_data[0] else: img_path = img_data img = Image.open(img_path) # Use original filename without extension orig_name = Path(img_path).stem img_filename = f"{orig_name}.png" txt_filename = f"{orig_name}.txt" img.save(os.path.join(output_dir, img_filename)) with open(os.path.join(output_dir, txt_filename), "w", encoding="utf-8") as f: f.write(caption if caption else "") return output_dir, f"Processed {len(images)} images", lora_name def create_zip(output_dir: str, lora_name: str = None): """Create a zip file from the output directory.""" if not output_dir or not os.path.exists(output_dir): return None # Use lora_name for zip filename if provided if lora_name and lora_name.strip(): zip_filename = f"{lora_name.strip().replace(' ', '_')}.zip" zip_path = os.path.join(tempfile.gettempdir(), zip_filename) else: zip_path = tempfile.mktemp(suffix=".zip") with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: for file in os.listdir(output_dir): zf.write(os.path.join(output_dir, file), file) return zip_path def push_to_hub(output_dir: str, repo_name: str, token: str, private: bool, progress=gr.Progress()): """Push the processed dataset to HuggingFace Hub.""" if not output_dir or not os.path.exists(output_dir): return "No data to push. Process a dataset first." if not repo_name or not repo_name.strip(): return "Please enter a repository name (or provide a LoRA name when processing)" if not token or not token.strip(): return "Please enter your HuggingFace token" try: progress(0, desc="Logging in...") api = HfApi(token=token) progress(0.2, desc="Creating repository...") api.create_repo(repo_name, repo_type="dataset", private=private, exist_ok=True) progress(0.4, desc="Uploading files...") api.upload_folder( folder_path=output_dir, repo_id=repo_name, repo_type="dataset", ) return f"Successfully pushed to https://huggingface.co/datasets/{repo_name}" except Exception as e: return f"Error: {str(e)}" # Global state for output directory current_output_dir = {"path": None, "lora_name": None} def process_dataset_wrapper(dataset_name, image_col, caption_col, lora_name, progress=gr.Progress()): output_dir, msg, lora = process_hf_dataset(dataset_name, image_col, caption_col, lora_name, progress) current_output_dir["path"] = output_dir current_output_dir["lora_name"] = lora zip_path = create_zip(output_dir, lora) if output_dir else None return msg, zip_path def process_images_wrapper(images, caption, lora_name, progress=gr.Progress()): output_dir, msg, lora = process_uploaded_images(images, caption, lora_name, progress) current_output_dir["path"] = output_dir current_output_dir["lora_name"] = lora zip_path = create_zip(output_dir, lora) if output_dir else None return msg, zip_path def push_wrapper(repo_name, token, private, progress=gr.Progress()): # Use lora_name as default repo name if repo_name is empty final_repo_name = repo_name.strip() if repo_name.strip() else None if not final_repo_name and current_output_dir["lora_name"]: final_repo_name = current_output_dir["lora_name"].strip().replace(" ", "_") return push_to_hub(current_output_dir["path"], final_repo_name, token, private, progress) # Build the Gradio interface with gr.Blocks(title="AI Toolkit Dataset Converter") as demo: gr.Markdown("# AI Toolkit Dataset Converter") gr.Markdown("""Convert your datasets to the format expected by [ostris AI Toolkit](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#dataset-preparation). You can either: 1. provide a dataset name from the hub OR 2. upload your images directly """) with gr.Tabs(): # Tab 1: HuggingFace Dataset with gr.Tab("From HuggingFace Dataset"): dataset_name = gr.Textbox( label="Dataset Name", placeholder="e.g., Norod78/Yarn-art-style" ) lora_name_ds = gr.Textbox( label="LoRA Name (optional)", placeholder="e.g., my-lora-style", info="Used for ZIP filename and Hub dataset name" ) with gr.Row(): image_col = gr.Textbox( label="Image Column (leave empty to auto-detect)", placeholder="image" ) caption_col = gr.Textbox( label="Caption Column (leave empty to auto-detect)", placeholder="text" ) process_ds_btn = gr.Button("Process Dataset", variant="primary") ds_status = gr.Textbox(label="Status", interactive=False) ds_download = gr.File(label="Download ZIP") process_ds_btn.click( process_dataset_wrapper, inputs=[dataset_name, image_col, caption_col, lora_name_ds], outputs=[ds_status, ds_download] ) # Tab 2: Upload Images with gr.Tab("From Uploaded Images"): images_input = gr.Gallery( label="Upload Images", file_types=["image"], interactive=True, columns=4, height="auto" ) lora_name_img = gr.Textbox( label="LoRA Name (optional)", placeholder="e.g., my-lora-style", info="Used for ZIP filename and Hub dataset name" ) shared_caption = gr.Textbox( label="Caption for all images", placeholder="Enter a caption to use for all uploaded images", lines=3 ) process_img_btn = gr.Button("Process Images", variant="primary") img_status = gr.Textbox(label="Status", interactive=False) img_download = gr.File(label="Download ZIP") process_img_btn.click( process_images_wrapper, inputs=[images_input, shared_caption, lora_name_img], outputs=[img_status, img_download] ) # Push to Hub section gr.Markdown("---") gr.Markdown("### Push to HuggingFace Hub") with gr.Row(): repo_name = gr.Textbox( label="Repository Name", placeholder="username/dataset-name (uses LoRA name if empty)", info="Leave empty to use LoRA name as dataset name" ) hf_token = gr.Textbox( label="HuggingFace Token", type="password", placeholder="hf_..." ) private_repo = gr.Checkbox(label="Private Repository", value=False) push_btn = gr.Button("Push to Hub", variant="secondary") push_status = gr.Textbox(label="Push Status", interactive=False) push_btn.click( push_wrapper, inputs=[repo_name, hf_token, private_repo], outputs=[push_status] ) if __name__ == "__main__": demo.launch()