Coverage for src/flag_gems/runtime/backend/_ascend/ops/exponential_.py: 0%
86 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils.random_utils import (
10 philox_backend_seed_offset,
11 uint_to_uniform_float,
12)
14logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
17@triton.heuristics(runtime.get_heuristic_config("exponential_"))
18@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
19def fused_exponential_kernel(
20 out_ptr,
21 N,
22 is_double,
23 lambd,
24 eps,
25 philox_seed,
26 philox_offset,
27 UNROLL,
28 BLOCK: tl.constexpr,
29):
30 n_workers = tl.num_programs(0)
31 pid = tl.program_id(0)
32 n_tasks = tl.cdiv(N, BLOCK * UNROLL)
33 tasks_per_worker = tl.cdiv(n_tasks, n_workers)
35 for task_index in range(tasks_per_worker):
36 task_id = pid + task_index * n_workers
37 philox_seed = philox_seed.to(tl.int64)
38 philox_offset_64 = philox_offset.to(tl.int64)
39 c0 = (philox_offset_64 & 0xFFFFFFFF).to(tl.uint32)
40 c1 = ((philox_offset_64 >> 32) & 0xFFFFFFFF).to(tl.uint32)
41 i4 = task_id * BLOCK + tl.arange(0, BLOCK)
42 c0 += i4
43 _O = c0 * 0
44 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
45 if is_double:
46 d0 = uint_to_uniform_float(paste_u64(r0, r2))
47 d1 = uint_to_uniform_float(paste_u64(r1, r3))
48 y0 = transform_exponential(d0, lambd, eps)
49 y1 = transform_exponential(d1, lambd, eps)
50 # UNROLLL = 2
51 start = task_id.to(tl.int64) * BLOCK * 2
52 off_0 = start + tl.arange(0, BLOCK)
53 off_1 = off_0 + BLOCK
54 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first")
55 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first")
56 else:
57 f0 = uint_to_uniform_float(r0)
58 f1 = uint_to_uniform_float(r1)
59 f2 = uint_to_uniform_float(r2)
60 f3 = uint_to_uniform_float(r3)
61 y0 = transform_exponential(f0, lambd, eps)
62 y1 = transform_exponential(f1, lambd, eps)
63 y2 = transform_exponential(f2, lambd, eps)
64 y3 = transform_exponential(f3, lambd, eps)
65 # UNROLLL = 4
66 start = task_id.to(tl.int64) * BLOCK * 4
67 off_0 = start + tl.arange(0, BLOCK)
68 off_1 = off_0 + BLOCK
69 off_2 = off_1 + BLOCK
70 off_3 = off_2 + BLOCK
71 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first")
72 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first")
73 tl.store(out_ptr + off_2, y2, mask=off_2 < N, eviction_policy="evict_first")
74 tl.store(out_ptr + off_3, y3, mask=off_3 < N, eviction_policy="evict_first")
77@triton.jit
78def paste_u64(hi: tl.uint32, lo: tl.uint32):
79 hi = hi.to(tl.uint64) << 32
80 x = hi | lo.to(tl.uint64)
81 return x
84@triton.jit
85def transform_exponential(u, lambd, eps):
86 eps1 = -0.5 * eps
87 is_min = u >= 1.0 + eps1
88 log = tl.where(is_min, eps1, tl.math.log(u))
89 v = -1.0 / lambd * log
90 return v
93def exponential_(x, lambd: float = 1.0, *, gen=None):
94 logger.debug("GEMS_ASCEND EXPONENTIAL_")
95 dtype = x.dtype
96 device = x.device
97 inplace = x.is_contiguous()
98 assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
99 is_double = dtype in (torch.float64,)
100 UNROLL = 2 if is_double else 4
101 N = x.numel()
103 def grid_fn(meta):
104 grid = triton.cdiv(N, meta["BLOCK"] * UNROLL)
105 grid = grid if grid < 240 else 240
106 return (grid,)
108 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
109 # hence we cannot obtain the per thread offset as in Pytorch.
110 increment = triton.cdiv(N, UNROLL)
111 philox_seed, philox_offset = philox_backend_seed_offset(increment)
112 eps = torch.finfo(dtype).eps
113 x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device)
114 with torch_device_fn.device(device):
115 fused_exponential_kernel[grid_fn](
116 x_, N, is_double, lambd, eps, philox_seed, philox_offset, UNROLL
117 )
118 if not inplace:
119 x.copy_(x_)
120 return x