Update app.py
Browse files
app.py
CHANGED
|
@@ -1,10 +1,14 @@
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import torch
|
| 3 |
import joblib
|
| 4 |
import dill
|
| 5 |
import numpy as np
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
# Load assets
|
| 8 |
@st.cache_resource
|
| 9 |
def load_assets():
|
| 10 |
with open("preprocess_function.pkl", "rb") as f:
|
|
@@ -13,31 +17,61 @@ def load_assets():
|
|
| 13 |
model = joblib.load("sage_model.pkl")
|
| 14 |
return preprocess_text, tfidf, model
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
preprocess_text, tfidf_vectorizer, sage_model = load_assets()
|
|
|
|
| 17 |
|
| 18 |
-
# App
|
| 19 |
st.title("π§ Disinformation Detection")
|
| 20 |
st.write("This app predicts whether a given news article is **real** or **disinformation** using a trained GraphSAGE model.")
|
| 21 |
|
| 22 |
-
#
|
| 23 |
user_input = st.text_area("π Enter a news article or headline:")
|
| 24 |
|
| 25 |
if st.button("Detect"):
|
| 26 |
if user_input.strip() == "":
|
| 27 |
st.warning("Please enter some text to analyze.")
|
| 28 |
else:
|
| 29 |
-
# Preprocess input
|
| 30 |
cleaned_text = preprocess_text(user_input)
|
| 31 |
tfidf_vector = tfidf_vectorizer.transform([cleaned_text])
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
#
|
| 35 |
sage_model.eval()
|
| 36 |
with torch.no_grad():
|
| 37 |
-
logits = sage_model(
|
| 38 |
-
|
| 39 |
-
|
|
|
|
| 40 |
|
|
|
|
| 41 |
label = "π’ Real News" if prediction == 1 else "π΄ Disinformation"
|
| 42 |
st.markdown(f"### Prediction: {label}")
|
| 43 |
-
st.markdown(f"**Confidence:** {
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
|
| 3 |
import streamlit as st
|
| 4 |
import torch
|
| 5 |
import joblib
|
| 6 |
import dill
|
| 7 |
import numpy as np
|
| 8 |
+
import gdown
|
| 9 |
+
import os
|
| 10 |
|
| 11 |
+
# Load assets with caching
|
| 12 |
@st.cache_resource
|
| 13 |
def load_assets():
|
| 14 |
with open("preprocess_function.pkl", "rb") as f:
|
|
|
|
| 17 |
model = joblib.load("sage_model.pkl")
|
| 18 |
return preprocess_text, tfidf, model
|
| 19 |
|
| 20 |
+
# Download KNN model from Google Drive if not present
|
| 21 |
+
def ensure_knn_model():
|
| 22 |
+
knn_path = "knn_model.pkl"
|
| 23 |
+
if not os.path.exists(knn_path):
|
| 24 |
+
gdown.download(
|
| 25 |
+
"https://drive.google.com/uc?id=166HWcckEVofU1TzVpZPNzbHdjxV_SqpT",
|
| 26 |
+
knn_path,
|
| 27 |
+
quiet=False
|
| 28 |
+
)
|
| 29 |
+
return joblib.load(knn_path)
|
| 30 |
+
|
| 31 |
+
# Load models
|
| 32 |
preprocess_text, tfidf_vectorizer, sage_model = load_assets()
|
| 33 |
+
knn_model = ensure_knn_model()
|
| 34 |
|
| 35 |
+
# App UI
|
| 36 |
st.title("π§ Disinformation Detection")
|
| 37 |
st.write("This app predicts whether a given news article is **real** or **disinformation** using a trained GraphSAGE model.")
|
| 38 |
|
| 39 |
+
# User input
|
| 40 |
user_input = st.text_area("π Enter a news article or headline:")
|
| 41 |
|
| 42 |
if st.button("Detect"):
|
| 43 |
if user_input.strip() == "":
|
| 44 |
st.warning("Please enter some text to analyze.")
|
| 45 |
else:
|
| 46 |
+
# STEP 1: Preprocess user input
|
| 47 |
cleaned_text = preprocess_text(user_input)
|
| 48 |
tfidf_vector = tfidf_vectorizer.transform([cleaned_text])
|
| 49 |
+
input_feature = torch.tensor(tfidf_vector.toarray(), dtype=torch.float)
|
| 50 |
+
|
| 51 |
+
# STEP 2: Get original feature set from KNN model
|
| 52 |
+
original_features = torch.tensor(knn_model._fit_X, dtype=torch.float)
|
| 53 |
+
|
| 54 |
+
# STEP 3: Combine input with training data features
|
| 55 |
+
combined_features = torch.cat([original_features, input_feature], dim=0)
|
| 56 |
+
|
| 57 |
+
# STEP 4: Build edge index using k-NN
|
| 58 |
+
neighbors = knn_model.kneighbors(combined_features, return_distance=False)
|
| 59 |
+
edge_list = []
|
| 60 |
+
for idx, nbrs in enumerate(neighbors):
|
| 61 |
+
for nbr in nbrs:
|
| 62 |
+
if idx != nbr:
|
| 63 |
+
edge_list.append([idx, nbr])
|
| 64 |
+
edge_index = torch.tensor(np.array(edge_list).T, dtype=torch.long)
|
| 65 |
|
| 66 |
+
# STEP 5: Run inference on the last node (user input)
|
| 67 |
sage_model.eval()
|
| 68 |
with torch.no_grad():
|
| 69 |
+
logits = sage_model(combined_features, edge_index)
|
| 70 |
+
pred_node_logits = logits[-1] # Last node is the user input
|
| 71 |
+
prediction = torch.argmax(pred_node_logits).item()
|
| 72 |
+
confidence = torch.exp(pred_node_logits)[prediction].item()
|
| 73 |
|
| 74 |
+
# STEP 6: Display result
|
| 75 |
label = "π’ Real News" if prediction == 1 else "π΄ Disinformation"
|
| 76 |
st.markdown(f"### Prediction: {label}")
|
| 77 |
+
st.markdown(f"**Confidence:** {confidence:.2%}")
|