Coverage for src/flag_gems/runtime/backend/_ascend/ops/vector_norm.py: 0%

267 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import dim_compress, libentry, tl_extra_shim 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

14 

15 

16try: 

17 import torch_npu # noqa: F401 

18 

19 pow = tl.extra.ascend.libdevice.pow 

20except: # noqa: E722 

21 pow = tl_extra_shim.pow 

22 

23 

24@libentry() 

25@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) 

26@triton.jit 

27def l2_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

28 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

29 X = X + pid * N 

30 Out = Out + pid 

31 row_mask = pid < M 

32 

33 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

34 for off in range(0, N, BLOCK_N): 

35 cols = off + tl.arange(0, BLOCK_N)[None, :] 

36 col_mask = cols < N 

37 mask = row_mask and col_mask 

38 

39 a = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

40 _sum += a * a 

41 sum = tl.sum(_sum, axis=1) 

42 

43 out = tl.sqrt(sum)[:, None] 

44 tl.store(Out, out, row_mask) 

45 

46 

47@libentry() 

48@triton.jit 

49def l2_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr, BLOCK_SIZE_SUB: tl.constexpr): 

50 pid = tl.program_id(0).to(tl.int64) 

51 

52 total_sum = 0.0 

53 

54 for off in range(0, BLOCK_SIZE, BLOCK_SIZE_SUB): 

55 offsets = pid * BLOCK_SIZE + off + tl.arange(0, BLOCK_SIZE_SUB) 

56 mask = offsets < M 

57 x = tl.load(X + offsets, mask=mask, other=0.0).to(tl.float32) 

58 total_sum += tl.sum(x * x) 

59 

60 tl.store(Mid + pid, total_sum) 

61 

62 

63@libentry() 

64@triton.jit 

65def l2_norm_kernel_2( 

66 Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr, BLOCK_MID_SUB: tl.constexpr 

67): 

68 pid = tl.program_id(0).to(tl.int64) 

69 

70 total_sum = 0.0 

71 

72 for off in range(0, MID_SIZE, BLOCK_MID_SUB): 

73 offsets = pid * MID_SIZE + off + tl.arange(0, BLOCK_MID_SUB) 

74 mask = offsets < MID_SIZE 

75 x = tl.load(Mid + offsets, mask=mask, other=0.0).to(tl.float32) 

76 total_sum += tl.sum(x) 

77 out = tl.sqrt(total_sum) 

78 tl.store(Out, out) 

79 

80 

81@libentry() 

82@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) 

83@triton.jit 

84def max_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

85 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

86 X = X + pid * N 

87 Out = Out + pid 

88 row_mask = pid < M 

89 

90 _max = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

91 for off in range(0, N, BLOCK_N): 

92 cols = off + tl.arange(0, BLOCK_N)[None, :] 

93 col_mask = cols < N 

94 mask = row_mask and col_mask 

95 

96 a = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

97 _max = tl.maximum(tl.abs(a), _max) 

98 

99 max = tl.max(_max, axis=1) 

100 out = max[:, None] 

101 tl.store(Out, out, row_mask) 

102 

103 

104@libentry() 

105@triton.jit 

106def max_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr): 

107 pid = tle.program_id(0).to(tl.int64) 

108 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

109 X = X + offset 

110 Mid = Mid + pid 

111 mask = offset < M 

112 

113 x = tl.load(X, mask=mask, other=0.0).to(tl.float32) 

114 mid = tl.max(tl.abs(x)) 

115 tl.store(Mid, mid) 

116 

117 

118@libentry() 

119@triton.jit 

120def max_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr): 

121 offset = tl.arange(0, BLOCK_MID) 

122 Mid = Mid + offset 

123 mask = offset < MID_SIZE 

124 mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32) 

125 out = tl.max(mid) 

126 tl.store(Out, out) 

127 

128 

129@libentry() 

130@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) 

131@triton.jit 

132def min_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

133 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

134 X = X + pid * N 

135 Out = Out + pid 

136 row_mask = pid < M 

137 

138 _min = tl.full([BLOCK_M, BLOCK_N], value=float("inf"), dtype=tl.float32) 

139 for off in range(0, N, BLOCK_N): 

140 cols = off + tl.arange(0, BLOCK_N)[None, :] 

141 col_mask = cols < N 

142 mask = row_mask and col_mask 

143 

144 a = tl.load(X + cols, mask, other=float("inf")).to(tl.float32) 

145 _min = tl.minimum(tl.abs(a), _min) 

146 

147 min = tl.min(_min, axis=1) 

148 out = min[:, None] 

149 tl.store(Out, out, row_mask) 

150 

151 

152@libentry() 

153@triton.jit 

154def min_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr): 

155 pid = tle.program_id(0).to(tl.int64) 

156 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

157 X = X + offset 

158 Mid = Mid + pid 

159 mask = offset < M 

160 

161 x = tl.load(X, mask=mask, other=float("inf")).to(tl.float32) 

162 mid = tl.min(tl.abs(x)) 

163 tl.store(Mid, mid) 

164 

165 

166@libentry() 

167@triton.jit 

168def min_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr): 

169 offset = tl.arange(0, BLOCK_MID) 

170 Mid = Mid + offset 

171 mask = offset < MID_SIZE 

172 mid = tl.load(Mid, mask=mask, other=float("inf")).to(tl.float32) 

173 out = tl.min(mid) 

174 tl.store(Out, out) 

175 

176 

177@libentry() 

178@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) 

179@triton.jit 

180def l0_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

181 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

182 X = X + pid * N 

183 Out = Out + pid 

184 row_mask = pid < M 

185 

186 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

187 for off in range(0, N, BLOCK_N): 

188 cols = off + tl.arange(0, BLOCK_N)[None, :] 

189 col_mask = cols < N 

190 mask = row_mask and col_mask 

191 

192 a = tl.load(X + cols, mask, other=0).to(tl.float32) 

193 _sum += tl.where(a != 0, 1, 0) 

194 sum = tl.sum(_sum, axis=1) 

195 out = sum[:, None] 

196 tl.store(Out, out, row_mask) 

197 

198 

199@libentry() 

200@triton.jit 

201def l0_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr): 

202 pid = tle.program_id(0).to(tl.int64) 

203 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

204 X = X + offset 

205 Mid = Mid + pid 

206 mask = offset < M 

207 

208 x = tl.load(X, mask=mask, other=0.0).to(tl.float32) 

209 cnt = (x != 0).to(tl.float32) 

210 mid = tl.sum(cnt) 

211 tl.store(Mid, mid) 

212 

213 

214@libentry() 

215@triton.jit 

216def l0_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr): 

217 offset = tl.arange(0, BLOCK_MID) 

218 Mid = Mid + offset 

219 mask = offset < MID_SIZE 

220 mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32) 

221 out = tl.sum(mid) 

222 tl.store(Out, out) 

223 

224 

225@libentry() 

226@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) 

227@triton.jit(do_not_specialize=["ord"]) 

228def v_norm_kernel(X, Out, M, N, ord, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

229 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

230 X = X + pid * N 

231 Out = Out + pid 

232 row_mask = pid < M 

233 

234 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

235 for off in range(0, N, BLOCK_N): 

236 cols = off + tl.arange(0, BLOCK_N)[None, :] 

237 col_mask = cols < N 

238 mask = row_mask and col_mask 

239 

240 a = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

241 _sum += pow(tl.abs(a), ord) 

242 sum = tl.sum(_sum, axis=1) 

243 out = pow(sum, 1 / ord)[:, None] 

244 tl.store(Out, out, row_mask) 

245 

246 

247@libentry() 

248@triton.jit(do_not_specialize=["ord"]) 

249def l1_norm_kernel_1(X, Mid, ord, M, BLOCK_SIZE: tl.constexpr): 

250 pid = tle.program_id(0).to(tl.int64) 

251 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

252 X = X + offset 

253 Mid = Mid + pid 

254 mask = offset < M 

255 

256 x = tl.load(X, mask=mask, other=0.0).to(tl.float32) 

257 mid = tl.sum(pow(tl.abs(x), ord)) 

258 tl.store(Mid, mid) 

259 

260 

261@libentry() 

262@triton.jit(do_not_specialize=["ord"]) 

263def l1_norm_kernel_2(Mid, Out, ord, MID_SIZE, BLOCK_MID: tl.constexpr): 

264 offset = tl.arange(0, BLOCK_MID) 

265 Mid = Mid + offset 

266 mask = offset < MID_SIZE 

267 mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32) 

268 out = pow(tl.sum(mid), 1 / ord) 

269 tl.store(Out, out) 

270 

271 

272def vector_norm(x, ord=2, dim=None, keepdim=False, dtype=None): 

273 logger.debug("GEMS_ASCEND VECTOR NORM") 

274 if dtype is not None: 

275 dtype = torch.dtype(dtype) 

276 else: 

277 dtype = x.dtype 

278 if dtype not in [torch.float16, torch.float32, torch.bfloat16]: 

279 raise NotImplementedError(f"vector_norm not implemented for {dtype}") 

280 

281 with torch_device_fn.device(x.device): 

282 if (not dim) or len(dim) == x.ndim: 

283 dim = list(range(x.ndim)) 

284 shape = [1] * x.ndim 

285 x = dim_compress(x, dim) 

286 M = x.numel() 

287 

288 MAX_BLOCK_SIZE = 32768 

289 BLOCK_SIZE = min( 

290 triton.next_power_of_2(math.ceil(math.sqrt(M))), MAX_BLOCK_SIZE 

291 ) 

292 MID_SIZE = triton.cdiv(M, BLOCK_SIZE) 

293 BLOCK_MID = triton.next_power_of_2(MID_SIZE) 

294 if BLOCK_MID >= 512: 

295 BLOCK_MID_SUB = 512 

296 else: 

297 BLOCK_MID_SUB = 1 

298 mid = torch.empty([MID_SIZE], dtype=dtype, device=x.device) 

299 out = torch.empty(shape, dtype=dtype, device=x.device) 

300 if ord == 2: 

301 l2_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE, BLOCK_MID_SUB) 

302 l2_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID, BLOCK_MID_SUB) 

303 elif ord == float("inf"): 

304 max_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE) 

305 max_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID) 

306 elif ord == -float("inf"): 

307 min_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE) 

308 min_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID) 

309 elif ord == 0: 

310 l0_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE) 

311 l0_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID) 

312 else: 

313 l1_norm_kernel_1[(MID_SIZE,)](x, mid, ord, M, BLOCK_SIZE) 

314 l1_norm_kernel_2[(1,)](mid, out, ord, MID_SIZE, BLOCK_MID) 

315 else: 

316 shape = list(x.shape) 

317 dim = [d % x.ndim for d in dim] 

318 x = dim_compress(x, dim) 

319 N = 1 

320 for i in dim: 

321 N *= shape[i] 

322 shape[i] = 1 

323 M = x.numel() // N 

324 out = torch.empty(shape, dtype=dtype, device=x.device) 

325 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) 

326 if ord == 2: 

327 l2_norm_kernel[grid](x, out, M, N) 

328 elif ord == float("inf"): 

329 max_norm_kernel[grid](x, out, M, N) 

330 elif ord == -float("inf"): 

331 min_norm_kernel[grid](x, out, M, N) 

332 elif ord == 0: 

333 l0_norm_kernel[grid](x, out, M, N) 

334 else: 

335 v_norm_kernel[grid](x, out, M, N, ord) 

336 if not keepdim: 

337 out = out.squeeze(dim=dim) 

338 return out