Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/uniform.py: 0%
41 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems import runtime
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils.random_utils import (
9 philox_backend_seed_offset,
10 uint_to_uniform_float,
11)
12from flag_gems.utils.shape_utils import volume
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
17@triton.heuristics(runtime.get_heuristic_config("uniform"))
18@triton.jit(do_not_specialize=["philox_seed", "philox_offset"])
19def uniform_kernel(
20 out_ptr,
21 N,
22 philox_seed,
23 philox_offset,
24 from_,
25 to,
26 BLOCK: tl.constexpr,
27):
28 philox_seed = philox_seed.to(tl.int64)
29 philox_offset = philox_offset.to(tl.int64)
30 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
31 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
32 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
33 c0 += i4
34 _O = c0 * 0
35 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
36 r0 = uint_to_uniform_float(r0) * (to - from_) + from_
37 r1 = uint_to_uniform_float(r1) * (to - from_) + from_
38 r2 = uint_to_uniform_float(r2) * (to - from_) + from_
39 r3 = uint_to_uniform_float(r3) * (to - from_) + from_
40 off_0 = tl.program_id(0) * BLOCK * 4 + tl.arange(0, BLOCK)
41 off_1 = off_0 + BLOCK
42 off_2 = off_1 + BLOCK
43 off_3 = off_2 + BLOCK
44 tl.store(out_ptr + off_0, r0, mask=off_0 < N, eviction_policy="evict_first")
45 tl.store(out_ptr + off_1, r1, mask=off_1 < N, eviction_policy="evict_first")
46 tl.store(out_ptr + off_2, r2, mask=off_2 < N, eviction_policy="evict_first")
47 tl.store(out_ptr + off_3, r3, mask=off_3 < N, eviction_policy="evict_first")
50UNROLL = 4
53def uniform_(self, from_=0.0, to=1.0, *, generator=None):
54 logger.debug("GEMS UNIFORM")
55 N = volume(self.shape)
56 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
58 increment = triton.cdiv(N, UNROLL)
59 philox_seed, philox_offset = philox_backend_seed_offset(
60 increment, generator=generator
61 )
62 with torch_device_fn.device(self.device):
63 uniform_kernel[grid_fn](self, N, philox_seed, philox_offset, from_, to)
64 return self