Coverage for src/flag_gems/runtime/backend/_hygon/ops/exponential_.py: 0%
93 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils.random_utils import (
9 philox_backend_seed_offset,
10 uint_to_uniform_float,
11)
13logger = logging.getLogger(__name__)
15MIN_NORMAL_F32 = 1.17549435e-38
16# Largest value less than 1.0 to avoid log(1)=0 edge (though harmless)
17MAX_U_F32 = 0.99999994 # nextafter(1.0, 0.0) in float32
20@triton.jit
21def safe_fast_log(x):
22 # Construct FP32 constants matching x's dtype
23 min_normal = x * 0.0 + 1.17549435e-38
24 max_u = x * 0.0 + 0.99999994
26 x = tl.minimum(tl.maximum(x, min_normal), max_u)
28 bits = x.to(tl.int32, bitcast=True)
29 exponent = (bits >> 23) - 127
30 # mantissa = (bits & 0x7FFFFF).to(tl.float32) * (1.0 / (1 << 23)) + 1.0
31 mantissa = (bits & 0x7FFFFF).to(tl.float32) * (1.0 / 8388608) + 1.0
33 m1 = mantissa - 1.0
34 log_m = m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333 - m1 * 0.25)))
35 log_val = log_m + exponent.to(tl.float32) * 0.6931471805599453
37 return log_val
40# ===== Kernel with constexpr switch =====
41@triton.autotune(
42 configs=[
43 triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2),
44 triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2),
45 triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2),
46 triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3),
47 triton.Config({"BLOCK": 1024}, num_warps=8, num_stages=3),
48 triton.Config({"BLOCK": 1024}, num_warps=16, num_stages=3),
49 triton.Config({"BLOCK": 2048}, num_warps=16, num_stages=4),
50 ],
51 key=["N", "is_double"],
52)
53# @triton.heuristics(runtime.get_heuristic_config("exponential_"))
54@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
55def fused_exponential_kernel(
56 out_ptr,
57 N,
58 is_double,
59 inv_lambd,
60 eps_minus,
61 philox_seed,
62 philox_offset,
63 BLOCK: tl.constexpr,
64):
65 philox_seed = philox_seed.to(tl.int64)
66 philox_offset = philox_offset.to(tl.int64)
67 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
68 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
69 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
70 c0 += i4
71 _O = c0 * 0
72 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
73 if is_double:
74 d0 = uint_to_uniform_float(paste_u64(r0, r2))
75 d1 = uint_to_uniform_float(paste_u64(r1, r3))
76 y0 = transform_exponential(d0, inv_lambd, eps_minus)
77 y1 = transform_exponential(d1, inv_lambd, eps_minus)
78 UNROLL = 2
79 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL
80 off_0 = start + tl.arange(0, BLOCK)
81 off_1 = off_0 + BLOCK
82 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first")
83 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first")
84 else:
85 f0 = uint_to_uniform_float(r0)
86 f1 = uint_to_uniform_float(r1)
87 f2 = uint_to_uniform_float(r2)
88 f3 = uint_to_uniform_float(r3)
89 y0 = transform_exponential(f0, inv_lambd, eps_minus)
90 y1 = transform_exponential(f1, inv_lambd, eps_minus)
91 y2 = transform_exponential(f2, inv_lambd, eps_minus)
92 y3 = transform_exponential(f3, inv_lambd, eps_minus)
94 UNROLL = 4
95 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL
96 off_0 = start + tl.arange(0, BLOCK)
97 off_1 = off_0 + BLOCK
98 off_2 = off_1 + BLOCK
99 off_3 = off_2 + BLOCK
100 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_last")
101 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_last")
102 tl.store(out_ptr + off_2, y2, mask=off_2 < N, eviction_policy="evict_last")
103 tl.store(out_ptr + off_3, y3, mask=off_3 < N, eviction_policy="evict_last")
106@triton.jit
107def paste_u64(hi: tl.uint32, lo: tl.uint32):
108 hi = hi.to(tl.uint64) << 32
109 x = hi | lo.to(tl.uint64)
110 return x
113@triton.jit
114def transform_exponential(u, inv_lambd, eps_minus):
115 # eps1 = -0.5 * eps
116 is_min = u >= 1.0 + eps_minus
117 # log = tl.where(is_min, eps1, tl.math.log(u))
118 # is_min = u >= compare_val
119 log = tl.where(is_min, eps_minus, safe_fast_log(u))
120 v = -inv_lambd * log
121 return v
124def exponential_(x, lambd: float = 1.0, *, generator=None):
125 logger.debug("GEMS EXPONENTIAL_")
126 dtype = x.dtype
127 device = x.device
128 inplace = x.is_contiguous()
129 assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
130 is_double = dtype in (torch.float64,)
131 UNROLL = 2 if is_double else 4
132 N = x.numel()
133 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
134 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
135 # hence we cannot obtain the per thread offset as in Pytorch.
136 increment = triton.cdiv(N, UNROLL)
137 philox_seed, philox_offset = philox_backend_seed_offset(
138 increment, generator=generator
139 )
140 eps = torch.finfo(dtype).eps
141 eps_minus = -0.5 * eps
142 inv_lambd = 1.0 / lambd
143 x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device)
144 with torch_device_fn.device(device):
145 fused_exponential_kernel[grid_fn](
146 x_, N, is_double, inv_lambd, eps_minus, philox_seed, philox_offset
147 )
148 if not inplace:
149 x.copy_(x_)
150 return x