Coverage for src/flag_gems/runtime/backend/_nvidia/fused/fused_add_rms_norm.py: 32%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-15 02:11 +0800

1import logging 

2import math 

3 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@libentry() 

15@triton.jit(do_not_specialize=["eps"]) 

16def fused_add_rms_norm_kernel( 

17 X, # pointer to the input 

18 R, # pointer to the residual 

19 W, # pointer to the weights 

20 x_stride_r, # how much to increase the pointer when moving by 1 row 

21 x_stride_c, # how much to increase the pointer when moving by 1 col 

22 r_stride_r, # how much to increase the pointer when moving by 1 row 

23 r_stride_c, # how much to increase the pointer when moving by 1 col 

24 N, # number of columns in X 

25 eps, # epsilon to avoid division by zero 

26 BLOCK_SIZE: tl.constexpr, 

27): 

28 pid = tle.program_id(0) 

29 X += pid * x_stride_r 

30 R += pid * r_stride_r 

31 

32 mask = tl.arange(0, BLOCK_SIZE) < N 

33 cols = tl.arange(0, BLOCK_SIZE) 

34 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

35 r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32) 

36 

37 x += r 

38 # write back to residual 

39 tl.store(R + cols * r_stride_c, x, mask=mask) 

40 

41 var = tl.sum(x * x / N, axis=0) 

42 rrms = 1 / tl.sqrt(var + eps) 

43 

44 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0) 

45 y = (x * rrms).to(X.dtype.element_ty) * w 

46 # write back to input 

47 tl.store(X + cols * x_stride_c, y, mask=mask) 

48 

49 

50def fused_add_rms_norm(x, residual, normalized_shape, weight, eps=1e-5): 

51 print( 

52 "\n .......test for mutibackend specific fused op fused_add_rms_norm ........\n" 

53 ) 

54 dim = x.ndim - len(normalized_shape) 

55 M = math.prod(x.shape[:dim]) 

56 N = math.prod(normalized_shape) 

57 

58 BLOCK_SIZE = triton.next_power_of_2(N) 

59 x = x.contiguous() 

60 residual = residual.contiguous() 

61 weight = weight.contiguous() 

62 

63 with torch_device_fn.device(x.device): 

64 fused_add_rms_norm_kernel[M,]( 

65 x, residual, weight, N, 1, N, 1, N, eps, BLOCK_SIZE 

66 ) 

67 return x, residual