Coverage for src/flag_gems/runtime/backend/_cambricon/ops/triu.py: 0%

140 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils.shape_utils import can_use_int32_index 

11 

12from ..utils import TOTAL_CORE_NUM 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15 

16 

17@libentry() 

18@triton.autotune(configs=runtime.get_tuned_config("triu"), key=["M", "N"]) 

19@triton.jit(do_not_specialize=["diagonal"]) 

20def triu_kernel( 

21 X, 

22 Y, 

23 M, 

24 N, 

25 diagonal, 

26 M_BLOCK_SIZE: tl.constexpr, 

27 N_BLOCK_SIZE: tl.constexpr, 

28 NEED_LOOP: tl.constexpr, 

29 INT64_INDEX: tl.constexpr = False, 

30): 

31 pid = tl.program_id(0) 

32 if INT64_INDEX: 

33 pid = pid.to(tl.int64) 

34 num_jobs = tl.num_programs(0) 

35 m_block_step = M_BLOCK_SIZE * num_jobs 

36 

37 for m_offset in range(pid * M_BLOCK_SIZE, M, m_block_step): 

38 if NEED_LOOP: 

39 row = m_offset + tl.arange(0, M_BLOCK_SIZE)[:, None] 

40 m_mask = row < M 

41 PX = X + row * N 

42 PY = Y + row * N 

43 for n_offset in range(0, N, N_BLOCK_SIZE): 

44 cols = n_offset + tl.arange(0, N_BLOCK_SIZE)[None, :] 

45 n_mask = cols < N 

46 mask = m_mask and n_mask 

47 

48 x = tl.load(PX + cols, mask, other=0.0) 

49 y = tl.where(row + diagonal <= cols, x, 0.0) 

50 tl.store(PY + cols, y, mask=mask) 

51 else: 

52 write = tl.empty([M_BLOCK_SIZE, N_BLOCK_SIZE], X.dtype.element_ty) 

53 cols = tl.arange(0, N_BLOCK_SIZE) 

54 repeat_num = tl.minimum(M_BLOCK_SIZE, M - m_offset) 

55 for i in tl.range(repeat_num, num_stages=0): 

56 cur_row = m_offset + i 

57 PX = X + cur_row * N 

58 rmask = cols >= cur_row + diagonal 

59 write[i, :] = tl.load(PX + cols, mask=rmask, other=0.0) 

60 

61 row = m_offset + tl.arange(0, M_BLOCK_SIZE) 

62 offset = cols[None, :] + row[:, None] * N 

63 n_mask = row[:, None] < M 

64 tl.store(Y + offset, write, mask=n_mask) 

65 

66 

67@libentry() 

68@triton.autotune( 

69 configs=runtime.get_tuned_config("triu_batch"), 

70 key=["batch", "MN", "N", "diagonal"], 

71) 

72@triton.jit(do_not_specialize=["diagonal"]) 

73def triu_batch_kernel( 

74 X, 

75 Y, 

76 batch, 

77 MN, 

78 N, 

79 diagonal, 

80 BATCH_BLOCK_SIZE: tl.constexpr, 

81 MN_BLOCK_SIZE: tl.constexpr, 

82 INT64_INDEX: tl.constexpr = False, 

83): 

84 batch_id = tl.program_id(0) 

85 mn_id = tl.program_id(1) 

86 if INT64_INDEX: 

87 batch_id = batch_id.to(tl.int64) 

88 mn_id = mn_id.to(tl.int64) 

89 row = batch_id * BATCH_BLOCK_SIZE + tl.arange(0, BATCH_BLOCK_SIZE)[:, None] 

90 batch_mask = row < batch 

91 X += row * MN 

92 Y += row * MN 

93 

94 cols = mn_id * MN_BLOCK_SIZE + tl.arange(0, MN_BLOCK_SIZE)[None, :] 

95 mn_mask = cols < MN 

96 mask = batch_mask and mn_mask 

97 x = tl.load(X + cols, mask, other=0.0) 

98 m = cols // N 

99 n = cols % N 

100 y = tl.where(m + diagonal <= n, x, 0.0) 

101 tl.store(Y + cols, y, mask=mask) 

102 

103 

104def _check_batch_contiguous(tensor, allow_zero_stride=True): 

105 if tensor.is_contiguous(): 

106 return True, tensor 

107 

108 dims = tensor.dim() 

109 

110 if dims >= 2: 

111 n = tensor.size(-1) 

112 stride_row, stride_col = tensor.stride(-2), tensor.stride(-1) 

113 

114 if not (stride_col == 1 and stride_row == n): 

115 return False, tensor.contiguous() 

116 

117 if allow_zero_stride and dims <= 3: 

118 return True, tensor 

119 

120 expected_stride = tensor.size(-1) * tensor.size(-2) 

121 for i in range(dims - 3, -1, -1): 

122 if ( 

123 allow_zero_stride 

124 and i == 0 

125 and (tensor.stride(i) == 0 or tensor.size(i) == 1) 

126 ): 

127 continue 

128 

129 if tensor.stride(i) != expected_stride: 

130 return False, tensor.contiguous() 

131 

132 expected_stride *= tensor.size(i) 

133 

134 return True, tensor 

135 

136 

137def triu(A, diagonal=0): 

138 logger.debug("GEMS_CAMBRICON TRIU") 

139 

140 assert len(A.shape) > 1, "Input tensor must have at least 2 dimensions" 

141 

142 can_use_directly, A_input = _check_batch_contiguous(A, allow_zero_stride=False) 

143 

144 out = torch.empty( 

145 A.shape, dtype=A.dtype, device=A.device, memory_format=torch.contiguous_format 

146 ) 

147 

148 M, N = A_input.shape[-2:] 

149 use_int64_index = not can_use_int32_index(A_input) 

150 with torch_device_fn.device(A_input.device): 

151 if len(A_input.shape) == 2: 

152 grid = lambda meta: ( 

153 min(triton.cdiv(M, meta["M_BLOCK_SIZE"]), TOTAL_CORE_NUM), 

154 ) 

155 # A large value for n_block_size can lead to insufficient MLU resources, 

156 # causing the compilation to fail. Therefore, a conservative upper limit of 8192 

157 # is currently set, but the actual maximum achievable value should be confirmed 

158 # based on real-world conditions. 

159 elements_bytes = A_input.element_size() 

160 n_block = min(256 * 1024 // elements_bytes, N) 

161 need_loop = n_block < N 

162 triu_kernel[grid]( 

163 A_input, 

164 out, 

165 M, 

166 N, 

167 diagonal, 

168 N_BLOCK_SIZE=n_block, 

169 NEED_LOOP=need_loop, 

170 INT64_INDEX=use_int64_index, 

171 ) 

172 else: 

173 batch = int(torch.numel(A_input) / M / N) 

174 B = A_input.view(batch, -1) 

175 grid = lambda meta: ( 

176 triton.cdiv(batch, meta["BATCH_BLOCK_SIZE"]), 

177 triton.cdiv(M * N, meta["MN_BLOCK_SIZE"]), 

178 ) 

179 triu_batch_kernel[grid]( 

180 B, out, batch, M * N, N, diagonal, INT64_INDEX=use_int64_index 

181 ) 

182 out = out.view(A.shape) 

183 return out 

184 

185 

186def triu_(A, diagonal=0): 

187 logger.debug("GEMS_CAMBRICON TRIU_(inplace)") 

188 

189 assert len(A.shape) > 1, "Input tensor must have at least 2 dimensions" 

190 diagonal = int(diagonal) 

191 M, N = A.shape[-2:] 

192 

193 can_use_directly, A_to_use = _check_batch_contiguous(A, allow_zero_stride=True) 

194 

195 if not can_use_directly: 

196 logger.debug( 

197 "Input tensor does not satisfy contiguity requirements, " 

198 "using temporary tensor for computation" 

199 ) 

200 

201 result_temp = torch.empty_like(A_to_use, memory_format=torch.contiguous_format) 

202 use_int64_index = not can_use_int32_index(A_to_use) 

203 with torch_device_fn.device(A.device): 

204 if len(A.shape) == 2: 

205 grid = lambda meta: ( 

206 min(triton.cdiv(M, meta["M_BLOCK_SIZE"]), TOTAL_CORE_NUM), 

207 ) 

208 # A large value for n_block_size can lead to insufficient MLU resources, 

209 # causing the compilation to fail. Therefore, a conservative upper limit of 8192 

210 # is currently set, but the actual maximum achievable value should be confirmed 

211 # based on real-world conditions. 

212 elements_bytes = A.element_size() 

213 n_block = min(256 * 1024 // elements_bytes, N) 

214 need_loop = n_block < N 

215 triu_kernel[grid]( 

216 A_to_use, 

217 result_temp, 

218 M, 

219 N, 

220 diagonal, 

221 N_BLOCK_SIZE=n_block, 

222 NEED_LOOP=need_loop, 

223 INT64_INDEX=use_int64_index, 

224 ) 

225 else: 

226 batch = int(torch.numel(A) / M / N) 

227 B = A_to_use.view(batch, -1) 

228 result_temp_flat = result_temp.view(batch, -1) 

229 grid = lambda meta: ( 

230 triton.cdiv(batch, meta["BATCH_BLOCK_SIZE"]), 

231 triton.cdiv(M * N, meta["MN_BLOCK_SIZE"]), 

232 ) 

233 triu_batch_kernel[grid]( 

234 B, 

235 result_temp_flat, 

236 batch, 

237 M * N, 

238 N, 

239 diagonal, 

240 INT64_INDEX=use_int64_index, 

241 ) 

242 A.copy_(result_temp) 

243 else: 

244 use_int64_index = not can_use_int32_index(A) 

245 with torch_device_fn.device(A.device): 

246 if len(A.shape) == 2: 

247 grid = lambda meta: ( 

248 min(triton.cdiv(M, meta["M_BLOCK_SIZE"]), TOTAL_CORE_NUM), 

249 ) 

250 # A large value for n_block_size can lead to insufficient MLU resources, 

251 # causing the compilation to fail. Therefore, a conservative upper limit of 8192 

252 # is currently set, but the actual maximum achievable value should be confirmed 

253 # based on real-world conditions. 

254 elements_bytes = A.element_size() 

255 n_block = min(256 * 1024 // elements_bytes, N) 

256 need_loop = n_block < N 

257 triu_kernel[grid]( 

258 A, 

259 A, 

260 M, 

261 N, 

262 diagonal, 

263 N_BLOCK_SIZE=n_block, 

264 NEED_LOOP=need_loop, 

265 INT64_INDEX=use_int64_index, 

266 ) 

267 else: 

268 batch = int(torch.numel(A) / M / N) 

269 B = A.view(batch, -1) 

270 grid = lambda meta: ( 

271 triton.cdiv(batch, meta["BATCH_BLOCK_SIZE"]), 

272 triton.cdiv(M * N, meta["MN_BLOCK_SIZE"]), 

273 ) 

274 triu_batch_kernel[grid]( 

275 B, B, batch, M * N, N, diagonal, INT64_INDEX=use_int64_index 

276 ) 

277 return A