Spaces:
Sleeping
Sleeping
| from transformers import AutoTokenizer, AutoModel | |
| import torch | |
| import torch.nn.functional as F | |
| import hnswlib | |
| import numpy as np | |
| import datetime | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from typing import List | |
| if torch.cuda.is_available(): | |
| print("CUDA is available! Inference on GPU!") | |
| else: | |
| print("CUDA is not available. Inference on CPU.") | |
| seperator = "-HFSEP-" | |
| base_name="intfloat/e5-small-v2" | |
| device="cuda" | |
| max_length=512 | |
| max_batch_size = 500 | |
| tokenizer = AutoTokenizer.from_pretrained(base_name) | |
| model = AutoModel.from_pretrained(base_name).to(device) | |
| def current_timestamp(): | |
| return datetime.datetime.utcnow().timestamp() | |
| def get_embeddings(input_texts): | |
| input_texts = input_texts[:max_batch_size] | |
| batch_dict = tokenizer( | |
| input_texts, | |
| max_length=max_length, | |
| padding=True, | |
| truncation=True, | |
| return_tensors='pt' | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**batch_dict) | |
| embeddings = _average_pool( | |
| outputs.last_hidden_state, batch_dict['attention_mask'] | |
| ) | |
| embeddings = F.normalize(embeddings, p=2, dim=1) | |
| embeddings_np = embeddings.cpu().numpy() | |
| if device == "cuda": | |
| del embeddings | |
| torch.cuda.empty_cache() | |
| return embeddings_np | |
| def _average_pool( | |
| last_hidden_states, | |
| attention_mask | |
| ): | |
| last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
| return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
| def create_hnsw_index(embeddings_np, space='ip', ef_construction=100, M=16): | |
| index = hnswlib.Index(space=space, dim=len(embeddings_np[0])) | |
| index.init_index(max_elements=len(embeddings_np), ef_construction=ef_construction, M=M) | |
| ids = np.arange(embeddings_np.shape[0]) | |
| index.add_items(embeddings_np, ids) | |
| return index | |
| def preprocess_texts(query, paragraphs): | |
| query = f'query: {query}' | |
| paragraphs = [f'passage: {p}' for p in paragraphs] | |
| return [query]+paragraphs | |
| app = FastAPI() | |
| class EmbeddingsSimilarityReq(BaseModel): | |
| paragraphs: List[str] | |
| query: str | |
| top_k: int | |
| async def find_similar_paragraphsitem(req: EmbeddingsSimilarityReq): | |
| print("Len of batches", len(req.paragraphs)) | |
| print("creating embeddings", current_timestamp()) | |
| inputs = preprocess_texts(req.query, req.paragraphs) | |
| embeddings_np = get_embeddings(inputs) | |
| query_embedding, chunks_embeddings = embeddings_np[0], embeddings_np[1:] | |
| print("creating index", current_timestamp()) | |
| search_index = create_hnsw_index(chunks_embeddings) | |
| print("searching index", current_timestamp()) | |
| labels, _ = search_index.knn_query(query_embedding, k=min(int(req.top_k), len(chunks_embeddings))) | |
| labels = labels[0].tolist() | |
| return labels | |