abdoelsayed commited on
Commit
b758709
Β·
verified Β·
1 Parent(s): 899d4fc

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +391 -0
README.md ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: mit
5
+ library_name: transformers
6
+ tags:
7
+ - reranking
8
+ - information-retrieval
9
+ - pointwise
10
+ - binary-cross-entropy
11
+ - llama
12
+ base_model: meta-llama/Llama-3.1-8B
13
+ datasets:
14
+ - Tevatron/msmarco-passage
15
+ - abdoelsayed/DeAR-COT
16
+ pipeline_tag: text-classification
17
+ ---
18
+
19
+ # DeAR-8B-Reranker-CE-v1
20
+
21
+ ## Model Description
22
+
23
+ **DeAR-8B-Reranker-CE-v1** is an 8B parameter neural reranker trained with Binary Cross-Entropy loss and knowledge distillation. This model uses a classification-based approach to document reranking and is optimized for both accuracy and inference speed.
24
+
25
+ ## Model Details
26
+
27
+ - **Model Type:** Pointwise Reranker (Binary Classification)
28
+ - **Base Model:** LLaMA-3.1-8B
29
+ - **Parameters:** 8 billion
30
+ - **Training Method:** Knowledge Distillation + Binary Cross-Entropy Loss
31
+ - **Teacher Model:** [LLaMA2-13B-RankLLaMA](https://huggingface.co/abdoelsayed/llama2-13b-rankllama-teacher)
32
+ - **Training Data:** MS MARCO
33
+ - **Precision:** BFloat16
34
+
35
+ ## Key Features
36
+
37
+ βœ… **Classification-based:** Binary relevance prediction with probabilistic outputs
38
+ βœ… **Fast Inference:** 2.2s average latency on standard GPU
39
+ βœ… **Strong Baseline:** Competitive performance across benchmarks
40
+ βœ… **CoT Enhanced:** Trained with Chain-of-Thought reasoning from teacher
41
+
42
+ ## Performance
43
+
44
+ | Benchmark | NDCG@10 |
45
+ |-----------|---------|
46
+ | TREC DL19 | 73.9 |
47
+ | TREC DL20 | 72.1 |
48
+ | BEIR (Avg) | 44.8 |
49
+ | MS MARCO Dev | 68.5 |
50
+
51
+ ## Usage
52
+
53
+ ### Quick Start
54
+
55
+ ```python
56
+ import torch
57
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
58
+
59
+ # Load model
60
+ model_path = "abdoelsayed/dear-8b-reranker-ce-v1"
61
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
62
+ model = AutoModelForSequenceClassification.from_pretrained(
63
+ model_path,
64
+ torch_dtype=torch.bfloat16
65
+ )
66
+ model.eval().cuda()
67
+
68
+ # Score a query-document pair
69
+ query = "What is llama?"
70
+ document = "The llama is a domesticated South American camelid..."
71
+
72
+ inputs = tokenizer(
73
+ f"query: {query}",
74
+ f"document: {document}",
75
+ return_tensors="pt",
76
+ truncation=True,
77
+ max_length=228,
78
+ padding="max_length"
79
+ )
80
+ inputs = {k: v.cuda() for k, v in inputs.items()}
81
+
82
+ with torch.no_grad():
83
+ score = model(**inputs).logits.squeeze().item()
84
+
85
+ print(f"Relevance score: {score}")
86
+ ```
87
+
88
+ ### Complete Reranking Example
89
+
90
+ ```python
91
+ import torch
92
+ from typing import List, Tuple
93
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
94
+
95
+ def load_reranker(model_path: str, device: str = "cuda"):
96
+ """Load the reranker model and tokenizer."""
97
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
98
+ model = AutoModelForSequenceClassification.from_pretrained(
99
+ model_path,
100
+ torch_dtype=torch.bfloat16
101
+ )
102
+
103
+ # Configure padding token
104
+ if tokenizer.pad_token is None:
105
+ tokenizer.pad_token = tokenizer.eos_token
106
+ tokenizer.pad_token_id = tokenizer.eos_token_id
107
+ tokenizer.padding_side = "right"
108
+
109
+ model.eval()
110
+ model.to(device)
111
+ return tokenizer, model
112
+
113
+ @torch.inference_mode()
114
+ def rerank(
115
+ tokenizer,
116
+ model,
117
+ query: str,
118
+ documents: List[Tuple[str, str]], # (title, text)
119
+ batch_size: int = 64
120
+ ) -> List[Tuple[int, float]]:
121
+ """
122
+ Rerank documents for a query.
123
+
124
+ Returns:
125
+ List of (doc_index, score) sorted by relevance (descending)
126
+ """
127
+ device = next(model.parameters()).device
128
+ scores = []
129
+
130
+ for i in range(0, len(documents), batch_size):
131
+ batch = documents[i:i + batch_size]
132
+
133
+ # Prepare batch
134
+ queries = [f"query: {query}"] * len(batch)
135
+ docs = [f"document: {title} {text}" for title, text in batch]
136
+
137
+ inputs = tokenizer(
138
+ queries,
139
+ docs,
140
+ return_tensors="pt",
141
+ truncation=True,
142
+ max_length=228,
143
+ padding=True,
144
+ return_attention_mask=True
145
+ )
146
+ inputs = {k: v.to(device) for k, v in inputs.items()}
147
+
148
+ # Score batch
149
+ logits = model(**inputs).logits.squeeze(-1)
150
+ scores.extend(logits.cpu().tolist())
151
+
152
+ # Rank by score
153
+ ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
154
+ return ranked
155
+
156
+
157
+ # Example
158
+ tokenizer, model = load_reranker("abdoelsayed/dear-8b-reranker-ce-v1")
159
+
160
+ query = "When did Thomas Edison invent the light bulb?"
161
+ documents = [
162
+ ("", "Lightning strike at Seoul National University"),
163
+ ("", "Thomas Edison tried to invent a device for car but failed"),
164
+ ("", "Coffee is good for diet"),
165
+ ("", "KEPCO fixes light problems"),
166
+ ("", "Thomas Edison invented the light bulb in 1879"),
167
+ ]
168
+
169
+ ranking = rerank(tokenizer, model, query, documents)
170
+ print(ranking)
171
+ # Output: [(4, -2.015625), (1, -5.6875), (2, -6.375), (0, -6.5), (3, -6.78125)]
172
+ # Document at index 4 is most relevant
173
+ ```
174
+
175
+ ## Training Details
176
+
177
+ ### Training Data
178
+ - **Primary Dataset:** MS MARCO Passage Ranking (~8M pairs)
179
+ - **CoT Dataset:** [DeAR-COT](https://huggingface.co/datasets/abdoelsayed/DeAR-COT)
180
+ - **Teacher Annotations:** Soft labels from 13B teacher model
181
+
182
+ ### Training Configuration
183
+ ```python
184
+ {
185
+ "base_model": "meta-llama/Llama-3.1-8B",
186
+ "teacher_model": "abdoelsayed/llama2-13b-rankllama-teacher",
187
+ "loss": "Binary Cross-Entropy",
188
+ "distillation": {
189
+ "temperature": 2.0,
190
+ "alpha": 0.1
191
+ },
192
+ "optimizer": "AdamW",
193
+ "learning_rate": 1e-4,
194
+ "batch_size": 2,
195
+ "gradient_accumulation": 2,
196
+ "epochs": 2,
197
+ "max_length": 228,
198
+ "q_max_len": 32,
199
+ "p_max_len": 196,
200
+ "warmup_ratio": 0.1,
201
+ "weight_decay": 0.01,
202
+ "bf16": true
203
+ }
204
+ ```
205
+
206
+ ### Hardware
207
+ - **GPUs:** 4x NVIDIA A100 (40GB)
208
+ - **Training Time:** ~34 hours
209
+ - **Framework:** DeepSpeed ZeRO Stage 2
210
+ - **Memory Usage:** ~38GB per GPU
211
+
212
+ ### Loss Function
213
+
214
+ **Binary Cross-Entropy** with Knowledge Distillation:
215
+
216
+ ```python
217
+ L_total = (1 - Ξ±) * BCE(y_pred, y_true) + Ξ± * KL(Οƒ(z_s/T), Οƒ(z_t/T))
218
+
219
+ where:
220
+ - BCE: Binary cross-entropy loss
221
+ - KL: KL divergence
222
+ - z_s: Student logits
223
+ - z_t: Teacher logits
224
+ - T: Temperature (2.0)
225
+ - Ξ±: Distillation weight (0.1)
226
+ - Οƒ: Sigmoid function
227
+ ```
228
+
229
+ ## Evaluation Results
230
+
231
+ ### TREC Deep Learning
232
+
233
+ | Dataset | NDCG@10 | NDCG@20 | MRR@10 | MAP |
234
+ |---------|---------|---------|--------|-----|
235
+ | DL19 | 73.90 | 69.82 | 87.3 | 44.92 |
236
+ | DL20 | 72.10 | 68.45 | 85.1 | 42.67 |
237
+
238
+ ### BEIR Benchmark
239
+
240
+ | Dataset | NDCG@10 | NDCG@100 |
241
+ |---------|---------|----------|
242
+ | MS MARCO | 68.5 | 75.2 |
243
+ | NQ | 51.8 | 69.4 |
244
+ | HotpotQA | 61.2 | 74.8 |
245
+ | FiQA | 46.8 | 62.3 |
246
+ | ArguAna | 58.9 | 71.5 |
247
+ | SciFact | 73.1 | 82.6 |
248
+ | TREC-COVID | 84.7 | 88.3 |
249
+ | NFCorpus | 39.4 | 51.7 |
250
+ | **Average** | **44.8** | **68.2** |
251
+
252
+ ### Efficiency Metrics
253
+
254
+ | Metric | Value |
255
+ |--------|-------|
256
+ | Inference Time (batch=64) | 2.2s |
257
+ | Throughput | ~45 docs/sec |
258
+ | GPU Memory (inference) | 18GB |
259
+ | Model Size (BF16) | 16GB |
260
+
261
+ ## Comparison
262
+
263
+ | Model | Loss | DL19 | DL20 | BEIR Avg | Speed (s) |
264
+ |-------|------|------|------|----------|-----------|
265
+ | **DeAR-8B-CE** | BCE | 73.9 | 72.1 | 44.8 | 2.2 |
266
+ | **DeAR-8B-RankNet** | RankNet | 74.5 | 72.8 | 45.2 | 2.2 |
267
+ | MonoT5-3B | - | 71.8 | 68.9 | 43.5 | 3.5 |
268
+ | Teacher-13B | - | 73.8 | 71.2 | 44.8 | 5.8 |
269
+
270
+ **Key Observations:**
271
+ - Slightly lower performance than RankNet variant
272
+ - Identical inference speed
273
+ - More stable training (simpler loss)
274
+ - Better for binary relevance tasks
275
+
276
+ ## Model Architecture
277
+
278
+ ```
279
+ Input Format: "query: [QUERY] document: [TITLE] [TEXT]"
280
+ ↓
281
+ Tokenization (max_length=228)
282
+ ↓
283
+ LLaMA-3.1-8B Transformer
284
+ ↓
285
+ [CLS] Token Pooling
286
+ ↓
287
+ Linear(hidden_size β†’ 1)
288
+ ↓
289
+ Sigmoid (optional)
290
+ ↓
291
+ Relevance Score
292
+ ```
293
+
294
+ ## When to Use This Model
295
+
296
+ **Best for:**
297
+ - βœ… Binary relevance classification
298
+ - βœ… Large-scale reranking (fast inference)
299
+ - βœ… General-purpose IR tasks
300
+ - βœ… Resource-constrained environments
301
+
302
+ **Consider alternatives for:**
303
+ - ❌ Listwise ranking (use DeAR-8B-Listwise)
304
+ - ❌ Maximum performance (use RankNet variant)
305
+ - ❌ Extreme low-latency (use 3B models)
306
+
307
+ ## Limitations
308
+
309
+ 1. **Document Truncation:** Limited to 196 tokens per document
310
+ 2. **Query Length:** Optimal for queries ≀32 tokens
311
+ 3. **Language:** English only
312
+ 4. **Domain:** Trained on MS MARCO (web documents)
313
+ 5. **Pointwise:** Does not model inter-document dependencies
314
+
315
+ ## Bias and Ethical Considerations
316
+
317
+ - **Training Data Bias:** Inherits biases from MS MARCO dataset
318
+ - **Representation Bias:** May perform differently across demographics
319
+ - **Language Bias:** Optimized for English; other languages not evaluated
320
+ - **Domain Bias:** Best performance on web-style documents
321
+
322
+ **Recommendations:**
323
+ - Evaluate fairness for your specific use case
324
+ - Test on diverse query sets
325
+ - Monitor for biased ranking patterns
326
+ - Consider domain-specific fine-tuning
327
+
328
+ ## Fine-tuning
329
+
330
+ To fine-tune on your own data:
331
+
332
+ ```python
333
+ from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
334
+
335
+ model = AutoModelForSequenceClassification.from_pretrained(
336
+ "abdoelsayed/dear-8b-reranker-ce-v1",
337
+ num_labels=1
338
+ )
339
+
340
+ training_args = TrainingArguments(
341
+ output_dir="./finetuned-model",
342
+ learning_rate=5e-6, # Lower LR for fine-tuning
343
+ per_device_train_batch_size=4,
344
+ num_train_epochs=1,
345
+ bf16=True,
346
+ logging_steps=100,
347
+ )
348
+
349
+ trainer = Trainer(
350
+ model=model,
351
+ args=training_args,
352
+ train_dataset=your_dataset,
353
+ )
354
+
355
+ trainer.train()
356
+ ```
357
+
358
+ ## Related Models
359
+
360
+ **DeAR Family (8B):**
361
+ - [DeAR-8B-RankNet](https://huggingface.co/abdoelsayed/dear-8b-reranker-ranknet-v1) - RankNet loss variant
362
+ - [DeAR-8B-Listwise](https://huggingface.co/abdoelsayed/dear-8b-reranker-listwise-v1) - Generative listwise reranker
363
+ - [DeAR-8B-CE-LoRA](https://huggingface.co/abdoelsayed/dear-8b-reranker-ce-lora-v1) - LoRA adapter version
364
+
365
+ **Other Sizes:**
366
+ - [DeAR-3B-CE](https://huggingface.co/abdoelsayed/dear-3b-reranker-ce-v1) - Faster 3B variant
367
+
368
+ **Resources:**
369
+ - [Teacher Model](https://huggingface.co/abdoelsayed/llama2-13b-rankllama-teacher)
370
+ - [DeAR-COT Dataset](https://huggingface.co/datasets/abdoelsayed/DeAR-COT)
371
+
372
+ ## Citation
373
+
374
+ ```bibtex
375
+ @article{abdallah2025dear,
376
+ title={DeAR: Dual-Stage Document Reranking with Reasoning Agents via LLM Distillation},
377
+ author={Abdallah, Abdelrahman and Mozafari, Jamshid and Piryani, Bhawna and Jatowt, Adam},
378
+ journal={arXiv preprint arXiv:2508.16998},
379
+ year={2025}
380
+ }
381
+ ```
382
+
383
+ ## License
384
+
385
+ MIT License
386
+
387
+ ## More Information
388
+
389
+ - **GitHub:** [DataScienceUIBK/DeAR-Reranking](https://github.com/DataScienceUIBK/DeAR-Reranking)
390
+ - **Paper:** [arXiv:2508.16998](https://arxiv.org/abs/2508.16998)
391
+ - **Collection:** [DeAR Model Collection](https://huggingface.co/collections/abdoelsayed/dear-reranking)