Coverage for src/flag_gems/runtime/backend/_cambricon/ops/exponential_.py: 0%
69 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from triton.language.extra.mlu.libdevice import philox as _philox
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils.random_utils import (
11 philox_backend_seed_offset,
12 uint_to_uniform_float,
13)
15from ..utils import TOTAL_CORE_NUM
17logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
20@triton.heuristics(runtime.get_heuristic_config("exponential_"))
21@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
22def fused_exponential_kernel(
23 out_ptr,
24 N,
25 is_double: tl.constexpr,
26 lambd,
27 eps,
28 philox_seed,
29 philox_offset,
30 BLOCK: tl.constexpr,
31):
32 if is_double:
33 UNROLL: tl.constexpr = 2 # philox generate 128 random bits at a time
34 else:
35 UNROLL: tl.constexpr = 4 # philox generate 128 random bits at a time
36 philox_seed = philox_seed.to(tl.int64)
37 philox_offset = philox_offset.to(tl.int64)
39 pid = tl.program_id(0)
40 num_jobs = tl.num_programs(0)
41 i4_start = pid * BLOCK
42 block_start = pid * UNROLL * BLOCK
43 step = num_jobs * BLOCK * UNROLL
45 for block_offset in range(block_start, N, step):
46 sl = (philox_seed & 0xFFFFFFFF).to(tl.uint32)
47 sh = ((philox_seed >> 32) & 0xFFFFFFFF).to(tl.uint32)
48 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
49 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
50 r = _philox(BLOCK, sl, sh, c0 + i4_start, c1, 0, 0, 10)
51 r = tl.reshape(r, [UNROLL * BLOCK], can_reorder=True)
52 off = block_offset + tl.arange(0, UNROLL * BLOCK)
54 if is_double:
55 r = r.to(tl.uint64, bitcast=True)
56 f = uint_to_uniform_float(r)
57 else:
58 f = uint_to_uniform_float(r)
59 y = transform_exponential(f, lambd, eps)
60 tl.store(out_ptr + off, y, mask=off < N)
61 i4_start += num_jobs * BLOCK
64@triton.jit
65def paste_u64(hi: tl.uint32, lo: tl.uint32):
66 hi = hi.to(tl.uint64) << 32
67 x = hi | lo.to(tl.uint64)
68 return x
71@triton.jit
72def transform_exponential(u, lambd, eps):
73 eps1 = -0.5 * eps
74 is_min = u >= 1.0 + eps1
75 log = tl.where(is_min, eps1, tl.math.log(u))
76 v = -1.0 / lambd * log
77 return v
80def exponential_(x, lambd: float = 1.0, *, generator=None):
81 logger.debug("GEMS_CAMBRICON EXPONENTIAL_")
82 dtype = x.dtype
83 device = x.device
84 inplace = x.is_contiguous()
85 assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
86 is_double = dtype in (torch.float64,)
87 UNROLL = 2 if is_double else 4
88 N = x.numel()
89 grid_fn = lambda meta: (
90 min(triton.cdiv(N, meta["BLOCK"] * UNROLL), TOTAL_CORE_NUM),
91 )
92 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
93 # hence we cannot obtain the per thread offset as in Pytorch.
94 increment = triton.cdiv(N, UNROLL)
95 philox_seed, philox_offset = philox_backend_seed_offset(
96 increment, generator=generator
97 )
98 eps = torch.finfo(dtype).eps
99 x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device)
100 with torch_device_fn.device(device):
101 fused_exponential_kernel[grid_fn](
102 x_,
103 N,
104 is_double,
105 lambd,
106 eps,
107 philox_seed,
108 philox_offset,
109 num_warps=1,
110 num_stages=3,
111 )
112 if not inplace:
113 x.copy_(x_)
114 return x