Coverage for src/flag_gems/runtime/backend/_ascend/ops/addmm.py: 0%
51 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import broadcastable_to, libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
15@libentry()
16@triton.autotune(
17 configs=runtime.get_tuned_config("addmm"),
18 key=["M", "N", "K"],
19)
20@triton.jit(do_not_specialize=["alpha", "beta"])
21def addmm_kernel(
22 a_ptr,
23 b_ptr,
24 i_ptr,
25 c_ptr,
26 alpha,
27 beta,
28 M,
29 N,
30 K,
31 stride_am,
32 stride_ak,
33 stride_bk,
34 stride_bn,
35 stride_im,
36 stride_in,
37 stride_cm,
38 stride_cn,
39 BLOCK_SIZE_M: tl.constexpr,
40 BLOCK_SIZE_N: tl.constexpr,
41 BLOCK_SIZE_K: tl.constexpr,
42):
43 pid_m = tle.program_id(0)
44 pid_n = tle.program_id(1)
46 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
47 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
48 offs_k = tl.arange(0, BLOCK_SIZE_K)
49 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
50 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
52 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
53 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
54 a = tl.load(
55 a_ptrs,
56 mask=(offs_k[None, :] < K - k * BLOCK_SIZE_K),
57 other=0.0,
58 )
59 b = tl.load(
60 b_ptrs,
61 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K),
62 other=0.0,
63 )
64 accumulator += tl.dot(a, b, allow_tf32=False)
65 a_ptrs += BLOCK_SIZE_K * stride_ak
66 b_ptrs += BLOCK_SIZE_K * stride_bk
68 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
69 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
70 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
71 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
72 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :]
73 bias = tl.load(i_ptrs, mask=c_mask, other=0.0)
74 bias1 = bias.to(accumulator.dtype)
75 accumulator = accumulator * alpha + bias1 * beta
76 c = accumulator.to(bias.dtype)
77 tl.store(c_ptrs, c, mask=c_mask)
80def addmm(bias, mat1, mat2, *, beta=1, alpha=1):
81 logger.debug("GEMS_ASCEND ADDMM")
82 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
83 assert broadcastable_to(
84 bias.shape, (mat1.shape[0], mat2.shape[1])
85 ), "Incompatible input shape"
86 M, K = mat1.shape
87 _, N = mat2.shape
89 mat1 = mat1.contiguous()
90 mat2 = mat2.contiguous()
91 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
92 bias = bias.broadcast_to(out.shape).contiguous()
94 grid = lambda META: (
95 triton.cdiv(M, META["BLOCK_SIZE_M"]),
96 triton.cdiv(N, META["BLOCK_SIZE_N"]),
97 )
98 with torch_device_fn.device(mat1.device):
99 addmm_kernel[grid](
100 mat1,
101 mat2,
102 bias,
103 out,
104 alpha,
105 beta,
106 M,
107 N,
108 K,
109 mat1.stride(0),
110 mat1.stride(1),
111 mat2.stride(0),
112 mat2.stride(1),
113 bias.stride(0),
114 bias.stride(1),
115 out.stride(0),
116 out.stride(1),
117 )
118 return out