Coverage for src/flag_gems/ops/prod.py: 64%

86 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import dim_compress, libentry, libtuner 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16@triton.jit 

17def reduce_mul(a, b): 

18 return a * b 

19 

20 

21@libentry() 

22@triton.jit 

23def prod_kernel_mid( 

24 inp, 

25 mid, 

26 M, 

27 BLOCK_SIZE: tl.constexpr, 

28): 

29 pid = tle.program_id(0) 

30 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

31 inp_ptrs = inp + offset 

32 mask = offset < M 

33 inp_val = tl.load(inp_ptrs, mask=mask, other=1.0).to(tl.float32) 

34 mid_value = tl.reduce(inp_val, axis=0, combine_fn=reduce_mul) 

35 mid_ptr = mid + pid 

36 tl.store(mid_ptr, mid_value.to(inp_val.dtype)) 

37 

38 

39@libentry() 

40@triton.jit 

41def prod_kernel_result(mid, out, mid_size, BLOCK_MID: tl.constexpr): 

42 offset = tl.arange(0, BLOCK_MID) 

43 mid_ptrs = mid + offset 

44 mask = offset < mid_size 

45 mid_val = tl.load(mid_ptrs, mask=mask, other=1.0).to(tl.float32) 

46 prod_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_mul) 

47 tl.store(out, prod_val) 

48 

49 

50def prod(inp, *, dtype=None): 

51 logger.debug("GEMS PROD") 

52 if dtype is None: 

53 dtype = inp.dtype 

54 

55 M = inp.numel() 

56 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

57 mid_size = triton.cdiv(M, block_size) 

58 block_mid = triton.next_power_of_2(mid_size) 

59 

60 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

61 out = torch.empty([], dtype=dtype, device=inp.device) 

62 

63 with torch_device_fn.device(inp.device): 

64 prod_kernel_mid[(mid_size, 1, 1)](inp, mid, M, block_size) 

65 prod_kernel_result[(1, 1, 1)](mid, out, mid_size, block_mid) 

66 return out 

67 

68 

69def heur_block_n(args): 

70 return triton.next_power_of_2(args["N"]) 

71 

72 

73@libentry() 

74@libtuner( 

75 configs=runtime.get_tuned_config("naive_reduction"), 

76 key=["M", "N"], 

77) 

78@triton.jit 

79def prod_kernel( 

80 inp, 

81 out, 

82 M, 

83 N, 

84 BLOCK_M: tl.constexpr, 

85 BLOCK_N: tl.constexpr, 

86): 

87 # set offset 

88 pid_m = tle.program_id(0) 

89 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

90 

91 acc = tl.full((BLOCK_M, BLOCK_N), value=1.0, dtype=tl.float32) 

92 for start_n in range(0, N, BLOCK_N): 

93 n_offset = start_n + tl.arange(0, BLOCK_N) 

94 offset = m_offset[:, None] * N + n_offset[None, :] 

95 

96 # set mask 

97 mask = (m_offset[:, None] < M) & (n_offset[None, :] < N) 

98 inp_ptrs = inp + offset 

99 inp_vals = tl.load(inp_ptrs, mask=mask, other=1.0).to(tl.float32) 

100 acc *= inp_vals 

101 result_index = tl.reduce(acc, axis=1, combine_fn=reduce_mul) 

102 

103 offset_index = m_offset 

104 out_ptrs = out + offset_index 

105 mask1 = m_offset < M 

106 tl.store(out_ptrs, result_index, mask=mask1) 

107 

108 

109def prod_dim(inp, dim=None, keepdim=False, *, dtype=None): 

110 logger.debug("GEMS PROD DIM") 

111 

112 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

113 shape = list(inp.shape) 

114 dim = dim % inp.ndim 

115 inp = dim_compress(inp, dim) 

116 N = shape[dim] 

117 shape[dim] = 1 

118 M = inp.numel() // N 

119 

120 if dtype is None: 

121 dtype = inp.dtype 

122 out = torch.empty(shape, dtype=dtype, device=inp.device) 

123 if not keepdim: 

124 out = torch.squeeze(out, dim) 

125 

126 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

127 with torch_device_fn.device(inp.device): 

128 prod_kernel[grid](inp, out, M, N) 

129 

130 return out