dmanjate commited on
Commit
0e778d7
·
verified ·
1 Parent(s): 5b5bd41

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +195 -0
app.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import json
3
+ import os
4
+ from datetime import datetime
5
+ from typing import Dict, List, Any
6
+
7
+ # ============== LLM / RAG deps ==============
8
+ from langchain_openai import ChatOpenAI
9
+ from langchain_community.vectorstores import Chroma
10
+ from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
11
+ from langchain_core.prompts import ChatPromptTemplate
12
+ from langchain_core.tools import tool
13
+ from langchain.agents import create_tool_calling_agent, AgentExecutor
14
+
15
+ # Optional: Llama Guard (Groq)
16
+ try:
17
+ from groq import Groq
18
+ HAS_GROQ = True
19
+ except Exception:
20
+ HAS_GROQ = False
21
+
22
+ # ============== CONFIG ==============
23
+ # TODO: set via env or secrets in Streamlit Cloud
24
+ API_KEY = os.getenv("OPENAI_API_KEY", "gl-U2FsdGVkX1+r0wSt3dbixZ6yKDLw0Rg46XrTm+rJY/t9b+4TU3aqZ4eDbA2OHufX")
25
+ API_BASE = os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1") # or your compatible gateway
26
+ MODEL_NAME = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
27
+
28
+ # Your nutrition vectorstore (text hypotheticals)
29
+ # TODO: update if you used different names/locations
30
+ PERSIST_DIR = os.getenv("CHROMA_DIR", "./research_db_hypotheticals")
31
+ COLLECTION_NAME = os.getenv("CHROMA_COLLECTION", "hypothetical_questions_text")
32
+
33
+ # Optional Llama Guard key
34
+ GROQ_API_KEY = os.getenv("LLAMA_API_KEY", "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")
35
+
36
+ # ============== LOAD VECTORSTORE ==============
37
+ # If you embedded with OpenAI, you can switch to an OpenAIEmbeddings. If you used gte/gte-large (HF), match that here.
38
+ # Using a robust open model by default.
39
+ embedding_model = SentenceTransformerEmbeddings(model_name="thenlper/gte-large")
40
+
41
+ vector_store = Chroma(
42
+ collection_name=COLLECTION_NAME,
43
+ persist_directory=PERSIST_DIR,
44
+ embedding_function=embedding_model
45
+ )
46
+
47
+ retriever = vector_store.as_retriever(
48
+ search_type="similarity",
49
+ search_kwargs={"k": 5}
50
+ )
51
+
52
+ # ============== LLM CLIENT ==============
53
+ llm = ChatOpenAI(
54
+ model=MODEL_NAME,
55
+ temperature=0,
56
+ max_retries=3,
57
+ api_key=API_KEY,
58
+ base_url=API_BASE,
59
+ )
60
+
61
+ # ============== SAFETY FILTER (Optional) ==============
62
+ def filter_input_with_llama_guard(user_input: str, model: str = "meta-llama/llama-guard-4-12b") -> str:
63
+ """
64
+ Use Llama Guard (Groq) to sanitize user input. Returns the moderated text if available,
65
+ else returns the original text.
66
+ """
67
+ if not HAS_GROQ or not GROQ_API_KEY:
68
+ return user_input # no safety client configured
69
+
70
+ try:
71
+ client = Groq(api_key=GROQ_API_KEY)
72
+ resp = client.chat.completions.create(
73
+ model=model,
74
+ messages=[
75
+ {"role": "system", "content": "You are a safety filter. Return only sanitized or safe text."},
76
+ {"role": "user", "content": user_input},
77
+ ],
78
+ )
79
+ safe_text = resp.choices[0].message.content.strip()
80
+ return safe_text or user_input
81
+ except Exception:
82
+ return user_input
83
+
84
+ # ============== SIMPLE RAG PIPELINE ==============
85
+ RAG_SYSTEM_MESSAGE = """You are a medical-support assistant specializing in nutritional disorders.
86
+ Answer clearly, concisely, and factually. Use ONLY the provided context.
87
+ If the answer is not contained in the context, say: "I don't know."
88
+ Prefer listing symptoms, diagnosis criteria, risk factors, and treatments/dietary recommendations when relevant.
89
+ Avoid speculation; cite nothing to the user."""
90
+
91
+ RAG_USER_TEMPLATE = """###Context
92
+ {context}
93
+
94
+ ###Question
95
+ {question}
96
+ """
97
+
98
+ rag_prompt = ChatPromptTemplate.from_messages([
99
+ ("system", RAG_SYSTEM_MESSAGE),
100
+ ("user", RAG_USER_TEMPLATE),
101
+ ])
102
+
103
+ @tool("agentic_rag")
104
+ def agentic_rag(question: str) -> str:
105
+ """
106
+ Nutrition Disorder RAG: retrieves context from Chroma and answers strictly from it.
107
+ """
108
+ try:
109
+ docs = retriever.invoke(question)
110
+ if not docs:
111
+ return "I don't know."
112
+
113
+ context = "\n\n".join([d.page_content for d in docs])
114
+ chain = rag_prompt | llm
115
+ resp = chain.invoke({"context": context, "question": question})
116
+ return resp.content if hasattr(resp, "content") else str(resp)
117
+ except Exception as e:
118
+ return f"Error: {e}"
119
+
120
+ # ============== STREAMLIT UI ==============
121
+ st.set_page_config(page_title="Nutrition Disorder Agentic RAG", page_icon="🥦", layout="centered")
122
+
123
+ def ensure_state():
124
+ if "logged_in" not in st.session_state:
125
+ st.session_state.logged_in = False
126
+ if "user_name" not in st.session_state:
127
+ st.session_state.user_name = ""
128
+ if "history" not in st.session_state:
129
+ st.session_state.history = [{"role": "assistant", "content": "Welcome! I’m your Nutrition Disorder assistant. How can I help?"}]
130
+
131
+ def login_page():
132
+ st.title("Nutrition Disorder Agent — Login")
133
+ with st.form("login_form"):
134
+ name = st.text_input("Your name")
135
+ submitted = st.form_submit_button("Enter")
136
+ if submitted:
137
+ if name.strip():
138
+ st.session_state.logged_in = True
139
+ st.session_state.user_name = name.strip()
140
+ st.success(f"Welcome, {name.strip()}!")
141
+ st.rerun()
142
+ else:
143
+ st.error("Please enter your name.")
144
+
145
+ def chat_page():
146
+ st.title("Nutrition Disorder Agent")
147
+ st.caption("Evidence-grounded answers about symptoms, diagnosis, and treatments for nutritional disorders.")
148
+
149
+ # Show chat history
150
+ for m in st.session_state.history:
151
+ with st.chat_message(m["role"]):
152
+ st.markdown(m["content"])
153
+
154
+ # Input
155
+ user_msg = st.chat_input("Ask about symptoms, diagnosis, or treatments…")
156
+ if user_msg:
157
+ # Optional safety filter
158
+ filtered = filter_input_with_llama_guard(user_msg)
159
+
160
+ # Append user
161
+ st.session_state.history.append({"role": "user", "content": user_msg})
162
+ with st.chat_message("user"):
163
+ st.markdown(user_msg)
164
+
165
+ # Call tool
166
+ with st.spinner("Thinking..."):
167
+ try:
168
+ answer = agentic_rag.invoke(filtered)
169
+ except Exception as e:
170
+ answer = f"Sorry, I hit an error: {e}"
171
+
172
+ st.session_state.history.append({"role": "assistant", "content": answer})
173
+ with st.chat_message("assistant"):
174
+ st.markdown(answer)
175
+
176
+ # Footer controls
177
+ cols = st.columns(3)
178
+ if cols[0].button("Reset Chat"):
179
+ st.session_state.history = [{"role": "assistant", "content": "Welcome! I’m your Nutrition Disorder assistant. How can I help?"}]
180
+ st.rerun()
181
+ with cols[1]:
182
+ st.write("")
183
+ if cols[2].button("Logout"):
184
+ st.session_state.clear()
185
+ st.rerun()
186
+
187
+ def main():
188
+ ensure_state()
189
+ if not st.session_state.logged_in:
190
+ login_page()
191
+ else:
192
+ chat_page()
193
+
194
+ if __name__ == "__main__":
195
+ main()