Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/rand.py: 0%
131 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import device, torch_device_fn
9from flag_gems.utils.random_utils import (
10 philox_backend_seed_offset,
11 uint_to_uniform_float,
12)
13from flag_gems.utils.shape_utils import volume
15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16device_ = device
19@triton.heuristics(runtime.get_heuristic_config("rand"))
20@triton.jit(do_not_specialize=["philox_seed", "philox_offset"])
21def rand_kernel(
22 out_ptr,
23 N,
24 philox_seed,
25 philox_offset,
26 BLOCK: tl.constexpr,
27):
28 philox_seed = philox_seed.to(tl.int64)
29 philox_offset = philox_offset.to(tl.int64)
30 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
31 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
32 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
33 c0 += i4
34 _O = c0 * 0
35 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
36 r0 = uint_to_uniform_float(r0)
37 r1 = uint_to_uniform_float(r1)
38 r2 = uint_to_uniform_float(r2)
39 r3 = uint_to_uniform_float(r3)
40 off_0 = tl.program_id(0) * BLOCK * 4 + tl.arange(0, BLOCK)
41 off_1 = off_0 + BLOCK
42 off_2 = off_1 + BLOCK
43 off_3 = off_2 + BLOCK
44 tl.store(out_ptr + off_0, r0, mask=off_0 < N, eviction_policy="evict_first")
45 tl.store(out_ptr + off_1, r1, mask=off_1 < N, eviction_policy="evict_first")
46 tl.store(out_ptr + off_2, r2, mask=off_2 < N, eviction_policy="evict_first")
47 tl.store(out_ptr + off_3, r3, mask=off_3 < N, eviction_policy="evict_first")
50def choose_unroll(N, core=64, clusters=12):
51 for u in (16, 1):
52 if triton.cdiv(N, clusters * u) >= core:
53 return u
54 return 1
57# @triton.heuristics(runtime.get_heuristic_config("rand"))
58@triton.jit(do_not_specialize=["philox_seed", "philox_offset"])
59def rand_kernel_1(
60 out_ptr,
61 N,
62 philox_seed,
63 philox_offset,
64 BLOCK: tl.constexpr,
65 UNROLL: tl.constexpr,
66):
67 philox_seed = philox_seed.to(tl.int64)
68 philox_offset = philox_offset.to(tl.int64)
69 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
70 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
71 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
72 c0 += i4
73 _O = c0 * 0
74 r0 = tl.philox(philox_seed, c0, c1, _O, _O)
75 r0 = uint_to_uniform_float(r0)
76 off_0 = tl.program_id(0) * BLOCK * UNROLL + tl.arange(0, BLOCK)
77 tl.store(out_ptr + off_0, r0, mask=off_0 < N, eviction_policy="evict_first")
80@triton.jit(do_not_specialize=["philox_seed", "philox_offset"])
81def rand_kernel_2(
82 out_ptr,
83 N,
84 philox_seed,
85 philox_offset,
86 BLOCK: tl.constexpr,
87 UNROLL: tl.constexpr,
88):
89 philox_seed = philox_seed.to(tl.int64)
90 philox_offset = philox_offset.to(tl.int64)
91 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
92 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
93 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
94 c0 += i4
95 _O = c0 * 0
96 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
97 r4, r5, r6, r7 = tl.philox(philox_seed, c0 + 1, c1, _O, _O)
98 r8, r9, r10, r11 = tl.philox(philox_seed, c0 + 2, c1, _O, _O)
99 r12, r13, r14, r15 = tl.philox(philox_seed, c0 + 3, c1, _O, _O)
100 r0 = uint_to_uniform_float(r0)
101 r1 = uint_to_uniform_float(r1)
102 r2 = uint_to_uniform_float(r2)
103 r3 = uint_to_uniform_float(r3)
104 r4 = uint_to_uniform_float(r4)
105 r5 = uint_to_uniform_float(r5)
106 r6 = uint_to_uniform_float(r6)
107 r7 = uint_to_uniform_float(r7)
108 r8 = uint_to_uniform_float(r8)
109 r9 = uint_to_uniform_float(r9)
110 r10 = uint_to_uniform_float(r10)
111 r11 = uint_to_uniform_float(r11)
112 r12 = uint_to_uniform_float(r12)
113 r13 = uint_to_uniform_float(r13)
114 r14 = uint_to_uniform_float(r14)
115 r15 = uint_to_uniform_float(r15)
116 off_0 = tl.program_id(0) * BLOCK * UNROLL + tl.arange(0, BLOCK)
117 off_1 = off_0 + BLOCK
118 off_2 = off_1 + BLOCK
119 off_3 = off_2 + BLOCK
120 off_4 = off_3 + BLOCK
121 off_5 = off_4 + BLOCK
122 off_6 = off_5 + BLOCK
123 off_7 = off_6 + BLOCK
124 off_8 = off_7 + BLOCK
125 off_9 = off_8 + BLOCK
126 off_10 = off_9 + BLOCK
127 off_11 = off_10 + BLOCK
128 off_12 = off_11 + BLOCK
129 off_13 = off_12 + BLOCK
130 off_14 = off_13 + BLOCK
131 off_15 = off_14 + BLOCK
132 tl.store(out_ptr + off_0, r0, mask=off_0 < N, eviction_policy="evict_first")
133 tl.store(out_ptr + off_1, r1, mask=off_1 < N, eviction_policy="evict_first")
134 tl.store(out_ptr + off_2, r2, mask=off_2 < N, eviction_policy="evict_first")
135 tl.store(out_ptr + off_3, r3, mask=off_3 < N, eviction_policy="evict_first")
136 tl.store(out_ptr + off_4, r4, mask=off_4 < N, eviction_policy="evict_first")
137 tl.store(out_ptr + off_5, r5, mask=off_5 < N, eviction_policy="evict_first")
138 tl.store(out_ptr + off_6, r6, mask=off_6 < N, eviction_policy="evict_first")
139 tl.store(out_ptr + off_7, r7, mask=off_7 < N, eviction_policy="evict_first")
140 tl.store(out_ptr + off_8, r8, mask=off_8 < N, eviction_policy="evict_first")
141 tl.store(out_ptr + off_9, r9, mask=off_9 < N, eviction_policy="evict_first")
142 tl.store(out_ptr + off_10, r10, mask=off_10 < N, eviction_policy="evict_first")
143 tl.store(out_ptr + off_11, r11, mask=off_11 < N, eviction_policy="evict_first")
144 tl.store(out_ptr + off_12, r12, mask=off_12 < N, eviction_policy="evict_first")
145 tl.store(out_ptr + off_13, r13, mask=off_13 < N, eviction_policy="evict_first")
146 tl.store(out_ptr + off_14, r14, mask=off_14 < N, eviction_policy="evict_first")
147 tl.store(out_ptr + off_15, r15, mask=off_15 < N, eviction_policy="evict_first")
150def rand(size, *, dtype=None, layout=None, device=None, pin_memory=None):
151 logger.debug("GEMS RAND")
152 if dtype is None:
153 dtype = torch.get_default_dtype()
154 if device is None:
155 device = torch.device(device_.name)
157 out = torch.empty(size, device=device, dtype=dtype)
158 N = volume(size)
159 # grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
160 cluster_num = 12
161 UNROLL = choose_unroll(N)
162 BLOCK_SIZE = min(triton.next_power_of_2(triton.cdiv(N, cluster_num * UNROLL)), 1024)
163 grid_fn = triton.cdiv(N, BLOCK_SIZE * UNROLL)
164 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
165 # hence we cannot obtain the per thread offset as in Pytorch.
166 increment = triton.cdiv(N, UNROLL)
167 philox_seed, philox_offset = philox_backend_seed_offset(increment)
168 with torch_device_fn.device(device):
169 if UNROLL <= 4:
170 rand_kernel_1[(grid_fn,)](
171 out, N, philox_seed, philox_offset, BLOCK_SIZE, UNROLL
172 )
173 else:
174 rand_kernel_2[(grid_fn,)](
175 out, N, philox_seed, philox_offset, BLOCK_SIZE, UNROLL
176 )
177 return out