Coverage for src/flag_gems/experimental_ops/rmsnorm.py: 13%
46 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1# Generated by KernelGen v1.0
2# Source: Triton
3# Performance: 2.6x vs Native on A100
4# License: Apache-2.0
7import torch
8import triton
9import triton.language as tl
12@triton.jit
13def rmsnorm_kernel(
14 input_ptr, # *Pointer* to the input tensor flattened to 2D [M, N]
15 weight_ptr, # *Pointer* to the weight tensor [N]
16 output_ptr, # *Pointer* to the output tensor flattened to 2D [M, N]
17 M, # Number of rows (prod of leading dims)
18 N, # Hidden size (last dimension)
19 eps, # Epsilon for numerical stability
20 BLOCK_SIZE: tl.constexpr,
21):
22 pid = tl.program_id(axis=0) # each program handles one row
23 row_start = pid * N
25 # First pass: compute sum of squares across the row
26 col_start = 0
27 acc = tl.zeros([1], dtype=tl.float32)
28 while col_start < N:
29 offs = col_start + tl.arange(0, BLOCK_SIZE)
30 mask = offs < N
31 x = tl.load(input_ptr + row_start + offs, mask=mask, other=0.0)
32 x = x.to(tl.float32)
33 acc += tl.sum(x * x, axis=0)
34 col_start += BLOCK_SIZE
36 mean_sq = acc / N
37 inv_rms = 1.0 / tl.sqrt(mean_sq + eps)
39 # Second pass: normalize and scale
40 col_start = 0
41 while col_start < N:
42 offs = col_start + tl.arange(0, BLOCK_SIZE)
43 mask = offs < N
44 x = tl.load(input_ptr + row_start + offs, mask=mask, other=0.0).to(tl.float32)
45 w = tl.load(weight_ptr + offs, mask=mask, other=0.0).to(tl.float32)
46 y = x * inv_rms * w
47 tl.store(output_ptr + row_start + offs, y, mask=mask)
48 col_start += BLOCK_SIZE
51def rmsnorm(input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
52 assert input_tensor.is_cuda and weight.is_cuda, "Tensors must be on CUDA device."
53 assert (
54 input_tensor.shape[-1] == weight.numel()
55 ), "weight must have shape (hidden_size,) matching the last dim of input_tensor."
57 x = input_tensor
58 w = weight
59 hidden_size = x.shape[-1]
60 M = x.numel() // hidden_size
61 N = hidden_size
63 # Ensure contiguous memory along the last dimension
64 x_2d = x.contiguous().view(M, N)
65 out = torch.empty_like(x)
66 out_2d = out.view(M, N)
67 w_c = w.contiguous()
69 # Choose a reasonable BLOCK_SIZE (power-of-two up to 4096)
70 def next_pow2(v: int) -> int:
71 return 1 if v <= 1 else 1 << ((v - 1).bit_length())
73 BLOCK_SIZE = min(4096, max(128, next_pow2(N)))
74 num_warps = 4 if BLOCK_SIZE <= 1024 else 8
76 grid = lambda meta: (M,)
78 rmsnorm_kernel[grid](
79 x_2d,
80 w_c,
81 out_2d,
82 M,
83 N,
84 eps,
85 BLOCK_SIZE=BLOCK_SIZE,
86 num_warps=num_warps,
87 )
88 return out