Coverage for src/flag_gems/runtime/backend/_cambricon/ops/vector_norm.py: 0%

308 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import dim_compress, libentry, tl_extra_shim 

10 

11from ..utils import TOTAL_CORE_NUM, cfggen_reduce_op, prune_reduce_config 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14pow = tl_extra_shim.pow 

15 

16 

17@libentry() 

18@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) 

19@triton.jit 

20def l2_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

21 # Map the program id to the row of X it should compute. 

22 num_prog = tl.num_programs(0) 

23 task_num = tl.cdiv(M, BLOCK_M) 

24 iter_num = tl.cdiv(task_num, num_prog) 

25 if task_num % num_prog != 0: 

26 iter_num = iter_num + 1 

27 for i in range(0, iter_num): 

28 pid = (i * num_prog + tl.program_id(0)) * BLOCK_M + tl.arange(0, BLOCK_M)[ 

29 :, None 

30 ] 

31 X_ptr = X + pid * N 

32 Out_ptr = Out + pid 

33 row_mask = pid < M 

34 

35 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

36 for off in range(0, N, BLOCK_N): 

37 cols = off + tl.arange(0, BLOCK_N)[None, :] 

38 col_mask = cols < N 

39 mask = row_mask and col_mask 

40 

41 a = tl.load(X_ptr + cols, mask, other=0.0).to(tl.float32) 

42 _sum += a * a 

43 sum = tl.sum(_sum, axis=1) 

44 

45 out = tl.sqrt(sum)[:, None] 

46 tl.store(Out_ptr, out, row_mask) 

47 

48 

49@libentry() 

50@triton.autotune( 

51 configs=cfggen_reduce_op(), 

52 key=["M"], 

53 prune_configs_by={"early_config_prune": prune_reduce_config}, 

54 reset_to_zero=["Out"], 

55) 

56@triton.heuristics( 

57 values={ 

58 "ONE_TILE_PER_CTA": lambda args: args["M"] 

59 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM 

60 }, 

61) 

62@triton.jit 

63def l2_norm_kernel_1( 

64 X, Out, M, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr 

65): 

66 pid = tl.program_id(0) 

67 block_start = pid * BLOCK_SIZE 

68 

69 mid = 0.0 

70 if ONE_TILE_PER_CTA: 

71 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

72 mask = offsets < M 

73 x = tl.load(X + offsets, mask, other=0.0).to(tl.float32) 

74 mid = tl.sum(x * x) 

75 else: 

76 _tmp = tl.zeros([BLOCK_SIZE], tl.float32) 

77 num_jobs = tl.num_programs(axis=0) 

78 step = num_jobs * BLOCK_SIZE 

79 for block_start_offset in range(block_start, M, step): 

80 offsets = block_start_offset + tl.arange(0, BLOCK_SIZE) 

81 mask = offsets < M 

82 x = tl.load(X + offsets, mask, other=0.0).to(tl.float32) 

83 _tmp = _tmp + x * x 

84 mid = tl.sum(_tmp) 

85 

86 tl.atomic_add(Out, mid.to(tl.float32)) 

87 

88 

89@libentry() 

90@triton.jit 

91def l2_norm_kernel_2( 

92 Out, 

93): 

94 out = tl.load(Out) 

95 out = tl.sqrt(out) 

96 tl.store(Out, out) 

97 

98 

99@libentry() 

100@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) 

101@triton.jit 

102def max_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

103 # Map the program id to the row of X it should compute. 

104 num_prog = tl.num_programs(0) 

105 task_num = tl.cdiv(M, BLOCK_M) 

106 iter_num = tl.cdiv(task_num, num_prog) 

107 if task_num % num_prog != 0: 

108 iter_num = iter_num + 1 

109 for i in range(0, iter_num): 

110 pid = (i * num_prog + tl.program_id(0)) * BLOCK_M + tl.arange(0, BLOCK_M)[ 

111 :, None 

112 ] 

113 X_ptr = X + pid * N 

114 Out_ptr = Out + pid 

115 row_mask = pid < M 

116 

117 _max = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

118 for off in range(0, N, BLOCK_N): 

119 cols = off + tl.arange(0, BLOCK_N)[None, :] 

120 col_mask = cols < N 

121 mask = row_mask and col_mask 

122 

123 a = tl.load(X_ptr + cols, mask, other=0.0).to(tl.float32) 

124 _max = tl.maximum(tl.abs(a), _max) 

125 

126 max = tl.max(_max, axis=1) 

127 out = max[:, None] 

128 tl.store(Out_ptr, out, row_mask) 

129 

130 

131@libentry() 

132@triton.autotune( 

133 configs=cfggen_reduce_op(), 

134 key=["M"], 

135 prune_configs_by={"early_config_prune": prune_reduce_config}, 

136) 

137@triton.heuristics( 

138 values={ 

139 "ONE_TILE_PER_CTA": lambda args: args["M"] 

140 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM 

141 }, 

142) 

143@triton.jit 

144def max_norm_kernel_1( 

145 X, Out, M, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr 

146): 

147 pid = tl.program_id(0) 

148 block_start = pid * BLOCK_SIZE 

149 

150 mid = 0.0 

151 if ONE_TILE_PER_CTA: 

152 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

153 mask = offsets < M 

154 x = tl.load(X + offsets, mask, other=0.0).to(tl.float32) 

155 mid = tl.max(tl.abs(x)) 

156 else: 

157 _tmp = tl.zeros([BLOCK_SIZE], tl.float32) 

158 num_jobs = tl.num_programs(axis=0) 

159 step = num_jobs * BLOCK_SIZE 

160 for block_start_offset in range(block_start, M, step): 

161 offsets = block_start_offset + tl.arange(0, BLOCK_SIZE) 

162 mask = offsets < M 

163 x = tl.load(X + offsets, mask, other=0.0).to(tl.float32) 

164 _x = tl.abs(x) 

165 _tmp = tl.where(_tmp > _x, _tmp, _x) 

166 mid = tl.max(_tmp) 

167 

168 tl.atomic_max(Out, mid.to(tl.float32)) 

169 

170 

171@libentry() 

172@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) 

173@triton.jit 

174def min_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

175 # Map the program id to the row of X it should compute. 

176 num_prog = tl.num_programs(0) 

177 task_num = tl.cdiv(M, BLOCK_M) 

178 iter_num = tl.cdiv(task_num, num_prog) 

179 if task_num % num_prog != 0: 

180 iter_num = iter_num + 1 

181 for i in range(0, iter_num): 

182 pid = (i * num_prog + tl.program_id(0)) * BLOCK_M + tl.arange(0, BLOCK_M)[ 

183 :, None 

184 ] 

185 X_ptr = X + pid * N 

186 Out_ptr = Out + pid 

187 row_mask = pid < M 

188 

189 _min = tl.full([BLOCK_M, BLOCK_N], value=float("inf"), dtype=tl.float32) 

190 for off in range(0, N, BLOCK_N): 

191 cols = off + tl.arange(0, BLOCK_N)[None, :] 

192 col_mask = cols < N 

193 mask = row_mask and col_mask 

194 

195 a = tl.load(X_ptr + cols, mask, other=float("inf")).to(tl.float32) 

196 _min = tl.minimum(tl.abs(a), _min) 

197 

198 min = tl.min(_min, axis=1) 

199 out = min[:, None] 

200 tl.store(Out_ptr, out, row_mask) 

201 

202 

203@libentry() 

204@triton.autotune( 

205 configs=cfggen_reduce_op(), 

206 key=["M"], 

207 prune_configs_by={"early_config_prune": prune_reduce_config}, 

208) 

209@triton.heuristics( 

210 values={ 

211 "ONE_TILE_PER_CTA": lambda args: args["M"] 

212 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM 

213 }, 

214) 

215@triton.jit 

216def min_norm_kernel_1( 

217 X, Out, M, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr 

218): 

219 pid = tl.program_id(0) 

220 block_start = pid * BLOCK_SIZE 

221 

222 if ONE_TILE_PER_CTA: 

223 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

224 mask = offsets < M 

225 x = tl.load(X + offsets, mask, other=float("inf")).to(tl.float32) 

226 mid = tl.min(tl.abs(x)) 

227 else: 

228 _tmp = tl.zeros([BLOCK_SIZE], tl.float32) 

229 num_jobs = tl.num_programs(axis=0) 

230 step = num_jobs * BLOCK_SIZE 

231 for block_start_offset in range(block_start, M, step): 

232 offsets = block_start_offset + tl.arange(0, BLOCK_SIZE) 

233 mask = offsets < M 

234 x = tl.load(X + offsets, mask, other=float("inf")).to(tl.float32) 

235 _x = tl.abs(x) 

236 _tmp = tl.where(_tmp < _x, _tmp, _x) 

237 mid = tl.min(_tmp) 

238 

239 tl.atomic_min(Out, mid.to(tl.float32)) 

240 

241 

242@libentry() 

243@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) 

244@triton.jit 

245def l0_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

246 # Map the program id to the row of X it should compute. 

247 num_prog = tl.num_programs(0) 

248 task_num = tl.cdiv(M, BLOCK_M) 

249 iter_num = tl.cdiv(task_num, num_prog) 

250 if task_num % num_prog != 0: 

251 iter_num = iter_num + 1 

252 for i in range(0, iter_num): 

253 pid = (i * num_prog + tl.program_id(0)) * BLOCK_M + tl.arange(0, BLOCK_M)[ 

254 :, None 

255 ] 

256 X_ptr = X + pid * N 

257 Out_ptr = Out + pid 

258 row_mask = pid < M 

259 

260 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

261 for off in range(0, N, BLOCK_N): 

262 cols = off + tl.arange(0, BLOCK_N)[None, :] 

263 col_mask = cols < N 

264 mask = row_mask and col_mask 

265 

266 a = tl.load(X_ptr + cols, mask, other=0).to(tl.float32) 

267 _sum += tl.where(a != 0, 1, 0) 

268 sum = tl.sum(_sum, axis=1) 

269 out = sum[:, None] 

270 tl.store(Out_ptr, out, row_mask) 

271 

272 

273@libentry() 

274@triton.autotune( 

275 configs=cfggen_reduce_op(), 

276 key=["M"], 

277 prune_configs_by={"early_config_prune": prune_reduce_config}, 

278 reset_to_zero=["Out"], 

279) 

280@triton.heuristics( 

281 values={ 

282 "ONE_TILE_PER_CTA": lambda args: args["M"] 

283 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM 

284 }, 

285) 

286@triton.jit 

287def l0_norm_kernel_1( 

288 X, Out, M, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr 

289): 

290 pid = tl.program_id(0) 

291 block_start = pid * BLOCK_SIZE 

292 

293 if ONE_TILE_PER_CTA: 

294 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

295 mask = offsets < M 

296 x = tl.load(X + offsets, mask, other=0.0).to(tl.float32) 

297 mid = tl.sum((x != 0).to(tl.float32)) 

298 else: 

299 _tmp = tl.zeros([BLOCK_SIZE], tl.float32) 

300 num_jobs = tl.num_programs(axis=0) 

301 step = num_jobs * BLOCK_SIZE 

302 for block_start_offset in range(block_start, M, step): 

303 offsets = block_start_offset + tl.arange(0, BLOCK_SIZE) 

304 mask = offsets < M 

305 x = tl.load(X + offsets, mask, other=0.0).to(tl.float32) 

306 _tmp = _tmp + (x != 0).to(tl.float32) 

307 mid = tl.sum(_tmp) 

308 

309 tl.atomic_add(Out, mid.to(tl.float32)) 

310 

311 

312@libentry() 

313@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) 

314@triton.jit(do_not_specialize=["ord"]) 

315def v_norm_kernel(X, Out, M, N, ord, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

316 # Map the program id to the row of X it should compute. 

317 num_prog = tl.num_programs(0) 

318 task_num = tl.cdiv(M, BLOCK_M) 

319 iter_num = tl.cdiv(task_num, num_prog) 

320 

321 for i in range(0, iter_num): 

322 pid = (i * num_prog + tl.program_id(0)) * BLOCK_M + tl.arange(0, BLOCK_M)[ 

323 :, None 

324 ] 

325 X_ptr = X + pid * N 

326 Out_ptr = Out + pid 

327 row_mask = pid < M 

328 

329 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

330 for off in range(0, N, BLOCK_N): 

331 cols = off + tl.arange(0, BLOCK_N)[None, :] 

332 col_mask = cols < N 

333 mask = row_mask and col_mask 

334 

335 a = tl.load(X_ptr + cols, mask, other=0.0).to(tl.float32) 

336 _sum += tl.extra.mlu.libdevice.pow(tl.abs(a), ord) 

337 sum = tl.sum(_sum, axis=1) 

338 out = tl.extra.mlu.libdevice.pow(sum, 1 / ord)[:, None] 

339 tl.store(Out_ptr, out, row_mask) 

340 

341 

342@libentry() 

343@triton.autotune( 

344 configs=cfggen_reduce_op(), 

345 key=["M"], 

346 prune_configs_by={"early_config_prune": prune_reduce_config}, 

347 reset_to_zero=["Out"], 

348) 

349@triton.heuristics( 

350 values={ 

351 "ONE_TILE_PER_CTA": lambda args: args["M"] 

352 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM 

353 }, 

354) 

355@triton.jit(do_not_specialize=["ord"]) 

356def l1_norm_kernel_1( 

357 X, Out, M, ord, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr 

358): 

359 pid = tl.program_id(0) 

360 block_start = pid * BLOCK_SIZE 

361 

362 mid = 0.0 

363 if ONE_TILE_PER_CTA: 

364 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

365 mask = offsets < M 

366 x = tl.load(X + offsets, mask, other=0.0).to(tl.float32) 

367 mid = tl.sum(pow(tl.abs(x), ord)) 

368 else: 

369 _tmp = tl.zeros([BLOCK_SIZE], tl.float32) 

370 num_jobs = tl.num_programs(axis=0) 

371 step = num_jobs * BLOCK_SIZE 

372 for block_start_offset in range(block_start, M, step): 

373 offsets = block_start_offset + tl.arange(0, BLOCK_SIZE) 

374 mask = offsets < M 

375 x = tl.load(X + offsets, mask, other=0.0).to(tl.float32) 

376 _tmp = _tmp + pow(tl.abs(x), ord) 

377 mid = tl.sum(_tmp) 

378 

379 tl.atomic_add(Out, mid.to(tl.float32)) 

380 

381 

382@libentry() 

383@triton.jit(do_not_specialize=["ord"]) 

384def l1_norm_kernel_2( 

385 Out, 

386 ord, 

387): 

388 out = tl.load(Out) 

389 out = pow(out, 1 / ord) 

390 tl.store(Out, out) 

391 

392 

393def vector_norm(x, ord=2, dim=None, keepdim=False, dtype=None): 

394 logger.debug("GEMS_CAMBRICON VECTOR NORM") 

395 if dtype is not None: 

396 if isinstance(dtype, str): 

397 dtype = getattr(torch, dtype) 

398 elif not isinstance(dtype, torch.dtype): 

399 dtype = torch.float32 

400 else: 

401 dtype = x.dtype 

402 if dtype not in [torch.float16, torch.float32, torch.bfloat16]: 

403 raise NotImplementedError(f"vector_norm not implemented for {dtype}") 

404 

405 with torch_device_fn.device(x.device): 

406 if (not dim) or len(dim) == x.ndim: 

407 dim = list(range(x.ndim)) 

408 shape = [1] * x.ndim 

409 x = dim_compress(x, dim) 

410 M = x.numel() 

411 

412 grid = lambda meta: ( 

413 min(triton.cdiv(M, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM), 

414 ) 

415 out = torch.zeros(shape, dtype=torch.float, device=x.device) 

416 if ord == 2: 

417 l2_norm_kernel_1[grid](x, out, M) 

418 l2_norm_kernel_2[(1,)](out) 

419 elif ord == float("inf"): 

420 max_norm_kernel_1[grid](x, out, M) 

421 elif ord == -float("inf"): 

422 out = torch.full( 

423 shape, 

424 fill_value=torch.finfo(torch.float32).max, 

425 dtype=torch.float, 

426 device=x.device, 

427 ) 

428 min_norm_kernel_1[grid](x, out, M) 

429 elif ord == 0: 

430 l0_norm_kernel_1[grid](x, out, M) 

431 else: 

432 l1_norm_kernel_1[grid](x, out, M, ord) 

433 l1_norm_kernel_2[(1,)]( 

434 out, 

435 ord, 

436 ) 

437 out = out.to(dtype) 

438 else: 

439 shape = list(x.shape) 

440 dim = [d % x.ndim for d in dim] 

441 x = dim_compress(x, dim) 

442 N = 1 

443 for i in dim: 

444 N *= shape[i] 

445 shape[i] = 1 

446 M = x.numel() // N 

447 out = torch.empty(shape, dtype=dtype, device=x.device) 

448 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) 

449 if ord == 2: 

450 l2_norm_kernel[grid](x, out, M, N) 

451 elif ord == float("inf"): 

452 max_norm_kernel[grid](x, out, M, N) 

453 elif ord == -float("inf"): 

454 min_norm_kernel[grid](x, out, M, N) 

455 elif ord == 0: 

456 l0_norm_kernel[grid](x, out, M, N) 

457 else: 

458 v_norm_kernel[grid](x, out, M, N, ord) 

459 if not keepdim: 

460 out = out.squeeze(dim=dim) 

461 return out