Coverage for src/flag_gems/runtime/backend/_metax/ops/outer.py: 0%
58 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7# from flag_gems.ops.mul import mul
8from flag_gems.ops.mv import mv
10logger = logging.getLogger("flag_gems." + __name__)
13@triton.jit
14def mul_outer_kernel(
15 inp,
16 weight,
17 out,
18 M,
19 N,
20 stride_m,
21 stride_n,
22 BLOCK_SIZE_M: tl.constexpr,
23 BLOCK_SIZE_N: tl.constexpr,
24):
25 pid_x = tl.program_id(axis=0)
26 pid_y = tl.program_id(axis=1)
27 n_range = tl.arange(0, BLOCK_SIZE_N)
28 weight_block_start = pid_y * BLOCK_SIZE_N
29 weight_offsets = weight_block_start + n_range[None, :]
30 mask_2 = weight_offsets < N
31 weight_data = tl.load(weight + weight_offsets, mask=mask_2)
32 for i in range(0, BLOCK_SIZE_M):
33 inp_offsets = pid_x * BLOCK_SIZE_M + i
34 mask_1 = inp_offsets < M
35 output_offsets = (pid_x * BLOCK_SIZE_M + i) * N + weight_offsets
36 # mask_3 = output_offsets < (M * N)
37 inp_data = tl.load(inp + inp_offsets, mask=mask_1)
38 inp_bd, weight_bd = tl.broadcast(inp_data, weight_data)
39 output = inp_bd * weight_bd
40 tl.store(out + output_offsets, output, mask=mask_2)
43def mul(inp, weight):
44 assert inp.ndim == 2 and weight.ndim == 2, "Invalid input"
45 assert inp.shape[1] == 1 and weight.shape[0] == 1, "Invalid input"
46 M = inp.shape[0]
47 N = weight.shape[1]
48 out = torch.empty((M, N), device=inp.device, dtype=inp.dtype)
49 num_warps = 1
50 BLOCK_SIZE_M = 8
51 BLOCK_SIZE_N = 512
52 grid = lambda META: (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N))
53 with torch.cuda.device(inp.device):
54 mul_outer_kernel[grid](
55 inp,
56 weight,
57 out,
58 M,
59 N,
60 inp.stride(0),
61 weight.stride(1),
62 BLOCK_SIZE_M=BLOCK_SIZE_M,
63 BLOCK_SIZE_N=BLOCK_SIZE_N,
64 num_warps=num_warps,
65 )
66 return out
69class Outer(torch.autograd.Function):
70 @staticmethod
71 def forward(ctx, inp, weight):
72 logger.debug("METAX GEMS OUTER")
73 assert inp.ndim == 1 and weight.ndim == 1, "Invalid input"
74 inp1 = inp[:, None]
75 weight1 = weight[None, :]
76 inp1 = inp1.contiguous()
77 weight1 = weight1.contiguous()
78 out = mul(inp1, weight1)
79 ctx.save_for_backward(inp, weight)
80 return out
82 @staticmethod
83 def backward(ctx, out_grad):
84 logger.debug("METAX GEMS OUTER VJP")
85 assert out_grad.ndim == 2, "invalide out_grad shape"
87 inp, weight = ctx.saved_tensors
89 inp_grad = mv(out_grad, weight)
90 weight_grad = mv(out_grad.t(), inp)
92 return inp_grad, weight_grad
95def outer(inp, weight):
96 return Outer.apply(inp, weight)