Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/rand.py: 0%

131 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import device, torch_device_fn 

9from flag_gems.utils.random_utils import ( 

10 philox_backend_seed_offset, 

11 uint_to_uniform_float, 

12) 

13from flag_gems.utils.shape_utils import volume 

14 

15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

16device_ = device 

17 

18 

19@triton.heuristics(runtime.get_heuristic_config("rand")) 

20@triton.jit(do_not_specialize=["philox_seed", "philox_offset"]) 

21def rand_kernel( 

22 out_ptr, 

23 N, 

24 philox_seed, 

25 philox_offset, 

26 BLOCK: tl.constexpr, 

27): 

28 philox_seed = philox_seed.to(tl.int64) 

29 philox_offset = philox_offset.to(tl.int64) 

30 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

31 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

32 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) 

33 c0 += i4 

34 _O = c0 * 0 

35 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O) 

36 r0 = uint_to_uniform_float(r0) 

37 r1 = uint_to_uniform_float(r1) 

38 r2 = uint_to_uniform_float(r2) 

39 r3 = uint_to_uniform_float(r3) 

40 off_0 = tl.program_id(0) * BLOCK * 4 + tl.arange(0, BLOCK) 

41 off_1 = off_0 + BLOCK 

42 off_2 = off_1 + BLOCK 

43 off_3 = off_2 + BLOCK 

44 tl.store(out_ptr + off_0, r0, mask=off_0 < N, eviction_policy="evict_first") 

45 tl.store(out_ptr + off_1, r1, mask=off_1 < N, eviction_policy="evict_first") 

46 tl.store(out_ptr + off_2, r2, mask=off_2 < N, eviction_policy="evict_first") 

47 tl.store(out_ptr + off_3, r3, mask=off_3 < N, eviction_policy="evict_first") 

48 

49 

50def choose_unroll(N, core=64, clusters=12): 

51 for u in (16, 1): 

52 if triton.cdiv(N, clusters * u) >= core: 

53 return u 

54 return 1 

55 

56 

57# @triton.heuristics(runtime.get_heuristic_config("rand")) 

58@triton.jit(do_not_specialize=["philox_seed", "philox_offset"]) 

59def rand_kernel_1( 

60 out_ptr, 

61 N, 

62 philox_seed, 

63 philox_offset, 

64 BLOCK: tl.constexpr, 

65 UNROLL: tl.constexpr, 

66): 

67 philox_seed = philox_seed.to(tl.int64) 

68 philox_offset = philox_offset.to(tl.int64) 

69 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

70 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

71 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) 

72 c0 += i4 

73 _O = c0 * 0 

74 r0 = tl.philox(philox_seed, c0, c1, _O, _O) 

75 r0 = uint_to_uniform_float(r0) 

76 off_0 = tl.program_id(0) * BLOCK * UNROLL + tl.arange(0, BLOCK) 

77 tl.store(out_ptr + off_0, r0, mask=off_0 < N, eviction_policy="evict_first") 

78 

79 

80@triton.jit(do_not_specialize=["philox_seed", "philox_offset"]) 

81def rand_kernel_2( 

82 out_ptr, 

83 N, 

84 philox_seed, 

85 philox_offset, 

86 BLOCK: tl.constexpr, 

87 UNROLL: tl.constexpr, 

88): 

89 philox_seed = philox_seed.to(tl.int64) 

90 philox_offset = philox_offset.to(tl.int64) 

91 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

92 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

93 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) 

94 c0 += i4 

95 _O = c0 * 0 

96 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O) 

97 r4, r5, r6, r7 = tl.philox(philox_seed, c0 + 1, c1, _O, _O) 

98 r8, r9, r10, r11 = tl.philox(philox_seed, c0 + 2, c1, _O, _O) 

99 r12, r13, r14, r15 = tl.philox(philox_seed, c0 + 3, c1, _O, _O) 

100 r0 = uint_to_uniform_float(r0) 

101 r1 = uint_to_uniform_float(r1) 

102 r2 = uint_to_uniform_float(r2) 

103 r3 = uint_to_uniform_float(r3) 

104 r4 = uint_to_uniform_float(r4) 

105 r5 = uint_to_uniform_float(r5) 

106 r6 = uint_to_uniform_float(r6) 

107 r7 = uint_to_uniform_float(r7) 

108 r8 = uint_to_uniform_float(r8) 

109 r9 = uint_to_uniform_float(r9) 

110 r10 = uint_to_uniform_float(r10) 

111 r11 = uint_to_uniform_float(r11) 

112 r12 = uint_to_uniform_float(r12) 

113 r13 = uint_to_uniform_float(r13) 

114 r14 = uint_to_uniform_float(r14) 

115 r15 = uint_to_uniform_float(r15) 

116 off_0 = tl.program_id(0) * BLOCK * UNROLL + tl.arange(0, BLOCK) 

117 off_1 = off_0 + BLOCK 

118 off_2 = off_1 + BLOCK 

119 off_3 = off_2 + BLOCK 

120 off_4 = off_3 + BLOCK 

121 off_5 = off_4 + BLOCK 

122 off_6 = off_5 + BLOCK 

123 off_7 = off_6 + BLOCK 

124 off_8 = off_7 + BLOCK 

125 off_9 = off_8 + BLOCK 

126 off_10 = off_9 + BLOCK 

127 off_11 = off_10 + BLOCK 

128 off_12 = off_11 + BLOCK 

129 off_13 = off_12 + BLOCK 

130 off_14 = off_13 + BLOCK 

131 off_15 = off_14 + BLOCK 

132 tl.store(out_ptr + off_0, r0, mask=off_0 < N, eviction_policy="evict_first") 

133 tl.store(out_ptr + off_1, r1, mask=off_1 < N, eviction_policy="evict_first") 

134 tl.store(out_ptr + off_2, r2, mask=off_2 < N, eviction_policy="evict_first") 

135 tl.store(out_ptr + off_3, r3, mask=off_3 < N, eviction_policy="evict_first") 

136 tl.store(out_ptr + off_4, r4, mask=off_4 < N, eviction_policy="evict_first") 

137 tl.store(out_ptr + off_5, r5, mask=off_5 < N, eviction_policy="evict_first") 

138 tl.store(out_ptr + off_6, r6, mask=off_6 < N, eviction_policy="evict_first") 

139 tl.store(out_ptr + off_7, r7, mask=off_7 < N, eviction_policy="evict_first") 

140 tl.store(out_ptr + off_8, r8, mask=off_8 < N, eviction_policy="evict_first") 

141 tl.store(out_ptr + off_9, r9, mask=off_9 < N, eviction_policy="evict_first") 

142 tl.store(out_ptr + off_10, r10, mask=off_10 < N, eviction_policy="evict_first") 

143 tl.store(out_ptr + off_11, r11, mask=off_11 < N, eviction_policy="evict_first") 

144 tl.store(out_ptr + off_12, r12, mask=off_12 < N, eviction_policy="evict_first") 

145 tl.store(out_ptr + off_13, r13, mask=off_13 < N, eviction_policy="evict_first") 

146 tl.store(out_ptr + off_14, r14, mask=off_14 < N, eviction_policy="evict_first") 

147 tl.store(out_ptr + off_15, r15, mask=off_15 < N, eviction_policy="evict_first") 

148 

149 

150def rand(size, *, dtype=None, layout=None, device=None, pin_memory=None): 

151 logger.debug("GEMS RAND") 

152 if dtype is None: 

153 dtype = torch.get_default_dtype() 

154 if device is None: 

155 device = torch.device(device_.name) 

156 

157 out = torch.empty(size, device=device, dtype=dtype) 

158 N = volume(size) 

159 # grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) 

160 cluster_num = 12 

161 UNROLL = choose_unroll(N) 

162 BLOCK_SIZE = min(triton.next_power_of_2(triton.cdiv(N, cluster_num * UNROLL)), 1024) 

163 grid_fn = triton.cdiv(N, BLOCK_SIZE * UNROLL) 

164 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, 

165 # hence we cannot obtain the per thread offset as in Pytorch. 

166 increment = triton.cdiv(N, UNROLL) 

167 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

168 with torch_device_fn.device(device): 

169 if UNROLL <= 4: 

170 rand_kernel_1[(grid_fn,)]( 

171 out, N, philox_seed, philox_offset, BLOCK_SIZE, UNROLL 

172 ) 

173 else: 

174 rand_kernel_2[(grid_fn,)]( 

175 out, N, philox_seed, philox_offset, BLOCK_SIZE, UNROLL 

176 ) 

177 return out