Coverage for src/flag_gems/runtime/backend/_ascend/ops/addmm.py: 0%

51 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import broadcastable_to, libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

13 

14 

15@libentry() 

16@triton.autotune( 

17 configs=runtime.get_tuned_config("addmm"), 

18 key=["M", "N", "K"], 

19) 

20@triton.jit(do_not_specialize=["alpha", "beta"]) 

21def addmm_kernel( 

22 a_ptr, 

23 b_ptr, 

24 i_ptr, 

25 c_ptr, 

26 alpha, 

27 beta, 

28 M, 

29 N, 

30 K, 

31 stride_am, 

32 stride_ak, 

33 stride_bk, 

34 stride_bn, 

35 stride_im, 

36 stride_in, 

37 stride_cm, 

38 stride_cn, 

39 BLOCK_SIZE_M: tl.constexpr, 

40 BLOCK_SIZE_N: tl.constexpr, 

41 BLOCK_SIZE_K: tl.constexpr, 

42): 

43 pid_m = tle.program_id(0) 

44 pid_n = tle.program_id(1) 

45 

46 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

47 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

48 offs_k = tl.arange(0, BLOCK_SIZE_K) 

49 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

50 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

51 

52 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

53 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

54 a = tl.load( 

55 a_ptrs, 

56 mask=(offs_k[None, :] < K - k * BLOCK_SIZE_K), 

57 other=0.0, 

58 ) 

59 b = tl.load( 

60 b_ptrs, 

61 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K), 

62 other=0.0, 

63 ) 

64 accumulator += tl.dot(a, b, allow_tf32=False) 

65 a_ptrs += BLOCK_SIZE_K * stride_ak 

66 b_ptrs += BLOCK_SIZE_K * stride_bk 

67 

68 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

69 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

70 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

71 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

72 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :] 

73 bias = tl.load(i_ptrs, mask=c_mask, other=0.0) 

74 bias1 = bias.to(accumulator.dtype) 

75 accumulator = accumulator * alpha + bias1 * beta 

76 c = accumulator.to(bias.dtype) 

77 tl.store(c_ptrs, c, mask=c_mask) 

78 

79 

80def addmm(bias, mat1, mat2, *, beta=1, alpha=1): 

81 logger.debug("GEMS_ASCEND ADDMM") 

82 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

83 assert broadcastable_to( 

84 bias.shape, (mat1.shape[0], mat2.shape[1]) 

85 ), "Incompatible input shape" 

86 M, K = mat1.shape 

87 _, N = mat2.shape 

88 

89 mat1 = mat1.contiguous() 

90 mat2 = mat2.contiguous() 

91 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

92 bias = bias.broadcast_to(out.shape).contiguous() 

93 

94 grid = lambda META: ( 

95 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

96 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

97 ) 

98 with torch_device_fn.device(mat1.device): 

99 addmm_kernel[grid]( 

100 mat1, 

101 mat2, 

102 bias, 

103 out, 

104 alpha, 

105 beta, 

106 M, 

107 N, 

108 K, 

109 mat1.stride(0), 

110 mat1.stride(1), 

111 mat2.stride(0), 

112 mat2.stride(1), 

113 bias.stride(0), 

114 bias.stride(1), 

115 out.stride(0), 

116 out.stride(1), 

117 ) 

118 return out