athms commited on
Commit
d4073f8
·
verified ·
1 Parent(s): 139362b

Upload sampling.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. sampling.py +41 -21
sampling.py CHANGED
@@ -16,26 +16,41 @@ def apply_top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor:
16
  """
17
  Apply top-k filtering to logits: with non-top-k values set to -inf
18
  """
19
- top_k_values, top_k_indices = torch.topk(logits, min(k, logits.size(-1)), dim=-1)
20
- filtered_logits = torch.full_like(logits, float('-inf'))
21
- filtered_logits.scatter_(-1, top_k_indices, top_k_values)
22
- return filtered_logits
 
 
 
23
 
24
 
25
- def apply_top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor:
26
  """
27
  Apply top-p (nucleus) filtering to logits: with tokens beyond threshold set to -inf
28
  """
 
 
 
 
 
29
  sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
30
- cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
31
 
32
- # Remove tokens with cumulative probability above threshold
 
 
33
  sorted_indices_to_remove = cumulative_probs > p
34
- sorted_indices_to_remove[..., 0] = False # Keep at least one token
 
 
 
35
  sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
 
 
 
 
36
 
37
- indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
38
- return logits.masked_fill(indices_to_remove, float('-inf'))
39
 
40
 
41
  @torch.no_grad()
@@ -189,18 +204,23 @@ def diffusion_sample(
189
  # Fall back to positional argument
190
  model_output = model(tokens)
191
 
192
- safe_temperature = max(temperature, 1e-8) # Prevent division by zero
193
- logits = model_output.logits / safe_temperature
 
 
194
 
195
- # Note: When both top_k and top_p are provided, they are applied sequentially:
196
- # First top_k filters to k tokens, then top_p filters from those k tokens
197
- if top_k is not None and top_k > 0:
198
- logits = apply_top_k_filtering(logits, top_k)
 
199
 
200
- if top_p is not None and 0 < top_p < 1.0:
201
- logits = apply_top_p_filtering(logits, top_p)
202
 
203
- logp = torch.log_softmax(logits, dim=-1)
 
 
204
 
205
  if greedy:
206
  pred_next = logp.argmax(-1)
@@ -208,10 +228,10 @@ def diffusion_sample(
208
  # Sample from categorical distribution with proper RNG handling
209
  if generator is not None:
210
  # Use multinomial with generator for reproducible sampling
211
- probs = logp.exp()
212
  pred_next = torch.multinomial(probs.view(-1, probs.size(-1)), 1, generator=generator).squeeze(-1).view(probs.shape[:-1])
213
  else:
214
- pred_next = torch.distributions.Categorical(logits=logp).sample()
 
215
 
216
  conf_next = torch.gather(logp, -1, pred_next.unsqueeze(-1)).squeeze(-1)
217
 
 
16
  """
17
  Apply top-k filtering to logits: with non-top-k values set to -inf
18
  """
19
+ if k is None or k <= 0:
20
+ return torch.full_like(logits, float("-inf"))
21
+ k = min(k, logits.size(-1))
22
+ top_k_values, top_k_indices = torch.topk(logits, k, dim=-1)
23
+ filtered = torch.full_like(logits, float("-inf"))
24
+ filtered.scatter_(-1, top_k_indices, top_k_values)
25
+ return filtered
26
 
27
 
28
+ def apply_top_p_filtering(logits: torch.Tensor, p: float, min_tokens_to_keep: int = 1) -> torch.Tensor:
29
  """
30
  Apply top-p (nucleus) filtering to logits: with tokens beyond threshold set to -inf
31
  """
32
+ if p <= 0:
33
+ p = 1e-8
34
+ if p >= 1:
35
+ return logits
36
+
37
  sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
 
38
 
39
+ probs = torch.softmax(sorted_logits, dim=-1)
40
+ cumulative_probs = torch.cumsum(probs, dim=-1)
41
+
42
  sorted_indices_to_remove = cumulative_probs > p
43
+
44
+ if min_tokens_to_keep > 0:
45
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = False
46
+
47
  sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
48
+ sorted_indices_to_remove[..., 0] = False
49
+
50
+ indices_to_remove = torch.zeros_like(sorted_indices_to_remove)
51
+ indices_to_remove.scatter_(-1, sorted_indices, sorted_indices_to_remove)
52
 
53
+ return logits.masked_fill(indices_to_remove, float("-inf"))
 
54
 
55
 
56
  @torch.no_grad()
 
204
  # Fall back to positional argument
205
  model_output = model(tokens)
206
 
207
+ # Apply temperature scaling (if temperature == 0, treat as 1.0 for greedy)
208
+ logits = model_output.logits
209
+ if temperature > 0:
210
+ logits = logits / temperature
211
 
212
+ # Apply filtering only when not in greedy mode
213
+ # Order matches reference: top_p before top_k
214
+ if not greedy:
215
+ if top_p is not None and 0 < top_p < 1.0:
216
+ logits = apply_top_p_filtering(logits, top_p)
217
 
218
+ if top_k is not None and top_k > 0:
219
+ logits = apply_top_k_filtering(logits, top_k)
220
 
221
+ # Compute probabilities for sampling and metrics
222
+ probs = torch.softmax(logits, dim=-1)
223
+ logp = torch.log(probs + 1e-10) # Add epsilon for numerical stability
224
 
225
  if greedy:
226
  pred_next = logp.argmax(-1)
 
228
  # Sample from categorical distribution with proper RNG handling
229
  if generator is not None:
230
  # Use multinomial with generator for reproducible sampling
 
231
  pred_next = torch.multinomial(probs.view(-1, probs.size(-1)), 1, generator=generator).squeeze(-1).view(probs.shape[:-1])
232
  else:
233
+ # Sample from categorical using probabilities
234
+ pred_next = torch.distributions.Categorical(probs=probs).sample()
235
 
236
  conf_next = torch.gather(logp, -1, pred_next.unsqueeze(-1)).squeeze(-1)
237