Coverage for src/flag_gems/runtime/backend/_metax/ops/exponential_.py: 0%

176 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils.random_utils import ( 

9 philox_backend_seed_offset, 

10 uint_to_uniform_float, 

11) 

12 

13logger = logging.getLogger("flag_gems." + __name__) 

14eps: tl.constexpr = [ 

15 2.220446049250313e-16, 

16 1.1920928955078125e-07, 

17 0.0009765625, 

18 0.0078125, 

19] # eps for double, float, float16, bfloat16 

20eps_1: tl.constexpr = [-0.5 * x for x in eps] 

21eps_2: tl.constexpr = [1.0 + x for x in eps_1] 

22 

23# 1/log2e 

24# use this scale to trans loge to log2 

25trans_scale: tl.constexpr = 1.0 / 1.4426950408889634 

26 

27 

28def heur_block(args): 

29 if args["N"] <= 512: 

30 return 256 

31 elif args["N"] <= 1024: 

32 return 512 

33 else: 

34 return 1024 

35 

36 

37def heur_num_warps(args): 

38 if args["N"] <= 512: 

39 return 4 

40 elif args["N"] <= 1024: 

41 return 8 

42 else: 

43 return 16 

44 

45 

46@triton.heuristics( 

47 { 

48 "BLOCK": heur_block, 

49 "num_warps": heur_num_warps, 

50 } 

51) 

52@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) 

53def fused_exponential_kernel( 

54 out_ptr, 

55 N, 

56 is_double, 

57 lambd, 

58 eps, 

59 philox_seed, 

60 philox_offset, 

61 BLOCK: tl.constexpr, 

62): 

63 philox_seed = philox_seed.to(tl.int64) 

64 philox_offset = philox_offset.to(tl.int64) 

65 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

66 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

67 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) 

68 c0 += i4 

69 _O = c0 * 0 

70 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O) 

71 if is_double: 

72 d0 = uint_to_uniform_float(paste_u64(r0, r2)) 

73 d1 = uint_to_uniform_float(paste_u64(r1, r3)) 

74 y0 = transform_exponential(d0, lambd, eps) 

75 y1 = transform_exponential(d1, lambd, eps) 

76 UNROLL = 2 

77 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL 

78 off_0 = start + tl.arange(0, BLOCK) 

79 off_1 = off_0 + BLOCK 

80 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first") 

81 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first") 

82 else: 

83 f0 = uint_to_uniform_float(r0) 

84 f1 = uint_to_uniform_float(r1) 

85 f2 = uint_to_uniform_float(r2) 

86 f3 = uint_to_uniform_float(r3) 

87 y0 = transform_exponential(f0, lambd, eps) 

88 y1 = transform_exponential(f1, lambd, eps) 

89 y2 = transform_exponential(f2, lambd, eps) 

90 y3 = transform_exponential(f3, lambd, eps) 

91 UNROLL = 4 

92 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL 

93 off_0 = start + tl.arange(0, BLOCK) 

94 off_1 = off_0 + BLOCK 

95 off_2 = off_1 + BLOCK 

96 off_3 = off_2 + BLOCK 

97 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first") 

98 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first") 

99 tl.store(out_ptr + off_2, y2, mask=off_2 < N, eviction_policy="evict_first") 

100 tl.store(out_ptr + off_3, y3, mask=off_3 < N, eviction_policy="evict_first") 

101 

102 

103# lambda == 1 

104@triton.heuristics( 

105 { 

106 "BLOCK": heur_block, 

107 "num_warps": heur_num_warps, 

108 } 

109) 

110@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) 

111def fused_exponential_kernel_opt( 

112 out_ptr, 

113 N, 

114 dtype, 

115 philox_seed, 

116 philox_offset, 

117 BLOCK: tl.constexpr, 

118): 

119 philox_seed = philox_seed.to(tl.int64) 

120 philox_offset = philox_offset.to(tl.int64) 

121 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

122 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

123 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) 

124 c0 += i4 

125 _O = c0 * 0 

126 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O) 

127 if dtype == 0: 

128 d0 = uint_to_uniform_float(paste_u64(r0, r2)) 

129 d1 = uint_to_uniform_float(paste_u64(r1, r3)) 

130 y0 = transform_exponential_double(d0) 

131 y1 = transform_exponential_double(d1) 

132 UNROLL = 2 

133 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL 

134 off_0 = start + tl.arange(0, BLOCK) 

135 off_1 = off_0 + BLOCK 

136 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first") 

137 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first") 

138 else: 

139 f0 = uint_to_uniform_float(r0) 

140 f1 = uint_to_uniform_float(r1) 

141 f2 = uint_to_uniform_float(r2) 

142 f3 = uint_to_uniform_float(r3) 

143 if dtype == 1: 

144 y0 = transform_exponential_float(f0) 

145 y1 = transform_exponential_float(f1) 

146 y2 = transform_exponential_float(f2) 

147 y3 = transform_exponential_float(f3) 

148 elif dtype == 2: 

149 y0 = transform_exponential_float16(f0) 

150 y1 = transform_exponential_float16(f1) 

151 y2 = transform_exponential_float16(f2) 

152 y3 = transform_exponential_float16(f3) 

153 else: 

154 y0 = transform_exponential_bfloat16(f0) 

155 y1 = transform_exponential_bfloat16(f1) 

156 y2 = transform_exponential_bfloat16(f2) 

157 y3 = transform_exponential_bfloat16(f3) 

158 

159 UNROLL = 4 

160 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL 

161 off_0 = start + tl.arange(0, BLOCK) 

162 off_1 = off_0 + BLOCK 

163 off_2 = off_1 + BLOCK 

164 off_3 = off_2 + BLOCK 

165 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first") 

166 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first") 

167 tl.store(out_ptr + off_2, y2, mask=off_2 < N, eviction_policy="evict_first") 

168 tl.store(out_ptr + off_3, y3, mask=off_3 < N, eviction_policy="evict_first") 

169 

170 

171@triton.jit 

172def paste_u64(hi, lo): 

173 hi = hi.to(tl.uint64) << 32 

174 x = hi | lo.to(tl.uint64) 

175 return x 

176 

177 

178@triton.jit 

179def transform_exponential(u, lambd, eps): 

180 eps1 = -0.5 * eps 

181 is_min = u >= 1.0 + eps1 

182 log = tl.where(is_min, eps1, tl.math.log(u)) 

183 v = -1.0 / lambd * log 

184 

185 return v 

186 

187 

188@triton.jit 

189def transform_exponential_double(u): 

190 eps1 = eps_1[0] 

191 is_min = u >= eps_2[0] 

192 log = tl.where(is_min, eps1, tl.math.log(u)) 

193 v = -1.0 * log 

194 

195 return v 

196 

197 

198@triton.jit 

199def transform_exponential_float(u): 

200 eps1 = eps_1[1] 

201 is_min = u >= eps_2[1] 

202 log = tl.where(is_min, eps1, tl.math.log2(u) * trans_scale) 

203 v = -1.0 * log 

204 

205 return v 

206 

207 

208@triton.jit 

209def transform_exponential_float16(u): 

210 eps1 = eps_1[2] 

211 is_min = u >= eps_2[2] 

212 log = tl.where(is_min, eps1, tl.math.log2(u) * trans_scale) 

213 v = -1.0 * log 

214 

215 return v 

216 

217 

218@triton.jit 

219def transform_exponential_bfloat16(u): 

220 eps1 = eps_1[3] 

221 is_min = u >= eps_2[3] 

222 log = tl.where(is_min, eps1, tl.math.log2(u) * trans_scale) 

223 v = -1.0 * log 

224 

225 return v 

226 

227 

228def exponential_(x, lambd: float = 1.0, *, generator=None): 

229 logger.debug("METAX GEMS EXPONENTIAL_") 

230 dtype = x.dtype 

231 device = x.device 

232 inplace = x.is_contiguous() 

233 lst = [torch.float64, torch.float32, torch.float16, torch.bfloat16] 

234 assert dtype in lst 

235 is_double = dtype in (torch.float64,) 

236 UNROLL = 2 if is_double else 4 

237 N = x.numel() 

238 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) 

239 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, 

240 # hence we cannot obtain the per thread offset as in Pytorch. 

241 increment = triton.cdiv(N, UNROLL) 

242 philox_seed, philox_offset = philox_backend_seed_offset( 

243 increment, generator=generator 

244 ) 

245 eps = torch.finfo(dtype).eps 

246 x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device) 

247 type_index = lst.index(dtype) 

248 with torch_device_fn.device(device): 

249 if lambd == 1.0: 

250 fused_exponential_kernel_opt[grid_fn]( 

251 x_, N, type_index, philox_seed, philox_offset 

252 ) 

253 else: 

254 fused_exponential_kernel[grid_fn]( 

255 x_, N, is_double, lambd, eps, philox_seed, philox_offset 

256 ) 

257 if not inplace: 

258 x.copy_(x_) 

259 return x