Coverage for src/flag_gems/ops/div.py: 59%

172 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic 

8from flag_gems.utils.triton_lang_extension import div_rn, div_rz, fmod, trunc 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")]) 

14@triton.jit 

15def true_div_func(x, y): 

16 return x / y 

17 

18 

19@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")]) 

20@triton.jit 

21def true_div_func_tensor_scalar(x, y): 

22 return x / y 

23 

24 

25@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "INT_TO_FLOAT")]) 

26@triton.jit 

27def true_div_func_scalar_tensor(x, y): 

28 return x / y 

29 

30 

31def true_divide(A, B): 

32 logger.debug("GEMS TRUE_DIVIDE") 

33 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

34 return true_div_func(A, B) 

35 elif isinstance(A, torch.Tensor): 

36 return true_div_func_tensor_scalar(A, B) 

37 elif isinstance(B, torch.Tensor): 

38 return true_div_func_scalar_tensor(A, B) 

39 else: 

40 # Both scalar 

41 return torch.tensor(A / B) 

42 

43 

44def true_divide_out(A, B, out): 

45 logger.debug("GEMS TRUE_DIVIDE OUT") 

46 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

47 return true_div_func(A, B, out0=out) 

48 elif isinstance(A, torch.Tensor): 

49 return true_div_func_tensor_scalar(A, B, out0=out) 

50 elif isinstance(B, torch.Tensor): 

51 return true_div_func_scalar_tensor(A, B, out0=out) 

52 else: 

53 # Both scalar 

54 return torch.tensor(A / B) if out is None else out.fill_(A / B) 

55 

56 

57def true_divide_(A, B): 

58 logger.debug("GEMS TRUE_DIVIDE_") 

59 if isinstance(B, torch.Tensor): 

60 return true_div_func(A, B, out0=A) 

61 else: 

62 return true_div_func_tensor_scalar(A, B, out0=A) 

63 

64 

65@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

66@triton.jit 

67def trunc_div_func(x, y): 

68 return trunc(div_rz(x, y)) 

69 

70 

71@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

72@triton.jit 

73def trunc_div_func_tensor_scalar(x, y): 

74 return trunc(div_rz(x, y)) 

75 

76 

77@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

78@triton.jit 

79def trunc_div_func_scalar_tensor(x, y): 

80 return trunc(div_rz(x, y)) 

81 

82 

83def trunc_divide(A, B): 

84 logger.debug("GEMS TRUNC_DIVIDE") 

85 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

86 return trunc_div_func(A, B) 

87 elif isinstance(A, torch.Tensor): 

88 return trunc_div_func_tensor_scalar(A, B) 

89 elif isinstance(B, torch.Tensor): 

90 return trunc_div_func_scalar_tensor(A, B) 

91 else: 

92 # Both scalar 

93 return torch.tensor(A / B) 

94 

95 

96def trunc_divide_(A, B): 

97 logger.debug("GEMS TRUNC_DIVIDE_") 

98 if isinstance(B, torch.Tensor): 

99 return trunc_div_func(A, B, out0=A) 

100 else: 

101 return trunc_div_func_tensor_scalar(A, B, out0=A) 

102 

103 

104@triton.jit 

105def _int_floordiv(x, y): 

106 # TODO: request Triton to add an integer remainder builtin 

107 # The semantic of Triton floordiv differs from Pytorch/Numpy 

108 # Triton floordiv equates to 

109 # (x - np.fmod(x, y)) / y 

110 # whereas Pytorch floordiv is 

111 # (x - np.remainder(x, y)) y 

112 # The results show a one off difference when 

113 # C1) x and y have opposite signs 

114 # and C2) x is not multiples of y. 

115 # Apart from the above, there's an erroneous case x // 0 returns -1 

116 # whereas in Pytorch x // 0 returns -1 if x >=0 and -2 if x < 0 

117 # but this special case is coalesced into the c1 and c2 check so 

118 # there's extra handling. 

119 r = x % y 

120 c1 = r != 0 

121 c2 = (x < 0) ^ (y < 0) 

122 return tl.where(c1 & c2, x // y - 1, x // y) 

123 

124 

125# TO be consistent with python, numpy and torch, we have to implement it in the 

126# following way. 

127# CPython 

128# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 

129# numpy 

130# https://github.com/numpy/numpy/blob/a4ad142aa1282a77bbb05acd706cb57c9cc29846/numpy/_core/src/npymath/npy_math_internal.h.src#L532 

131# torch 

132# https://github.com/pytorch/pytorch/blob/d6d9183456cd07ca0b361a194b98c2fb196e7c36/c10/util/generic_math.h#L23 

133@triton.jit 

134def _float_floordiv(x, y): 

135 # NOTE: fmod's sign is the same as the dividend 

136 remainder = fmod(x, y) 

137 imperfect = remainder != 0.0 

138 different_sign = (x < 0) ^ (y < 0) 

139 

140 # NOTE: we have to use div_rn explicitly here 

141 q = div_rn(x - remainder, y) 

142 q = tl.where(imperfect & different_sign, q - 1, q) 

143 

144 floor_q = tl.math.floor(q) 

145 c = q - floor_q > 0.5 

146 floor_q = tl.where(c, floor_q + 1.0, floor_q) 

147 

148 q_is_zeros = q == 0.0 

149 floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q) 

150 

151 is_div_by_zero = y == 0.0 

152 float_division = x / y 

153 out = tl.where(is_div_by_zero, float_division, floor_q) 

154 return out 

155 

156 

157@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

158@triton.jit 

159def floor_div_func(x, y): 

160 if x.type.scalar.is_int() & x.type.scalar.is_int(): 

161 return _int_floordiv(x, y) 

162 else: 

163 return _float_floordiv(x, y) 

164 

165 

166@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

167@triton.jit 

168def floor_div_func_tensor_scalar(x, y): 

169 if x.type.scalar.is_int() & x.type.scalar.is_int(): 

170 return _int_floordiv(x, y) 

171 else: 

172 return _float_floordiv(x, y) 

173 

174 

175@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

176@triton.jit 

177def floor_div_func_scalar_tensor(x, y): 

178 if x.type.scalar.is_int() & x.type.scalar.is_int(): 

179 return _int_floordiv(x, y) 

180 else: 

181 return _float_floordiv(x, y) 

182 

183 

184def floor_divide(A, B): 

185 logger.debug("GEMS FLOOR_DIVIDE") 

186 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

187 return floor_div_func(A, B) 

188 elif isinstance(A, torch.Tensor): 

189 return floor_div_func_tensor_scalar(A, B) 

190 elif isinstance(B, torch.Tensor): 

191 return floor_div_func_scalar_tensor(A, B) 

192 else: 

193 # Both scalar 

194 return torch.tensor(A // B) 

195 

196 

197def floor_divide_(A, B): 

198 logger.debug("GEMS FLOOR_DIVIDE_") 

199 if isinstance(B, torch.Tensor): 

200 return floor_div_func(A, B, out0=A) 

201 else: 

202 return floor_div_func_tensor_scalar(A, B, out0=A) 

203 

204 

205def div_mode(A, B, rounding_mode=None): 

206 if rounding_mode is None: 

207 return true_divide(A, B) 

208 elif rounding_mode == "trunc": 

209 return trunc_divide(A, B) 

210 elif rounding_mode == "floor": 

211 return floor_divide(A, B) 

212 else: 

213 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." 

214 raise ValueError(msg) 

215 

216 

217def div_mode_(A, B, rounding_mode=None): 

218 if rounding_mode is None: 

219 return true_divide_(A, B) 

220 elif rounding_mode == "trunc": 

221 return trunc_divide_(A, B) 

222 elif rounding_mode == "floor": 

223 return floor_divide_(A, B) 

224 else: 

225 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." 

226 raise ValueError(msg) 

227 

228 

229@triton.jit 

230def _remainder(x, y): 

231 r = x % y 

232 c1 = r != 0 

233 c2 = (x < 0) ^ (y < 0) 

234 return tl.where(c1 & c2, r + y, r) 

235 

236 

237@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

238@triton.jit 

239def rem_tt(x, y): 

240 return _remainder(x, y) 

241 

242 

243@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

244@triton.jit 

245def rem_ts(x, y): 

246 return _remainder(x, y) 

247 

248 

249@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

250@triton.jit 

251def rem_st(x, y): 

252 return _remainder(x, y) 

253 

254 

255def remainder(A, B): 

256 logger.debug("GEMS REMAINDER") 

257 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

258 return rem_tt(A, B) 

259 elif isinstance(A, torch.Tensor): 

260 return rem_ts(A, B) 

261 elif isinstance(B, torch.Tensor): 

262 return rem_st(A, B) 

263 else: 

264 # Both scalar 

265 return torch.tensor(A % B) 

266 

267 

268def remainder_(A, B): 

269 logger.debug("GEMS REMAINDER_") 

270 if isinstance(B, torch.Tensor): 

271 return rem_tt(A, B, out0=A) 

272 else: 

273 return rem_ts(A, B, out0=A)