Coverage for src/flag_gems/ops/rand.py: 54%
48 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import device, torch_device_fn
9from flag_gems.utils.random_utils import (
10 philox_backend_seed_offset,
11 uint_to_uniform_float,
12)
13from flag_gems.utils.shape_utils import volume
15logger = logging.getLogger(__name__)
16device_ = device
19@triton.heuristics(runtime.get_heuristic_config("rand"))
20@triton.jit(do_not_specialize=["philox_seed", "philox_offset"])
21def rand_kernel(
22 out_ptr,
23 N,
24 philox_seed,
25 philox_offset,
26 BLOCK: tl.constexpr,
27):
28 philox_seed = philox_seed.to(tl.int64)
29 philox_offset = philox_offset.to(tl.int64)
30 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
31 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
32 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
33 c0 += i4
34 _O = c0 * 0
35 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
36 r0 = uint_to_uniform_float(r0)
37 r1 = uint_to_uniform_float(r1)
38 r2 = uint_to_uniform_float(r2)
39 r3 = uint_to_uniform_float(r3)
40 off_0 = tl.program_id(0) * BLOCK * 4 + tl.arange(0, BLOCK)
41 off_1 = off_0 + BLOCK
42 off_2 = off_1 + BLOCK
43 off_3 = off_2 + BLOCK
44 tl.store(out_ptr + off_0, r0, mask=off_0 < N, eviction_policy="evict_first")
45 tl.store(out_ptr + off_1, r1, mask=off_1 < N, eviction_policy="evict_first")
46 tl.store(out_ptr + off_2, r2, mask=off_2 < N, eviction_policy="evict_first")
47 tl.store(out_ptr + off_3, r3, mask=off_3 < N, eviction_policy="evict_first")
50UNROLL = 4
53def rand(size, *, dtype=None, layout=None, device=None, pin_memory=None):
54 logger.debug("GEMS RAND")
55 if dtype is None:
56 dtype = torch.get_default_dtype()
57 if device is None:
58 device = torch.device(device_.name)
60 out = torch.empty(size, device=device, dtype=dtype)
61 N = volume(size)
62 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
63 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
64 # hence we cannot obtain the per thread offset as in Pytorch.
65 increment = triton.cdiv(N, UNROLL)
66 philox_seed, philox_offset = philox_backend_seed_offset(increment)
67 with torch_device_fn.device(device):
68 rand_kernel[grid_fn](out, N, philox_seed, philox_offset)
69 return out