High memory usage with mlx-lm
mlx_lm.benchmark with a relatively small prompt uses 24.4gb of RAM. Is this normal?
❯ mlx_lm.benchmark --model mlx-community/granite-4.0-h-tiny-8bit -p 2048 -g 128
Running warmup..
Timing with prompt_tokens=2048, generation_tokens=128, batch_size=1.
Trial 1: prompt_tps=970.617, generation_tps=101.874, peak_memory=24.439
Trial 2: prompt_tps=965.406, generation_tps=101.726, peak_memory=24.440
Trial 3: prompt_tps=972.656, generation_tps=101.855, peak_memory=24.440
Trial 4: prompt_tps=953.336, generation_tps=101.399, peak_memory=24.440
Trial 5: prompt_tps=965.120, generation_tps=101.688, peak_memory=24.440
Averages: prompt_tps=965.427, generation_tps=101.709, peak_memory=24.440
Hm, this is a really interesting question. I did the mlx-lm implementation, but used a coding assistant (the downside of AI-assisted development). It's possible that something is broken in my implementation. It looks like there's an open PR to make this better though: https://github.com/ml-explore/mlx-lm/pull/525.
Looking at the PR, I think the cause here is the use of the SSD (State Space Duality) algorithm for prefill. The PR limits the overall size by processing the context in chunks.
Thanks for the info! I should have checked GitHub again before posting.
I should have checked GitHub again before posting.
Not at all, thanks for raising it here! I hadn't seen the PR until you asked, and this will be where many people look first.
Looks like the PR comes with a nice boost in prompt processing speed too! (M4 Pro 48gb for reference)
❯ mlx_lm.benchmark --model mlx-community/granite-4.0-h-tiny-8bit -p 2048 -g 128
Running warmup..
Timing with prompt_tokens=2048, generation_tokens=128, batch_size=1.
Trial 1: prompt_tps=1559.782, generation_tps=99.254, peak_memory=9.745
Trial 2: prompt_tps=1552.299, generation_tps=98.935, peak_memory=9.745
Trial 3: prompt_tps=1556.283, generation_tps=98.678, peak_memory=9.746
Trial 4: prompt_tps=1553.376, generation_tps=98.632, peak_memory=9.746
Trial 5: prompt_tps=1550.414, generation_tps=98.808, peak_memory=9.746
Averages: prompt_tps=1554.431, generation_tps=98.862, peak_memory=9.746