← Back
AI

Training an LLM in Swift, Part 1: Taking Matrix Mult from Gflop/s to Tflop/s

Michael Sintim-Koree · May 2026

This project started from a desire to understand LLM training from the inside. Not call a framework, not wrap PyTorch, but actually implement it. Swift seemed like the right language: it's fast, it runs on Apple Silicon, and the Metal compute API is genuinely approachable once you stop being intimidated by the shader syntax.

The first thing you hit is matrix multiplication. Everything in a transformer runs on matmul. Attention, the feed-forward layers, the output projection — all of it. If your matmul is slow, the rest is academic. So before writing a single line of the actual training loop, it's worth spending real time getting matrix multiplication fast enough to justify the rest of the project.

This is part one. It covers the journey from a naive implementation doing a few Gflop/s to a Metal compute shader hitting the low Tflop/s range on an M2 Max. Part two will cover autograd. But you can't get to autograd if your forward pass takes twenty minutes.


Why matmul performance is the whole game

A transformer forward pass for a modest model — say a 1B parameter architecture with sequence length 2048 — involves thousands of matrix multiplications per training step. The attention mechanism alone does four: Q, K, V projections plus the output projection, multiplied by every layer. The feed-forward block does two more per layer. At 24 layers, that's 144 matmuls per forward pass before you count the backward pass, which roughly doubles the work.

The M2 Max has a theoretical peak of around 13.6 Tflop/s for FP32 on the GPU. Hitting even 30% utilization on that — roughly 4 Tflop/s — would make a Swift training loop viable for smaller models. Hitting 10% would not. The gap between a naive implementation and a fast one isn't a constant factor. It's two to three orders of magnitude.


Starting point: naive Swift on the CPU

The naive implementation is the obvious triple nested loop. For matrices A (M×K) and B (K×N) producing C (M×N):

for i in 0..<M { for j in 0..<N { var sum: Float = 0 ; for k in 0..<K { sum += A[i*K + k] * B[k*N + j] } ; C[i*N + j] = sum } }

On a 1024×1024 matrix multiply, this runs at roughly 0.8 Gflop/s on an M2 Max CPU core. The theoretical peak for a single P-core is around 400 Gflop/s for SIMD-vectorized FP32. We're leaving 99.8% of available compute on the floor.

The reason is cache behavior. The inner loop on B[k*N + j] strides through memory with step N — that's a stride of 4096 bytes for a 1024-wide matrix of Float32. Every access is effectively a cache miss. The CPU has to wait for memory on every inner loop iteration, and the arithmetic units are idle almost the entire time.


Step one: transposing B

The single cheapest improvement is transposing B before the multiply, so the inner loop accesses both A and B sequentially. Instead of striding through columns of B, you walk through rows of B^T, which are contiguous in memory. The loop body doesn't change — the access pattern does.

This alone takes the same 1024×1024 multiply from 0.8 Gflop/s to roughly 8 Gflop/s. A 10x improvement from a layout change. The arithmetic didn't change. The hardware didn't change. The compiler didn't change. Memory access pattern is everything.

You pay a transpose cost upfront, but in a training loop where the weight matrices are used for thousands of forward passes, that cost amortizes to nothing.


Step two: tiling for cache reuse

Transposing B helps but doesn't fully solve cache utilization. For large matrices, the rows of A and B^T still don't fit in L1 or L2 cache simultaneously. The fix is tiling: process the matrices in blocks that do fit, maximizing reuse before eviction.

The standard approach is a three-level tiled loop. Choose a tile size that fits two tiles in L2 — for the M2 Max's 36MB shared L2 on the performance cores, a tile of 64×64 Float32 values is 16KB, so two tiles fit comfortably alongside other working state. Inside each tile, you can unroll the inner loop manually or let the compiler do it with @inline and the right optimization flags. With 64×64 tiling on the CPU, the 1024×1024 benchmark moves to around 35–40 Gflop/s — close to what Accelerate's BLAS sgemm does on a single core.

At this point the CPU path is about as fast as it can get without dropping into assembly or calling Accelerate directly, which defeats the purpose.


Using Accelerate as an honest reference

Before moving to Metal, wiring up cblas_sgemm from Accelerate gives a real ceiling number. On an M2 Max using all P-cores, Accelerate hits around 800–900 Gflop/s on large matrix multiplies. That's not a typo — Accelerate uses AMX (Apple Matrix Extension), the dedicated matrix coprocessor that Apple doesn't fully document but ships on every Apple Silicon chip. The CPU path won't beat that. The goal is to get close to the GPU's theoretical peak, which is 15x higher.


Moving to Metal: the basics

Metal compute shaders are written in Metal Shading Language, which is a C++14-based specification with GPU extensions and restrictions. You write a kernel function, compile it at app startup or ahead of time, create a MTLComputePipelineState from it, and dispatch it over a grid of thread groups.

The first Metal matmul kernel is straightforward: one thread per output element, each thread computing a full dot product over K. This is the GPU equivalent of the naive CPU triple loop, parallelized across the M×N output elements. On the 1024×1024 benchmark, this runs at around 150 Gflop/s — better than the CPU tiled version, but the M2 Max GPU is theoretically capable of 13.6 Tflop/s, so we're at about 1% utilization.

The problem is arithmetic intensity. Each thread loads K floats from A and K floats from B for a single output element — 2K memory reads per multiply-add, which works out to 0.5 flops per byte. Far below the GPU's compute-to-bandwidth ratio. The GPU is memory-bound, not compute-bound.


Tiling in Metal: threadgroup shared memory

The fix is tiling again, but implemented using threadgroup shared memory: an explicit fast scratchpad local to each threadgroup. The pattern:

  • Assign each threadgroup to an output tile of size TILE×TILE.
  • Threads in the threadgroup cooperatively load a TILE×TILE block of A and a TILE×TILE block of B into threadgroup memory.
  • Synchronize with threadgroup_barrier(mem_flags::mem_threadgroup).
  • Each thread computes its partial dot product from the loaded tiles, then advances to the next K-tile and repeats until K is exhausted.

With a 32×32 tile on Metal, the 1024×1024 benchmark hits around 1.8 Tflop/s. That's a 12x improvement over the naive GPU kernel. The memory loads now serve 32 threads each instead of 1, and arithmetic intensity goes up substantially — past the GPU's compute-to-bandwidth crossover point.


Register tiling and simdgroup_matrix: where the hardware actually wants to run

1.8 Tflop/s is usable, but there's more on the table. The next level is register-level tiling: each thread computes a small output submatrix instead of a single element, reusing loaded values across multiple output positions. If each thread computes a 4×4 output block, it does 32 multiply-adds per loaded value instead of 1.

Metal also exposes SIMD groups — 32-thread warps that execute in lockstep — and simdgroup_matrix operations that map directly to the GPU's matrix hardware. On Apple Silicon, simdgroup_matrix<float, 8, 8> lets you perform matrix multiply-accumulate operations using the hardware's preferred access pattern for the matrix coprocessor. This is the approach that finally makes the kernel feel like it's running with the hardware rather than against it.

Using simdgroup_matrix with a 64×64 threadgroup tile and register tiling within each SIMD group, the 1024×1024 benchmark moves to around 4.2 Tflop/s on an M2 Max. That's 31% of theoretical peak — about as good as you can expect from a hand-written kernel without model-specific tuning. MLX, Apple's own ML framework, sits in roughly the same range for this matrix size.


What 4.2 Tflop/s actually buys you

4.2 Tflop/s on matmul means a training step for a small transformer — 100M parameters, batch size 8, sequence length 512 — runs in the low hundreds of milliseconds. Not fast enough for serious pretraining. Fast enough to iterate on architecture experiments, verify gradient flow, and validate the autograd implementation against PyTorch.

That's the actual goal of this project. The aim isn't a production training stack. It's understanding what's happening at every level of a training run — forward pass, loss, backward pass, optimizer step — without framework magic hiding the details. The performance work was necessary to make the feedback loop short enough to be useful, not to compete with JAX.


Three mistakes worth avoiding

The tile size choice matters a lot, and the right answer depends on the specific matrix dimensions you're hitting most often. It's easy to over-optimize for square matrices and then find that the attention mechanism's QK^T multiply is often non-square in ways that hurt tile efficiency. Profile your actual shapes before committing to a tile size.

In practice, reaching for MPSMatrixMultiplication earlier is worth considering. Metal Performance Shaders ships with Apple's own highly optimized matmul, and using it as a reference — and a fallback for shapes your kernel handles poorly — is pragmatic. Writing a kernel from scratch is about understanding, not permanently avoiding the optimized path.

The threadgroup memory bank conflict problem is a common source of lost time. If your tile dimensions cause multiple threads to access the same threadgroup memory bank simultaneously, you get serialized accesses and performance falls off a cliff. Padding the threadgroup arrays by one element in the inner dimension — allocating [TILE][TILE + 1] instead of [TILE][TILE] — resolves most of it. It's worth checking that first before hunting elsewhere for the cause of unexpected slowdowns.


Where this leaves us going into part two

At this point: a Metal-backed matrix type in Swift with operator overloading for +, -, and *, a matmul kernel hitting ~4 Tflop/s on M2 Max for typical transformer shapes, and a thin Swift wrapper around MTLCommandBuffer that handles encoding and synchronization. About 600 lines of code total.

Part two is autograd. The challenge there isn't performance — it's correctness. The plan is to check gradients numerically at every layer, compare against PyTorch for each operation, and be systematic about which operations need custom backward implementations versus what falls out of the chain rule automatically. It remains to be seen whether the Swift type system will make this cleaner or more painful than doing it in Python.


The bank conflict issue and the non-square QK^T shapes are two things that seem to cost the most time and that have almost nothing written about them for Metal specifically. If you've hit either of those — or found a better tile size strategy for attention's actual matrix dimensions — it would be genuinely useful to hear how you handled it.