Coverage for src/flag_gems/experimental_ops/rmsnorm.py: 15%
47 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1# Generated by KernelGen v1.0
2# Source: Triton
3# Performance: 2.6x vs Native on A100
4# License: Apache-2.0
7import torch
8import triton
9import triton.language as tl
12@triton.jit
13def rmsnorm(
14 input_ptr, # *Pointer* to the input tensor flattened to 2D [M, N]
15 weight_ptr, # *Pointer* to the weight tensor [N]
16 output_ptr, # *Pointer* to the output tensor flattened to 2D [M, N]
17 M, # Number of rows (prod of leading dims)
18 N, # Hidden size (last dimension)
19 eps, # Epsilon for numerical stability
20 BLOCK_SIZE: tl.constexpr,
21):
22 pid = tl.program_id(axis=0) # each program handles one row
23 row_start = pid * N
25 # First pass: compute sum of squares across the row
26 col_start = 0
27 acc = tl.zeros([1], dtype=tl.float32)
28 while col_start < N:
29 offs = col_start + tl.arange(0, BLOCK_SIZE)
30 mask = offs < N
31 x = tl.load(input_ptr + row_start + offs, mask=mask, other=0.0)
32 x = x.to(tl.float32)
33 acc += tl.sum(x * x, axis=0)
34 col_start += BLOCK_SIZE
36 mean_sq = acc / N
37 inv_rms = 1.0 / tl.sqrt(mean_sq + eps)
39 # Second pass: normalize and scale
40 col_start = 0
41 while col_start < N:
42 offs = col_start + tl.arange(0, BLOCK_SIZE)
43 mask = offs < N
44 x = tl.load(input_ptr + row_start + offs, mask=mask, other=0.0).to(tl.float32)
45 w = tl.load(weight_ptr + offs, mask=mask, other=0.0).to(tl.float32)
46 y = x * inv_rms * w
47 tl.store(output_ptr + row_start + offs, y, mask=mask)
48 col_start += BLOCK_SIZE
51# Keep a handle to the Triton kernel before defining the Python wrapper with the same name
52rmsnorm_kernel = rmsnorm
55def rmsnorm(input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
56 assert input_tensor.is_cuda and weight.is_cuda, "Tensors must be on CUDA device."
57 assert (
58 input_tensor.shape[-1] == weight.numel()
59 ), "weight must have shape (hidden_size,) matching the last dim of input_tensor."
61 x = input_tensor
62 w = weight
63 hidden_size = x.shape[-1]
64 M = x.numel() // hidden_size
65 N = hidden_size
67 # Ensure contiguous memory along the last dimension
68 x_2d = x.contiguous().view(M, N)
69 out = torch.empty_like(x)
70 out_2d = out.view(M, N)
71 w_c = w.contiguous()
73 # Choose a reasonable BLOCK_SIZE (power-of-two up to 4096)
74 def next_pow2(v: int) -> int:
75 return 1 if v <= 1 else 1 << ((v - 1).bit_length())
77 BLOCK_SIZE = min(4096, max(128, next_pow2(N)))
78 num_warps = 4 if BLOCK_SIZE <= 1024 else 8
80 grid = lambda meta: (M,)
82 rmsnorm_kernel[grid](
83 x_2d,
84 w_c,
85 out_2d,
86 M,
87 N,
88 eps,
89 BLOCK_SIZE=BLOCK_SIZE,
90 num_warps=num_warps,
91 )
92 return out