Coverage for src/flag_gems/runtime/backend/_iluvatar/ops/div.py: 0%

166 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic, tl_extra_shim 

8 

9logger = logging.getLogger(__name__) 

10div_rn = tl_extra_shim.div_rn 

11div_rz = tl_extra_shim.div_rz 

12fmod = tl_extra_shim.fmod 

13trunc = tl_extra_shim.trunc 

14 

15 

16@pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")]) 

17@triton.jit 

18def true_div_func(x, y): 

19 return x / y 

20 

21 

22@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")]) 

23@triton.jit 

24def true_div_func_tensor_scalar(x, y): 

25 return x / y 

26 

27 

28@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "INT_TO_FLOAT")]) 

29@triton.jit 

30def true_div_func_scalar_tensor(x, y): 

31 return x / y 

32 

33 

34def true_divide(A, B): 

35 logger.debug("GEMS TRUE_DIVIDE") 

36 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

37 return true_div_func(A, B) 

38 elif isinstance(A, torch.Tensor): 

39 return true_div_func_tensor_scalar(A, B) 

40 elif isinstance(B, torch.Tensor): 

41 return true_div_func_scalar_tensor(A, B) 

42 else: 

43 # Both scalar 

44 return torch.tensor(A / B) 

45 

46 

47def true_divide_(A, B): 

48 logger.debug("GEMS TRUE_DIVIDE_") 

49 if isinstance(B, torch.Tensor): 

50 return true_div_func(A, B, out0=A) 

51 else: 

52 return true_div_func_tensor_scalar(A, B, out0=A) 

53 

54 

55@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

56@triton.jit 

57def trunc_div_func(x, y): 

58 return trunc(x / y) 

59 

60 

61@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

62@triton.jit 

63def trunc_div_func_tensor_scalar(x, y): 

64 return trunc(div_rz(x, y)) 

65 

66 

67@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

68@triton.jit 

69def trunc_div_func_scalar_tensor(x, y): 

70 return trunc(div_rz(x, y)) 

71 

72 

73def trunc_divide(A, B): 

74 logger.debug("GEMS TRUNC_DIVIDE") 

75 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

76 return trunc_div_func(A, B) 

77 elif isinstance(A, torch.Tensor): 

78 return trunc_div_func_tensor_scalar(A, B) 

79 elif isinstance(B, torch.Tensor): 

80 return trunc_div_func_scalar_tensor(A, B) 

81 else: 

82 # Both scalar 

83 return torch.tensor(A / B) 

84 

85 

86def trunc_divide_(A, B): 

87 logger.debug("GEMS TRUNC_DIVIDE_") 

88 if isinstance(B, torch.Tensor): 

89 return trunc_div_func(A, B, out0=A) 

90 else: 

91 return trunc_div_func_tensor_scalar(A, B, out0=A) 

92 

93 

94@triton.jit 

95def _int_floordiv(x, y): 

96 # TODO: request Triton to add an integer remainder builtin 

97 # The semantic of Triton floordiv differs from Pytorch/Numpy 

98 # Triton floordiv equates to 

99 # (x - np.fmod(x, y)) / y 

100 # whereas Pytorch floordiv is 

101 # (x - np.remainder(x, y)) y 

102 # The results show a one off difference when 

103 # C1) x and y have opposite signs 

104 # and C2) x is not multiples of y. 

105 # Apart from the above, there's an erroneous case x // 0 returns -1 

106 # whereas in Pytorch x // 0 returns -1 if x >=0 and -2 if x < 0 

107 # but this special case is coalesced into the c1 and c2 check so 

108 # there's extra handling. 

109 r = x % y 

110 c1 = r != 0 

111 c2 = (x < 0) ^ (y < 0) 

112 return tl.where(c1 & c2, x // y - 1, x // y) 

113 

114 

115# TO be consistent with python, numpy and torch, we have to implement it in the 

116# following way. 

117# CPython 

118# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 

119# numpy 

120# https://github.com/numpy/numpy/blob/a4ad142aa1282a77bbb05acd706cb57c9cc29846/numpy/_core/src/npymath/npy_math_internal.h.src#L532 

121# torch 

122# https://github.com/pytorch/pytorch/blob/d6d9183456cd07ca0b361a194b98c2fb196e7c36/c10/util/generic_math.h#L23 

123@triton.jit 

124def _float_floordiv(x, y): 

125 # NOTE: fmod's sign is the same as the dividend 

126 remainder = fmod(x, y) 

127 imperfect = remainder != 0.0 

128 different_sign = (x < 0) ^ (y < 0) 

129 

130 # NOTE: we have to use div_rn explicitly here 

131 q = div_rn(x - remainder, y) 

132 q = tl.where(imperfect & different_sign, q - 1, q) 

133 

134 floor_q = tl.math.floor(q) 

135 c = q - floor_q > 0.5 

136 floor_q = tl.where(c, floor_q + 1.0, floor_q) 

137 

138 q_is_zeros = q == 0.0 

139 floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q) 

140 

141 is_div_by_zero = y == 0.0 

142 float_division = x / y 

143 out = tl.where(is_div_by_zero, float_division, floor_q) 

144 return out 

145 

146 

147@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

148@triton.jit 

149def floor_div_func(x, y): 

150 if x.type.scalar.is_int() & x.type.scalar.is_int(): 

151 return _int_floordiv(x, y) 

152 else: 

153 return _float_floordiv(x, y) 

154 

155 

156@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

157@triton.jit 

158def floor_div_func_tensor_scalar(x, y): 

159 if x.type.scalar.is_int() & x.type.scalar.is_int(): 

160 return _int_floordiv(x, y) 

161 else: 

162 return _float_floordiv(x, y) 

163 

164 

165@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

166@triton.jit 

167def floor_div_func_scalar_tensor(x, y): 

168 if x.type.scalar.is_int() & x.type.scalar.is_int(): 

169 return _int_floordiv(x, y) 

170 else: 

171 return _float_floordiv(x, y) 

172 

173 

174def floor_divide(A, B): 

175 logger.debug("GEMS FLOOR_DIVIDE") 

176 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

177 return floor_div_func(A, B) 

178 elif isinstance(A, torch.Tensor): 

179 return floor_div_func_tensor_scalar(A, B) 

180 elif isinstance(B, torch.Tensor): 

181 return floor_div_func_scalar_tensor(A, B) 

182 else: 

183 # Both scalar 

184 return torch.tensor(A // B) 

185 

186 

187def floor_divide_(A, B): 

188 logger.debug("GEMS FLOOR_DIVIDE_") 

189 if isinstance(B, torch.Tensor): 

190 return floor_div_func(A, B, out0=A) 

191 else: 

192 return floor_div_func_tensor_scalar(A, B, out0=A) 

193 

194 

195def div_mode(A, B, rounding_mode=None): 

196 if rounding_mode is None: 

197 return true_divide(A, B) 

198 elif rounding_mode == "trunc": 

199 return trunc_divide(A, B) 

200 elif rounding_mode == "floor": 

201 return floor_divide(A, B) 

202 else: 

203 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." 

204 raise ValueError(msg) 

205 

206 

207def div_mode_(A, B, rounding_mode=None): 

208 if rounding_mode is None: 

209 return true_divide_(A, B) 

210 elif rounding_mode == "trunc": 

211 return trunc_divide_(A, B) 

212 elif rounding_mode == "floor": 

213 return floor_divide_(A, B) 

214 else: 

215 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." 

216 raise ValueError(msg) 

217 

218 

219@triton.jit 

220def _remainder(x, y): 

221 r = x % y 

222 c1 = r != 0 

223 c2 = (x < 0) ^ (y < 0) 

224 return tl.where(c1 & c2, r + y, r) 

225 

226 

227@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

228@triton.jit 

229def rem_tt(x, y): 

230 return _remainder(x, y) 

231 

232 

233@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

234@triton.jit 

235def rem_ts(x, y): 

236 return _remainder(x, y) 

237 

238 

239@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

240@triton.jit 

241def rem_st(x, y): 

242 return _remainder(x, y) 

243 

244 

245def remainder(A, B): 

246 logger.debug("GEMS FLOOR_DIVIDE") 

247 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

248 return rem_tt(A, B) 

249 elif isinstance(A, torch.Tensor): 

250 return rem_ts(A, B) 

251 elif isinstance(B, torch.Tensor): 

252 return rem_st(A, B) 

253 else: 

254 # Both scalar 

255 return torch.tensor(A % B) 

256 

257 

258def remainder_(A, B): 

259 logger.debug("GEMS REMAINDER_") 

260 if isinstance(B, torch.Tensor): 

261 return rem_tt(A, B, out0=A) 

262 else: 

263 return rem_ts(A, B, out0=A)