Commit
·
5f8e4b6
1
Parent(s):
c35a42b
fix: sentences as a str
Browse filesSigned-off-by: Meow <[email protected]>
- modeling_lora.py +9 -6
modeling_lora.py
CHANGED
|
@@ -393,17 +393,20 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 393 |
f"Alternatively, don't pass the `task_type` argument to disable LoRA."
|
| 394 |
)
|
| 395 |
adapter_mask = None
|
| 396 |
-
sentences = list(sentences) if isinstance(sentences, str) else sentences
|
| 397 |
if task_type:
|
| 398 |
task_id = self._adaptation_map[task_type]
|
|
|
|
| 399 |
adapter_mask = torch.full(
|
| 400 |
-
(
|
| 401 |
)
|
| 402 |
if task_type in ["query", "passage"]:
|
| 403 |
-
sentences
|
| 404 |
-
self._task_instructions[task_type] + " " +
|
| 405 |
-
|
| 406 |
-
|
|
|
|
|
|
|
|
|
|
| 407 |
return self.roberta.encode(
|
| 408 |
sentences, *args, adapter_mask=adapter_mask, **kwargs
|
| 409 |
)
|
|
|
|
| 393 |
f"Alternatively, don't pass the `task_type` argument to disable LoRA."
|
| 394 |
)
|
| 395 |
adapter_mask = None
|
|
|
|
| 396 |
if task_type:
|
| 397 |
task_id = self._adaptation_map[task_type]
|
| 398 |
+
num_examples = 1 if isinstance(sentences, str) else len(sentences)
|
| 399 |
adapter_mask = torch.full(
|
| 400 |
+
(num_examples,), task_id, dtype=torch.int32, device=self.device
|
| 401 |
)
|
| 402 |
if task_type in ["query", "passage"]:
|
| 403 |
+
if isinstance(sentences, str):
|
| 404 |
+
sentences = self._task_instructions[task_type] + " " + sentences
|
| 405 |
+
else:
|
| 406 |
+
sentences = [
|
| 407 |
+
self._task_instructions[task_type] + " " + sentence
|
| 408 |
+
for sentence in sentences
|
| 409 |
+
]
|
| 410 |
return self.roberta.encode(
|
| 411 |
sentences, *args, adapter_mask=adapter_mask, **kwargs
|
| 412 |
)
|