MrUtakata commited on
Commit
5d3091e
Β·
verified Β·
1 Parent(s): ac77892

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -10
app.py CHANGED
@@ -1,10 +1,14 @@
 
 
1
  import streamlit as st
2
  import torch
3
  import joblib
4
  import dill
5
  import numpy as np
 
 
6
 
7
- # Load assets
8
  @st.cache_resource
9
  def load_assets():
10
  with open("preprocess_function.pkl", "rb") as f:
@@ -13,31 +17,61 @@ def load_assets():
13
  model = joblib.load("sage_model.pkl")
14
  return preprocess_text, tfidf, model
15
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  preprocess_text, tfidf_vectorizer, sage_model = load_assets()
 
17
 
18
- # App title
19
  st.title("🧠 Disinformation Detection")
20
  st.write("This app predicts whether a given news article is **real** or **disinformation** using a trained GraphSAGE model.")
21
 
22
- # Input text
23
  user_input = st.text_area("πŸ“ Enter a news article or headline:")
24
 
25
  if st.button("Detect"):
26
  if user_input.strip() == "":
27
  st.warning("Please enter some text to analyze.")
28
  else:
29
- # Preprocess input
30
  cleaned_text = preprocess_text(user_input)
31
  tfidf_vector = tfidf_vectorizer.transform([cleaned_text])
32
- features = torch.tensor(tfidf_vector.toarray(), dtype=torch.float)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- # Predict
35
  sage_model.eval()
36
  with torch.no_grad():
37
- logits = sage_model(features, torch.empty((2, 0), dtype=torch.long)) # dummy edge_index
38
- prediction = torch.argmax(logits, dim=1).item()
39
- prob = torch.exp(logits)[0, prediction].item()
 
40
 
 
41
  label = "🟒 Real News" if prediction == 1 else "πŸ”΄ Disinformation"
42
  st.markdown(f"### Prediction: {label}")
43
- st.markdown(f"**Confidence:** {prob:.2%}")
 
1
+ # app.py
2
+
3
  import streamlit as st
4
  import torch
5
  import joblib
6
  import dill
7
  import numpy as np
8
+ import gdown
9
+ import os
10
 
11
+ # Load assets with caching
12
  @st.cache_resource
13
  def load_assets():
14
  with open("preprocess_function.pkl", "rb") as f:
 
17
  model = joblib.load("sage_model.pkl")
18
  return preprocess_text, tfidf, model
19
 
20
+ # Download KNN model from Google Drive if not present
21
+ def ensure_knn_model():
22
+ knn_path = "knn_model.pkl"
23
+ if not os.path.exists(knn_path):
24
+ gdown.download(
25
+ "https://drive.google.com/uc?id=166HWcckEVofU1TzVpZPNzbHdjxV_SqpT",
26
+ knn_path,
27
+ quiet=False
28
+ )
29
+ return joblib.load(knn_path)
30
+
31
+ # Load models
32
  preprocess_text, tfidf_vectorizer, sage_model = load_assets()
33
+ knn_model = ensure_knn_model()
34
 
35
+ # App UI
36
  st.title("🧠 Disinformation Detection")
37
  st.write("This app predicts whether a given news article is **real** or **disinformation** using a trained GraphSAGE model.")
38
 
39
+ # User input
40
  user_input = st.text_area("πŸ“ Enter a news article or headline:")
41
 
42
  if st.button("Detect"):
43
  if user_input.strip() == "":
44
  st.warning("Please enter some text to analyze.")
45
  else:
46
+ # STEP 1: Preprocess user input
47
  cleaned_text = preprocess_text(user_input)
48
  tfidf_vector = tfidf_vectorizer.transform([cleaned_text])
49
+ input_feature = torch.tensor(tfidf_vector.toarray(), dtype=torch.float)
50
+
51
+ # STEP 2: Get original feature set from KNN model
52
+ original_features = torch.tensor(knn_model._fit_X, dtype=torch.float)
53
+
54
+ # STEP 3: Combine input with training data features
55
+ combined_features = torch.cat([original_features, input_feature], dim=0)
56
+
57
+ # STEP 4: Build edge index using k-NN
58
+ neighbors = knn_model.kneighbors(combined_features, return_distance=False)
59
+ edge_list = []
60
+ for idx, nbrs in enumerate(neighbors):
61
+ for nbr in nbrs:
62
+ if idx != nbr:
63
+ edge_list.append([idx, nbr])
64
+ edge_index = torch.tensor(np.array(edge_list).T, dtype=torch.long)
65
 
66
+ # STEP 5: Run inference on the last node (user input)
67
  sage_model.eval()
68
  with torch.no_grad():
69
+ logits = sage_model(combined_features, edge_index)
70
+ pred_node_logits = logits[-1] # Last node is the user input
71
+ prediction = torch.argmax(pred_node_logits).item()
72
+ confidence = torch.exp(pred_node_logits)[prediction].item()
73
 
74
+ # STEP 6: Display result
75
  label = "🟒 Real News" if prediction == 1 else "πŸ”΄ Disinformation"
76
  st.markdown(f"### Prediction: {label}")
77
+ st.markdown(f"**Confidence:** {confidence:.2%}")