Spaces:
Sleeping
Sleeping
| """ | |
| Synthex Medical Text Generator - MVP Streamlit App | |
| Deploy this on Hugging Face Spaces for free hosting | |
| """ | |
| import streamlit as st | |
| import json | |
| import time | |
| from datetime import datetime | |
| import pandas as pd | |
| import os | |
| import sys | |
| import logging | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Add src directory to Python path | |
| sys.path.append(os.path.join(os.path.dirname(__file__), 'src')) | |
| # Import the medical generator | |
| from src.generation.medical_generator import MedicalTextGenerator, DEFAULT_GEMINI_API_KEY | |
| # Page config | |
| st.set_page_config( | |
| page_title="Synthex Medical Text Generator", | |
| page_icon="π₯", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom CSS | |
| st.markdown(""" | |
| <style> | |
| /* Main container styling */ | |
| .main { | |
| padding: 2rem; | |
| background-color: #f8f9fa; | |
| } | |
| /* Header styling */ | |
| .main-header { | |
| font-size: 2.5rem; | |
| font-weight: bold; | |
| color: #1f77b4; | |
| text-align: center; | |
| margin-bottom: 1rem; | |
| padding: 1rem; | |
| background: linear-gradient(135deg, #1f77b4 0%, #2c9cdb 100%); | |
| color: white; | |
| border-radius: 10px; | |
| box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
| } | |
| .sub-header { | |
| font-size: 1.2rem; | |
| color: #666; | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| padding: 0.5rem; | |
| } | |
| /* Card styling */ | |
| .record-container { | |
| background-color: white; | |
| padding: 1.5rem; | |
| border-radius: 10px; | |
| border-left: 4px solid #1f77b4; | |
| margin: 1rem 0; | |
| box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05); | |
| transition: transform 0.2s; | |
| } | |
| .record-container:hover { | |
| transform: translateY(-2px); | |
| box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); | |
| } | |
| /* Stats container styling */ | |
| .stats-container { | |
| background-color: white; | |
| padding: 1.5rem; | |
| border-radius: 10px; | |
| margin: 1rem 0; | |
| box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05); | |
| } | |
| /* Button styling */ | |
| .stButton>button { | |
| width: 100%; | |
| border-radius: 5px; | |
| height: 3em; | |
| font-weight: bold; | |
| transition: all 0.3s; | |
| } | |
| .stButton>button:hover { | |
| transform: translateY(-2px); | |
| box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); | |
| } | |
| /* Metric styling */ | |
| .stMetric { | |
| background-color: #f8f9fa; | |
| padding: 1rem; | |
| border-radius: 5px; | |
| text-align: center; | |
| } | |
| /* Sidebar styling */ | |
| .sidebar .sidebar-content { | |
| background-color: #f8f9fa; | |
| } | |
| /* Progress bar styling */ | |
| .stProgress > div > div { | |
| background-color: #1f77b4; | |
| } | |
| /* Success message styling */ | |
| .stSuccess { | |
| padding: 1rem; | |
| border-radius: 5px; | |
| background-color: #d4edda; | |
| color: #155724; | |
| margin: 1rem 0; | |
| } | |
| /* Error message styling */ | |
| .stError { | |
| padding: 1rem; | |
| border-radius: 5px; | |
| background-color: #f8d7da; | |
| color: #721c24; | |
| margin: 1rem 0; | |
| } | |
| /* Expander styling */ | |
| .streamlit-expanderHeader { | |
| font-size: 1.1rem; | |
| font-weight: bold; | |
| color: #1f77b4; | |
| } | |
| /* Text area styling */ | |
| .stTextArea textarea { | |
| font-family: monospace; | |
| font-size: 0.9rem; | |
| line-height: 1.5; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Initialize session state | |
| if 'generated_records' not in st.session_state: | |
| st.session_state.generated_records = [] | |
| if 'total_generated' not in st.session_state: | |
| st.session_state.total_generated = 0 | |
| if 'generator' not in st.session_state: | |
| st.session_state.generator = None | |
| # Header | |
| st.markdown('<div class="main-header">π₯ Synthex Medical Text Generator</div>', unsafe_allow_html=True) | |
| st.markdown('<div class="sub-header">Generate synthetic medical records for AI training and testing</div>', unsafe_allow_html=True) | |
| # Add a status message area | |
| status_area = st.empty() | |
| # Sidebar | |
| with st.sidebar: | |
| st.markdown("### βοΈ Configuration") | |
| # API Key section | |
| with st.expander("π API Settings", expanded=False): | |
| gemini_api_key = st.text_input( | |
| "Gemini API Key", | |
| value=os.getenv('GEMINI_API_KEY', ''), | |
| type="password", | |
| help="Enter your Google Gemini API key for better generation quality" | |
| ) | |
| # Record settings | |
| st.markdown("### π Record Settings") | |
| record_type = st.selectbox( | |
| "Select Record Type", | |
| ["clinical_note", "discharge_summary", "lab_report", "prescription", "patient_intake"], | |
| format_func=lambda x: x.replace("_", " ").title() | |
| ) | |
| quantity = st.slider("Number of Records", 1, 20, 5) | |
| # Generation settings | |
| st.markdown("### π€ Generation Settings") | |
| use_gemini = st.checkbox( | |
| "Use Gemini API", | |
| value=False, | |
| help="Uses Google Gemini API for better quality generation" | |
| ) | |
| # Advanced options | |
| with st.expander("β‘ Advanced Options"): | |
| include_metadata = st.checkbox("Include Metadata", value=True) | |
| export_format = st.selectbox("Export Format", ["JSON", "CSV", "TXT"]) | |
| # Main content with better organization | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.markdown("### π Generate Records") | |
| # Generation button with better styling | |
| if st.button("π Generate Records", type="primary", use_container_width=True): | |
| status_area.info("Initializing generator...") | |
| # Initialize generator if not already done | |
| if st.session_state.generator is None: | |
| try: | |
| with st.spinner("Initializing medical text generator..."): | |
| st.session_state.generator = MedicalTextGenerator(gemini_api_key=gemini_api_key) | |
| status_area.success("Generator initialized successfully!") | |
| except Exception as e: | |
| status_area.error(f"Error initializing generator: {str(e)}") | |
| st.stop() | |
| # Generate records with progress | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| generated_records = [] | |
| for i in range(quantity): | |
| status_text.text(f"Generating record {i+1} of {quantity}...") | |
| progress_bar.progress((i + 1) / quantity) | |
| try: | |
| record = st.session_state.generator.generate_record(record_type, use_gemini=use_gemini) | |
| generated_records.append(record) | |
| # Rate limiting | |
| if use_gemini: | |
| time.sleep(1) | |
| except Exception as e: | |
| logger.error(f"Failed to generate record {i+1}: {str(e)}") | |
| status_area.error(f"Failed to generate record {i+1}: {str(e)}") | |
| continue | |
| # Update session state | |
| if generated_records: | |
| st.session_state.generated_records.extend(generated_records) | |
| st.session_state.total_generated += len(generated_records) | |
| status_text.text("β Generation complete!") | |
| progress_bar.progress(1.0) | |
| status_area.success(f"Successfully generated {len(generated_records)} medical records!") | |
| # Display generated records with better organization | |
| if st.session_state.generated_records: | |
| st.markdown("### π Generated Records") | |
| # Filters with better layout | |
| col_filter1, col_filter2 = st.columns(2) | |
| with col_filter1: | |
| filter_type = st.selectbox( | |
| "Filter by Type", | |
| ["All"] + list(set([r.get('type', 'Unknown') for r in st.session_state.generated_records])) | |
| ) | |
| with col_filter2: | |
| records_per_page = st.selectbox("Records per page", [5, 10, 20, 50]) | |
| # Filter records | |
| filtered_records = st.session_state.generated_records | |
| if filter_type != "All": | |
| filtered_records = [r for r in filtered_records if r.get('type', 'Unknown') == filter_type] | |
| # Pagination | |
| total_records = len(filtered_records) | |
| total_pages = (total_records - 1) // records_per_page + 1 | |
| if total_pages > 1: | |
| page = st.selectbox("Page", range(1, total_pages + 1)) | |
| start_idx = (page - 1) * records_per_page | |
| end_idx = start_idx + records_per_page | |
| page_records = filtered_records[start_idx:end_idx] | |
| else: | |
| page_records = filtered_records | |
| # Display records with better styling | |
| for i, record in enumerate(page_records): | |
| with st.expander(f"Record {record.get('id', 'Unknown')} - {record.get('type', 'Unknown').replace('_', ' ').title()}"): | |
| if include_metadata: | |
| col_meta1, col_meta2, col_meta3 = st.columns(3) | |
| with col_meta1: | |
| st.metric("Type", record.get('type', 'Unknown').replace('_', ' ').title()) | |
| with col_meta2: | |
| st.metric("Generated", record.get('timestamp', 'N/A')) | |
| with col_meta3: | |
| st.metric("Source", record.get('source', 'Hugging Face')) | |
| st.markdown('<div class="record-container">', unsafe_allow_html=True) | |
| st.text_area("Content", record.get('text', 'No content available'), height=200, key=f"record_{i}") | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| with col2: | |
| st.markdown("### π Statistics") | |
| # Stats container with better styling | |
| st.markdown('<div class="stats-container">', unsafe_allow_html=True) | |
| # Total records | |
| st.metric("Total Records Generated", st.session_state.total_generated) | |
| # Record type distribution with better visualization | |
| if st.session_state.generated_records: | |
| type_counts = pd.Series([r.get('type', 'Unknown') for r in st.session_state.generated_records]).value_counts() | |
| st.markdown("#### Record Type Distribution") | |
| st.bar_chart(type_counts) | |
| # Export options with better organization | |
| st.markdown("#### πΎ Export Data") | |
| if st.session_state.generated_records: | |
| if export_format == "JSON": | |
| json_str = json.dumps(st.session_state.generated_records, indent=2) | |
| st.download_button( | |
| "π₯ Download JSON", | |
| json_str, | |
| file_name=f"medical_records_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", | |
| mime="application/json", | |
| use_container_width=True | |
| ) | |
| elif export_format == "CSV": | |
| df = pd.DataFrame(st.session_state.generated_records) | |
| csv = df.to_csv(index=False) | |
| st.download_button( | |
| "π₯ Download CSV", | |
| csv, | |
| file_name=f"medical_records_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv", | |
| mime="text/csv", | |
| use_container_width=True | |
| ) | |
| elif export_format == "TXT": | |
| txt = "\n\n".join([f"Record {r.get('id', 'Unknown')} ({r.get('type', 'Unknown')}):\n{r.get('text', 'No content available')}" for r in st.session_state.generated_records]) | |
| st.download_button( | |
| "π₯ Download TXT", | |
| txt, | |
| file_name=f"medical_records_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt", | |
| mime="text/plain", | |
| use_container_width=True | |
| ) | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| # Add a footer | |
| st.markdown("---") | |
| st.markdown(""" | |
| <div style='text-align: center; color: #666;'> | |
| <p>Built with β€οΈ using Streamlit | Synthex Medical Text Generator</p> | |
| </div> | |
| """, unsafe_allow_html=True) |