Coverage for src/flag_gems/runtime/backend/_mthreads/ops/randn.py: 0%
59 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import device, torch_device_fn
9from flag_gems.utils.random_utils import (
10 philox_backend_seed_offset,
11 uint_to_uniform_float,
12)
13from flag_gems.utils.shape_utils import volume
15try:
16 pair_uniform_to_normal = tl.pair_uniform_to_normal
17except AttributeError:
19 @triton.jit
20 def pair_uniform_to_normal(u1, u2):
21 """Box-Muller transform"""
22 u1 = tl.maximum(1.0e-7, u1)
23 th = 6.283185307179586 * u2
24 r = tl.sqrt(-2.0 * tl.log(u1))
25 return r * tl.cos(th), r * tl.sin(th)
28device_ = device
29logger = logging.getLogger(
30 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
31)
34@triton.heuristics(runtime.get_heuristic_config("randn"))
35@triton.jit(do_not_specialize=["philox_seed", "philox_offset"])
36def randn_kernel(
37 out_ptr,
38 N,
39 philox_seed,
40 philox_offset,
41 BLOCK: tl.constexpr,
42):
43 philox_seed = philox_seed.to(tl.int64)
44 philox_offset = philox_offset.to(tl.int64)
45 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
46 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
47 i4 = (tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)).to(tl.uint32)
48 c0 += i4
49 _O = c0 * 0
50 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
51 r0 = uint_to_uniform_float(r0)
52 r1 = uint_to_uniform_float(r1)
53 r2 = uint_to_uniform_float(r2)
54 r3 = uint_to_uniform_float(r3)
55 n0, n1 = pair_uniform_to_normal(r0, r1)
56 n2, n3 = pair_uniform_to_normal(r2, r3)
57 off_0 = ((tl.program_id(0) * BLOCK * 4).to(tl.int64) + tl.arange(0, BLOCK)).to(
58 tl.int64
59 )
60 off_1 = off_0 + BLOCK
61 off_2 = off_1 + BLOCK
62 off_3 = off_2 + BLOCK
63 tl.store(out_ptr + off_0, n0, mask=off_0 < N, eviction_policy="evict_first")
64 tl.store(out_ptr + off_1, n1, mask=off_1 < N, eviction_policy="evict_first")
65 tl.store(out_ptr + off_2, n2, mask=off_2 < N, eviction_policy="evict_first")
66 tl.store(out_ptr + off_3, n3, mask=off_3 < N, eviction_policy="evict_first")
69UNROLL = 4
72def randn(size, *, dtype=None, layout=None, device=None, pin_memory=None):
73 logger.debug("GEMS_MTHREADS RANDN")
74 if dtype is None:
75 dtype = torch.get_default_dtype()
76 if device is None:
77 device = torch.device(device_.name)
78 out = torch.empty(size, device=device, dtype=dtype)
79 N = volume(size)
80 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
81 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
82 # hence we cannot obtain the per thread offset as in Pytorch.
83 increment = triton.cdiv(N, UNROLL)
84 philox_seed, philox_offset = philox_backend_seed_offset(increment)
85 with torch_device_fn.device(device):
86 randn_kernel[grid_fn](out, N, philox_seed, philox_offset)
87 return out