Coverage for src/flag_gems/fused/skip_layernorm.py: 60%

50 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@triton.jit(do_not_specialize=["eps"]) 

17def skip_layer_norm_kernel( 

18 Y, # pointer to the output 

19 X, # pointer to the input 

20 R, # pointer to the residual 

21 W, # pointer to the weights 

22 B, # pointer to the biases 

23 y_stride_r, 

24 y_stride_c, 

25 x_stride_r, # how much to increase the pointer when moving by 1 row 

26 x_stride_c, # how much to increase the pointer when moving by 1 col 

27 r_stride_r, # how much to increase the pointer when moving by 1 row 

28 r_stride_c, # how much to increase the pointer when moving by 1 col 

29 N, # number of columns in X 

30 eps, # epsilon to avoid division by zero 

31 BLOCK_SIZE: tl.constexpr, 

32): 

33 pid = tle.program_id(0) 

34 Y += pid * y_stride_r 

35 X += pid * x_stride_r 

36 R += pid * r_stride_r 

37 

38 mask = tl.arange(0, BLOCK_SIZE) < N 

39 cols = tl.arange(0, BLOCK_SIZE) 

40 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

41 r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32) 

42 

43 x += r 

44 

45 mean = tl.sum(x, axis=0) / N 

46 

47 # Compute variance 

48 _var = tl.where(mask, x - mean, 0.0) 

49 _var = _var * _var 

50 var = tl.sum(_var, axis=0) / N 

51 rstd = 1 / tl.sqrt(var + eps) 

52 

53 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0).to(tl.float32) 

54 b = tl.load(B + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0).to(tl.float32) 

55 

56 x_hat = (x - mean) * rstd 

57 y = w * x_hat + b 

58 y = y.to(Y.dtype.element_ty) 

59 tl.store(Y + cols * y_stride_c, y, mask=mask) 

60 

61 

62class SkipLayerNorm(torch.autograd.Function): 

63 @staticmethod 

64 def forward(ctx, x, residual, normalized_shape, weight, bias, eps=1e-5): 

65 logger.debug("GEMS SKIP LAYERNORM FORWARD") 

66 dim = x.ndim - len(normalized_shape) 

67 M = math.prod(x.shape[:dim]) 

68 N = math.prod(normalized_shape) 

69 

70 BLOCK_SIZE = triton.next_power_of_2(N) 

71 x = x.contiguous() 

72 residual = residual.contiguous() 

73 weight = weight.contiguous() 

74 bias = bias.contiguous() 

75 y = torch.empty_like(x) 

76 

77 with torch_device_fn.device(x.device): 

78 skip_layer_norm_kernel[M,]( 

79 y, x, residual, weight, bias, N, 1, N, 1, N, 1, N, eps, BLOCK_SIZE 

80 ) 

81 return y 

82 

83 

84def skip_layer_norm(x, residual, normalized_shape, weight, bias, eps=1e-5): 

85 return SkipLayerNorm.apply(x, residual, normalized_shape, weight, bias, eps)