Coverage for src/flag_gems/runtime/backend/_ascend/fused/fused_add_rms_norm.py: 0%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import logging 

2import math 

3 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@libentry() 

14@triton.jit(do_not_specialize=["eps"]) 

15def fused_add_rms_norm_kernel( 

16 X, # pointer to the input 

17 R, # pointer to the residual 

18 W, # pointer to the weight 

19 x_stride_r, # how much to increase the pointer when moving by 1 row 

20 x_stride_c, # how much to increase the pointer when moving by 1 col 

21 r_stride_r, # how much to increase the pointer when moving by 1 row 

22 r_stride_c, # how much to increase the pointer when moving by 1 col 

23 N, # number of columns in X 

24 eps, # epsilon to avoid division by zero 

25 BLOCK_SIZE: tl.constexpr, 

26): 

27 pid = tl.program_id(0) 

28 X += pid * x_stride_r 

29 R += pid * r_stride_r 

30 

31 _var_base = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 

32 

33 for off in range(0, N, BLOCK_SIZE): 

34 cols = off + tl.arange(0, BLOCK_SIZE) 

35 mask = cols < N 

36 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

37 r = tl.load(R + cols, mask, other=0.0).to(tl.float32) 

38 x += r 

39 _var_base += x * x / N 

40 var = tl.sum(_var_base) 

41 

42 rrms = 1 / tl.sqrt(var + eps) 

43 

44 for off in range(0, N, BLOCK_SIZE): 

45 cols = off + tl.arange(0, BLOCK_SIZE) 

46 mask = cols < N 

47 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

48 r = tl.load(R + cols, mask, other=0.0).to(tl.float32) 

49 x += r 

50 w = tl.load(W + cols, mask, other=0.0) 

51 y = (x * rrms).to(X.dtype.element_ty) * w 

52 # write back to residual and input 

53 tl.store(R + cols * r_stride_c, x, mask=mask) 

54 tl.store(X + cols * x_stride_c, y, mask=mask) 

55 

56 

57def fused_add_rms_norm(x, residual, normalized_shape, weight, eps=1e-5): 

58 """ 

59 This function performs fused residual addition and RMS normalization **in-place**. 

60 Both `x` and `residual` tensors will be modified. Use with caution if these tensors 

61 are reused elsewhere or require gradients. 

62 """ 

63 logger.debug("GEMS_ASCEND FUSED_ADD_RMS_NORM FORWARD") 

64 dim = x.ndim - len(normalized_shape) 

65 M = min(math.prod(x.shape[:dim]), 65535) 

66 N = math.prod(normalized_shape) 

67 

68 BLOCK_SIZE = min(triton.next_power_of_2(N), 8192) 

69 x = x.contiguous() 

70 residual = residual.contiguous() 

71 weight = weight.contiguous() 

72 

73 with torch_device_fn.device(x.device): 

74 fused_add_rms_norm_kernel[M,]( 

75 x, residual, weight, N, 1, N, 1, N, eps, BLOCK_SIZE 

76 ) 

77 return x, residual