Coverage for src/flag_gems/experimental_ops/rmsnorm.py: 13%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1# Generated by KernelGen v1.0 

2# Source: Triton 

3# Performance: 2.6x vs Native on A100 

4# License: Apache-2.0 

5 

6 

7import torch 

8import triton 

9import triton.language as tl 

10 

11 

12@triton.jit 

13def rmsnorm_kernel( 

14 input_ptr, # *Pointer* to the input tensor flattened to 2D [M, N] 

15 weight_ptr, # *Pointer* to the weight tensor [N] 

16 output_ptr, # *Pointer* to the output tensor flattened to 2D [M, N] 

17 M, # Number of rows (prod of leading dims) 

18 N, # Hidden size (last dimension) 

19 eps, # Epsilon for numerical stability 

20 BLOCK_SIZE: tl.constexpr, 

21): 

22 pid = tl.program_id(axis=0) # each program handles one row 

23 row_start = pid * N 

24 

25 # First pass: compute sum of squares across the row 

26 col_start = 0 

27 acc = tl.zeros([1], dtype=tl.float32) 

28 while col_start < N: 

29 offs = col_start + tl.arange(0, BLOCK_SIZE) 

30 mask = offs < N 

31 x = tl.load(input_ptr + row_start + offs, mask=mask, other=0.0) 

32 x = x.to(tl.float32) 

33 acc += tl.sum(x * x, axis=0) 

34 col_start += BLOCK_SIZE 

35 

36 mean_sq = acc / N 

37 inv_rms = 1.0 / tl.sqrt(mean_sq + eps) 

38 

39 # Second pass: normalize and scale 

40 col_start = 0 

41 while col_start < N: 

42 offs = col_start + tl.arange(0, BLOCK_SIZE) 

43 mask = offs < N 

44 x = tl.load(input_ptr + row_start + offs, mask=mask, other=0.0).to(tl.float32) 

45 w = tl.load(weight_ptr + offs, mask=mask, other=0.0).to(tl.float32) 

46 y = x * inv_rms * w 

47 tl.store(output_ptr + row_start + offs, y, mask=mask) 

48 col_start += BLOCK_SIZE 

49 

50 

51def rmsnorm(input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6): 

52 assert input_tensor.is_cuda and weight.is_cuda, "Tensors must be on CUDA device." 

53 assert ( 

54 input_tensor.shape[-1] == weight.numel() 

55 ), "weight must have shape (hidden_size,) matching the last dim of input_tensor." 

56 

57 x = input_tensor 

58 w = weight 

59 hidden_size = x.shape[-1] 

60 M = x.numel() // hidden_size 

61 N = hidden_size 

62 

63 # Ensure contiguous memory along the last dimension 

64 x_2d = x.contiguous().view(M, N) 

65 out = torch.empty_like(x) 

66 out_2d = out.view(M, N) 

67 w_c = w.contiguous() 

68 

69 # Choose a reasonable BLOCK_SIZE (power-of-two up to 4096) 

70 def next_pow2(v: int) -> int: 

71 return 1 if v <= 1 else 1 << ((v - 1).bit_length()) 

72 

73 BLOCK_SIZE = min(4096, max(128, next_pow2(N))) 

74 num_warps = 4 if BLOCK_SIZE <= 1024 else 8 

75 

76 grid = lambda meta: (M,) 

77 

78 rmsnorm_kernel[grid]( 

79 x_2d, 

80 w_c, 

81 out_2d, 

82 M, 

83 N, 

84 eps, 

85 BLOCK_SIZE=BLOCK_SIZE, 

86 num_warps=num_warps, 

87 ) 

88 return out