Coverage for src/flag_gems/runtime/backend/_arm/ops/rms_norm.py: 0%
3 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1"""
2ARM CPU fused_add_rms_norm wrapper.
4Wraps the _arm/fused/fused_add_rms_norm.py Triton kernel so it can be used
5as a drop-in replacement for flag_gems.fused_add_rms_norm on ARM64 CPU.
7Standalone rms_norm (without residual add) was removed: A/B measurement on
8Qwen3-1.7B INT8 decode showed no measurable benefit over ATen's native
9Qwen3RMSNorm.forward. See _arm/fused/fused_add_rms_norm.py for the note.
10"""
12from flag_gems.runtime.backend._arm.fused.fused_add_rms_norm import (
13 fused_add_rms_norm as _arm_fused_add_rms_norm,
14)
17def fused_add_rms_norm(x, residual, normalized_shape, weight, eps=1e-5):
18 """
19 ARM CPU drop-in for flag_gems.fused_add_rms_norm.
21 In-place: residual = x + residual; x = rms_norm(residual) * weight.
22 Returns (x, residual).
23 """
24 return _arm_fused_add_rms_norm(x, residual, normalized_shape, weight, eps)