Coverage for src/flag_gems/ops/exponential_.py: 41%
120 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import device, torch_device_fn
8from flag_gems.utils import libentry, libtuner
9from flag_gems.utils.random_utils import (
10 philox_backend_seed_offset,
11 uint_to_uniform_float,
12)
14logger = logging.getLogger(__name__)
17@triton.jit
18def safe_fast_log_f32(x):
19 min_normal = (x * 0.0 + 1.17549435e-38).to(tl.float32)
20 max_u = x * 0.0 + 0.99999994
21 x = tl.minimum(tl.maximum(x, min_normal), max_u)
22 bits = x.to(tl.int32, bitcast=True)
23 exponent = (bits >> 23) - 127
24 mantissa = (bits & 0x7FFFFF).to(tl.float32) * (1.0 / 8388608.0) + 1.0
25 m1 = mantissa - 1.0
26 return (
27 m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333 - m1 * 0.25)))
28 + exponent.to(tl.float32) * 0.6931471805599453
29 )
32@triton.jit
33def safe_fast_log_f64(x):
34 min_normal = x * 0.0 + 2.2250738585072014e-308
35 max_u = x * 0.0 + (1.0 - 2.220446049250313e-16)
36 x = tl.minimum(tl.maximum(x, min_normal), max_u)
37 bits = x.to(tl.int64, bitcast=True)
38 exponent = (bits >> 52) - 1023
39 mantissa = (bits & 0x000FFFFFFFFFFFFF).to(tl.float64) * (
40 1.0 / 4503599627370496.0
41 ) + 1.0
42 m1 = mantissa - 1.0
43 return (
44 m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333333333 - m1 * 0.25)))
45 + exponent.to(tl.float64) * 0.6931471805599453
46 )
49@triton.jit
50def paste_u64(hi: tl.uint32, lo: tl.uint32):
51 return (hi.to(tl.uint64) << 32) | lo.to(tl.uint64)
54@triton.jit
55def transform_exponential_f32_precise(u, inv_lambd, eps_minus):
56 log = tl.where(u >= 1.0 + eps_minus, eps_minus, tl.math.log(u))
57 # log = tl.log(tl.maximum(u, 1e-38))
58 return -inv_lambd * log
61@triton.jit
62def transform_exponential_f32_fast(u, inv_lambd, eps_minus):
63 log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f32(u))
64 # log = tl.log(tl.maximum(u, 1e-38))
65 return -inv_lambd * log
68if device.vendor_name == "iluvatar":
69 transform_exponential_f32 = transform_exponential_f32_precise
70else:
71 transform_exponential_f32 = transform_exponential_f32_fast
74@triton.jit
75def transform_exponential_f64(u, inv_lambd, eps_minus):
76 log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f64(u))
77 return -inv_lambd * log
80@libentry()
81@libtuner(
82 configs=[
83 triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2),
84 triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2),
85 triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2),
86 triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3),
87 triton.Config({"BLOCK": 1024}, num_warps=8, num_stages=3),
88 ],
89 key=["N"],
90)
91@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
92def fused_exponential_kernel_f32(
93 out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr
94):
95 philox_seed = philox_seed.to(tl.int64)
96 philox_offset = philox_offset.to(tl.int64)
97 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
98 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
100 pid = tl.program_id(0)
101 i = pid * BLOCK + tl.arange(0, BLOCK)
102 c0 += i
103 z = c0 * 0
104 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z)
106 y0 = transform_exponential_f32(uint_to_uniform_float(r0), inv_lambd, eps_minus)
107 y1 = transform_exponential_f32(uint_to_uniform_float(r1), inv_lambd, eps_minus)
108 y2 = transform_exponential_f32(uint_to_uniform_float(r2), inv_lambd, eps_minus)
109 y3 = transform_exponential_f32(uint_to_uniform_float(r3), inv_lambd, eps_minus)
111 start = pid.to(tl.uint64) * BLOCK * 4
112 off0 = start + tl.arange(0, BLOCK)
113 off1 = off0 + BLOCK
114 off2 = off1 + BLOCK
115 off3 = off2 + BLOCK
117 tl.store(out_ptr + off0, y0, mask=off0 < N)
118 tl.store(out_ptr + off1, y1, mask=off1 < N)
119 tl.store(out_ptr + off2, y2, mask=off2 < N)
120 tl.store(out_ptr + off3, y3, mask=off3 < N)
123@libentry()
124@libtuner(
125 configs=[
126 triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2),
127 triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2),
128 triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2),
129 triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3),
130 ],
131 key=["N"],
132)
133@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
134def fused_exponential_kernel_f64(
135 out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr
136):
137 philox_seed = philox_seed.to(tl.int64)
138 philox_offset = philox_offset.to(tl.int64)
139 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
140 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
142 pid = tl.program_id(0)
143 i = pid * BLOCK + tl.arange(0, BLOCK)
144 c0 += i
145 z = c0 * 0
146 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z)
148 u0 = uint_to_uniform_float(paste_u64(r0, r2))
149 u1 = uint_to_uniform_float(paste_u64(r1, r3))
151 y0 = transform_exponential_f64(u0, inv_lambd, eps_minus)
152 y1 = transform_exponential_f64(u1, inv_lambd, eps_minus)
154 start = pid.to(tl.uint64) * BLOCK * 2
155 off0 = start + tl.arange(0, BLOCK)
156 off1 = off0 + BLOCK
158 tl.store(out_ptr + off0, y0, mask=off0 < N)
159 tl.store(out_ptr + off1, y1, mask=off1 < N)
162def exponential_(x, lambd: float = 1.0, *, generator=None):
163 logger.debug("GEMS EXPONENTIAL_")
165 dtype = x.dtype
166 device = x.device
167 inplace = x.is_contiguous()
168 assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
170 N = x.numel()
171 inv_lambd = 1.0 / lambd
172 eps_minus = -0.5 * torch.finfo(dtype).eps
174 out = x if inplace else torch.empty_like(x)
176 if dtype is torch.float64:
177 UNROLL = 2
178 grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
179 increment = triton.cdiv(N, UNROLL)
180 philox_seed, philox_offset = philox_backend_seed_offset(
181 increment, generator=generator
182 )
183 with torch_device_fn.device(device):
184 fused_exponential_kernel_f64[grid](
185 out, N, inv_lambd, eps_minus, philox_seed, philox_offset
186 )
187 else:
188 UNROLL = 4
189 grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
190 increment = triton.cdiv(N, UNROLL)
191 philox_seed, philox_offset = philox_backend_seed_offset(
192 increment, generator=generator
193 )
194 with torch_device_fn.device(device):
195 fused_exponential_kernel_f32[grid](
196 out, N, inv_lambd, eps_minus, philox_seed, philox_offset
197 )
199 if not inplace:
200 x.copy_(out)
201 return x