Coverage for src/flag_gems/ops/mm_streamk.py: 22%
234 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import libentry
8from flag_gems.utils import triton_lang_extension as tle
10logger = logging.getLogger(__name__)
13@triton.jit
14def prev_multiple_of(a, b):
15 # the largest x<a that x%b ==0
16 return tl.cdiv(a, b) * b - b
19@triton.jit()
20def swizzle_tile(
21 tile_id,
22 M,
23 N,
24 BLOCK_M: tl.constexpr,
25 BLOCK_N: tl.constexpr,
26 GROUP_M: tl.constexpr,
27):
28 grid_m = tl.cdiv(M, BLOCK_M)
29 grid_n = tl.cdiv(N, BLOCK_N)
30 # re-order program ID for better L2 performance
31 width = GROUP_M * grid_n
32 group_id = tile_id // width
33 group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M)
34 pid_m = group_id * GROUP_M + (tile_id % group_size)
35 pid_n = (tile_id % width) // group_size
36 return pid_m, pid_n
39@triton.jit()
40def linear_tile(
41 tile_id,
42 M,
43 N,
44 BLOCK_M: tl.constexpr,
45 BLOCK_N: tl.constexpr,
46 GROUP_M: tl.constexpr,
47):
48 grid_n = tl.cdiv(N, BLOCK_N)
50 # column first
51 pid_m = tile_id // grid_n
52 pid_n = tile_id % grid_n
54 return pid_m, pid_n
57@triton.jit(
58 do_not_specialize=[
59 "iters_per_pid",
60 "iters_remaining",
61 "iters_per_tile",
62 "start_iter",
63 "end_iter",
64 ]
65)
66def mac_loop(
67 A,
68 B,
69 C,
70 P,
71 M,
72 N,
73 K,
74 locks,
75 stride_am,
76 stride_ak,
77 stride_bk,
78 stride_bn,
79 stride_cm,
80 stride_cn,
81 iters_per_pid,
82 iters_remaining,
83 iters_per_tile,
84 start_iter,
85 end_iter,
86 BLOCK_M: tl.constexpr,
87 BLOCK_N: tl.constexpr,
88 BLOCK_K: tl.constexpr,
89 GROUP_M: tl.constexpr,
90):
91 # where are we in the grid
92 pid = tle.program_id(0)
93 tile_id = start_iter // iters_per_tile
95 pid_m, pid_n = swizzle_tile(tile_id, M, N, BLOCK_M, BLOCK_N, GROUP_M)
97 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
98 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
99 rk = tl.arange(0, BLOCK_K)
101 if stride_am == 1:
102 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
103 else:
104 ram = rm % M
105 if stride_bk == 1:
106 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
107 else:
108 rbn = rn % N
110 # pointers
111 A_base = A + ram[:, None] * stride_am
112 B_base = B + rbn[None, :] * stride_bn
113 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
115 if end_iter % iters_per_tile != 0:
116 for current_iter in range(start_iter, end_iter):
117 k_offset_in_tile = (current_iter % iters_per_tile) * BLOCK_K
118 a = tl.load(A_base + (k_offset_in_tile + rk[None, :]) * stride_ak)
119 b = tl.load(B_base + (k_offset_in_tile + rk[:, None]) * stride_bk)
120 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
121 else:
122 prev_multiple = prev_multiple_of(K, BLOCK_K)
123 for current_iter in range(start_iter, end_iter - 1):
124 k_offset_in_tile = (current_iter % iters_per_tile) * BLOCK_K
125 a = tl.load(A_base + (k_offset_in_tile + rk[None, :]) * stride_ak)
126 b = tl.load(B_base + (k_offset_in_tile + rk[:, None]) * stride_bk)
127 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
129 # handle the last iter
130 rk = prev_multiple + tl.arange(0, BLOCK_K)
131 mask_k = rk < K
132 a = tl.load(A_base + rk[None, :] * stride_ak, mask=mask_k[None, :])
133 b = tl.load(B_base + rk[:, None] * stride_bk, mask=mask_k[:, None])
134 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
136 rm1 = tl.arange(0, BLOCK_M)
137 rn1 = tl.arange(0, BLOCK_N)
139 # the first situation: not the starting parts. only need to store the data on P
140 if start_iter % iters_per_tile != 0:
141 P_ptr = P + pid * BLOCK_M * BLOCK_N + (rm1[:, None] * BLOCK_N + rn1[None, :])
142 tl.store(P_ptr, acc, cache_modifier=".cg")
143 # tl.debug_barrier()
144 tl.atomic_xchg(locks + pid, 1)
145 else: # the first part of certain grids. shoud read datas and merge datas
146 next_pid = pid + 1
147 stop_loading_iter = start_iter + iters_per_tile
148 end = end_iter
149 while end < stop_loading_iter:
150 while tl.atomic_cas(locks + next_pid, 1, 1) != 1:
151 pass
152 P_ptr = (
153 P
154 + next_pid * BLOCK_M * BLOCK_N
155 + (rm1[:, None] * BLOCK_N + rn1[None, :])
156 )
157 acc += tl.load(P_ptr, cache_modifier=".cg")
158 end += iters_per_pid + (next_pid < iters_remaining)
159 next_pid += 1
161 # acc = acc.to(C.dtype.element_ty) #
162 C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
163 mask = (rm < M)[:, None] & (rn < N)[None, :]
164 tl.store(C_, acc, mask=mask)
167@libentry()
168@triton.jit(
169 do_not_specialize=[
170 "iters_per_pid",
171 "iters_remaining",
172 "iters_per_tile",
173 ],
174)
175def first_wave(
176 A,
177 B,
178 C,
179 M,
180 N,
181 K,
182 locks,
183 stride_am,
184 stride_ak,
185 stride_bk,
186 stride_bn,
187 stride_cm,
188 stride_cn,
189 iters_per_pid,
190 iters_remaining,
191 iters_per_tile,
192 BLOCK_M: tl.constexpr,
193 BLOCK_N: tl.constexpr,
194 BLOCK_K: tl.constexpr,
195 GROUP_M: tl.constexpr,
196 EVEN_K: tl.constexpr,
197):
198 pid = tle.program_id(0) # pid range from 0 to sm_count
199 start_iter = pid * iters_per_pid + tl.minimum(pid, iters_remaining)
200 last_iter = (pid + 1) * iters_per_pid + tl.minimum(pid + 1, iters_remaining)
201 while start_iter < last_iter:
202 iter_offset_in_tile = start_iter % iters_per_tile
203 # Iterate over the K axis. Recalculate end_iter as M/N may change during the iteration.
204 end_iter = tl.minimum(
205 start_iter + (iters_per_tile - iter_offset_in_tile), last_iter
206 )
208 tile_id = start_iter // iters_per_tile
210 pid_m, pid_n = swizzle_tile(tile_id, M, N, BLOCK_M, BLOCK_N, GROUP_M)
212 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
213 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
214 rk = tl.arange(0, BLOCK_K)
216 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
217 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
219 A_base = (
220 A
221 + ram[:, None] * stride_am
222 + rk[None, :] * stride_ak
223 + BLOCK_K * stride_ak * iter_offset_in_tile
224 )
225 B_base = (
226 B
227 + rk[:, None] * stride_bk
228 + rbn[None, :] * stride_bn
229 + BLOCK_K * stride_bk * iter_offset_in_tile
230 )
232 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
234 for current_iter in range(start_iter, end_iter):
235 if EVEN_K:
236 a = tl.load(A_base)
237 b = tl.load(B_base)
238 else:
239 k_offset_in_tile = (current_iter % iters_per_tile) * BLOCK_K
240 k_mask = (k_offset_in_tile + rk) < K
241 a = tl.load(A_base, mask=k_mask[None, :], other=0.0)
242 b = tl.load(B_base, mask=k_mask[:, None], other=0.0)
244 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
245 A_base += BLOCK_K * stride_ak
246 B_base += BLOCK_K * stride_bk
248 # last iteration of the tile always happens before its start on another SM
249 if end_iter % iters_per_tile == 0:
250 C_ptr = C + (
251 rm[:, None] * stride_cm + rn[None, :] * stride_cn
252 ) # compute inside the if/else to avoid spilling!
253 mask = (rm < M)[:, None] & (rn < N)[None, :]
254 tl.store(C_ptr, acc, mask=mask)
255 if iter_offset_in_tile != 0: # only if tile has been partially processed
256 tl.atomic_xchg(locks + tile_id, 1)
257 else:
258 while tl.atomic_cas(locks + tile_id, 1, 1) != 1:
259 pass
260 C_ptr = C + (
261 rm[:, None] * stride_cm + rn[None, :] * stride_cn
262 ) # compute inside the if/else to avoid spilling!
263 mask = (rm < M)[:, None] & (rn < N)[None, :]
264 tl.atomic_add(C_ptr, acc, mask=mask, sem="relaxed")
265 # next round
266 start_iter = end_iter
269@libentry()
270@triton.jit(
271 do_not_specialize=[
272 "iters_per_pid",
273 "iters_remaining",
274 "iters_per_tile",
275 ],
276)
277def first_wave_for_bf16(
278 A,
279 B,
280 C,
281 P,
282 M,
283 N,
284 K,
285 locks,
286 stride_am,
287 stride_ak,
288 stride_bk,
289 stride_bn,
290 stride_cm,
291 stride_cn,
292 iters_per_pid,
293 iters_remaining,
294 iters_per_tile,
295 BLOCK_M: tl.constexpr,
296 BLOCK_N: tl.constexpr,
297 BLOCK_K: tl.constexpr,
298 GROUP_M: tl.constexpr,
299 EVEN_K: tl.constexpr,
300):
301 pid = tle.program_id(0) # pid range from 0 to sm_count
302 start_iter = pid * iters_per_pid + tl.minimum(pid, iters_remaining)
303 last_iter = (pid + 1) * iters_per_pid + tl.minimum(pid + 1, iters_remaining)
304 while start_iter < last_iter:
305 iter_offset_in_tile = start_iter % iters_per_tile
306 # Iterate over the K axis. Recalculate end_iter as M/N may change during the iteration.
307 end_iter = tl.minimum(
308 start_iter + (iters_per_tile - iter_offset_in_tile), last_iter
309 )
311 tile_id = start_iter // iters_per_tile
313 pid_m, pid_n = swizzle_tile(tile_id, M, N, BLOCK_M, BLOCK_N, GROUP_M)
315 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
316 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
317 rk = tl.arange(0, BLOCK_K)
319 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
320 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
322 A_base = (
323 A
324 + ram[:, None] * stride_am
325 + rk[None, :] * stride_ak
326 + BLOCK_K * stride_ak * iter_offset_in_tile
327 )
328 B_base = (
329 B
330 + rk[:, None] * stride_bk
331 + rbn[None, :] * stride_bn
332 + BLOCK_K * stride_bk * iter_offset_in_tile
333 )
335 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
337 for current_iter in range(start_iter, end_iter):
338 if EVEN_K:
339 a = tl.load(A_base)
340 b = tl.load(B_base)
341 else:
342 k_offset_in_tile = (current_iter % iters_per_tile) * BLOCK_K
343 k_mask = (k_offset_in_tile + rk) < K
344 a = tl.load(A_base, mask=k_mask[None, :], other=0.0)
345 b = tl.load(B_base, mask=k_mask[:, None], other=0.0)
347 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
348 A_base += BLOCK_K * stride_ak
349 B_base += BLOCK_K * stride_bk
351 rm1 = tl.arange(0, BLOCK_M)
352 rn1 = tl.arange(0, BLOCK_N)
354 # the first situation: not the starting parts. only need to store the data on P
355 if start_iter % iters_per_tile != 0:
356 P_ptr = (
357 P + pid * BLOCK_M * BLOCK_N + (rm1[:, None] * BLOCK_N + rn1[None, :])
358 )
359 tl.store(P_ptr, acc, cache_modifier=".cg")
360 # tl.debug_barrier()
361 tl.atomic_xchg(locks + pid, 1)
362 else: # the first part of certain grids. shoud read datas and merge datas
363 next_pid = pid + 1
364 stop_loading_iter = start_iter + iters_per_tile
365 end = end_iter
366 while end < stop_loading_iter:
367 while tl.atomic_cas(locks + next_pid, 1, 1) != 1:
368 pass
369 P_ptr = (
370 P
371 + next_pid * BLOCK_M * BLOCK_N
372 + (rm1[:, None] * BLOCK_N + rn1[None, :])
373 )
374 acc += tl.load(P_ptr, cache_modifier=".cg")
375 end += iters_per_pid + (next_pid < iters_remaining)
376 next_pid += 1
378 # acc = acc.to(C.dtype.element_ty) #
379 C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
380 mask = (rm < M)[:, None] & (rn < N)[None, :]
381 tl.store(C_, acc, mask=mask)
382 start_iter = end_iter
385@libentry()
386@triton.jit
387def classic_mm(
388 A,
389 B,
390 C,
391 M,
392 N,
393 K,
394 stride_am,
395 stride_ak,
396 stride_bk,
397 stride_bn,
398 stride_cm,
399 stride_cn,
400 total_tiles_streamk,
401 BLOCK_M: tl.constexpr,
402 BLOCK_N: tl.constexpr,
403 BLOCK_K: tl.constexpr,
404 GROUP_M: tl.constexpr,
405):
406 # first wave has done more tiles than there are SMs, we adjust pid
407 tile_id = tle.program_id(0) + total_tiles_streamk
408 pid_m, pid_n = swizzle_tile(tile_id, M, N, BLOCK_M, BLOCK_N, GROUP_M)
410 # do matrix multiplication
411 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
412 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
413 # pointers
414 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
415 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
416 prev_multiple = prev_multiple_of(K, BLOCK_K)
418 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
419 for start_k in range(0, prev_multiple, BLOCK_K):
420 rk = start_k + tl.arange(0, BLOCK_K)
421 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak))
422 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn))
423 if a.dtype != b.dtype:
424 a = a.to(C.dtype.element_ty)
425 b = b.to(C.dtype.element_ty)
426 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
428 # loop peeling
429 rk = prev_multiple + tl.arange(0, BLOCK_K)
430 mask_k = rk < K
431 a = tl.load(
432 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), mask=mask_k[None, :]
433 )
434 b = tl.load(
435 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), mask=mask_k[:, None]
436 )
437 if a.dtype != b.dtype:
438 a = a.to(C.dtype.element_ty)
439 b = b.to(C.dtype.element_ty)
440 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
442 acc = acc.to(C.dtype.element_ty)
443 # rematerialize rm and rn to save registers
444 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
445 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
446 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
447 mask = (rm < M)[:, None] & (rn < N)[None, :]
448 # handles write-back with reduction-splitting
449 tl.store(C, acc, mask=mask)
452def streamk_mm(a, b, c, M, N, K, sm_count=108):
453 logger.debug(
454 "GEMS MM, [mm scenario]: streamk, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
455 "[A column-major]: %s, [B column-major]: %s",
456 M,
457 N,
458 K,
459 a.stride(0) == 1,
460 b.stride(0) == 1,
461 )
462 # TODO: change the hard code to tuning config
463 BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 128
464 num_stages = 3
465 num_warps = 8
466 GROUP_M = 8
467 number_blocks_m = triton.cdiv(M, BLOCK_M)
468 number_blocks_n = triton.cdiv(N, BLOCK_N)
470 total_tiles = number_blocks_m * number_blocks_n
471 iters_per_tile = triton.cdiv(K, BLOCK_K)
472 tiles_per_wave = sm_count
474 number_cooperative_tiles = total_tiles % tiles_per_wave
475 number_other_tiles = total_tiles - number_cooperative_tiles
476 if number_other_tiles > 0 and number_cooperative_tiles < sm_count * 0.5:
477 number_cooperative_tiles = number_cooperative_tiles + tiles_per_wave
478 elif number_other_tiles > 0 and number_cooperative_tiles > sm_count * 0.8:
479 number_cooperative_tiles = 0
481 if number_cooperative_tiles > 0:
482 # mini wave
483 total_iters_streamk = number_cooperative_tiles * iters_per_tile
484 iters_per_pid = total_iters_streamk // tiles_per_wave
485 iters_remaining = total_iters_streamk % tiles_per_wave
486 even_k = K % BLOCK_K == 0
488 if a.dtype == torch.bfloat16:
489 locks = torch.zeros((tiles_per_wave,), device=a.device, dtype=torch.int32)
490 P = torch.empty(
491 (tiles_per_wave, BLOCK_M, BLOCK_N), device=a.device, dtype=torch.float32
492 )
493 # with torch_device_fn.device(a.device):
494 first_wave_for_bf16[(tiles_per_wave,)](
495 a,
496 b,
497 c,
498 P,
499 M,
500 N,
501 K,
502 locks,
503 a.stride(0),
504 a.stride(1),
505 b.stride(0),
506 b.stride(1),
507 c.stride(0),
508 c.stride(1),
509 iters_per_pid=iters_per_pid,
510 iters_remaining=iters_remaining,
511 iters_per_tile=iters_per_tile,
512 BLOCK_M=BLOCK_M,
513 BLOCK_N=BLOCK_N,
514 BLOCK_K=BLOCK_K,
515 GROUP_M=GROUP_M,
516 EVEN_K=even_k,
517 num_stages=num_stages,
518 num_warps=num_warps,
519 )
520 # logger.debug(f"{k1.n_regs} registers used, {k1.n_spills} spills")
521 # logger.debug(f"shared memory: {k1.metadata.shared} bytes")
522 else:
523 locks = torch.zeros(
524 (number_cooperative_tiles,), device=a.device, dtype=torch.int32
525 )
526 first_wave[(tiles_per_wave,)](
527 a,
528 b,
529 c,
530 M,
531 N,
532 K,
533 locks,
534 a.stride(0),
535 a.stride(1),
536 b.stride(0),
537 b.stride(1),
538 c.stride(0),
539 c.stride(1),
540 iters_per_pid=iters_per_pid,
541 iters_remaining=iters_remaining,
542 iters_per_tile=iters_per_tile,
543 BLOCK_M=BLOCK_M,
544 BLOCK_N=BLOCK_N,
545 BLOCK_K=BLOCK_K,
546 GROUP_M=GROUP_M,
547 EVEN_K=even_k,
548 num_stages=num_stages,
549 num_warps=num_warps,
550 )
551 # logger.debug(f"{k1.n_regs} registers used, {k1.n_spills} spills")
552 # logger.debug(f"shared memory: {k1.metadata.shared} bytes")
554 classic_mm[(total_tiles - number_cooperative_tiles,)](
555 a,
556 b,
557 c,
558 M,
559 N,
560 K,
561 a.stride(0),
562 a.stride(1),
563 b.stride(0),
564 b.stride(1),
565 c.stride(0),
566 c.stride(1),
567 total_tiles_streamk=number_cooperative_tiles,
568 BLOCK_M=BLOCK_M,
569 BLOCK_N=BLOCK_N,
570 BLOCK_K=BLOCK_K,
571 GROUP_M=GROUP_M,
572 num_stages=num_stages,
573 num_warps=num_warps,
574 )
575 # logger.debug(f"{k2.n_regs} registers used, {k2.n_spills} spills")
576 # logger.debug(f"shared memory: {k2.metadata.shared} bytes")
577 return c