Gemma 4 inference sped up with fused kernels
Gemma 4 inference gets up to 17% faster by fusing separate GELU and tanh operations into single kernels, reducing computational overhead across prefill and generation phases.
Gemma 4 model inference runs significantly faster after this optimization. The change replaces sequences of separate math operations — GELU approximation followed by multiplication, and tanh followed by scalar multiplication — with fused kernels that compute both steps in a single pass. Benchmarks show 10-17% improvements during prefill and 5-15% during generation across model sizes from 2B to 31B parameters. These fused operations are particularly impactful in the MLP and Mixture-of-Experts layers where activation functions dominate the compute budget. The changes are localized to the MLX runner and Gemma 4 model implementation.
View Original GitHub Description
go run cmd/bench/bench.go -model gemma4:XXX-nvfp4 -prompt-tokens 2048 -max-tokens 128 -epochs 5 -warmup 1
┌──────┬─────────┬──────────┬────────────────┬────────┐
│ Size │ Metric │ Baseline │ New (compiled) │ Δ │
├──────┼─────────┼──────────┼────────────────┼────────┤
│ E2B │ prefill │ 18,777 │ 21,901 │ +16.6% │
├──────┼─────────┼──────────┼────────────────┼────────┤
│ E2B │ gen │ 153.8 │ 176.9 │ +15.0% │
├──────┼─────────┼──────────┼────────────────┼────────┤
│ E4B │ prefill │ 6,980 │ 8,086 │ +15.8% │
├──────┼─────────┼──────────┼────────────────┼────────┤
│ E4B │ gen │ 99.1 │ 110.6 │ +11.6% │
├──────┼─────────┼──────────┼────────────────┼────────┤
│ 26B │ prefill │ 3,957 │ 4,372 │ +10.5% │
├──────┼─────────┼──────────┼────────────────┼────────┤
│ 26B │ gen │ 101.3 │ 107.5 │ +6.1% │
├──────┼─────────┼──────────┼────────────────┼────────┤
│ 31B │ prefill │ 531 │ 593 │ +11.7% │
├──────┼─────────┼──────────┼────────────────┼────────┤
│ 31B │ gen │ 21.4 │ 22.4 │ +4.8% │
└──────┴─────────┴──────────┴────────────────┴────────┘