Coverage for src/flag_gems/ops/vector_norm.py: 44%

259 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import dim_compress, libentry, tl_extra_shim 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13pow = tl_extra_shim.pow 

14logger = logging.getLogger(__name__) 

15 

16 

17@libentry() 

18@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) 

19@triton.jit 

20def l2_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

21 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

22 X = X + pid * N 

23 Out = Out + pid 

24 row_mask = pid < M 

25 

26 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

27 for off in range(0, N, BLOCK_N): 

28 cols = off + tl.arange(0, BLOCK_N)[None, :] 

29 col_mask = cols < N 

30 mask = row_mask and col_mask 

31 

32 a = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

33 _sum += a * a 

34 sum = tl.sum(_sum, axis=1) 

35 

36 out = tl.sqrt(sum)[:, None] 

37 tl.store(Out, out, row_mask) 

38 

39 

40@libentry() 

41@triton.jit 

42def l2_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr): 

43 pid = tle.program_id(0).to(tl.int64) 

44 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

45 X = X + offset 

46 Mid = Mid + pid 

47 mask = offset < M 

48 

49 x = tl.load(X, mask=mask, other=0.0).to(tl.float32) 

50 mid = tl.sum(x * x) 

51 tl.store(Mid, mid) 

52 

53 

54@libentry() 

55@triton.jit 

56def l2_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr): 

57 offset = tl.arange(0, BLOCK_MID) 

58 Mid = Mid + offset 

59 mask = offset < MID_SIZE 

60 mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32) 

61 out = tl.sqrt(tl.sum(mid)) 

62 tl.store(Out, out) 

63 

64 

65@libentry() 

66@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) 

67@triton.jit 

68def max_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

69 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

70 X = X + pid * N 

71 Out = Out + pid 

72 row_mask = pid < M 

73 

74 _max = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

75 for off in range(0, N, BLOCK_N): 

76 cols = off + tl.arange(0, BLOCK_N)[None, :] 

77 col_mask = cols < N 

78 mask = row_mask and col_mask 

79 

80 a = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

81 _max = tl.maximum(tl.abs(a), _max) 

82 

83 max = tl.max(_max, axis=1) 

84 out = max[:, None] 

85 tl.store(Out, out, row_mask) 

86 

87 

88@libentry() 

89@triton.jit 

90def max_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr): 

91 pid = tle.program_id(0).to(tl.int64) 

92 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

93 X = X + offset 

94 Mid = Mid + pid 

95 mask = offset < M 

96 

97 x = tl.load(X, mask=mask, other=0.0).to(tl.float32) 

98 mid = tl.max(tl.abs(x)) 

99 tl.store(Mid, mid) 

100 

101 

102@libentry() 

103@triton.jit 

104def max_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr): 

105 offset = tl.arange(0, BLOCK_MID) 

106 Mid = Mid + offset 

107 mask = offset < MID_SIZE 

108 mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32) 

109 out = tl.max(mid) 

110 tl.store(Out, out) 

111 

112 

113@libentry() 

114@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) 

115@triton.jit 

116def min_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

117 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

118 X = X + pid * N 

119 Out = Out + pid 

120 row_mask = pid < M 

121 

122 _min = tl.full([BLOCK_M, BLOCK_N], value=float("inf"), dtype=tl.float32) 

123 for off in range(0, N, BLOCK_N): 

124 cols = off + tl.arange(0, BLOCK_N)[None, :] 

125 col_mask = cols < N 

126 mask = row_mask and col_mask 

127 

128 a = tl.load(X + cols, mask, other=float("inf")).to(tl.float32) 

129 _min = tl.minimum(tl.abs(a), _min) 

130 

131 min = tl.min(_min, axis=1) 

132 out = min[:, None] 

133 tl.store(Out, out, row_mask) 

134 

135 

136@libentry() 

137@triton.jit 

138def min_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr): 

139 pid = tle.program_id(0).to(tl.int64) 

140 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

141 X = X + offset 

142 Mid = Mid + pid 

143 mask = offset < M 

144 

145 x = tl.load(X, mask=mask, other=float("inf")).to(tl.float32) 

146 mid = tl.min(tl.abs(x)) 

147 tl.store(Mid, mid) 

148 

149 

150@libentry() 

151@triton.jit 

152def min_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr): 

153 offset = tl.arange(0, BLOCK_MID) 

154 Mid = Mid + offset 

155 mask = offset < MID_SIZE 

156 mid = tl.load(Mid, mask=mask, other=float("inf")).to(tl.float32) 

157 out = tl.min(mid) 

158 tl.store(Out, out) 

159 

160 

161@libentry() 

162@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) 

163@triton.jit 

164def l0_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

165 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

166 X = X + pid * N 

167 Out = Out + pid 

168 row_mask = pid < M 

169 

170 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

171 for off in range(0, N, BLOCK_N): 

172 cols = off + tl.arange(0, BLOCK_N)[None, :] 

173 col_mask = cols < N 

174 mask = row_mask and col_mask 

175 

176 a = tl.load(X + cols, mask, other=0).to(tl.float32) 

177 _sum += tl.where(a != 0, 1, 0) 

178 sum = tl.sum(_sum, axis=1) 

179 out = sum[:, None] 

180 tl.store(Out, out, row_mask) 

181 

182 

183@libentry() 

184@triton.jit 

185def l0_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr): 

186 pid = tle.program_id(0).to(tl.int64) 

187 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

188 X = X + offset 

189 Mid = Mid + pid 

190 mask = offset < M 

191 

192 x = tl.load(X, mask=mask, other=0.0).to(tl.float32) 

193 cnt = (x != 0).to(tl.float32) 

194 mid = tl.sum(cnt) 

195 tl.store(Mid, mid) 

196 

197 

198@libentry() 

199@triton.jit 

200def l0_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr): 

201 offset = tl.arange(0, BLOCK_MID) 

202 Mid = Mid + offset 

203 mask = offset < MID_SIZE 

204 mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32) 

205 out = tl.sum(mid) 

206 tl.store(Out, out) 

207 

208 

209@libentry() 

210@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"]) 

211@triton.jit(do_not_specialize=["ord"]) 

212def v_norm_kernel(X, Out, M, N, ord, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

213 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

214 X = X + pid * N 

215 Out = Out + pid 

216 row_mask = pid < M 

217 

218 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

219 for off in range(0, N, BLOCK_N): 

220 cols = off + tl.arange(0, BLOCK_N)[None, :] 

221 col_mask = cols < N 

222 mask = row_mask and col_mask 

223 

224 a = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

225 _sum += pow(tl.abs(a), ord) 

226 sum = tl.sum(_sum, axis=1) 

227 out = pow(sum, 1 / ord)[:, None] 

228 tl.store(Out, out, row_mask) 

229 

230 

231@libentry() 

232@triton.jit(do_not_specialize=["ord"]) 

233def l1_norm_kernel_1(X, Mid, ord, M, BLOCK_SIZE: tl.constexpr): 

234 pid = tle.program_id(0).to(tl.int64) 

235 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

236 X = X + offset 

237 Mid = Mid + pid 

238 mask = offset < M 

239 

240 x = tl.load(X, mask=mask, other=0.0).to(tl.float32) 

241 mid = tl.sum(pow(tl.abs(x), ord)) 

242 tl.store(Mid, mid) 

243 

244 

245@libentry() 

246@triton.jit(do_not_specialize=["ord"]) 

247def l1_norm_kernel_2(Mid, Out, ord, MID_SIZE, BLOCK_MID: tl.constexpr): 

248 offset = tl.arange(0, BLOCK_MID) 

249 Mid = Mid + offset 

250 mask = offset < MID_SIZE 

251 mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32) 

252 out = pow(tl.sum(mid), 1 / ord) 

253 tl.store(Out, out) 

254 

255 

256def vector_norm(x, ord=2, dim=None, keepdim=False, dtype=None): 

257 logger.debug("GEMS VECTOR NORM") 

258 if dtype is not None: 

259 if isinstance(dtype, str): 

260 dtype = getattr(torch, dtype) 

261 elif not isinstance(dtype, torch.dtype): 

262 dtype = torch.float32 

263 else: 

264 dtype = x.dtype 

265 if dtype not in [torch.float16, torch.float32, torch.bfloat16]: 

266 raise NotImplementedError(f"vector_norm not implemented for {dtype}") 

267 

268 with torch_device_fn.device(x.device): 

269 if (not dim) or len(dim) == x.ndim: 

270 dim = list(range(x.ndim)) 

271 shape = [1] * x.ndim 

272 x = dim_compress(x, dim) 

273 M = x.numel() 

274 BLOCK_SIZE = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

275 MID_SIZE = triton.cdiv(M, BLOCK_SIZE) 

276 BLOCK_MID = triton.next_power_of_2(MID_SIZE) 

277 

278 mid = torch.empty([MID_SIZE], dtype=dtype, device=x.device) 

279 out = torch.empty(shape, dtype=dtype, device=x.device) 

280 if ord == 2: 

281 l2_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE) 

282 l2_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID) 

283 elif ord == float("inf"): 

284 max_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE) 

285 max_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID) 

286 elif ord == -float("inf"): 

287 min_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE) 

288 min_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID) 

289 elif ord == 0: 

290 l0_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE) 

291 l0_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID) 

292 else: 

293 l1_norm_kernel_1[(MID_SIZE,)](x, mid, ord, M, BLOCK_SIZE) 

294 l1_norm_kernel_2[(1,)](mid, out, ord, MID_SIZE, BLOCK_MID) 

295 else: 

296 shape = list(x.shape) 

297 dim = [d % x.ndim for d in dim] 

298 x = dim_compress(x, dim) 

299 N = 1 

300 for i in dim: 

301 N *= shape[i] 

302 shape[i] = 1 

303 M = x.numel() // N 

304 out = torch.empty(shape, dtype=dtype, device=x.device) 

305 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) 

306 if ord == 2: 

307 l2_norm_kernel[grid](x, out, M, N) 

308 elif ord == float("inf"): 

309 max_norm_kernel[grid](x, out, M, N) 

310 elif ord == -float("inf"): 

311 min_norm_kernel[grid](x, out, M, N) 

312 elif ord == 0: 

313 l0_norm_kernel[grid](x, out, M, N) 

314 else: 

315 v_norm_kernel[grid](x, out, M, N, ord) 

316 if not keepdim: 

317 out = out.squeeze(dim=dim) 

318 return out