Coverage for src/flag_gems/runtime/backend/_ascend/fused/skip_layernorm.py: 0%

60 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@triton.jit(do_not_specialize=["eps"]) 

17def skip_layer_norm_kernel( 

18 Y, # pointer to the output 

19 X, # pointer to the input 

20 R, # pointer to the residual 

21 W, # pointer to the weights 

22 B, # pointer to the biases 

23 y_stride_r, 

24 y_stride_c, 

25 x_stride_r, # pointer increment for moving 1 row 

26 x_stride_c, # pointer increment for moving 1 column 

27 r_stride_r, # pointer increment for moving 1 row in residual 

28 r_stride_c, # pointer increment for moving 1 column in residual 

29 N, # number of columns in X 

30 eps, # epsilon to avoid division by zero 

31 BLOCK_SIZE: tl.constexpr, 

32): 

33 pid = tle.program_id(0) 

34 loops = tl.cdiv(N, BLOCK_SIZE) 

35 

36 # Initialize accumulators for sum of x and sum of squared x (for mean/variance) 

37 sum_x = tl.zeros((), dtype=tl.float32) # Explicitly specify as float32 

38 sum_sq = tl.zeros((), dtype=tl.float32) # For variance calculation 

39 

40 # Pointer offsets for current row (based on program ID) 

41 X += pid * x_stride_r 

42 R += pid * r_stride_r 

43 Y += pid * y_stride_r 

44 

45 # This partitioning is special: need to load entire row data for computation, so N is required 

46 # First pass: compute sum(x) and sum(x²) in one traversal (optimized from 2 passes) 

47 for process in range(loops): 

48 cols = process * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

49 mask = cols < N 

50 # Load input and residual, compute x = X + residual 

51 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

52 r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32) 

53 x += r 

54 # Accumulate sum of x and sum of x squared 

55 sum_x += tl.sum(x, axis=0) 

56 sum_sq += tl.sum(x * x, axis=0) 

57 

58 # Compute mean and variance from accumulated sums 

59 mean = sum_x / N 

60 var = (sum_sq / N) - (mean * mean) # Equivalent to E[x²] - (E[x])² 

61 rstd = 1 / tl.sqrt(var + eps) # Reciprocal of standard deviation 

62 

63 # Second pass: compute final output (y = w * (x-mean)/std + b) 

64 for process in range(loops): 

65 cols = process * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

66 mask = cols < N 

67 # Load weights and biases 

68 w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) 

69 b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) 

70 # Re-load x (X + residual) for normalization 

71 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

72 r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32) 

73 x += r 

74 # Apply layer norm and linear transformation 

75 x_hat = (x - mean) * rstd 

76 y = w * x_hat + b 

77 # Cast to output dtype and store 

78 y = y.to(Y.dtype.element_ty) 

79 tl.store(Y + cols * y_stride_c, y, mask=mask) 

80 

81 

82class SkipLayerNorm(torch.autograd.Function): 

83 @staticmethod 

84 def forward(ctx, x, residual, normalized_shape, weight, bias, eps=1e-5): 

85 logger.debug("GEMS_ASCEND SKIP LAYERNORM FORWARD") 

86 dim = x.ndim - len(normalized_shape) 

87 M = min(math.prod(x.shape[:dim]), 65535) 

88 N = math.prod(normalized_shape) 

89 BLOCK_SIZE = min(triton.next_power_of_2(N), 4096) 

90 x = x.contiguous() 

91 residual = residual.contiguous() 

92 weight = weight.contiguous() 

93 bias = bias.contiguous() 

94 y = torch.empty_like(x) 

95 with torch_device_fn.device(x.device): 

96 skip_layer_norm_kernel[M,]( 

97 y, x, residual, weight, bias, N, 1, N, 1, N, 1, N, eps, BLOCK_SIZE 

98 ) 

99 return y 

100 

101 

102def skip_layer_norm(x, residual, normalized_shape, weight, bias, eps=1e-5): 

103 return SkipLayerNorm.apply(x, residual, normalized_shape, weight, bias, eps)