Coverage for src/flag_gems/runtime/backend/_ascend/ops/bmm.py: 0%

62 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

13 

14 

15# avoid 

16@libentry() 

17@triton.autotune( 

18 configs=runtime.get_tuned_config("bmm"), 

19 key=["M", "N", "K"], 

20) 

21@triton.heuristics(runtime.get_heuristic_config("bmm")) 

22@triton.jit 

23def bmm_kernel( 

24 A, 

25 B, 

26 O, 

27 M, 

28 N, 

29 K, 

30 TILE_M: tl.constexpr, 

31 TILE_N: tl.constexpr, 

32 TILE_K: tl.constexpr, 

33 GROUP_M: tl.constexpr, 

34 DIVISIBLE_M: tl.constexpr, 

35 DIVISIBLE_N: tl.constexpr, 

36 DIVISIBLE_K: tl.constexpr, 

37): 

38 # batch offsets 

39 pid_b = tle.program_id(2) 

40 A += pid_b * M * K 

41 B += pid_b * K * N 

42 O += pid_b * M * N 

43 

44 pidx = tle.program_id(0) 

45 pidy = tle.program_id(1) 

46 if GROUP_M == 1: 

47 pid_m, pid_n = pidx, pidy 

48 else: 

49 # reorder CTAs 

50 gridx = tle.num_programs(0) 

51 gridy = tle.num_programs(1) 

52 pid = pidx + pidy * gridx 

53 

54 num_CTA_per_group = gridy * GROUP_M 

55 

56 group_id = pid // num_CTA_per_group 

57 inner_group_id = pid % num_CTA_per_group 

58 GROUP_SIZE = tl.where( 

59 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M 

60 ) 

61 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE 

62 pid_n = inner_group_id // GROUP_SIZE 

63 

64 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M) 

65 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N) 

66 offs_k = tl.arange(0, TILE_K) 

67 

68 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :] 

69 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :] 

70 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :] 

71 

72 num_iters = tl.cdiv(K, TILE_K) 

73 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) 

74 for i in range(num_iters): 

75 mask_a = offs_k[None, :] < K - i * TILE_K 

76 mask_b = offs_k[:, None] < K - i * TILE_K 

77 a = tl.load(a_ptrs, mask=mask_a) 

78 b = tl.load(b_ptrs, mask=mask_b) 

79 

80 a_ptrs += TILE_K 

81 b_ptrs += TILE_K * N 

82 

83 o += tl.dot(a, b, allow_tf32=False) 

84 

85 mask_m = (pid_m * TILE_M + tl.arange(0, TILE_M)) < M 

86 mask_n = (pid_n * TILE_N + tl.arange(0, TILE_N)) < N 

87 mask_c = mask_m[:, None] & mask_n[None, :] 

88 tl.store(o_ptrs, o, mask_c) 

89 

90 

91def bmm(A, B): 

92 logger.debug("GEMS_ASCEND BMM") 

93 batch, M, K = A.shape 

94 _, _, N = B.shape 

95 A = A.contiguous() 

96 B = B.contiguous() 

97 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device) 

98 

99 grid_fn = lambda meta: ( 

100 triton.cdiv(meta["M"], meta["TILE_M"]), 

101 triton.cdiv(meta["N"], meta["TILE_N"]), 

102 batch, 

103 ) 

104 

105 with torch_device_fn.device(A.device): 

106 bmm_kernel[grid_fn](A, B, out, M, N, K) 

107 return out