import streamlit as st import pandas as pd import numpy as np import tensorflow as tf import seaborn as sns import matplotlib.pyplot as plt from tensorflow.keras.models import load_model # Configure styling sns.set_theme(style="whitegrid") st.set_page_config( page_title="Federated Learning for Anomaly Detection in IOT Environments", page_icon="๐ก๏ธ", layout="wide", initial_sidebar_state="expanded" ) # Load the pre-trained model @st.cache_resource def load_intrusion_model(): return load_model('intrusion_model.h5') # Define attack type labels ATTACK_TYPES = { 0: 'Normal', 1: 'Backdoor', 2: 'DDoS_HTTP', 3: 'DDoS_ICMP', 4: 'DDoS_TCP', 5: 'DDoS_UDP', 6: 'Fingerprinting', 7: 'MITM', 8: 'Password', 9: 'Port_Scanning', 10: 'Ransomware', 11: 'SQL_injection', 12: 'Uploading', 13: 'Vulnerability_scanner', 14: 'XSS' } # Critical attacks that trigger alerts CRITICAL_ATTACKS = { 'DDoS_HTTP', 'DDoS_ICMP', 'DDoS_TCP', 'DDoS_UDP', 'Ransomware', 'SQL_injection', 'Port_Scanning' } # Create the Streamlit app def main(): # Sidebar with model information st.sidebar.header("About") st.sidebar.markdown(""" **Federated Learning for Anomaly Detection in IOT Environments** This system detects and classifies cyber attacks on IoT networks using deep learning. The model achieves 93.6% accuracy on validation data. """) st.sidebar.subheader("Attack Types") for code, name in ATTACK_TYPES.items(): st.sidebar.caption(f"{code}: {name}") st.sidebar.subheader("Attack Severity") st.sidebar.markdown(""" - ๐ด **Critical**: DDoS, Ransomware, SQL Injection - ๐ **High**: Port Scanning, Backdoor - ๐ข **Medium**: Other attacks - โช **Normal**: Benign traffic """) st.sidebar.divider() st.sidebar.info("๏นซ2025") st.sidebar.download_button( label="Download Sample CSV", data=pd.DataFrame(columns=range(1, 250)).to_csv(index=False), file_name="sample_features.csv", mime="text/csv" ) # Main content st.title("๐ก๏ธ Federated Learning for Anomaly Detection in IOT Environments") st.caption("Detect and classify security threats in IoT network traffic") # Initialize session state if 'predictions' not in st.session_state: st.session_state.predictions = None if 'critical_alerts' not in st.session_state: st.session_state.critical_alerts = [] # Load model try: model = load_intrusion_model() except Exception as e: st.error(f"Error loading model: {str(e)}") st.stop() # Alert banner area at top alert_placeholder = st.empty() # Prediction section tab1, tab2 = st.tabs(["๐ Batch Prediction", "๐ Single Prediction"]) with tab1: st.subheader("Batch Prediction from CSV") uploaded_file = st.file_uploader("Upload IoT device data (CSV)", type="csv") if uploaded_file: try: df = pd.read_csv(uploaded_file) st.success(f"Successfully loaded {len(df)} records") # Show sample data if st.checkbox("Show data preview"): st.dataframe(df.head()) # Validate features if len(df.columns) != 249: st.warning(f"Data should have 249 features. Found {len(df.columns)} columns.") st.info("Ensure your CSV has exactly 249 columns representing the model features") else: # Make predictions if st.button("Run Predictions", type="primary"): with st.spinner("Analyzing network traffic..."): # Preprocess and predict X = df.values.astype('float32') pred_probs = model.predict(X, verbose=0) pred_classes = np.argmax(pred_probs, axis=1) confidence_scores = np.max(pred_probs, axis=1) # Add predictions to dataframe df['Predicted_Attack'] = [ATTACK_TYPES[c] for c in pred_classes] df['Prediction_Confidence'] = confidence_scores # Store in session state st.session_state.predictions = df st.session_state.critical_alerts = df[ df['Predicted_Attack'].isin(CRITICAL_ATTACKS) ] except Exception as e: st.error(f"Error processing file: {str(e)}") # Display results if available if st.session_state.predictions is not None: df = st.session_state.predictions # Critical attack alert if not st.session_state.critical_alerts.empty: critical_count = len(st.session_state.critical_alerts) with alert_placeholder.container(): st.error(f"๐จ **CRITICAL THREAT DETECTED!** - {critical_count} critical attacks identified", icon="โ ๏ธ") st.subheader("Prediction Results") # Summary stats normal_count = len(df[df['Predicted_Attack'] == 'Normal']) attack_count = len(df) - normal_count critical_count = len(st.session_state.critical_alerts) col1, col2, col3 = st.columns(3) col1.metric("Total Records", len(df)) col2.metric("Attack Traffic", f"{attack_count} ({attack_count/len(df):.1%})") col3.metric("Critical Threats", critical_count, f"{critical_count/attack_count:.1%}" if attack_count else "0%") # Visualization section st.subheader("Attack Analysis") # Tabs for different visualizations viz_tab1, viz_tab2, viz_tab3, viz_tab4 = st.tabs([ "Attack Distribution", "Confidence Analysis", "Threat Severity", "Detailed Results" ]) with viz_tab1: col1, col2 = st.columns([3, 2]) with col1: # Attack type bar chart st.markdown("**Attack Type Distribution**") attack_counts = df['Predicted_Attack'].value_counts() fig, ax = plt.subplots(figsize=(10, 6)) sns.barplot( x=attack_counts.values, y=attack_counts.index, palette="viridis", ax=ax ) plt.xlabel("Count") plt.ylabel("Attack Type") plt.title("Attack Frequency Distribution") st.pyplot(fig) with col2: # Attack type pie chart st.markdown("**Attack Proportion**") normal_attack = df['Predicted_Attack'] != 'Normal' attack_ratio = normal_attack.value_counts(normalize=True) fig, ax = plt.subplots(figsize=(8, 6)) attack_ratio.plot.pie( autopct='%1.1f%%', labels=['Normal', 'Attack'], colors=['#2ca02c', '#d62728'], startangle=90, ax=ax ) plt.title("Normal vs Attack Traffic") plt.ylabel("") st.pyplot(fig) with viz_tab2: col1, col2 = st.columns(2) with col1: # Confidence histogram st.markdown("**Confidence Distribution**") fig, ax = plt.subplots(figsize=(10, 6)) sns.histplot( df['Prediction_Confidence'], bins=20, kde=True, color='#1f77b4', ax=ax ) plt.axvline(x=0.9, color='r', linestyle='--', label='High Confidence') plt.xlabel("Confidence Score") plt.ylabel("Frequency") plt.title("Prediction Confidence Distribution") plt.legend() st.pyplot(fig) with col2: # Confidence by attack type st.markdown("**Confidence by Attack Type**") fig, ax = plt.subplots(figsize=(10, 6)) sns.boxplot( x=df['Prediction_Confidence'], y=df['Predicted_Attack'], palette="Set3", ax=ax ) plt.xlabel("Confidence Score") plt.ylabel("Attack Type") plt.title("Confidence Distribution per Attack Type") st.pyplot(fig) with viz_tab3: # Define severity levels severity_map = { 'Normal': 'Normal', 'DDoS_HTTP': 'Critical', 'DDoS_ICMP': 'Critical', 'DDoS_TCP': 'Critical', 'DDoS_UDP': 'Critical', 'Ransomware': 'Critical', 'SQL_injection': 'Critical', 'Port_Scanning': 'High', 'Backdoor': 'High', 'Fingerprinting': 'Medium', 'MITM': 'Medium', 'Password': 'Medium', 'Uploading': 'Medium', 'Vulnerability_scanner': 'Medium', 'XSS': 'Medium' } df['Severity'] = df['Predicted_Attack'].map(severity_map) col1, col2 = st.columns(2) with col1: # Severity pie chart st.markdown("**Threat Severity Distribution**") severity_counts = df['Severity'].value_counts() fig, ax = plt.subplots(figsize=(8, 8)) colors = {'Critical': '#d62728', 'High': '#ff7f0e', 'Medium': '#e377c2', 'Normal': '#2ca02c'} severity_counts.plot.pie( autopct='%1.1f%%', colors=[colors[s] for s in severity_counts.index], startangle=90, ax=ax ) plt.title("Threat Severity Levels") plt.ylabel("") st.pyplot(fig) with col2: # Severity count plot st.markdown("**Threat Severity Counts**") fig, ax = plt.subplots(figsize=(10, 6)) sns.countplot( x=df['Severity'], order=['Critical', 'High', 'Medium', 'Normal'], palette=list(colors.values()), ax=ax ) plt.xlabel("Severity Level") plt.ylabel("Count") plt.title("Threat Severity Distribution") st.pyplot(fig) with viz_tab4: # Detailed results table st.dataframe(df[['Predicted_Attack', 'Prediction_Confidence', 'Severity']].head(50)) # Download results st.divider() csv = df.to_csv(index=False) st.download_button( label="Download Full Predictions", data=csv, file_name="intrusion_predictions.csv", mime="text/csv", type="primary" ) with tab2: st.subheader("Single Prediction") st.markdown("Enter feature values manually for real-time threat detection") # Create input form with st.form("single_prediction"): # Generate sample input features sample_features = [0.0] * 249 inputs = [] st.info("For demonstration, only the first 10 features are shown. Others are set to default values.") # Split into 3 columns for better layout col1, col2, col3 = st.columns(3) cols = [col1, col2, col3] # Only show first 10 features to save space features_to_show = 10 for i in range(features_to_show): with cols[i % 3]: inputs.append( st.number_input( f"Feature {i+1}", value=sample_features[i], key=f"feature_{i}", step=0.001 ) ) # Fill remaining features with default values inputs += sample_features[features_to_show:] submit = st.form_submit_button("Analyze Traffic", type="primary") if submit: try: # Prepare input data input_array = np.array([inputs], dtype='float32') # Make prediction pred_prob = model.predict(input_array, verbose=0) pred_class = np.argmax(pred_prob, axis=1)[0] confidence = np.max(pred_prob) attack_name = ATTACK_TYPES[pred_class] # Check if critical is_critical = attack_name in CRITICAL_ATTACKS # Display alert if is_critical: with alert_placeholder.container(): st.error(f"๐จ **CRITICAL THREAT DETECTED!** - {attack_name} attack identified", icon="โ ๏ธ") # Display results st.subheader("Analysis Result") # Create columns for results col1, col2 = st.columns([1, 2]) with col1: # Attack type card severity = "Critical" if is_critical else "Normal" if attack_name == "Normal" else "Warning" color = "#d62728" if is_critical else "#2ca02c" if attack_name == "Normal" else "#ff7f0e" st.markdown(f"""
Threat Level: {severity}
Confidence: {confidence:.2%}