Coverage for src/flag_gems/runtime/backend/_cambricon/ops/div.py: 0%

182 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import tl_extra_shim 

8 

9from ..utils.pointwise_dynamic import pointwise_dynamic 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12div_rn = tl_extra_shim.div_rn 

13div_rz = tl_extra_shim.div_rz 

14fmod = tl_extra_shim.fmod 

15trunc = tl_extra_shim.trunc 

16 

17 

18@pointwise_dynamic( 

19 is_tensor=[True, True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")] 

20) 

21@triton.jit 

22def true_div_func(x, y, inplace): 

23 return x / y 

24 

25 

26@pointwise_dynamic( 

27 is_tensor=[True, False, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")] 

28) 

29@triton.jit 

30def true_div_func_tensor_scalar(x, y, inplace): 

31 y = y.to(x.dtype) 

32 return x / y 

33 

34 

35@pointwise_dynamic( 

36 is_tensor=[False, True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")] 

37) 

38@triton.jit 

39def true_div_func_scalar_tensor(x, y, inplace): 

40 x = x.to(y.dtype) 

41 return x / y 

42 

43 

44def true_divide(A, B): 

45 logger.debug("GEMS_CAMBRICON TRUE_DIVIDE") 

46 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

47 return true_div_func(A, B, False) 

48 elif isinstance(A, torch.Tensor): 

49 return true_div_func_tensor_scalar(A, B, False) 

50 elif isinstance(B, torch.Tensor): 

51 return true_div_func_scalar_tensor(A, B, False) 

52 else: 

53 # Both scalar 

54 return torch.tensor(A / B) 

55 

56 

57def true_divide_out(A, B, out): 

58 logger.debug("GEMS_CAMBRICON TRUE_DIVIDE OUT") 

59 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

60 return true_div_func(A, B, False, out0=out) 

61 elif isinstance(A, torch.Tensor): 

62 return true_div_func_tensor_scalar(A, B, False, out0=out) 

63 elif isinstance(B, torch.Tensor): 

64 return true_div_func_scalar_tensor(A, B, False, out0=out) 

65 else: 

66 # Both scalar 

67 return torch.tensor(A / B) if out is None else out.fill_(A / B) 

68 

69 

70def true_divide_(A, B): 

71 logger.debug("GEMS_CAMBRICON TRUE_DIVIDE_") 

72 if isinstance(B, torch.Tensor): 

73 return true_div_func(A, B, True, out0=A) 

74 else: 

75 return true_div_func_tensor_scalar(A, B, True, out0=A) 

76 

77 

78@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

79@triton.jit 

80def trunc_div_func(x, y, inplace): 

81 return trunc(div_rn(x, y)) 

82 

83 

84@pointwise_dynamic( 

85 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

86) 

87@triton.jit 

88def trunc_div_func_tensor_scalar(x, y, inplace): 

89 return trunc(div_rn(x, y)) 

90 

91 

92@pointwise_dynamic( 

93 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")] 

94) 

95@triton.jit 

96def trunc_div_func_scalar_tensor(x, y, inplace): 

97 return trunc(div_rn(x, y)) 

98 

99 

100def trunc_divide(A, B): 

101 logger.debug("GEMS_CAMBRICON TRUNC_DIVIDE") 

102 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

103 return trunc_div_func(A, B, False) 

104 elif isinstance(A, torch.Tensor): 

105 return trunc_div_func_tensor_scalar(A, B, False) 

106 elif isinstance(B, torch.Tensor): 

107 return trunc_div_func_scalar_tensor(A, B, False) 

108 else: 

109 # Both scalar 

110 return torch.tensor(A / B) 

111 

112 

113def trunc_divide_(A, B): 

114 logger.debug("GEMS_CAMBRICON TRUNC_DIVIDE_") 

115 if isinstance(B, torch.Tensor): 

116 return trunc_div_func(A, B, True, out0=A) 

117 else: 

118 return trunc_div_func_tensor_scalar(A, B, True, out0=A) 

119 

120 

121@triton.jit 

122def _int_floordiv(x, y): 

123 # TODO: request Triton to add an integer remainder builtin 

124 # The semantic of Triton floordiv differs from Pytorch/Numpy 

125 # Triton floordiv equates to 

126 # (x - np.fmod(x, y)) / y 

127 # whereas Pytorch floordiv is 

128 # (x - np.remainder(x, y)) y 

129 # The results show a one off difference when 

130 # C1) x and y have opposite signs 

131 # and C2) x is not multiples of y. 

132 # Apart from the above, there's an erroneous case x // 0 returns -1 

133 # whereas in Pytorch x // 0 returns -1 if x >=0 and -2 if x < 0 

134 # but this special case is coalesced into the c1 and c2 check so 

135 # there's extra handling. 

136 r = x % y 

137 c1 = r != 0 

138 c2 = (x < 0) ^ (y < 0) 

139 c3 = (x < 0) & (y == 0) 

140 c = c1 & c2 

141 if x.dtype == tl.int16 and y.dtype == tl.int16: 

142 return (x.to(tl.int32) // y.to(tl.int32)).cast(tl.int16) - c - c3 

143 return x // y - c - c3 

144 

145 

146# TO be consistent with python, numpy and torch, we have to implement it in the 

147# following way. 

148# CPython 

149# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 

150# numpy 

151# https://github.com/numpy/numpy/blob/a4ad142aa1282a77bbb05acd706cb57c9cc29846/numpy/_core/src/npymath/npy_math_internal.h.src#L532 

152# torch 

153# https://github.com/pytorch/pytorch/blob/d6d9183456cd07ca0b361a194b98c2fb196e7c36/c10/util/generic_math.h#L23 

154@triton.jit 

155def _float_floordiv(x, y): 

156 # NOTE: fmod's sign is the same as the dividend 

157 remainder = fmod(x, y) 

158 imperfect = remainder != 0.0 

159 different_sign = (x < 0) ^ (y < 0) 

160 

161 # NOTE: we have to use div_rn explicitly here 

162 q = div_rn(x - remainder, y) 

163 q = tl.where(imperfect & different_sign, q - 1, q) 

164 

165 floor_q = tl.math.floor(q) 

166 c = q - floor_q > 0.5 

167 floor_q = tl.where(c, floor_q + 1.0, floor_q) 

168 

169 q_is_zeros = q == 0.0 

170 floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q) 

171 

172 is_div_by_zero = y == 0.0 

173 float_division = x / y 

174 out = tl.where(is_div_by_zero, float_division, floor_q) 

175 return out 

176 

177 

178@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

179@triton.jit 

180def floor_div_func(x, y, inplace): 

181 if x.type.scalar.is_int() & y.type.scalar.is_int(): 

182 return _int_floordiv(x, y) 

183 else: 

184 return _float_floordiv(x, y) 

185 

186 

187@pointwise_dynamic( 

188 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

189) 

190@triton.jit 

191def floor_div_func_tensor_scalar(x, y, inplace): 

192 if x.type.scalar.is_int() & y.type.scalar.is_int(): 

193 return _int_floordiv(x, y) 

194 else: 

195 return _float_floordiv(x, y) 

196 

197 

198@pointwise_dynamic( 

199 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")] 

200) 

201@triton.jit 

202def floor_div_func_scalar_tensor(x, y, inplace): 

203 if x.type.scalar.is_int() & y.type.scalar.is_int(): 

204 return _int_floordiv(x, y) 

205 else: 

206 return _float_floordiv(x, y) 

207 

208 

209def floor_divide(A, B): 

210 logger.debug("GEMS_CAMBRICON FLOOR_DIVIDE") 

211 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

212 return floor_div_func(A, B, False) 

213 elif isinstance(A, torch.Tensor): 

214 return floor_div_func_tensor_scalar(A, B, False) 

215 elif isinstance(B, torch.Tensor): 

216 return floor_div_func_scalar_tensor(A, B, False) 

217 else: 

218 # Both scalar 

219 return torch.tensor(A // B) 

220 

221 

222def floor_divide_(A, B): 

223 logger.debug("GEMS_CAMBRICON FLOOR_DIVIDE_") 

224 if isinstance(B, torch.Tensor): 

225 return floor_div_func(A, B, True, out0=A) 

226 else: 

227 return floor_div_func_tensor_scalar(A, B, True, out0=A) 

228 

229 

230def div_mode(A, B, rounding_mode=None): 

231 if rounding_mode is None: 

232 return true_divide(A, B) 

233 elif rounding_mode == "trunc": 

234 return trunc_divide(A, B) 

235 elif rounding_mode == "floor": 

236 return floor_divide(A, B) 

237 else: 

238 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." 

239 raise ValueError(msg) 

240 

241 

242def div_mode_(A, B, rounding_mode=None): 

243 if rounding_mode is None: 

244 return true_divide_(A, B) 

245 elif rounding_mode == "trunc": 

246 return trunc_divide_(A, B) 

247 elif rounding_mode == "floor": 

248 return floor_divide_(A, B) 

249 else: 

250 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." 

251 raise ValueError(msg) 

252 

253 

254@triton.jit 

255def _remainder(x, y): 

256 r = x % y 

257 c1 = r != 0 

258 c2 = (x < 0) ^ (y < 0) 

259 return tl.where(c1 & c2, r + y, r) 

260 

261 

262@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

263@triton.jit 

264def rem_tt(x, y, inplace): 

265 return _remainder(x, y) 

266 

267 

268@pointwise_dynamic( 

269 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

270) 

271@triton.jit 

272def rem_ts(x, y, inplace): 

273 return _remainder(x, y) 

274 

275 

276@pointwise_dynamic( 

277 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")] 

278) 

279@triton.jit 

280def rem_st(x, y, inplace): 

281 return _remainder(x, y) 

282 

283 

284def remainder(A, B): 

285 logger.debug("GEMS_CAMBRICON REMAINDER") 

286 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

287 return rem_tt(A, B, False) 

288 elif isinstance(A, torch.Tensor): 

289 return rem_ts(A, B, False) 

290 elif isinstance(B, torch.Tensor): 

291 return rem_st(A, B, False) 

292 else: 

293 # Both scalar 

294 return torch.tensor(A % B) 

295 

296 

297def remainder_(A, B): 

298 logger.debug("GEMS_CAMBRICON REMAINDER_") 

299 if isinstance(B, torch.Tensor): 

300 return rem_tt(A, B, True, out0=A) 

301 else: 

302 return rem_ts(A, B, True, out0=A)