June 23, 2024
Recently I’ve been doing some LLM finetuning experiments on my MacBook using MLX, and found that there wasn’t really a great way to take advantage of parallel inference for evaluating outputs locally. For single-stream applications like chat interfaces, this isn’t a big deal – both llama.cpp and MLXServer run quite fast on Apple devices. But if you’re trying to sample a large number of outputs at once, either for evaluating a training run or for “agent-flavored” applications, neither of them really offer a speedup in terms of total throughput (at least from what I’ve been able to test). If you’re on a CUDA machine, you’d use something like vLLM, which is a more “production-grade” solution for achieving high tok/s throughput with parallel requests, but it doesn’t work on a Mac.
The main feature we need to enable this in MLX is batched key-value caching. Borrowing heavily from the existing mlx_lm library, I extended the generate
method to make use of a BatchedKVCache
object and to allow multiple decoding channels via a batch_generate
method. For “small” models like Gemma-2B, this gets you to 1600+ tokens/sec in total throughput on a 128GB M3 Max.
from mlx_parallm.utils import load, generate, batch_generate
# fun trick for generating workloads
import string
capital_letters = string.ascii_uppercase
distinct_pairs = [(a, b) for i, a in enumerate(capital_letters) for b in capital_letters[i + 1:]]
prompt_template = "Think of a real word containing both the letters {l1} and {l2}. Then, say 3 sentences which use the word."
prompts_raw = [prompt_template.format(l1=p[0], l2=p[1]) for p in random.sample(distinct_pairs, 325)]
model, tokenizer = load("google/gemma-1.1-2b-it")
responses = batch_generate(model, tokenizer, prompts=prompts_raw, max_tokens=100, verbose=True, temp=0.0)
The code is available on GitHub as mlx_parallm. I’ve tested with Gemma-2B, Phi-3-mini, and Llama3-8B, all of which get substantial throughput gains vs. single-stream generation, particularly as you increase the number of parallel requests.
Some features like repetition penalties and streaming outputs aren’t supported yet, but I’ll look into putting up a batch_generate
PR for mlx_lm if I can get it to a point where it’d be non-breaking. In the meantime, it should be easy to add other models by copying the architecture file(s) from mlx_lm/models
into mlx_parallm/models
and replacing any KVCache
references with BatchedKVCache
.