Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import os | |
| from utils import ( | |
| plot_distances_tsne, | |
| plot_distances_umap, | |
| cluster_languages_hdbscan, | |
| cluster_languages_kmeans, | |
| plot_mst, | |
| cluster_languages_by_families, | |
| cluster_languages_by_subfamilies, | |
| filter_languages_by_families, | |
| ) | |
| from functools import partial | |
| import datasets | |
| dataset = datasets.load_dataset( | |
| "mshamrai/language-metric-data", split="train", trust_remote_code=True | |
| ) | |
| languages = dataset["languages_list"][0] | |
| average_distances_matrix = np.array(dataset["average_distances_matrix"][0]) | |
| DATASETS = dataset["distances_matrices"][0]["dataset_name"] | |
| MODELS = dataset["distances_matrices"][0]["models"][0]["model_name"] | |
| distance_matrices = { | |
| DATASETS[i]: { | |
| MODELS[j]: np.array(dataset["distances_matrices"][0]["models"][i]["matrix"][j]) | |
| for j in range(len(MODELS)) | |
| } | |
| for i in range(len(DATASETS)) | |
| } | |
| def filter_languages_nan(model, dataset, use_average): | |
| if use_average: | |
| matrix = average_distances_matrix | |
| else: | |
| matrix = distance_matrices[dataset][model] | |
| vector = matrix[0] | |
| updated_languages = np.array(languages)[~np.isnan(vector)] | |
| updated_matrix = matrix[~np.isnan(vector), :][:, ~np.isnan(vector)] | |
| return updated_matrix, updated_languages | |
| def get_similar_languages(model, dataset, selected_language, use_average, n): | |
| """ | |
| Retrieves the distances for the selected language from the chosen model and dataset, | |
| sorts them by similarity (lowest distance first), and returns a DataFrame. | |
| """ | |
| if use_average: | |
| matrix = average_distances_matrix | |
| else: | |
| matrix = distance_matrices[dataset][model] | |
| selected_language_index = languages.index(selected_language) | |
| distances = matrix[selected_language_index] | |
| df = pd.DataFrame({"Language": languages, "Distance": distances}) | |
| sorted_distances = df.sort_values(by="Distance") | |
| sorted_distances.drop(index=selected_language_index, inplace=True) | |
| sorted_distances.reset_index(drop=True, inplace=True) | |
| sorted_distances.reset_index(inplace=True) | |
| sorted_distances["Distance"] = sorted_distances["Distance"].round(4) | |
| return sorted_distances.head(n) | |
| def update_languages(model, dataset): | |
| """ | |
| Returns the language list based on the given model and dataset. | |
| """ | |
| matrix = distance_matrices[dataset][model] | |
| vector = matrix[0] | |
| updated_languages = np.array(languages)[~np.isnan(vector)] | |
| return list(updated_languages) | |
| def update_language_options(model, dataset, language, use_average): | |
| if use_average: | |
| updated_languages = languages | |
| else: | |
| updated_languages = update_languages(model, dataset) | |
| if language not in updated_languages: | |
| language = updated_languages[0] | |
| return gr.Dropdown(label="Language", choices=updated_languages, value=language) | |
| def toggle_inputs(use_average): | |
| if use_average: | |
| return gr.update(interactive=False, visible=False), gr.update( | |
| interactive=False, visible=False | |
| ) | |
| else: | |
| return gr.update(interactive=True, visible=True), gr.update( | |
| interactive=True, visible=True | |
| ) | |
| plot_path = "plots/last_plot.pdf" | |
| os.makedirs("plots", exist_ok=True) | |
| def plot_distances( | |
| model, | |
| dataset, | |
| use_average, | |
| cluster_method, | |
| cluster_method_param, | |
| figsize_h, | |
| figsize_w, | |
| plot_fn, | |
| ): | |
| """ | |
| Plots all languages from the distances matrix using t-SNE. | |
| """ | |
| updated_matrix, updated_languages = filter_languages_nan( | |
| model, dataset, use_average | |
| ) | |
| if cluster_method == "HDBSCAN": | |
| filtered_matrix, filtered_languages, clusters = cluster_languages_hdbscan( | |
| updated_matrix, updated_languages, min_cluster_size=cluster_method_param | |
| ) | |
| legends = None | |
| elif cluster_method == "KMeans": | |
| filtered_matrix, filtered_languages, clusters = cluster_languages_kmeans( | |
| updated_matrix, updated_languages, n_clusters=cluster_method_param | |
| ) | |
| legends = None | |
| elif cluster_method == "Family": | |
| clusters, legends = cluster_languages_by_families(updated_languages) | |
| filtered_matrix = updated_matrix | |
| filtered_languages = updated_languages | |
| elif cluster_method == "Subfamily": | |
| clusters, legends = cluster_languages_by_subfamilies(updated_languages) | |
| filtered_matrix = updated_matrix | |
| filtered_languages = updated_languages | |
| else: | |
| raise ValueError("Invalid cluster method") | |
| fig = plot_fn( | |
| filtered_matrix, | |
| filtered_languages, | |
| clusters, | |
| legends, | |
| fig_size=(figsize_w, figsize_h), | |
| ) | |
| fig.tight_layout() | |
| fig.savefig(plot_path, format="pdf") | |
| return fig, gr.DownloadButton(label="Download Plot", value=plot_path) | |
| def plot_families_subfamilies( | |
| families, model, dataset, use_average, figsize_h, figsize_w | |
| ): | |
| updated_matrix, updated_languages = filter_languages_nan( | |
| model, dataset, use_average | |
| ) | |
| updated_matrix, updated_languages = filter_languages_by_families( | |
| updated_matrix, updated_languages, families | |
| ) | |
| clusters, legends = cluster_languages_by_subfamilies(updated_languages) | |
| fig = plot_mst( | |
| updated_matrix, | |
| updated_languages, | |
| clusters, | |
| legends, | |
| fig_size=(figsize_w, figsize_h), | |
| ) | |
| fig.tight_layout() | |
| fig.savefig(plot_path, format="pdf") | |
| return fig, gr.DownloadButton(label="Download Plot", value=plot_path) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Language Distance Explorer") | |
| average_checkbox = gr.Checkbox(label="Use Average Distances", value=False) | |
| with gr.Row(): | |
| model_input = gr.Dropdown(label="Model", choices=MODELS, value=MODELS[0]) | |
| dataset_input = gr.Dropdown( | |
| label="Dataset", choices=DATASETS, value=DATASETS[0] | |
| ) | |
| with gr.Tab(label="Closest Languages Table"): | |
| with gr.Row(): | |
| language_input = gr.Dropdown( | |
| label="Language", choices=languages, value=languages[0] | |
| ) | |
| top_n_input = gr.Slider( | |
| label="Top N", minimum=1, maximum=30, step=1, value=10 | |
| ) | |
| output_table = gr.Dataframe(label="Similar Languages") | |
| model_input.change( | |
| fn=update_language_options, | |
| inputs=[model_input, dataset_input, language_input, average_checkbox], | |
| outputs=language_input, | |
| ) | |
| dataset_input.change( | |
| fn=update_language_options, | |
| inputs=[model_input, dataset_input, language_input, average_checkbox], | |
| outputs=language_input, | |
| ) | |
| language_input.change( | |
| fn=get_similar_languages, | |
| inputs=[ | |
| model_input, | |
| dataset_input, | |
| language_input, | |
| average_checkbox, | |
| top_n_input, | |
| ], | |
| outputs=output_table, | |
| ) | |
| model_input.change( | |
| fn=get_similar_languages, | |
| inputs=[ | |
| model_input, | |
| dataset_input, | |
| language_input, | |
| average_checkbox, | |
| top_n_input, | |
| ], | |
| outputs=output_table, | |
| ) | |
| dataset_input.change( | |
| fn=get_similar_languages, | |
| inputs=[ | |
| model_input, | |
| dataset_input, | |
| language_input, | |
| average_checkbox, | |
| top_n_input, | |
| ], | |
| outputs=output_table, | |
| ) | |
| top_n_input.change( | |
| fn=get_similar_languages, | |
| inputs=[ | |
| model_input, | |
| dataset_input, | |
| language_input, | |
| average_checkbox, | |
| top_n_input, | |
| ], | |
| outputs=output_table, | |
| ) | |
| average_checkbox.change( | |
| fn=toggle_inputs, | |
| inputs=[average_checkbox], | |
| outputs=[model_input, dataset_input], | |
| ) | |
| average_checkbox.change( | |
| fn=update_language_options, | |
| inputs=[model_input, dataset_input, language_input, average_checkbox], | |
| outputs=language_input, | |
| ) | |
| average_checkbox.change( | |
| fn=get_similar_languages, | |
| inputs=[ | |
| model_input, | |
| dataset_input, | |
| language_input, | |
| average_checkbox, | |
| top_n_input, | |
| ], | |
| outputs=output_table, | |
| ) | |
| with gr.Tab(label="Distance Plot"): | |
| with gr.Row(): | |
| cluster_method_input = gr.Dropdown( | |
| label="Cluster Method", | |
| choices=["HDBSCAN", "KMeans", "Family", "Subfamily"], | |
| value="HDBSCAN", | |
| ) | |
| clusters_input = gr.Slider( | |
| label="Minimum Elements in a Cluster", | |
| minimum=2, | |
| maximum=10, | |
| step=1, | |
| value=2, | |
| ) | |
| def update_clusters_input_option(cluster_method): | |
| if cluster_method == "HDBSCAN": | |
| return gr.Slider( | |
| label="Minimum Elements in a Cluster", | |
| minimum=2, | |
| maximum=10, | |
| step=1, | |
| value=2, | |
| visible=True, | |
| interactive=True, | |
| ) | |
| elif cluster_method == "KMeans": | |
| return gr.Slider( | |
| label="Number of Clusters", | |
| minimum=2, | |
| maximum=20, | |
| step=1, | |
| value=2, | |
| visible=True, | |
| interactive=True, | |
| ) | |
| else: | |
| return gr.update(interactive=False, visible=False) | |
| cluster_method_input.change( | |
| fn=update_clusters_input_option, | |
| inputs=[cluster_method_input], | |
| outputs=clusters_input, | |
| ) | |
| with gr.Row(): | |
| plot_tsne_button = gr.Button("Plot t-SNE") | |
| plot_umap_button = gr.Button("Plot UMAP") | |
| plot_mst_button = gr.Button("Plot MST") | |
| with gr.Row(): | |
| plot_figsize_dist_h_input = gr.Slider( | |
| label="Figure Height", minimum=5, maximum=30, step=1, value=15 | |
| ) | |
| plot_figsize_dist_w_input = gr.Slider( | |
| label="Figure Width", minimum=5, maximum=30, step=1, value=15 | |
| ) | |
| with gr.Row(): | |
| download_plot_button = gr.DownloadButton("Download Plot") | |
| with gr.Row(): | |
| plot_output = gr.Plot(label="Distance Plot") | |
| plot_tsne_button.click( | |
| fn=partial(plot_distances, plot_fn=plot_distances_tsne), | |
| inputs=[ | |
| model_input, | |
| dataset_input, | |
| average_checkbox, | |
| cluster_method_input, | |
| clusters_input, | |
| plot_figsize_dist_h_input, | |
| plot_figsize_dist_w_input, | |
| ], | |
| outputs=[plot_output, download_plot_button], | |
| ) | |
| plot_umap_button.click( | |
| fn=partial(plot_distances, plot_fn=plot_distances_umap), | |
| inputs=[ | |
| model_input, | |
| dataset_input, | |
| average_checkbox, | |
| cluster_method_input, | |
| clusters_input, | |
| plot_figsize_dist_h_input, | |
| plot_figsize_dist_w_input, | |
| ], | |
| outputs=[plot_output, download_plot_button], | |
| ) | |
| plot_mst_button.click( | |
| fn=partial(plot_distances, plot_fn=plot_mst), | |
| inputs=[ | |
| model_input, | |
| dataset_input, | |
| average_checkbox, | |
| cluster_method_input, | |
| clusters_input, | |
| plot_figsize_dist_h_input, | |
| plot_figsize_dist_w_input, | |
| ], | |
| outputs=[plot_output, download_plot_button], | |
| ) | |
| with gr.Tab(label="Language Families Subplot"): | |
| checked_families_input = gr.CheckboxGroup( | |
| label="Language Families", | |
| choices=[ | |
| "Afroasiatic", | |
| "Austroasiatic", | |
| "Austronesian", | |
| "Constructed", | |
| "Creole", | |
| "Dravidian", | |
| "Germanic", | |
| "Indo-European", | |
| "Japonic", | |
| "Kartvelian", | |
| "Koreanic", | |
| "Language Isolate", | |
| "Niger-Congo", | |
| "Northeast Caucasian", | |
| "Romance", | |
| "Sino-Tibetan", | |
| "Turkic", | |
| "Uralic", | |
| ], | |
| value=["Indo-European"], | |
| ) | |
| with gr.Row(): | |
| plot_family_button = gr.Button("Plot Families") | |
| plot_figsize_h_input = gr.Slider( | |
| label="Figure Height", minimum=5, maximum=30, step=1, value=15 | |
| ) | |
| plot_figsize_w_input = gr.Slider( | |
| label="Figure Width", minimum=5, maximum=30, step=1, value=15 | |
| ) | |
| with gr.Row(): | |
| download_families_plot_button = gr.DownloadButton( | |
| "Download Plot", value=plot_path | |
| ) | |
| plot_family_output = gr.Plot(label="Families Plot") | |
| plot_family_button.click( | |
| fn=plot_families_subfamilies, | |
| inputs=[ | |
| checked_families_input, | |
| model_input, | |
| dataset_input, | |
| average_checkbox, | |
| plot_figsize_h_input, | |
| plot_figsize_w_input, | |
| ], | |
| outputs=[plot_family_output, download_families_plot_button], | |
| ) | |
| demo.launch(share=True) | |