Coverage for src/flag_gems/experimental_ops/rmsnorm.py: 15%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1# Generated by KernelGen v1.0 

2# Source: Triton 

3# Performance: 2.6x vs Native on A100 

4# License: Apache-2.0 

5 

6 

7import torch 

8import triton 

9import triton.language as tl 

10 

11 

12@triton.jit 

13def rmsnorm( 

14 input_ptr, # *Pointer* to the input tensor flattened to 2D [M, N] 

15 weight_ptr, # *Pointer* to the weight tensor [N] 

16 output_ptr, # *Pointer* to the output tensor flattened to 2D [M, N] 

17 M, # Number of rows (prod of leading dims) 

18 N, # Hidden size (last dimension) 

19 eps, # Epsilon for numerical stability 

20 BLOCK_SIZE: tl.constexpr, 

21): 

22 pid = tl.program_id(axis=0) # each program handles one row 

23 row_start = pid * N 

24 

25 # First pass: compute sum of squares across the row 

26 col_start = 0 

27 acc = tl.zeros([1], dtype=tl.float32) 

28 while col_start < N: 

29 offs = col_start + tl.arange(0, BLOCK_SIZE) 

30 mask = offs < N 

31 x = tl.load(input_ptr + row_start + offs, mask=mask, other=0.0) 

32 x = x.to(tl.float32) 

33 acc += tl.sum(x * x, axis=0) 

34 col_start += BLOCK_SIZE 

35 

36 mean_sq = acc / N 

37 inv_rms = 1.0 / tl.sqrt(mean_sq + eps) 

38 

39 # Second pass: normalize and scale 

40 col_start = 0 

41 while col_start < N: 

42 offs = col_start + tl.arange(0, BLOCK_SIZE) 

43 mask = offs < N 

44 x = tl.load(input_ptr + row_start + offs, mask=mask, other=0.0).to(tl.float32) 

45 w = tl.load(weight_ptr + offs, mask=mask, other=0.0).to(tl.float32) 

46 y = x * inv_rms * w 

47 tl.store(output_ptr + row_start + offs, y, mask=mask) 

48 col_start += BLOCK_SIZE 

49 

50 

51# Keep a handle to the Triton kernel before defining the Python wrapper with the same name 

52rmsnorm_kernel = rmsnorm 

53 

54 

55def rmsnorm(input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6): 

56 assert input_tensor.is_cuda and weight.is_cuda, "Tensors must be on CUDA device." 

57 assert ( 

58 input_tensor.shape[-1] == weight.numel() 

59 ), "weight must have shape (hidden_size,) matching the last dim of input_tensor." 

60 

61 x = input_tensor 

62 w = weight 

63 hidden_size = x.shape[-1] 

64 M = x.numel() // hidden_size 

65 N = hidden_size 

66 

67 # Ensure contiguous memory along the last dimension 

68 x_2d = x.contiguous().view(M, N) 

69 out = torch.empty_like(x) 

70 out_2d = out.view(M, N) 

71 w_c = w.contiguous() 

72 

73 # Choose a reasonable BLOCK_SIZE (power-of-two up to 4096) 

74 def next_pow2(v: int) -> int: 

75 return 1 if v <= 1 else 1 << ((v - 1).bit_length()) 

76 

77 BLOCK_SIZE = min(4096, max(128, next_pow2(N))) 

78 num_warps = 4 if BLOCK_SIZE <= 1024 else 8 

79 

80 grid = lambda meta: (M,) 

81 

82 rmsnorm_kernel[grid]( 

83 x_2d, 

84 w_c, 

85 out_2d, 

86 M, 

87 N, 

88 eps, 

89 BLOCK_SIZE=BLOCK_SIZE, 

90 num_warps=num_warps, 

91 ) 

92 return out