Coverage for src/flag_gems/ops/randn.py: 39%
77 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import device, torch_device_fn
8from flag_gems.utils.random_utils import (
9 philox_backend_seed_offset,
10 uint_to_uniform_float,
11)
12from flag_gems.utils.shape_utils import volume
15@triton.jit
16def high_precision_fast_sin_cos(x):
17 # Normalize to [-π, π]
18 two_pi = 6.283185307179586
19 x = x - two_pi * tl.floor(x / two_pi + 0.5)
20 x2 = x * x
22 # --- SIN: 7th-order minimax (x * P(x²)) ---
23 # Coefficients optimized for [-π, π], max error ~1.5e-9
24 s_c0 = 0.99999999999999999999
25 s_c1 = -0.16666666666666666654
26 s_c2 = 0.00833333333333332876
27 s_c3 = -0.00019841269841269616
28 s_c4 = 2.755731922398589e-6
29 s_c5 = -2.505210838544172e-8
31 sin_x = x * (
32 s_c0 + x2 * (s_c1 + x2 * (s_c2 + x2 * (s_c3 + x2 * (s_c4 + x2 * s_c5))))
33 )
35 # --- COS: 6th-order minimax (Q(x²)) ---
36 c_c0 = 1.0
37 c_c1 = -0.49999999999999999983
38 c_c2 = 0.04166666666666666636
39 c_c3 = -0.00138888888888888742
40 c_c4 = 2.4801587301587299e-5
41 c_c5 = -2.755731922398581e-7
43 cos_x = c_c0 + x2 * (c_c1 + x2 * (c_c2 + x2 * (c_c3 + x2 * (c_c4 + x2 * c_c5))))
45 return sin_x, cos_x
48@triton.jit
49def pair_uniform_to_normal_fast(u1, u2):
50 u1 = tl.maximum(1.0e-7, u1)
51 theta = 6.283185307179586 * u2
52 r = tl.sqrt(-2.0 * tl.log(u1))
53 sin_t, cos_t = high_precision_fast_sin_cos(theta)
54 return r * cos_t, r * sin_t
57device_ = device
58logger = logging.getLogger(__name__)
61# @triton.heuristics(runtime.get_heuristic_config("randn"))
62configs = [
63 triton.Config({"BLOCK": 256}, num_warps=8, num_stages=2),
64 triton.Config({"BLOCK": 512}, num_warps=4, num_stages=2),
65 triton.Config({"BLOCK": 512}, num_warps=8, num_stages=3),
66 triton.Config({"BLOCK": 1024}, num_warps=4, num_stages=2),
67 triton.Config({"BLOCK": 1024}, num_warps=8, num_stages=3),
68 triton.Config({"BLOCK": 1024}, num_warps=8, num_stages=4),
69]
72@triton.autotune(configs=configs, key=["N"])
73@triton.jit(do_not_specialize=["philox_seed", "philox_offset"])
74def randn_kernel(
75 out_ptr,
76 N,
77 philox_seed,
78 philox_offset,
79 BLOCK: tl.constexpr,
80):
81 philox_seed = philox_seed.to(tl.int64)
82 philox_offset = philox_offset.to(tl.int64)
83 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
84 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
85 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
86 c0 += i4
87 _O = c0 * 0
88 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
89 r0 = uint_to_uniform_float(r0)
90 r1 = uint_to_uniform_float(r1)
91 r2 = uint_to_uniform_float(r2)
92 r3 = uint_to_uniform_float(r3)
93 n0, n1 = pair_uniform_to_normal_fast(r0, r1)
94 n2, n3 = pair_uniform_to_normal_fast(r2, r3)
95 off_0 = tl.program_id(0) * BLOCK * 4 + tl.arange(0, BLOCK)
96 off_1 = off_0 + BLOCK
97 off_2 = off_1 + BLOCK
98 off_3 = off_2 + BLOCK
100 tl.store(out_ptr + off_0, n0, mask=off_0 < N, eviction_policy="evict_first")
101 tl.store(out_ptr + off_1, n1, mask=off_1 < N, eviction_policy="evict_first")
102 tl.store(out_ptr + off_2, n2, mask=off_2 < N, eviction_policy="evict_first")
103 tl.store(out_ptr + off_3, n3, mask=off_3 < N, eviction_policy="evict_first")
106UNROLL = 4
109def randn(size, *, dtype=None, layout=None, device=None, pin_memory=None):
110 logger.debug("GEMS RANDN")
111 if dtype is None:
112 dtype = torch.get_default_dtype()
113 if device is None:
114 device = torch.device(device_.name)
115 out = torch.empty(size, device=device, dtype=dtype)
116 N = volume(size)
117 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
118 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
119 # hence we cannot obtain the per thread offset as in Pytorch.
120 increment = triton.cdiv(N, UNROLL)
121 philox_seed, philox_offset = philox_backend_seed_offset(increment)
122 with torch_device_fn.device(device):
123 randn_kernel[grid_fn](out, N, philox_seed, philox_offset)
124 return out