Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/exponential_.py: 0%
88 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from triton.language.extra.xpu.libdevice import log2
8# from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils.random_utils import (
11 philox_backend_seed_offset,
12 uint_to_uniform_float,
13)
15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16# def heur_block(args):
17# if args["N"] <= 512:
18# return 512
19# else:
20# return 1024
23def heur_block(args):
24 return triton.next_power_of_2(triton.cdiv(args["N"], 12)) # CLUSTER_NUM = 12
27def heur_num_warps(args):
28 if args["N"] <= 512:
29 return 4
30 elif args["N"] <= 1024:
31 return 8
32 else:
33 return 16
36@triton.heuristics(
37 {
38 "BLOCK": heur_block,
39 "num_warps": heur_num_warps,
40 }
41)
42# @triton.heuristics(runtime.get_heuristic_config("exponential_"))
43@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
44def fused_exponential_kernel(
45 out_ptr,
46 N,
47 is_double: tl.constexpr,
48 lambd,
49 eps,
50 philox_seed,
51 philox_offset,
52 BLOCK: tl.constexpr,
53):
54 philox_seed = philox_seed.to(tl.int64)
55 philox_offset = philox_offset.to(tl.int64)
56 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
57 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
58 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
59 c0 += i4
60 _O = c0 * 0
61 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
62 if is_double:
63 d0 = uint_to_uniform_float(paste_u64(r0, r2))
64 d1 = uint_to_uniform_float(paste_u64(r1, r3))
65 y0 = transform_exponential(d0, lambd, eps)
66 y1 = transform_exponential(d1, lambd, eps)
67 UNROLL = 2
68 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL
69 off_0 = start + tl.arange(0, BLOCK)
70 off_1 = off_0 + BLOCK
71 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first")
72 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first")
73 else:
74 f0 = uint_to_uniform_float(r0)
75 f1 = uint_to_uniform_float(r1)
76 f2 = uint_to_uniform_float(r2)
77 f3 = uint_to_uniform_float(r3)
78 y0 = transform_exponential(f0, lambd, eps)
79 y1 = transform_exponential(f1, lambd, eps)
80 y2 = transform_exponential(f2, lambd, eps)
81 y3 = transform_exponential(f3, lambd, eps)
82 UNROLL = 4
83 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL
84 off_0 = start + tl.arange(0, BLOCK)
85 off_1 = off_0 + BLOCK
86 off_2 = off_1 + BLOCK
87 off_3 = off_2 + BLOCK
88 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first")
89 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first")
90 tl.store(out_ptr + off_2, y2, mask=off_2 < N, eviction_policy="evict_first")
91 tl.store(out_ptr + off_3, y3, mask=off_3 < N, eviction_policy="evict_first")
94@triton.jit
95def paste_u64(hi: tl.uint32, lo: tl.uint32):
96 hi = hi.to(tl.uint64) << 32
97 x = hi | lo.to(tl.uint64)
98 return x
101@triton.jit
102def transform_exponential(u, lambd, eps):
103 eps1 = -0.5 * eps
104 is_min = u >= 1.0 + eps1
105 trans_scale = 1.0 / 1.4426950408889634
106 log = tl.where(is_min, eps1, log2(u) * trans_scale)
107 v = -1.0 / lambd * log
108 return v
111def exponential_(x, lambd: float = 1.0, *, generator=None):
112 logger.debug("GEMS EXPONENTIAL_")
113 dtype = x.dtype
114 device = x.device
115 inplace = x.is_contiguous()
116 assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
117 is_double = dtype in (torch.float64,)
118 UNROLL = 2 if is_double else 4
119 N = x.numel()
120 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
121 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
122 # hence we cannot obtain the per thread offset as in Pytorch.
123 increment = triton.cdiv(N, UNROLL)
124 philox_seed, philox_offset = philox_backend_seed_offset(
125 increment, generator=generator
126 )
127 eps = torch.finfo(dtype).eps
128 x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device)
129 with torch_device_fn.device(device):
130 fused_exponential_kernel[grid_fn](
131 x_, N, is_double, lambd, eps, philox_seed, philox_offset
132 )
133 if not inplace:
134 x.copy_(x_)
135 return x