Coverage for src/flag_gems/runtime/backend/_mthreads/ops/rand.py: 0%
48 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import device, torch_device_fn
9from flag_gems.utils.random_utils import (
10 philox_backend_seed_offset,
11 uint_to_uniform_float,
12)
13from flag_gems.utils.shape_utils import volume
15logger = logging.getLogger(
16 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
17)
18device_ = device
21@triton.heuristics(runtime.get_heuristic_config("rand"))
22@triton.jit(do_not_specialize=["philox_seed", "philox_offset"])
23def rand_kernel(
24 out_ptr,
25 N,
26 philox_seed,
27 philox_offset,
28 BLOCK: tl.constexpr,
29):
30 philox_seed = philox_seed.to(tl.int64)
31 philox_offset = philox_offset.to(tl.int64)
32 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
33 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
34 i4 = (tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)).to(tl.uint32)
35 c0 += i4
36 _O = c0 * 0
37 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
38 r0 = uint_to_uniform_float(r0)
39 r1 = uint_to_uniform_float(r1)
40 r2 = uint_to_uniform_float(r2)
41 r3 = uint_to_uniform_float(r3)
42 off_0 = ((tl.program_id(0) * BLOCK * 4).to(tl.int64) + tl.arange(0, BLOCK)).to(
43 tl.int64
44 )
45 off_1 = off_0 + BLOCK
46 off_2 = off_1 + BLOCK
47 off_3 = off_2 + BLOCK
48 tl.store(out_ptr + off_0, r0, mask=off_0 < N, eviction_policy="evict_first")
49 tl.store(out_ptr + off_1, r1, mask=off_1 < N, eviction_policy="evict_first")
50 tl.store(out_ptr + off_2, r2, mask=off_2 < N, eviction_policy="evict_first")
51 tl.store(out_ptr + off_3, r3, mask=off_3 < N, eviction_policy="evict_first")
54UNROLL = 4
57def rand(size, *, dtype=None, layout=None, device=None, pin_memory=None):
58 logger.debug("GEMS_MTHREADS RAND")
59 if dtype is None:
60 dtype = torch.get_default_dtype()
61 if device is None:
62 device = torch.device(device_.name)
64 out = torch.empty(size, device=device, dtype=dtype)
65 N = volume(size)
66 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
67 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
68 # hence we cannot obtain the per thread offset as in Pytorch.
69 increment = triton.cdiv(N, UNROLL)
70 philox_seed, philox_offset = philox_backend_seed_offset(increment)
71 with torch_device_fn.device(device):
72 rand_kernel[grid_fn](out, N, philox_seed, philox_offset)
73 return out