# app.py import streamlit as st import torch import joblib import dill import numpy as np import gdown import os # Load assets with caching @st.cache_resource def load_assets(): with open("preprocess_function.pkl", "rb") as f: preprocess_text = dill.load(f) tfidf = joblib.load("tfidf_vectorizer.pkl") model = joblib.load("sage_model.pkl") return preprocess_text, tfidf, model # Download KNN model from Google Drive if not present def ensure_knn_model(): knn_path = "knn_model.pkl" if not os.path.exists(knn_path): gdown.download( "https://drive.google.com/uc?id=166HWcckEVofU1TzVpZPNzbHdjxV_SqpT", knn_path, quiet=False ) return joblib.load(knn_path) # Load models preprocess_text, tfidf_vectorizer, sage_model = load_assets() knn_model = ensure_knn_model() # App UI st.title("🧠 Disinformation Detection") st.write("This app predicts whether a given news article is **real** or **disinformation** using a trained GraphSAGE model.") # User input user_input = st.text_area("📝 Enter a news article or headline:") if st.button("Detect"): if user_input.strip() == "": st.warning("Please enter some text to analyze.") else: # STEP 1: Preprocess user input cleaned_text = preprocess_text(user_input) tfidf_vector = tfidf_vectorizer.transform([cleaned_text]) input_feature = torch.tensor(tfidf_vector.toarray(), dtype=torch.float) # STEP 2: Get original feature set from KNN model original_features = torch.tensor(knn_model._fit_X, dtype=torch.float) # STEP 3: Combine input with training data features combined_features = torch.cat([original_features, input_feature], dim=0) # STEP 4: Build edge index using k-NN neighbors = knn_model.kneighbors(combined_features, return_distance=False) edge_list = [] for idx, nbrs in enumerate(neighbors): for nbr in nbrs: if idx != nbr: edge_list.append([idx, nbr]) edge_index = torch.tensor(np.array(edge_list).T, dtype=torch.long) # STEP 5: Run inference on the last node (user input) sage_model.eval() with torch.no_grad(): logits = sage_model(combined_features, edge_index) pred_node_logits = logits[-1] # Last node is the user input prediction = torch.argmax(pred_node_logits).item() confidence = torch.exp(pred_node_logits)[prediction].item() # STEP 6: Display result label = "🟢 Real News" if prediction == 1 else "🔴 Disinformation" st.markdown(f"### Prediction: {label}") st.markdown(f"**Confidence:** {confidence:.2%}")