Coverage for src/flag_gems/runtime/backend/_cambricon/fused/outer.py: 0%
75 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import logging
3import torch
4import triton
5from triton import language as tl
7from flag_gems.utils import libentry
9from ..ops import mv
10from ..utils import TOTAL_CORE_NUM
12logger = logging.getLogger(__name__)
15# The outer kernel requires 3 parameters to determine the splitting method,
16# but during actual tuning, you only need to determine the total size of the split blocks.
17# Based on the second input length N and the total size of the split blocks,
18# the 3 parameters that determine the splitting method can be calculated.
19# Therefore, the conversion between these two is achieved through early_config_prune.
20def early_config_prune(configs, named_args, **kwargs):
21 if "N" in kwargs:
22 N = kwargs["N"]
23 else:
24 N = named_args["N"]
26 new_configs = []
27 for config in configs:
28 tile_size = config.kwargs["tile_size"]
29 block_n = min(tile_size, N)
30 block_m = triton.cdiv(tile_size, block_n)
31 new_config = triton.Config(
32 {"BLOCK_M": block_m, "BLOCK_N": block_n, "NEED_LOOP_N": block_n < N},
33 num_stages=config.num_stages,
34 num_warps=config.num_warps,
35 )
36 new_configs.append(new_config)
38 return new_configs
41@libentry()
42@triton.autotune(
43 configs=[
44 triton.Config({"tile_size": 1024}, num_stages=3, num_warps=1),
45 triton.Config({"tile_size": 2048}, num_stages=3, num_warps=1),
46 triton.Config({"tile_size": 4096}, num_stages=3, num_warps=1),
47 triton.Config({"tile_size": 8192}, num_stages=3, num_warps=1),
48 triton.Config({"tile_size": 16384}, num_stages=3, num_warps=1),
49 triton.Config({"tile_size": 21760}, num_stages=3, num_warps=1),
50 triton.Config({"tile_size": 32768}, num_stages=3, num_warps=1),
51 ],
52 key=["M", "N"],
53 prune_configs_by={"early_config_prune": early_config_prune},
54)
55@triton.jit
56def outer_kernel(
57 lhs,
58 rhs,
59 res,
60 M,
61 N,
62 BLOCK_M: tl.constexpr,
63 BLOCK_N: tl.constexpr,
64 NEED_LOOP_N: tl.constexpr,
65):
66 pid = tl.program_id(0)
67 num_jobs = tl.num_programs(axis=0)
69 m_tasks_num = tl.cdiv(M, BLOCK_M)
70 n_tasks_num = tl.cdiv(N, BLOCK_N)
71 total_tasks_num = m_tasks_num * n_tasks_num
73 if NEED_LOOP_N:
74 for task_id in range(pid, total_tasks_num, num_jobs):
75 start_m = task_id // n_tasks_num
76 start_n = task_id % n_tasks_num
78 offset_m = tl.arange(0, BLOCK_M) + start_m * BLOCK_M
79 lhs_val = tl.load(lhs + offset_m, mask=offset_m < M)
81 offset_n = tl.arange(0, BLOCK_N) + start_n * BLOCK_N
82 rhs_val = tl.load(rhs + offset_n, mask=offset_n < N)
84 res_val = lhs_val[:, None] * rhs_val[None, :]
86 offset_r = offset_m[:, None] * N + offset_n[None, :]
87 tl.store(
88 res + offset_r,
89 res_val,
90 mask=(offset_m[:, None] < M) & (offset_n[None, :] < N),
91 )
92 else:
93 offset_n = tl.arange(0, BLOCK_N)
94 rhs_val = tl.load(rhs + offset_n)
95 for task_id in range(pid, total_tasks_num, num_jobs):
96 start_m = task_id // n_tasks_num
98 offset_m = tl.arange(0, BLOCK_M) + start_m * BLOCK_M
99 lhs_val = tl.load(lhs + offset_m, mask=offset_m < M)
101 res_val = lhs_val[:, None] * rhs_val[None, :]
103 offset_r = offset_m[:, None] * N + offset_n[None, :]
104 tl.store(
105 res + offset_r,
106 res_val,
107 mask=(offset_m[:, None] < M) & (offset_n[None, :] < N),
108 )
111def outer_(lhs, rhs):
112 m = lhs.shape[0]
113 n = rhs.shape[0]
114 res_shape = [m, n]
115 res = torch.empty(res_shape, dtype=lhs.dtype, device="mlu")
116 grid = lambda META: (
117 min(
118 triton.cdiv(m, META["BLOCK_M"]) * triton.cdiv(n, META["BLOCK_N"]),
119 TOTAL_CORE_NUM,
120 ),
121 )
122 outer_kernel[grid](lhs, rhs, res, m, n)
123 return res
126class Outer(torch.autograd.Function):
127 @staticmethod
128 def forward(ctx, inp, weight):
129 logger.debug("GEMS_CAMBRICON OUTER")
130 assert inp.ndim == 1 and weight.ndim == 1, "Invalid input"
131 out = outer_(inp, weight)
132 ctx.save_for_backward(inp, weight)
133 return out
135 @staticmethod
136 def backward(ctx, out_grad):
137 logger.debug("GEMS_CAMBRICON OUTER VJP")
138 assert out_grad.ndim == 2, "invalide out_grad shape"
140 inp, weight = ctx.saved_tensors
142 inp_grad = mv(out_grad, weight)
143 weight_grad = mv(out_grad.t(), inp)
145 return inp_grad, weight_grad
148def outer(inp, weight):
149 return Outer.apply(inp, weight)