Coverage for src/flag_gems/runtime/backend/_ascend/ops/outer.py: 0%
74 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import logging
3import torch
4import triton
5from triton import language as tl
7from flag_gems.ops.mv import mv
8from flag_gems.utils import libentry
10logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
13# The outer kernel requires 3 parameters to determine the splitting method,
14# but during actual tuning, you only need to determine the total size of the split blocks.
15# Based on the second input length N and the total size of the split blocks,
16# the 3 parameters that determine the splitting method can be calculated.
17# Therefore, the conversion between these two is achieved through early_config_prune.
18def early_config_prune(configs, named_args, **kwargs):
19 if "N" in kwargs:
20 N = kwargs["N"]
21 else:
22 N = named_args["N"]
24 new_configs = []
25 for config in configs:
26 tile_size = config.kwargs["tile_size"]
27 block_n = min(tile_size, N)
28 block_m = triton.cdiv(tile_size, block_n)
29 new_config = triton.Config(
30 {"BLOCK_M": block_m, "BLOCK_N": block_n, "NEED_LOOP_N": block_n < N},
31 num_stages=config.num_stages,
32 num_warps=config.num_warps,
33 )
34 new_configs.append(new_config)
36 return new_configs
39@libentry()
40@triton.autotune(
41 configs=[
42 triton.Config({"tile_size": 1024}, num_stages=3, num_warps=1),
43 triton.Config({"tile_size": 2048}, num_stages=3, num_warps=1),
44 triton.Config({"tile_size": 4096}, num_stages=3, num_warps=1),
45 triton.Config({"tile_size": 8192}, num_stages=3, num_warps=1),
46 triton.Config({"tile_size": 16384}, num_stages=3, num_warps=1),
47 triton.Config({"tile_size": 21760}, num_stages=3, num_warps=1),
48 triton.Config({"tile_size": 32768}, num_stages=3, num_warps=1),
49 ],
50 key=["M", "N"],
51 prune_configs_by={"early_config_prune": early_config_prune},
52)
53@triton.jit
54def outer_kernel(
55 lhs,
56 rhs,
57 res,
58 M,
59 N,
60 BLOCK_M: tl.constexpr,
61 BLOCK_N: tl.constexpr,
62 NEED_LOOP_N: tl.constexpr,
63):
64 pid = tl.program_id(0)
65 num_jobs = tl.num_programs(axis=0)
67 m_tasks_num = tl.cdiv(M, BLOCK_M)
68 n_tasks_num = tl.cdiv(N, BLOCK_N)
69 total_tasks_num = m_tasks_num * n_tasks_num
71 if NEED_LOOP_N:
72 for task_id in range(pid, total_tasks_num, num_jobs):
73 start_m = task_id // n_tasks_num
74 start_n = task_id % n_tasks_num
76 offset_m = tl.arange(0, BLOCK_M) + start_m * BLOCK_M
77 lhs_val = tl.load(lhs + offset_m, mask=offset_m < M)
79 offset_n = tl.arange(0, BLOCK_N) + start_n * BLOCK_N
80 rhs_val = tl.load(rhs + offset_n, mask=offset_n < N)
82 res_val = lhs_val[:, None] * rhs_val[None, :]
84 offset_r = offset_m[:, None] * N + offset_n[None, :]
85 tl.store(
86 res + offset_r,
87 res_val,
88 mask=(offset_m[:, None] < M) & (offset_n[None, :] < N),
89 )
90 else:
91 offset_n = tl.arange(0, BLOCK_N)
92 rhs_val = tl.load(rhs + offset_n)
93 for task_id in range(pid, total_tasks_num, num_jobs):
94 start_m = task_id // n_tasks_num
96 offset_m = tl.arange(0, BLOCK_M) + start_m * BLOCK_M
97 lhs_val = tl.load(lhs + offset_m, mask=offset_m < M)
99 res_val = lhs_val[:, None] * rhs_val[None, :]
101 offset_r = offset_m[:, None] * N + offset_n[None, :]
102 tl.store(
103 res + offset_r,
104 res_val,
105 mask=(offset_m[:, None] < M) & (offset_n[None, :] < N),
106 )
109def outer_(lhs, rhs):
110 m = lhs.shape[0]
111 n = rhs.shape[0]
112 res_shape = [m, n]
113 res = torch.empty(res_shape, dtype=lhs.dtype, device="npu")
114 grid = lambda META: (
115 min(
116 triton.cdiv(m, META["BLOCK_M"]) * triton.cdiv(n, META["BLOCK_N"]),
117 65535,
118 ),
119 )
120 outer_kernel[grid](lhs, rhs, res, m, n)
121 return res
124class Outer(torch.autograd.Function):
125 @staticmethod
126 def forward(ctx, inp, weight):
127 logger.debug("GEMS_ASCEND OUTER")
128 assert inp.ndim == 1 and weight.ndim == 1, "Invalid input"
129 out = outer_(inp, weight)
130 ctx.save_for_backward(inp, weight)
131 return out
133 @staticmethod
134 def backward(ctx, out_grad):
135 logger.debug("GEMS_ASCEND OUTER VJP")
136 assert out_grad.ndim == 2, "invalide out_grad shape"
138 inp, weight = ctx.saved_tensors
140 inp_grad = mv(out_grad, weight)
141 weight_grad = mv(out_grad.t(), inp)
143 return inp_grad, weight_grad
146def outer(inp, weight):
147 return Outer.apply(inp, weight)