Upload sampling.py with huggingface_hub
Browse files- sampling.py +41 -21
sampling.py
CHANGED
|
@@ -16,26 +16,41 @@ def apply_top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor:
|
|
| 16 |
"""
|
| 17 |
Apply top-k filtering to logits: with non-top-k values set to -inf
|
| 18 |
"""
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
-
def apply_top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor:
|
| 26 |
"""
|
| 27 |
Apply top-p (nucleus) filtering to logits: with tokens beyond threshold set to -inf
|
| 28 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
| 30 |
-
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| 31 |
|
| 32 |
-
|
|
|
|
|
|
|
| 33 |
sorted_indices_to_remove = cumulative_probs > p
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
| 35 |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
|
| 38 |
-
return logits.masked_fill(indices_to_remove, float('-inf'))
|
| 39 |
|
| 40 |
|
| 41 |
@torch.no_grad()
|
|
@@ -189,18 +204,23 @@ def diffusion_sample(
|
|
| 189 |
# Fall back to positional argument
|
| 190 |
model_output = model(tokens)
|
| 191 |
|
| 192 |
-
|
| 193 |
-
logits = model_output.logits
|
|
|
|
|
|
|
| 194 |
|
| 195 |
-
#
|
| 196 |
-
#
|
| 197 |
-
if
|
| 198 |
-
|
|
|
|
| 199 |
|
| 200 |
-
|
| 201 |
-
|
| 202 |
|
| 203 |
-
|
|
|
|
|
|
|
| 204 |
|
| 205 |
if greedy:
|
| 206 |
pred_next = logp.argmax(-1)
|
|
@@ -208,10 +228,10 @@ def diffusion_sample(
|
|
| 208 |
# Sample from categorical distribution with proper RNG handling
|
| 209 |
if generator is not None:
|
| 210 |
# Use multinomial with generator for reproducible sampling
|
| 211 |
-
probs = logp.exp()
|
| 212 |
pred_next = torch.multinomial(probs.view(-1, probs.size(-1)), 1, generator=generator).squeeze(-1).view(probs.shape[:-1])
|
| 213 |
else:
|
| 214 |
-
|
|
|
|
| 215 |
|
| 216 |
conf_next = torch.gather(logp, -1, pred_next.unsqueeze(-1)).squeeze(-1)
|
| 217 |
|
|
|
|
| 16 |
"""
|
| 17 |
Apply top-k filtering to logits: with non-top-k values set to -inf
|
| 18 |
"""
|
| 19 |
+
if k is None or k <= 0:
|
| 20 |
+
return torch.full_like(logits, float("-inf"))
|
| 21 |
+
k = min(k, logits.size(-1))
|
| 22 |
+
top_k_values, top_k_indices = torch.topk(logits, k, dim=-1)
|
| 23 |
+
filtered = torch.full_like(logits, float("-inf"))
|
| 24 |
+
filtered.scatter_(-1, top_k_indices, top_k_values)
|
| 25 |
+
return filtered
|
| 26 |
|
| 27 |
|
| 28 |
+
def apply_top_p_filtering(logits: torch.Tensor, p: float, min_tokens_to_keep: int = 1) -> torch.Tensor:
|
| 29 |
"""
|
| 30 |
Apply top-p (nucleus) filtering to logits: with tokens beyond threshold set to -inf
|
| 31 |
"""
|
| 32 |
+
if p <= 0:
|
| 33 |
+
p = 1e-8
|
| 34 |
+
if p >= 1:
|
| 35 |
+
return logits
|
| 36 |
+
|
| 37 |
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
|
|
|
| 38 |
|
| 39 |
+
probs = torch.softmax(sorted_logits, dim=-1)
|
| 40 |
+
cumulative_probs = torch.cumsum(probs, dim=-1)
|
| 41 |
+
|
| 42 |
sorted_indices_to_remove = cumulative_probs > p
|
| 43 |
+
|
| 44 |
+
if min_tokens_to_keep > 0:
|
| 45 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = False
|
| 46 |
+
|
| 47 |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 48 |
+
sorted_indices_to_remove[..., 0] = False
|
| 49 |
+
|
| 50 |
+
indices_to_remove = torch.zeros_like(sorted_indices_to_remove)
|
| 51 |
+
indices_to_remove.scatter_(-1, sorted_indices, sorted_indices_to_remove)
|
| 52 |
|
| 53 |
+
return logits.masked_fill(indices_to_remove, float("-inf"))
|
|
|
|
| 54 |
|
| 55 |
|
| 56 |
@torch.no_grad()
|
|
|
|
| 204 |
# Fall back to positional argument
|
| 205 |
model_output = model(tokens)
|
| 206 |
|
| 207 |
+
# Apply temperature scaling (if temperature == 0, treat as 1.0 for greedy)
|
| 208 |
+
logits = model_output.logits
|
| 209 |
+
if temperature > 0:
|
| 210 |
+
logits = logits / temperature
|
| 211 |
|
| 212 |
+
# Apply filtering only when not in greedy mode
|
| 213 |
+
# Order matches reference: top_p before top_k
|
| 214 |
+
if not greedy:
|
| 215 |
+
if top_p is not None and 0 < top_p < 1.0:
|
| 216 |
+
logits = apply_top_p_filtering(logits, top_p)
|
| 217 |
|
| 218 |
+
if top_k is not None and top_k > 0:
|
| 219 |
+
logits = apply_top_k_filtering(logits, top_k)
|
| 220 |
|
| 221 |
+
# Compute probabilities for sampling and metrics
|
| 222 |
+
probs = torch.softmax(logits, dim=-1)
|
| 223 |
+
logp = torch.log(probs + 1e-10) # Add epsilon for numerical stability
|
| 224 |
|
| 225 |
if greedy:
|
| 226 |
pred_next = logp.argmax(-1)
|
|
|
|
| 228 |
# Sample from categorical distribution with proper RNG handling
|
| 229 |
if generator is not None:
|
| 230 |
# Use multinomial with generator for reproducible sampling
|
|
|
|
| 231 |
pred_next = torch.multinomial(probs.view(-1, probs.size(-1)), 1, generator=generator).squeeze(-1).view(probs.shape[:-1])
|
| 232 |
else:
|
| 233 |
+
# Sample from categorical using probabilities
|
| 234 |
+
pred_next = torch.distributions.Categorical(probs=probs).sample()
|
| 235 |
|
| 236 |
conf_next = torch.gather(logp, -1, pred_next.unsqueeze(-1)).squeeze(-1)
|
| 237 |
|