Coverage for src/flag_gems/runtime/backend/_metax/ops/outer.py: 0%

58 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7# from flag_gems.ops.mul import mul 

8from flag_gems.ops.mv import mv 

9 

10logger = logging.getLogger("flag_gems." + __name__) 

11 

12 

13@triton.jit 

14def mul_outer_kernel( 

15 inp, 

16 weight, 

17 out, 

18 M, 

19 N, 

20 stride_m, 

21 stride_n, 

22 BLOCK_SIZE_M: tl.constexpr, 

23 BLOCK_SIZE_N: tl.constexpr, 

24): 

25 pid_x = tl.program_id(axis=0) 

26 pid_y = tl.program_id(axis=1) 

27 n_range = tl.arange(0, BLOCK_SIZE_N) 

28 weight_block_start = pid_y * BLOCK_SIZE_N 

29 weight_offsets = weight_block_start + n_range[None, :] 

30 mask_2 = weight_offsets < N 

31 weight_data = tl.load(weight + weight_offsets, mask=mask_2) 

32 for i in range(0, BLOCK_SIZE_M): 

33 inp_offsets = pid_x * BLOCK_SIZE_M + i 

34 mask_1 = inp_offsets < M 

35 output_offsets = (pid_x * BLOCK_SIZE_M + i) * N + weight_offsets 

36 # mask_3 = output_offsets < (M * N) 

37 inp_data = tl.load(inp + inp_offsets, mask=mask_1) 

38 inp_bd, weight_bd = tl.broadcast(inp_data, weight_data) 

39 output = inp_bd * weight_bd 

40 tl.store(out + output_offsets, output, mask=mask_2) 

41 

42 

43def mul(inp, weight): 

44 assert inp.ndim == 2 and weight.ndim == 2, "Invalid input" 

45 assert inp.shape[1] == 1 and weight.shape[0] == 1, "Invalid input" 

46 M = inp.shape[0] 

47 N = weight.shape[1] 

48 out = torch.empty((M, N), device=inp.device, dtype=inp.dtype) 

49 num_warps = 1 

50 BLOCK_SIZE_M = 8 

51 BLOCK_SIZE_N = 512 

52 grid = lambda META: (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N)) 

53 with torch.cuda.device(inp.device): 

54 mul_outer_kernel[grid]( 

55 inp, 

56 weight, 

57 out, 

58 M, 

59 N, 

60 inp.stride(0), 

61 weight.stride(1), 

62 BLOCK_SIZE_M=BLOCK_SIZE_M, 

63 BLOCK_SIZE_N=BLOCK_SIZE_N, 

64 num_warps=num_warps, 

65 ) 

66 return out 

67 

68 

69class Outer(torch.autograd.Function): 

70 @staticmethod 

71 def forward(ctx, inp, weight): 

72 logger.debug("METAX GEMS OUTER") 

73 assert inp.ndim == 1 and weight.ndim == 1, "Invalid input" 

74 inp1 = inp[:, None] 

75 weight1 = weight[None, :] 

76 inp1 = inp1.contiguous() 

77 weight1 = weight1.contiguous() 

78 out = mul(inp1, weight1) 

79 ctx.save_for_backward(inp, weight) 

80 return out 

81 

82 @staticmethod 

83 def backward(ctx, out_grad): 

84 logger.debug("METAX GEMS OUTER VJP") 

85 assert out_grad.ndim == 2, "invalide out_grad shape" 

86 

87 inp, weight = ctx.saved_tensors 

88 

89 inp_grad = mv(out_grad, weight) 

90 weight_grad = mv(out_grad.t(), inp) 

91 

92 return inp_grad, weight_grad 

93 

94 

95def outer(inp, weight): 

96 return Outer.apply(inp, weight)