Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/index_add.py: 0%

80 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.utils import dim_compress, libentry 

9 

10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

11# def cfggen(): 

12# block_m = [1, 2, 4] 

13# block_n = [128, 1024, 2048, 4096] 

14# configs = [ 

15# triton.Config({"BLOCK_M": m, "BLOCK_N": n}, num_warps=4) 

16# for m in block_m 

17# for n in block_n 

18# ] 

19# return configs 

20 

21 

22@libentry() 

23# @triton.autotune(configs=cfggen(), key=["M", "N"]) 

24@triton.heuristics(runtime.get_heuristic_config("index_add")) 

25# @triton.autotune( 

26# configs=[], generate_configs="index_add", op_affiliation="cluster", row_sign="M", col_sign="N", 

27# key=["M", "N"], 

28# ) 

29@triton.jit 

30def index_add_kernel( 

31 inp, 

32 inp_cont, 

33 index, 

34 src, 

35 M: tl.constexpr, 

36 N: tl.constexpr, 

37 alpha, 

38 inp_len, 

39 BLOCK_M: tl.constexpr, 

40 BLOCK_N: tl.constexpr, 

41): 

42 pid_x = tl.program_id(axis=0) # block_x 

43 pid_y = tl.program_id(axis=1) # block_y 

44 rows_offsets = ( 

45 pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

46 ) # block_x * BLOCK_M + tl.arange(0, BLOCK_M) 

47 cols_offsets = pid_y * BLOCK_N + tl.arange( 

48 0, BLOCK_N 

49 ) # block_y * BLOCK_N + tl.arange(0, BLOCK_N) 

50 

51 rows_mask = ( 

52 rows_offsets < M 

53 ) # rows_mask = block_x * BLOCK_M + tl.arange(0, BLOCK_M) < M 

54 index_mask = ( 

55 cols_offsets < N 

56 ) # index_mask = block_y * BLOCK_N + tl.arange(0, BLOCK_N) < N 

57 block_mask = rows_mask and index_mask # block_mask = rows_mask and index_mask 

58 

59 cur_indices = tl.load( 

60 index + cols_offsets, mask=index_mask, other=0 

61 ) # cur_indices = tl.load(index + cols_offsets, mask=index_mask, other=0) 

62 inp_off = ( 

63 rows_offsets * inp_len + cur_indices[None, :] 

64 ) # inp_off = (block_x * BLOCK_M + tl.arange(0, BLOCK_M)) * M + cur_indices 

65 cur_inp = tl.load( 

66 inp + inp_off, mask=block_mask, other=0.0 

67 ) # cur_inp = tl.load(inp + inp_off, mask=block_mask, other=0.0) 

68 src_off = ( 

69 rows_offsets * N + cols_offsets[None, :] 

70 ) # src_off = (block_x * BLOCK_M + tl.arange(0, BLOCK_M)) * N + block_y * BLOCK_N + tl.arange(0, BLOCK_N) 

71 cur_src = tl.load( 

72 src + src_off, mask=block_mask, other=0.0 

73 ) # cur_src = tl.load(src + src_off, mask=block_mask, other=0.0) 

74 cur_inp += alpha * cur_src 

75 

76 tl.store(inp_cont + inp_off, cur_inp, mask=block_mask) 

77 

78 

79def index_add(inp, dim, index, src, alpha=1): 

80 logger.debug("GEMS INDEX ADD") 

81 assert ((0 <= index) * (index < inp.size(dim))).equal( 

82 torch.ones(tuple(index.shape), dtype=torch.bool, device="cuda") 

83 ), "0 <= index < self.size(dim)" 

84 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

85 assert index.numel() == src.size( 

86 dim 

87 ), "The dimth dimension of source must have the same size as the length of index" 

88 assert ( 

89 inp.ndim == src.ndim 

90 ), "Self and source should have the same number of dimensions" 

91 assert ( 

92 ((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim) 

93 ), "src.size(d) == self.size(d) for all dimensions d != dim" 

94 

95 inp = inp.contiguous() 

96 index = index.contiguous() 

97 src = src.contiguous() 

98 

99 dim = dim % inp.ndim 

100 inp_len = inp.size(dim) 

101 N = index.numel() 

102 M = src.numel() // N 

103 fine_dim = inp.ndim - 1 

104 if dim != fine_dim: 

105 inp = dim_compress(inp, dim) 

106 src = dim_compress(src, dim) 

107 inp_cont = inp.clone() 

108 

109 grid = lambda meta: ( 

110 triton.cdiv(M, meta["BLOCK_M"]), 

111 triton.cdiv(N, meta["BLOCK_N"]), 

112 ) 

113 index_add_kernel[grid](inp, inp_cont, index, src, M, N, alpha, inp_len) 

114 if dim != fine_dim: 

115 order = [i for i in range(inp_cont.ndim - 1)] 

116 order.insert(dim, fine_dim) 

117 return inp_cont.permute(order).contiguous() 

118 else: 

119 return inp_cont 

120 

121 

122def index_add_(inp, dim, index, src, alpha=1): 

123 logger.debug("GEMS INDEX ADD_") 

124 assert ((0 <= index) * (index < inp.size(dim))).equal( 

125 torch.ones(tuple(index.shape), dtype=torch.bool, device="cuda") 

126 ), "0 <= index < self.size(dim)" 

127 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

128 assert index.numel() == src.size( 

129 dim 

130 ), "The dimth dimension of source must have the same size as the length of index" 

131 assert ( 

132 inp.ndim == src.ndim 

133 ), "Self and source should have the same number of dimensions" 

134 assert ( 

135 ((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim) 

136 ), "src.size(d) == self.size(d) for all dimensions d != dim" 

137 

138 inp_cont = inp.clone() 

139 inp_cont = inp_cont.contiguous() 

140 index = index.contiguous() 

141 src = src.contiguous() 

142 

143 dim = dim % inp_cont.ndim 

144 inp_len = inp_cont.size(dim) 

145 N = index.numel() 

146 M = src.numel() // N 

147 fine_dim = inp_cont.ndim - 1 

148 if dim != fine_dim: 

149 inp_cont = dim_compress(inp_cont, dim) 

150 src = dim_compress(src, dim) 

151 

152 grid = lambda meta: ( 

153 triton.cdiv(M, meta["BLOCK_M"]), 

154 triton.cdiv(N, meta["BLOCK_N"]), 

155 ) 

156 index_add_kernel[grid](inp_cont, inp_cont, index, src, M, N, alpha, inp_len) 

157 if dim != fine_dim: 

158 order = [i for i in range(inp_cont.ndim - 1)] 

159 order.insert(dim, fine_dim) 

160 inp_cont = inp_cont.permute(order).contiguous() 

161 inp.copy_(inp_cont) 

162 return inp 

163 else: 

164 inp.copy_(inp_cont) 

165 return inp