Coverage for src/flag_gems/runtime/backend/_metax/ops/exponential_.py: 0%
176 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils.random_utils import (
9 philox_backend_seed_offset,
10 uint_to_uniform_float,
11)
13logger = logging.getLogger("flag_gems." + __name__)
14eps: tl.constexpr = [
15 2.220446049250313e-16,
16 1.1920928955078125e-07,
17 0.0009765625,
18 0.0078125,
19] # eps for double, float, float16, bfloat16
20eps_1: tl.constexpr = [-0.5 * x for x in eps]
21eps_2: tl.constexpr = [1.0 + x for x in eps_1]
23# 1/log2e
24# use this scale to trans loge to log2
25trans_scale: tl.constexpr = 1.0 / 1.4426950408889634
28def heur_block(args):
29 if args["N"] <= 512:
30 return 256
31 elif args["N"] <= 1024:
32 return 512
33 else:
34 return 1024
37def heur_num_warps(args):
38 if args["N"] <= 512:
39 return 4
40 elif args["N"] <= 1024:
41 return 8
42 else:
43 return 16
46@triton.heuristics(
47 {
48 "BLOCK": heur_block,
49 "num_warps": heur_num_warps,
50 }
51)
52@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
53def fused_exponential_kernel(
54 out_ptr,
55 N,
56 is_double,
57 lambd,
58 eps,
59 philox_seed,
60 philox_offset,
61 BLOCK: tl.constexpr,
62):
63 philox_seed = philox_seed.to(tl.int64)
64 philox_offset = philox_offset.to(tl.int64)
65 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
66 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
67 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
68 c0 += i4
69 _O = c0 * 0
70 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
71 if is_double:
72 d0 = uint_to_uniform_float(paste_u64(r0, r2))
73 d1 = uint_to_uniform_float(paste_u64(r1, r3))
74 y0 = transform_exponential(d0, lambd, eps)
75 y1 = transform_exponential(d1, lambd, eps)
76 UNROLL = 2
77 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL
78 off_0 = start + tl.arange(0, BLOCK)
79 off_1 = off_0 + BLOCK
80 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first")
81 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first")
82 else:
83 f0 = uint_to_uniform_float(r0)
84 f1 = uint_to_uniform_float(r1)
85 f2 = uint_to_uniform_float(r2)
86 f3 = uint_to_uniform_float(r3)
87 y0 = transform_exponential(f0, lambd, eps)
88 y1 = transform_exponential(f1, lambd, eps)
89 y2 = transform_exponential(f2, lambd, eps)
90 y3 = transform_exponential(f3, lambd, eps)
91 UNROLL = 4
92 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL
93 off_0 = start + tl.arange(0, BLOCK)
94 off_1 = off_0 + BLOCK
95 off_2 = off_1 + BLOCK
96 off_3 = off_2 + BLOCK
97 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first")
98 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first")
99 tl.store(out_ptr + off_2, y2, mask=off_2 < N, eviction_policy="evict_first")
100 tl.store(out_ptr + off_3, y3, mask=off_3 < N, eviction_policy="evict_first")
103# lambda == 1
104@triton.heuristics(
105 {
106 "BLOCK": heur_block,
107 "num_warps": heur_num_warps,
108 }
109)
110@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
111def fused_exponential_kernel_opt(
112 out_ptr,
113 N,
114 dtype,
115 philox_seed,
116 philox_offset,
117 BLOCK: tl.constexpr,
118):
119 philox_seed = philox_seed.to(tl.int64)
120 philox_offset = philox_offset.to(tl.int64)
121 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
122 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
123 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
124 c0 += i4
125 _O = c0 * 0
126 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
127 if dtype == 0:
128 d0 = uint_to_uniform_float(paste_u64(r0, r2))
129 d1 = uint_to_uniform_float(paste_u64(r1, r3))
130 y0 = transform_exponential_double(d0)
131 y1 = transform_exponential_double(d1)
132 UNROLL = 2
133 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL
134 off_0 = start + tl.arange(0, BLOCK)
135 off_1 = off_0 + BLOCK
136 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first")
137 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first")
138 else:
139 f0 = uint_to_uniform_float(r0)
140 f1 = uint_to_uniform_float(r1)
141 f2 = uint_to_uniform_float(r2)
142 f3 = uint_to_uniform_float(r3)
143 if dtype == 1:
144 y0 = transform_exponential_float(f0)
145 y1 = transform_exponential_float(f1)
146 y2 = transform_exponential_float(f2)
147 y3 = transform_exponential_float(f3)
148 elif dtype == 2:
149 y0 = transform_exponential_float16(f0)
150 y1 = transform_exponential_float16(f1)
151 y2 = transform_exponential_float16(f2)
152 y3 = transform_exponential_float16(f3)
153 else:
154 y0 = transform_exponential_bfloat16(f0)
155 y1 = transform_exponential_bfloat16(f1)
156 y2 = transform_exponential_bfloat16(f2)
157 y3 = transform_exponential_bfloat16(f3)
159 UNROLL = 4
160 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL
161 off_0 = start + tl.arange(0, BLOCK)
162 off_1 = off_0 + BLOCK
163 off_2 = off_1 + BLOCK
164 off_3 = off_2 + BLOCK
165 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first")
166 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first")
167 tl.store(out_ptr + off_2, y2, mask=off_2 < N, eviction_policy="evict_first")
168 tl.store(out_ptr + off_3, y3, mask=off_3 < N, eviction_policy="evict_first")
171@triton.jit
172def paste_u64(hi, lo):
173 hi = hi.to(tl.uint64) << 32
174 x = hi | lo.to(tl.uint64)
175 return x
178@triton.jit
179def transform_exponential(u, lambd, eps):
180 eps1 = -0.5 * eps
181 is_min = u >= 1.0 + eps1
182 log = tl.where(is_min, eps1, tl.math.log(u))
183 v = -1.0 / lambd * log
185 return v
188@triton.jit
189def transform_exponential_double(u):
190 eps1 = eps_1[0]
191 is_min = u >= eps_2[0]
192 log = tl.where(is_min, eps1, tl.math.log(u))
193 v = -1.0 * log
195 return v
198@triton.jit
199def transform_exponential_float(u):
200 eps1 = eps_1[1]
201 is_min = u >= eps_2[1]
202 log = tl.where(is_min, eps1, tl.math.log2(u) * trans_scale)
203 v = -1.0 * log
205 return v
208@triton.jit
209def transform_exponential_float16(u):
210 eps1 = eps_1[2]
211 is_min = u >= eps_2[2]
212 log = tl.where(is_min, eps1, tl.math.log2(u) * trans_scale)
213 v = -1.0 * log
215 return v
218@triton.jit
219def transform_exponential_bfloat16(u):
220 eps1 = eps_1[3]
221 is_min = u >= eps_2[3]
222 log = tl.where(is_min, eps1, tl.math.log2(u) * trans_scale)
223 v = -1.0 * log
225 return v
228def exponential_(x, lambd: float = 1.0, *, generator=None):
229 logger.debug("METAX GEMS EXPONENTIAL_")
230 dtype = x.dtype
231 device = x.device
232 inplace = x.is_contiguous()
233 lst = [torch.float64, torch.float32, torch.float16, torch.bfloat16]
234 assert dtype in lst
235 is_double = dtype in (torch.float64,)
236 UNROLL = 2 if is_double else 4
237 N = x.numel()
238 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
239 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
240 # hence we cannot obtain the per thread offset as in Pytorch.
241 increment = triton.cdiv(N, UNROLL)
242 philox_seed, philox_offset = philox_backend_seed_offset(
243 increment, generator=generator
244 )
245 eps = torch.finfo(dtype).eps
246 x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device)
247 type_index = lst.index(dtype)
248 with torch_device_fn.device(device):
249 if lambd == 1.0:
250 fused_exponential_kernel_opt[grid_fn](
251 x_, N, type_index, philox_seed, philox_offset
252 )
253 else:
254 fused_exponential_kernel[grid_fn](
255 x_, N, is_double, lambd, eps, philox_seed, philox_offset
256 )
257 if not inplace:
258 x.copy_(x_)
259 return x