William Brown

ParaLLM: 1600+ tok/s on a MacBook

Batched KV caching for fast parallel LLM inference in MLX.

June 23, 2024

Recently I’ve been doing some LLM finetuning experiments on my MacBook using MLX, and found that there wasn’t really a great way to take advantage of parallel inference for evaluating outputs locally. For single-stream applications like chat interfaces, this isn’t a big deal – both llama.cpp and MLXServer run quite fast on Apple devices. But if you’re trying to sample a large number of outputs at once, either for evaluating a training run or for “agent-flavored” applications, neither of them really offer a speedup in terms of total throughput (at least from what I’ve been able to test). If you’re on a CUDA machine, you’d use something like vLLM, which is a more “production-grade” solution for achieving high tok/s throughput with parallel requests, but it doesn’t work on a Mac.

The main feature we need to enable this in MLX is batched key-value caching. Borrowing heavily from the existing mlx_lm library, I extended the generate method to make use of a BatchedKVCache object and to allow multiple decoding channels via a batch_generate method. For “small” models like Gemma-2B, this gets you to 1600+ tokens/sec in total throughput on a 128GB M3 Max.

from mlx_parallm.utils import load, generate, batch_generate

# fun trick for generating workloads
import string
capital_letters = string.ascii_uppercase
distinct_pairs = [(a, b) for i, a in enumerate(capital_letters) for b in capital_letters[i + 1:]]
prompt_template = "Think of a real word containing both the letters {l1} and {l2}. Then, say 3 sentences which use the word."
prompts_raw = [prompt_template.format(l1=p[0], l2=p[1]) for p in random.sample(distinct_pairs, 325)]

model, tokenizer = load("google/gemma-1.1-2b-it")
responses = batch_generate(model, tokenizer, prompts=prompts_raw, max_tokens=100, verbose=True, temp=0.0)

1600+ tokens per second via Gemma-2B (float16)

The code is available on GitHub as mlx_parallm. I’ve tested with Gemma-2B, Phi-3-mini, and Llama3-8B, all of which get substantial throughput gains vs. single-stream generation, particularly as you increase the number of parallel requests.

Some features like repetition penalties and streaming outputs aren’t supported yet, but I’ll look into putting up a batch_generate PR for mlx_lm if I can get it to a point where it’d be non-breaking. In the meantime, it should be easy to add other models by copying the architecture file(s) from mlx_lm/models into mlx_parallm/models and replacing any KVCache references with BatchedKVCache.