Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/vector_norm.py: 0%

263 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import builtins 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8# from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import dim_compress, libentry, tl_extra_shim 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14pow = tl_extra_shim.pow 

15 

16 

17def heur_block_m(args): 

18 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) 

19 

20 

21def heur_block_n(args): 

22 return builtins.min(args["N"], 8192) 

23 

24 

25@libentry() 

26# @triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) 

27@triton.heuristics( 

28 { 

29 "BLOCK_M": heur_block_m, 

30 "BLOCK_N": heur_block_n, 

31 } 

32) 

33@triton.jit 

34def l2_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

35 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

36 X = X + pid * N 

37 Out = Out + pid 

38 row_mask = pid < M 

39 

40 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

41 for off in range(0, N, BLOCK_N): 

42 cols = off + tl.arange(0, BLOCK_N)[None, :] 

43 col_mask = cols < N 

44 mask = row_mask and col_mask 

45 

46 a = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

47 _sum += a * a 

48 sum = tl.sum(_sum, axis=1) 

49 

50 out = tl.sqrt(sum)[:, None] 

51 tl.store(Out, out, row_mask) 

52 

53 

54@libentry() 

55@triton.jit 

56def l2_norm_kernel_1( 

57 X, Mid, M, BLOCK_SIZE: tl.constexpr, buffer_size_limit: tl.constexpr 

58): 

59 pid = tle.program_id(0).to(tl.int64) 

60 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

61 X = X + offset 

62 Mid = Mid + pid 

63 mask = offset < M 

64 

65 x = tl.load(X, mask=mask, other=0.0).to(tl.float32) 

66 mid = tl.sum(x * x) 

67 tl.store(Mid, mid) 

68 

69 

70@libentry() 

71@triton.jit 

72def l2_norm_kernel_2( 

73 Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr, buffer_size_limit: tl.constexpr 

74): 

75 offset = tl.arange(0, BLOCK_MID) 

76 Mid = Mid + offset 

77 mask = offset < MID_SIZE 

78 mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32) 

79 out = tl.sqrt(tl.sum(mid)) 

80 tl.store(Out, out) 

81 

82 

83@libentry() 

84# @triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) 

85@triton.heuristics( 

86 { 

87 "BLOCK_M": heur_block_m, 

88 "BLOCK_N": heur_block_n, 

89 } 

90) 

91@triton.jit 

92def max_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

93 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

94 X = X + pid * N 

95 Out = Out + pid 

96 row_mask = pid < M 

97 

98 _max = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

99 for off in range(0, N, BLOCK_N): 

100 cols = off + tl.arange(0, BLOCK_N)[None, :] 

101 col_mask = cols < N 

102 mask = row_mask and col_mask 

103 

104 a = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

105 _max = tl.maximum(tl.abs(a), _max) 

106 

107 max = tl.max(_max, axis=1) 

108 out = max[:, None] 

109 tl.store(Out, out, row_mask) 

110 

111 

112@libentry() 

113@triton.jit 

114def max_norm_kernel_1( 

115 X, Mid, M, BLOCK_SIZE: tl.constexpr, buffer_size_limit: tl.constexpr 

116): 

117 pid = tle.program_id(0).to(tl.int64) 

118 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

119 X = X + offset 

120 Mid = Mid + pid 

121 mask = offset < M 

122 

123 x = tl.load(X, mask=mask, other=0.0).to(tl.float32) 

124 mid = tl.max(tl.abs(x)) 

125 tl.store(Mid, mid) 

126 

127 

128@libentry() 

129@triton.jit 

130def max_norm_kernel_2( 

131 Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr, buffer_size_limit: tl.constexpr 

132): 

133 offset = tl.arange(0, BLOCK_MID) 

134 Mid = Mid + offset 

135 mask = offset < MID_SIZE 

136 mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32) 

137 out = tl.max(mid) 

138 tl.store(Out, out) 

139 

140 

141@libentry() 

142# @triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) 

143@triton.heuristics( 

144 { 

145 "BLOCK_M": heur_block_m, 

146 "BLOCK_N": heur_block_n, 

147 } 

148) 

149@triton.jit 

150def min_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

151 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

152 X = X + pid * N 

153 Out = Out + pid 

154 row_mask = pid < M 

155 

156 _min = tl.full([BLOCK_M, BLOCK_N], value=float("inf"), dtype=tl.float32) 

157 for off in range(0, N, BLOCK_N): 

158 cols = off + tl.arange(0, BLOCK_N)[None, :] 

159 col_mask = cols < N 

160 mask = row_mask and col_mask 

161 

162 a = tl.load(X + cols, mask, other=float("inf")).to(tl.float32) 

163 _min = tl.minimum(tl.abs(a), _min) 

164 

165 min = tl.min(_min, axis=1) 

166 out = min[:, None] 

167 tl.store(Out, out, row_mask) 

168 

169 

170@libentry() 

171@triton.jit 

172def min_norm_kernel_1( 

173 X, Mid, M, BLOCK_SIZE: tl.constexpr, buffer_size_limit: tl.constexpr 

174): 

175 pid = tle.program_id(0).to(tl.int64) 

176 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

177 X = X + offset 

178 Mid = Mid + pid 

179 mask = offset < M 

180 

181 x = tl.load(X, mask=mask, other=float("inf")).to(tl.float32) 

182 mid = tl.min(tl.abs(x)) 

183 tl.store(Mid, mid) 

184 

185 

186@libentry() 

187@triton.jit 

188def min_norm_kernel_2( 

189 Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr, buffer_size_limit: tl.constexpr 

190): 

191 offset = tl.arange(0, BLOCK_MID) 

192 Mid = Mid + offset 

193 mask = offset < MID_SIZE 

194 mid = tl.load(Mid, mask=mask, other=float("inf")).to(tl.float32) 

195 out = tl.min(mid) 

196 tl.store(Out, out) 

197 

198 

199@libentry() 

200# @triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) 

201@triton.heuristics( 

202 { 

203 "BLOCK_M": heur_block_m, 

204 "BLOCK_N": heur_block_n, 

205 } 

206) 

207@triton.jit 

208def l0_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

209 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

210 X = X + pid * N 

211 Out = Out + pid 

212 row_mask = pid < M 

213 

214 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

215 for off in range(0, N, BLOCK_N): 

216 cols = off + tl.arange(0, BLOCK_N)[None, :] 

217 col_mask = cols < N 

218 mask = row_mask and col_mask 

219 

220 a = tl.load(X + cols, mask, other=0).to(tl.float32) 

221 _sum += tl.where(a != 0, 1, 0) 

222 sum = tl.sum(_sum, axis=1) 

223 out = sum[:, None] 

224 tl.store(Out, out, row_mask) 

225 

226 

227@libentry() 

228@triton.jit 

229def l0_norm_kernel_1( 

230 X, Mid, M, BLOCK_SIZE: tl.constexpr, buffer_size_limit: tl.constexpr 

231): 

232 pid = tle.program_id(0).to(tl.int64) 

233 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

234 X = X + offset 

235 Mid = Mid + pid 

236 mask = offset < M 

237 

238 x = tl.load(X, mask=mask, other=0.0).to(tl.float32) 

239 cnt = (x != 0).to(tl.float32) 

240 mid = tl.sum(cnt) 

241 tl.store(Mid, mid) 

242 

243 

244@libentry() 

245@triton.jit 

246def l0_norm_kernel_2( 

247 Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr, buffer_size_limit: tl.constexpr 

248): 

249 offset = tl.arange(0, BLOCK_MID) 

250 Mid = Mid + offset 

251 mask = offset < MID_SIZE 

252 mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32) 

253 out = tl.sum(mid) 

254 tl.store(Out, out) 

255 

256 

257@libentry() 

258# @triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) 

259@triton.heuristics( 

260 { 

261 "BLOCK_M": heur_block_m, 

262 "BLOCK_N": heur_block_n, 

263 } 

264) 

265@triton.jit(do_not_specialize=["ord"]) 

266def v_norm_kernel(X, Out, M, N, ord, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

267 ord = ord.to(tl.float32) 

268 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

269 X = X + pid * N 

270 Out = Out + pid 

271 row_mask = pid < M 

272 

273 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

274 for off in range(0, N, BLOCK_N): 

275 cols = off + tl.arange(0, BLOCK_N)[None, :] 

276 col_mask = cols < N 

277 mask = row_mask and col_mask 

278 

279 a = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

280 _sum += pow(tl.abs(a), ord) 

281 sum = tl.sum(_sum, axis=1) 

282 out = pow(sum, 1 / ord)[:, None] 

283 tl.store(Out, out, row_mask) 

284 

285 

286@libentry() 

287@triton.jit(do_not_specialize=["ord"]) 

288def l1_norm_kernel_1( 

289 X, Mid, ord, M, BLOCK_SIZE: tl.constexpr, buffer_size_limit: tl.constexpr 

290): 

291 ord = ord.to(tl.float32) 

292 pid = tle.program_id(0).to(tl.int64) 

293 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

294 X = X + offset 

295 Mid = Mid + pid 

296 mask = offset < M 

297 

298 x = tl.load(X, mask=mask, other=0.0).to(tl.float32) 

299 mid = tl.sum(pow(tl.abs(x), ord)) 

300 tl.store(Mid, mid) 

301 

302 

303@libentry() 

304@triton.jit(do_not_specialize=["ord"]) 

305def l1_norm_kernel_2( 

306 Mid, Out, ord, MID_SIZE, BLOCK_MID: tl.constexpr, buffer_size_limit: tl.constexpr 

307): 

308 ord = ord.to(tl.float32) 

309 offset = tl.arange(0, BLOCK_MID) 

310 Mid = Mid + offset 

311 mask = offset < MID_SIZE 

312 mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32) 

313 out = pow(tl.sum(mid), 1 / ord) 

314 tl.store(Out, out) 

315 

316 

317def vector_norm(x, ord=2, dim=None, keepdim=False, dtype=None): 

318 logger.debug("GEMS VECTOR NORM") 

319 if dtype is not None: 

320 dtype = torch.dtype(dtype) 

321 else: 

322 dtype = x.dtype 

323 if dtype not in [torch.float16, torch.float32, torch.bfloat16]: 

324 raise NotImplementedError(f"vector_norm not implemented for {dtype}") 

325 

326 with torch_device_fn.device(x.device): 

327 if (not dim) or len(dim) == x.ndim: 

328 dim = list(range(x.ndim)) 

329 shape = [1] * x.ndim 

330 x = dim_compress(x, dim) 

331 M = x.numel() 

332 cluster_num = 12 

333 BLOCK_SIZE = min( 

334 triton.next_power_of_2(triton.cdiv(M, cluster_num)), 

335 int(1024 * 64 / x.element_size()), 

336 ) 

337 MID_SIZE = triton.cdiv(M, BLOCK_SIZE) 

338 BLOCK_MID = triton.next_power_of_2(MID_SIZE) 

339 

340 mid = torch.empty([MID_SIZE], dtype=dtype, device=x.device) 

341 out = torch.empty(shape, dtype=dtype, device=x.device) 

342 if ord == 2: 

343 l2_norm_kernel_1[(MID_SIZE,)]( 

344 x, mid, M, BLOCK_SIZE, buffer_size_limit=2048 

345 ) 

346 l2_norm_kernel_2[(1,)]( 

347 mid, out, MID_SIZE, BLOCK_MID, buffer_size_limit=2048 

348 ) 

349 elif ord == float("inf"): 

350 max_norm_kernel_1[(MID_SIZE,)]( 

351 x, mid, M, BLOCK_SIZE, buffer_size_limit=2048 

352 ) 

353 max_norm_kernel_2[(1,)]( 

354 mid, out, MID_SIZE, BLOCK_MID, buffer_size_limit=2048 

355 ) 

356 elif ord == -float("inf"): 

357 min_norm_kernel_1[(MID_SIZE,)]( 

358 x, mid, M, BLOCK_SIZE, buffer_size_limit=2048 

359 ) 

360 min_norm_kernel_2[(1,)]( 

361 mid, out, MID_SIZE, BLOCK_MID, buffer_size_limit=2048 

362 ) 

363 elif ord == 0: 

364 l0_norm_kernel_1[(MID_SIZE,)]( 

365 x, mid, M, BLOCK_SIZE, buffer_size_limit=2048 

366 ) 

367 l0_norm_kernel_2[(1,)]( 

368 mid, out, MID_SIZE, BLOCK_MID, buffer_size_limit=2048 

369 ) 

370 else: 

371 l1_norm_kernel_1[(MID_SIZE,)]( 

372 x, mid, ord, M, BLOCK_SIZE, buffer_size_limit=2048 

373 ) 

374 l1_norm_kernel_2[(1,)]( 

375 mid, out, ord, MID_SIZE, BLOCK_MID, buffer_size_limit=2048 

376 ) 

377 else: 

378 shape = list(x.shape) 

379 dim = [d % x.ndim for d in dim] 

380 x = dim_compress(x, dim) 

381 N = 1 

382 for i in dim: 

383 N *= shape[i] 

384 shape[i] = 1 

385 M = x.numel() // N 

386 out = torch.empty(shape, dtype=dtype, device=x.device) 

387 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) 

388 if ord == 2: 

389 l2_norm_kernel[grid](x, out, M, N) 

390 elif ord == float("inf"): 

391 max_norm_kernel[grid](x, out, M, N) 

392 elif ord == -float("inf"): 

393 min_norm_kernel[grid](x, out, M, N) 

394 elif ord == 0: 

395 l0_norm_kernel[grid](x, out, M, N) 

396 else: 

397 v_norm_kernel[grid](x, out, M, N, ord, isCloseUnrollControl=True) 

398 if not keepdim: 

399 out = out.squeeze(dim=dim) 

400 return out