fix: preserve spatial_toks when reasoning=true

#26
by jasonmoo - opened

Fixes bug where spatial_refs parameter fails when reasoning mode is enabled.

Root cause: Line 670 was replacing prompt_tokens with only the suffix,
discarding the spatial_toks that were added earlier. This caused a shape
mismatch error in _prefill_prompt() when trying to encode spatial refs.

Error was:
RuntimeError: shape mismatch: value tensor of shape [2, 2048]
cannot be broadcast to indexing result of shape [0, 2048]

Fix: Convert tensor back to list, append suffix, maintain 2D structure:
Before: prompt_tokens = [self.config.tokenizer.templates["query"]["suffix"]]
After: prompt_tokens = [prompt_tokens[0].tolist() + self.config.tokenizer.templates["query"]["suffix"]]

After _generate_reasoning() returns, prompt_tokens is a torch.Tensor, not a list.
Cannot use += operator on Tensor with list. Must convert to list first, append
suffix, then wrap in list to maintain 2D batch structure.

Tested: spatial_refs now work with reasoning=true without errors.

I'm new to this PR workflow on hf. I meant to use this as the PR description so I'm adding here as a comment:

Summary

Fixes a bug where using spatial_refs parameter with reasoning=true causes a tensor shape mismatch error.

Problem

When calling model.query() with both spatial_refs and reasoning=true:

result = model.query(
    image=img,
    question="What is in the center?",
    reasoning=True,
    spatial_refs=[(0.25, 0.25, 0.75, 0.75)],
)

Error:

RuntimeError: shape mismatch: value tensor of shape [2, 2048]
cannot be broadcast to indexing result of shape [0, 2048]

Root Cause

In moondream.py line 670, when reasoning=True, the code replaces prompt_tokens with only the suffix, discarding the spatial_toks that were added in lines 650-662:

# Lines 650-662: spatial_toks are created and added to prompt_tokens
spatial_toks = []
if spatial_refs:
    for ref in spatial_refs:
        coord_id = self.config.tokenizer.coord_id
        size_id = self.config.tokenizer.size_id
        if len(ref) == 2:
            spatial_toks.extend([coord_id, coord_id])
        else:
            spatial_toks.extend([coord_id, coord_id, size_id])

prompt_tokens = [prompt_toks + spatial_toks + self.tokenizer.encode(question).ids]

# Lines 664-670: Reasoning generation, then prompt_tokens REPLACED
if reasoning:
    # ... reasoning generation ...
    prompt_tokens = [self.config.tokenizer.templates["query"]["suffix"]]  # ❌ Discards spatial_toks

Later, in _prefill_prompt() (line 361-369), the code tries to encode spatial refs into the prompt:

if spatial_refs:
    encoded_refs = encode_spatial_refs(spatial_refs, self.region)
    prompt_emb[prompt_tokens == self.config.tokenizer.coord_id] = encoded_refs["coords"]

The mask prompt_tokens == coord_id finds 0 matches (because spatial_toks were discarded), creating a shape [0, 2048], but encoded_refs["coords"] has shape [2, 2048] (2 coordinates from the bounding box), causing the broadcast error.

Solution

Convert tensor to list, append suffix, and maintain 2D structure:

# Before:
prompt_tokens = [self.config.tokenizer.templates["query"]["suffix"]]

# After:
prompt_tokens = [prompt_tokens[0].tolist() + self.config.tokenizer.templates["query"]["suffix"]]

Why this approach:

After _generate_reasoning() returns (line 668), prompt_tokens is a torch.Tensor, not a list. A simple prompt_tokens[0] += suffix would fail with:

TypeError: unsupported operand type(s) for +=: 'Tensor' and 'list'

The fix:

  1. Converts the first tensor element back to a list with .tolist()
  2. Appends the suffix tokens
  3. Wraps in a list to maintain 2D structure [batch, tokens]
  4. On line 679, it gets converted back to tensor for _generate_answer()

This preserves the spatial_toks in prompt_tokens for the answer generation phase.

Testing

Tested with the following code:

from transformers import AutoModelForCausalLM
from PIL import Image
import numpy as np
import torch

model = AutoModelForCausalLM.from_pretrained(
    "moondream/moondream3-preview",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
)

img = Image.fromarray(np.full((100, 100, 3), 255, dtype=np.uint8))

# This now works without error:
result = model.query(
    image=img,
    question="What is in the center?",
    reasoning=True,
    spatial_refs=[(0.25, 0.25, 0.75, 0.75)],
    stream=False
)

print(f"Answer: {result['answer']}")
print(f"Reasoning: {result['reasoning']['text']}")

Result: βœ“ No errors, spatial_refs work correctly with reasoning enabled.

Checklist

  • Fix tested locally with spatial_refs + reasoning
  • No breaking changes to existing functionality
  • Follows existing code patterns (matches else branch on line 675)
  • Minimal change (single line)

Additional Context

This bug was discovered while building a production API wrapper around Moondream v3. We implemented a workaround by automatically disabling reasoning when spatial_refs are provided, but would prefer to support both features simultaneously.

Related: This issue affects use cases where users need both:

  • Spatial context (region-based queries)
  • Detailed reasoning (step-by-step explanation)

For example: "Explain what's happening in the center region of this image" with spatial_refs=[(0.25, 0.25, 0.75, 0.75)] and reasoning=true.

I'm running this fix in my fork and it's working as expected.

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment