add compression rates
Browse files- modeling_provence.py +10 -1
modeling_provence.py
CHANGED
|
@@ -158,6 +158,9 @@ class Provence(DebertaV2PreTrainedModel):
|
|
| 158 |
reranking_scores = [
|
| 159 |
[None for j in range(len(contexts[i]))] for i in range(len(queries))
|
| 160 |
]
|
|
|
|
|
|
|
|
|
|
| 161 |
with torch.no_grad():
|
| 162 |
for batch_start in tqdm(
|
| 163 |
range(0, len(dataset), batch_size), desc="Pruning contexts..."
|
|
@@ -225,18 +228,24 @@ class Provence(DebertaV2PreTrainedModel):
|
|
| 225 |
)
|
| 226 |
else:
|
| 227 |
selected_contexts[i][j] = selected_contexts[i][j][0]
|
|
|
|
|
|
|
|
|
|
| 228 |
if reorder:
|
| 229 |
idxs = np.argsort(reranking_scores[i])[::-1][:top_k]
|
| 230 |
selected_contexts[i] = [selected_contexts[i][j] for j in idxs]
|
| 231 |
reranking_scores[i] = [reranking_scores[i][j] for j in idxs]
|
|
|
|
| 232 |
|
| 233 |
if type(context) == str:
|
| 234 |
selected_contexts = selected_contexts[0][0]
|
| 235 |
reranking_scores = reranking_scores[0][0]
|
|
|
|
| 236 |
|
| 237 |
return {
|
| 238 |
"pruned_context": selected_contexts,
|
| 239 |
-
"reranking_score": reranking_scores
|
|
|
|
| 240 |
}
|
| 241 |
|
| 242 |
|
|
|
|
| 158 |
reranking_scores = [
|
| 159 |
[None for j in range(len(contexts[i]))] for i in range(len(queries))
|
| 160 |
]
|
| 161 |
+
compressions = [
|
| 162 |
+
[0 for j in range(len(contexts[i]))] for i in range(len(queries))
|
| 163 |
+
]
|
| 164 |
with torch.no_grad():
|
| 165 |
for batch_start in tqdm(
|
| 166 |
range(0, len(dataset), batch_size), desc="Pruning contexts..."
|
|
|
|
| 228 |
)
|
| 229 |
else:
|
| 230 |
selected_contexts[i][j] = selected_contexts[i][j][0]
|
| 231 |
+
len_original = len(contexts[i][j])
|
| 232 |
+
len_compressed = len(selected_contexts[i][j])
|
| 233 |
+
compressions[i][j] = (len_original-len_compressed)/len_original * 100
|
| 234 |
if reorder:
|
| 235 |
idxs = np.argsort(reranking_scores[i])[::-1][:top_k]
|
| 236 |
selected_contexts[i] = [selected_contexts[i][j] for j in idxs]
|
| 237 |
reranking_scores[i] = [reranking_scores[i][j] for j in idxs]
|
| 238 |
+
compressions[i] = [compressions[i][j] for j in idxs]
|
| 239 |
|
| 240 |
if type(context) == str:
|
| 241 |
selected_contexts = selected_contexts[0][0]
|
| 242 |
reranking_scores = reranking_scores[0][0]
|
| 243 |
+
compressions = compressions[0][0]
|
| 244 |
|
| 245 |
return {
|
| 246 |
"pruned_context": selected_contexts,
|
| 247 |
+
"reranking_score": reranking_scores,
|
| 248 |
+
"compression_rate": compressions,
|
| 249 |
}
|
| 250 |
|
| 251 |
|