|
|
import gradio as gr |
|
|
import json |
|
|
import os |
|
|
import zipfile |
|
|
|
|
|
def load_data(filepath): |
|
|
"""Loads data from the JSON file.""" |
|
|
with open(filepath, 'r') as f: |
|
|
data = json.load(f) |
|
|
return data |
|
|
|
|
|
def create_comparison_app(file_paths): |
|
|
"""Creates the Gradio app for comparing LLM responses with side-by-side layout for multiple files and browser download.""" |
|
|
|
|
|
all_data = {} |
|
|
current_file_index = 0 |
|
|
current_prompt_index = 0 |
|
|
current_filepath = "" |
|
|
results_data = {} |
|
|
|
|
|
def initialize_data(filepath): |
|
|
nonlocal all_data, current_prompt_index, current_filepath, results_data |
|
|
if filepath not in all_data: |
|
|
all_data[filepath] = load_data(filepath) |
|
|
results_data[filepath] = list(all_data[filepath]) |
|
|
current_filepath = filepath |
|
|
current_prompt_index = 0 |
|
|
|
|
|
def get_progress_text(): |
|
|
nonlocal current_file_index, file_paths, current_prompt_index |
|
|
files_left = len(file_paths) - (current_file_index + 1) |
|
|
if current_filepath: |
|
|
prompts_left = len(results_data[current_filepath]) - (current_prompt_index + 1) if current_prompt_index < len(results_data[current_filepath]) else 0 |
|
|
return f"File {current_file_index + 1}/{len(file_paths)} - {prompts_left + 1} prompts left in this file, {files_left} files remaining." |
|
|
else: |
|
|
return "No file loaded." |
|
|
|
|
|
def display_prompt_and_responses(filepath, index): |
|
|
"""Displays the prompt and responses for a given index within the current file.""" |
|
|
if not filepath or filepath not in results_data: |
|
|
return "No file loaded.", "", "", get_progress_text(), None |
|
|
|
|
|
data = results_data[filepath] |
|
|
if 0 <= index < len(data): |
|
|
item = data[index] |
|
|
prompt_text = item.get("prompt", "No prompt available") |
|
|
finetuned_output_text = item.get("finetuned_output", "No finetuned output") |
|
|
base_output_text = item.get("base_output", "No base output") |
|
|
return prompt_text, finetuned_output_text, base_output_text, get_progress_text(), None |
|
|
else: |
|
|
return "File finished! Please proceed to the next file.", "", "", get_progress_text(), None |
|
|
|
|
|
def record_choice(choice): |
|
|
"""Records the user's choice and moves to the next prompt or file, provides download at the end.""" |
|
|
nonlocal current_prompt_index, results_data, current_filepath, current_file_index, file_paths |
|
|
|
|
|
if not current_filepath: |
|
|
return "No file loaded.", "", "", get_progress_text(), None |
|
|
|
|
|
data = results_data[current_filepath] |
|
|
if 0 <= current_prompt_index < len(data): |
|
|
if choice == "finetuned": |
|
|
data[current_prompt_index]["choice"] = "finetuned" |
|
|
elif choice == "base": |
|
|
data[current_prompt_index]["choice"] = "base_output" |
|
|
|
|
|
current_prompt_index += 1 |
|
|
if current_prompt_index < len(data): |
|
|
return display_prompt_and_responses(current_filepath, current_prompt_index) + (None,) |
|
|
else: |
|
|
|
|
|
if current_file_index < len(file_paths) - 1: |
|
|
current_file_index += 1 |
|
|
next_filepath = file_paths[current_file_index] |
|
|
initialize_data(next_filepath) |
|
|
return display_prompt_and_responses(current_filepath, current_prompt_index) + (None,) |
|
|
else: |
|
|
|
|
|
zip_filepath = create_zip_archive(results_data) |
|
|
return "Comparison finished for all files! Please download the results ('Download results' button).", "", "", "Comparison finished for all files!", gr.update(visible=True, value=zip_filepath, label="Download results") |
|
|
else: |
|
|
|
|
|
if current_file_index < len(file_paths) - 1: |
|
|
current_file_index += 1 |
|
|
next_filepath = file_paths[current_file_index] |
|
|
initialize_data(next_filepath) |
|
|
return display_prompt_and_responses(current_filepath, current_prompt_index) + (None,) |
|
|
else: |
|
|
|
|
|
zip_filepath = create_zip_archive(results_data) |
|
|
return "Comparison finished for all files! Please download the results ('Download results' button).", "", "", "Comparison finished for all files!", gr.update(visible=True, value=zip_filepath, label="Download results") |
|
|
|
|
|
def create_zip_archive(results_data): |
|
|
"""Creates a zip archive of all result files.""" |
|
|
zip_filepath = "/tmp/results.zip" |
|
|
with zipfile.ZipFile(zip_filepath, 'w') as zipf: |
|
|
for filepath, data in results_data.items(): |
|
|
results_filename = os.path.basename(filepath).replace(".json", "_results.json") |
|
|
results_json_string = json.dumps(data, indent=2) |
|
|
zipf.writestr(results_filename, results_json_string) |
|
|
return zip_filepath |
|
|
|
|
|
with gr.Blocks() as iface: |
|
|
progress_markdown = gr.Markdown(get_progress_text()) |
|
|
gr.Markdown("# LLM song lyrics generation ranking") |
|
|
gr.Markdown("There are 5 files (each with 50 prompts) to compare. For each prompt, choose the better lyrics between Model A and Model B. After you complete all files and prompts, you can download the results.") |
|
|
prompt_output = gr.Textbox(label="Lyrics description", lines=3, interactive=False, max_lines=3) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
finetuned_output_box = gr.Textbox(label="Model A", lines=10, interactive=False, max_lines=10) |
|
|
with gr.Column(): |
|
|
base_output_box = gr.Textbox(label="Model B", lines=10, interactive=False, max_lines=10) |
|
|
|
|
|
with gr.Row(): |
|
|
finetuned_button = gr.Button("Model A is better") |
|
|
base_button = gr.Button("Model B is better") |
|
|
|
|
|
file_download_output = gr.DownloadButton(label="Download results", visible=False) |
|
|
|
|
|
def load_initial_file(files): |
|
|
if files: |
|
|
filepath = files[0] |
|
|
initialize_data(filepath) |
|
|
return display_prompt_and_responses(current_filepath, current_prompt_index) + (None,) |
|
|
return "No file loaded.", "", "", get_progress_text(), None |
|
|
|
|
|
|
|
|
iface.load( |
|
|
load_initial_file, |
|
|
inputs=[gr.State(file_paths)], |
|
|
outputs=[prompt_output, finetuned_output_box, base_output_box, progress_markdown, file_download_output] |
|
|
) |
|
|
|
|
|
|
|
|
finetuned_button.click( |
|
|
fn=record_choice, |
|
|
inputs=[gr.State("finetuned")], |
|
|
outputs=[prompt_output, finetuned_output_box, base_output_box, progress_markdown, file_download_output], |
|
|
api_name="choose_finetuned" |
|
|
) |
|
|
base_button.click( |
|
|
fn=record_choice, |
|
|
inputs=[gr.State("base")], |
|
|
outputs=[prompt_output, finetuned_output_box, base_output_box, progress_markdown, file_download_output], |
|
|
api_name="choose_base" |
|
|
) |
|
|
|
|
|
return iface |
|
|
|
|
|
if __name__ == '__main__': |
|
|
json_files = ["./Qwen2.5-0.5B-song-lyrics-generation.json", "./SmolLM2-135M-song-lyrics-generation.json", "./SmolLM2-135M-Instruct-song-lyrics-generation.json", "./SmolLM2-360M-song-lyrics-generation.json", "./SmolLM2-360M-Instruct-song-lyrics-generation.json"] |
|
|
app = create_comparison_app(json_files) |
|
|
app.launch() |
|
|
|