MrUtakata commited on
Commit
ac77892
Β·
verified Β·
1 Parent(s): 23ad8ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -53
app.py CHANGED
@@ -1,66 +1,43 @@
1
- # app.py
2
  import streamlit as st
3
- import joblib
4
- import nltk
5
  import torch
6
- import torch.nn.functional as F
 
7
  import numpy as np
8
 
9
- from nltk.corpus import stopwords
10
- from nltk.tokenize import RegexpTokenizer
11
- from sklearn.feature_extraction.text import TfidfVectorizer
12
-
13
- # β€”β€”β€” 1) NLTK setup β€”β€”β€”
14
- nltk.download('stopwords')
15
- _STOP_WORDS = set(stopwords.words('english'))
16
- _TOKENIZER = RegexpTokenizer(r'\w+')
17
-
18
- def preprocess_text(text: str) -> str:
19
- tokens = _TOKENIZER.tokenize(text.lower())
20
- return " ".join([t for t in tokens if t not in _STOP_WORDS])
21
-
22
- # β€”β€”β€” 2) Load heavy resources once β€”β€”β€”
23
  @st.cache_resource
24
- def load_resources():
25
- tfidf: TfidfVectorizer = joblib.load("tfidf_vectorizer.pkl")
26
- sage_model: torch.nn.Module = joblib.load("sage_model.pkl")
27
- sage_model.eval()
28
- return tfidf, sage_model
 
29
 
30
- tfidf, sage_model = load_resources()
31
 
32
- # β€”β€”β€” 3) Streamlit UI β€”β€”β€”
33
- st.title("Disinformation Detection")
34
- st.write(
35
- """
36
- Paste in some text and click **Predict**.
37
- The model will output the probability it’s **True Information** vs. **Disinformation**.
38
- """
39
- )
40
 
41
- user_input = st.text_area("Your text here", height=200)
 
42
 
43
- if st.button("Predict"):
44
- if not user_input.strip():
45
- st.warning("Please enter some text first.")
46
  else:
47
- # Preprocess & vectorize
48
- cleaned = preprocess_text(user_input)
49
- vec = tfidf.transform([cleaned]).toarray()
50
- x = torch.from_numpy(vec).float() # shape [1, D]
51
-
52
- # Empty graph so GraphSAGE layers still run
53
- edge_index = torch.empty((2, 0), dtype=torch.long)
54
 
55
- # Inference
 
56
  with torch.no_grad():
57
- logits = sage_model(x, edge_index) # [1, 2]
58
- probs = torch.exp(logits).numpy()[0] # convert log‑softmax β†’ probabilities
59
-
60
- # Display
61
- st.markdown("### Prediction probabilities")
62
- st.write(f"β€’ πŸ”΅ True information: {probs[1]:.2%}")
63
- st.write(f"β€’ πŸ”΄ Disinformation: {probs[0]:.2%}")
64
 
65
- verdict = "βœ… Likely TRUE" if probs[1] > probs[0] else "❌ Likely DISINFORMATION"
66
- st.markdown(f"## **{verdict}**")
 
 
 
1
  import streamlit as st
 
 
2
  import torch
3
+ import joblib
4
+ import dill
5
  import numpy as np
6
 
7
+ # Load assets
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  @st.cache_resource
9
+ def load_assets():
10
+ with open("preprocess_function.pkl", "rb") as f:
11
+ preprocess_text = dill.load(f)
12
+ tfidf = joblib.load("tfidf_vectorizer.pkl")
13
+ model = joblib.load("sage_model.pkl")
14
+ return preprocess_text, tfidf, model
15
 
16
+ preprocess_text, tfidf_vectorizer, sage_model = load_assets()
17
 
18
+ # App title
19
+ st.title("🧠 Disinformation Detection")
20
+ st.write("This app predicts whether a given news article is **real** or **disinformation** using a trained GraphSAGE model.")
 
 
 
 
 
21
 
22
+ # Input text
23
+ user_input = st.text_area("πŸ“ Enter a news article or headline:")
24
 
25
+ if st.button("Detect"):
26
+ if user_input.strip() == "":
27
+ st.warning("Please enter some text to analyze.")
28
  else:
29
+ # Preprocess input
30
+ cleaned_text = preprocess_text(user_input)
31
+ tfidf_vector = tfidf_vectorizer.transform([cleaned_text])
32
+ features = torch.tensor(tfidf_vector.toarray(), dtype=torch.float)
 
 
 
33
 
34
+ # Predict
35
+ sage_model.eval()
36
  with torch.no_grad():
37
+ logits = sage_model(features, torch.empty((2, 0), dtype=torch.long)) # dummy edge_index
38
+ prediction = torch.argmax(logits, dim=1).item()
39
+ prob = torch.exp(logits)[0, prediction].item()
 
 
 
 
40
 
41
+ label = "🟒 Real News" if prediction == 1 else "πŸ”΄ Disinformation"
42
+ st.markdown(f"### Prediction: {label}")
43
+ st.markdown(f"**Confidence:** {prob:.2%}")