#!/usr/bin/env python3 import numpy as np import subprocess import tempfile import os from typing import Optional, List, Dict from safetensors import safe_open import json class MLPProjector: """MLP projector to project hidden states to embedding space.""" def __init__(self, linear1_weight, linear2_weight): self.linear1_weight = linear1_weight self.linear2_weight = linear2_weight def __call__(self, x): # Linear 1 x = x @ self.linear1_weight.T # ReLU x = np.maximum(0, x) # Linear 2 x = x @ self.linear2_weight.T return x def load_projector(projector_path: str) -> MLPProjector: """Load projector weights from safetensors file.""" with safe_open(projector_path, framework="numpy") as f: w0 = f.get_tensor("projector.0.weight") w2 = f.get_tensor("projector.2.weight") return MLPProjector(w0, w2) def sanitize_input(text: str, special_tokens: Dict[str, str]) -> str: """Remove special tokens from input text.""" for token in special_tokens.values(): text = text.replace(token, "") return text def format_docs_prompts_func( query: str, docs: list[str], instruction: Optional[str] = None, special_tokens: Dict[str, str] = {}, ) -> str: """Format query and documents into a prompt for the model.""" query = sanitize_input(query, special_tokens) docs = [sanitize_input(doc, special_tokens) for doc in docs] prefix = ( "<|im_start|>system\n" "You are a search relevance expert who can determine a ranking of the passages based on how relevant they are to the query. " "If the query is a question, how relevant a passage is depends on how well it answers the question. " "If not, try to analyze the intent of the query and assess how well each passage satisfies the intent. " "If an instruction is provided, you should follow the instruction when determining the ranking." "<|im_end|>\n<|im_start|>user\n" ) suffix = "<|im_end|>\n<|im_start|>assistant\n" doc_emb_token = special_tokens["doc_embed_token"] query_emb_token = special_tokens["query_embed_token"] prompt = ( f"I will provide you with {len(docs)} passages, each indicated by a numerical identifier. " f"Rank the passages based on their relevance to query: {query}\n" ) if instruction: prompt += f'\n{instruction}\n\n' doc_prompts = [f'\n{doc}{doc_emb_token}\n' for i, doc in enumerate(docs)] prompt += "\n".join(doc_prompts) + "\n" prompt += f"\n{query}{query_emb_token}\n" return prefix + prompt + suffix class GGUFReranker: """GGUF-based implementation of jina-reranker-v3.""" def __init__(self, model_path: str = "jina-reranker-v3-BF16.gguf", projector_path: str = "projector.safetensors", llama_embedding_path: str = "/tmp/hanxiao-llama.cpp/build/bin/llama-embedding"): """Initialize GGUF-based reranker.""" self.model_path = model_path self.llama_embedding_path = llama_embedding_path self.projector = load_projector(projector_path) # Special tokens self.special_tokens = { "query_embed_token": "<|rerank_token|>", "doc_embed_token": "<|embed_token|>" } self.doc_embed_token_id = 151670 self.query_embed_token_id = 151671 def _get_hidden_states(self, prompt: str) -> np.ndarray: """Get per-token hidden states using llama-embedding CLI.""" with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: f.write(prompt) prompt_file = f.name try: result = subprocess.run( [ self.llama_embedding_path, '-m', self.model_path, '-f', prompt_file, '--pooling', 'none', '--embd-separator', '<#JINA_SEP#>', # Preserve internal newlines '--embd-normalize', '-1', '--embd-output-format', 'json', '--ubatch-size', '512', '--ctx-size', '8192', '--flash-attn', '-ngl', '99' ], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True ) output = json.loads(result.stdout) embeddings = [item['embedding'] for item in output['data']] return np.array(embeddings) finally: os.unlink(prompt_file) def _tokenize(self, prompt: str) -> List[int]: """Tokenize prompt to find special token positions.""" with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: f.write(prompt) prompt_file = f.name try: result = subprocess.run( ['llama-tokenize', '-m', self.model_path, '-f', prompt_file], stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True, check=True ) tokens = [] for line in result.stdout.strip().split('\n'): if '->' in line: token_id = int(line.split('->')[0].strip()) tokens.append(token_id) return tokens finally: os.unlink(prompt_file) def rerank( self, query: str, documents: List[str], top_n: Optional[int] = None, return_embeddings: bool = False, instruction: Optional[str] = None ) -> List[Dict]: """Rerank documents based on relevance to query.""" # Format prompt prompt = format_docs_prompts_func( query, documents, instruction=instruction, special_tokens=self.special_tokens ) # Get per-token hidden states using llama-embedding CLI embeddings = self._get_hidden_states(prompt) # Tokenize to find special token positions tokens = self._tokenize(prompt) tokens_array = np.array(tokens) query_embed_positions_in_tokens = np.where(tokens_array == self.query_embed_token_id)[0] doc_embed_positions_in_tokens = np.where(tokens_array == self.doc_embed_token_id)[0] if len(query_embed_positions_in_tokens) == 0: raise ValueError(f"Query embed token (ID {self.query_embed_token_id}) not found in input") if len(doc_embed_positions_in_tokens) == 0: raise ValueError(f"Document embed tokens (ID {self.doc_embed_token_id}) not found in input") # llama-embedding strips trailing newlines but preserves internal newlines (via --embd-separator) # Token positions map directly to embedding indices query_pos = query_embed_positions_in_tokens[0] doc_positions = doc_embed_positions_in_tokens # Extract embeddings at special token positions query_hidden = embeddings[query_pos:query_pos+1] # [1, hidden_size] doc_hidden = embeddings[doc_positions] # [num_docs, hidden_size] # Project embeddings query_embeds = self.projector(query_hidden) # [1, 512] doc_embeds = self.projector(doc_hidden) # [num_docs, 512] # Compute cosine similarity scores # Broadcast query to match doc shape query_expanded = np.tile(query_embeds, (len(doc_embeds), 1)) # [num_docs, 512] # Cosine similarity dot_product = np.sum(doc_embeds * query_expanded, axis=-1) # [num_docs] doc_norm = np.sqrt(np.sum(doc_embeds * doc_embeds, axis=-1)) # [num_docs] query_norm = np.sqrt(np.sum(query_expanded * query_expanded, axis=-1)) # [num_docs] scores = dot_product / (doc_norm * query_norm) # [num_docs] # Create results results = [] for idx, (doc, score, embed) in enumerate(zip(documents, scores, doc_embeds)): result = { "index": idx, "relevance_score": float(score), "document": doc } if return_embeddings: result["embedding"] = embed.tolist() results.append(result) # Sort by score descending results.sort(key=lambda x: x["relevance_score"], reverse=True) # Return top_n if specified if top_n is not None: results = results[:top_n] return results if __name__ == "__main__": # Test the reranker reranker = GGUFReranker() query = "What is the capital of France?" documents = [ "Paris is the capital and largest city of France.", "Berlin is the capital of Germany.", "The Eiffel Tower is located in Paris." ] results = reranker.rerank(query, documents) for result in results: print(f"Doc {result['index']}: {result['relevance_score']:.4f} - {result['document'][:50]}...")