Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/prod.py: 0%

90 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-15 02:11 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7# from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import dim_compress, libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12from ..utils.block_size_utils import get_block_size_1d 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15 

16 

17@triton.jit 

18def reduce_mul(a, b): 

19 return a * b 

20 

21 

22@libentry() 

23@triton.jit 

24def prod_kernel_mid( 

25 inp, 

26 mid, 

27 M, 

28 BLOCK_SIZE: tl.constexpr, 

29): 

30 pid = tle.program_id(0) 

31 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

32 inp_ptrs = inp + offset 

33 mask = offset < M 

34 inp_val = tl.load(inp_ptrs, mask=mask, other=1.0).to(tl.float32) 

35 mid_value = tl.reduce(inp_val, axis=0, combine_fn=reduce_mul) 

36 mid_ptr = mid + pid 

37 tl.store(mid_ptr, mid_value.to(inp_val.dtype)) 

38 

39 

40@libentry() 

41@triton.jit 

42def prod_kernel_result(mid, out, mid_size, BLOCK_MID: tl.constexpr): 

43 offset = tl.arange(0, BLOCK_MID) 

44 mid_ptrs = mid + offset 

45 mask = offset < mid_size 

46 mid_val = tl.load(mid_ptrs, mask=mask, other=1.0).to(tl.float32) 

47 prod_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_mul) 

48 tl.store(out, prod_val) 

49 

50 

51def prod(inp, *, dtype=None): 

52 logger.debug("GEMS PROD") 

53 if dtype is None: 

54 dtype = inp.dtype 

55 

56 M = inp.numel() 

57 # block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

58 block_size = get_block_size_1d(M, inp.element_size()) 

59 mid_size = triton.cdiv(M, block_size) 

60 block_mid = triton.next_power_of_2(mid_size) 

61 

62 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

63 out = torch.empty([], dtype=dtype, device=inp.device) 

64 

65 with torch_device_fn.device(inp.device): 

66 prod_kernel_mid[(mid_size, 1, 1)]( 

67 inp, mid, M, block_size, buffer_size_limit=2048 

68 ) 

69 if mid_size == 1: 

70 return mid.reshape([]) 

71 prod_kernel_result[(1, 1, 1)]( 

72 mid, out, mid_size, block_mid, buffer_size_limit=2048 

73 ) 

74 return out 

75 

76 

77def heur_m_block_size(args): 

78 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num 

79 

80 

81def heur_n_block_size(args): 

82 import builtins 

83 

84 return builtins.min(triton.next_power_of_2(args["N"]), 8192) 

85 

86 

87@libentry() 

88@triton.heuristics( 

89 values={ 

90 "BLOCK_M": heur_m_block_size, 

91 "BLOCK_N": heur_n_block_size, 

92 }, 

93) 

94@triton.jit 

95def prod_kernel( 

96 inp, 

97 out, 

98 M, 

99 N, 

100 BLOCK_M: tl.constexpr, 

101 BLOCK_N: tl.constexpr, 

102): 

103 # set offset 

104 pid_m = tle.program_id(0) 

105 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

106 

107 acc = tl.full((BLOCK_M, BLOCK_N), value=1.0, dtype=tl.float32) 

108 for start_n in range(0, N, BLOCK_N): 

109 n_offset = start_n + tl.arange(0, BLOCK_N) 

110 offset = m_offset[:, None] * N + n_offset[None, :] 

111 

112 # set mask 

113 mask = m_offset[:, None] < M and n_offset[None, :] < N 

114 inp_ptrs = inp + offset 

115 inp_vals = tl.load(inp_ptrs, mask=mask, other=1.0).to(tl.float32) 

116 acc *= inp_vals 

117 result_index = tl.reduce(acc, axis=1, combine_fn=reduce_mul) 

118 

119 offset_index = m_offset 

120 out_ptrs = out + offset_index 

121 mask1 = m_offset < M 

122 tl.store(out_ptrs, result_index, mask=mask1) 

123 

124 

125def prod_dim(inp, dim=None, keepdim=False, *, dtype=None): 

126 logger.debug("GEMS PROD DIM") 

127 

128 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

129 shape = list(inp.shape) 

130 dim = dim % inp.ndim 

131 inp = dim_compress(inp, dim) 

132 N = shape[dim] 

133 shape[dim] = 1 

134 M = inp.numel() // N 

135 

136 if dtype is None: 

137 dtype = inp.dtype 

138 out = torch.empty(shape, dtype=dtype, device=inp.device) 

139 if not keepdim: 

140 out = torch.squeeze(out, dim) 

141 

142 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

143 with torch_device_fn.device(inp.device): 

144 prod_kernel[grid](inp, out, M, N, buffer_size_limit=2048) 

145 

146 return out