Coverage for src/flag_gems/experimental_ops/glu.py: 0%

78 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import math # noqa: F401 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7 

8@triton.jit 

9def glu_kernel( 

10 x_ptr, # *Pointer* to input tensor data (flattened, contiguous). 

11 y_ptr, # *Pointer* to output tensor data (flattened, contiguous). 

12 n_out_elements, # Number of elements in the output tensor. 

13 inner_size, # Product of sizes of dims after 'dim' in output shape. 

14 half_size, # Size along 'dim' in output shape (i.e., original dim size // 2). 

15 outer_elems, # Number of elements per 'outer' slice in the input: (2*half_size)*inner_size. 

16 BLOCK_SIZE: tl.constexpr, 

17): 

18 pid = tl.program_id(axis=0) 

19 block_start = pid * BLOCK_SIZE 

20 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

21 mask = offsets < n_out_elements 

22 

23 idx = offsets 

24 s = half_size 

25 inner = inner_size 

26 outer_inc = outer_elems 

27 

28 # Map each output index to the corresponding input indices. 

29 # For contiguous tensors: 

30 # - output shape: [..., s, ...]; n_out = outer * s * inner 

31 # - input shape: [..., 2*s, ...] 

32 # Let: 

33 # o = idx // (s * inner) 

34 # r = idx % (s * inner) 

35 # d = r // inner 

36 # i = r % inner 

37 # Then: 

38 # x_left_index = o * (2*s*inner) + d * inner + i 

39 # x_right_index = x_left_index + s * inner 

40 denom = s * inner 

41 o = idx // denom 

42 r = idx % denom 

43 d = r // inner 

44 i = r % inner 

45 

46 x_index = o * outer_inc + d * inner + i 

47 gate_index = x_index + s * inner 

48 

49 x_val = tl.load(x_ptr + x_index, mask=mask, other=0.0) 

50 g_val = tl.load(x_ptr + gate_index, mask=mask, other=0.0) 

51 

52 x_f = x_val.to(tl.float32) 

53 g_f = g_val.to(tl.float32) 

54 gate = 1.0 / (1.0 + tl.exp(-g_f)) 

55 y = x_f * gate 

56 y_cast = y.to(x_val.dtype) 

57 

58 tl.store(y_ptr + idx, y_cast, mask=mask) 

59 

60 

61def _normalize_dim(dim: int, ndim: int) -> int: 

62 if dim < 0: 

63 dim += ndim 

64 if not (0 <= dim < ndim): 

65 actual_dim = dim - ndim if dim >= ndim else dim 

66 raise IndexError( 

67 f"Dimension out of range (expected to be in range of " 

68 f"[{-ndim}, {ndim - 1}], but got {actual_dim})" 

69 ) 

70 return dim 

71 

72 

73def _check_dtype_supported(dtype: torch.dtype): 

74 if dtype not in (torch.float16, torch.bfloat16, torch.float32): 

75 raise TypeError( 

76 f"Unsupported dtype {dtype}. Supported dtypes are: float16, bfloat16, float32." 

77 ) 

78 

79 

80def _glu_launch(x: torch.Tensor, dim: int, out: torch.Tensor = None) -> torch.Tensor: 

81 if not x.is_cuda: 

82 raise AssertionError("Input tensor must be on CUDA device.") 

83 x = x.contiguous() 

84 _check_dtype_supported(x.dtype) 

85 

86 ndim = x.dim() 

87 dim = _normalize_dim(dim, ndim) 

88 size_dim = x.size(dim) 

89 if size_dim % 2 != 0: 

90 raise RuntimeError( 

91 f"glu: dimension {dim} size must be even, but got {size_dim}." 

92 ) 

93 

94 half = size_dim // 2 

95 

96 # Compute output shape 

97 out_shape = list(x.shape) 

98 out_shape[dim] = half 

99 

100 # Prepare output 

101 if out is None: 

102 out = torch.empty(out_shape, device=x.device, dtype=x.dtype) 

103 else: 

104 if not out.is_cuda: 

105 raise AssertionError("Output tensor must be on CUDA device.") 

106 if tuple(out.shape) != tuple(out_shape): 

107 raise RuntimeError( 

108 f"glu_out: provided out has wrong shape. Expected {tuple(out_shape)}, got {tuple(out.shape)}." 

109 ) 

110 if out.dtype != x.dtype: 

111 raise RuntimeError( 

112 f"glu_out: dtype mismatch. out.dtype={out.dtype}, expected {x.dtype}." 

113 ) 

114 if not out.is_contiguous(): 

115 raise RuntimeError("glu_out: output tensor must be contiguous.") 

116 out = out.contiguous() 

117 

118 # Compute mapping parameters for contiguous layout 

119 # inner_size = product of dimensions after 'dim' in the output shape 

120 inner_size = 1 

121 for k in range(dim + 1, ndim): 

122 inner_size *= out_shape[k] 

123 

124 n_out = out.numel() 

125 outer_elems = (2 * half) * inner_size # elements per 'outer' slice in input 

126 

127 BLOCK_SIZE = 1024 

128 grid = lambda meta: (triton.cdiv(n_out, meta["BLOCK_SIZE"]),) # noqa: E731 

129 

130 glu_kernel[grid]( 

131 x, 

132 out, 

133 n_out, 

134 inner_size, 

135 half, 

136 outer_elems, 

137 BLOCK_SIZE=BLOCK_SIZE, 

138 ) 

139 return out 

140 

141 

142def glu(input: torch.Tensor, dim: int = -1) -> torch.Tensor: 

143 return _glu_launch(input, dim, out=None) 

144 

145 

146def glu_out( 

147 input: torch.Tensor, dim: int = -1, out: torch.Tensor = None 

148) -> torch.Tensor: 

149 if out is None: 

150 raise RuntimeError("glu_out: 'out' tensor must be provided.") 

151 return _glu_launch(input, dim, out=out)