Skip to content

perf(qwen3next): drop q/k/v/a/b contiguous copies in GDN fused_recurrent decode#1349

Open
sufubao wants to merge 1 commit into
ModelTC:mainfrom
sufubao:perf/fused_recurrent
Open

perf(qwen3next): drop q/k/v/a/b contiguous copies in GDN fused_recurrent decode#1349
sufubao wants to merge 1 commit into
ModelTC:mainfrom
sufubao:perf/fused_recurrent

Conversation

@sufubao

@sufubao sufubao commented Jun 13, 2026

Copy link
Copy Markdown
Collaborator

What

The Qwen3-Next / Qwen3.5 gated-delta-rule (GDN) decode path passes q/k/v/a_raw/b_raw into FusedRecurrentFunction, which previously called .contiguous() on each before launching the Triton kernel. These tensors are column views of a single fused projection output, so the copies are pure per-decode-step overhead (extra allocations + copy kernels).

This PR teaches the kernel to read those views directly: it passes the per-token element stride of each tensor in and indexes token i as base + i * token_stride. A small helper (_ensure_token_strided) derives the stride for the decode [tokens,1,H,D], varlen [1,tokens,H,D], and 2D a/b layouts, and falls back to .contiguous() only when a tensor genuinely can't be addressed that way.

Correctness

Adds unit_tests/models/qwen3next/test_fused_recurrent_strided.py, which asserts bit-exact parity (torch.equal) of both the output and the written SSM state between the strided and the old contiguous path:

  • decode layout, batch 1 / 2 / 16
  • varlen layout with 2D ssm_state_indices

Also hardens the 2D index-stride handling with explicit contiguous-row asserts (the kernel advances read/write indices with token-stride 1).

Performance

Static decode throughput, Qwen3.5-122B-A10B, TP8 on H200, output_len=256 (decode tok/s = mean of steady-state steps 100/200/255, optimized vs. baseline on the same commit):

batch baseline optimized speedup
2 285.7 297.1 +3.97%
8 989.3 1056.7 +6.82%
16 1788.2 1860.8 +4.06%
32 3026.9 3255.7 +7.56%
64 5185.9 5451.6 +5.12%
128 8423.5 8600.6 +2.10%

Consistent improvement at every batch size (mean ~+5%). The prefill path is unchanged.

Benchmarked at input_len=256 so the full [2..128] batch sweep runs in one model load; GDN decode is sequence-length-independent (fixed-size recurrent state), so this is a faithful decode-throughput proxy.

…ent decode

The gated-delta-rule decode path passed q/k/v/a_raw/b_raw to
FusedRecurrentFunction, which copied each to a contiguous buffer before the
Triton kernel. These are column views of one fused projection output, so the
copies are pure per-step overhead.

Instead, pass the per-token element stride of each tensor into the kernel so it
reads the column views directly (base + i * token_stride). A helper derives the
stride for the decode [tokens,1,H,D], varlen [1,tokens,H,D], and 2D a/b layouts,
and falls back to .contiguous() only when a tensor can't be addressed that way.

Adds unit_tests/models/qwen3next/test_fused_recurrent_strided.py asserting
bit-exact (torch.equal) parity of output and SSM state between the strided and
contiguous paths, for decode (bs 1/2/16) and the varlen layout.

Static decode throughput (Qwen3.5-122B-A10B, TP8 / H200, output_len=256):
consistent +2-8% across batch 2..128 (mean ~+5%); prefill path unchanged.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes the fused recurrent gated delta rule kernel by allowing it to read non-contiguous column views of projection outputs (q, k, v, a_raw, b_raw) directly via per-token strides, eliminating the need for contiguous copies. It introduces helper functions to calculate token strides and adds comprehensive unit tests to verify correctness. The review feedback suggests replacing assert statements with explicit ValueError exceptions for validating tensor strides, ensuring these critical runtime checks are not stripped out when Python is run with optimization flags.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

elif ssm_state_indices.ndim == 1:
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
else:
assert ssm_state_indices.stride(-1) == 1, "2D ssm_state_indices must have contiguous rows"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using assert statements for runtime validation of tensor properties can be risky because assertions are stripped out when Python is run with optimization flags (-O). If these checks are bypassed, it could lead to silent correctness issues or out-of-bounds memory accesses in the Triton kernel. It is safer to raise a ValueError instead.

Suggested change
assert ssm_state_indices.stride(-1) == 1, "2D ssm_state_indices must have contiguous rows"
if ssm_state_indices.stride(-1) != 1:
raise ValueError("2D ssm_state_indices must have contiguous rows")

elif ssm_state_write_indices.ndim == 1:
stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride(0), 1
else:
assert ssm_state_write_indices.stride(-1) == 1, "2D ssm_state_write_indices must have contiguous rows"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using assert statements for runtime validation of tensor properties can be risky because assertions are stripped out when Python is run with optimization flags (-O). If these checks are bypassed, it could lead to silent correctness issues or out-of-bounds memory accesses in the Triton kernel. It is safer to raise a ValueError instead.

Suggested change
assert ssm_state_write_indices.stride(-1) == 1, "2D ssm_state_write_indices must have contiguous rows"
if ssm_state_write_indices.stride(-1) != 1:
raise ValueError("2D ssm_state_write_indices must have contiguous rows")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant