Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| import os | |
| import time | |
| from tempfile import NamedTemporaryFile, _TemporaryFileWrapper | |
| from typing import Any, Optional, Union | |
| import streamlit as st | |
| import torchaudio | |
| from conette import CoNeTTEModel, conette, __version__ | |
| from conette.utils.collections import dict_list_to_list_dict | |
| from st_audiorec import st_audiorec | |
| from streamlit.runtime.uploaded_file_manager import UploadedFile | |
| from torch import Tensor | |
| ALLOW_REP_MODES = ("stopwords", "all", "none") | |
| DEFAULT_TASK = "audiocaps" | |
| MAX_BEAM_SIZE = 20 | |
| MAX_PRED_SIZE = 30 | |
| MAX_BATCH_SIZE = 16 | |
| RECORD_AUDIO_FNAME = "microphone_conette_record.wav" | |
| DEFAULT_THRESHOLD = 0.3 | |
| THRESHOLD_PRECISION = 100 | |
| MIN_AUDIO_DURATION_SEC = 0.3 | |
| MAX_AUDIO_DURATION_SEC = 60 | |
| HASH_PREFIX = "hash_" | |
| TMP_FILE_PREFIX = "audio_tmp_file_" | |
| SECOND_BEFORE_CLEAR_CACHE = 10 * 60 | |
| def load_conette(*args, **kwargs) -> CoNeTTEModel: | |
| return conette(*args, **kwargs) | |
| def format_candidate(candidate: str) -> str: | |
| if len(candidate) == 0: | |
| return "" | |
| else: | |
| return f"{candidate[0].title()}{candidate[1:]}." | |
| def format_tags(tags: Optional[list[str]]) -> str: | |
| if tags is None or len(tags) == 0: | |
| return "None." | |
| else: | |
| return ", ".join(tags) | |
| def get_result_hash(audio_fname: str, generate_kwds: dict[str, Any]) -> str: | |
| return f"{HASH_PREFIX}{audio_fname}-{generate_kwds}" | |
| def get_results( | |
| model: CoNeTTEModel, | |
| audio_files: dict[str, bytes], | |
| generate_kwds: dict[str, Any], | |
| ) -> dict[str, Union[dict[str, Any], str]]: | |
| # Get audio to be processed | |
| audio_to_predict: dict[str, tuple[str, bytes]] = {} | |
| for audio_fname, audio in audio_files.items(): | |
| result_hash = get_result_hash(audio_fname, generate_kwds) | |
| if result_hash not in st.session_state or audio_fname == RECORD_AUDIO_FNAME: | |
| audio_to_predict[result_hash] = (audio_fname, audio) | |
| # Save audio to be processed | |
| tmp_files: dict[str, _TemporaryFileWrapper] = {} | |
| for result_hash, (audio_fname, audio) in audio_to_predict.items(): | |
| tmp_file = NamedTemporaryFile(delete=False, prefix=TMP_FILE_PREFIX) | |
| tmp_file.write(audio) | |
| tmp_file.close() | |
| metadata = torchaudio.info(tmp_file.name) # type: ignore | |
| duration = metadata.num_frames / metadata.sample_rate | |
| if MIN_AUDIO_DURATION_SEC > duration: | |
| error_msg = f""" | |
| ##### Result for "{audio_fname}" | |
| Audio file is too short. (found {duration:.2f}s but the model expect audio in range [{MIN_AUDIO_DURATION_SEC}, {MAX_AUDIO_DURATION_SEC}]) | |
| """ | |
| st.session_state[result_hash] = error_msg | |
| elif duration > MAX_AUDIO_DURATION_SEC: | |
| error_msg = f""" | |
| ##### Result for "{audio_fname}" | |
| Audio file is too long. (found {duration:.2f}s but the model expect audio in range [{MIN_AUDIO_DURATION_SEC}, {MAX_AUDIO_DURATION_SEC}]) | |
| """ | |
| st.session_state[result_hash] = error_msg | |
| else: | |
| tmp_files[result_hash] = tmp_file | |
| # Generate predictions and store them in session state | |
| for start in range(0, len(tmp_files), MAX_BATCH_SIZE): | |
| end = min(start + MAX_BATCH_SIZE, len(tmp_files)) | |
| result_hashes_j = list(tmp_files.keys())[start:end] | |
| tmp_files_j = list(tmp_files.values())[start:end] | |
| tmp_paths_j = [tmp_file.name for tmp_file in tmp_files_j] | |
| outputs_j = model( | |
| tmp_paths_j, | |
| **generate_kwds, | |
| ) | |
| outputs_lst = dict_list_to_list_dict(outputs_j) # type: ignore | |
| for result_hash, output_i in zip(result_hashes_j, outputs_lst): | |
| st.session_state[result_hash] = output_i | |
| # Get outputs | |
| outputs = {} | |
| for audio_fname in audio_files.keys(): | |
| result_hash = get_result_hash(audio_fname, generate_kwds) | |
| output_i = st.session_state[result_hash] | |
| outputs[audio_fname] = output_i | |
| for tmp_file in tmp_files.values(): | |
| os.remove(tmp_file.name) | |
| return outputs | |
| def show_results(outputs: dict[str, Union[dict[str, Any], str]]) -> None: | |
| keys = list(outputs.keys())[::-1] | |
| outputs = {key: outputs[key] for key in keys} | |
| st.divider() | |
| for audio_fname, output in outputs.items(): | |
| if isinstance(output, str): | |
| st.error(output) | |
| st.divider() | |
| continue | |
| cand: str = output["cands"] | |
| lprobs: Tensor = output["lprobs"] | |
| tags_lst = output.get("tags") | |
| mult_cands: list[str] = output["mult_cands"] | |
| mult_lprobs: Tensor = output["mult_lprobs"] | |
| cand = format_candidate(cand) | |
| prob = lprobs.exp().tolist() | |
| tags = format_tags(tags_lst) | |
| mult_cands = [format_candidate(cand_i) for cand_i in mult_cands] | |
| mult_probs = mult_lprobs.exp() | |
| indexes = mult_probs.argsort(descending=True)[1:] | |
| mult_probs = mult_probs[indexes].tolist() | |
| mult_cands = [mult_cands[idx] for idx in indexes] | |
| if audio_fname == RECORD_AUDIO_FNAME: | |
| header = "##### Result for microphone input:" | |
| else: | |
| header = f'##### Result for "{audio_fname}"' | |
| lines = [ | |
| header, | |
| f'<center><p class="space"><p class="big-font">"{cand}"</p></p></center>', | |
| ] | |
| st.markdown( | |
| """ | |
| <style> | |
| .big-font { | |
| font-size:22px !important; | |
| background-color: rgba(0, 255, 0, 0.1); | |
| padding: 10px; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| content = "<br>".join(lines) | |
| st.markdown(content, unsafe_allow_html=True) | |
| lines = [ | |
| f"- **Probability**: {prob*100:.1f}%", | |
| ] | |
| if len(mult_cands) > 0: | |
| msg = f"- **Other descriptions:**" | |
| lines.append(msg) | |
| for cand_i, prob_i in zip(mult_cands, mult_probs): | |
| msg = f' - "{cand_i}" ({prob_i*100:.1f}%)' | |
| lines.append(msg) | |
| msg = f"- **Tags:** {tags}" | |
| lines.append(msg) | |
| content = "\n".join(lines) | |
| st.markdown(content, unsafe_allow_html=False) | |
| st.divider() | |
| def main() -> None: | |
| model = load_conette(model_kwds=dict(device="cpu")) | |
| st.header("Describe audio content with CoNeTTE") | |
| st.markdown( | |
| "This interface allows you to generate a short description of the sound events of any recording using an Audio Captioning system. You can try it from your microphone or upload a file below." | |
| ) | |
| st.markdown( | |
| "Use '**Start Recording**' and '**Stop**' to record an audio from your microphone." | |
| ) | |
| record_data = st_audiorec() | |
| with st.expander("Or upload audio files here:"): | |
| audio_files: Optional[list[UploadedFile]] = st.file_uploader( | |
| f"Audio files are automatically resampled to 32 kHz.\nTheir duration must be in range [{MIN_AUDIO_DURATION_SEC}, {MAX_AUDIO_DURATION_SEC}] seconds.", | |
| type=["wav", "flac", "mp3", "ogg", "avi"], | |
| accept_multiple_files=True, | |
| help="Supports wav, flac, mp3, ogg and avi files.", | |
| ) | |
| with st.expander("Model options"): | |
| if DEFAULT_TASK in model.tasks: | |
| default_task_idx = list(model.tasks).index(DEFAULT_TASK) | |
| else: | |
| default_task_idx = 0 | |
| task = st.selectbox("Task embedding input", model.tasks, default_task_idx) | |
| allow_rep_mode = st.selectbox("Allow repetition of words", ALLOW_REP_MODES, 0) | |
| beam_size: int = st.select_slider( # type: ignore | |
| "Beam size", | |
| list(range(1, MAX_BEAM_SIZE + 1)), | |
| model.config.beam_size, | |
| ) | |
| min_pred_size, max_pred_size = st.slider( | |
| "Minimal and maximal number of words", | |
| 1, | |
| MAX_PRED_SIZE, | |
| (model.config.min_pred_size, model.config.max_pred_size), | |
| ) | |
| threshold = st.select_slider( | |
| "Tags threshold", | |
| [(i / THRESHOLD_PRECISION) for i in range(THRESHOLD_PRECISION + 1)], | |
| DEFAULT_THRESHOLD, | |
| ) | |
| if allow_rep_mode == "all": | |
| forbid_rep_mode = "none" | |
| elif allow_rep_mode == "none": | |
| forbid_rep_mode = "all" | |
| elif allow_rep_mode == "stopwords": | |
| forbid_rep_mode = "content_words" | |
| else: | |
| msg = ( | |
| f"Unknown option {allow_rep_mode=}. (expected one of {ALLOW_REP_MODES})" | |
| ) | |
| raise ValueError(msg) | |
| del allow_rep_mode | |
| generate_kwds: dict[str, Any] = dict( | |
| task=task, | |
| beam_size=beam_size, | |
| min_pred_size=min_pred_size, | |
| max_pred_size=max_pred_size, | |
| forbid_rep_mode=forbid_rep_mode, | |
| threshold=threshold, | |
| ) | |
| audios: dict[str, bytes] = {} | |
| if audio_files is not None: | |
| audios |= {audio.name: audio.getvalue() for audio in audio_files} | |
| if record_data is not None: | |
| audios |= {RECORD_AUDIO_FNAME: record_data} | |
| if len(audios) > 0: | |
| with st.spinner("Generating descriptions..."): | |
| outputs = get_results(model, audios, generate_kwds) | |
| st.header("Results:") | |
| show_results(outputs) | |
| current = time.perf_counter() | |
| last_generation = st.session_state.get("last_generation", current) | |
| if current > last_generation + SECOND_BEFORE_CLEAR_CACHE: | |
| print(f"Removing result cache...") | |
| for key in st.session_state.keys(): | |
| if isinstance(key, str) and key.startswith(HASH_PREFIX): | |
| del st.session_state[key] | |
| st.session_state["last_generation"] = current | |
| content = f"""CoNeTTE version {__version__}. <a href="https://github.com/Labbeti/conette-audio-captioning/">Source code on GitHub</a>. <a href="https://ieeexplore.ieee.org/document/10603439">Academic Paper</a>.""" | |
| st.divider() | |
| st.markdown(content, unsafe_allow_html=True) | |
| if __name__ == "__main__": | |
| main() | |