Coverage for src/flag_gems/runtime/backend/_mthreads/ops/index_add.py: 0%

79 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.runtime import torch_device_fn 

7from flag_gems.utils import dim_compress, libentry 

8from flag_gems.utils import triton_lang_extension as tle 

9 

10logger = logging.getLogger( 

11 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

12) 

13 

14 

15def cfggen(): 

16 """Generate autotune configurations for index_add kernel.""" 

17 block_m = [1, 2, 4, 8, 16] 

18 block_n = [64, 128, 256, 512, 1024, 2048] 

19 warps = [4, 8, 16] 

20 configs = [ 

21 triton.Config({"BLOCK_M": m, "BLOCK_N": n}, num_warps=w) 

22 for m in block_m 

23 for n in block_n 

24 for w in warps 

25 if m * n <= 16384 # Limit total block size 

26 ] 

27 return configs 

28 

29 

30@libentry() 

31@triton.autotune(configs=cfggen(), key=["M", "N"]) 

32@triton.jit 

33def index_add_kernel( 

34 inp_ptr, 

35 out_ptr, 

36 index_ptr, 

37 src_ptr, 

38 M, 

39 N, 

40 alpha, 

41 inp_len, 

42 BLOCK_M: tl.constexpr, 

43 BLOCK_N: tl.constexpr, 

44): 

45 """ 

46 Kernel for index_add operation with autotune. 

47 

48 After dim_compress, tensors are reshaped so that: 

49 - inp has shape (M, inp_len) where inp_len is the size of target dimension 

50 - src has shape (M, N) where N is the size of index 

51 

52 For each row m and each index position n: 

53 out[m, index[n]] += alpha * src[m, n] 

54 """ 

55 pid_m = tle.program_id(axis=0) 

56 pid_n = tle.program_id(axis=1) 

57 

58 # Calculate row and column offsets 

59 rows_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

60 cols_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] 

61 

62 # Create masks 

63 rows_mask = rows_offset < M 

64 cols_mask = cols_offset < N 

65 block_mask = rows_mask & cols_mask 

66 

67 # Load indices for this block of columns 

68 cur_indices = tl.load(index_ptr + cols_offset, mask=cols_mask, other=0) 

69 

70 # Calculate offsets into inp/out (which has shape M x inp_len) 

71 inp_off = rows_offset * inp_len + cur_indices 

72 

73 # Load current values from input 

74 cur_inp = tl.load(inp_ptr + inp_off, mask=block_mask, other=0.0) 

75 

76 # Calculate offsets into src (which has shape M x N) 

77 src_off = rows_offset * N + cols_offset 

78 

79 # Load source values 

80 cur_src = tl.load(src_ptr + src_off, mask=block_mask, other=0.0) 

81 

82 # Compute: out = inp + alpha * src 

83 result = cur_inp + alpha * cur_src 

84 

85 # Store result 

86 tl.store(out_ptr + inp_off, result, mask=block_mask) 

87 

88 

89def index_add(inp, dim, index, src, alpha=1): 

90 """ 

91 Optimized index_add for mthreads backend. 

92 

93 self.index_add_(dim, index, source, alpha=1) -> Tensor 

94 

95 For a 3-D tensor the output is: 

96 self[index[i], :, :] += alpha * src[i, :, :] # if dim == 0 

97 self[:, index[i], :] += alpha * src[:, i, :] # if dim == 1 

98 self[:, :, index[i]] += alpha * src[:, :, i] # if dim == 2 

99 """ 

100 logger.debug("GEMS_MTHREADS INDEX ADD") 

101 

102 # Make inputs contiguous 

103 inp = inp.contiguous() 

104 index = index.contiguous() 

105 src = src.contiguous() 

106 

107 # Normalize dimension 

108 dim = dim % inp.ndim 

109 inp_len = inp.size(dim) 

110 N = index.numel() 

111 M = src.numel() // N 

112 

113 # Move target dim to last position for coalesced memory access 

114 final_dim = inp.ndim - 1 

115 if dim != final_dim: 

116 inp = dim_compress(inp, dim) 

117 src = dim_compress(src, dim) 

118 

119 # Clone input for output 

120 out = inp.clone() 

121 

122 # Calculate grid with autotune 

123 grid = lambda meta: ( 

124 triton.cdiv(M, meta["BLOCK_M"]), 

125 triton.cdiv(N, meta["BLOCK_N"]), 

126 ) 

127 

128 with torch_device_fn.device(inp.device): 

129 index_add_kernel[grid](inp, out, index, src, M, N, alpha, inp_len) 

130 

131 # Restore original dimension order if needed 

132 if dim != final_dim: 

133 order = list(range(out.ndim - 1)) 

134 order.insert(dim, final_dim) 

135 return out.permute(order).contiguous() 

136 else: 

137 return out 

138 

139 

140def index_add_(inp, dim, index, src, alpha=1): 

141 """ 

142 In-place version of index_add. 

143 """ 

144 logger.debug("GEMS_MTHREADS INDEX ADD_") 

145 

146 # Make index and src contiguous 

147 index = index.contiguous() 

148 src = src.contiguous() 

149 

150 # Normalize dimension 

151 dim = dim % inp.ndim 

152 inp_len = inp.size(dim) 

153 N = index.numel() 

154 M = src.numel() // N 

155 

156 # Move target dim to last position 

157 final_dim = inp.ndim - 1 

158 

159 if dim != final_dim: 

160 # Need to work on a permuted copy 

161 inp_work = dim_compress(inp.clone().contiguous(), dim) 

162 src_work = dim_compress(src, dim) 

163 

164 # Calculate grid with autotune 

165 grid = lambda meta: ( 

166 triton.cdiv(M, meta["BLOCK_M"]), 

167 triton.cdiv(N, meta["BLOCK_N"]), 

168 ) 

169 

170 with torch_device_fn.device(inp.device): 

171 index_add_kernel[grid]( 

172 inp_work, inp_work, index, src_work, M, N, alpha, inp_len 

173 ) 

174 

175 # Restore original dimension order and copy back 

176 order = list(range(inp_work.ndim - 1)) 

177 order.insert(dim, final_dim) 

178 inp_work = inp_work.permute(order).contiguous() 

179 inp.copy_(inp_work) 

180 else: 

181 # Can work directly on input if already contiguous 

182 inp_contig = inp.contiguous() 

183 

184 # Calculate grid with autotune 

185 grid = lambda meta: ( 

186 triton.cdiv(M, meta["BLOCK_M"]), 

187 triton.cdiv(N, meta["BLOCK_N"]), 

188 ) 

189 

190 with torch_device_fn.device(inp.device): 

191 index_add_kernel[grid]( 

192 inp_contig, inp_contig, index, src, M, N, alpha, inp_len 

193 ) 

194 

195 # Copy back if input wasn't contiguous 

196 if not inp.is_contiguous(): 

197 inp.copy_(inp_contig) 

198 

199 return inp