Coverage for src/flag_gems/ops/exponential_.py: 26%

120 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-07 22:33 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import device, torch_device_fn 

8from flag_gems.utils import libentry, libtuner 

9from flag_gems.utils.random_utils import ( 

10 philox_backend_seed_offset, 

11 uint_to_uniform_float, 

12) 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17@triton.jit 

18def safe_fast_log_f32(x): 

19 min_normal = (x * 0.0 + 1.17549435e-38).to(tl.float32) 

20 max_u = x * 0.0 + 0.99999994 

21 x = tl.minimum(tl.maximum(x, min_normal), max_u) 

22 bits = x.to(tl.int32, bitcast=True) 

23 exponent = (bits >> 23) - 127 

24 mantissa = (bits & 0x7FFFFF).to(tl.float32) * (1.0 / 8388608.0) + 1.0 

25 m1 = mantissa - 1.0 

26 return ( 

27 m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333 - m1 * 0.25))) 

28 + exponent.to(tl.float32) * 0.6931471805599453 

29 ) 

30 

31 

32@triton.jit 

33def safe_fast_log_f64(x): 

34 min_normal = x * 0.0 + 2.2250738585072014e-308 

35 max_u = x * 0.0 + (1.0 - 2.220446049250313e-16) 

36 x = tl.minimum(tl.maximum(x, min_normal), max_u) 

37 bits = x.to(tl.int64, bitcast=True) 

38 exponent = (bits >> 52) - 1023 

39 mantissa = (bits & 0x000FFFFFFFFFFFFF).to(tl.float64) * ( 

40 1.0 / 4503599627370496.0 

41 ) + 1.0 

42 m1 = mantissa - 1.0 

43 return ( 

44 m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333333333 - m1 * 0.25))) 

45 + exponent.to(tl.float64) * 0.6931471805599453 

46 ) 

47 

48 

49@triton.jit 

50def paste_u64(hi: tl.uint32, lo: tl.uint32): 

51 return (hi.to(tl.uint64) << 32) | lo.to(tl.uint64) 

52 

53 

54@triton.jit 

55def transform_exponential_f32_precise(u, inv_lambd, eps_minus): 

56 log = tl.where(u >= 1.0 + eps_minus, eps_minus, tl.math.log(u)) 

57 # log = tl.log(tl.maximum(u, 1e-38)) 

58 return -inv_lambd * log 

59 

60 

61@triton.jit 

62def transform_exponential_f32_fast(u, inv_lambd, eps_minus): 

63 log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f32(u)) 

64 # log = tl.log(tl.maximum(u, 1e-38)) 

65 return -inv_lambd * log 

66 

67 

68if device.vendor_name == "iluvatar": 

69 transform_exponential_f32 = transform_exponential_f32_precise 

70else: 

71 transform_exponential_f32 = transform_exponential_f32_fast 

72 

73 

74@triton.jit 

75def transform_exponential_f64(u, inv_lambd, eps_minus): 

76 log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f64(u)) 

77 return -inv_lambd * log 

78 

79 

80@libentry() 

81@libtuner( 

82 configs=[ 

83 triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2), 

84 triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2), 

85 triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2), 

86 triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3), 

87 triton.Config({"BLOCK": 1024}, num_warps=8, num_stages=3), 

88 ], 

89 key=["N"], 

90) 

91@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) 

92def fused_exponential_kernel_f32( 

93 out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr 

94): 

95 philox_seed = philox_seed.to(tl.int64) 

96 philox_offset = philox_offset.to(tl.int64) 

97 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

98 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

99 

100 pid = tl.program_id(0) 

101 i = pid * BLOCK + tl.arange(0, BLOCK) 

102 c0 += i 

103 z = c0 * 0 

104 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z) 

105 

106 y0 = transform_exponential_f32(uint_to_uniform_float(r0), inv_lambd, eps_minus) 

107 y1 = transform_exponential_f32(uint_to_uniform_float(r1), inv_lambd, eps_minus) 

108 y2 = transform_exponential_f32(uint_to_uniform_float(r2), inv_lambd, eps_minus) 

109 y3 = transform_exponential_f32(uint_to_uniform_float(r3), inv_lambd, eps_minus) 

110 

111 start = pid.to(tl.uint64) * BLOCK * 4 

112 off0 = start + tl.arange(0, BLOCK) 

113 off1 = off0 + BLOCK 

114 off2 = off1 + BLOCK 

115 off3 = off2 + BLOCK 

116 

117 tl.store(out_ptr + off0, y0, mask=off0 < N) 

118 tl.store(out_ptr + off1, y1, mask=off1 < N) 

119 tl.store(out_ptr + off2, y2, mask=off2 < N) 

120 tl.store(out_ptr + off3, y3, mask=off3 < N) 

121 

122 

123@libentry() 

124@libtuner( 

125 configs=[ 

126 triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2), 

127 triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2), 

128 triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2), 

129 triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3), 

130 ], 

131 key=["N"], 

132) 

133@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) 

134def fused_exponential_kernel_f64( 

135 out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr 

136): 

137 philox_seed = philox_seed.to(tl.int64) 

138 philox_offset = philox_offset.to(tl.int64) 

139 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

140 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

141 

142 pid = tl.program_id(0) 

143 i = pid * BLOCK + tl.arange(0, BLOCK) 

144 c0 += i 

145 z = c0 * 0 

146 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z) 

147 

148 u0 = uint_to_uniform_float(paste_u64(r0, r2)) 

149 u1 = uint_to_uniform_float(paste_u64(r1, r3)) 

150 

151 y0 = transform_exponential_f64(u0, inv_lambd, eps_minus) 

152 y1 = transform_exponential_f64(u1, inv_lambd, eps_minus) 

153 

154 start = pid.to(tl.uint64) * BLOCK * 2 

155 off0 = start + tl.arange(0, BLOCK) 

156 off1 = off0 + BLOCK 

157 

158 tl.store(out_ptr + off0, y0, mask=off0 < N) 

159 tl.store(out_ptr + off1, y1, mask=off1 < N) 

160 

161 

162def exponential_(x, lambd: float = 1.0, *, generator=None): 

163 logger.debug("GEMS EXPONENTIAL_") 

164 

165 dtype = x.dtype 

166 device = x.device 

167 inplace = x.is_contiguous() 

168 assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) 

169 

170 N = x.numel() 

171 inv_lambd = 1.0 / lambd 

172 eps_minus = -0.5 * torch.finfo(dtype).eps 

173 

174 out = x if inplace else torch.empty_like(x) 

175 

176 if dtype is torch.float64: 

177 UNROLL = 2 

178 grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) 

179 increment = triton.cdiv(N, UNROLL) 

180 philox_seed, philox_offset = philox_backend_seed_offset( 

181 increment, generator=generator 

182 ) 

183 with torch_device_fn.device(device): 

184 fused_exponential_kernel_f64[grid]( 

185 out, N, inv_lambd, eps_minus, philox_seed, philox_offset 

186 ) 

187 else: 

188 UNROLL = 4 

189 grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) 

190 increment = triton.cdiv(N, UNROLL) 

191 philox_seed, philox_offset = philox_backend_seed_offset( 

192 increment, generator=generator 

193 ) 

194 with torch_device_fn.device(device): 

195 fused_exponential_kernel_f32[grid]( 

196 out, N, inv_lambd, eps_minus, philox_seed, philox_offset 

197 ) 

198 

199 if not inplace: 

200 x.copy_(out) 

201 return x