Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,9 +3,10 @@ from sentence_transformers import SentenceTransformer
|
|
| 3 |
import faiss
|
| 4 |
import numpy as np
|
| 5 |
import gradio as gr
|
|
|
|
| 6 |
|
| 7 |
# Load a small subset (10,000 rows)
|
| 8 |
-
dataset = load_dataset("wiki40b", "en", split="train[:
|
| 9 |
|
| 10 |
# Extract only text
|
| 11 |
docs = [d["text"] for d in dataset]
|
|
@@ -18,19 +19,29 @@ embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
|
| 18 |
# Convert texts to embeddings
|
| 19 |
embeddings = embed_model.encode(docs, show_progress_bar=True)
|
| 20 |
|
| 21 |
-
# Store in FAISS index
|
| 22 |
-
dimension = embeddings.shape[1]
|
| 23 |
-
index = faiss.IndexFlatL2(dimension)
|
| 24 |
-
index.add(np.array(embeddings))
|
| 25 |
|
| 26 |
-
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
def search_wikipedia(query, top_k=3):
|
| 30 |
-
query_embedding = embed_model.encode([query])
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
| 34 |
|
| 35 |
# Gradio Interface
|
| 36 |
iface = gr.Interface(
|
|
|
|
| 3 |
import faiss
|
| 4 |
import numpy as np
|
| 5 |
import gradio as gr
|
| 6 |
+
import chromadb
|
| 7 |
|
| 8 |
# Load a small subset (10,000 rows)
|
| 9 |
+
dataset = load_dataset("wiki40b", "en", split="train[:1000]")
|
| 10 |
|
| 11 |
# Extract only text
|
| 12 |
docs = [d["text"] for d in dataset]
|
|
|
|
| 19 |
# Convert texts to embeddings
|
| 20 |
embeddings = embed_model.encode(docs, show_progress_bar=True)
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
+
# Initialize ChromaDB client
|
| 24 |
+
chroma_client = chromadb.PersistentClient(path="./chroma_db") # Stores data persistently
|
| 25 |
+
collection = chroma_client.get_or_create_collection(name="wikipedia_docs")
|
| 26 |
|
| 27 |
+
# Store embeddings in ChromaDB
|
| 28 |
+
for i, (doc, embedding) in enumerate(zip(docs, embeddings)):
|
| 29 |
+
collection.add(
|
| 30 |
+
ids=[str(i)], # Unique ID for each doc
|
| 31 |
+
embeddings=[embedding.tolist()], # Convert numpy array to list
|
| 32 |
+
documents=[doc]
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
print("Stored embeddings in ChromaDB!")
|
| 36 |
+
|
| 37 |
+
# Search function using ChromaDB
|
| 38 |
def search_wikipedia(query, top_k=3):
|
| 39 |
+
query_embedding = embed_model.encode([query]).tolist()
|
| 40 |
+
results = collection.query(
|
| 41 |
+
query_embeddings=query_embedding,
|
| 42 |
+
n_results=top_k
|
| 43 |
+
)
|
| 44 |
+
return "\n\n".join(results["documents"][0]) # Return top results
|
| 45 |
|
| 46 |
# Gradio Interface
|
| 47 |
iface = gr.Interface(
|