Coverage for src/flag_gems/ops/randn.py: 39%

77 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import device, torch_device_fn 

8from flag_gems.utils.random_utils import ( 

9 philox_backend_seed_offset, 

10 uint_to_uniform_float, 

11) 

12from flag_gems.utils.shape_utils import volume 

13 

14 

15@triton.jit 

16def high_precision_fast_sin_cos(x): 

17 # Normalize to [-π, π] 

18 two_pi = 6.283185307179586 

19 x = x - two_pi * tl.floor(x / two_pi + 0.5) 

20 x2 = x * x 

21 

22 # --- SIN: 7th-order minimax (x * P(x²)) --- 

23 # Coefficients optimized for [-π, π], max error ~1.5e-9 

24 s_c0 = 0.99999999999999999999 

25 s_c1 = -0.16666666666666666654 

26 s_c2 = 0.00833333333333332876 

27 s_c3 = -0.00019841269841269616 

28 s_c4 = 2.755731922398589e-6 

29 s_c5 = -2.505210838544172e-8 

30 

31 sin_x = x * ( 

32 s_c0 + x2 * (s_c1 + x2 * (s_c2 + x2 * (s_c3 + x2 * (s_c4 + x2 * s_c5)))) 

33 ) 

34 

35 # --- COS: 6th-order minimax (Q(x²)) --- 

36 c_c0 = 1.0 

37 c_c1 = -0.49999999999999999983 

38 c_c2 = 0.04166666666666666636 

39 c_c3 = -0.00138888888888888742 

40 c_c4 = 2.4801587301587299e-5 

41 c_c5 = -2.755731922398581e-7 

42 

43 cos_x = c_c0 + x2 * (c_c1 + x2 * (c_c2 + x2 * (c_c3 + x2 * (c_c4 + x2 * c_c5)))) 

44 

45 return sin_x, cos_x 

46 

47 

48@triton.jit 

49def pair_uniform_to_normal_fast(u1, u2): 

50 u1 = tl.maximum(1.0e-7, u1) 

51 theta = 6.283185307179586 * u2 

52 r = tl.sqrt(-2.0 * tl.log(u1)) 

53 sin_t, cos_t = high_precision_fast_sin_cos(theta) 

54 return r * cos_t, r * sin_t 

55 

56 

57device_ = device 

58logger = logging.getLogger(__name__) 

59 

60 

61# @triton.heuristics(runtime.get_heuristic_config("randn")) 

62configs = [ 

63 triton.Config({"BLOCK": 256}, num_warps=8, num_stages=2), 

64 triton.Config({"BLOCK": 512}, num_warps=4, num_stages=2), 

65 triton.Config({"BLOCK": 512}, num_warps=8, num_stages=3), 

66 triton.Config({"BLOCK": 1024}, num_warps=4, num_stages=2), 

67 triton.Config({"BLOCK": 1024}, num_warps=8, num_stages=3), 

68 triton.Config({"BLOCK": 1024}, num_warps=8, num_stages=4), 

69] 

70 

71 

72@triton.autotune(configs=configs, key=["N"]) 

73@triton.jit(do_not_specialize=["philox_seed", "philox_offset"]) 

74def randn_kernel( 

75 out_ptr, 

76 N, 

77 philox_seed, 

78 philox_offset, 

79 BLOCK: tl.constexpr, 

80): 

81 philox_seed = philox_seed.to(tl.int64) 

82 philox_offset = philox_offset.to(tl.int64) 

83 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

84 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

85 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) 

86 c0 += i4 

87 _O = c0 * 0 

88 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O) 

89 r0 = uint_to_uniform_float(r0) 

90 r1 = uint_to_uniform_float(r1) 

91 r2 = uint_to_uniform_float(r2) 

92 r3 = uint_to_uniform_float(r3) 

93 n0, n1 = pair_uniform_to_normal_fast(r0, r1) 

94 n2, n3 = pair_uniform_to_normal_fast(r2, r3) 

95 off_0 = tl.program_id(0) * BLOCK * 4 + tl.arange(0, BLOCK) 

96 off_1 = off_0 + BLOCK 

97 off_2 = off_1 + BLOCK 

98 off_3 = off_2 + BLOCK 

99 

100 tl.store(out_ptr + off_0, n0, mask=off_0 < N, eviction_policy="evict_first") 

101 tl.store(out_ptr + off_1, n1, mask=off_1 < N, eviction_policy="evict_first") 

102 tl.store(out_ptr + off_2, n2, mask=off_2 < N, eviction_policy="evict_first") 

103 tl.store(out_ptr + off_3, n3, mask=off_3 < N, eviction_policy="evict_first") 

104 

105 

106UNROLL = 4 

107 

108 

109def randn(size, *, dtype=None, layout=None, device=None, pin_memory=None): 

110 logger.debug("GEMS RANDN") 

111 if dtype is None: 

112 dtype = torch.get_default_dtype() 

113 if device is None: 

114 device = torch.device(device_.name) 

115 out = torch.empty(size, device=device, dtype=dtype) 

116 N = volume(size) 

117 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) 

118 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, 

119 # hence we cannot obtain the per thread offset as in Pytorch. 

120 increment = triton.cdiv(N, UNROLL) 

121 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

122 with torch_device_fn.device(device): 

123 randn_kernel[grid_fn](out, N, philox_seed, philox_offset) 

124 return out