Skip to content

Rewrite Triton normalization backward kernel_1 (#499)#546

Merged
jlamypoirier merged 8 commits into
mainfrom
jlp_norm_kernel1_rewrite
Jun 22, 2026
Merged

Rewrite Triton normalization backward kernel_1 (#499)#546
jlamypoirier merged 8 commits into
mainfrom
jlp_norm_kernel1_rewrite

Conversation

@jlamypoirier

@jlamypoirier jlamypoirier commented Jun 18, 2026

Copy link
Copy Markdown
Collaborator

Summary

Closes the backward-pass gap in layer_norm/rms_norm (issue #499). On H100 the Triton backward trailed apex and torch-compiled by ~1.1–1.6× at most hidden sizes — worst on tall-narrow shapes. kernel_1 was both the bottleneck and the source of an oversized partial-reduction buffer for kernel_2.

kernel_1 rewrite

  1. Decouple the register tile from n_cols. A block_size_row × block_size_col tile grid-strides the columns, so occupancy no longer collapses as the hidden size grows. Rows wider than one chunk use a two-pass scheme (reduce the per-row corrections, then re-read to write grad_input and the partials); narrower rows stay single-pass with no re-read.
  2. Bound the partial-reduction work, the way apex does. apex (both the general fused path and the hand-tuned fast path) does not avoid a second reduction kernel — it bounds the number of partial rows to a small constant via row grid-striding. Previously kernel_1 emitted one partial row per block_size_row input rows, so the buffer kernel_2 reduces grew with the row count (4096 rows at 32768×1024 → kernel_2 ran at ~10% of bandwidth). Single-pass now grid-strides the rows with a program count fixed at multi_processor_count × 2, folding many row tiles into one fp32-accumulated partial. The buffer is then independent of the row count. Two waves per SM is the measured knee — one wave starves grad_input latency-hiding; more only re-inflates kernel_2.

Config tuning

An offline sweep of the kernel_1 config space (single-pass threshold, block_size_col, block_size_row, num_warps, num_stages), validated per shape against the prior config, drives the launch heuristic:

  • Single-pass is extended to wide rows. A wide row stays single-pass when a warp-saturated one-row tile spanning the whole row fits in registers — which avoids the two-pass column re-read entirely. It fits up to the block-size cap without bias and half of it with bias (bias roughly doubles live registers per element), and wins once there are enough rows to fill the SMs. This is the main lever for wide hidden sizes.
  • The remaining two-pass path is tuned (wider column chunk, more warps), and num_stages is threaded through the launch.

Parameter-grad partials reduce in fp32 (the store casts to the buffer dtype) — reducing in bf16 measurably degrades the parameter gradients.

Results (H100, bf16)

Backward µs against apex_fast (layer_norm only — apex has no rms_norm variant) and torch-compiled (max-autotune). apex's general fused path is omitted: it loses on every shape (1.5–3× slower on backward). Bold ratio = Triton matches or beats both; ratio is against the faster of the two. Triton match-or-beats on 9/15 LN and 10/15 RMS shapes.

layer_norm

shape ours apex_fast compiled-max ratio
1024×8192 36.2 49.8 47.7 0.76×
2048×4096 32.3 35.8 45.3 0.90×
8192×1024 27.6 30.2 41.4 0.91×
4096×8192 119 147.9 128.7 0.93×
2048×8192 67.7 83.7 71.8 0.94×
16384×1024 49.1 50.7 59.9 0.97×
4096×4096 56.2 56.9 66.7 0.99×
4096×2048 29.0 29.7 29.4 0.99×
32768×1024 89.4 88.6 102.8 1.01×
512×16384 41.5 53.1 39.2 1.06×
8192×2048 51.5 47.8 59.6 1.08×
1024×16384 77.6 88.0 70.3 1.10×
16384×2048 89.4 81.6 90.2 1.10×
2048×16384 148 155.8 127.8 1.16×
8192×4096 97.2 97.7 83.4 1.17×

rms_norm (no apex_fast variant)

shape ours compiled-max ratio
4096×4096 46.5 62.5 0.74×
4096×8192 86.6 117 0.74×
2048×8192 48.2 62.2 0.77×
8192×1024 24.0 30.7 0.78×
2048×4096 26.8 31.6 0.85×
1024×8192 28.1 32.5 0.86×
32768×1024 78.5 86.9 0.90×
16384×1024 43.1 47.4 0.91×
16384×2048 82.0 83.8 0.98×
2048×16384 114 115 0.99×
1024×16384 64.1 62.3 1.03×
8192×4096 83.2 79.6 1.05×
512×16384 34.7 31.9 1.09×
8192×2048 45.9 40.6 1.13×
4096×2048 25.9 22.3 1.16×

The wide hidden sizes improved up to 1.4× from the config tuning (e.g. rms_norm 4096×8192 121→87 µs) by removing the re-read where it can be afforded, with no regression elsewhere. kernel_2 is no longer a factor (2–5 µs on single-pass shapes, was up to 47 µs).

Remaining sub-parity shapes are a mix of (a) narrow shapes (n_cols ≤ 4096) where apex_fast/compiled are very tight, and (b) layer_norm at the widest hidden size (16384), where the bias term spills the wide single-pass tile so the kernel must fall back to the two-pass re-read. Closing (b) would need a bias-aware shared-memory single-pass; it is the natural follow-up.

Forward is at parity across implementations and is unchanged.

Default implementation

Now that the Triton backward is competitive with or faster than apex, NormalizationConfig.implementation = auto resolves to Triton when Triton is enabled (or required, for zero-centered weights) and falls back to PyTorch otherwise — apex is dropped from the auto path. The apex fast/fused implementations stay available via explicit selection.

Benchmark harness

tools/benchmark/triton_kernels:

  • Isolated, cold-L2 backward timing (forward untimed, L2 flushed, then the backward timed) — training-representative. The prior fwd_bwd − fwd number had a warm-L2 confound: the forward left the saved output partly resident in L2, flattering the backward in a way real training never sees.
  • Per-kernel device-time breakdown, so kernel_1 and kernel_2 can be attributed separately.

Validation

tests/layers/ and tests/tools/test_triton_benchmark.py: 733 passed, 27 skipped (H100). Parameter-grad precision is bit-equivalent to the previous kernel (grad_weight rel-rms ≈ 2.8–2.9e-3).


Authored by Claude Opus 4.8 (Claude Code).

🤖 Generated with Claude Code

jlamypoirier and others added 8 commits June 18, 2026 16:14
The backward of `layer_norm`/`rms_norm` trailed apex and torch-compiled by
1.1-1.6x at most hidden sizes, worst on tall-narrow shapes. kernel_1 was the
bottleneck and over-produced grad_weight/grad_bias partials.

kernel_1:
- Decouple the register tile from `n_cols`: a `block_size_row x block_size_col`
  tile grid-strides the columns, so occupancy no longer collapses as hidden size
  grows. Rows wider than one chunk use a two-pass scheme (reduce per-row
  corrections, then re-read to write grad_input and the partials); narrower rows
  stay single pass.
- Bound the partial-reduction work like apex: single pass grid-strides the rows
  with a program count fixed at `multi_processor_count x 2`, folding many row
  tiles into one fp32-accumulated partial. The partial buffer kernel_2 reduces is
  then independent of the row count instead of growing with it (e.g. 4096 -> ~260
  rows at 32768x1024), which was the dominant remaining cost. Two waves per SM is
  the measured knee: one starves grad_input latency-hiding, more only re-inflates
  kernel_2.

Parameter-grad partials reduce in fp32 (the store casts to the buffer dtype);
reducing in bf16 degraded the parameter gradients.

Result (H100, bf16): tall-narrow shapes go from ~1.3-1.6x behind to parity or
better against the fastest alternative (apex_fast / torch-compiled-max), and
apex's general fused path is beaten across the board. Wide hidden sizes
(two-pass) remain ~1.1-1.3x behind, bounded by the column re-read.

Benchmark harness (tools/benchmark/triton_kernels):
- Measure backward in isolation with a cold L2 (forward untimed, L2 flushed, then
  the backward timed), which is training-representative. The prior fwd_bwd-minus-fwd
  number had a warm-L2 confound: the forward left the saved output partly resident,
  flattering the backward in a way real training never sees.
- Add a per-kernel device-time breakdown so kernel_1 and kernel_2 can be attributed
  separately.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Offline-sweep the kernel_1 config space (single-pass threshold, block_size_col,
block_size_row, num_warps, num_stages) per shape, validated against the prior
config, and fold the result into the launch heuristic:

- Extend single-pass to wide rows. A wide row can stay single-pass when a
  warp-saturated one-row tile spanning the whole row fits in registers, which
  avoids the two-pass column re-read. It fits up to the block-size cap without
  bias and half of it with bias (bias roughly doubles live registers per
  element), and wins once there are enough rows to fill the SMs.
- Tune the remaining two-pass path (wider column chunk, more warps).
- Thread num_stages through the launch.

Result (H100, bf16): wide hidden sizes improve up to 1.4x (e.g. rms_norm
4096x8192 121->87us) with no regression elsewhere, by removing the re-read where
it can be afforded. The remaining sub-parity shapes are now the narrow ones,
where apex's per-hidden-size kernel is hard to match.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Now that the Triton backward is competitive with or faster than apex, drop apex
from the `auto` resolution: it picks Triton when Triton is enabled (or required,
for zero-centered weights) and falls back to PyTorch otherwise. The apex `fast`
and `fused` implementations remain available via explicit selection.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The single-pass launch queried the CUDA SM count unconditionally to bound the
program count, which broke `triton_normalization_backward` under
`TRITON_INTERPRET=1` on CPU (no CUDA). The bound is a GPU-occupancy heuristic, so
skip it off-GPU and use one program per tile. Also cover the two-pass path in the
kernel test (it was only exercising single-pass) and drop the now-constant
num_stages from the launch config.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Add a (4096, 8192) case — wide single-pass for both bias settings — skipped under
the Triton interpreter where the size is prohibitive. Keep the effective weight
near 1 so the backward's `(output - bias) / weight` recovery stays well
conditioned: random near-zero weights amplify fp32 error and intermittently
diverge from the reference at wide n_cols, unrelated to the kernel.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The cache stored a single SM count and reused the first device's value for every
device. Key it by device index so the `device` argument is honored.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Add the `typing.Callable[[], None]` return annotation to the two triton backward
builders, matching their fwd/fwd_bwd siblings and the Variant.backward field.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@jlamypoirier jlamypoirier merged commit 27f4747 into main Jun 22, 2026
3 checks passed
@jlamypoirier jlamypoirier deleted the jlp_norm_kernel1_rewrite branch June 22, 2026 19:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant